diff --git a/.env b/.env index 5f461fdebce86..e066cc2b2b350 100644 --- a/.env +++ b/.env @@ -1,10 +1,11 @@ IMAGE_REPO=milvusdb IMAGE_ARCH=amd64 OS_NAME=ubuntu20.04 -DATE_VERSION=20230830-30ca458 -LATEST_DATE_VERSION=20230830-30ca458 +DATE_VERSION=20231011-11b5213 +LATEST_DATE_VERSION=20231011-11b5213 GPU_DATE_VERSION=20230822-a64488a LATEST_GPU_DATE_VERSION=20230317-a1c7b0c MINIO_ADDRESS=minio:9000 PULSAR_ADDRESS=pulsar://pulsar:6650 ETCD_ENDPOINTS=etcd:2379 +AZURITE_CONNECTION_STRING="DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://azurite:10000/devstoreaccount1;" \ No newline at end of file diff --git a/.github/mergify.yml b/.github/mergify.yml index ff4a56afef4ac..6b6911ec92e3b 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -30,11 +30,12 @@ pull_request_rules: add: - dco-passed - - name: Test passed for code changed-master + - name: Test passed for code changed-master conditions: - or: - base=sql_beta - base=master + - base~=^2\.3\.\d+$ - 'status-success=Build and test AMD64 Ubuntu 20.04' - 'status-success=Code Checker AMD64 Ubuntu 20.04' - 'status-success=Code Checker MacOS 12' @@ -46,7 +47,7 @@ pull_request_rules: label: add: - ci-passed - - name: Test passed for code changed -2.*.* + - name: Test passed for code changed -2.2.* conditions: - base~=^2(\.\d+){2}$ - 'status-success=Code Checker AMD64 Ubuntu 20.04' @@ -60,6 +61,7 @@ pull_request_rules: label: add: - ci-passed + - name: Test passed for tests changed conditions: - or: @@ -103,21 +105,23 @@ pull_request_rules: - or: - base=master - base=sql_beta + - base~=^2\.3\.\d+$ - 'status-success=Build and test AMD64 Ubuntu 20.04' - 'status-success=Code Checker AMD64 Ubuntu 20.04' - 'status-success=Code Checker MacOS 12' - 'status-success=Code Checker Amazonlinux 2023' - 'status-success=UT for Go (20.04)' - or: - - -files~=^(?!pkg\/..*test\.go).*$ + - -files~=^(?!pkg\/.*_test\.go).*$ - -files~=^(?!internal\/.*_test\.go).*$ actions: label: add: - ci-passed - - name: Test passed for go unittest code changed -2.*.* + + - name: Test passed for go unittest code changed -2.2.* conditions: - - base~=^2(\.\d+){2}$ + - base~=^2\.2\.\d+$ - 'status-success=Code Checker AMD64 Ubuntu 20.04' - 'status-success=Build and test AMD64 Ubuntu 20.04' - 'status-success=Code Checker MacOS 12' @@ -197,8 +201,7 @@ pull_request_rules: - name: Blocking PR if missing a related master PR or doesn't have kind/branch-feature label conditions: - - base=2.2.0 - # - base~=^2(\.\d+){2}$ + - base~=^2(\.\d+){2}$ - and: - -body~=pr\:\ \#[0-9]{1,6}(\s+|$) - -body~=https://github.com/milvus-io/milvus/pull/[0-9]{1,6}(\s+|$) @@ -214,8 +217,7 @@ pull_request_rules: - name: Dismiss block label if related pr be added into PR conditions: - - base=2.2.0 - # - base~=^2(\.\d+){2}$ + - base~=^2(\.\d+){2}$ - or: - body~=pr\:\ \#[0-9]{1,6}(\s+|$) - body~=https://github.com/milvus-io/milvus/pull/[0-9]{1,6}(\s+|$) @@ -243,6 +245,7 @@ pull_request_rules: - or: - base=master - base=sql_beta + - base~=^2\.3\.\d+$ - title~=\[skip e2e\] - 'status-success=Code Checker AMD64 Ubuntu 20.04' - 'status-success=Build and test AMD64 Ubuntu 20.04' @@ -254,9 +257,9 @@ pull_request_rules: add: - ci-passed - - name: Test passed for skip e2e - 2.*.* + - name: Test passed for skip e2e - 2.2.* conditions: - - base~=^2(\.\d+){2}$ + - base~=^2\.2\.\d+$ - title~=\[skip e2e\] - 'status-success=Code Checker AMD64 Ubuntu 20.04' - 'status-success=Build and test AMD64 Ubuntu 20.04' @@ -268,12 +271,13 @@ pull_request_rules: add: - ci-passed - - name: Remove ci-passed label when status for code checker or ut is not success-master + - name: Remove ci-passed label when status for code checker or ut is not success-master conditions: - label!=manual-pass - or: - base=master - base=sql_beta + - base~=^2\.3\.\d+$ - files~=^(?=.*((\.(go|h|cpp)|CMakeLists.txt))).*$ - or: - 'status-success!=Code Checker AMD64 Ubuntu 20.04' @@ -284,10 +288,11 @@ pull_request_rules: label: remove: - ci-passed - - name: Remove ci-passed label when status for code checker or ut is not success-2.*.* + + - name: Remove ci-passed label when status for code checker or ut is not success-2.2.* conditions: - label!=manual-pass - - base~=^2(\.\d+){2}$ + - base~=^2\.2\.\d+$ - files~=^(?=.*((\.(go|h|cpp)|CMakeLists.txt))).*$ - or: - 'status-success!=Code Checker AMD64 Ubuntu 20.04' @@ -298,6 +303,7 @@ pull_request_rules: label: remove: - ci-passed + - name: Remove ci-passed label when status for jenkins job is not success conditions: - label!=manual-pass @@ -306,7 +312,7 @@ pull_request_rules: - base=sql_beta - base~=^2(\.\d+){2}$ - -title~=\[skip e2e\] - - files~=^(?!(internal\/.*_test\.go|.*\.md)).*$ + - files~=^(?!(.*_test\.go|.*\.md)).*$ - 'status-success!=cpu-e2e' actions: label: @@ -338,9 +344,9 @@ pull_request_rules: message: | @{{author}} ut workflow job failed, comment `rerun ut` can trigger the job again. - - name: Add comment when code checker or ut failed -2.*.* + - name: Add comment when code checker or ut failed -2.2.* conditions: - - base~=^2(\.\d+){2}$ + - base~=^2\.2\.\d+$ - or: - 'check-failure=Code Checker AMD64 Ubuntu 20.04' - 'check-failure=Build and test AMD64 Ubuntu 20.04' diff --git a/.github/workflows/check-issue.yaml b/.github/workflows/check-issue.yaml new file mode 100644 index 0000000000000..952c07641b91c --- /dev/null +++ b/.github/workflows/check-issue.yaml @@ -0,0 +1,48 @@ +name: Add Comment for issue + +on: + issues: + types: [opened] + +jobs: + check_issue_title: + name: Check issue + runs-on: ubuntu-latest + env: + TITLE_PASSED: "T" + permissions: + issues: write + timeout-minutes: 20 + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Check Issue + shell: bash + run: | + echo Issue title: ${{ github.event.issue.title }} + cat >> check_title.py << EOF + import re + import sys + check_str = sys.argv[1] + pattern = re.compile(r'[\u4e00-\u9fa5]') + match = pattern.search(check_str) + if match: + print("TITLE_PASSED=F") + else: + print("TITLE_PASSED=T") + EOF + + python3 check_title.py "${{ github.event.issue.title }}" >> "$GITHUB_ENV" + cat $GITHUB_ENV + + - name: Check env + shell: bash + run: | + echo ${{ env.TITLE_PASSED }} + - name: Add comment + if: ${{ env.TITLE_PASSED == 'F'}} + uses: peter-evans/create-or-update-comment@5f728c3dae25f329afbe34ee4d08eef25569d79f + with: + issue-number: ${{ github.event.issue.number }} + body: | + The title and description of this issue contains Chinese. Please use English to describe your issue. \ No newline at end of file diff --git a/.github/workflows/code-checker.yaml b/.github/workflows/code-checker.yaml index c801ad97bb957..d4f4416745f51 100644 --- a/.github/workflows/code-checker.yaml +++ b/.github/workflows/code-checker.yaml @@ -53,7 +53,7 @@ jobs: uses: actions/cache@v3 with: path: .docker/amd64-ubuntu20.04-go-mod - key: ubuntu20.04-go-mod-${{ hashFiles('**/go.sum') }} + key: ubuntu20.04-go-mod-${{ hashFiles('go.sum, */go.sum') }} restore-keys: ubuntu20.04-go-mod- - name: Cache Conan Packages uses: pat-s/always-upload-cache@v3 @@ -98,7 +98,7 @@ jobs: uses: actions/cache@v3 with: path: .docker/amd64-amazonlinux2023-go-mod - key: amazonlinux2023-go-mod-${{ hashFiles('**/go.sum') }} + key: amazonlinux2023-go-mod-${{ hashFiles('go.sum, */go.sum') }} restore-keys: amazonlinux2023-go-mod- - name: Cache Conan Packages uses: pat-s/always-upload-cache@v3 @@ -107,7 +107,6 @@ jobs: key: amazonlinux2023-conan-${{ hashFiles('internal/core/conanfile.*') }} restore-keys: amazonlinux2023-conan- - name: Code Check - env: - OS_NAME: 'amazonlinux2023' run: | + sed -i 's/ubuntu20.04/amazonlinux2023/g' .env ./build/builder.sh /bin/bash -c "make install" diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 79509edd62d0e..590e0b1cf17db 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -67,7 +67,7 @@ jobs: - name: 'Setup Use USE_ASAN' if: steps.changed-files-cpp.outputs.any_changed == 'true' run: | - echo "useasan=true" >> $GITHUB_ENV + echo "useasan=ON" >> $GITHUB_ENV echo "Setup USE_ASAN to true since cpp file(s) changed" - name: 'Generate CCache Hash' env: @@ -123,6 +123,12 @@ jobs: path: .docker/amd64-ubuntu${{ matrix.ubuntu }}-conan key: ubuntu${{ matrix.ubuntu }}-conan-${{ hashFiles('internal/core/conanfile.*') }} restore-keys: ubuntu${{ matrix.ubuntu }}-conan- + - name: Start Service + shell: bash + run: | + docker-compose up -d azurite +# - name: 'Setup upterm session' +# uses: lhotari/action-upterm@v1 - name: UT run: | chmod +x build/builder.sh @@ -166,7 +172,9 @@ jobs: - name: Start Service shell: bash run: | - docker-compose up -d pulsar etcd minio + docker-compose up -d pulsar etcd minio azurite +# - name: 'Setup upterm session' +# uses: lhotari/action-upterm@v1 - name: UT run: | chmod +x build/builder.sh diff --git a/.github/workflows/update-knowhere-commit.yaml b/.github/workflows/update-knowhere-commit.yaml index 6739c175bb2de..051f3b8d49290 100644 --- a/.github/workflows/update-knowhere-commit.yaml +++ b/.github/workflows/update-knowhere-commit.yaml @@ -39,7 +39,7 @@ jobs: continue-on-error: true shell: bash run: | - sed -i "s#( KNOWHERE_VERSION.*#( KNOWHERE_VERSION ${{ steps.get-knowhere-latest-commit.outputs.knowhere-commit }} )#g" internal/core/thirdparty/knowhere/CMakeLists.txt + sed -i "0,/(\ KNOWHERE_VERSION/ s#( KNOWHERE_VERSION.*#( KNOWHERE_VERSION ${{ steps.get-knowhere-latest-commit.outputs.knowhere-commit }} )#g" internal/core/thirdparty/knowhere/CMakeLists.txt head -n 17 internal/core/thirdparty/knowhere/CMakeLists.txt git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" git config --local user.name "github-actions[bot]" diff --git a/.gitignore b/.gitignore index 99a8e5736aa59..3becfcd2d0d55 100644 --- a/.gitignore +++ b/.gitignore @@ -5,11 +5,13 @@ # proxy/cmake-build-debug/ # a/b/c/cmake-build-debug/ **/cmake-build-debug/* +**/cmake-build-debug-coverage/* **/cmake-build-release/* **/cmake_build_release/* **/cmake_build/* **/CmakeFiles/* .cache +coverage_report/ internal/core/output/* internal/core/build/* diff --git a/.golangci.yml b/.golangci.yml index cca8c7a7b4024..3d6dd38ad6296 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,5 +1,5 @@ run: - go: '1.18' + go: "1.18" skip-dirs: - build - configs @@ -7,29 +7,80 @@ run: - docs - scripts - internal/core + - cmake_build linters: disable-all: true enable: - - typecheck - - goimports - - misspell + - gosimple - govet - ineffassign - - gosimple + - staticcheck + - decorder + - depguard + - gofmt + - goimports - gosec - revive - - durationcheck - unconvert + - misspell + - typecheck + - durationcheck - forbidigo - - depguard - # - gocritic + - gci + - whitespace + - gofumpt + - gocritic linters-settings: + gci: + sections: + - standard + - default + - prefix(github.com/milvus-io) + custom-order: true + gofumpt: + lang-version: "1.18" + module-path: github.com/milvus-io + goimports: + local-prefixes: github.com/milvus-io revive: rules: - name: unused-parameter disabled: true + - name: var-naming + severity: warning + disabled: false + arguments: + - ["ID"] # Allow list + - name: context-as-argument + severity: warning + disabled: false + arguments: + - allowTypesBefore: "*testing.T" + - name: datarace + severity: warning + disabled: false + - name: duplicated-imports + severity: warning + disabled: false + - name: waitgroup-by-value + severity: warning + disabled: false + - name: indent-error-flow + severity: warning + disabled: false + arguments: + - "preserveScope" + - name: range-val-in-closure + severity: warning + disabled: false + - name: range-val-address + severity: warning + disabled: false + - name: string-of-int + severity: warning + disabled: false misspell: locale: US gocritic: @@ -38,28 +89,40 @@ linters-settings: settings: ruleguard: failOnError: true - rules: 'rules.go' + rules: "rules.go" depguard: rules: main: deny: - - pkg: 'errors' - desc: not allowd, use github.com/cockroachdb/errors - - pkg: 'github.com/pkg/errors' - desc: not allowd, use github.com/cockroachdb/errors - - pkg: 'github.com/pingcap/errors' - desc: not allowd, use github.com/cockroachdb/errors - - pkg: 'golang.org/x/xerrors' - desc: not allowd, use github.com/cockroachdb/errors - - pkg: 'github.com/go-errors/errors' - desc: not allowd, use github.com/cockroachdb/errors + - pkg: "errors" + desc: not allowed, use github.com/cockroachdb/errors + - pkg: "github.com/pkg/errors" + desc: not allowed, use github.com/cockroachdb/errors + - pkg: "github.com/pingcap/errors" + desc: not allowed, use github.com/cockroachdb/errors + - pkg: "golang.org/x/xerrors" + desc: not allowed, use github.com/cockroachdb/errors + - pkg: "github.com/go-errors/errors" + desc: not allowed, use github.com/cockroachdb/errors + - pkg: "io/ioutil" + desc: ioutil is deprecated after 1.16, 1.17, use os and io package instead forbidigo: forbid: - '^time\.Tick$' + - 'return merr\.Err[a-zA-Z]+' + - 'merr\.Wrap\w+\(\)\.Error\(\)' + - '\.(ErrorCode|Reason) = ' + - 'Reason:\s+\w+\.Error\(\)' + - 'errors.New\((.+)\.GetReason\(\)\)' + - 'commonpb\.Status\{[\s\n]*ErrorCode:[\s\n]*.+[\s\S\n]*?\}' #- 'fmt\.Print.*' WIP issues: exclude-use-default: false + exclude-rules: + - path: .+_test\.go + linters: + - forbidigo exclude: - should have a package comment - should have comment @@ -81,10 +144,12 @@ issues: - G402 # Use of weak random number generator math/rand - G404 + # Unused parameters + - SA1019 + # defer return errors + - SA5001 + # Maximum issues count per one linter. Set to 0 to disable. Default is 50. max-issues-per-linter: 0 # Maximum count of issues with the same text. Set to 0 to disable. Default is 3. max-same-issues: 0 - -service: - golangci-lint-version: 1.27.0 # use the fixed version to not introduce new linters unexpectedly diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c74859be45915..4e6b15096ee19 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,4 +16,4 @@ repos: description: Detect secrets in your data. entry: bash -c 'trufflehog git file://. --max-depth 1 --since-commit HEAD --only-verified --fail' language: system - stages: ["commit", "push"] + stages: ["commit"] diff --git a/Makefile b/Makefile index 5ded97d8a72d5..f643f02edef59 100644 --- a/Makefile +++ b/Makefile @@ -19,11 +19,17 @@ INSTALL_PATH := $(PWD)/bin LIBRARY_PATH := $(PWD)/lib OS := $(shell uname -s) mode = Release -disk_index = OFF -useasan = false -ifeq (${USE_ASAN}, true) -useasan = true + +use_disk_index = OFF +ifdef disk_index + use_disk_index = ${disk_index} +endif + +use_asan = OFF +ifdef USE_ASAN + use_asan =${USE_ASAN} endif + use_dynamic_simd = OFF ifdef USE_DYNAMIC_SIMD use_dynamic_simd = ${USE_DYNAMIC_SIMD} @@ -36,9 +42,23 @@ INSTALL_GOLANGCI_LINT := $(findstring $(GOLANGCI_LINT_VERSION), $(GOLANGCI_LINT_ MOCKERY_VERSION := 2.32.4 MOCKERY_OUTPUT := $(shell $(INSTALL_PATH)/mockery --version 2>/dev/null) INSTALL_MOCKERY := $(findstring $(MOCKERY_VERSION),$(MOCKERY_OUTPUT)) +# gci +GCI_VERSION := 0.11.2 +GCI_OUTPUT := $(shell $(INSTALL_PATH)/gci --version 2>/dev/null) +INSTALL_GCI := $(findstring $(GCI_VERSION),$(GCI_OUTPUT)) +# gofumpt +GOFUMPT_VERSION := 0.5.0 +GOFUMPT_OUTPUT := $(shell $(INSTALL_PATH)/gofumpt --version 2>/dev/null) +INSTALL_GOFUMPT := $(findstring $(GOFUMPT_VERSION),$(GOFUMPT_OUTPUT)) + +index_engine = knowhere export GIT_BRANCH=master +ifeq (${ENABLE_AZURE}, false) + AZURE_OPTION := -Z +endif + milvus: build-cpp print-build-info @echo "Building Milvus ..." @source $(PWD)/scripts/setenv.sh && \ @@ -68,7 +88,7 @@ getdeps: echo "Installing mockery v$(MOCKERY_VERSION) to ./bin/" && GOBIN=$(INSTALL_PATH) go install github.com/vektra/mockery/v2@v$(MOCKERY_VERSION); \ else \ echo "Mockery v$(MOCKERY_VERSION) already installed"; \ - fi + fi tools/bin/revive: tools/check/go.mod cd tools/check; \ @@ -88,21 +108,39 @@ else @GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh internal/ @GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh tests/integration/ @GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh tests/go/ - @GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh pkg/ + @GO111MODULE=on env bash $(PWD)/scripts/gofmt.sh pkg/ endif -lint: tools/bin/revive - @echo "Running $@ check" - @tools/bin/revive -formatter friendly -config tools/check/revive.toml ./... +lint-fix: getdeps + @mkdir -p $(INSTALL_PATH) + @if [ -z "$(INSTALL_GCI)" ]; then \ + echo "Installing gci v$(GCI_VERSION) to ./bin/" && GOBIN=$(INSTALL_PATH) go install github.com/daixiang0/gci@v$(GCI_VERSION); \ + else \ + echo "gci v$(GCI_VERSION) already installed"; \ + fi + @if [ -z "$(INSTALL_GOFUMPT)" ]; then \ + echo "Installing gofumpt v$(GOFUMPT_VERSION) to ./bin/" && GOBIN=$(INSTALL_PATH) go install mvdan.cc/gofumpt@v$(GOFUMPT_VERSION); \ + else \ + echo "gofumpt v$(GOFUMPT_VERSION) already installed"; \ + fi + @echo "Running gofumpt fix" + @$(INSTALL_PATH)/gofumpt -l -w internal/ + @$(INSTALL_PATH)/gofumpt -l -w cmd/ + @$(INSTALL_PATH)/gofumpt -l -w pkg/ + @$(INSTALL_PATH)/gofumpt -l -w tests/integration/ + @echo "Running gci fix" + @$(INSTALL_PATH)/gci write cmd/ --skip-generated -s standard -s default -s "prefix(github.com/milvus-io)" --custom-order + @$(INSTALL_PATH)/gci write internal/ --skip-generated -s standard -s default -s "prefix(github.com/milvus-io)" --custom-order + @$(INSTALL_PATH)/gci write pkg/ --skip-generated -s standard -s default -s "prefix(github.com/milvus-io)" --custom-order + @$(INSTALL_PATH)/gci write tests/ --skip-generated -s standard -s default -s "prefix(github.com/milvus-io)" --custom-order + @echo "Running golangci-lint auto-fix" + @source $(PWD)/scripts/setenv.sh && GO111MODULE=on $(INSTALL_PATH)/golangci-lint run --fix --timeout=30m --config $(PWD)/.golangci.yml; cd pkg && GO111MODULE=on $(INSTALL_PATH)/golangci-lint run --fix --timeout=30m --config $(PWD)/.golangci.yml #TODO: Check code specifications by golangci-lint static-check: getdeps @echo "Running $@ check" - @GO111MODULE=on $(INSTALL_PATH)/golangci-lint cache clean - @source $(PWD)/scripts/setenv.sh && GO111MODULE=on $(INSTALL_PATH)/golangci-lint run --timeout=30m --config ./.golangci.yml ./internal/... - @source $(PWD)/scripts/setenv.sh && GO111MODULE=on $(INSTALL_PATH)/golangci-lint run --timeout=30m --config ./.golangci.yml ./cmd/... - @source $(PWD)/scripts/setenv.sh && GO111MODULE=on $(INSTALL_PATH)/golangci-lint run --timeout=30m --config ./.golangci.yml ./tests/integration/... - @source $(PWD)/scripts/setenv.sh && cd pkg && GO111MODULE=on $(INSTALL_PATH)/golangci-lint run --timeout=30m --config ../.golangci.yml ./... + @source $(PWD)/scripts/setenv.sh && GO111MODULE=on $(INSTALL_PATH)/golangci-lint run --timeout=30m --config $(PWD)/.golangci.yml + @source $(PWD)/scripts/setenv.sh && cd pkg && GO111MODULE=on $(INSTALL_PATH)/golangci-lint run --timeout=30m --config $(PWD)/.golangci.yml verifiers: build-cpp getdeps cppcheck fmt static-check @@ -169,19 +207,19 @@ generated-proto: download-milvus-proto build-3rdparty build-cpp: generated-proto @echo "Building Milvus cpp library ..." - @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -y ${use_dynamic_simd}) + @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -n ${use_disk_index} -y ${use_dynamic_simd} ${AZURE_OPTION} -x ${index_engine}) build-cpp-gpu: generated-proto - @echo "Building Milvus cpp gpu library ..." - @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -g -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -y ${use_dynamic_simd}) + @echo "Building Milvus cpp gpu library ... " + @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -g -n ${use_disk_index} -y ${use_dynamic_simd} ${AZURE_OPTION} -x ${index_engine}) build-cpp-with-unittest: generated-proto - @echo "Building Milvus cpp library with unittest ..." - @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -u -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -y ${use_dynamic_simd}) + @echo "Building Milvus cpp library with unittest ... " + @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -u -n ${use_disk_index} -y ${use_dynamic_simd} ${AZURE_OPTION} -x ${index_engine}) build-cpp-with-coverage: generated-proto @echo "Building Milvus cpp library with coverage and unittest ..." - @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -u -c -f "$(CUSTOM_THIRDPARTY_PATH)" -n ${disk_index} -y ${use_dynamic_simd}) + @(env bash $(PWD)/scripts/core_build.sh -t ${mode} -a ${use_asan} -u -c -n ${use_disk_index} -y ${use_dynamic_simd} ${AZURE_OPTION} -x ${index_engine}) check-proto-product: generated-proto @(env bash $(PWD)/scripts/check_proto_product.sh) @@ -296,13 +334,10 @@ gpu-install: milvus-gpu clean: @echo "Cleaning up all the generated files" - @find . -name '*.test' | xargs rm -fv - @find . -name '*~' | xargs rm -fv @rm -rf bin/ @rm -rf lib/ @rm -rf $(GOPATH)/bin/milvus @rm -rf cmake_build - @rm -rf cwrapper_build @rm -rf internal/core/output milvus-tools: print-build-info @@ -336,9 +371,9 @@ generate-mockery-types: getdeps # Proxy $(INSTALL_PATH)/mockery --name=ProxyComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_proxy.go --with-expecter --structname=MockProxy # QueryCoord - $(INSTALL_PATH)/mockery --name=QueryCoordComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_querycoord.go --with-expecter --structname=MockQueryCoord + $(INSTALL_PATH)/mockery --name=QueryCoordComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_querycoord.go --with-expecter --structname=MockQueryCoord # QueryNode - $(INSTALL_PATH)/mockery --name=QueryNodeComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_querynode.go --with-expecter --structname=MockQueryNode + $(INSTALL_PATH)/mockery --name=QueryNodeComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_querynode.go --with-expecter --structname=MockQueryNode # DataCoord $(INSTALL_PATH)/mockery --name=DataCoordComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_datacoord.go --with-expecter --structname=MockDataCoord # DataNode @@ -346,6 +381,15 @@ generate-mockery-types: getdeps # IndexNode $(INSTALL_PATH)/mockery --name=IndexNodeComponent --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_indexnode.go --with-expecter --structname=MockIndexNode + # Clients + $(INSTALL_PATH)/mockery --name=RootCoordClient --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_rootcoord_client.go --with-expecter --structname=MockRootCoordClient + $(INSTALL_PATH)/mockery --name=QueryCoordClient --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_querycoord_client.go --with-expecter --structname=MockQueryCoordClient + $(INSTALL_PATH)/mockery --name=DataCoordClient --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_datacoord_client.go --with-expecter --structname=MockDataCoordClient + $(INSTALL_PATH)/mockery --name=QueryNodeClient --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_querynode_client.go --with-expecter --structname=MockQueryNodeClient + $(INSTALL_PATH)/mockery --name=DataNodeClient --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_datanode_client.go --with-expecter --structname=MockDataNodeClient + $(INSTALL_PATH)/mockery --name=IndexNodeClient --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_indexnode_client.go --with-expecter --structname=MockIndexNodeClient + $(INSTALL_PATH)/mockery --name=ProxyClient --dir=$(PWD)/internal/types --output=$(PWD)/internal/mocks --filename=mock_proxy_client.go --with-expecter --structname=MockProxyClient + generate-mockery-rootcoord: getdeps $(INSTALL_PATH)/mockery --name=IMetaTable --dir=$(PWD)/internal/rootcoord --output=$(PWD)/internal/rootcoord/mocks --filename=meta_table.go --with-expecter --outpkg=mockrootcoord $(INSTALL_PATH)/mockery --name=GarbageCollector --dir=$(PWD)/internal/rootcoord --output=$(PWD)/internal/rootcoord/mocks --filename=garbage_collector.go --with-expecter --outpkg=mockrootcoord @@ -356,6 +400,7 @@ generate-mockery-proxy: getdeps $(INSTALL_PATH)/mockery --name=LBPolicy --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_policy.go --structname=MockLBPolicy --with-expecter --outpkg=proxy --inpackage $(INSTALL_PATH)/mockery --name=LBBalancer --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_balancer.go --structname=MockLBBalancer --with-expecter --outpkg=proxy --inpackage $(INSTALL_PATH)/mockery --name=shardClientMgr --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_shardclient_manager.go --structname=MockShardClientManager --with-expecter --outpkg=proxy --inpackage + $(INSTALL_PATH)/mockery --name=channelsMgr --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_channels_manager.go --structname=MockChannelsMgr --with-expecter --outpkg=proxy --inpackage generate-mockery-querycoord: getdeps $(INSTALL_PATH)/mockery --name=QueryNodeServer --dir=$(PWD)/internal/proto/querypb/ --output=$(PWD)/internal/querycoordv2/mocks --filename=mock_querynode.go --with-expecter --structname=MockQueryNodeServer @@ -367,6 +412,7 @@ generate-mockery-querycoord: getdeps generate-mockery-querynode: getdeps build-cpp @source $(PWD)/scripts/setenv.sh # setup PKG_CONFIG_PATH + $(INSTALL_PATH)/mockery --name=QueryHook --dir=$(PWD)/internal/querynodev2/optimizers --output=$(PWD)/internal/querynodev2/optimizers --filename=mock_query_hook.go --with-expecter --outpkg=optimizers --structname=MockQueryHook --inpackage $(INSTALL_PATH)/mockery --name=Manager --dir=$(PWD)/internal/querynodev2/cluster --output=$(PWD)/internal/querynodev2/cluster --filename=mock_manager.go --with-expecter --outpkg=cluster --structname=MockManager --inpackage $(INSTALL_PATH)/mockery --name=SegmentManager --dir=$(PWD)/internal/querynodev2/segments --output=$(PWD)/internal/querynodev2/segments --filename=mock_segment_manager.go --with-expecter --outpkg=segments --structname=MockSegmentManager --inpackage $(INSTALL_PATH)/mockery --name=CollectionManager --dir=$(PWD)/internal/querynodev2/segments --output=$(PWD)/internal/querynodev2/segments --filename=mock_collection_manager.go --with-expecter --outpkg=segments --structname=MockCollectionManager --inpackage @@ -382,11 +428,12 @@ generate-mockery-datacoord: getdeps generate-mockery-datanode: getdeps $(INSTALL_PATH)/mockery --name=Allocator --dir=$(PWD)/internal/datanode/allocator --output=$(PWD)/internal/datanode/allocator --filename=mock_allocator.go --with-expecter --structname=MockAllocator --outpkg=allocator --inpackage + $(INSTALL_PATH)/mockery --name=Broker --dir=$(PWD)/internal/datanode/broker --output=$(PWD)/internal/datanode/broker/ --filename=mock_broker.go --with-expecter --structname=MockBroker --outpkg=broker --inpackage generate-mockery-metastore: getdeps - $(INSTALL_PATH)/mockery --name=RootCoordCatalog --dir=$(PWD)/internal/metastore --output=$(PWD)/internal/metastore/mocks --filename=mock_rootcoord_catalog.go --with-expecter --structname=RootCoordCatalog --outpkg=mocks - $(INSTALL_PATH)/mockery --name=DataCoordCatalog --dir=$(PWD)/internal/metastore --output=$(PWD)/internal/metastore/mocks --filename=mock_datacoord_catalog.go --with-expecter --structname=DataCoordCatalog --outpkg=mocks - $(INSTALL_PATH)/mockery --name=QueryCoordCatalog --dir=$(PWD)/internal/metastore --output=$(PWD)/internal/metastore/mocks --filename=mock_querycoord_catalog.go --with-expecter --structname=QueryCoordCatalog --outpkg=mocks + $(INSTALL_PATH)/mockery --name=RootCoordCatalog --dir=$(PWD)/internal/metastore --output=$(PWD)/internal/metastore/mocks --filename=mock_rootcoord_catalog.go --with-expecter --structname=RootCoordCatalog --outpkg=mocks + $(INSTALL_PATH)/mockery --name=DataCoordCatalog --dir=$(PWD)/internal/metastore --output=$(PWD)/internal/metastore/mocks --filename=mock_datacoord_catalog.go --with-expecter --structname=DataCoordCatalog --outpkg=mocks + $(INSTALL_PATH)/mockery --name=QueryCoordCatalog --dir=$(PWD)/internal/metastore --output=$(PWD)/internal/metastore/mocks --filename=mock_querycoord_catalog.go --with-expecter --structname=QueryCoordCatalog --outpkg=mocks generate-mockery-utils: getdeps # dependency.Factory @@ -399,9 +446,10 @@ generate-mockery-kv: getdeps $(INSTALL_PATH)/mockery --name=MetaKv --dir=$(PWD)/internal/kv --output=$(PWD)/internal/kv/mocks --filename=meta_kv.go --with-expecter $(INSTALL_PATH)/mockery --name=WatchKV --dir=$(PWD)/internal/kv --output=$(PWD)/internal/kv/mocks --filename=watch_kv.go --with-expecter $(INSTALL_PATH)/mockery --name=SnapShotKV --dir=$(PWD)/internal/kv --output=$(PWD)/internal/kv/mocks --filename=snapshot_kv.go --with-expecter + $(INSTALL_PATH)/mockery --name=Predicate --dir=$(PWD)/internal/kv/predicates --output=$(PWD)/internal/kv/predicates --filename=mock_predicate.go --with-expecter --inpackage generate-mockery-pkg: $(MAKE) -C pkg generate-mockery generate-mockery: generate-mockery-types generate-mockery-kv generate-mockery-rootcoord generate-mockery-proxy generate-mockery-querycoord generate-mockery-querynode generate-mockery-datacoord generate-mockery-pkg - + diff --git a/README.md b/README.md index 8eb5351170e63..707e0b419e5c8 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut ### All contributors
-
+
@@ -188,6 +188,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut + @@ -262,6 +263,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut + @@ -276,6 +278,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut + @@ -319,6 +322,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut + @@ -399,6 +403,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut + @@ -409,6 +414,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut + @@ -447,6 +453,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut + @@ -472,6 +479,7 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut + @@ -481,7 +489,9 @@ Contributions to Milvus are welcome from everyone. See [Guidelines for Contribut + + diff --git a/README_CN.md b/README_CN.md index 11785188531b1..3cf9a188af826 100644 --- a/README_CN.md +++ b/README_CN.md @@ -154,7 +154,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 ### All contributors
-
+
@@ -173,6 +173,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 + @@ -247,6 +248,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 + @@ -261,6 +263,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 + @@ -304,6 +307,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 + @@ -384,6 +388,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 + @@ -394,6 +399,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 + @@ -432,6 +438,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 + @@ -457,6 +464,7 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 + @@ -466,7 +474,9 @@ Milvus [训练营](https://github.com/milvus-io/bootcamp)能够帮助你了解 + + diff --git a/build/builder.sh b/build/builder.sh index 9802764c7aa0e..5e58b7f07f088 100755 --- a/build/builder.sh +++ b/build/builder.sh @@ -9,8 +9,6 @@ if [[ -f "$toplevel/.env" ]]; then export $(cat $toplevel/.env | xargs) fi -export OS_NAME="${OS_NAME:-ubuntu20.04}" - pushd "${toplevel}" if [[ "${1-}" == "pull" ]]; then @@ -27,8 +25,6 @@ PLATFORM_ARCH="${PLATFORM_ARCH:-${IMAGE_ARCH}}" export IMAGE_ARCH=${PLATFORM_ARCH} -echo ${IMAGE_ARCH} - mkdir -p "${DOCKER_VOLUME_DIRECTORY:-.docker}/${IMAGE_ARCH}-${OS_NAME}-ccache" mkdir -p "${DOCKER_VOLUME_DIRECTORY:-.docker}/${IMAGE_ARCH}-${OS_NAME}-go-mod" mkdir -p "${DOCKER_VOLUME_DIRECTORY:-.docker}/${IMAGE_ARCH}-${OS_NAME}-vscode-extensions" diff --git a/build/docker/builder/cpu/amazonlinux2023/Dockerfile b/build/docker/builder/cpu/amazonlinux2023/Dockerfile index d7d0aba5ebf0b..c9f3f15cecdb2 100644 --- a/build/docker/builder/cpu/amazonlinux2023/Dockerfile +++ b/build/docker/builder/cpu/amazonlinux2023/Dockerfile @@ -13,31 +13,22 @@ FROM amazonlinux:2023 ARG TARGETARCH -RUN yum install -y wget g++ gcc gdb libstdc++-static git make zip unzip tar which \ +RUN yum install -y wget g++ gcc gdb libatomic libstdc++-static git make zip unzip tar which \ autoconf automake golang python3 python3-pip perl-FindBin texinfo \ - pkg-config libuuid-devel libaio && \ + pkg-config libuuid-devel libaio perl-IPC-Cmd && \ rm -rf /var/cache/yum/* RUN pip3 install conan==1.58.0 RUN echo "target arch $TARGETARCH" -RUN if [ "$TARGETARCH" = "amd64" ]; then CMAKE_SUFFIX=x86_64; else CMAKE_SUFFIX=aarch64; fi &&\ - wget -qO- "https://cmake.org/files/v3.24/cmake-3.24.4-linux-$CMAKE_SUFFIX.tar.gz" | tar --strip-components=1 -xz -C /usr/local +RUN wget -qO- "https://cmake.org/files/v3.24/cmake-3.24.4-linux-`uname -m`.tar.gz" | tar --strip-components=1 -xz -C /usr/local -RUN mkdir /tmp/ccache && wget -qO- https://github.com/ccache/ccache/releases/download/v4.8.2/ccache-4.8.2.tar.gz | tar --strip-components=1 -xz -C /tmp/ccache &&\ - cd /tmp/ccache && mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr .. && make && make install &&\ +RUN mkdir /tmp/ccache && cd /tmp/ccache &&\ + wget https://dl.fedoraproject.org/pub/epel/9/Everything/`uname -m`/Packages/h/hiredis-1.0.2-1.el9.`uname -m`.rpm &&\ + wget https://dl.fedoraproject.org/pub/epel/9/Everything/`uname -m`/Packages/c/ccache-4.5.1-2.el9.`uname -m`.rpm &&\ + rpm -i hiredis-1.0.2-1.el9.`uname -m`.rpm ccache-4.5.1-2.el9.`uname -m`.rpm &&\ rm -rf /tmp/ccache -# https://github.com/golang/go/issues/22040 Due to this issue, ld.gold cannot be used in the ARM environment, so the official golang package cannot be used. You need to use the golang package that comes with yum. -# Install Go -# ENV GOPATH /go -# ENV GOROOT /usr/local/go -# ENV GO111MODULE on -# ENV PATH $GOPATH/bin:$GOROOT/bin:$PATH -# RUN mkdir -p /usr/local/go && wget -qO- "https://go.dev/dl/go1.20.7.linux-$TARGETARCH.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go && \ -# mkdir -p "$GOPATH/src" "$GOPATH/bin" && \ -# go clean --modcache && \ -# chmod -R 777 "$GOPATH" && chmod -R a+w $(go env GOTOOLDIR) # refer: https://code.visualstudio.com/docs/remote/containers-advanced#_avoiding-extension-reinstalls-on-container-rebuild RUN mkdir -p /home/milvus/.vscode-server/extensions \ diff --git a/build/docker/builder/cpu/ubuntu20.04/Dockerfile b/build/docker/builder/cpu/ubuntu20.04/Dockerfile index 444a971132518..3e95294dfd5f2 100644 --- a/build/docker/builder/cpu/ubuntu20.04/Dockerfile +++ b/build/docker/builder/cpu/ubuntu20.04/Dockerfile @@ -13,6 +13,8 @@ FROM ubuntu:focal-20220426 ARG TARGETARCH +RUN if [ "$TARGETARCH" = "arm64" ]; then apt-get update && apt-get install -y ninja-build; fi + RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-certificates gnupg2 \ g++ gcc gdb gdbserver git make ccache libssl-dev zlib1g-dev zip unzip \ clang-format-10 clang-tidy-10 lcov libtool m4 autoconf automake python3 python3-pip \ @@ -23,8 +25,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-ce RUN pip3 install conan==1.58.0 RUN echo "target arch $TARGETARCH" -RUN if [ "$TARGETARCH" = "amd64" ]; then CMAKE_SUFFIX=x86_64; else CMAKE_SUFFIX=aarch64; fi &&\ - wget -qO- "https://cmake.org/files/v3.24/cmake-3.24.0-linux-$CMAKE_SUFFIX.tar.gz" | tar --strip-components=1 -xz -C /usr/local +RUN wget -qO- "https://cmake.org/files/v3.24/cmake-3.24.4-linux-`uname -m`.tar.gz" | tar --strip-components=1 -xz -C /usr/local # Install Go ENV GOPATH /go diff --git a/ci/jenkins/PR.groovy b/ci/jenkins/PR.groovy index f28f2f7f48bf0..37d7f656cba17 100644 --- a/ci/jenkins/PR.groovy +++ b/ci/jenkins/PR.groovy @@ -3,7 +3,7 @@ int total_timeout_minutes = 60 * 5 int e2e_timeout_seconds = 120 * 60 def imageTag='' -int case_timeout_seconds = 10 * 60 +int case_timeout_seconds = 20 * 60 def chart_version='4.0.6' pipeline { options { diff --git a/ci/jenkins/pod/e2e.yaml b/ci/jenkins/pod/e2e.yaml index 3a87861174455..ebf856ba5eea3 100644 --- a/ci/jenkins/pod/e2e.yaml +++ b/ci/jenkins/pod/e2e.yaml @@ -9,7 +9,7 @@ spec: enableServiceLinks: false containers: - name: pytest - image: harbor.milvus.io/dockerhub/milvusdb/pytest:20230830-a8e5dc3 + image: harbor.milvus.io/dockerhub/milvusdb/pytest:20231019-020ad9a resources: limits: cpu: "6" diff --git a/ci/jenkins/pod/rte.yaml b/ci/jenkins/pod/rte.yaml index 43adc7d78688c..ae81cd7e6314e 100644 --- a/ci/jenkins/pod/rte.yaml +++ b/ci/jenkins/pod/rte.yaml @@ -45,7 +45,7 @@ spec: - mountPath: /ci-logs name: ci-logs - name: pytest - image: harbor.milvus.io/dockerhub/milvusdb/pytest:20230830-a8e5dc3 + image: harbor.milvus.io/dockerhub/milvusdb/pytest:20231019-020ad9a resources: limits: cpu: "6" diff --git a/cmd/components/data_coord.go b/cmd/components/data_coord.go index 2e556242c3a11..f7878314739f3 100644 --- a/cmd/components/data_coord.go +++ b/cmd/components/data_coord.go @@ -19,13 +19,14 @@ package components import ( "context" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" ) // DataCoord implements grpc server of DataCoord server diff --git a/cmd/components/data_node.go b/cmd/components/data_node.go index 734d1e1686a54..25a7b9a91c37c 100644 --- a/cmd/components/data_node.go +++ b/cmd/components/data_node.go @@ -19,13 +19,14 @@ package components import ( "context" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" ) // DataNode implements DataNode grpc server diff --git a/cmd/components/index_coord.go b/cmd/components/index_coord.go index adaad1a2757b5..ff03f83789318 100644 --- a/cmd/components/index_coord.go +++ b/cmd/components/index_coord.go @@ -26,8 +26,7 @@ import ( ) // IndexCoord implements IndexCoord grpc server -type IndexCoord struct { -} +type IndexCoord struct{} // NewIndexCoord creates a new IndexCoord func NewIndexCoord(ctx context.Context, factory dependency.Factory) (*IndexCoord, error) { @@ -48,7 +47,6 @@ func (s *IndexCoord) Stop() error { // GetComponentStates returns indexnode's states func (s *IndexCoord) Health(ctx context.Context) commonpb.StateCode { - log.Info("IndexCoord is healthy") return commonpb.StateCode_Healthy } diff --git a/cmd/components/index_node.go b/cmd/components/index_node.go index 9c874a8801bac..4f947d35f4158 100644 --- a/cmd/components/index_node.go +++ b/cmd/components/index_node.go @@ -19,13 +19,14 @@ package components import ( "context" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" ) // IndexNode implements IndexNode grpc server @@ -43,7 +44,6 @@ func NewIndexNode(ctx context.Context, factory dependency.Factory) (*IndexNode, } n.svr = svr return n, nil - } // Run starts service diff --git a/cmd/components/proxy.go b/cmd/components/proxy.go index 37bb9ef13873b..61a62df495538 100644 --- a/cmd/components/proxy.go +++ b/cmd/components/proxy.go @@ -19,13 +19,14 @@ package components import ( "context" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" grpcproxy "github.com/milvus-io/milvus/internal/distributed/proxy" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" ) // Proxy implements Proxy grpc server diff --git a/cmd/components/query_node.go b/cmd/components/query_node.go index f44cffa8ce1e0..50570ec152fe4 100644 --- a/cmd/components/query_node.go +++ b/cmd/components/query_node.go @@ -46,7 +46,6 @@ func NewQueryNode(ctx context.Context, factory dependency.Factory) (*QueryNode, ctx: ctx, svr: svr, }, nil - } // Run starts service diff --git a/cmd/components/root_coord.go b/cmd/components/root_coord.go index e26a5c50fd4e9..720511902a911 100644 --- a/cmd/components/root_coord.go +++ b/cmd/components/root_coord.go @@ -19,13 +19,14 @@ package components import ( "context" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" rc "github.com/milvus-io/milvus/internal/distributed/rootcoord" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" ) // RootCoord implements RoodCoord grpc server diff --git a/cmd/embedded/embedded.go b/cmd/embedded/embedded.go index 1f2148a093f28..34b979a612483 100644 --- a/cmd/embedded/embedded.go +++ b/cmd/embedded/embedded.go @@ -16,8 +16,9 @@ package main +import "C" + import ( - "C" "os" "github.com/milvus-io/milvus/cmd/milvus" diff --git a/cmd/main.go b/cmd/main.go index 875126af32c33..9e02d743555ee 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -17,11 +17,81 @@ package main import ( + "log" "os" + "os/exec" + "os/signal" + "strings" + + "golang.org/x/exp/slices" "github.com/milvus-io/milvus/cmd/milvus" + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func main() { - milvus.RunMilvus(os.Args) + idx := slices.Index(os.Args, "--run-with-subprocess") + + // execute command as a subprocess if the command contains "--run-with-subprocess" + if idx > 0 { + args := slices.Delete(os.Args, idx, idx+1) + log.Println("run subprocess with cmd:", args) + + /* #nosec G204 */ + cmd := exec.Command(args[0], args[1:]...) + + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Start(); err != nil { + // Command not found on PATH, not executable, &c. + log.Fatal(err) + } + + // wait for the command to finish + waitCh := make(chan error, 1) + go func() { + waitCh <- cmd.Wait() + close(waitCh) + }() + + sc := make(chan os.Signal, 1) + signal.Notify(sc) + + // Need a for loop to handle multiple signals + for { + select { + case sig := <-sc: + if err := cmd.Process.Signal(sig); err != nil { + log.Println("error sending signal", sig, err) + } + case err := <-waitCh: + // clean session + paramtable.Init() + params := paramtable.Get() + if len(args) >= 3 { + metaPath := params.EtcdCfg.MetaRootPath.GetValue() + endpoints := params.EtcdCfg.Endpoints.GetValue() + etcdEndpoints := strings.Split(endpoints, ",") + + sessionSuffix := sessionutil.GetSessions(cmd.Process.Pid) + defer sessionutil.RemoveServerInfoFile(cmd.Process.Pid) + + if err := milvus.CleanSession(metaPath, etcdEndpoints, sessionSuffix); err != nil { + log.Println("clean session failed", err.Error()) + } + } + + if err != nil { + log.Println("subprocess exit, ", err.Error()) + } else { + log.Println("exit code:", cmd.ProcessState.ExitCode()) + } + return + } + } + } else { + milvus.RunMilvus(os.Args) + } } diff --git a/cmd/milvus/help.go b/cmd/milvus/help.go index c967b8228865b..3cb4d7c15912e 100644 --- a/cmd/milvus/help.go +++ b/cmd/milvus/help.go @@ -7,6 +7,11 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) +const ( + RunCmd = "run" + RoleMixture = "mixture" +) + var ( usageLine = fmt.Sprintf("Usage:\n"+ "%s\n%s\n%s\n%s\n", runLine, stopLine, mckLine, serverTypeLine) diff --git a/cmd/milvus/mck.go b/cmd/milvus/mck.go index c07758f7a7fc6..3775764bf07e9 100644 --- a/cmd/milvus/mck.go +++ b/cmd/milvus/mck.go @@ -364,7 +364,6 @@ func getTrashKey(taskType, key string) string { } func (c *mck) extractTask(prefix string, keys []string, values []string) { - for i := range keys { taskID, err := strconv.ParseInt(filepath.Base(keys[i]), 10, 64) if err != nil { @@ -520,7 +519,6 @@ func (c *mck) extractVecFieldIndexInfo(taskID int64, infos []*querypb.FieldIndex func (c *mck) unmarshalTask(taskID int64, t string) (string, []int64, []int64, error) { header := commonpb.MsgHeader{} err := proto.Unmarshal([]byte(t), &header) - if err != nil { return errReturn(taskID, "MsgHeader", err) } diff --git a/cmd/milvus/run.go b/cmd/milvus/run.go index b898e2d1c1a8e..bbb19eb88ab7a 100644 --- a/cmd/milvus/run.go +++ b/cmd/milvus/run.go @@ -10,96 +10,29 @@ import ( "go.uber.org/zap" - "github.com/milvus-io/milvus/cmd/roles" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/metricsinfo" - "github.com/milvus-io/milvus/pkg/util/typeutil" ) -const ( - RunCmd = "run" - roleMixture = "mixture" -) - -type run struct { - serverType string - // flags - svrAlias string - enableRootCoord bool - enableQueryCoord bool - enableDataCoord bool - enableIndexCoord bool - enableQueryNode bool - enableDataNode bool - enableIndexNode bool - enableProxy bool -} - -func (c *run) getHelp() string { - return runLine + "\n" + serverTypeLine -} +type run struct{} func (c *run) execute(args []string, flags *flag.FlagSet) { if len(args) < 3 { - fmt.Fprintln(os.Stderr, c.getHelp()) + fmt.Fprintln(os.Stderr, getHelp()) return } flags.Usage = func() { - fmt.Fprintln(os.Stderr, c.getHelp()) + fmt.Fprintln(os.Stderr, getHelp()) } - c.serverType = args[2] - c.formatFlags(args, flags) - - // make go ignore SIGPIPE when all cgo threads set mask of SIGPIPE + // make go ignore SIGPIPE when all cgo thread set mask SIGPIPE signal.Ignore(syscall.SIGPIPE) - role := roles.NewMilvusRoles() - role.Local = false - switch c.serverType { - case typeutil.RootCoordRole: - role.EnableRootCoord = true - case typeutil.ProxyRole: - role.EnableProxy = true - case typeutil.QueryCoordRole: - role.EnableQueryCoord = true - case typeutil.QueryNodeRole: - role.EnableQueryNode = true - case typeutil.DataCoordRole: - role.EnableDataCoord = true - case typeutil.DataNodeRole: - role.EnableDataNode = true - case typeutil.IndexCoordRole: - role.EnableIndexCoord = true - case typeutil.IndexNodeRole: - role.EnableIndexNode = true - case typeutil.StandaloneRole, typeutil.EmbeddedRole: - role.EnableRootCoord = true - role.EnableProxy = true - role.EnableQueryCoord = true - role.EnableQueryNode = true - role.EnableDataCoord = true - role.EnableDataNode = true - role.EnableIndexCoord = true - role.EnableIndexNode = true - role.Local = true - role.Embedded = c.serverType == typeutil.EmbeddedRole - case roleMixture: - role.EnableRootCoord = c.enableRootCoord - role.EnableQueryCoord = c.enableQueryCoord - role.EnableDataCoord = c.enableDataCoord - role.EnableIndexCoord = c.enableIndexCoord - role.EnableQueryNode = c.enableQueryNode - role.EnableDataNode = c.enableDataNode - role.EnableIndexNode = c.enableIndexNode - role.EnableProxy = c.enableProxy - default: - fmt.Fprintf(os.Stderr, "Unknown server type = %s\n%s", c.serverType, c.getHelp()) - os.Exit(-1) - } + serverType := args[2] + roles := GetMilvusRoles(args, flags) + // setup config for embedded milvus - runtimeDir := createRuntimeDir(c.serverType) - filename := getPidFileName(c.serverType, c.svrAlias) + runtimeDir := createRuntimeDir(serverType) + filename := getPidFileName(serverType, roles.Alias) c.printBanner(flags.Output()) c.injectVariablesToEnv() @@ -108,29 +41,7 @@ func (c *run) execute(args []string, flags *flag.FlagSet) { panic(err) } defer removePidFile(lock) - role.Run(c.svrAlias) -} - -func (c *run) formatFlags(args []string, flags *flag.FlagSet) { - flags.StringVar(&c.svrAlias, "alias", "", "set alias") - - flags.BoolVar(&c.enableRootCoord, typeutil.RootCoordRole, false, "enable root coordinator") - flags.BoolVar(&c.enableQueryCoord, typeutil.QueryCoordRole, false, "enable query coordinator") - flags.BoolVar(&c.enableIndexCoord, typeutil.IndexCoordRole, false, "enable index coordinator") - flags.BoolVar(&c.enableDataCoord, typeutil.DataCoordRole, false, "enable data coordinator") - - flags.BoolVar(&c.enableQueryNode, typeutil.QueryNodeRole, false, "enable query node") - flags.BoolVar(&c.enableDataNode, typeutil.DataNodeRole, false, "enable data node") - flags.BoolVar(&c.enableIndexNode, typeutil.IndexNodeRole, false, "enable index node") - flags.BoolVar(&c.enableProxy, typeutil.ProxyRole, false, "enable proxy node") - - if c.serverType == typeutil.EmbeddedRole { - flags.SetOutput(io.Discard) - } - hardware.InitMaxprocs(c.serverType, flags) - if err := flags.Parse(args[3:]); err != nil { - os.Exit(-1) - } + roles.Run() } func (c *run) printBanner(w io.Writer) { diff --git a/cmd/milvus/stop.go b/cmd/milvus/stop.go index 66517bd21b3d8..1392125e6f44e 100644 --- a/cmd/milvus/stop.go +++ b/cmd/milvus/stop.go @@ -61,7 +61,7 @@ func (c *stop) formatFlags(args []string, flags *flag.FlagSet) { func (c *stop) stopPid(filename string, runtimeDir string) error { var pid int - fd, err := os.OpenFile(path.Join(runtimeDir, filename), os.O_RDONLY, 0664) + fd, err := os.OpenFile(path.Join(runtimeDir, filename), os.O_RDONLY, 0o664) if err != nil { return err } diff --git a/cmd/milvus/util.go b/cmd/milvus/util.go index 1e40bb4048e69..e042190af2a3e 100644 --- a/cmd/milvus/util.go +++ b/cmd/milvus/util.go @@ -1,20 +1,34 @@ package milvus import ( + "context" + "encoding/json" + "flag" "fmt" "io" - "io/ioutil" "os" "path" "runtime" + "strconv" + "strings" + "time" + "github.com/cockroachdb/errors" "github.com/gofrs/flock" + "github.com/samber/lo" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" + "github.com/milvus-io/milvus/cmd/roles" + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/typeutil" ) func makeRuntimeDir(dir string) error { - perm := os.FileMode(0755) + perm := os.FileMode(0o755) // os.MkdirAll equal to `mkdir -p` err := os.MkdirAll(dir, perm) if err != nil { @@ -22,7 +36,7 @@ func makeRuntimeDir(dir string) error { return fmt.Errorf("create runtime dir %s failed, err: %s", dir, err.Error()) } - tmpFile, err := ioutil.TempFile(dir, "tmp") + tmpFile, err := os.CreateTemp(dir, "tmp") if err != nil { return err } @@ -63,7 +77,7 @@ func createRuntimeDir(sType string) string { func createPidFile(w io.Writer, filename string, runtimeDir string) (*flock.Flock, error) { fileFullName := path.Join(runtimeDir, filename) - fd, err := os.OpenFile(fileFullName, os.O_CREATE|os.O_RDWR, 0664) + fd, err := os.OpenFile(fileFullName, os.O_CREATE|os.O_RDWR, 0o664) if err != nil { return nil, fmt.Errorf("file %s is locked, error = %w", filename, err) } @@ -106,3 +120,190 @@ func removePidFile(lock *flock.Flock) { lock.Close() os.Remove(filename) } + +func GetMilvusRoles(args []string, flags *flag.FlagSet) *roles.MilvusRoles { + alias, enableRootCoord, enableQueryCoord, enableIndexCoord, enableDataCoord, enableQueryNode, + enableDataNode, enableIndexNode, enableProxy := formatFlags(args, flags) + + serverType := args[2] + role := roles.NewMilvusRoles() + role.Alias = alias + + switch serverType { + case typeutil.RootCoordRole: + role.EnableRootCoord = true + case typeutil.ProxyRole: + role.EnableProxy = true + case typeutil.QueryCoordRole: + role.EnableQueryCoord = true + case typeutil.QueryNodeRole: + role.EnableQueryNode = true + case typeutil.DataCoordRole: + role.EnableDataCoord = true + case typeutil.DataNodeRole: + role.EnableDataNode = true + case typeutil.IndexCoordRole: + role.EnableIndexCoord = true + case typeutil.IndexNodeRole: + role.EnableIndexNode = true + case typeutil.StandaloneRole, typeutil.EmbeddedRole: + role.EnableRootCoord = true + role.EnableProxy = true + role.EnableQueryCoord = true + role.EnableQueryNode = true + role.EnableDataCoord = true + role.EnableDataNode = true + role.EnableIndexCoord = true + role.EnableIndexNode = true + role.Local = true + role.Embedded = serverType == typeutil.EmbeddedRole + case RoleMixture: + role.EnableRootCoord = enableRootCoord + role.EnableQueryCoord = enableQueryCoord + role.EnableDataCoord = enableDataCoord + role.EnableIndexCoord = enableIndexCoord + role.EnableQueryNode = enableQueryNode + role.EnableDataNode = enableDataNode + role.EnableIndexNode = enableIndexNode + role.EnableProxy = enableProxy + default: + fmt.Fprintf(os.Stderr, "Unknown server type = %s\n%s", serverType, getHelp()) + os.Exit(-1) + } + + return role +} + +func formatFlags(args []string, flags *flag.FlagSet) (alias string, enableRootCoord, enableQueryCoord, + enableIndexCoord, enableDataCoord, enableQueryNode, enableDataNode, enableIndexNode, enableProxy bool, +) { + flags.StringVar(&alias, "alias", "", "set alias") + + flags.BoolVar(&enableRootCoord, typeutil.RootCoordRole, false, "enable root coordinator") + flags.BoolVar(&enableQueryCoord, typeutil.QueryCoordRole, false, "enable query coordinator") + flags.BoolVar(&enableIndexCoord, typeutil.IndexCoordRole, false, "enable index coordinator") + flags.BoolVar(&enableDataCoord, typeutil.DataCoordRole, false, "enable data coordinator") + + flags.BoolVar(&enableQueryNode, typeutil.QueryNodeRole, false, "enable query node") + flags.BoolVar(&enableDataNode, typeutil.DataNodeRole, false, "enable data node") + flags.BoolVar(&enableIndexNode, typeutil.IndexNodeRole, false, "enable index node") + flags.BoolVar(&enableProxy, typeutil.ProxyRole, false, "enable proxy node") + + serverType := args[2] + if serverType == typeutil.EmbeddedRole { + flags.SetOutput(io.Discard) + } + hardware.InitMaxprocs(serverType, flags) + if err := flags.Parse(args[3:]); err != nil { + os.Exit(-1) + } + return +} + +func getHelp() string { + return runLine + "\n" + serverTypeLine +} + +func CleanSession(metaPath string, etcdEndpoints []string, sessionSuffix []string) error { + if len(sessionSuffix) == 0 { + log.Warn("not found session info , skip to clean sessions") + return nil + } + + etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) + if err != nil { + return err + } + defer etcdCli.Close() + + ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) + defer cancel() + + keys := getSessionPaths(ctx, etcdCli, metaPath, sessionSuffix) + if len(keys) == 0 { + return nil + } + + for _, key := range keys { + _, _ = etcdCli.Delete(ctx, key, clientv3.WithPrefix()) + } + log.Info("clean sessions from etcd", zap.Any("keys", keys)) + return nil +} + +func getSessionPaths(ctx context.Context, client *clientv3.Client, metaPath string, sessionSuffix []string) []string { + sessionKeys := make([]string, 0) + sessionPathPrefix := path.Join(metaPath, sessionutil.DefaultServiceRoot) + newSessionSuffixSet := addActiveKeySuffix(ctx, client, sessionPathPrefix, sessionSuffix) + for _, suffix := range newSessionSuffixSet { + key := path.Join(sessionPathPrefix, suffix) + sessionKeys = append(sessionKeys, key) + } + return sessionKeys +} + +// filterUnmatchedKey skip active keys that don't match completed key, the latest active key may from standby server +func addActiveKeySuffix(ctx context.Context, client *clientv3.Client, sessionPathPrefix string, sessionSuffix []string) []string { + suffixSet := lo.SliceToMap(sessionSuffix, func(t string) (string, struct{}) { + return t, struct{}{} + }) + + for _, suffix := range sessionSuffix { + if strings.Contains(suffix, "-") && (strings.HasPrefix(suffix, typeutil.RootCoordRole) || + strings.HasPrefix(suffix, typeutil.QueryCoordRole) || strings.HasPrefix(suffix, typeutil.DataCoordRole) || + strings.HasPrefix(suffix, typeutil.IndexCoordRole)) { + res := strings.Split(suffix, "-") + if len(res) != 2 { + // skip illegal keys + log.Warn("skip illegal key", zap.String("suffix", suffix)) + continue + } + + serverType := res[0] + targetServerID, err := strconv.ParseInt(res[1], 10, 64) + if err != nil { + log.Warn("get server id failed from key", zap.String("suffix", suffix), zap.Error(err)) + continue + } + + key := path.Join(sessionPathPrefix, serverType) + serverID, err := getServerID(ctx, client, key) + if err != nil { + log.Warn("get server id failed from key", zap.String("suffix", suffix), zap.Error(err)) + continue + } + + if serverID == targetServerID { + log.Info("add active serverID key", zap.String("suffix", suffix), zap.String("key", key)) + suffixSet[serverType] = struct{}{} + } + + // also remove a faked indexcoord seesion if role is a datacoord + if strings.HasPrefix(suffix, typeutil.DataCoordRole) { + suffixSet[typeutil.IndexCoordRole] = struct{}{} + } + } + } + + return lo.MapToSlice(suffixSet, func(key string, v struct{}) string { return key }) +} + +func getServerID(ctx context.Context, client *clientv3.Client, key string) (int64, error) { + resp, err := client.Get(ctx, key) + if err != nil { + return 0, err + } + + if len(resp.Kvs) == 0 { + return 0, errors.New("not found value") + } + + value := resp.Kvs[0].Value + session := &sessionutil.SessionRaw{} + err = json.Unmarshal(value, &session) + if err != nil { + return 0, err + } + + return session.ServerID, nil +} diff --git a/cmd/roles/roles.go b/cmd/roles/roles.go index 550b575719d11..7d3d62e49a22d 100644 --- a/cmd/roles/roles.go +++ b/cmd/roles/roles.go @@ -27,6 +27,10 @@ import ( "syscall" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.uber.org/zap" + "github.com/milvus-io/milvus/cmd/components" "github.com/milvus-io/milvus/internal/http" "github.com/milvus-io/milvus/internal/http/healthz" @@ -37,14 +41,12 @@ import ( "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/generic" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" _ "github.com/milvus-io/milvus/pkg/util/symbolizer" // support symbolizer and crash dump "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" - "go.uber.org/zap" ) // all milvus related metrics is in a separate registry @@ -91,7 +93,7 @@ func runComponent[T component](ctx context.Context, runWg *sync.WaitGroup, creator func(context.Context, dependency.Factory) (T, error), metricRegister func(*prometheus.Registry), -) T { +) component { var role T sign := make(chan struct{}) @@ -118,6 +120,9 @@ func runComponent[T component](ctx context.Context, healthz.Register(role) metricRegister(Registry.GoRegistry) + if generic.IsZero(role) { + return nil + } return role } @@ -133,9 +138,11 @@ type MilvusRoles struct { EnableIndexNode bool `env:"ENABLE_INDEX_NODE"` Local bool + Alias string Embedded bool - closed chan struct{} - once sync.Once + + closed chan struct{} + once sync.Once } // NewMilvusRoles creates a new MilvusRoles with private fields initialized. @@ -161,46 +168,52 @@ func (mr *MilvusRoles) printLDPreLoad() { } } -func (mr *MilvusRoles) runRootCoord(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *components.RootCoord { +func (mr *MilvusRoles) runRootCoord(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component { wg.Add(1) return runComponent(ctx, localMsg, wg, components.NewRootCoord, metrics.RegisterRootCoord) } -func (mr *MilvusRoles) runProxy(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *components.Proxy { +func (mr *MilvusRoles) runProxy(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component { wg.Add(1) return runComponent(ctx, localMsg, wg, components.NewProxy, metrics.RegisterProxy) } -func (mr *MilvusRoles) runQueryCoord(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *components.QueryCoord { +func (mr *MilvusRoles) runQueryCoord(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component { wg.Add(1) return runComponent(ctx, localMsg, wg, components.NewQueryCoord, metrics.RegisterQueryCoord) } -func (mr *MilvusRoles) runQueryNode(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *components.QueryNode { +func (mr *MilvusRoles) runQueryNode(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component { wg.Add(1) + // clear local storage rootPath := paramtable.Get().LocalStorageCfg.Path.GetValue() queryDataLocalPath := filepath.Join(rootPath, typeutil.QueryNodeRole) cleanLocalDir(queryDataLocalPath) + // clear mmap dir + mmapDir := paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue() + if len(mmapDir) > 0 { + cleanLocalDir(mmapDir) + } return runComponent(ctx, localMsg, wg, components.NewQueryNode, metrics.RegisterQueryNode) } -func (mr *MilvusRoles) runDataCoord(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *components.DataCoord { +func (mr *MilvusRoles) runDataCoord(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component { wg.Add(1) return runComponent(ctx, localMsg, wg, components.NewDataCoord, metrics.RegisterDataCoord) } -func (mr *MilvusRoles) runDataNode(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *components.DataNode { +func (mr *MilvusRoles) runDataNode(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component { wg.Add(1) return runComponent(ctx, localMsg, wg, components.NewDataNode, metrics.RegisterDataNode) } -func (mr *MilvusRoles) runIndexCoord(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *components.IndexCoord { +func (mr *MilvusRoles) runIndexCoord(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component { wg.Add(1) return runComponent(ctx, localMsg, wg, components.NewIndexCoord, func(registry *prometheus.Registry) {}) } -func (mr *MilvusRoles) runIndexNode(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *components.IndexNode { +func (mr *MilvusRoles) runIndexNode(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component { wg.Add(1) rootPath := paramtable.Get().LocalStorageCfg.Path.GetValue() indexDataLocalPath := filepath.Join(rootPath, typeutil.IndexNodeRole) @@ -237,6 +250,7 @@ func (mr *MilvusRoles) setupLogger() { // Register serves prometheus http service func setupPrometheusHTTPServer(r *internalmetrics.MilvusRegistry) { + log.Info("setupPrometheusHTTPServer") http.Register(&http.Handler{ Path: "/metrics", Handler: promhttp.HandlerFor(r, promhttp.HandlerOpts{}), @@ -282,7 +296,7 @@ func (mr *MilvusRoles) handleSignals() func() { } // Run Milvus components. -func (mr *MilvusRoles) Run(alias string) { +func (mr *MilvusRoles) Run() { // start signal handler, defer close func closeFn := mr.handleSignals() defer closeFn() @@ -325,54 +339,49 @@ func (mr *MilvusRoles) Run(alias string) { } http.ServeHTTP() + setupPrometheusHTTPServer(Registry) - var rc *components.RootCoord var wg sync.WaitGroup local := mr.Local + + var rootCoord, queryCoord, indexCoord, dataCoord component + var proxy, dataNode, indexNode, queryNode component if mr.EnableRootCoord { - rc = mr.runRootCoord(ctx, local, &wg) + rootCoord = mr.runRootCoord(ctx, local, &wg) } - var pn *components.Proxy if mr.EnableProxy { - pn = mr.runProxy(ctx, local, &wg) + proxy = mr.runProxy(ctx, local, &wg) } - var qs *components.QueryCoord if mr.EnableQueryCoord { - qs = mr.runQueryCoord(ctx, local, &wg) + queryCoord = mr.runQueryCoord(ctx, local, &wg) } - var qn *components.QueryNode if mr.EnableQueryNode { - qn = mr.runQueryNode(ctx, local, &wg) + queryNode = mr.runQueryNode(ctx, local, &wg) } - var ds *components.DataCoord if mr.EnableDataCoord { - ds = mr.runDataCoord(ctx, local, &wg) + dataCoord = mr.runDataCoord(ctx, local, &wg) } - var dn *components.DataNode if mr.EnableDataNode { - dn = mr.runDataNode(ctx, local, &wg) + dataNode = mr.runDataNode(ctx, local, &wg) } - var is *components.IndexCoord if mr.EnableIndexCoord { - is = mr.runIndexCoord(ctx, local, &wg) + indexCoord = mr.runIndexCoord(ctx, local, &wg) } - var in *components.IndexNode if mr.EnableIndexNode { - in = mr.runIndexNode(ctx, local, &wg) + indexNode = mr.runIndexNode(ctx, local, &wg) } wg.Wait() mr.setupLogger() tracer.Init() - setupPrometheusHTTPServer(Registry) paramtable.SetCreateTime(time.Now()) paramtable.SetUpdateTime(time.Now()) @@ -381,9 +390,11 @@ func (mr *MilvusRoles) Run(alias string) { // stop coordinators first // var component - coordinators := []component{rc, qs, ds, is} - for _, coord := range coordinators { + coordinators := []component{rootCoord, queryCoord, dataCoord, indexCoord} + for idx, coord := range coordinators { + log.Warn("stop processing") if coord != nil { + log.Warn("stop coord", zap.Int("idx", idx), zap.Any("coord", coord)) wg.Add(1) go func(coord component) { defer wg.Done() @@ -395,7 +406,7 @@ func (mr *MilvusRoles) Run(alias string) { log.Info("All coordinators have stopped") // stop nodes - nodes := []component{qn, in, dn} + nodes := []component{queryNode, indexNode, dataNode} for _, node := range nodes { if node != nil { wg.Add(1) @@ -409,10 +420,39 @@ func (mr *MilvusRoles) Run(alias string) { log.Info("All nodes have stopped") // stop proxy - if pn != nil { - pn.Stop() + if proxy != nil { + proxy.Stop() log.Info("proxy stopped") } log.Info("Milvus components graceful stop done") } + +func (mr *MilvusRoles) GetRoles() []string { + roles := make([]string, 0) + if mr.EnableRootCoord { + roles = append(roles, typeutil.RootCoordRole) + } + if mr.EnableProxy { + roles = append(roles, typeutil.ProxyRole) + } + if mr.EnableQueryCoord { + roles = append(roles, typeutil.QueryCoordRole) + } + if mr.EnableQueryNode { + roles = append(roles, typeutil.QueryNodeRole) + } + if mr.EnableDataCoord { + roles = append(roles, typeutil.DataCoordRole) + } + if mr.EnableDataNode { + roles = append(roles, typeutil.DataNodeRole) + } + if mr.EnableIndexCoord { + roles = append(roles, typeutil.IndexCoordRole) + } + if mr.EnableIndexNode { + roles = append(roles, typeutil.IndexNodeRole) + } + return roles +} diff --git a/cmd/tools/config/generate.go b/cmd/tools/config/generate.go index a30f2a46c8b80..0e6a4d5571f52 100644 --- a/cmd/tools/config/generate.go +++ b/cmd/tools/config/generate.go @@ -7,12 +7,13 @@ import ( "reflect" "strings" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/samber/lo" "go.uber.org/zap" "golang.org/x/exp/slices" + + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type DocContent struct { @@ -106,7 +107,7 @@ type YamlMarshaller struct { } func (m *YamlMarshaller) writeYamlRecursive(data []DocContent, level int) { - var topLevels = typeutil.NewOrderedMap[string, []DocContent]() + topLevels := typeutil.NewOrderedMap[string, []DocContent]() for _, d := range data { key := strings.Split(d.key, ".")[level] diff --git a/cmd/tools/config/main.go b/cmd/tools/config/main.go index 73aa670c3ef4c..8d6d0abfe1c2a 100644 --- a/cmd/tools/config/main.go +++ b/cmd/tools/config/main.go @@ -36,5 +36,4 @@ func main() { default: log.Error(fmt.Sprintf("unknown argument %s", args[1])) } - } diff --git a/cmd/tools/config/printer.go b/cmd/tools/config/printer.go index 54cd90b4bf4dd..20300694868dc 100644 --- a/cmd/tools/config/printer.go +++ b/cmd/tools/config/printer.go @@ -5,9 +5,10 @@ import ( "os" "sort" - "github.com/milvus-io/milvus/pkg/log" "github.com/spf13/viper" "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" ) func ShowYaml(filepath string) { diff --git a/cmd/tools/datameta/main.go b/cmd/tools/datameta/main.go index c5e20035f6ac7..d6bdffc35c2cb 100644 --- a/cmd/tools/datameta/main.go +++ b/cmd/tools/datameta/main.go @@ -7,12 +7,13 @@ import ( "strings" "github.com/golang/protobuf/proto" + "go.uber.org/zap" + etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/tsoutil" - "go.uber.org/zap" ) var ( diff --git a/cmd/tools/migration/allocator/atomic_allocator.go b/cmd/tools/migration/allocator/atomic_allocator.go index 894747ec06d4e..cce5a4c0b3cba 100644 --- a/cmd/tools/migration/allocator/atomic_allocator.go +++ b/cmd/tools/migration/allocator/atomic_allocator.go @@ -1,8 +1,9 @@ package allocator import ( - "github.com/milvus-io/milvus/pkg/util/typeutil" "go.uber.org/atomic" + + "github.com/milvus-io/milvus/pkg/util/typeutil" ) const ( diff --git a/cmd/tools/migration/backend/backend.go b/cmd/tools/migration/backend/backend.go index 915477f359ee0..91c33e431d240 100644 --- a/cmd/tools/migration/backend/backend.go +++ b/cmd/tools/migration/backend/backend.go @@ -6,10 +6,8 @@ import ( "github.com/blang/semver/v4" "github.com/milvus-io/milvus/cmd/tools/migration/configs" - - "github.com/milvus-io/milvus/cmd/tools/migration/versions" - "github.com/milvus-io/milvus/cmd/tools/migration/meta" + "github.com/milvus-io/milvus/cmd/tools/migration/versions" "github.com/milvus-io/milvus/pkg/util" ) diff --git a/cmd/tools/migration/backend/backup_header.go b/cmd/tools/migration/backend/backup_header.go index 40e30ed8e9dd6..59436505ed7cb 100644 --- a/cmd/tools/migration/backend/backup_header.go +++ b/cmd/tools/migration/backend/backup_header.go @@ -3,9 +3,9 @@ package backend import ( "encoding/json" - "github.com/milvus-io/milvus/cmd/tools/migration/console" - "github.com/golang/protobuf/proto" + + "github.com/milvus-io/milvus/cmd/tools/migration/console" ) type BackupHeaderVersion int32 @@ -78,7 +78,7 @@ func (v *BackupHeaderExtra) ToJSONBytes() []byte { } func GetExtra(extra []byte) *BackupHeaderExtra { - var v = newDefaultBackupHeaderExtra() + v := newDefaultBackupHeaderExtra() err := json.Unmarshal(extra, v) if err != nil { console.Error(err.Error()) diff --git a/cmd/tools/migration/backend/backup_restore.go b/cmd/tools/migration/backend/backup_restore.go index 6980c25bee2e0..c21d096cdf1be 100644 --- a/cmd/tools/migration/backend/backup_restore.go +++ b/cmd/tools/migration/backend/backup_restore.go @@ -6,6 +6,7 @@ import ( "io" "github.com/golang/protobuf/proto" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) diff --git a/cmd/tools/migration/backend/etcd.go b/cmd/tools/migration/backend/etcd.go index b4ad68f56d49c..0f5e4d28a507c 100644 --- a/cmd/tools/migration/backend/etcd.go +++ b/cmd/tools/migration/backend/etcd.go @@ -1,11 +1,12 @@ package backend import ( + clientv3 "go.etcd.io/etcd/client/v3" + "github.com/milvus-io/milvus/cmd/tools/migration/configs" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/pkg/util/etcd" - clientv3 "go.etcd.io/etcd/client/v3" ) type etcdBasedBackend struct { diff --git a/cmd/tools/migration/backend/etcd210.go b/cmd/tools/migration/backend/etcd210.go index 527ec44fe6a83..e66cbd2a815b9 100644 --- a/cmd/tools/migration/backend/etcd210.go +++ b/cmd/tools/migration/backend/etcd210.go @@ -3,20 +3,18 @@ package backend import ( "context" "fmt" - "io/ioutil" + "os" "path" "strconv" "strings" + "github.com/golang/protobuf/proto" clientv3 "go.etcd.io/etcd/client/v3" "github.com/milvus-io/milvus/cmd/tools/migration/configs" + "github.com/milvus-io/milvus/cmd/tools/migration/console" "github.com/milvus-io/milvus/cmd/tools/migration/legacy" - "github.com/milvus-io/milvus/cmd/tools/migration/legacy/legacypb" - - "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus/cmd/tools/migration/console" "github.com/milvus-io/milvus/cmd/tools/migration/meta" "github.com/milvus-io/milvus/cmd/tools/migration/utils" "github.com/milvus-io/milvus/cmd/tools/migration/versions" @@ -56,7 +54,7 @@ func (b etcd210) loadTtAliases() (meta.TtAliasesMeta210, error) { tsKey := keys[i] tsValue := values[i] valueIsTombstone := rootcoord.IsTombstone(tsValue) - var aliasInfo = &pb.CollectionInfo{} // alias stored in collection info. + aliasInfo := &pb.CollectionInfo{} // alias stored in collection info. if valueIsTombstone { aliasInfo = nil } else { @@ -88,7 +86,7 @@ func (b etcd210) loadAliases() (meta.AliasesMeta210, error) { key := keys[i] value := values[i] valueIsTombstone := rootcoord.IsTombstone(value) - var aliasInfo = &pb.CollectionInfo{} // alias stored in collection info. + aliasInfo := &pb.CollectionInfo{} // alias stored in collection info. if valueIsTombstone { aliasInfo = nil } else { @@ -122,7 +120,7 @@ func (b etcd210) loadTtCollections() (meta.TtCollectionsMeta210, error) { } valueIsTombstone := rootcoord.IsTombstone(tsValue) - var coll = &pb.CollectionInfo{} + coll := &pb.CollectionInfo{} if valueIsTombstone { coll = nil } else { @@ -164,7 +162,7 @@ func (b etcd210) loadCollections() (meta.CollectionsMeta210, error) { } valueIsTombstone := rootcoord.IsTombstone(value) - var coll = &pb.CollectionInfo{} + coll := &pb.CollectionInfo{} if valueIsTombstone { coll = nil } else { @@ -213,7 +211,7 @@ func (b etcd210) loadCollectionIndexes() (meta.CollectionIndexesMeta210, error) key := keys[i] value := values[i] - var index = &pb.IndexInfo{} + index := &pb.IndexInfo{} if err := proto.Unmarshal([]byte(value), index); err != nil { return nil, err } @@ -240,7 +238,7 @@ func (b etcd210) loadSegmentIndexes() (meta.SegmentIndexesMeta210, error) { for i := 0; i < l; i++ { value := values[i] - var index = &pb.SegmentIndexInfo{} + index := &pb.SegmentIndexInfo{} if err := proto.Unmarshal([]byte(value), index); err != nil { return nil, err } @@ -263,7 +261,7 @@ func (b etcd210) loadIndexBuildMeta() (meta.IndexBuildMeta210, error) { for i := 0; i < l; i++ { value := values[i] - var record = &legacypb.IndexMeta{} + record := &legacypb.IndexMeta{} if err := proto.Unmarshal([]byte(value), record); err != nil { return nil, err } @@ -434,7 +432,7 @@ func (b etcd210) Backup(meta *meta.Meta, backupFile string) error { return err } console.Warning(fmt.Sprintf("backup to: %s", backupFile)) - return ioutil.WriteFile(backupFile, backup, 0600) + return os.WriteFile(backupFile, backup, 0o600) } func (b etcd210) BackupV2(file string) error { @@ -489,11 +487,11 @@ func (b etcd210) BackupV2(file string) error { } console.Warning(fmt.Sprintf("backup to: %s", file)) - return ioutil.WriteFile(file, backup, 0600) + return os.WriteFile(file, backup, 0o600) } func (b etcd210) Restore(backupFile string) error { - backup, err := ioutil.ReadFile(backupFile) + backup, err := os.ReadFile(backupFile) if err != nil { return err } diff --git a/cmd/tools/migration/backend/etcd220.go b/cmd/tools/migration/backend/etcd220.go index 4eae60d43e49a..d00805786537e 100644 --- a/cmd/tools/migration/backend/etcd220.go +++ b/cmd/tools/migration/backend/etcd220.go @@ -3,14 +3,11 @@ package backend import ( "fmt" - "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" - "github.com/milvus-io/milvus/cmd/tools/migration/configs" - - "github.com/milvus-io/milvus/pkg/util" - "github.com/milvus-io/milvus/cmd/tools/migration/meta" + "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" "github.com/milvus-io/milvus/internal/metastore/kv/rootcoord" + "github.com/milvus-io/milvus/pkg/util" ) // etcd220 implements Backend. diff --git a/cmd/tools/migration/command/main.go b/cmd/tools/migration/command/main.go index 5c393f4c5af15..40185ef4695f3 100644 --- a/cmd/tools/migration/command/main.go +++ b/cmd/tools/migration/command/main.go @@ -5,9 +5,8 @@ import ( "fmt" "os" - "github.com/milvus-io/milvus/cmd/tools/migration/console" - "github.com/milvus-io/milvus/cmd/tools/migration/configs" + "github.com/milvus-io/milvus/cmd/tools/migration/console" ) func Execute(args []string) { diff --git a/cmd/tools/migration/command/run.go b/cmd/tools/migration/command/run.go index e90faacfcd64e..d30abdbd36d67 100644 --- a/cmd/tools/migration/command/run.go +++ b/cmd/tools/migration/command/run.go @@ -4,9 +4,7 @@ import ( "context" "github.com/milvus-io/milvus/cmd/tools/migration/configs" - "github.com/milvus-io/milvus/cmd/tools/migration/console" - "github.com/milvus-io/milvus/cmd/tools/migration/migration" ) diff --git a/cmd/tools/migration/meta/210_to_220.go b/cmd/tools/migration/meta/210_to_220.go index 3ae5ee7722a7c..1fceb29d0e083 100644 --- a/cmd/tools/migration/meta/210_to_220.go +++ b/cmd/tools/migration/meta/210_to_220.go @@ -6,11 +6,11 @@ import ( "strconv" "strings" - "github.com/milvus-io/milvus/cmd/tools/migration/legacy/legacypb" - - "github.com/milvus-io/milvus/cmd/tools/migration/allocator" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/cmd/tools/migration/allocator" + "github.com/milvus-io/milvus/cmd/tools/migration/legacy/legacypb" "github.com/milvus-io/milvus/cmd/tools/migration/versions" "github.com/milvus-io/milvus/internal/metastore/model" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" @@ -19,7 +19,6 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" ) func alias210ToAlias220(record *pb.CollectionInfo, ts Timestamp) *model.Alias { diff --git a/cmd/tools/migration/meta/meta.go b/cmd/tools/migration/meta/meta.go index b2d36d429196f..f76e6b174eb7c 100644 --- a/cmd/tools/migration/meta/meta.go +++ b/cmd/tools/migration/meta/meta.go @@ -2,11 +2,14 @@ package meta import ( "github.com/blang/semver/v4" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) -type UniqueID = typeutil.UniqueID -type Timestamp = typeutil.Timestamp +type ( + UniqueID = typeutil.UniqueID + Timestamp = typeutil.Timestamp +) type Meta struct { SourceVersion semver.Version diff --git a/cmd/tools/migration/meta/meta210.go b/cmd/tools/migration/meta/meta210.go index 9f3c65f97a057..23016210660f9 100644 --- a/cmd/tools/migration/meta/meta210.go +++ b/cmd/tools/migration/meta/meta210.go @@ -21,14 +21,20 @@ type FieldIndexesWithSchema struct { type FieldIndexes210 map[UniqueID]*FieldIndexesWithSchema // coll_id -> field indexes. -type TtCollectionsMeta210 map[UniqueID]map[Timestamp]*pb.CollectionInfo // coll_id -> ts -> coll -type CollectionsMeta210 map[UniqueID]*pb.CollectionInfo // coll_id -> coll +type ( + TtCollectionsMeta210 map[UniqueID]map[Timestamp]*pb.CollectionInfo // coll_id -> ts -> coll + CollectionsMeta210 map[UniqueID]*pb.CollectionInfo // coll_id -> coll +) -type TtAliasesMeta210 map[string]map[Timestamp]*pb.CollectionInfo // alias name -> ts -> coll -type AliasesMeta210 map[string]*pb.CollectionInfo // alias name -> coll +type ( + TtAliasesMeta210 map[string]map[Timestamp]*pb.CollectionInfo // alias name -> ts -> coll + AliasesMeta210 map[string]*pb.CollectionInfo // alias name -> coll +) -type CollectionIndexesMeta210 map[UniqueID]map[UniqueID]*pb.IndexInfo // coll_id -> index_id -> index -type SegmentIndexesMeta210 map[UniqueID]map[UniqueID]*pb.SegmentIndexInfo // seg_id -> index_id -> segment index +type ( + CollectionIndexesMeta210 map[UniqueID]map[UniqueID]*pb.IndexInfo // coll_id -> index_id -> index + SegmentIndexesMeta210 map[UniqueID]map[UniqueID]*pb.SegmentIndexInfo // seg_id -> index_id -> segment index +) type IndexBuildMeta210 map[UniqueID]*legacypb.IndexMeta // index_build_id -> index diff --git a/cmd/tools/migration/meta/meta220.go b/cmd/tools/migration/meta/meta220.go index 63044fdb8c1b8..f190b4061c651 100644 --- a/cmd/tools/migration/meta/meta220.go +++ b/cmd/tools/migration/meta/meta220.go @@ -13,23 +13,35 @@ import ( "github.com/milvus-io/milvus/pkg/util" ) -type TtCollectionsMeta220 map[UniqueID]map[Timestamp]*model.Collection // coll_id -> ts -> coll -type CollectionsMeta220 map[UniqueID]*model.Collection // coll_id -> coll +type ( + TtCollectionsMeta220 map[UniqueID]map[Timestamp]*model.Collection // coll_id -> ts -> coll + CollectionsMeta220 map[UniqueID]*model.Collection // coll_id -> coll +) -type TtAliasesMeta220 map[string]map[Timestamp]*model.Alias // alias name -> ts -> coll -type AliasesMeta220 map[string]*model.Alias // alias name -> coll +type ( + TtAliasesMeta220 map[string]map[Timestamp]*model.Alias // alias name -> ts -> coll + AliasesMeta220 map[string]*model.Alias // alias name -> coll +) -type TtPartitionsMeta220 map[UniqueID]map[Timestamp][]*model.Partition // coll_id -> ts -> partitions -type PartitionsMeta220 map[UniqueID][]*model.Partition // coll_id -> ts -> partitions +type ( + TtPartitionsMeta220 map[UniqueID]map[Timestamp][]*model.Partition // coll_id -> ts -> partitions + PartitionsMeta220 map[UniqueID][]*model.Partition // coll_id -> ts -> partitions +) -type TtFieldsMeta220 map[UniqueID]map[Timestamp][]*model.Field // coll_id -> ts -> fields -type FieldsMeta220 map[UniqueID][]*model.Field // coll_id -> ts -> fields +type ( + TtFieldsMeta220 map[UniqueID]map[Timestamp][]*model.Field // coll_id -> ts -> fields + FieldsMeta220 map[UniqueID][]*model.Field // coll_id -> ts -> fields +) -type CollectionIndexesMeta220 map[UniqueID]map[UniqueID]*model.Index // coll_id -> index_id -> index -type SegmentIndexesMeta220 map[UniqueID]map[UniqueID]*model.SegmentIndex // seg_id -> index_id -> segment index +type ( + CollectionIndexesMeta220 map[UniqueID]map[UniqueID]*model.Index // coll_id -> index_id -> index + SegmentIndexesMeta220 map[UniqueID]map[UniqueID]*model.SegmentIndex // seg_id -> index_id -> segment index +) -type CollectionLoadInfo220 map[UniqueID]*model.CollectionLoadInfo // collectionID -> CollectionLoadInfo -type PartitionLoadInfo220 map[UniqueID]map[UniqueID]*model.PartitionLoadInfo // collectionID, partitionID -> PartitionLoadInfo +type ( + CollectionLoadInfo220 map[UniqueID]*model.CollectionLoadInfo // collectionID -> CollectionLoadInfo + PartitionLoadInfo220 map[UniqueID]map[UniqueID]*model.PartitionLoadInfo // collectionID, partitionID -> PartitionLoadInfo +) func (meta *TtCollectionsMeta220) GenerateSaves(sourceVersion semver.Version) (map[string]string, error) { saves := make(map[string]string) diff --git a/cmd/tools/migration/migration/210_to_220.go b/cmd/tools/migration/migration/210_to_220.go index edb28f2673fbb..79aaba7a5f44a 100644 --- a/cmd/tools/migration/migration/210_to_220.go +++ b/cmd/tools/migration/migration/210_to_220.go @@ -4,8 +4,7 @@ import ( "github.com/milvus-io/milvus/cmd/tools/migration/meta" ) -type migrator210To220 struct { -} +type migrator210To220 struct{} func (m migrator210To220) Migrate(metas *meta.Meta) (*meta.Meta, error) { return meta.From210To220(metas) diff --git a/cmd/tools/migration/migration/migrator.go b/cmd/tools/migration/migration/migrator.go index e220f71dca82d..c02f35c343fa6 100644 --- a/cmd/tools/migration/migration/migrator.go +++ b/cmd/tools/migration/migration/migrator.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/blang/semver/v4" + "github.com/milvus-io/milvus/cmd/tools/migration/meta" "github.com/milvus-io/milvus/cmd/tools/migration/versions" ) diff --git a/cmd/tools/migration/migration/runner.go b/cmd/tools/migration/migration/runner.go index 204c0c2d76841..87d8d664dbfdf 100644 --- a/cmd/tools/migration/migration/runner.go +++ b/cmd/tools/migration/migration/runner.go @@ -7,20 +7,15 @@ import ( "sync" "time" - "github.com/milvus-io/milvus/internal/util/sessionutil" - "go.uber.org/atomic" - - "github.com/milvus-io/milvus/cmd/tools/migration/versions" - "github.com/blang/semver/v4" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/atomic" + "github.com/milvus-io/milvus/cmd/tools/migration/backend" "github.com/milvus-io/milvus/cmd/tools/migration/configs" - "github.com/milvus-io/milvus/cmd/tools/migration/console" - - "github.com/milvus-io/milvus/cmd/tools/migration/backend" - clientv3 "go.etcd.io/etcd/client/v3" - + "github.com/milvus-io/milvus/cmd/tools/migration/versions" + "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/util/etcd" ) diff --git a/cmd/tools/migration/utils/util.go b/cmd/tools/migration/utils/util.go index 6fa47a4f8997f..e9dc2caa5776e 100644 --- a/cmd/tools/migration/utils/util.go +++ b/cmd/tools/migration/utils/util.go @@ -5,13 +5,14 @@ import ( "strconv" "strings" - "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/milvus-io/milvus/internal/metastore/kv/rootcoord" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) -type UniqueID = typeutil.UniqueID -type Timestamp = typeutil.Timestamp +type ( + UniqueID = typeutil.UniqueID + Timestamp = typeutil.Timestamp +) type errNotOfTsKey struct { key string diff --git a/configs/milvus.yaml b/configs/milvus.yaml index eff16f90ed9ec..adea48aee559e 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -45,9 +45,19 @@ etcd: metastore: # Default value: etcd - # Valid values: etcd + # Valid values: [etcd, tikv] type: etcd +# Related configuration of tikv, used to store Milvus metadata. +# Notice that when TiKV is enabled for metastore, you still need to have etcd for service discovery. +# TiKV is a good option when the metadata size requires better horizontal scalability. +tikv: + # Note that the default pd port of tikv is 2379, which conflicts with etcd. + endpoints: 127.0.0.1:2389 + rootPath: by-dev # The root path where data is stored + metaSubPath: meta # metaRootPath = rootPath + '/' + metaSubPath + kvSubPath: kv # kvRootPath = rootPath + '/' + kvSubPath + localStorage: path: /tmp/milvus/data/ # please adjust in embedded Milvus: /tmp/milvus/data/ @@ -84,6 +94,8 @@ minio: region: "" # Cloud whether use virtual host bucket mode useVirtualHost: false + # timeout for request time in milliseconds + requestTimeoutMs: 3000 # Milvus supports four MQ: rocksmq(based on RockDB), natsmq(embedded nats-server), Pulsar and Kafka. # You can change your mq by setting mq.type field. @@ -104,6 +116,7 @@ pulsar: tenant: public namespace: default requestTimeout: 60 # pulsar client global request timeout in seconds + enableClientMetrics: false # Whether to register pulsar client metrics into milvus metrics path. # If you want to enable kafka, needs to comment the pulsar configs # kafka: @@ -112,6 +125,7 @@ pulsar: # saslPassword: # saslMechanisms: PLAIN # securityProtocol: SASL_SSL +# readTimeout: 10 # read message timeout in seconds rocksmq: # The path where the message is stored in rocksmq @@ -157,6 +171,9 @@ rootCoord: importTaskExpiration: 900 # (in seconds) Duration after which an import task will expire (be killed). Default 900 seconds (15 minutes). importTaskRetention: 86400 # (in seconds) Milvus will keep the record of import tasks for at least `importTaskRetention` seconds. Default 86400, seconds (24 hours). enableActiveStandby: false + # can specify ip for example + # ip: 127.0.0.1 + ip: # if not specify address, will use the first unicastable address as local ip port: 53100 grpc: serverMaxSendSize: 536870912 @@ -188,6 +205,9 @@ proxy: http: enabled: true # Whether to enable the http server debug_mode: false # Whether to enable http server debug mode + # can specify ip for example + # ip: 127.0.0.1 + ip: # if not specify address, will use the first unicastable address as local ip port: 19530 internalPort: 19529 grpc: @@ -214,6 +234,9 @@ queryCoord: heartbeatAvailableInterval: 10000 # 10s, Only QueryNodes which fetched heartbeats within the duration are available loadTimeoutSeconds: 600 checkHandoffInterval: 5000 + # can specify ip for example + # ip: 127.0.0.1 + ip: # if not specify address, will use the first unicastable address as local ip port: 19531 grpc: serverMaxSendSize: 536870912 @@ -248,8 +271,9 @@ queryNode: enableDisk: false # enable querynode load disk index, and search on disk index maxDiskUsagePercentage: 95 cache: - enabled: true - memoryLimit: 2147483648 # 2 GB, 2 * 1024 *1024 *1024 + enabled: true # deprecated, TODO: remove it + memoryLimit: 2147483648 # 2 GB, 2 * 1024 *1024 *1024 # deprecated, TODO: remove it + readAheadPolicy: willneed # The read ahead policy of chunk cache, options: `normal, random, sequential, willneed, dontneed` grouping: enabled: true maxNQ: 1000 @@ -281,6 +305,9 @@ queryNode: enableCrossUserGrouping: false # false by default Enable Cross user grouping when using user-task-polling policy. (close it if task of any user can not merge others). maxPendingTaskPerUser: 1024 # 50 by default, max pending task in scheduler per user. + # can specify ip for example + # ip: 127.0.0.1 + ip: # if not specify address, will use the first unicastable address as local ip port: 21123 grpc: serverMaxSendSize: 536870912 @@ -302,6 +329,9 @@ indexNode: buildParallel: 1 enableDisk: true # enable index node build disk vector index maxDiskUsagePercentage: 95 + # can specify ip for example + # ip: 127.0.0.1 + ip: # if not specify address, will use the first unicastable address as local ip port: 21121 grpc: serverMaxSendSize: 536870912 @@ -354,6 +384,9 @@ dataCoord: missingTolerance: 3600 # file meta missing tolerance duration in seconds, 3600 dropTolerance: 10800 # file belongs to dropped entity tolerance duration in seconds. 10800 enableActiveStandby: false + # can specify ip for example + # ip: 127.0.0.1 + ip: # if not specify address, will use the first unicastable address as local ip port: 13333 grpc: serverMaxSendSize: 536870912 @@ -371,6 +404,9 @@ dataNode: insertBufSize: 16777216 # Max buffer size to flush for a single segment. deleteBufBytes: 67108864 # Max buffer size to flush del for a single channel syncPeriod: 600 # The period to sync segments if buffer is not empty. + # can specify ip for example + # ip: 127.0.0.1 + ip: # if not specify address, will use the first unicastable address as local ip port: 21124 grpc: serverMaxSendSize: 536870912 @@ -384,6 +420,11 @@ dataNode: watermarkCluster: 0.5 # memory watermark for cluster, upon reaching this watermark, segments will be synced. timetick: byRPC: true + channel: + # specify the size of global work pool of all channels + # if this parameter <= 0, will set it as the maximum number of CPUs that can be executing + # suggest to set it bigger on large collection numbers to avoid blocking + workPoolSize: -1 # Configures the system log output. log: @@ -402,14 +443,14 @@ grpc: serverMaxSendSize: 536870912 serverMaxRecvSize: 536870912 client: - compressionEnabled: false + compressionEnabled: true dialTimeout: 200 keepAliveTime: 10000 keepAliveTimeout: 20000 - maxMaxAttempts: 5 - initialBackoff: 1 - maxBackoff: 10 - backoffMultiplier: 2 + maxMaxAttempts: 10 + initialBackOff: 0.2 # seconds + maxBackoff: 10 # seconds + backoffMultiplier: 2.0 # deprecated clientMaxSendSize: 268435456 clientMaxRecvSize: 268435456 @@ -425,6 +466,7 @@ common: rootCoordTimeTick: rootcoord-timetick rootCoordStatistics: rootcoord-statistics rootCoordDml: rootcoord-dml + replicateMsg: replicate-msg rootCoordDelta: rootcoord-delta search: search searchResult: searchResult @@ -492,6 +534,7 @@ common: threshold: info: 500 # minimum milliseconds for printing durations in info level warn: 1000 # minimum milliseconds for printing durations in warn level + ttMsgEnabled: true # Whether the instance disable sending ts messages # QuotaConfig, configurations of Milvus quota and limits. # By default, we enable: diff --git a/docker-compose.yml b/docker-compose.yml index 6166fe8c7ba2e..c8b0b8f7a19e9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -28,6 +28,7 @@ services: ETCD_ENDPOINTS: ${ETCD_ENDPOINTS} MINIO_ADDRESS: ${MINIO_ADDRESS} CONAN_USER_HOME: /home/milvus + AZURE_STORAGE_CONNECTION_STRING: ${AZURITE_CONNECTION_STRING} volumes: &builder-volumes - .:/go/src/github.com/milvus-io/milvus:delegated - ${DOCKER_VOLUME_DIRECTORY:-.docker}/${IMAGE_ARCH}-${OS_NAME}-ccache:/ccache:delegated @@ -39,6 +40,7 @@ services: - etcd - minio - pulsar + - azurite # Command command: &builder-command > /bin/bash -c " @@ -64,6 +66,7 @@ services: ETCD_ENDPOINTS: ${ETCD_ENDPOINTS} MINIO_ADDRESS: ${MINIO_ADDRESS} CONAN_USER_HOME: /home/milvus + AZURE_STORAGE_CONNECTION_STRING: ${AZURITE_CONNECTION_STRING} volumes: &builder-volumes-gpu - .:/go/src/github.com/milvus-io/milvus:delegated - ${DOCKER_VOLUME_DIRECTORY:-.docker-gpu}/${OS_NAME}-ccache:/ccache:delegated @@ -75,6 +78,7 @@ services: - etcd - minio - pulsar + - azurite # Command command: &builder-command-gpu > /bin/bash -c " @@ -110,6 +114,10 @@ services: timeout: 20s retries: 3 + azurite: + image: mcr.microsoft.com/azure-storage/azurite + command: azurite-blob --blobHost 0.0.0.0 + jaeger: image: jaegertracing/all-in-one:latest diff --git a/docs/design_docs/20230918-datanode_remove_datacoord_dependency.md b/docs/design_docs/20230918-datanode_remove_datacoord_dependency.md new file mode 100644 index 0000000000000..338aaa986ecbc --- /dev/null +++ b/docs/design_docs/20230918-datanode_remove_datacoord_dependency.md @@ -0,0 +1,121 @@ +# MEP: Datanode remove dependency of `Datacoord` + +Current state: "Accepted" + +ISSUE: https://github.com/milvus-io/milvus/issues/26758 + +Keywords: datacoord, datanode, flush, dependency, roll-upgrade + +## Summary + +Remove the dependency of `Datacoord` for `Datanodes`. + +## Motivation + +1. Datanodes shall be always be running even when the data coordinator is not alive + +If datanodes performs `sync` during rolling upgrade, it needs datacoord to change the related meta in metastore. If datacoord happens to be offline or it is during some period of rolling-upgrade, datanode has to panic to ensure there is no data lost. + +2. Flush operation is complex and error-prone due since the whole procedure involves datacoord, datanodes and grpc + +This proposal means to remove the dependency of datacoord ensuring: + +- the data is integrate and no duplicate data is kept in records +- no compatibility issue during or after rolling upgrade +- `Datacoord` shall be able to detect the segment meta updates and provides recent targets for `QueryCoord` + +## Design Details + +The most brief description if this proposal is to: + +- Make `Datanode` operating the segment meta directly +- Make `Datacoord` refresh the latest segment change periodically + + +### Preventing multiple writers + +There is a major concern that if multiple `Datanodes` are handling the same dml channel, there shall be only one `DataNode` could update segment meta successfully. + +This guarantee is previously implemented by singleton writer in `Datacoord`: it checks the valid watcher id before update the segment meta when receiving the `SaveBinlogPaths` grpc call. + +In this proposal, `DataNodes` update segment meta on its own, so we need to introduce a new mechanism to prevent this error from happening: + +{% note %} + +**Note:** Like the "etcd lease for key", the ownership of each dml channel is bound to a lease id. This lease id shall be recorded in metastore (etcd/tikv or any other implementation). +When a `DataNode` start to watch a dml channel, it shall read this lease id (via etcd or grpc call). ANY operations on this dml channel shall under a transaction with the lease id is equal to previously read value. +If a `datanode` finds the lease id is revoke or updated, it shall close the flowgraph/pipeline and cancel all pending operations instead of panicking. + +{% endnote %} + +- [] Add lease id field in etcd channel watch info/ grpc watch request +- [] Add `TransactionIf` like APIs in `TxnKV` interface + +### Updating channel checkpoint + +Likewise, all channel checkpoints update operations are performed by `Datacoord` invoking by grpc calls from `DataNodes`. So it has the same problem in previously stated scenarios. + +So, "updating channel checkpoint" shall also be processed in `DataNodes` while removing the dependency of `DataCoord`. + +The rules system shall follow is: + +{% note %} + +**Note:** Segments meta shall be updated *BEFORE* changing the channel checkpoint in case of datanode crashing during the prodedure. Under this premise, reconsuming from the old checkpoint shall recover all the data and duplidated entires will be discarded by segment checkpoints. + +{% endnote %} + +### Updating segment status in `DataCoord` + +As previous described, `DataCoord` shall refresh the segment meta and channel checkpoint periodically to provide recent target for `QueryCoord`. + +The `watching via Etcd` strategy is ruled out first since `Watch` operation shall avoided in the future design: currently Milvus system tends to not use `Watch` operation and try to remove it from metastore. +Also `Watch` is heavy and has caused lots of issue before. + +The winning option is to: + +{% note %} + +**Note:** `Datacoord` reloads from metastore periodically. +Optimization 1: reload channel checkpoint first, then reload segment meta if newly read revision is greater than in-memory one. +Optimization 2: After `L0 segemnt` is implemented, datacoord shall refresh growing segments only. + +{% endnote %} + + +## Compatibility, Deprecation, and Migration Plan + +This change shall guarantee that: + +- When new `Datacoord` starts, it shall be able to upgrade the old watch info and add lease id into it + - For watch info, release then watch + - For grpc, `release then watch` is the second choice, try call watch with lease id +- Older `DataNodes` could invoking `SaveBinlogPaths` and other legacy grpc calls without panicking +- The new `DataNodes` receiving old watch request(without lease id) shall fallback to older strategy, which is to update meta via grpc +- `SaveBinlogPaths`, `UpdateChannelCheckpoints` APIs shall be kept until next break change + +## Test Plan + +### Unit test +Coverage over 90% + +### Integration Test + +#### Datacoord offline + +1. Insert data without datanodes online +2. Start datanodes +3. Make datacoord go offline after channel assignment +4. Assert no datanode panicking and all data shall be intact +5. Bring back datacoord and test `GetRecoveryInfo`, which shall returns latest target + + +#### Compatibility + +1. Start mock datacoord +2. construct a watch info (without lease) +3. Datanode start to watch dml channel and all meta update shall be performed via grpc + +## Rejected Alternatives + +DataCoord refresh meta via Etcd watch diff --git a/go.mod b/go.mod index 9f7771d16642e..3501b30ecb576 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,12 @@ module github.com/milvus-io/milvus go 1.18 require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 + github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.1.0 github.com/aliyun/credentials-go v1.2.7 github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210826220005-b48c857c3a0e - github.com/antonmedv/expr v1.8.9 github.com/apache/pulsar-client-go v0.6.1-0.20210728062540-29414db801a7 github.com/bits-and-blooms/bloom/v3 v3.0.1 github.com/blang/semver/v4 v4.0.0 @@ -20,7 +23,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/klauspost/compress v1.16.7 github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d - github.com/milvus-io/milvus-proto/go-api/v2 v2.3.1-0.20230907032509-23756009c643 + github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2-0.20231019101159-a0a6f5e7eff8 github.com/milvus-io/milvus/pkg v0.0.1 github.com/minio/minio-go/v7 v7.0.61 github.com/prometheus/client_golang v1.14.0 @@ -35,6 +38,7 @@ require ( github.com/stretchr/testify v1.8.4 github.com/tecbot/gorocksdb v0.0.0-20191217155057-f0fad39f321c github.com/tidwall/gjson v1.14.4 + github.com/tikv/client-go/v2 v2.0.4 go.etcd.io/etcd/api/v3 v3.5.5 go.etcd.io/etcd/client/v3 v3.5.5 go.etcd.io/etcd/server/v3 v3.5.5 @@ -44,11 +48,11 @@ require ( go.uber.org/atomic v1.11.0 go.uber.org/multierr v1.11.0 go.uber.org/zap v1.24.0 - golang.org/x/crypto v0.11.0 + golang.org/x/crypto v0.14.0 golang.org/x/exp v0.0.0-20230728194245-b0cb94b80691 golang.org/x/oauth2 v0.8.0 golang.org/x/sync v0.3.0 - golang.org/x/text v0.11.0 + golang.org/x/text v0.13.0 google.golang.org/grpc v1.57.0 google.golang.org/grpc/examples v0.0.0-20220617181431-3e7b97febc7f stathat.com/c/consistent v1.0.0 @@ -57,6 +61,7 @@ require ( require ( github.com/apache/arrow/go/v12 v12.0.0-20230223012627-e0e740bd7a24 github.com/milvus-io/milvus-storage/go v0.0.0-20231017063757-b4720fe2ec8f + github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81 ) require ( @@ -65,6 +70,7 @@ require ( github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect github.com/99designs/keyring v1.2.1 // indirect github.com/AthenZ/athenz v1.10.39 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 // indirect github.com/DataDog/zstd v1.5.0 // indirect github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c // indirect github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible // indirect @@ -88,8 +94,10 @@ require ( github.com/containerd/cgroups v1.1.0 // indirect github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect + github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 // indirect github.com/danieljoos/wincred v1.1.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect github.com/docker/go-units v0.4.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/dvsekhvalnov/jose2go v1.5.0 // indirect @@ -111,6 +119,7 @@ require ( github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 // indirect github.com/godbus/dbus/v5 v5.0.4 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect + github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v2.0.8+incompatible // indirect github.com/google/uuid v1.3.0 // indirect @@ -127,6 +136,7 @@ require ( github.com/klauspost/cpuid/v2 v2.2.5 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/linkedin/goavro/v2 v2.11.1 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect @@ -149,16 +159,22 @@ require ( github.com/nats-io/nkeys v0.4.4 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/opencontainers/runtime-spec v1.0.2 // indirect + github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/panjf2000/ants/v2 v2.7.2 // indirect github.com/pelletier/go-toml v1.9.3 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pierrec/lz4 v2.5.2+incompatible // indirect github.com/pierrec/lz4/v4 v4.1.18 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect + github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 // indirect + github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989 // indirect + github.com/pingcap/kvproto v0.0.0-20221129023506-621ec37aac7a // indirect + github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/prometheus/procfs v0.9.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect github.com/rs/xid v1.5.0 // indirect github.com/shirou/gopsutil/v3 v3.22.9 // indirect @@ -167,15 +183,19 @@ require ( github.com/spf13/afero v1.6.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stathat/consistent v1.0.0 // indirect github.com/streamnative/pulsarctl v0.5.0 // indirect github.com/stretchr/objx v0.5.0 // indirect github.com/subosito/gotenv v1.2.0 // indirect + github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect + github.com/tikv/pd/client v0.0.0-20221031025758-80f0d8ca4d07 // indirect github.com/tklauser/go-sysconf v0.3.10 // indirect github.com/tklauser/numcpus v0.4.0 // indirect github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/twmb/murmur3 v1.1.3 // indirect github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect github.com/ugorji/go/codec v1.2.11 // indirect github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect @@ -197,9 +217,9 @@ require ( go.uber.org/automaxprocs v1.5.2 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/mod v0.12.0 // indirect - golang.org/x/net v0.12.0 // indirect - golang.org/x/sys v0.10.0 // indirect - golang.org/x/term v0.10.0 // indirect + golang.org/x/net v0.17.0 // indirect + golang.org/x/sys v0.13.0 // indirect + golang.org/x/term v0.13.0 // indirect golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.11.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect diff --git a/go.sum b/go.sum index 232315ee7f40b..8e5e5d28b0e25 100644 --- a/go.sum +++ b/go.sum @@ -49,12 +49,22 @@ github.com/99designs/keyring v1.2.1/go.mod h1:fc+wB5KTk9wQ9sDx0kFXB3A0MaeGHM9AwR github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/AthenZ/athenz v1.10.39 h1:mtwHTF/v62ewY2Z5KWhuZgVXftBej1/Tn80zx4DcawY= github.com/AthenZ/athenz v1.10.39/go.mod h1:3Tg8HLsiQZp81BJY58JBeU2BR6B/H4/0MQGfCwhHNEA= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0 h1:8q4SaHjFsClSvuVne0ID/5Ka8u3fcIHyqkLjcFpNRHQ= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.2.0 h1:Ma67P/GGprNwsslzEH6+Kb8nybI8jpDTm4Wmzu2ReK8= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.1.0 h1:nVocQV40OQne5613EeLayJiRAJuKlBGy+m22qWG+WRg= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.1.0/go.mod h1:7QJP7dr2wznCMeqIrhMgWGf7XpAQnVrJqDm9nvV3Cu4= +github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 h1:OBhqkivkhkMqLPymWEppkm7vgPQY2XsHoEkaMQ0AdZY= +github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.2.1 h1:9F2/+DoOYIOksmaJFPw1tGFy1eDnIJXg+UHjuD8lTak= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/CloudyKit/fastprinter v0.0.0-20200109182630-33d98a066a53/go.mod h1:+3IMCy2vIlbG1XG/0ggNQv0SvxCAIpPM5b1nCz56Xno= github.com/CloudyKit/jet/v3 v3.0.0/go.mod h1:HKQPgSJmdK8hdoAbKUUWajkHyHo4RaU5rMdUywE7VMo= -github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/DataDog/zstd v1.5.0 h1:+K/VEwIAaPcHiMtQvpLD4lqW7f0Gk3xdYZmI1hD+CXo= github.com/DataDog/zstd v1.5.0/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU= @@ -82,11 +92,10 @@ github.com/aliyun/credentials-go v1.2.7 h1:gLtFylxLZ1TWi1pStIt1O6a53GFU1zkNwjtJi github.com/aliyun/credentials-go v1.2.7/go.mod h1:/KowD1cfGSLrLsH28Jr8W+xwoId0ywIy5lNzDz6O1vw= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/antihax/optional v0.0.0-20180407024304-ca021399b1a6/go.mod h1:V8iCPQYkqmusNa815XgQio277wI47sdRh1dUOLdyC6Q= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210826220005-b48c857c3a0e h1:GCzyKMDDjSGnlpl3clrdAK7I1AaVoaiKDOYkUzChZzg= github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20210826220005-b48c857c3a0e/go.mod h1:F7bn7fEU90QkQ3tnmaTx3LTKLEDqnwWODIYppRQ5hnY= -github.com/antonmedv/expr v1.8.9 h1:O9stiHmHHww9b4ozhPx7T6BK7fXfOCHJ8ybxf0833zw= -github.com/antonmedv/expr v1.8.9/go.mod h1:5qsM3oLGDND7sDmQGDXHkYfkjYMUX14qsgqmHhwGEk8= github.com/apache/arrow/go/v12 v12.0.0-20230223012627-e0e740bd7a24 h1:3klg6Gtrm0jGkiXWLYricKhI1pYYFuBFXhGzOT5B1eo= github.com/apache/arrow/go/v12 v12.0.0-20230223012627-e0e740bd7a24/go.mod h1:3JcT3bSZFdc7wLPKSlQXhf3L0GjPz0TOmLlG1YXnBfU= github.com/apache/thrift v0.18.1 h1:lNhK/1nqjbwbiOPDBPFJVKxgDEGSepKuTh6OLiXW8kg= @@ -186,19 +195,22 @@ github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwc github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 h1:iwZdTE0PVqJCos1vaoKsclOGD3ADKpshg3SRtYBbwso= +github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM= github.com/danieljoos/wincred v1.1.2 h1:QLdCxFs1/Yl4zduvBdcHB8goaYk9RARS2SgLLRuAyr0= github.com/danieljoos/wincred v1.1.2/go.mod h1:GijpziifJoIBfYh+S7BbkdUTU4LfM+QnGqR5Vl2tAx0= -github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dimfeld/httptreemux v5.0.1+incompatible h1:Qj3gVcDNoOthBAqftuD596rm4wg/adLLz5xh5CmpiCA= github.com/dimfeld/httptreemux v5.0.1+incompatible/go.mod h1:rbUlSV+CCpv/SuqUTP/8Bk2O3LyUV436/yaRGkhP6Z0= +github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= @@ -244,8 +256,6 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4 github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc= -github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo5dl+VrEg= -github.com/gdamore/tcell v1.3.0/go.mod h1:Hjvr+Ofd+gLglo7RYKxxnzCBmev3BzsS67MebKS4zMM= github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/getsentry/sentry-go v0.12.0 h1:era7g0re5iY13bHSdN/xMkyV+5zZppjRVQhZrXCaEIk= github.com/getsentry/sentry-go v0.12.0/go.mod h1:NSap0JBYWzHND8oMbyi0+XZhUalc1TBdRL1M71JZW2c= @@ -314,6 +324,8 @@ github.com/gogo/status v1.1.0/go.mod h1:BFv9nrluPLmrS0EmGVvLaPNmRosr9KapBYd5/hpY github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= @@ -414,6 +426,7 @@ github.com/grpc-ecosystem/go-grpc-middleware v1.3.0/go.mod h1:z0ButlSOZa5vEBq9m2 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= +github.com/grpc-ecosystem/grpc-gateway v1.12.1/go.mod h1:8XEsbTttt/W+VvjtQhLACqCisSPWTxCZ7sBRjU6iH9c= github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 h1:BZHcxBETFHIdVyhyEfOvn/RdU/QGdLI4y34qQGjGWO0= @@ -524,6 +537,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kris-nova/logger v0.0.0-20181127235838-fd0d87064b06 h1:vN4d3jSss3ExzUn2cE0WctxztfOgiKvMKnDrydBsg00= github.com/kris-nova/lolgopher v0.0.0-20180921204813-313b3abb0d9b h1:xYEM2oBUhBEhQjrV+KJ9lEWDWYZoNVZUaBF++Wyljq4= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/labstack/echo/v4 v4.5.0/go.mod h1:czIriw4a0C1dFun+ObrXp7ok03xON0N1awStJ6ArI7Y= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= @@ -535,8 +550,6 @@ github.com/linkedin/goavro/v2 v2.10.0/go.mod h1:UgQUb2N/pmueQYH9bfqFioWxzYCZXSfF github.com/linkedin/goavro/v2 v2.10.1/go.mod h1:UgQUb2N/pmueQYH9bfqFioWxzYCZXSfF8Jw03O5sjqA= github.com/linkedin/goavro/v2 v2.11.1 h1:4cuAtbDfqkKnBXp9E+tRkIJGa6W6iAjwonwt8O1f4U0= github.com/linkedin/goavro/v2 v2.11.1/go.mod h1:UgQUb2N/pmueQYH9bfqFioWxzYCZXSfF8Jw03O5sjqA= -github.com/lucasb-eyer/go-colorful v1.0.2/go.mod h1:0MS4r+7BZKSJ5mw4/S5MPN+qHFF1fYclkSPilDOKW0s= -github.com/lucasb-eyer/go-colorful v1.0.3/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= @@ -556,9 +569,7 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.8 h1:3tS41NlGYSmhhe/8fhGRzc+z3AYCw1Fe1WAyLuujKs0= -github.com/mattn/go-runewidth v0.0.8/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= @@ -570,18 +581,8 @@ github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/le github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.1-0.20230907032509-23756009c643 h1:3MXEYckliGnyepZeLDrhn+speelsoRKU1IwD8JrxXMo= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.1-0.20230907032509-23756009c643/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek= -github.com/milvus-io/milvus-storage/go v0.0.0-20231008080526-5593abc6e20e h1:p7Gzp5fAePnDd+ikHek4u9OhAewKazBVI720uP6QtgY= -github.com/milvus-io/milvus-storage/go v0.0.0-20231008080526-5593abc6e20e/go.mod h1:GPETMcTZq1gLY1WA6Na5kiNAKnq8SEMMiVKUZrM3sho= -github.com/milvus-io/milvus-storage/go v0.0.0-20231008084415-18e12222e1e0 h1:tRuzrCHnFXxYpcGd4VebQoqTmty69A8IrR8ZYgMM9Dk= -github.com/milvus-io/milvus-storage/go v0.0.0-20231008084415-18e12222e1e0/go.mod h1:GPETMcTZq1gLY1WA6Na5kiNAKnq8SEMMiVKUZrM3sho= -github.com/milvus-io/milvus-storage/go v0.0.0-20231008085610-f9cad594aa32 h1:cfUa5LlfXguXubqS+6Lv1UaV0Qs9jYHkC0VWBQ/d6Pw= -github.com/milvus-io/milvus-storage/go v0.0.0-20231008085610-f9cad594aa32/go.mod h1:GPETMcTZq1gLY1WA6Na5kiNAKnq8SEMMiVKUZrM3sho= -github.com/milvus-io/milvus-storage/go v0.0.0-20231008092056-1a44517beb7d h1:3rjVfWdW/NvwBHWa6H6c3XJQKW4t1VODvybxZfIe/sM= -github.com/milvus-io/milvus-storage/go v0.0.0-20231008092056-1a44517beb7d/go.mod h1:GPETMcTZq1gLY1WA6Na5kiNAKnq8SEMMiVKUZrM3sho= -github.com/milvus-io/milvus-storage/go v0.0.0-20231009032726-c040a793ebf4 h1:oXPeTLfI7BEdBJHKXdnKx7ACv8jW7H0Uo0Mo4h57d04= -github.com/milvus-io/milvus-storage/go v0.0.0-20231009032726-c040a793ebf4/go.mod h1:GPETMcTZq1gLY1WA6Na5kiNAKnq8SEMMiVKUZrM3sho= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2-0.20231019101159-a0a6f5e7eff8 h1:GoGErEOhdWjwSfQilXso3eINqb11yEBDLtoBMNdlve0= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2-0.20231019101159-a0a6f5e7eff8/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek= github.com/milvus-io/milvus-storage/go v0.0.0-20231017063757-b4720fe2ec8f h1:NGI0nposnZ2S0K1W/pEWRNzDkaLraaI10aMgKCMnVYs= github.com/milvus-io/milvus-storage/go v0.0.0-20231017063757-b4720fe2ec8f/go.mod h1:GPETMcTZq1gLY1WA6Na5kiNAKnq8SEMMiVKUZrM3sho= github.com/milvus-io/pulsar-client-go v0.6.10 h1:eqpJjU+/QX0iIhEo3nhOqMNXL+TyInAs1IAHZCrCM/A= @@ -655,6 +656,7 @@ github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9 github.com/opencontainers/runtime-spec v1.0.2 h1:UfAcuLBJB9Coz72x1hgl8O5RVzTdNiaglX6v2DM6FI0= github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/panjf2000/ants/v2 v2.7.2 h1:2NUt9BaZFO5kQzrieOmK/wdb/tQ/K+QHaxN8sOgD63U= github.com/panjf2000/ants/v2 v2.7.2/go.mod h1:KIBmYG9QQX5U2qzFP/yQJaq/nSb6rahS9iEHkrCMgM8= @@ -671,16 +673,27 @@ github.com/pierrec/lz4 v2.5.2+incompatible h1:WCjObylUIOlKy/+7Abdn34TLIkXiA4UWUM github.com/pierrec/lz4 v2.5.2+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/lz4/v4 v4.1.18 h1:xaKrnTkyoqfh1YItXl56+6KJNVYWlEEPuAQW9xsplYQ= github.com/pierrec/lz4/v4 v4.1.18/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c h1:xpW9bvK+HuuTmyFqUwr+jcCvpVkK7sumiz+ko5H9eq4= github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg= +github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 h1:C3N3itkduZXDZFh4N3vQ5HEtld3S+Y+StULhWVvumU0= +github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00/go.mod h1:4qGtCB0QK0wBzKtFEGDhxXnSnbQApw1gc9siScUl8ew= +github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989 h1:surzm05a8C9dN8dIUmo4Be2+pMRb6f55i+UIYrluu2E= +github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989/go.mod h1:O17XtbryoCJhkKGbT62+L2OlrniwqiGLSqrmdHCMzZw= +github.com/pingcap/kvproto v0.0.0-20221026112947-f8d61344b172/go.mod h1:OYtxs0786qojVTmkVeufx93xe+jUgm56GUYRIKnmaGI= +github.com/pingcap/kvproto v0.0.0-20221129023506-621ec37aac7a h1:LzIZsQpXQlj8yF7+yvyOg680OaPq7bmPuDuszgXfHsw= +github.com/pingcap/kvproto v0.0.0-20221129023506-621ec37aac7a/go.mod h1:OYtxs0786qojVTmkVeufx93xe+jUgm56GUYRIKnmaGI= +github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81 h1:URLoJ61DmmY++Sa/yyPEQHG2s/ZBeV1FbIswHEMrdoY= +github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= +github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= -github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -692,6 +705,7 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= +github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.11.1/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.14.0 h1:nJdhIvne2eSX/XRAFV9PcvFFRbrjbcTUj0VP62TMhnw= github.com/prometheus/client_golang v1.14.0/go.mod h1:8vpkKitgIVNcqrRBWh1C4TIUQgYNtG/XQE4E/Zae36Y= @@ -718,8 +732,8 @@ github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= github.com/quasilyte/go-ruleguard/dsl v0.3.22 h1:wd8zkOhSNr+I+8Qeciml08ivDt1pSXe60+5DqOpCjPE= github.com/quasilyte/go-ruleguard/dsl v0.3.22/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU= -github.com/rivo/tview v0.0.0-20200219210816-cd38d7432498/go.mod h1:6lkG1x+13OShEf0EaOCaTQYyB7d5nSbb181KtjlS+84= -github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/clock v0.0.0-20190514195947-2896927a307a/go.mod h1:4r5QyqhjIWCcK8DO4KMclc5Iknq5qVBAlbYYzAbUScQ= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= @@ -739,13 +753,13 @@ github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= github.com/samber/lo v1.27.0 h1:GOyDWxsblvqYobqsmUuMddPa2/mMzkKyojlXol4+LaQ= github.com/samber/lo v1.27.0/go.mod h1:it33p9UtPMS7z72fP4gw/EIfQB2eI8ke7GR2wc6+Rhg= -github.com/sanity-io/litter v1.2.0/go.mod h1:JF6pZUFgu2Q0sBZ+HSV35P8TVPI1TTzEwyu9FXAw2W4= github.com/santhosh-tekuri/jsonschema/v5 v5.0.0/go.mod h1:FKdcjfQW6rpZSnxxUvEA5H/cDPdvJ/SZJQLWWXWGrZ0= github.com/sbinet/npyio v0.6.0 h1:IyqqQIzRjDym9xnIXsToCKei/qCzxDP+Y74KoMlMgXo= github.com/sbinet/npyio v0.6.0/go.mod h1:/q3BNr6dJOy+t6h7RZchTJ0nwRJO52mivaem29WE1j8= github.com/schollz/closestmatch v2.1.0+incompatible/go.mod h1:RtP1ddjLong6gTkbtmuhtR2uUrrJOpYzYRvbcPAid+g= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/shirou/gopsutil/v3 v3.22.9 h1:yibtJhIVEMcdw+tCTbOPiF1VcsuDeTE4utJ8Dm4c5eA= github.com/shirou/gopsutil/v3 v3.22.9/go.mod h1:bBYl1kjgEJpWpxeHmLI+dVHWtyAwfcmSBLDsp2TNT8A= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= @@ -784,6 +798,8 @@ github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DM github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/spf13/viper v1.8.1 h1:Kq1fyeebqsBfbjZj4EL7gj2IO0mMaiyjYUWcUsl2O44= github.com/spf13/viper v1.8.1/go.mod h1:o0Pch8wJ9BVSWGQMbra6iw0oQ5oktSIBaujf1rJH9Ns= +github.com/stathat/consistent v1.0.0 h1:ZFJ1QTRn8npNBKW065raSZ8xfOqhpb8vLOkfp4CcL/U= +github.com/stathat/consistent v1.0.0/go.mod h1:uajTPbgSygZBJ+V+0mY7meZ8i0XAcZs7AQ6V121XSxw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= @@ -791,7 +807,6 @@ github.com/stretchr/objx v0.3.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoH github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v0.0.0-20161117074351-18a02ba4a312/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.1-0.20190311161405-34c6fa2dc709/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -809,12 +824,18 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M= +github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a h1:J/YdBZ46WKpXsxsW93SG+q0F8KI+yFrcIDT4c/RNoc4= +github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a/go.mod h1:h4xBhSNtOeEosLJ4P7JyKXX7Cabg7AVkWCK5gV2vOrM= github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tikv/client-go/v2 v2.0.4 h1:cPtMXTExqjzk8L40qhrgB/mXiBXKP5LRU0vwjtI2Xxo= +github.com/tikv/client-go/v2 v2.0.4/go.mod h1:v52O5zDtv2BBus4lm5yrSQhxGW4Z4RaXWfg0U1Kuyqo= +github.com/tikv/pd/client v0.0.0-20221031025758-80f0d8ca4d07 h1:ckPpxKcl75mO2N6a4cJXiZH43hvcHPpqc9dh1TmH1nc= +github.com/tikv/pd/client v0.0.0-20221031025758-80f0d8ca4d07/go.mod h1:CipBxPfxPUME+BImx9MUYXCnAVLS3VJUr3mnSJwh40A= github.com/tklauser/go-sysconf v0.3.10 h1:IJ1AZGZRWbY8T5Vfk04D9WOA5WSejdflXxP03OUqALw= github.com/tklauser/go-sysconf v0.3.10/go.mod h1:C8XykCvCb+Gn0oNCWPIlcb0RuglQTYaQ2hGm7jmxEFk= github.com/tklauser/numcpus v0.4.0 h1:E53Dm1HjH1/R2/aoCtXtPgzmElmn51aOkhCFSuZq//o= @@ -824,6 +845,8 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 h1:uruHq4 github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/twmb/murmur3 v1.1.3 h1:D83U0XYKcHRYwYIpBKf3Pks91Z0Byda/9SJ8B6EMRcA= +github.com/twmb/murmur3 v1.1.3/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o= github.com/uber/jaeger-client-go v2.30.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= @@ -920,17 +943,23 @@ go.opentelemetry.io/proto/otlp v0.19.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/automaxprocs v1.5.2 h1:2LxUOGiR3O6tw8ui5sZa2LAaHnsviZdVOUZw4fvbnME= go.uber.org/automaxprocs v1.5.2/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= +go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= +go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= +go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= +go.uber.org/zap v1.20.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= @@ -950,8 +979,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= -golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1022,6 +1051,7 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191002035440-2ec189313ef0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -1052,8 +1082,8 @@ golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20211008194852-3b03d305991f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= -golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1101,7 +1131,6 @@ golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190626150813-e07cf5db2756/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -1152,6 +1181,7 @@ golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -1164,12 +1194,12 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= -golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= -golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= +golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= +golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1179,8 +1209,8 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= -golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1212,6 +1242,7 @@ golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191112195655-aa38f8e97acc/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -1250,6 +1281,7 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.11.0 h1:EMCa6U9S2LtZXLAMoWiR/R8dAQFRqbAitmbJ2UKhoi8= golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1305,6 +1337,7 @@ google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRn google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= +google.golang.org/genproto v0.0.0-20190927181202-20e1ac93f88c/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= @@ -1354,6 +1387,7 @@ google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZi google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.24.0/go.mod h1:XDChyiUovWa60DnaeDeZmSW86xtLtjtZbwvSiRnRtcA= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= @@ -1374,6 +1408,7 @@ google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQ google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.41.0/go.mod h1:U3l9uK9J0sini8mHphKoXyaqDA/8VyGnDee1zzIUK6k= google.golang.org/grpc v1.42.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= +google.golang.org/grpc v1.43.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= google.golang.org/grpc v1.46.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= google.golang.org/grpc v1.57.0 h1:kfzNeI/klCGD2YPMUlaGNT3pxvYfga7smW3Vth8Zsiw= google.golang.org/grpc v1.57.0/go.mod h1:Sd+9RMTACXwmub0zcNY2c4arhtrbBYD1AUHI/dt16Mo= diff --git a/internal/allocator/cached_allocator.go b/internal/allocator/cached_allocator.go index 6308070038b7f..fe63b2f3c3eee 100644 --- a/internal/allocator/cached_allocator.go +++ b/internal/allocator/cached_allocator.go @@ -23,8 +23,9 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/log" "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" ) const ( diff --git a/internal/allocator/global_id_allocator_test.go b/internal/allocator/global_id_allocator_test.go index f752ffb03d43b..74d856313eac9 100644 --- a/internal/allocator/global_id_allocator_test.go +++ b/internal/allocator/global_id_allocator_test.go @@ -17,7 +17,6 @@ package allocator import ( - "io/ioutil" "net/url" "os" "testing" @@ -37,7 +36,7 @@ var Params paramtable.ComponentParam var embedEtcdServer *embed.Etcd func startEmbedEtcdServer() (*embed.Etcd, error) { - dir, err := ioutil.TempDir(os.TempDir(), "milvus_ut") + dir, err := os.MkdirTemp(os.TempDir(), "milvus_ut") if err != nil { return nil, err } diff --git a/internal/allocator/id_allocator.go b/internal/allocator/id_allocator.go index a71e7214a15b0..6ea7f8fca3885 100644 --- a/internal/allocator/id_allocator.go +++ b/internal/allocator/id_allocator.go @@ -22,6 +22,7 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/util/commonpbutil" @@ -83,7 +84,6 @@ func (ia *IDAllocator) gatherReqIDCount() uint32 { } func (ia *IDAllocator) syncID() (bool, error) { - need := ia.gatherReqIDCount() if need < ia.countPerRPC { need = ia.countPerRPC diff --git a/internal/allocator/id_allocator_test.go b/internal/allocator/id_allocator_test.go index b56c4e5f18ad1..3d83f42944bc3 100644 --- a/internal/allocator/id_allocator_test.go +++ b/internal/allocator/id_allocator_test.go @@ -20,23 +20,21 @@ import ( "context" "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/pkg/util/merr" ) -type mockIDAllocator struct { -} +type mockIDAllocator struct{} -func (tso *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { +func (tso *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { return &rootcoordpb.AllocIDResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - ID: int64(1), - Count: req.Count, + Status: merr.Success(), + ID: int64(1), + Count: req.Count, }, nil } diff --git a/internal/allocator/remote_interface.go b/internal/allocator/remote_interface.go index 6a9cddd572e2a..2f70ef55cd9ab 100644 --- a/internal/allocator/remote_interface.go +++ b/internal/allocator/remote_interface.go @@ -19,9 +19,11 @@ package allocator import ( "context" + "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" ) type remoteInterface interface { - AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) + AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) } diff --git a/internal/core/CMakeLists.txt b/internal/core/CMakeLists.txt index a4dc18bd758a3..f5bd368bfac11 100644 --- a/internal/core/CMakeLists.txt +++ b/internal/core/CMakeLists.txt @@ -29,7 +29,7 @@ if ( MILVUS_GPU_VERSION ) add_definitions(-DMILVUS_GPU_VERSION) endif () -if ( USE_DYNAMIC_SIMD ) +if ( USE_DYNAMIC_SIMD ) add_definitions(-DUSE_DYNAMIC_SIMD) endif() @@ -151,19 +151,6 @@ if ( APPLE ) ) endif () -# Set SIMD to CMAKE_CXX_FLAGS -if (OPEN_SIMD) - message(STATUS "open simd function, CPU_ARCH:${CPU_ARCH}") - if (${CPU_ARCH} STREQUAL "avx") - #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ftree-vectorize -mavx2 -mfma -mavx -mf16c ") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mfma -mavx -mf16c ") - elseif (${CPU_ARCH} STREQUAL "sse") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2 ") - elseif (${CPU_ARCH} STREQUAL "arm64") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mcpu=apple-m1+crc ") - endif() -endif () - # **************************** Coding style check tools **************************** find_package( ClangTools ) set( BUILD_SUPPORT_DIR "${CMAKE_SOURCE_DIR}/build-support" ) diff --git a/internal/core/build.sh b/internal/core/build.sh index ec918b78a67f2..051a42e5f5e4a 100755 --- a/internal/core/build.sh +++ b/internal/core/build.sh @@ -37,8 +37,9 @@ WITH_PROMETHEUS="ON" CUDA_ARCH="DEFAULT" CUSTOM_THIRDPARTY_PATH="" BUILD_DISK_ANN="OFF" +INDEX_ENGINE="knowhere" -while getopts "p:t:s:f:o:ulrcghzme" arg; do +while getopts "p:t:s:f:o:ulrcghzmex" arg; do case $arg in f) CUSTOM_THIRDPARTY_PATH=$OPTARG @@ -82,6 +83,9 @@ while getopts "p:t:s:f:o:ulrcghzme" arg; do n) BUILD_DISK_ANN="OFF" ;; + x) + INDEX_ENGINE=$OPTARG + ;; h) # help echo " parameter: @@ -154,6 +158,7 @@ CMAKE_CMD="cmake \ -DCUSTOM_THIRDPARTY_DOWNLOAD_PATH=${CUSTOM_THIRDPARTY_PATH} \ -DKNOWHERE_GPU_VERSION=${SUPPORT_GPU} \ -DBUILD_DISK_ANN=${BUILD_DISK_ANN} \ +-DINDEX_ENGINE=${INDEX_ENGINE} \ ${SCRIPTS_DIR}" echo ${CMAKE_CMD} ${CMAKE_CMD} diff --git a/internal/core/conanfile.py b/internal/core/conanfile.py index e93db20872986..9bcea02577b93 100644 --- a/internal/core/conanfile.py +++ b/internal/core/conanfile.py @@ -13,7 +13,7 @@ class MilvusConan(ConanFile): "snappy/1.1.9", "lzo/2.10", "arrow/12.0.1", - "openssl/1.1.1q", + "openssl/3.1.2", "s2n/1.3.31@milvus/dev", "aws-c-common/0.8.2@milvus/dev", "aws-c-compression/0.2.15@milvus/dev", diff --git a/internal/core/src/CMakeLists.txt b/internal/core/src/CMakeLists.txt index 70bc827fde22b..bdebcedf4483f 100644 --- a/internal/core/src/CMakeLists.txt +++ b/internal/core/src/CMakeLists.txt @@ -23,10 +23,7 @@ endif() include_directories(${MILVUS_ENGINE_SRC}) include_directories(${MILVUS_THIRDPARTY_SRC}) -set(FOUND_OPENBLAS "unknown") add_subdirectory( pb ) -add_subdirectory( exceptions ) -add_subdirectory( utils ) add_subdirectory( log ) add_subdirectory( config ) add_subdirectory( common ) diff --git a/internal/core/src/common/Array.h b/internal/core/src/common/Array.h new file mode 100644 index 0000000000000..4cf311f0ac548 --- /dev/null +++ b/internal/core/src/common/Array.h @@ -0,0 +1,645 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include +#include +#include + +#include "FieldMeta.h" +#include "Types.h" + +namespace milvus { + +class Array { + public: + Array() = default; + + ~Array() { + delete[] data_; + } + + explicit Array(const ScalarArray& field_data) { + switch (field_data.data_case()) { + case ScalarArray::kBoolData: { + element_type_ = DataType::BOOL; + length_ = field_data.bool_data().data().size(); + auto data = new bool[length_]; + size_ = length_; + offsets_.reserve(length_); + for (int i = 0; i < length_; ++i) { + data[i] = field_data.bool_data().data(i); + offsets_.push_back(sizeof(bool) * i); + } + data_ = reinterpret_cast(data); + break; + } + case ScalarArray::kIntData: { + element_type_ = DataType::INT32; + length_ = field_data.int_data().data().size(); + size_ = length_ * sizeof(int32_t); + data_ = new char[size_]; + offsets_.reserve(length_); + for (int i = 0; i < length_; ++i) { + reinterpret_cast(data_)[i] = + field_data.int_data().data(i); + offsets_.push_back(sizeof(int32_t) * i); + } + break; + } + case ScalarArray::kLongData: { + element_type_ = DataType::INT64; + length_ = field_data.long_data().data().size(); + size_ = length_ * sizeof(int64_t); + data_ = new char[size_]; + offsets_.reserve(length_); + for (int i = 0; i < length_; ++i) { + reinterpret_cast(data_)[i] = + field_data.long_data().data(i); + offsets_.push_back(sizeof(int64_t) * i); + } + break; + } + case ScalarArray::kFloatData: { + element_type_ = DataType::FLOAT; + length_ = field_data.float_data().data().size(); + size_ = length_ * sizeof(float); + data_ = new char[size_]; + offsets_.reserve(length_); + for (int i = 0; i < length_; ++i) { + reinterpret_cast(data_)[i] = + field_data.float_data().data(i); + offsets_.push_back(sizeof(float) * i); + } + break; + } + case ScalarArray::kDoubleData: { + element_type_ = DataType::DOUBLE; + length_ = field_data.double_data().data().size(); + size_ = length_ * sizeof(double); + data_ = new char[size_]; + offsets_.reserve(length_); + for (int i = 0; i < length_; ++i) { + reinterpret_cast(data_)[i] = + field_data.double_data().data(i); + offsets_.push_back(sizeof(double) * i); + } + break; + } + case ScalarArray::kStringData: { + element_type_ = DataType::STRING; + length_ = field_data.string_data().data().size(); + offsets_.reserve(length_); + for (int i = 0; i < length_; ++i) { + offsets_.push_back(size_); + size_ += field_data.string_data().data(i).size(); + } + + data_ = new char[size_]; + for (int i = 0; i < length_; ++i) { + std::copy_n(field_data.string_data().data(i).data(), + field_data.string_data().data(i).size(), + data_ + offsets_[i]); + } + break; + } + default: { + // empty array + } + } + } + + Array(char* data, + size_t size, + DataType element_type, + std::vector&& element_offsets) + : length_(element_offsets.size()), + size_(size), + offsets_(std::move(element_offsets)), + element_type_(element_type) { + delete[] data_; + data_ = new char[size]; + std::copy(data, data + size, data_); + } + + Array(const Array& array) noexcept + : length_{array.length_}, + size_{array.size_}, + element_type_{array.element_type_} { + delete[] data_; + data_ = new char[array.size_]; + std::copy(array.data_, array.data_ + array.size_, data_); + offsets_ = array.offsets_; + } + + Array& + operator=(const Array& array) { + delete[] data_; + + data_ = new char[array.size_]; + std::copy(array.data_, array.data_ + array.size_, data_); + length_ = array.length_; + size_ = array.size_; + offsets_ = array.offsets_; + element_type_ = array.element_type_; + return *this; + } + + bool + operator==(const Array& arr) const { + if (element_type_ != arr.element_type_) { + return false; + } + if (length_ != arr.length_) { + return false; + } + if (length_ == 0) { + return true; + } + switch (element_type_) { + case DataType::INT64: { + for (int i = 0; i < length_; ++i) { + if (get_data(i) != arr.get_data(i)) { + return false; + } + } + return true; + } + case DataType::BOOL: { + for (int i = 0; i < length_; ++i) { + if (get_data(i) != arr.get_data(i)) { + return false; + } + } + return true; + } + case DataType::DOUBLE: { + for (int i = 0; i < length_; ++i) { + if (get_data(i) != arr.get_data(i)) { + return false; + } + } + return true; + } + case DataType::FLOAT: { + for (int i = 0; i < length_; ++i) { + if (get_data(i) != arr.get_data(i)) { + return false; + } + } + return true; + } + case DataType::INT32: + case DataType::INT16: + case DataType::INT8: { + for (int i = 0; i < length_; ++i) { + if (get_data(i) != arr.get_data(i)) { + return false; + } + } + return true; + } + case DataType::STRING: + case DataType::VARCHAR: { + for (int i = 0; i < length_; ++i) { + if (get_data(i) != + arr.get_data(i)) { + return false; + } + } + return true; + } + default: + PanicInfo(Unsupported, "unsupported element type for array"); + } + } + + template + T + get_data(const int index) const { + AssertInfo( + index >= 0 && index < length_, + fmt::format( + "index out of range, index={}, length={}", index, length_)); + size_t element_length = (index == length_ - 1) + ? size_ - offsets_.back() + : offsets_[index + 1] - offsets_[index]; + if constexpr (std::is_same_v || + std::is_same_v) { + return T(data_ + offsets_[index], element_length); + } + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + switch (element_type_) { + case DataType::INT8: + case DataType::INT16: + case DataType::INT32: + return static_cast( + reinterpret_cast(data_)[index]); + case DataType::INT64: + return static_cast( + reinterpret_cast(data_)[index]); + case DataType::FLOAT: + return static_cast( + reinterpret_cast(data_)[index]); + case DataType::DOUBLE: + return static_cast( + reinterpret_cast(data_)[index]); + default: + PanicInfo(Unsupported, + "unsupported element type for array"); + } + } + return reinterpret_cast(data_)[index]; + } + + const std::vector& + get_offsets() const { + return offsets_; + } + + ScalarArray + output_data() const { + ScalarArray data_array; + switch (element_type_) { + case DataType::BOOL: { + for (int j = 0; j < length_; ++j) { + auto element = get_data(j); + data_array.mutable_bool_data()->add_data(element); + } + break; + } + case DataType::INT8: + case DataType::INT16: + case DataType::INT32: { + for (int j = 0; j < length_; ++j) { + auto element = get_data(j); + data_array.mutable_int_data()->add_data(element); + } + break; + } + case DataType::INT64: { + for (int j = 0; j < length_; ++j) { + auto element = get_data(j); + data_array.mutable_long_data()->add_data(element); + } + break; + } + case DataType::STRING: + case DataType::VARCHAR: { + for (int j = 0; j < length_; ++j) { + auto element = get_data(j); + data_array.mutable_string_data()->add_data(element); + } + break; + } + case DataType::FLOAT: { + for (int j = 0; j < length_; ++j) { + auto element = get_data(j); + data_array.mutable_float_data()->add_data(element); + } + break; + } + case DataType::DOUBLE: { + for (int j = 0; j < length_; ++j) { + auto element = get_data(j); + data_array.mutable_double_data()->add_data(element); + } + break; + } + default: { + // empty array + } + } + return data_array; + } + + int + length() const { + return length_; + } + + size_t + byte_size() const { + return size_; + } + + DataType + get_element_type() const { + return element_type_; + } + + const char* + data() const { + return data_; + } + + bool + is_same_array(const proto::plan::Array& arr2) const { + if (arr2.array_size() != length_) { + return false; + } + if (length_ == 0) { + return true; + } + if (!arr2.same_type()) { + return false; + } + switch (element_type_) { + case DataType::BOOL: { + for (int i = 0; i < length_; i++) { + auto val = get_data(i); + if (val != arr2.array(i).bool_val()) { + return false; + } + } + return true; + } + case DataType::INT8: + case DataType::INT16: + case DataType::INT32: { + for (int i = 0; i < length_; i++) { + auto val = get_data(i); + if (val != arr2.array(i).int64_val()) { + return false; + } + } + return true; + } + case DataType::INT64: { + for (int i = 0; i < length_; i++) { + auto val = get_data(i); + if (val != arr2.array(i).int64_val()) { + return false; + } + } + return true; + } + case DataType::FLOAT: { + for (int i = 0; i < length_; i++) { + auto val = get_data(i); + if (val != arr2.array(i).float_val()) { + return false; + } + } + return true; + } + case DataType::DOUBLE: { + for (int i = 0; i < length_; i++) { + auto val = get_data(i); + if (val != arr2.array(i).float_val()) { + return false; + } + } + return true; + } + case DataType::VARCHAR: + case DataType::STRING: { + for (int i = 0; i < length_; i++) { + auto val = get_data(i); + if (val != arr2.array(i).string_val()) { + return false; + } + } + return true; + } + default: + return false; + } + } + + private: + char* data_{nullptr}; + int length_ = 0; + int size_ = 0; + std::vector offsets_{}; + DataType element_type_ = DataType::NONE; +}; + +class ArrayView { + public: + ArrayView() = default; + + ArrayView(char* data, + size_t size, + DataType element_type, + std::vector&& element_offsets) + : size_(size), + element_type_(element_type), + offsets_(std::move(element_offsets)), + length_(element_offsets.size()) { + data_ = data; + } + + template + T + get_data(const int index) const { + AssertInfo( + index >= 0 && index < length_, + fmt::format( + "index out of range, index={}, length={}", index, length_)); + size_t element_length = (index == length_ - 1) + ? size_ - offsets_.back() + : offsets_[index + 1] - offsets_[index]; + if constexpr (std::is_same_v || + std::is_same_v) { + return T(data_ + offsets_[index], element_length); + } + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + switch (element_type_) { + case DataType::INT8: + case DataType::INT16: + case DataType::INT32: + return static_cast( + reinterpret_cast(data_)[index]); + case DataType::INT64: + return static_cast( + reinterpret_cast(data_)[index]); + case DataType::FLOAT: + return static_cast( + reinterpret_cast(data_)[index]); + case DataType::DOUBLE: + return static_cast( + reinterpret_cast(data_)[index]); + default: + PanicInfo(Unsupported, + "unsupported element type for array"); + } + } + return reinterpret_cast(data_)[index]; + } + + ScalarArray + output_data() const { + ScalarArray data_array; + switch (element_type_) { + case DataType::BOOL: { + for (int j = 0; j < length_; ++j) { + auto element = get_data(j); + data_array.mutable_bool_data()->add_data(element); + } + break; + } + case DataType::INT8: + case DataType::INT16: + case DataType::INT32: { + for (int j = 0; j < length_; ++j) { + auto element = get_data(j); + data_array.mutable_int_data()->add_data(element); + } + break; + } + case DataType::INT64: { + for (int j = 0; j < length_; ++j) { + auto element = get_data(j); + data_array.mutable_long_data()->add_data(element); + } + break; + } + case DataType::STRING: + case DataType::VARCHAR: { + for (int j = 0; j < length_; ++j) { + auto element = get_data(j); + data_array.mutable_string_data()->add_data(element); + } + break; + } + case DataType::FLOAT: { + for (int j = 0; j < length_; ++j) { + auto element = get_data(j); + data_array.mutable_float_data()->add_data(element); + } + break; + } + case DataType::DOUBLE: { + for (int j = 0; j < length_; ++j) { + auto element = get_data(j); + data_array.mutable_double_data()->add_data(element); + } + break; + } + default: { + // empty array + } + } + return data_array; + } + + int + length() const { + return length_; + } + + size_t + byte_size() const { + return size_; + } + + DataType + get_element_type() const { + return element_type_; + } + + const void* + data() const { + return data_; + } + + bool + is_same_array(const proto::plan::Array& arr2) const { + if (arr2.array_size() != length_) { + return false; + } + if (!arr2.same_type()) { + return false; + } + switch (element_type_) { + case DataType::BOOL: { + for (int i = 0; i < length_; i++) { + auto val = get_data(i); + if (val != arr2.array(i).bool_val()) { + return false; + } + } + return true; + } + case DataType::INT8: + case DataType::INT16: + case DataType::INT32: { + for (int i = 0; i < length_; i++) { + auto val = get_data(i); + if (val != arr2.array(i).int64_val()) { + return false; + } + } + return true; + } + case DataType::INT64: { + for (int i = 0; i < length_; i++) { + auto val = get_data(i); + if (val != arr2.array(i).int64_val()) { + return false; + } + } + return true; + } + case DataType::FLOAT: { + for (int i = 0; i < length_; i++) { + auto val = get_data(i); + if (val != arr2.array(i).float_val()) { + return false; + } + } + return true; + } + case DataType::DOUBLE: { + for (int i = 0; i < length_; i++) { + auto val = get_data(i); + if (val != arr2.array(i).float_val()) { + return false; + } + } + return true; + } + case DataType::VARCHAR: + case DataType::STRING: { + for (int i = 0; i < length_; i++) { + auto val = get_data(i); + if (val != arr2.array(i).string_val()) { + return false; + } + } + return true; + } + default: + return length_ == 0; + } + } + + private: + char* data_{nullptr}; + int length_ = 0; + int size_ = 0; + std::vector offsets_{}; + DataType element_type_ = DataType::NONE; +}; + +} // namespace milvus diff --git a/internal/core/src/common/BitsetView.h b/internal/core/src/common/BitsetView.h index 777c3ef6bb3d9..57ce86a97e912 100644 --- a/internal/core/src/common/BitsetView.h +++ b/internal/core/src/common/BitsetView.h @@ -22,7 +22,7 @@ #include #include "common/Types.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "knowhere/bitsetview.h" namespace milvus { diff --git a/internal/core/src/common/CMakeLists.txt b/internal/core/src/common/CMakeLists.txt index 8996aa734811f..7fdac692ed909 100644 --- a/internal/core/src/common/CMakeLists.txt +++ b/internal/core/src/common/CMakeLists.txt @@ -20,14 +20,16 @@ set(COMMON_SRC Common.cpp RangeSearchHelper.cpp Tracer.cpp - IndexMeta.cpp) + IndexMeta.cpp + EasyAssert.cpp +) add_library(milvus_common SHARED ${COMMON_SRC}) target_link_libraries(milvus_common milvus_config - milvus_utils milvus_log + milvus_proto yaml-cpp boost_bitset_ext simdjson diff --git a/internal/core/src/common/Consts.h b/internal/core/src/common/Consts.h index 50a4f2dc3202d..9de3b0e48b566 100644 --- a/internal/core/src/common/Consts.h +++ b/internal/core/src/common/Consts.h @@ -52,3 +52,5 @@ constexpr const char* RADIUS = knowhere::meta::RADIUS; constexpr const char* RANGE_FILTER = knowhere::meta::RANGE_FILTER; const int64_t DEFAULT_MAX_OUTPUT_SIZE = 67108864; // bytes, 64MB + +const int64_t DEFAULT_CHUNK_MANAGER_REQUEST_TIMEOUT_MS = 3000; diff --git a/internal/core/src/exceptions/EasyAssert.cpp b/internal/core/src/common/EasyAssert.cpp similarity index 83% rename from internal/core/src/exceptions/EasyAssert.cpp rename to internal/core/src/common/EasyAssert.cpp index 77d80ccd6d0c6..fceb87e35f37f 100644 --- a/internal/core/src/exceptions/EasyAssert.cpp +++ b/internal/core/src/common/EasyAssert.cpp @@ -16,7 +16,7 @@ #include #include "EasyAssert.h" -// #define BOOST_STACKTRACE_USE_BACKTRACE +#include "fmt/core.h" #include #include @@ -40,13 +40,16 @@ EasyAssertInfo(bool value, std::string_view filename, int lineno, std::string_view extra_info, - ErrorCodeEnum error_code) { + ErrorCode error_code) { // enable error code if (!value) { std::string info; - info += "Assert \"" + std::string(expr_str) + "\""; - info += " at " + std::string(filename) + ":" + std::to_string(lineno) + - "\n"; + if (!expr_str.empty()) { + info += fmt::format("Assert \"{}\" at {}:{}\n", + expr_str, + std::string(filename), + std::to_string(lineno)); + } if (!extra_info.empty()) { info += " => " + std::string(extra_info); } diff --git a/internal/core/src/exceptions/EasyAssert.h b/internal/core/src/common/EasyAssert.h similarity index 53% rename from internal/core/src/exceptions/EasyAssert.h rename to internal/core/src/common/EasyAssert.h index cabacdad5ceea..07a444c4a70b9 100644 --- a/internal/core/src/exceptions/EasyAssert.h +++ b/internal/core/src/common/EasyAssert.h @@ -18,14 +18,46 @@ #include #include #include -#include -#include +#include +#include #include #include "pb/common.pb.h" +#include "common/type_c.h" /* Paste this on the file if you want to debug. */ namespace milvus { -using ErrorCodeEnum = proto::common::ErrorCode; +enum ErrorCode { + Success = 0, + UnexpectedError = 2001, + NotImplemented = 2002, + Unsupported = 2003, + IndexBuildError = 2004, + IndexAlreadyBuild = 2005, + ConfigInvalid = 2006, + DataTypeInvalid = 2007, + PathInvalid = 2009, + PathAlreadyExist = 2010, + PathNotExist = 2011, + FileOpenFailed = 2012, + FileCreateFailed = 2013, + FileReadFailed = 2014, + FileWriteFailed = 2015, + BucketInvalid = 2016, + ObjectNotExist = 2017, + S3Error = 2018, + RetrieveError = 2019, + FieldIDInvalid = 2020, + FieldAlreadyExist = 2021, + OpTypeInvalid = 2022, + DataIsEmpty = 2023, + DataFormatBroken = 2024, + JsonKeyInvalid = 2025, + MetricTypeInvalid = 2026, + FieldNotLoaded = 2027, + ExprInvalid = 2028, + UnistdError = 2030, + KnowhereError = 2100, +}; namespace impl { void EasyAssertInfo(bool value, @@ -33,25 +65,55 @@ EasyAssertInfo(bool value, std::string_view filename, int lineno, std::string_view extra_info, - ErrorCodeEnum error_code = ErrorCodeEnum::UnexpectedError); + ErrorCode error_code = ErrorCode::UnexpectedError); } // namespace impl class SegcoreError : public std::runtime_error { public: - SegcoreError(ErrorCodeEnum error_code, const std::string& error_msg) + static SegcoreError + success() { + return {ErrorCode::Success, ""}; + } + + SegcoreError(ErrorCode error_code, const std::string& error_msg) : std::runtime_error(error_msg), error_code_(error_code) { } - ErrorCodeEnum + ErrorCode get_error_code() { return error_code_; } + bool + ok() { + return error_code_ == ErrorCode::Success; + } + private: - ErrorCodeEnum error_code_; + ErrorCode error_code_; }; +inline CStatus +SuccessCStatus() { + return CStatus{Success, ""}; +} + +inline CStatus +FailureCStatus(int code, const std::string& msg) { + return CStatus{code, strdup(msg.data())}; +} + +inline CStatus +FailureCStatus(std::exception* ex) { + if (dynamic_cast(ex) != nullptr) { + auto segcore_error = dynamic_cast(ex); + return CStatus{static_cast(segcore_error->get_error_code()), + strdup(ex->what())}; + } + return CStatus{static_cast(UnexpectedError), strdup(ex->what())}; +} + } // namespace milvus #define AssertInfo(expr, info) \ @@ -65,15 +127,10 @@ class SegcoreError : public std::runtime_error { } while (0) #define Assert(expr) AssertInfo((expr), "") -#define PanicInfo(info) \ - do { \ - milvus::impl::EasyAssertInfo(false, (info), __FILE__, __LINE__, ""); \ - __builtin_unreachable(); \ - } while (0) -#define PanicCodeInfo(errcode, info) \ +#define PanicInfo(errcode, info) \ do { \ milvus::impl::EasyAssertInfo( \ - false, (info), __FILE__, __LINE__, "", errcode); \ + false, "", __FILE__, __LINE__, (info), errcode); \ __builtin_unreachable(); \ } while (0) diff --git a/internal/core/src/common/FieldMeta.h b/internal/core/src/common/FieldMeta.h index 35098dcf3a8a0..c6d1775595068 100644 --- a/internal/core/src/common/FieldMeta.h +++ b/internal/core/src/common/FieldMeta.h @@ -21,8 +21,7 @@ #include #include "common/Types.h" -#include "exceptions/EasyAssert.h" -#include "utils/Status.h" +#include "common/EasyAssert.h" namespace milvus { @@ -49,8 +48,12 @@ datatype_sizeof(DataType data_type, int dim = 1) { AssertInfo(dim % 8 == 0, "dim=" + std::to_string(dim)); return dim / 8; } + case DataType::VECTOR_FLOAT16: { + return sizeof(float16) * dim; + } default: { - throw std::invalid_argument("unsupported data type"); + throw SegcoreError(DataTypeInvalid, + fmt::format("invalid type is {}", data_type)); } } } @@ -59,6 +62,8 @@ datatype_sizeof(DataType data_type, int dim = 1) { inline std::string datatype_name(DataType data_type) { switch (data_type) { + case DataType::NONE: + return "none"; case DataType::BOOL: return "bool"; case DataType::INT8: @@ -73,6 +78,8 @@ datatype_name(DataType data_type) { return "float"; case DataType::DOUBLE: return "double"; + case DataType::STRING: + return "string"; case DataType::VARCHAR: return "varChar"; case DataType::ARRAY: @@ -84,10 +91,12 @@ datatype_name(DataType data_type) { case DataType::VECTOR_BINARY: { return "vector_binary"; } + case DataType::VECTOR_FLOAT16: { + return "vector_float16"; + } default: { - auto err_msg = - "Unsupported DataType(" + std::to_string((int)data_type) + ")"; - PanicInfo(err_msg); + PanicInfo(DataTypeInvalid, + fmt::format("Unsupported DataType({})", data_type)); } } } @@ -95,7 +104,8 @@ datatype_name(DataType data_type) { inline bool datatype_is_vector(DataType datatype) { return datatype == DataType::VECTOR_BINARY || - datatype == DataType::VECTOR_FLOAT; + datatype == DataType::VECTOR_FLOAT || + datatype == DataType::VECTOR_FLOAT16; } inline bool @@ -120,6 +130,16 @@ datatype_is_binary(DataType datatype) { } } +inline bool +datatype_is_json(DataType datatype) { + return datatype == DataType::JSON; +} + +inline bool +datatype_is_array(DataType datatype) { + return datatype == DataType::ARRAY; +} + inline bool datatype_is_variable(DataType datatype) { switch (datatype) { @@ -183,6 +203,14 @@ class FieldMeta { Assert(datatype_is_string(type_)); } + FieldMeta(const FieldName& name, + FieldId id, + DataType type, + DataType element_type) + : name_(name), id_(id), type_(type), element_type_(element_type) { + Assert(datatype_is_array(type_)); + } + FieldMeta(const FieldName& name, FieldId id, DataType type, @@ -231,6 +259,11 @@ class FieldMeta { return type_; } + DataType + get_element_type() const { + return element_type_; + } + bool is_vector() const { return datatype_is_vector(type_); @@ -267,6 +300,7 @@ class FieldMeta { FieldName name_; FieldId id_; DataType type_ = DataType::NONE; + DataType element_type_ = DataType::NONE; std::optional vector_info_; std::optional string_info_; }; diff --git a/internal/core/src/utils/File.h b/internal/core/src/common/File.h similarity index 96% rename from internal/core/src/utils/File.h rename to internal/core/src/common/File.h index 6ce8ae5722210..3622bcbf033c0 100644 --- a/internal/core/src/utils/File.h +++ b/internal/core/src/common/File.h @@ -12,7 +12,7 @@ #pragma once #include -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "fmt/core.h" #include #include @@ -62,4 +62,4 @@ class File { } int fd_{-1}; }; -} // namespace milvus \ No newline at end of file +} // namespace milvus diff --git a/internal/core/src/common/IndexMeta.h b/internal/core/src/common/IndexMeta.h index 89ad9fd01c3b4..132c1fd34b170 100644 --- a/internal/core/src/common/IndexMeta.h +++ b/internal/core/src/common/IndexMeta.h @@ -21,6 +21,7 @@ #include "pb/common.pb.h" #include "pb/segcore.pb.h" +#include "knowhere/utils.h" #include "Types.h" namespace milvus { @@ -44,6 +45,11 @@ class FieldIndexMeta { return index_params_.at(knowhere::meta::INDEX_TYPE); } + bool + IsFlatIndex() const { + return knowhere::IsFlatIndex(GetIndexType()); + } + const std::map& GetIndexParams() const { return index_params_; diff --git a/internal/core/src/common/Json.h b/internal/core/src/common/Json.h index 91df2b94144b7..892652955a366 100644 --- a/internal/core/src/common/Json.h +++ b/internal/core/src/common/Json.h @@ -25,7 +25,7 @@ #include #include -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "simdjson.h" #include "fmt/core.h" #include "simdjson/common_defs.h" diff --git a/internal/core/src/common/LoadInfo.h b/internal/core/src/common/LoadInfo.h index 554d70d5f5cbc..f58a002b237f6 100644 --- a/internal/core/src/common/LoadInfo.h +++ b/internal/core/src/common/LoadInfo.h @@ -28,6 +28,7 @@ struct FieldBinlogInfo { int64_t field_id; int64_t row_count = -1; + std::vector entries_nums; std::vector insert_files; }; diff --git a/internal/core/src/common/RangeSearchHelper.cpp b/internal/core/src/common/RangeSearchHelper.cpp index 36c68f35bd627..9e51dac1e6541 100644 --- a/internal/core/src/common/RangeSearchHelper.cpp +++ b/internal/core/src/common/RangeSearchHelper.cpp @@ -88,7 +88,7 @@ ReGenRangeSearchResult(DatasetPtr data_set, pq(cmp); auto capacity = std::min(lims[i + 1] - lims[i], topk); - for (int j = lims[i]; j < lims[i + 1]; j++) { + for (size_t j = lims[i]; j < lims[i + 1]; j++) { auto curr = ResultPair(dist[j], id[j]); if (pq.size() < capacity) { pq.push(curr); diff --git a/internal/core/src/common/Schema.cpp b/internal/core/src/common/Schema.cpp index e35ff7368c2c4..7265ba10871ad 100644 --- a/internal/core/src/common/Schema.cpp +++ b/internal/core/src/common/Schema.cpp @@ -69,6 +69,9 @@ Schema::ParseFrom(const milvus::proto::schema::CollectionSchema& schema_proto) { auto max_len = boost::lexical_cast(type_map.at(MAX_LENGTH)); schema->AddField(name, field_id, data_type, max_len); + } else if (datatype_is_array(data_type)) { + schema->AddField( + name, field_id, data_type, DataType(child.element_type())); } else { schema->AddField(name, field_id, data_type); } diff --git a/internal/core/src/common/Schema.h b/internal/core/src/common/Schema.h index 6e2a933095968..71187f1004564 100644 --- a/internal/core/src/common/Schema.h +++ b/internal/core/src/common/Schema.h @@ -41,6 +41,16 @@ class Schema { return field_id; } + FieldId + AddDebugField(const std::string& name, + DataType data_type, + DataType element_type) { + auto field_id = FieldId(debug_id); + debug_id++; + this->AddField(FieldName(name), field_id, data_type, element_type); + return field_id; + } + // auto gen field_id for convenience FieldId AddDebugField(const std::string& name, @@ -62,6 +72,16 @@ class Schema { this->AddField(std::move(field_meta)); } + // array type + void + AddField(const FieldName& name, + const FieldId id, + DataType data_type, + DataType element_type) { + auto field_meta = FieldMeta(name, id, data_type, element_type); + this->AddField(std::move(field_meta)); + } + // string type void AddField(const FieldName& name, diff --git a/internal/core/src/common/SystemProperty.cpp b/internal/core/src/common/SystemProperty.cpp index 49f889a88c57a..ad42ea541a745 100644 --- a/internal/core/src/common/SystemProperty.cpp +++ b/internal/core/src/common/SystemProperty.cpp @@ -18,7 +18,7 @@ #include "SystemProperty.h" #include "Consts.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" namespace milvus { class SystemPropertyImpl : public SystemProperty { diff --git a/internal/core/src/common/SystemProperty.h b/internal/core/src/common/SystemProperty.h index 373e4e42a3f4f..65800aea38dbf 100644 --- a/internal/core/src/common/SystemProperty.h +++ b/internal/core/src/common/SystemProperty.h @@ -18,8 +18,7 @@ #include -#include "Types.h" -#include "utils/Json.h" +#include "common/Types.h" namespace milvus { @@ -64,3 +63,23 @@ class SystemProperty { }; } // namespace milvus + +template <> +struct fmt::formatter : formatter { + auto + format(milvus::SystemFieldType c, format_context& ctx) const { + string_view name = "unknown"; + switch (c) { + case milvus::SystemFieldType::Invalid: + name = "Invalid"; + break; + case milvus::SystemFieldType::RowId: + name = "RowId"; + break; + case milvus::SystemFieldType::Timestamp: + name = "Timestamp"; + break; + } + return formatter::format(name, ctx); + } +}; diff --git a/internal/core/src/common/Tracer.cpp b/internal/core/src/common/Tracer.cpp index 69fc899fc8eed..21a4c637092f2 100644 --- a/internal/core/src/common/Tracer.cpp +++ b/internal/core/src/common/Tracer.cpp @@ -11,6 +11,8 @@ #include "log/Log.h" #include "Tracer.h" +#include + #include "opentelemetry/exporters/ostream/span_exporter_factory.h" #include "opentelemetry/exporters/jaeger/jaeger_exporter_factory.h" #include "opentelemetry/exporters/otlp/otlp_grpc_exporter_factory.h" @@ -103,7 +105,7 @@ thread_local std::shared_ptr local_span; void SetRootSpan(std::shared_ptr span) { if (enable_trace) { - local_span = span; + local_span = std::move(span); } } @@ -123,7 +125,7 @@ AddEvent(std::string event_label) { bool isEmptyID(const uint8_t* id, int length) { - for (size_t i = 0; i < length; i++) { + for (int i = 0; i < length; i++) { if (id[i] != 0) { return false; } diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index 0846b8cb536ed..2db86a0390002 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -35,6 +35,7 @@ #include #include +#include "fmt/core.h" #include "knowhere/binaryset.h" #include "knowhere/comp/index_param.h" #include "knowhere/dataset.h" @@ -51,6 +52,33 @@ using offset_t = int32_t; using date_t = int32_t; using distance_t = float; +union float16 { + unsigned short bits; + struct { + unsigned short mantissa : 10; + unsigned short exponent : 5; + unsigned short sign : 1; + } parts; + float16() { + } + float16(float f) { + unsigned int i = *(unsigned int*)&f; + unsigned int sign = (i >> 31) & 0x0001; + unsigned int exponent = ((i >> 23) & 0xff) - 127 + 15; + unsigned int mantissa = (i >> 13) & 0x3ff; + parts.sign = sign; + parts.exponent = exponent; + parts.mantissa = mantissa; + } + operator float() const { + unsigned int sign = parts.sign << 31; + unsigned int exponent = (parts.exponent - 15 + 127) << 23; + unsigned int mantissa = parts.mantissa << 13; + unsigned int bits = sign | exponent | mantissa; + return *(float*)&bits; + } +}; + enum class DataType { NONE = 0, BOOL = 1, @@ -69,6 +97,7 @@ enum class DataType { VECTOR_BINARY = 100, VECTOR_FLOAT = 101, + VECTOR_FLOAT16 = 102, }; using Timestamp = uint64_t; // TODO: use TiKV-like timestamp @@ -137,6 +166,7 @@ using BinarySet = knowhere::BinarySet; using Dataset = knowhere::DataSet; using DatasetPtr = knowhere::DataSetPtr; using MetricType = knowhere::MetricType; +using IndexVersion = knowhere::IndexVersion; // TODO :: type define milvus index type(vector index type and scalar index type) using IndexType = knowhere::IndexType; @@ -202,6 +232,65 @@ struct fmt::formatter : formatter { case milvus::DataType::VECTOR_FLOAT: name = "VECTOR_FLOAT"; break; + case milvus::DataType::VECTOR_FLOAT16: + name = "VECTOR_FLOAT16"; + break; + } + return formatter::format(name, ctx); + } +}; + +template <> +struct fmt::formatter : formatter { + auto + format(milvus::OpType c, format_context& ctx) const { + string_view name = "unknown"; + switch (c) { + case milvus::OpType::Invalid: + name = "Invalid"; + break; + case milvus::OpType::GreaterThan: + name = "GreaterThan"; + break; + case milvus::OpType::GreaterEqual: + name = "GreaterEqual"; + break; + case milvus::OpType::LessThan: + name = "LessThan"; + break; + case milvus::OpType::LessEqual: + name = "LessEqual"; + break; + case milvus::OpType::Equal: + name = "Equal"; + break; + case milvus::OpType::NotEqual: + name = "NotEqual"; + break; + case milvus::OpType::PrefixMatch: + name = "PrefixMatch"; + break; + case milvus::OpType::PostfixMatch: + name = "PostfixMatch"; + break; + case milvus::OpType::Match: + name = "Match"; + break; + case milvus::OpType::Range: + name = "Range"; + break; + case milvus::OpType::In: + name = "In"; + break; + case milvus::OpType::NotIn: + name = "NotIn"; + break; + case milvus::OpType::OpType_INT_MIN_SENTINEL_DO_NOT_USE_: + name = "OpType_INT_MIN_SENTINEL_DO_NOT_USE"; + break; + case milvus::OpType::OpType_INT_MAX_SENTINEL_DO_NOT_USE_: + name = "OpType_INT_MAX_SENTINEL_DO_NOT_USE"; + break; } return formatter::format(name, ctx); } diff --git a/internal/core/src/common/Utils.h b/internal/core/src/common/Utils.h index d90a72deba23b..2e8e245af0c8c 100644 --- a/internal/core/src/common/Utils.h +++ b/internal/core/src/common/Utils.h @@ -28,7 +28,7 @@ #include "common/FieldMeta.h" #include "common/LoadInfo.h" #include "common/Types.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "knowhere/dataset.h" #include "knowhere/expected.h" #include "simdjson.h" @@ -166,46 +166,7 @@ PositivelyRelated(const knowhere::MetricType& metric_type) { inline std::string KnowhereStatusString(knowhere::Status status) { - switch (status) { - case knowhere::Status::invalid_args: - return "invalid args"; - case knowhere::Status::invalid_param_in_json: - return "invalid param in json"; - case knowhere::Status::out_of_range_in_json: - return "out of range in json"; - case knowhere::Status::type_conflict_in_json: - return "type conflict in json"; - case knowhere::Status::invalid_metric_type: - return "invalid metric type"; - case knowhere::Status::empty_index: - return "empty index"; - case knowhere::Status::not_implemented: - return "not implemented"; - case knowhere::Status::index_not_trained: - return "index not trained"; - case knowhere::Status::index_already_trained: - return "index already trained"; - case knowhere::Status::faiss_inner_error: - return "faiss inner error"; - case knowhere::Status::hnsw_inner_error: - return "hnsw inner error"; - case knowhere::Status::malloc_error: - return "malloc error"; - case knowhere::Status::diskann_inner_error: - return "diskann inner error"; - case knowhere::Status::diskann_file_error: - return "diskann file error"; - case knowhere::Status::invalid_value_in_json: - return "invalid value in json"; - case knowhere::Status::arithmetic_overflow: - return "arithmetic overflow"; - case knowhere::Status::raft_inner_error: - return "raft inner error"; - case knowhere::Status::invalid_binary_set: - return "invalid binary set"; - default: - return "unexpected status"; - } + return knowhere::Status2String(status); } inline std::vector diff --git a/internal/core/src/common/VectorTrait.h b/internal/core/src/common/VectorTrait.h index bd8805bcad286..a6a899abf03fb 100644 --- a/internal/core/src/common/VectorTrait.h +++ b/internal/core/src/common/VectorTrait.h @@ -18,6 +18,7 @@ #include "Types.h" #include #include +#include "Array.h" namespace milvus { @@ -35,12 +36,20 @@ class BinaryVector : public VectorTrait { static constexpr auto metric_type = DataType::VECTOR_BINARY; }; +class Float16Vector : public VectorTrait { + public: + using embedded_type = float16; + static constexpr auto metric_type = DataType::VECTOR_FLOAT16; +}; + template inline constexpr int64_t element_sizeof(int64_t dim) { static_assert(std::is_base_of_v); if constexpr (std::is_same_v) { return dim * sizeof(float); + } else if constexpr (std::is_same_v) { + return dim * sizeof(float16); } else { return dim / 8; } @@ -52,7 +61,9 @@ constexpr bool IsVector = std::is_base_of_v; template constexpr bool IsScalar = std::is_fundamental_v || std::is_same_v || - std::is_same_v || std::is_same_v; + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v; template struct EmbeddedTypeImpl; @@ -64,8 +75,10 @@ struct EmbeddedTypeImpl>> { template struct EmbeddedTypeImpl>> { - using type = - std::conditional_t, float, uint8_t>; + using type = std::conditional_t< + std::is_same_v, + float, + std::conditional_t, float16, uint8_t>>; }; template diff --git a/internal/core/src/common/binary_set_c.cpp b/internal/core/src/common/binary_set_c.cpp index 8495ae73f6bd0..5de5fa0f18bb2 100644 --- a/internal/core/src/common/binary_set_c.cpp +++ b/internal/core/src/common/binary_set_c.cpp @@ -14,6 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "common/EasyAssert.h" #include "knowhere/binaryset.h" #include "common/binary_set_c.h" @@ -23,12 +24,12 @@ NewBinarySet(CBinarySet* c_binary_set) { auto binary_set = std::make_unique(); *c_binary_set = binary_set.release(); auto status = CStatus(); - status.error_code = Success; + status.error_code = milvus::ErrorCode::Success; status.error_msg = ""; return status; } catch (std::exception& e) { auto status = CStatus(); - status.error_code = UnexpectedError; + status.error_code = milvus::ErrorCode::UnexpectedError; status.error_msg = strdup(e.what()); return status; } @@ -55,10 +56,10 @@ AppendIndexBinary(CBinarySet c_binary_set, std::shared_ptr data(dup); binary_set->Append(index_key, data, index_size); - status.error_code = Success; + status.error_code = milvus::ErrorCode::Success; status.error_msg = ""; } catch (std::exception& e) { - status.error_code = UnexpectedError; + status.error_code = milvus::ErrorCode::UnexpectedError; status.error_msg = strdup(e.what()); } return status; @@ -100,11 +101,11 @@ CopyBinarySetValue(void* data, const char* key, CBinarySet c_binary_set) { auto binary_set = (knowhere::BinarySet*)c_binary_set; try { auto binary = binary_set->GetByName(key); - status.error_code = Success; + status.error_code = milvus::ErrorCode::Success; status.error_msg = ""; memcpy((uint8_t*)data, binary->data.get(), binary->size); } catch (std::exception& e) { - status.error_code = UnexpectedError; + status.error_code = milvus::ErrorCode::UnexpectedError; status.error_msg = strdup(e.what()); } return status; diff --git a/internal/core/src/common/protobuf_utils.h b/internal/core/src/common/protobuf_utils.h index 387b168de9ef7..4a6502fc4f95d 100644 --- a/internal/core/src/common/protobuf_utils.h +++ b/internal/core/src/common/protobuf_utils.h @@ -21,7 +21,7 @@ #include #include "pb/schema.pb.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" using std::string; @@ -38,4 +38,4 @@ RepeatedKeyValToMap( } return mapping; } -} //namespace milvus \ No newline at end of file +} //namespace milvus diff --git a/internal/core/src/common/type_c.h b/internal/core/src/common/type_c.h index 1ab6176e0045e..e432e899e6f69 100644 --- a/internal/core/src/common/type_c.h +++ b/internal/core/src/common/type_c.h @@ -31,12 +31,6 @@ enum SegmentType { typedef enum SegmentType SegmentType; -enum ErrorCode { - Success = 0, - UnexpectedError = 1, - IllegalArgument = 5, -}; - // pure C don't support that we use schemapb.DataType directly. // Note: the value of all enumerations must match the corresponding schemapb.DataType. // TODO: what if there are increments in schemapb.DataType. @@ -56,6 +50,7 @@ enum CDataType { BinaryVector = 100, FloatVector = 101, + Float16Vector = 102, }; typedef enum CDataType CDataType; @@ -83,12 +78,14 @@ typedef struct CStorageConfig { const char* access_key_value; const char* root_path; const char* storage_type; + const char* cloud_provider; const char* iam_endpoint; const char* log_level; const char* region; bool useSSL; bool useIAM; bool useVirtualHost; + int64_t requestTimeoutMs; } CStorageConfig; typedef struct CTraceConfig { @@ -107,4 +104,5 @@ typedef struct CTraceContext { } CTraceContext; #ifdef __cplusplus } + #endif diff --git a/internal/core/src/config/CMakeLists.txt b/internal/core/src/config/CMakeLists.txt index a167af69b34fc..36f3ccc6a785f 100644 --- a/internal/core/src/config/CMakeLists.txt +++ b/internal/core/src/config/CMakeLists.txt @@ -20,13 +20,8 @@ if ( EMBEDDED_MILVUS ) add_compile_definitions( EMBEDDED_MILVUS ) endif() -set(CONFIG_SRC - ConfigKnowhere.cpp - ) +set(CONFIG_SRC ConfigKnowhere.cpp) add_library(milvus_config STATIC ${CONFIG_SRC}) -target_link_libraries(milvus_config - milvus_exceptions - knowhere - ) +target_link_libraries(milvus_config knowhere) diff --git a/internal/core/src/config/ConfigKnowhere.cpp b/internal/core/src/config/ConfigKnowhere.cpp index d57681f6865c7..cfe235876221c 100644 --- a/internal/core/src/config/ConfigKnowhere.cpp +++ b/internal/core/src/config/ConfigKnowhere.cpp @@ -17,11 +17,11 @@ #include #include "ConfigKnowhere.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "glog/logging.h" #include "log/Log.h" -#include "knowhere/comp/thread_pool.h" #include "knowhere/comp/knowhere_config.h" +#include "knowhere/version.h" namespace milvus::config { @@ -40,10 +40,11 @@ KnowhereInitImpl(const char* conf_file) { #ifdef EMBEDDED_MILVUS // always disable all logs for embedded milvus google::SetCommandLineOption("minloglevel", "4"); -#endif +#else if (conf_file != nullptr) { gflags::SetCommandLineOption("flagfile", conf_file); } +#endif }; std::call_once(init_knowhere_once_, init); @@ -61,28 +62,39 @@ KnowhereSetSimdType(const char* value) { } else if (strcmp(value, "avx") == 0 || strcmp(value, "sse4_2") == 0) { simd_type = knowhere::KnowhereConfig::SimdType::SSE4_2; } else { - PanicInfo("invalid SIMD type: " + std::string(value)); + PanicInfo(ConfigInvalid, "invalid SIMD type: " + std::string(value)); } try { return knowhere::KnowhereConfig::SetSimdType(simd_type); } catch (std::exception& e) { LOG_SERVER_ERROR_ << e.what(); - PanicInfo(e.what()); + PanicInfo(ConfigInvalid, e.what()); } } void KnowhereInitBuildThreadPool(const uint32_t num_threads) { - knowhere::ThreadPool::InitGlobalBuildThreadPool(num_threads); + knowhere::KnowhereConfig::SetBuildThreadPoolSize(num_threads); } void KnowhereInitSearchThreadPool(const uint32_t num_threads) { - knowhere::ThreadPool::InitGlobalSearchThreadPool(num_threads); + knowhere::KnowhereConfig::SetSearchThreadPoolSize(num_threads); if (!knowhere::KnowhereConfig::SetAioContextPool(num_threads)) { - PanicInfo("Failed to set aio context pool with num_threads " + - std::to_string(num_threads)); + PanicInfo(ConfigInvalid, + "Failed to set aio context pool with num_threads " + + std::to_string(num_threads)); } } +int32_t +GetMinimalIndexVersion() { + return knowhere::Version::GetMinimalVersion().VersionNumber(); +} + +int32_t +GetCurrentIndexVersion() { + return knowhere::Version::GetCurrentVersion().VersionNumber(); +} + } // namespace milvus::config diff --git a/internal/core/src/config/ConfigKnowhere.h b/internal/core/src/config/ConfigKnowhere.h index d76a898946480..c7584f2e7d96c 100644 --- a/internal/core/src/config/ConfigKnowhere.h +++ b/internal/core/src/config/ConfigKnowhere.h @@ -30,4 +30,11 @@ KnowhereInitBuildThreadPool(const uint32_t); void KnowhereInitSearchThreadPool(const uint32_t); + +int32_t +GetMinimalIndexVersion(); + +int32_t +GetCurrentIndexVersion(); + } // namespace milvus::config diff --git a/internal/core/src/exceptions/CMakeLists.txt b/internal/core/src/exceptions/CMakeLists.txt deleted file mode 100644 index b92906ca8f2a1..0000000000000 --- a/internal/core/src/exceptions/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -# Licensed to the LF AI & Data foundation under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set(exceptions_files - EasyAssert.cpp -) - -add_library(milvus_exceptions STATIC ${exceptions_files}) - -target_link_libraries(milvus_exceptions milvus_proto) diff --git a/internal/core/src/index/BoolIndex.h b/internal/core/src/index/BoolIndex.h index 3e18a127f0a18..fe9b3df66a8dc 100644 --- a/internal/core/src/index/BoolIndex.h +++ b/internal/core/src/index/BoolIndex.h @@ -25,7 +25,8 @@ namespace milvus::index { using BoolIndexPtr = std::shared_ptr>; inline BoolIndexPtr -CreateBoolIndex(storage::FileManagerImplPtr file_manager = nullptr) { - return std::make_unique>(file_manager); +CreateBoolIndex(const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()) { + return std::make_unique>(file_manager_context); } } // namespace milvus::index diff --git a/internal/core/src/index/CMakeLists.txt b/internal/core/src/index/CMakeLists.txt index c429b575ddfd8..1a5dbf200a8ee 100644 --- a/internal/core/src/index/CMakeLists.txt +++ b/internal/core/src/index/CMakeLists.txt @@ -14,31 +14,12 @@ set(INDEX_FILES Utils.cpp VectorMemIndex.cpp IndexFactory.cpp - VectorMemNMIndex.cpp + VectorDiskIndex.cpp ) -if ( BUILD_DISK_ANN STREQUAL "ON" ) - set(INDEX_FILES - ${INDEX_FILES} - VectorDiskIndex.cpp - ) -endif () - milvus_add_pkg_config("milvus_index") add_library(milvus_index SHARED ${INDEX_FILES}) -set(PLATFORM_LIBS ) -if ( LINUX OR APPLE ) - set(PLATFORM_LIBS marisa) -endif() -if (MSYS) - set(PLATFORM_LIBS -Wl,--allow-multiple-definition) -endif () - -target_link_libraries(milvus_index - milvus_storage - ${PLATFORM_LIBS} - milvus-storage - ) +target_link_libraries(milvus_index milvus_storage milvus-storage) install(TARGETS milvus_index DESTINATION "${CMAKE_INSTALL_LIBDIR}") diff --git a/internal/core/src/index/Exception.h b/internal/core/src/index/Exception.h deleted file mode 100644 index dde9a9e24bfd6..0000000000000 --- a/internal/core/src/index/Exception.h +++ /dev/null @@ -1,34 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -namespace milvus::index { - -class UnistdException : public std::runtime_error { - public: - explicit UnistdException(const std::string& msg) : std::runtime_error(msg) { - } - - virtual ~UnistdException() { - } -}; - -} // namespace milvus::index diff --git a/internal/core/src/index/Index.h b/internal/core/src/index/Index.h index d2accc10a2c5f..39d6011dbfb82 100644 --- a/internal/core/src/index/Index.h +++ b/internal/core/src/index/Index.h @@ -18,7 +18,7 @@ #include #include -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "knowhere/comp/index_param.h" #include "knowhere/dataset.h" #include "common/Types.h" diff --git a/internal/core/src/index/IndexFactory-inl.h b/internal/core/src/index/IndexFactory-inl.h index 64e19d1e2b116..a03cc3e2d93a6 100644 --- a/internal/core/src/index/IndexFactory-inl.h +++ b/internal/core/src/index/IndexFactory-inl.h @@ -23,45 +23,35 @@ namespace milvus::index { template inline ScalarIndexPtr -IndexFactory::CreateScalarIndex(const IndexType& index_type, - storage::FileManagerImplPtr file_manager) { - return CreateScalarIndexSort(file_manager); +IndexFactory::CreateScalarIndex( + const IndexType& index_type, + const storage::FileManagerContext& file_manager_context) { + return CreateScalarIndexSort(file_manager_context); } template inline ScalarIndexPtr -IndexFactory::CreateScalarIndex(const IndexType& index_type, - storage::FileManagerImplPtr file_manager, - std::shared_ptr space) { - return CreateScalarIndexSort(file_manager, space); +IndexFactory::CreateScalarIndex( + const IndexType& index_type, + const storage::FileManagerContext& file_manager_context, + std::shared_ptr space) { + return CreateScalarIndexSort(file_manager_context, space); } -// template <> -// inline ScalarIndexPtr -// IndexFactory::CreateScalarIndex(const IndexType& index_type) { -// return CreateBoolIndex(); -//} - template <> inline ScalarIndexPtr -IndexFactory::CreateScalarIndex(const IndexType& index_type, - storage::FileManagerImplPtr file_manager) { -#if defined(__linux__) || defined(__APPLE__) - return CreateStringIndexMarisa(file_manager); -#else - throw std::runtime_error("unsupported platform"); -#endif +IndexFactory::CreateScalarIndex( + const IndexType& index_type, + const storage::FileManagerContext& file_manager_context) { + return CreateStringIndexMarisa(file_manager_context); } template <> inline ScalarIndexPtr -IndexFactory::CreateScalarIndex(const IndexType& index_type, - storage::FileManagerImplPtr file_manager, - std::shared_ptr space) { -#if defined(__linux__) || defined(__APPLE__) - return CreateStringIndexMarisa(file_manager, space); -#else - throw std::runtime_error("unsupported platform"); -#endif +IndexFactory::CreateScalarIndex( + const IndexType& index_type, + const storage::FileManagerContext& file_manager_context, + std::shared_ptr space) { + return CreateStringIndexMarisa(file_manager_context, space); } } // namespace milvus::index diff --git a/internal/core/src/index/IndexFactory.cpp b/internal/core/src/index/IndexFactory.cpp index 4757a06b7001e..501a3efb711ac 100644 --- a/internal/core/src/index/IndexFactory.cpp +++ b/internal/core/src/index/IndexFactory.cpp @@ -16,106 +16,107 @@ #include "index/IndexFactory.h" #include "index/VectorMemIndex.h" -#include "index/VectorMemNMIndex.h" #include "index/Utils.h" #include "index/Meta.h" +#include "knowhere/utils.h" -#ifdef BUILD_DISK_ANN #include "index/VectorDiskIndex.h" -#endif namespace milvus::index { IndexBasePtr -IndexFactory::CreateIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager) { +IndexFactory::CreateIndex( + const CreateIndexInfo& create_index_info, + const storage::FileManagerContext& file_manager_context) { if (datatype_is_vector(create_index_info.field_type)) { - return CreateVectorIndex(create_index_info, file_manager); + return CreateVectorIndex(create_index_info, file_manager_context); } - return CreateScalarIndex(create_index_info, file_manager); + return CreateScalarIndex(create_index_info, file_manager_context); } IndexBasePtr -IndexFactory::CreateIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager, - std::shared_ptr space) { +IndexFactory::CreateIndex( + const CreateIndexInfo& create_index_info, + const storage::FileManagerContext& file_manager_context, + std::shared_ptr space) { if (datatype_is_vector(create_index_info.field_type)) { - return CreateVectorIndex(create_index_info, file_manager, space); + return CreateVectorIndex( + create_index_info, file_manager_context, space); } - return CreateScalarIndex(create_index_info, file_manager, space); + return CreateScalarIndex(create_index_info, file_manager_context, space); } IndexBasePtr -IndexFactory::CreateScalarIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager) { +IndexFactory::CreateScalarIndex( + const CreateIndexInfo& create_index_info, + const storage::FileManagerContext& file_manager_context) { auto data_type = create_index_info.field_type; auto index_type = create_index_info.index_type; switch (data_type) { // create scalar index case DataType::BOOL: - return CreateScalarIndex(index_type, file_manager); + return CreateScalarIndex(index_type, file_manager_context); case DataType::INT8: - return CreateScalarIndex(index_type, file_manager); + return CreateScalarIndex(index_type, file_manager_context); case DataType::INT16: - return CreateScalarIndex(index_type, file_manager); + return CreateScalarIndex(index_type, file_manager_context); case DataType::INT32: - return CreateScalarIndex(index_type, file_manager); + return CreateScalarIndex(index_type, file_manager_context); case DataType::INT64: - return CreateScalarIndex(index_type, file_manager); + return CreateScalarIndex(index_type, file_manager_context); case DataType::FLOAT: - return CreateScalarIndex(index_type, file_manager); + return CreateScalarIndex(index_type, file_manager_context); case DataType::DOUBLE: - return CreateScalarIndex(index_type, file_manager); + return CreateScalarIndex(index_type, file_manager_context); // create string index case DataType::STRING: case DataType::VARCHAR: - return CreateScalarIndex(index_type, file_manager); + return CreateScalarIndex(index_type, + file_manager_context); default: - throw std::invalid_argument( - std::string("invalid data type to build index: ") + - std::to_string(int(data_type))); + throw SegcoreError( + DataTypeInvalid, + fmt::format("invalid data type to build index: {}", data_type)); } } IndexBasePtr -IndexFactory::CreateVectorIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager) { - auto data_type = create_index_info.field_type; +IndexFactory::CreateVectorIndex( + const CreateIndexInfo& create_index_info, + const storage::FileManagerContext& file_manager_context) { auto index_type = create_index_info.index_type; auto metric_type = create_index_info.metric_type; - -#ifdef BUILD_DISK_ANN + auto version = create_index_info.index_engine_version; // create disk index - if (is_in_disk_list(index_type)) { +#ifdef BUILD_DISK_ANN + auto data_type = create_index_info.field_type; + if (knowhere::UseDiskLoad(index_type, version)) { switch (data_type) { case DataType::VECTOR_FLOAT: { return std::make_unique>( - index_type, metric_type, file_manager); + index_type, metric_type, version, file_manager_context); } default: - throw std::invalid_argument( - std::string("invalid data type to build disk index: ") + - std::to_string(int(data_type))); + throw SegcoreError( + DataTypeInvalid, + fmt::format("invalid data type to build disk index: {}", + data_type)); } } #endif - if (is_in_nm_list(index_type)) { - return std::make_unique( - index_type, metric_type, file_manager); - } // create mem index return std::make_unique( - index_type, metric_type, file_manager); + index_type, metric_type, version, file_manager_context); } IndexBasePtr IndexFactory::CreateScalarIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager, + const storage::FileManagerContext& file_manager, std::shared_ptr space) { auto data_type = create_index_info.field_type; auto index_type = create_index_info.index_type; @@ -150,35 +151,34 @@ IndexFactory::CreateScalarIndex(const CreateIndexInfo& create_index_info, } IndexBasePtr -IndexFactory::CreateVectorIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager, - std::shared_ptr space) { +IndexFactory::CreateVectorIndex( + const CreateIndexInfo& create_index_info, + const storage::FileManagerContext& file_manager_context, + std::shared_ptr space) { auto data_type = create_index_info.field_type; auto index_type = create_index_info.index_type; auto metric_type = create_index_info.metric_type; + auto version = create_index_info.index_engine_version; -#ifdef BUILD_DISK_ANN // create disk index - if (is_in_disk_list(index_type)) { +#ifdef BUILD_DISK_ANN + if (knowhere::UseDiskLoad(index_type, version)) { switch (data_type) { case DataType::VECTOR_FLOAT: { return std::make_unique>( - index_type, metric_type, space); + index_type, metric_type, version, file_manager_context); } default: - throw std::invalid_argument( - std::string("invalid data type to build disk index: ") + - std::to_string(int(data_type))); + throw SegcoreError( + DataTypeInvalid, + fmt::format("invalid data type to build disk index: {}", + data_type)); } } #endif - if (is_in_nm_list(index_type)) { - return std::make_unique( - create_index_info, file_manager, space); - } // create mem index return std::make_unique( - create_index_info, file_manager, space); + create_index_info, file_manager_context, space); } } // namespace milvus::index diff --git a/internal/core/src/index/IndexFactory.h b/internal/core/src/index/IndexFactory.h index 5581b007fb158..0754639422121 100644 --- a/internal/core/src/index/IndexFactory.h +++ b/internal/core/src/index/IndexFactory.h @@ -50,28 +50,29 @@ class IndexFactory { IndexBasePtr CreateIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager); + const storage::FileManagerContext& file_manager_context); IndexBasePtr CreateIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager, + const storage::FileManagerContext& file_manager_context, std::shared_ptr space); IndexBasePtr CreateVectorIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager); + const storage::FileManagerContext& file_manager_context); IndexBasePtr CreateScalarIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager = nullptr); + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()); IndexBasePtr CreateVectorIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager, + const storage::FileManagerContext& file_manager_context, std::shared_ptr space); IndexBasePtr CreateScalarIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager, + const storage::FileManagerContext& file_manager_context, std::shared_ptr space); // IndexBasePtr @@ -80,12 +81,12 @@ class IndexFactory { template ScalarIndexPtr CreateScalarIndex(const IndexType& index_type, - storage::FileManagerImplPtr file_manager = nullptr); + const storage::FileManagerContext& file_manager); template ScalarIndexPtr CreateScalarIndex(const IndexType& index_type, - storage::FileManagerImplPtr file_manager, + const storage::FileManagerContext& file_manager, std::shared_ptr space); }; diff --git a/internal/core/src/index/IndexInfo.h b/internal/core/src/index/IndexInfo.h index 4b4b6925addde..632cc3210cffc 100644 --- a/internal/core/src/index/IndexInfo.h +++ b/internal/core/src/index/IndexInfo.h @@ -26,6 +26,7 @@ struct CreateIndexInfo { MetricType metric_type; std::string field_name; int64_t dim; + IndexVersion index_engine_version; }; } // namespace milvus::index diff --git a/internal/core/src/index/Meta.h b/internal/core/src/index/Meta.h index 489a8afe50c3f..77024a13aa073 100644 --- a/internal/core/src/index/Meta.h +++ b/internal/core/src/index/Meta.h @@ -45,10 +45,13 @@ constexpr const char* FIELD_ID = "field_id"; constexpr const char* INDEX_BUILD_ID = "index_build_id"; constexpr const char* INDEX_ID = "index_id"; constexpr const char* INDEX_VERSION = "index_version"; +constexpr const char* INDEX_ENGINE_VERSION = "index_engine_version"; -// DiskAnn build params +// VecIndex file metas constexpr const char* DISK_ANN_PREFIX_PATH = "index_prefix"; constexpr const char* DISK_ANN_RAW_DATA_PATH = "data_path"; + +// DiskAnn build params constexpr const char* DISK_ANN_MAX_DEGREE = "max_degree"; constexpr const char* DISK_ANN_SEARCH_LIST_SIZE = "search_list_size"; constexpr const char* DISK_ANN_PQ_CODE_BUDGET = "pq_code_budget_gb"; diff --git a/internal/core/src/index/ScalarIndex-inl.h b/internal/core/src/index/ScalarIndex-inl.h index 5ddf60d2195d1..b59aadf27d47f 100644 --- a/internal/core/src/index/ScalarIndex-inl.h +++ b/internal/core/src/index/ScalarIndex-inl.h @@ -20,6 +20,7 @@ #include "index/Meta.h" #include "knowhere/dataset.h" +#include "common/Types.h" namespace milvus::index { template @@ -63,8 +64,9 @@ ScalarIndex::Query(const DatasetPtr& dataset) { case OpType::PrefixMatch: case OpType::PostfixMatch: default: - throw std::invalid_argument(std::string( - "unsupported operator type: " + std::to_string(op))); + throw SegcoreError( + OpTypeInvalid, + fmt::format("unsupported operator type: {}", op)); } } @@ -73,7 +75,6 @@ inline void ScalarIndex::BuildWithRawData(size_t n, const void* values, const Config& config) { - // TODO :: use arrow proto::schema::StringArray arr; auto ok = arr.ParseFromArray(values, n); Assert(ok); diff --git a/internal/core/src/index/ScalarIndex.h b/internal/core/src/index/ScalarIndex.h index ec67f79b80024..3fb4c27425615 100644 --- a/internal/core/src/index/ScalarIndex.h +++ b/internal/core/src/index/ScalarIndex.h @@ -22,8 +22,9 @@ #include #include "common/Types.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "index/Index.h" +#include "fmt/format.h" namespace milvus::index { @@ -38,7 +39,8 @@ class ScalarIndex : public IndexBase { void BuildWithDataset(const DatasetPtr& dataset, const Config& config = {}) override { - PanicInfo("scalar index don't support build index with dataset"); + PanicInfo(Unsupported, + "scalar index don't support build index with dataset"); }; public: diff --git a/internal/core/src/index/ScalarIndexSort-inl.h b/internal/core/src/index/ScalarIndexSort-inl.h index 1759692a245ff..d1b459dd0545e 100644 --- a/internal/core/src/index/ScalarIndexSort-inl.h +++ b/internal/core/src/index/ScalarIndexSort-inl.h @@ -25,28 +25,31 @@ #include "Meta.h" #include "common/Utils.h" #include "common/Slice.h" +#include "common/Types.h" #include "index/Utils.h" namespace milvus::index { template inline ScalarIndexSort::ScalarIndexSort( - storage::FileManagerImplPtr file_manager) + const storage::FileManagerContext& file_manager_context) : is_built_(false), data_() { - if (file_manager != nullptr) { - file_manager_ = std::dynamic_pointer_cast( - file_manager); + if (file_manager_context.Valid()) { + file_manager_ = + std::make_shared(file_manager_context); + AssertInfo(file_manager_ != nullptr, "create file manager failed!"); } } template inline ScalarIndexSort::ScalarIndexSort( - storage::FileManagerImplPtr file_manager, + const storage::FileManagerContext& file_manager_context, std::shared_ptr space) : is_built_(false), data_(), space_(space) { - if (file_manager != nullptr) { - file_manager_ = std::dynamic_pointer_cast( - file_manager); + if (file_manager_context.Valid()) { + file_manager_ = + std::make_shared(file_manager_context, space_); + AssertInfo(file_manager_ != nullptr, "create file manager failed!"); } } template @@ -71,9 +74,8 @@ ScalarIndexSort::Build(const Config& config) { total_num_rows += data->get_num_rows(); } if (total_num_rows == 0) { - // todo: throw an exception - throw std::invalid_argument( - "ScalarIndexSort cannot build null values!"); + throw SegcoreError(DataIsEmpty, + "ScalarIndexSort cannot build null values!"); } data_.reserve(total_num_rows); @@ -101,9 +103,8 @@ ScalarIndexSort::Build(size_t n, const T* values) { if (is_built_) return; if (n == 0) { - // todo: throw an exception - throw std::invalid_argument( - "ScalarIndexSort cannot build null values!"); + throw SegcoreError(DataIsEmpty, + "ScalarIndexSort cannot build null values!"); } data_.reserve(n); idx_to_offsets_.resize(n); @@ -226,14 +227,14 @@ ScalarIndexSort::LoadV2(const Config& config) { for (auto& file_name : index_files.value()) { auto res = space_->GetBlobByteSize(file_name); if (!res.ok()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, + PanicInfo(S3Error, "unable to read index blob"); } auto index_blob_data = std::shared_ptr(new uint8_t[res.value()]); auto status = space_->ReadBlob(file_name, index_blob_data.get()); if (!status.ok()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, + PanicInfo(S3Error, "unable to read index blob"); } auto raw_index_blob = @@ -322,8 +323,8 @@ ScalarIndexSort::Range(const T value, const OpType op) { data_.begin(), data_.end(), IndexStructure(value)); break; default: - throw std::invalid_argument(std::string("Invalid OperatorType: ") + - std::to_string((int)op) + "!"); + throw SegcoreError(OpTypeInvalid, + fmt::format("Invalid OperatorType: {}", op)); } for (; lb < ub; ++lb) { bitset[lb->idx_] = true; diff --git a/internal/core/src/index/ScalarIndexSort.h b/internal/core/src/index/ScalarIndexSort.h index 48c0bcb0f0b48..10f32d910b372 100644 --- a/internal/core/src/index/ScalarIndexSort.h +++ b/internal/core/src/index/ScalarIndexSort.h @@ -34,9 +34,10 @@ template class ScalarIndexSort : public ScalarIndex { public: explicit ScalarIndexSort( - storage::FileManagerImplPtr file_manager = nullptr); + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()); - explicit ScalarIndexSort(storage::FileManagerImplPtr file_manager, + explicit ScalarIndexSort(const storage::FileManagerContext& file_manager_context, std::shared_ptr space); BinarySet @@ -126,13 +127,14 @@ using ScalarIndexSortPtr = std::unique_ptr>; namespace milvus::index { template inline ScalarIndexSortPtr -CreateScalarIndexSort(storage::FileManagerImplPtr file_manager = nullptr) { - return std::make_unique>(file_manager); +CreateScalarIndexSort(const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()) { + return std::make_unique>(file_manager_context); } template inline ScalarIndexSortPtr -CreateScalarIndexSort(storage::FileManagerImplPtr file_manager, +CreateScalarIndexSort(const storage::FileManagerContext& file_manager_context, std::shared_ptr space) { - return std::make_unique>(file_manager, space); + return std::make_unique>(file_manager_context, space); } } // namespace milvus::index diff --git a/internal/core/src/index/StringIndexMarisa.cpp b/internal/core/src/index/StringIndexMarisa.cpp index ae5c2d6708216..f24080e45bb10 100644 --- a/internal/core/src/index/StringIndexMarisa.cpp +++ b/internal/core/src/index/StringIndexMarisa.cpp @@ -23,10 +23,10 @@ #include #include "common/Types.h" +#include "common/EasyAssert.h" #include "index/StringIndexMarisa.h" #include "index/Utils.h" #include "index/Index.h" -#include "index/Exception.h" #include "common/Utils.h" #include "common/Slice.h" #include "storage/Util.h" @@ -34,22 +34,21 @@ namespace milvus::index { -#if defined(__linux__) || defined(__APPLE__) - -StringIndexMarisa::StringIndexMarisa(storage::FileManagerImplPtr file_manager) { - if (file_manager != nullptr) { - file_manager_ = std::dynamic_pointer_cast( - file_manager); +StringIndexMarisa::StringIndexMarisa( + const storage::FileManagerContext& file_manager_context) { + if (file_manager_context.Valid()) { + file_manager_ = + std::make_shared(file_manager_context); } } StringIndexMarisa::StringIndexMarisa( - storage::FileManagerImplPtr file_manager, + const storage::FileManagerContext& file_manager_context, std::shared_ptr space) : space_(space) { - if (file_manager != nullptr) { - file_manager_ = std::dynamic_pointer_cast( - file_manager); + if (file_manager_context.Valid()) { + file_manager_ = std::make_shared( + file_manager_context, space_); } } @@ -72,13 +71,13 @@ StringIndexMarisa::BuildV2(const Config& config) { AssertInfo(field_name.has_value(), "field name can not be empty"); auto res = space_->ScanData(); if (!res.ok()) { - PanicInfo("failed to create scan iterator"); + PanicInfo(S3Error, "failed to create scan iterator"); } auto reader = res.value(); std::vector field_datas; for (auto rec = reader->Next(); rec != nullptr; rec = reader->Next()) { if (!rec.ok()) { - PanicInfo("failed to read data"); + PanicInfo(DataFormatBroken, "failed to read data"); } auto data = rec.ValueUnsafe(); auto total_num_rows = data->num_rows(); @@ -122,7 +121,7 @@ StringIndexMarisa::BuildV2(const Config& config) { void StringIndexMarisa::Build(const Config& config) { if (built_) { - throw std::runtime_error("index has been built"); + throw SegcoreError(IndexAlreadyBuild, "index has been built"); } auto insert_files = @@ -135,9 +134,9 @@ StringIndexMarisa::Build(const Config& config) { // fill key set. marisa::Keyset keyset; - for (auto data : field_datas) { + for (const auto& data : field_datas) { auto slice_num = data->get_num_rows(); - for (size_t i = 0; i < slice_num; ++i) { + for (int64_t i = 0; i < slice_num; ++i) { keyset.push_back( (*static_cast(data->RawValue(i))).c_str()); } @@ -148,9 +147,9 @@ StringIndexMarisa::Build(const Config& config) { // fill str_ids_ str_ids_.resize(total_num_rows); int64_t offset = 0; - for (auto data : field_datas) { + for (const auto& data : field_datas) { auto slice_num = data->get_num_rows(); - for (size_t i = 0; i < slice_num; ++i) { + for (int64_t i = 0; i < slice_num; ++i) { auto str_id = lookup(*static_cast(data->RawValue(i))); AssertInfo(valid_str_id(str_id), "invalid marisa key"); @@ -167,7 +166,7 @@ StringIndexMarisa::Build(const Config& config) { void StringIndexMarisa::Build(size_t n, const std::string* values) { if (built_) { - throw std::runtime_error("index has been built"); + throw SegcoreError(IndexAlreadyBuild, "index has been built"); } marisa::Keyset keyset; @@ -262,8 +261,9 @@ StringIndexMarisa::LoadWithoutAssemble(const BinarySet& set, if (status != len) { close(fd); remove(file.c_str()); - throw UnistdException("write index to fd error, errorCode is " + - std::to_string(status)); + throw SegcoreError( + ErrorCode::UnistdError, + "write index to fd error, errorCode is " + std::to_string(status)); } lseek(fd, 0, SEEK_SET); @@ -315,15 +315,13 @@ StringIndexMarisa::LoadV2(const Config& config) { for (auto& file_name : index_files.value()) { auto res = space_->GetBlobByteSize(file_name); if (!res.ok()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, - "unable to read index blob"); + PanicInfo(DataFormatBroken, "unable to read index blob"); } auto index_blob_data = std::shared_ptr(new uint8_t[res.value()]); auto status = space_->ReadBlob(file_name, index_blob_data.get()); if (!status.ok()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, - "unable to read index blob"); + PanicInfo(DataFormatBroken, "unable to read index blob"); } auto raw_index_blob = storage::DeserializeFileData(index_blob_data, res.value()); @@ -398,9 +396,9 @@ StringIndexMarisa::Range(std::string value, OpType op) { set = raw_data.compare(value) >= 0; break; default: - throw std::invalid_argument( - std::string("Invalid OperatorType: ") + - std::to_string((int)op) + "!"); + throw SegcoreError(OpTypeInvalid, + fmt::format("Invalid OperatorType: {}", + static_cast(op))); } if (set) { bitset[offset] = true; @@ -511,6 +509,4 @@ StringIndexMarisa::Reverse_Lookup(size_t offset) const { return std::string(agent.key().ptr(), agent.key().length()); } -#endif - } // namespace milvus::index diff --git a/internal/core/src/index/StringIndexMarisa.h b/internal/core/src/index/StringIndexMarisa.h index 8a6e340ed9c80..acf2a3c549166 100644 --- a/internal/core/src/index/StringIndexMarisa.h +++ b/internal/core/src/index/StringIndexMarisa.h @@ -16,8 +16,6 @@ #pragma once -#if defined(__linux__) || defined(__APPLE__) - #include #include "index/StringIndex.h" #include @@ -32,10 +30,12 @@ namespace milvus::index { class StringIndexMarisa : public StringIndex { public: explicit StringIndexMarisa( - storage::FileManagerImplPtr file_manager = nullptr); + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()); - explicit StringIndexMarisa(storage::FileManagerImplPtr file_manager, - std::shared_ptr space); + explicit StringIndexMarisa( + const storage::FileManagerContext& file_manager_context, + std::shared_ptr space); int64_t Size() override; @@ -123,15 +123,15 @@ class StringIndexMarisa : public StringIndex { using StringIndexMarisaPtr = std::unique_ptr; inline StringIndexPtr -CreateStringIndexMarisa(storage::FileManagerImplPtr file_manager = nullptr) { - return std::make_unique(file_manager); +CreateStringIndexMarisa( + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()) { + return std::make_unique(file_manager_context); } inline StringIndexPtr -CreateStringIndexMarisa(storage::FileManagerImplPtr file_manager, +CreateStringIndexMarisa(const storage::FileManagerContext& file_manager_context, std::shared_ptr space) { - return std::make_unique(file_manager, space); + return std::make_unique(file_manager_context, space); } } // namespace milvus::index - -#endif diff --git a/internal/core/src/index/Utils.cpp b/internal/core/src/index/Utils.cpp index 87f5c7fc0e6d7..8193241a96010 100644 --- a/internal/core/src/index/Utils.cpp +++ b/internal/core/src/index/Utils.cpp @@ -26,16 +26,15 @@ #include #include "index/Utils.h" -#include "index/Exception.h" #include "index/Meta.h" #include #include -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "knowhere/comp/index_param.h" #include "common/Slice.h" #include "storage/FieldData.h" #include "storage/Util.h" -#include "utils/File.h" +#include "common/File.h" namespace milvus::index { @@ -121,6 +120,15 @@ GetIndexTypeFromConfig(const Config& config) { return index_type.value(); } +IndexVersion +GetIndexEngineVersionFromConfig(const Config& config) { + auto index_engine_version = + GetValueFromConfig(config, INDEX_ENGINE_VERSION); + AssertInfo(index_engine_version.has_value(), + "index_engine not exist in config"); + return (std::stoi(index_engine_version.value())); +} + // TODO :: too ugly storage::FieldDataMeta GetFieldDataMetaFromConfig(const Config& config) { @@ -286,9 +294,10 @@ ReadDataFromFD(int fd, void* buf, size_t size, size_t chunk_size) { const size_t count = (size < chunk_size) ? size : chunk_size; const ssize_t size_read = read(fd, buf, count); if (size_read != count) { - throw UnistdException( + throw SegcoreError( + ErrorCode::UnistdError, "read data from fd error, returned read size is " + - std::to_string(size_read)); + std::to_string(size_read)); } buf = static_cast(buf) + size_read; diff --git a/internal/core/src/index/Utils.h b/internal/core/src/index/Utils.h index 02c57471978c2..adc0b34595e7a 100644 --- a/internal/core/src/index/Utils.h +++ b/internal/core/src/index/Utils.h @@ -100,6 +100,9 @@ GetMetricTypeFromConfig(const Config& config); std::string GetIndexTypeFromConfig(const Config& config); +IndexVersion +GetIndexEngineVersionFromConfig(const Config& config); + storage::FieldDataMeta GetFieldDataMetaFromConfig(const Config& config); diff --git a/internal/core/src/index/VectorDiskIndex.cpp b/internal/core/src/index/VectorDiskIndex.cpp index 416d215a4a892..7e89cb44e86d2 100644 --- a/internal/core/src/index/VectorDiskIndex.cpp +++ b/internal/core/src/index/VectorDiskIndex.cpp @@ -27,9 +27,6 @@ namespace milvus::index { -#define BUILD_DISK_ANN -#ifdef BUILD_DISK_ANN - #define kSearchListMaxValue1 200 // used if tok <= 20 #define kSearchListMaxValue2 65535 // used for topk > 20 #define kPrepareDim 100 @@ -39,10 +36,12 @@ template VectorDiskAnnIndex::VectorDiskAnnIndex( const IndexType& index_type, const MetricType& metric_type, - storage::FileManagerImplPtr file_manager) + const IndexVersion& version, + const storage::FileManagerContext& file_manager_context) : VectorIndex(index_type, metric_type) { file_manager_ = - std::dynamic_pointer_cast(file_manager); + std::make_shared(file_manager_context); + AssertInfo(file_manager_ != nullptr, "create file manager failed!"); auto local_chunk_manager = storage::LocalChunkManagerSingleton::GetInstance().GetChunkManager(); auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix(); @@ -53,12 +52,12 @@ VectorDiskAnnIndex::VectorDiskAnnIndex( if (local_chunk_manager->Exist(local_index_path_prefix)) { local_chunk_manager->RemoveDir(local_index_path_prefix); } - + CheckCompatible(version); local_chunk_manager->CreateDir(local_index_path_prefix); auto diskann_index_pack = - knowhere::Pack(std::shared_ptr(file_manager)); - index_ = knowhere::IndexFactory::Instance().Create(GetIndexType(), - diskann_index_pack); + knowhere::Pack(std::shared_ptr(file_manager_)); + index_ = knowhere::IndexFactory::Instance().Create( + GetIndexType(), version, diskann_index_pack); } template @@ -81,9 +80,8 @@ VectorDiskAnnIndex::Load(const Config& config) { auto stat = index_.Deserialize(knowhere::BinarySet(), load_config); if (stat != knowhere::Status::success) - PanicCodeInfo( - ErrorCodeEnum::UnexpectedError, - "failed to Deserialize index, " + KnowhereStatusString(stat)); + PanicInfo(ErrorCode::UnexpectedError, + "failed to Deserialize index, " + KnowhereStatusString(stat)); SetDim(index_.Dim()); } @@ -91,8 +89,9 @@ VectorDiskAnnIndex::Load(const Config& config) { template BinarySet VectorDiskAnnIndex::Upload(const Config& config) { - auto remote_paths_to_size = file_manager_->GetRemotePathsToFileSize(); BinarySet ret; + index_.Serialize(ret); + auto remote_paths_to_size = file_manager_->GetRemotePathsToFileSize(); for (auto& file : remote_paths_to_size) { ret.Append(file.first, nullptr, file.second); } @@ -125,11 +124,15 @@ VectorDiskAnnIndex::Build(const Config& config) { auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix(); build_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix; - auto num_threads = GetValueFromConfig( - build_config, DISK_ANN_BUILD_THREAD_NUM); - AssertInfo(num_threads.has_value(), - "param " + std::string(DISK_ANN_BUILD_THREAD_NUM) + "is empty"); - build_config[DISK_ANN_THREADS_NUM] = std::atoi(num_threads.value().c_str()); + if (GetIndexType() == knowhere::IndexEnum::INDEX_DISKANN) { + auto num_threads = GetValueFromConfig( + build_config, DISK_ANN_BUILD_THREAD_NUM); + AssertInfo( + num_threads.has_value(), + "param " + std::string(DISK_ANN_BUILD_THREAD_NUM) + "is empty"); + build_config[DISK_ANN_THREADS_NUM] = + std::atoi(num_threads.value().c_str()); + } knowhere::DataSet* ds_ptr = nullptr; build_config.erase("insert_files"); index_.Build(*ds_ptr, build_config); @@ -157,12 +160,15 @@ VectorDiskAnnIndex::BuildWithDataset(const DatasetPtr& dataset, auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix(); build_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix; - auto num_threads = GetValueFromConfig( - build_config, DISK_ANN_BUILD_THREAD_NUM); - AssertInfo(num_threads.has_value(), - "param " + std::string(DISK_ANN_BUILD_THREAD_NUM) + "is empty"); - build_config[DISK_ANN_THREADS_NUM] = std::atoi(num_threads.value().c_str()); - + if (GetIndexType() == knowhere::IndexEnum::INDEX_DISKANN) { + auto num_threads = GetValueFromConfig( + build_config, DISK_ANN_BUILD_THREAD_NUM); + AssertInfo( + num_threads.has_value(), + "param " + std::string(DISK_ANN_BUILD_THREAD_NUM) + "is empty"); + build_config[DISK_ANN_THREADS_NUM] = + std::atoi(num_threads.value().c_str()); + } if (!local_chunk_manager->Exist(local_data_path)) { local_chunk_manager->CreateFile(local_data_path); } @@ -183,8 +189,8 @@ VectorDiskAnnIndex::BuildWithDataset(const DatasetPtr& dataset, knowhere::DataSet* ds_ptr = nullptr; auto stat = index_.Build(*ds_ptr, build_config); if (stat != knowhere::Status::success) - PanicCodeInfo(ErrorCodeEnum::BuildIndexError, - "failed to build index, " + KnowhereStatusString(stat)); + PanicInfo(ErrorCode::IndexBuildError, + "failed to build index, " + KnowhereStatusString(stat)); local_chunk_manager->RemoveDir( storage::GetSegmentRawDataPathPrefix(local_chunk_manager, segment_id)); @@ -210,20 +216,21 @@ VectorDiskAnnIndex::Query(const DatasetPtr dataset, // set search list size auto search_list_size = GetValueFromConfig( search_info.search_params_, DISK_ANN_QUERY_LIST); - if (search_list_size.has_value()) { - search_config[DISK_ANN_SEARCH_LIST_SIZE] = search_list_size.value(); - } - // set beamwidth - search_config[DISK_ANN_QUERY_BEAMWIDTH] = int(search_beamwidth_); + if (GetIndexType() == knowhere::IndexEnum::INDEX_DISKANN) { + if (search_list_size.has_value()) { + search_config[DISK_ANN_SEARCH_LIST_SIZE] = search_list_size.value(); + } + // set beamwidth + search_config[DISK_ANN_QUERY_BEAMWIDTH] = int(search_beamwidth_); + // set json reset field, will be removed later + search_config[DISK_ANN_PQ_CODE_BUDGET] = 0.0; + } // set index prefix, will be removed later auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix(); search_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix; - // set json reset field, will be removed later - search_config[DISK_ANN_PQ_CODE_BUDGET] = 0.0; - auto final = [&] { auto radius = GetValueFromConfig(search_info.search_params_, RADIUS); @@ -240,20 +247,20 @@ VectorDiskAnnIndex::Query(const DatasetPtr dataset, auto res = index_.RangeSearch(*dataset, search_config, bitset); if (!res.has_value()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, - fmt::format("failed to range search: {}: {}", - KnowhereStatusString(res.error()), - res.what())); + PanicInfo(ErrorCode::UnexpectedError, + fmt::format("failed to range search: {}: {}", + KnowhereStatusString(res.error()), + res.what())); } return ReGenRangeSearchResult( res.value(), topk, num_queries, GetMetricType()); } else { auto res = index_.Search(*dataset, search_config, bitset); if (!res.has_value()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, - fmt::format("failed to search: {}: {}", - KnowhereStatusString(res.error()), - res.what())); + PanicInfo(ErrorCode::UnexpectedError, + fmt::format("failed to search: {}: {}", + KnowhereStatusString(res.error()), + res.what())); } return res.value(); } @@ -295,10 +302,10 @@ std::vector VectorDiskAnnIndex::GetVector(const DatasetPtr dataset) const { auto res = index_.GetVectorByIds(*dataset); if (!res.has_value()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, - fmt::format("failed to get vector: {}: {}", - KnowhereStatusString(res.error()), - res.what())); + PanicInfo(ErrorCode::UnexpectedError, + fmt::format("failed to get vector: {}: {}", + KnowhereStatusString(res.error()), + res.what())); } auto index_type = GetIndexType(); auto tensor = res.value()->GetTensor(); @@ -336,22 +343,26 @@ VectorDiskAnnIndex::update_load_json(const Config& config) { auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix(); load_config[DISK_ANN_PREFIX_PATH] = local_index_path_prefix; - // set base info - load_config[DISK_ANN_PREPARE_WARM_UP] = false; - load_config[DISK_ANN_PREPARE_USE_BFS_CACHE] = false; - - // set threads number - auto num_threads = - GetValueFromConfig(load_config, DISK_ANN_LOAD_THREAD_NUM); - AssertInfo(num_threads.has_value(), - "param " + std::string(DISK_ANN_LOAD_THREAD_NUM) + "is empty"); - load_config[DISK_ANN_THREADS_NUM] = std::atoi(num_threads.value().c_str()); - - // update search_beamwidth - auto beamwidth = - GetValueFromConfig(load_config, DISK_ANN_QUERY_BEAMWIDTH); - if (beamwidth.has_value()) { - search_beamwidth_ = std::atoi(beamwidth.value().c_str()); + if (GetIndexType() == knowhere::IndexEnum::INDEX_DISKANN) { + // set base info + load_config[DISK_ANN_PREPARE_WARM_UP] = false; + load_config[DISK_ANN_PREPARE_USE_BFS_CACHE] = false; + + // set threads number + auto num_threads = GetValueFromConfig( + load_config, DISK_ANN_LOAD_THREAD_NUM); + AssertInfo( + num_threads.has_value(), + "param " + std::string(DISK_ANN_LOAD_THREAD_NUM) + "is empty"); + load_config[DISK_ANN_THREADS_NUM] = + std::atoi(num_threads.value().c_str()); + + // update search_beamwidth + auto beamwidth = GetValueFromConfig( + load_config, DISK_ANN_QUERY_BEAMWIDTH); + if (beamwidth.has_value()) { + search_beamwidth_ = std::atoi(beamwidth.value().c_str()); + } } return load_config; @@ -359,6 +370,4 @@ VectorDiskAnnIndex::update_load_json(const Config& config) { template class VectorDiskAnnIndex; -#endif - } // namespace milvus::index diff --git a/internal/core/src/index/VectorDiskIndex.h b/internal/core/src/index/VectorDiskIndex.h index 0e7f2ee91de13..1ff0ddd8176f5 100644 --- a/internal/core/src/index/VectorDiskIndex.h +++ b/internal/core/src/index/VectorDiskIndex.h @@ -25,19 +25,21 @@ namespace milvus::index { -#define BUILD_DISK_ANN -#ifdef BUILD_DISK_ANN - template class VectorDiskAnnIndex : public VectorIndex { public: - explicit VectorDiskAnnIndex(const IndexType& index_type, - const MetricType& metric_type, - storage::FileManagerImplPtr file_manager); + explicit VectorDiskAnnIndex( + const IndexType& index_type, + const MetricType& metric_type, + const IndexVersion& version, + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()); + BinarySet Serialize(const Config& config) override { // deprecated - auto remote_paths_to_size = file_manager_->GetRemotePathsToFileSize(); BinarySet binary_set; + index_.Serialize(binary_set); + auto remote_paths_to_size = file_manager_->GetRemotePathsToFileSize(); for (auto& file : remote_paths_to_size) { binary_set.Append(file.first, nullptr, file.second); } @@ -96,6 +98,5 @@ class VectorDiskAnnIndex : public VectorIndex { template using VectorDiskAnnIndexPtr = std::unique_ptr>; -#endif } // namespace milvus::index diff --git a/internal/core/src/index/VectorIndex.h b/internal/core/src/index/VectorIndex.h index 4d14dc1e68328..6c906c02f655a 100644 --- a/internal/core/src/index/VectorIndex.h +++ b/internal/core/src/index/VectorIndex.h @@ -22,12 +22,14 @@ #include #include +#include "Utils.h" #include "knowhere/factory.h" #include "index/Index.h" #include "common/Types.h" #include "common/BitsetView.h" #include "common/QueryResult.h" #include "common/QueryInfo.h" +#include "knowhere/version.h" namespace milvus::index { @@ -43,12 +45,13 @@ class VectorIndex : public IndexBase { BuildWithRawData(size_t n, const void* values, const Config& config = {}) override { - PanicInfo("vector index don't support build index with raw data"); + PanicInfo(Unsupported, + "vector index don't support build index with raw data"); }; virtual void AddWithDataset(const DatasetPtr& dataset, const Config& config) { - PanicInfo("vector index don't support add with dataset"); + PanicInfo(Unsupported, "vector index don't support add with dataset"); } virtual std::unique_ptr @@ -86,6 +89,18 @@ class VectorIndex : public IndexBase { CleanLocalData() { } + void + CheckCompatible(const IndexVersion& version) { + std::string err_msg = + "version not support : " + std::to_string(version) + + " , knowhere current version " + + std::to_string( + knowhere::Version::GetCurrentVersion().VersionNumber()); + AssertInfo( + knowhere::Version::VersionSupport(knowhere::Version(version)), + err_msg); + } + private: MetricType metric_type_; int64_t dim_; diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index 7e0d3eb8acbb8..f46b0d4439f02 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -31,7 +31,7 @@ #include "index/IndexInfo.h" #include "index/Meta.h" #include "index/Utils.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "config/ConfigKnowhere.h" #include "knowhere/factory.h" #include "knowhere/comp/time_recorder.h" @@ -47,27 +47,31 @@ #include "storage/MemFileManagerImpl.h" #include "storage/ThreadPools.h" #include "storage/Util.h" -#include "utils/File.h" +#include "common/File.h" #include "common/Tracer.h" #include "storage/space.h" namespace milvus::index { -VectorMemIndex::VectorMemIndex(const IndexType& index_type, - const MetricType& metric_type, - storage::FileManagerImplPtr file_manager) +VectorMemIndex::VectorMemIndex( + const IndexType& index_type, + const MetricType& metric_type, + const IndexVersion& version, + const storage::FileManagerContext& file_manager_context) : VectorIndex(index_type, metric_type) { AssertInfo(!is_unsupported(index_type, metric_type), index_type + " doesn't support metric: " + metric_type); - if (file_manager != nullptr) { - file_manager_ = std::dynamic_pointer_cast( - file_manager); + if (file_manager_context.Valid()) { + file_manager_ = + std::make_shared(file_manager_context); + AssertInfo(file_manager_ != nullptr, "create file manager failed!"); } - index_ = knowhere::IndexFactory::Instance().Create(GetIndexType()); + CheckCompatible(version); + index_ = knowhere::IndexFactory::Instance().Create(GetIndexType(), version); } VectorMemIndex::VectorMemIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager, + const storage::FileManagerContext& file_manager_context, std::shared_ptr space) : VectorIndex(create_index_info.index_type, create_index_info.metric_type), space_(space), @@ -76,11 +80,14 @@ VectorMemIndex::VectorMemIndex(const CreateIndexInfo& create_index_info, create_index_info.metric_type), create_index_info.index_type + " doesn't support metric: " + create_index_info.metric_type); - if (file_manager != nullptr) { - file_manager_ = std::dynamic_pointer_cast( - file_manager); + if (file_manager_context.Valid()) { + file_manager_ = + std::make_shared(file_manager_context); + AssertInfo(file_manager_ != nullptr, "create file manager failed!"); } - index_ = knowhere::IndexFactory::Instance().Create(GetIndexType()); + auto version = create_index_info.index_engine_version; + CheckCompatible(version); + index_ = knowhere::IndexFactory::Instance().Create(GetIndexType(), version); } BinarySet @@ -116,9 +123,8 @@ VectorMemIndex::Serialize(const Config& config) { knowhere::BinarySet ret; auto stat = index_.Serialize(ret); if (stat != knowhere::Status::success) - PanicCodeInfo( - ErrorCodeEnum::UnexpectedError, - "failed to serialize index, " + KnowhereStatusString(stat)); + PanicInfo(ErrorCode::UnexpectedError, + "failed to serialize index, " + KnowhereStatusString(stat)); Disassemble(ret); return ret; @@ -127,11 +133,10 @@ VectorMemIndex::Serialize(const Config& config) { void VectorMemIndex::LoadWithoutAssemble(const BinarySet& binary_set, const Config& config) { - auto stat = index_.Deserialize(binary_set); + auto stat = index_.Deserialize(binary_set, config); if (stat != knowhere::Status::success) - PanicCodeInfo( - ErrorCodeEnum::UnexpectedError, - "failed to Deserialize index, " + KnowhereStatusString(stat)); + PanicInfo(ErrorCode::UnexpectedError, + "failed to Deserialize index, " + KnowhereStatusString(stat)); SetDim(index_.Dim()); } @@ -160,7 +165,7 @@ VectorMemIndex::LoadV2(const Config& config) { std::map index_datas{}; if (!res.ok() && !res.status().IsFileNotFound()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to read blob"); + PanicInfo(DataFormatBroken, "failed to read blob"); } bool slice_meta_exist = res.ok(); @@ -168,14 +173,14 @@ VectorMemIndex::LoadV2(const Config& config) { -> std::unique_ptr { auto res = space_->GetBlobByteSize(file_name); if (!res.ok()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, + PanicInfo(DataFormatBroken, "unable to read index blob"); } auto index_blob_data = std::shared_ptr(new uint8_t[res.value()]); auto status = space_->ReadBlob(file_name, index_blob_data.get()); if (!status.ok()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, + PanicInfo(DataFormatBroken, "unable to read index blob"); } return storage::DeserializeFileData(index_blob_data, res.value()); @@ -187,7 +192,7 @@ VectorMemIndex::LoadV2(const Config& config) { auto status = space_->ReadBlob(INDEX_FILE_SLICE_META, slice_meta_data.get()); if (!status.ok()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, + PanicInfo(DataFormatBroken, "unable to read slice meta"); } auto raw_slice_meta = @@ -360,8 +365,8 @@ VectorMemIndex::BuildWithDataset(const DatasetPtr& dataset, knowhere::TimeRecorder rc("BuildWithoutIds", 1); auto stat = index_.Build(*dataset, index_config); if (stat != knowhere::Status::success) - PanicCodeInfo(ErrorCodeEnum::BuildIndexError, - "failed to build index, " + KnowhereStatusString(stat)); + PanicInfo(ErrorCode::IndexBuildError, + "failed to build index, " + KnowhereStatusString(stat)); rc.ElapseFromBegin("Done"); SetDim(index_.Dim()); } @@ -374,7 +379,7 @@ VectorMemIndex::BuildV2(const Config& config) { auto dim = create_index_info_.dim; auto res = space_->ScanData(); if (!res.ok()) { - PanicInfo(fmt::format("failed to create scan iterator: {}", + PanicInfo(IndexBuildError, fmt::format("failed to create scan iterator: {}", res.status().ToString())); } @@ -382,7 +387,7 @@ VectorMemIndex::BuildV2(const Config& config) { std::vector field_datas; for (auto rec : *reader) { if (!rec.ok()) { - PanicInfo(fmt::format("failed to read data: {}", + PanicInfo(IndexBuildError,fmt::format("failed to read data: {}", rec.status().ToString())); } auto data = rec.ValueUnsafe(); @@ -472,8 +477,8 @@ VectorMemIndex::AddWithDataset(const DatasetPtr& dataset, knowhere::TimeRecorder rc("AddWithDataset", 1); auto stat = index_.Add(*dataset, index_config); if (stat != knowhere::Status::success) - PanicCodeInfo(ErrorCodeEnum::BuildIndexError, - "failed to append index, " + KnowhereStatusString(stat)); + PanicInfo(ErrorCode::IndexBuildError, + "failed to append index, " + KnowhereStatusString(stat)); rc.ElapseFromBegin("Done"); } @@ -502,10 +507,10 @@ VectorMemIndex::Query(const DatasetPtr dataset, auto res = index_.RangeSearch(*dataset, search_conf, bitset); milvus::tracer::AddEvent("finish_knowhere_index_range_search"); if (!res.has_value()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, - fmt::format("failed to range search: {}: {}", - KnowhereStatusString(res.error()), - res.what())); + PanicInfo(ErrorCode::UnexpectedError, + fmt::format("failed to range search: {}: {}", + KnowhereStatusString(res.error()), + res.what())); } auto result = ReGenRangeSearchResult( res.value(), topk, num_queries, GetMetricType()); @@ -516,10 +521,10 @@ VectorMemIndex::Query(const DatasetPtr dataset, auto res = index_.Search(*dataset, search_conf, bitset); milvus::tracer::AddEvent("finish_knowhere_index_search"); if (!res.has_value()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, - fmt::format("failed to search: {}: {}", - KnowhereStatusString(res.error()), - res.what())); + PanicInfo(ErrorCode::UnexpectedError, + fmt::format("failed to search: {}: {}", + KnowhereStatusString(res.error()), + res.what())); } return res.value(); } @@ -558,9 +563,8 @@ std::vector VectorMemIndex::GetVector(const DatasetPtr dataset) const { auto res = index_.GetVectorByIds(*dataset); if (!res.has_value()) { - PanicCodeInfo( - ErrorCodeEnum::UnexpectedError, - "failed to get vector, " + KnowhereStatusString(res.error())); + PanicInfo(ErrorCode::UnexpectedError, + "failed to get vector, " + KnowhereStatusString(res.error())); } auto index_type = GetIndexType(); auto tensor = res.value()->GetTensor(); @@ -678,9 +682,9 @@ VectorMemIndex::LoadFromFile(const Config& config) { conf[kEnableMmap] = true; auto stat = index_.DeserializeFromFile(filepath.value(), conf); if (stat != knowhere::Status::success) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, - fmt::format("failed to Deserialize index: {}", - KnowhereStatusString(stat))); + PanicInfo(ErrorCode::UnexpectedError, + fmt::format("failed to Deserialize index: {}", + KnowhereStatusString(stat))); } auto dim = index_.Dim(); @@ -717,7 +721,7 @@ VectorMemIndex::LoadFromFileV2(const Config& config) { auto res = space_->GetBlobByteSize(std::string(INDEX_FILE_SLICE_META)); if (!res.ok() && !res.status().IsFileNotFound()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, "failed to read blob"); + PanicInfo(DataFormatBroken, "failed to read blob"); } bool slice_meta_exist = res.ok(); @@ -725,14 +729,14 @@ VectorMemIndex::LoadFromFileV2(const Config& config) { -> std::unique_ptr { auto res = space_->GetBlobByteSize(file_name); if (!res.ok()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, + PanicInfo(DataFormatBroken, "unable to read index blob"); } auto index_blob_data = std::shared_ptr(new uint8_t[res.value()]); auto status = space_->ReadBlob(file_name, index_blob_data.get()); if (!status.ok()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, + PanicInfo(DataFormatBroken, "unable to read index blob"); } return storage::DeserializeFileData(index_blob_data, res.value()); @@ -744,7 +748,7 @@ VectorMemIndex::LoadFromFileV2(const Config& config) { auto status = space_->ReadBlob(INDEX_FILE_SLICE_META, slice_meta_data.get()); if (!status.ok()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, + PanicInfo(DataFormatBroken, "unable to read slice meta"); } auto raw_slice_meta = @@ -783,7 +787,7 @@ VectorMemIndex::LoadFromFileV2(const Config& config) { conf[kEnableMmap] = true; auto stat = index_.DeserializeFromFile(filepath.value(), conf); if (stat != knowhere::Status::success) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, + PanicInfo(DataFormatBroken, fmt::format("failed to Deserialize index: {}", KnowhereStatusString(stat))); } diff --git a/internal/core/src/index/VectorMemIndex.h b/internal/core/src/index/VectorMemIndex.h index 0ce6e25619c58..f064f050ab9b9 100644 --- a/internal/core/src/index/VectorMemIndex.h +++ b/internal/core/src/index/VectorMemIndex.h @@ -32,12 +32,15 @@ namespace milvus::index { class VectorMemIndex : public VectorIndex { public: - explicit VectorMemIndex(const IndexType& index_type, - const MetricType& metric_type, - storage::FileManagerImplPtr file_manager = nullptr); + explicit VectorMemIndex( + const IndexType& index_type, + const MetricType& metric_type, + const IndexVersion& version, + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()); explicit VectorMemIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager, + const storage::FileManagerContext& file_manager, std::shared_ptr space); BinarySet Serialize(const Config& config) override; diff --git a/internal/core/src/index/VectorMemNMIndex.cpp b/internal/core/src/index/VectorMemNMIndex.cpp deleted file mode 100644 index b18de825efd68..0000000000000 --- a/internal/core/src/index/VectorMemNMIndex.cpp +++ /dev/null @@ -1,133 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "common/Slice.h" -#include "common/Utils.h" -#include "common/BitsetView.h" -#include "index/VectorMemNMIndex.h" -#include "log/Log.h" - -#include "knowhere/factory.h" -#include "knowhere/comp/time_recorder.h" -#define RAW_DATA "RAW_DATA" -#include "common/Tracer.h" - -namespace milvus::index { - -BinarySet -VectorMemNMIndex::Serialize(const Config& config) { - knowhere::BinarySet ret; - auto stat = index_.Serialize(ret); - if (stat != knowhere::Status::success) - PanicCodeInfo( - ErrorCodeEnum::UnexpectedError, - "failed to serialize index, " + KnowhereStatusString(stat)); - - auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction - auto raw_data = std::shared_ptr( - static_cast(raw_data_.data()), deleter); - ret.Append(RAW_DATA, raw_data, raw_data_.size()); - Disassemble(ret); - - return ret; -} - -void -VectorMemNMIndex::BuildWithDataset(const DatasetPtr& dataset, - const Config& config) { - VectorMemIndex::BuildWithDataset(dataset, config); - knowhere::TimeRecorder rc("store_raw_data", 1); - store_raw_data(dataset); - rc.ElapseFromBegin("Done"); -} - -void -VectorMemNMIndex::LoadWithoutAssemble(const BinarySet& binary_set, - const Config& config) { - VectorMemIndex::LoadWithoutAssemble(binary_set, config); - if (binary_set.Contains(RAW_DATA)) { - std::call_once(raw_data_loaded_, [&]() { - LOG_SEGCORE_INFO_ << "NM index load raw data done!"; - }); - } -} - -void -VectorMemNMIndex::AddWithDataset(const DatasetPtr& /*dataset*/, - const Config& /*config*/) { -} - -void -VectorMemNMIndex::Load(const BinarySet& binary_set, const Config& config) { - VectorMemIndex::Load(binary_set, config); - if (binary_set.Contains(RAW_DATA)) { - std::call_once(raw_data_loaded_, [&]() { - LOG_SEGCORE_INFO_ << "NM index load raw data done!"; - }); - } -} - -std::unique_ptr -VectorMemNMIndex::Query(const DatasetPtr dataset, - const SearchInfo& search_info, - const BitsetView& bitset) { - auto load_raw_data_closure = [&]() { LoadRawData(); }; // hide this pointer - // load -> query, raw data has been loaded - // build -> query, this case just for test, should load raw data before query - std::call_once(raw_data_loaded_, load_raw_data_closure); - return VectorMemIndex::Query(dataset, search_info, bitset); -} - -void -VectorMemNMIndex::store_raw_data(const DatasetPtr& dataset) { - auto index_type = GetIndexType(); - auto tensor = dataset->GetTensor(); - auto row_num = dataset->GetRows(); - auto dim = dataset->GetDim(); - int64_t data_size; - if (is_in_bin_list(index_type)) { - data_size = dim / 8 * row_num; - } else { - data_size = dim * row_num * sizeof(float); - } - raw_data_.resize(data_size); - memcpy(raw_data_.data(), tensor, data_size); -} - -void -VectorMemNMIndex::LoadRawData() { - knowhere::BinarySet bs; - auto stat = index_.Serialize(bs); - if (stat != knowhere::Status::success) - PanicCodeInfo( - ErrorCodeEnum::UnexpectedError, - "failed to Serialize index, " + KnowhereStatusString(stat)); - - auto bptr = std::make_shared(); - auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction - bptr->data = std::shared_ptr( - static_cast(raw_data_.data()), deleter); - bptr->size = raw_data_.size(); - bs.Append(RAW_DATA, bptr); - stat = index_.Deserialize(bs); - if (stat != knowhere::Status::success) - PanicCodeInfo( - ErrorCodeEnum::UnexpectedError, - "failed to Deserialize index, " + KnowhereStatusString(stat)); - milvus::tracer::AddEvent("VectorMemNMIndex_Loaded_RawData"); -} - -} // namespace milvus::index diff --git a/internal/core/src/index/VectorMemNMIndex.h b/internal/core/src/index/VectorMemNMIndex.h deleted file mode 100644 index 02f4c6d79f22e..0000000000000 --- a/internal/core/src/index/VectorMemNMIndex.h +++ /dev/null @@ -1,82 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include - -#include "index/IndexInfo.h" -#include "index/Utils.h" -#include "index/VectorMemIndex.h" - -namespace milvus::index { - -class VectorMemNMIndex : public VectorMemIndex { - public: - explicit VectorMemNMIndex( - const IndexType& index_type, - const MetricType& metric_type, - storage::FileManagerImplPtr file_manager = nullptr) - : VectorMemIndex(index_type, metric_type, file_manager) { - AssertInfo(is_in_nm_list(index_type), "not valid nm index type"); - } - - explicit VectorMemNMIndex(const CreateIndexInfo& create_index_info, - storage::FileManagerImplPtr file_manager, - std::shared_ptr space) - : VectorMemIndex(create_index_info, file_manager, space) { - AssertInfo(is_in_nm_list(create_index_info.index_type), - "not valid nm index type"); - } - BinarySet - Serialize(const Config& config) override; - - void - BuildWithDataset(const DatasetPtr& dataset, - const Config& config = {}) override; - - void - AddWithDataset(const DatasetPtr& dataset, const Config& config) override; - - void - Load(const BinarySet& binary_set, const Config& config = {}) override; - - std::unique_ptr - Query(const DatasetPtr dataset, - const SearchInfo& search_info, - const BitsetView& bitset) override; - - void - LoadWithoutAssemble(const BinarySet& binary_set, - const Config& config) override; - - private: - void - store_raw_data(const DatasetPtr& dataset); - - void - LoadRawData(); - - private: - std::vector raw_data_; - std::once_flag raw_data_loaded_; -}; - -using VectorMemNMIndexPtr = std::unique_ptr; -} // namespace milvus::index diff --git a/internal/core/src/indexbuilder/IndexFactory.h b/internal/core/src/indexbuilder/IndexFactory.h index e9c1826a92961..c2b4da058aa88 100644 --- a/internal/core/src/indexbuilder/IndexFactory.h +++ b/internal/core/src/indexbuilder/IndexFactory.h @@ -16,6 +16,7 @@ #include #include +#include "common/EasyAssert.h" #include "indexbuilder/IndexCreatorBase.h" #include "indexbuilder/ScalarIndexCreator.h" #include "indexbuilder/VecIndexCreator.h" @@ -45,7 +46,7 @@ class IndexFactory { IndexCreatorBasePtr CreateIndex(DataType type, Config& config, - storage::FileManagerImplPtr file_manager) { + const storage::FileManagerContext& context) { auto invalid_dtype_msg = std::string("invalid data type: ") + std::to_string(int(type)); @@ -59,14 +60,15 @@ class IndexFactory { case DataType::DOUBLE: case DataType::VARCHAR: case DataType::STRING: - return CreateScalarIndex(type, config, file_manager); + return CreateScalarIndex(type, config, context); case DataType::VECTOR_FLOAT: case DataType::VECTOR_BINARY: - return std::make_unique( - type, config, file_manager); + return std::make_unique(type, config, context); default: - throw std::invalid_argument(invalid_dtype_msg); + throw SegcoreError( + DataTypeInvalid, + fmt::format("invalid type is {}", invalid_dtype_msg)); } } diff --git a/internal/core/src/indexbuilder/ScalarIndexCreator.cpp b/internal/core/src/indexbuilder/ScalarIndexCreator.cpp index 6a1b506d197ad..4b0f4f913a36f 100644 --- a/internal/core/src/indexbuilder/ScalarIndexCreator.cpp +++ b/internal/core/src/indexbuilder/ScalarIndexCreator.cpp @@ -20,15 +20,16 @@ namespace milvus::indexbuilder { -ScalarIndexCreator::ScalarIndexCreator(DataType dtype, - Config& config, - storage::FileManagerImplPtr file_manager) +ScalarIndexCreator::ScalarIndexCreator( + DataType dtype, + Config& config, + const storage::FileManagerContext& file_manager_context) : dtype_(dtype), config_(config) { milvus::index::CreateIndexInfo index_info; index_info.field_type = dtype_; index_info.index_type = index_type(); - index_ = index::IndexFactory::GetInstance().CreateIndex(index_info, - file_manager); + index_ = index::IndexFactory::GetInstance().CreateIndex( + index_info, file_manager_context); } ScalarIndexCreator::ScalarIndexCreator( diff --git a/internal/core/src/indexbuilder/ScalarIndexCreator.h b/internal/core/src/indexbuilder/ScalarIndexCreator.h index cc27dd847aa81..4e4d1fec2fbb6 100644 --- a/internal/core/src/indexbuilder/ScalarIndexCreator.h +++ b/internal/core/src/indexbuilder/ScalarIndexCreator.h @@ -25,7 +25,7 @@ class ScalarIndexCreator : public IndexCreatorBase { public: ScalarIndexCreator(DataType data_type, Config& config, - storage::FileManagerImplPtr file_manager); + const storage::FileManagerContext& file_manager_context); ScalarIndexCreator(DataType data_type, Config& config, @@ -64,8 +64,9 @@ using ScalarIndexCreatorPtr = std::unique_ptr; inline ScalarIndexCreatorPtr CreateScalarIndex(DataType dtype, Config& config, - storage::FileManagerImplPtr file_manager) { - return std::make_unique(dtype, config, file_manager); + const storage::FileManagerContext& file_manager_context) { + return std::make_unique( + dtype, config, file_manager_context); } inline ScalarIndexCreatorPtr diff --git a/internal/core/src/indexbuilder/VecIndexCreator.cpp b/internal/core/src/indexbuilder/VecIndexCreator.cpp index 492372225e836..d24c85d57cf08 100644 --- a/internal/core/src/indexbuilder/VecIndexCreator.cpp +++ b/internal/core/src/indexbuilder/VecIndexCreator.cpp @@ -11,7 +11,7 @@ #include -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "indexbuilder/VecIndexCreator.h" #include "index/Utils.h" #include "index/IndexFactory.h" @@ -19,27 +19,30 @@ namespace milvus::indexbuilder { -VecIndexCreator::VecIndexCreator(DataType data_type, - Config& config, - storage::FileManagerImplPtr file_manager) - : VecIndexCreator(data_type, "", config, file_manager, nullptr) { +VecIndexCreator::VecIndexCreator( + DataType data_type, + Config& config, + const storage::FileManagerContext& file_manager_context) + : VecIndexCreator(data_type, "", config, file_manager_context, nullptr) { } -VecIndexCreator::VecIndexCreator(DataType data_type, - const std::string& field_name, - Config& config, - storage::FileManagerImplPtr file_manager, - std::shared_ptr space) +VecIndexCreator::VecIndexCreator( + DataType data_type, + std::string field_name, + Config& config, + const storage::FileManagerContext& file_manager_context, + std::shared_ptr space) : data_type_(data_type), config_(config), space_(space) { index::CreateIndexInfo index_info; index_info.field_type = data_type_; index_info.index_type = index::GetIndexTypeFromConfig(config_); index_info.metric_type = index::GetMetricTypeFromConfig(config_); index_info.field_name = field_name; - index_info.dim = index::GetDimFromConfig(config); + index_info.index_engine_version = + index::GetIndexEngineVersionFromConfig(config_); index_ = index::IndexFactory::GetInstance().CreateIndex( - index_info, file_manager, space_); + index_info, file_manager_context, space_); AssertInfo(index_ != nullptr, "[VecIndexCreator]Index is null after create index"); } diff --git a/internal/core/src/indexbuilder/VecIndexCreator.h b/internal/core/src/indexbuilder/VecIndexCreator.h index 184273799b96e..fdb1ac263002f 100644 --- a/internal/core/src/indexbuilder/VecIndexCreator.h +++ b/internal/core/src/indexbuilder/VecIndexCreator.h @@ -27,14 +27,16 @@ namespace milvus::indexbuilder { // TODO: better to distinguish binary vec & float vec. class VecIndexCreator : public IndexCreatorBase { public: - explicit VecIndexCreator(DataType data_type, - Config& config, - storage::FileManagerImplPtr file_manager); + explicit VecIndexCreator( + DataType data_type, + Config& config, + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()); VecIndexCreator(DataType data_type, const std::string& field_name, Config& config, - storage::FileManagerImplPtr file_manager, + const storage::FileManagerContext& file_manager_context, std::shared_ptr space); void Build(const milvus::DatasetPtr& dataset) override; diff --git a/internal/core/src/indexbuilder/index_c.cpp b/internal/core/src/indexbuilder/index_c.cpp index 993e79327cfbb..d8600cbc51553 100644 --- a/internal/core/src/indexbuilder/index_c.cpp +++ b/internal/core/src/indexbuilder/index_c.cpp @@ -21,7 +21,7 @@ #include #endif -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "indexbuilder/VecIndexCreator.h" #include "indexbuilder/index_c.h" #include "indexbuilder/IndexFactory.h" @@ -32,7 +32,9 @@ #include "pb/index_cgo_msg.pb.h" #include "storage/Util.h" #include "storage/space.h" +#include "index/Meta.h" +using namespace milvus; CStatus CreateIndex(enum CDataType dtype, const char* serialized_type_params, @@ -58,9 +60,14 @@ CreateIndex(enum CDataType dtype, config[param.key()] = param.value(); } + config[milvus::index::INDEX_ENGINE_VERSION] = std::to_string( + knowhere::Version::GetCurrentVersion().VersionNumber()); + auto& index_factory = milvus::indexbuilder::IndexFactory::GetInstance(); auto index = - index_factory.CreateIndex(milvus::DataType(dtype), config, nullptr); + index_factory.CreateIndex(milvus::DataType(dtype), + config, + milvus::storage::FileManagerContext()); *res_index = index.release(); status.error_code = Success; @@ -90,6 +97,12 @@ CreateIndexV2(CIndex* res_index, CBuildIndexInfo c_build_index_info) { AssertInfo(index_type.has_value(), "index type is empty"); index_info.index_type = index_type.value(); + auto engine_version = build_index_info->index_engine_version; + + index_info.index_engine_version = engine_version; + config[milvus::index::INDEX_ENGINE_VERSION] = + std::to_string(engine_version); + // get metric type if (milvus::datatype_is_vector(field_type)) { auto metric_type = milvus::index::GetValueFromConfig( @@ -111,14 +124,13 @@ CreateIndexV2(CIndex* res_index, CBuildIndexInfo c_build_index_info) { build_index_info->index_version}; auto chunk_manager = milvus::storage::CreateChunkManager( build_index_info->storage_config); - auto file_manager = milvus::storage::CreateFileManager( - index_info.index_type, field_meta, index_meta, chunk_manager); - AssertInfo(file_manager != nullptr, "create file manager failed!"); + + milvus::storage::FileManagerContext fileManagerContext( + field_meta, index_meta, chunk_manager); auto index = milvus::indexbuilder::IndexFactory::GetInstance().CreateIndex( - build_index_info->field_type, config, file_manager); - + build_index_info->field_type, config, fileManagerContext); index->Build(); *res_index = index.release(); auto status = CStatus(); @@ -386,12 +398,15 @@ NewBuildIndexInfo(CBuildIndexInfo* c_build_index_info, storage_config.root_path = std::string(c_storage_config.root_path); storage_config.storage_type = std::string(c_storage_config.storage_type); + storage_config.cloud_provider = + std::string(c_storage_config.cloud_provider); storage_config.iam_endpoint = std::string(c_storage_config.iam_endpoint); storage_config.useSSL = c_storage_config.useSSL; storage_config.useIAM = c_storage_config.useIAM; storage_config.region = c_storage_config.region; storage_config.useVirtualHost = c_storage_config.useVirtualHost; + storage_config.requestTimeoutMs = c_storage_config.requestTimeoutMs; *c_build_index_info = build_index_info.release(); auto status = CStatus(); @@ -563,10 +578,23 @@ AppendInsertFilePath(CBuildIndexInfo c_build_index_info, status.error_msg = ""; return status; } catch (std::exception& e) { + return milvus::FailureCStatus(&e); + } +} + +CStatus +AppendIndexEngineVersionToBuildInfo(CBuildIndexInfo c_load_index_info, + int32_t index_engine_version) { + try { + auto build_index_info = (BuildIndexInfo*)c_load_index_info; + build_index_info->index_engine_version = index_engine_version; + auto status = CStatus(); - status.error_code = UnexpectedError; - status.error_msg = strdup(e.what()); + status.error_code = Success; + status.error_msg = ""; return status; + } catch (std::exception& e) { + return milvus::FailureCStatus(&e); } } diff --git a/internal/core/src/indexbuilder/index_c.h b/internal/core/src/indexbuilder/index_c.h index 623c52948c5d8..bd7fce66d91b9 100644 --- a/internal/core/src/indexbuilder/index_c.h +++ b/internal/core/src/indexbuilder/index_c.h @@ -88,6 +88,10 @@ AppendIndexMetaInfo(CBuildIndexInfo c_build_index_info, CStatus AppendInsertFilePath(CBuildIndexInfo c_build_index_info, const char* file_path); +CStatus +AppendIndexEngineVersionToBuildInfo(CBuildIndexInfo c_load_index_info, + int32_t c_index_engine_version); + CStatus CreateIndexV2(CIndex* res_index, CBuildIndexInfo c_build_index_info); diff --git a/internal/core/src/indexbuilder/types.h b/internal/core/src/indexbuilder/types.h index 725fd15fe20f2..5f5ce89fbae95 100644 --- a/internal/core/src/indexbuilder/types.h +++ b/internal/core/src/indexbuilder/types.h @@ -38,4 +38,5 @@ struct BuildIndexInfo { int64_t data_store_version; std::string index_store_path; int64_t dim; -}; \ No newline at end of file + int32_t index_engine_version; +}; diff --git a/internal/core/src/log/Log.cpp b/internal/core/src/log/Log.cpp index 7598284f8aabb..4ec482e67696e 100644 --- a/internal/core/src/log/Log.cpp +++ b/internal/core/src/log/Log.cpp @@ -47,7 +47,7 @@ LogOut(const char* pattern, ...) { vsnprintf(str_p.get(), len, pattern, vl); // NOLINT va_end(vl); - return std::string(str_p.get()); + return {str_p.get()}; } void @@ -83,18 +83,6 @@ get_now_timestamp() { #ifndef WIN32 -int64_t -get_system_boottime() { - FILE* uptime = fopen("/proc/uptime", "r"); - float since_sys_boot, _; - auto ret = fscanf(uptime, "%f %f", &since_sys_boot, &_); - fclose(uptime); - if (ret != 2) { - throw std::runtime_error("read /proc/uptime failed."); - } - return static_cast(since_sys_boot); -} - int64_t get_thread_starttime() { #ifdef __APPLE__ @@ -133,34 +121,11 @@ get_thread_starttime() { return val / sysconf(_SC_CLK_TCK); } -int64_t -get_thread_start_timestamp() { - try { - return get_now_timestamp() - get_system_boottime() + - get_thread_starttime(); - } catch (...) { - return 0; - } -} - #else #define WINDOWS_TICK 10000000 #define SEC_TO_UNIX_EPOCH 11644473600LL -int64_t -get_thread_start_timestamp() { - FILETIME dummy; - FILETIME ret; - - if (GetThreadTimes(GetCurrentThread(), &ret, &dummy, &dummy, &dummy)) { - auto ticks = Int64ShllMod32(ret.dwHighDateTime, 32) | ret.dwLowDateTime; - auto thread_started = ticks / WINDOWS_TICK - SEC_TO_UNIX_EPOCH; - return get_now_timestamp() - thread_started; - } - return 0; -} - #endif // } // namespace milvus diff --git a/internal/core/src/log/Log.h b/internal/core/src/log/Log.h index 6d217f8308e35..171c542264e06 100644 --- a/internal/core/src/log/Log.h +++ b/internal/core/src/log/Log.h @@ -37,7 +37,6 @@ #define VAR_CLIENT_TAG (context->client_tag()) #define VAR_CLIENT_IPPORT (context->client_ipport()) #define VAR_THREAD_ID (gettid()) -#define VAR_THREAD_START_TIMESTAMP (get_thread_start_timestamp()) #define VAR_COMMAND_TAG (context->command_tag()) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -93,7 +92,4 @@ SetThreadName(const std::string_view name); std::string GetThreadName(); -int64_t -get_thread_start_timestamp(); - // } // namespace milvus diff --git a/internal/core/src/mmap/Column.h b/internal/core/src/mmap/Column.h index bd1e8d5f0b378..7c58e1fe05370 100644 --- a/internal/core/src/mmap/Column.h +++ b/internal/core/src/mmap/Column.h @@ -23,12 +23,13 @@ #include "common/FieldMeta.h" #include "common/Span.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" +#include "common/File.h" #include "fmt/format.h" #include "log/Log.h" #include "mmap/Utils.h" #include "storage/FieldData.h" -#include "utils/File.h" +#include "common/Array.h" namespace milvus { @@ -79,6 +80,32 @@ class ColumnBase { mmap_flags, file.Descriptor(), 0)); + AssertInfo(data_ != MAP_FAILED, + fmt::format("failed to create file-backed map, err: {}", + strerror(errno))); + madvise(data_, cap_size_ + padding_, MADV_WILLNEED); + } + + // mmap mode ctor + ColumnBase(const File& file, + size_t size, + int dim, + const DataType& data_type) + : type_size_(datatype_sizeof(data_type, dim)), + num_rows_(size / datatype_sizeof(data_type, dim)), + size_(size), + cap_size_(size) { + padding_ = data_type == DataType::JSON ? simdjson::SIMDJSON_PADDING : 0; + + data_ = static_cast(mmap(nullptr, + cap_size_ + padding_, + PROT_READ, + mmap_flags, + file.Descriptor(), + 0)); + AssertInfo(data_ != MAP_FAILED, + fmt::format("failed to create file-backed map, err: {}", + strerror(errno))); } virtual ~ColumnBase() { @@ -115,6 +142,11 @@ class ColumnBase { return num_rows_; }; + const size_t + ByteSize() const { + return cap_size_ + padding_; + } + // The capacity of the column, // DO NOT call this for variable length column. size_t @@ -203,6 +235,11 @@ class Column : public ColumnBase { : ColumnBase(file, size, field_meta) { } + // mmap mode ctor + Column(const File& file, size_t size, int dim, DataType data_type) + : ColumnBase(file, size, dim, data_type) { + } + Column(Column&& column) noexcept : ColumnBase(std::move(column)) { } @@ -292,4 +329,90 @@ class VariableColumn : public ColumnBase { // Compatible with current Span type std::vector views_{}; }; + +class ArrayColumn : public ColumnBase { + public: + // memory mode ctor + ArrayColumn(size_t num_rows, const FieldMeta& field_meta) + : ColumnBase(num_rows, field_meta), + element_type_(field_meta.get_element_type()) { + } + + // mmap mode ctor + ArrayColumn(const File& file, size_t size, const FieldMeta& field_meta) + : ColumnBase(file, size, field_meta), + element_type_(field_meta.get_element_type()) { + } + + ArrayColumn(ArrayColumn&& column) noexcept + : ColumnBase(std::move(column)), + indices_(std::move(column.indices_)), + views_(std::move(column.views_)), + element_type_(column.element_type_) { + } + + ~ArrayColumn() override = default; + + SpanBase + Span() const override { + return SpanBase(views_.data(), views_.size(), sizeof(ArrayView)); + } + + [[nodiscard]] const std::vector& + Views() const { + return views_; + } + + ArrayView + operator[](const int i) const { + return views_[i]; + } + + ScalarArray + RawAt(const int i) const { + return views_[i].output_data(); + } + + void + Append(const Array& array) { + indices_.emplace_back(size_); + element_indices_.emplace_back(array.get_offsets()); + ColumnBase::Append(static_cast(array.data()), + array.byte_size()); + } + + void + Seal(std::vector&& indices = {}, + std::vector>&& element_indices = {}) { + if (!indices.empty()) { + indices_ = std::move(indices); + element_indices_ = std::move(element_indices); + } + ConstructViews(); + } + + protected: + void + ConstructViews() { + views_.reserve(indices_.size()); + for (size_t i = 0; i < indices_.size() - 1; i++) { + views_.emplace_back(data_ + indices_[i], + indices_[i + 1] - indices_[i], + element_type_, + std::move(element_indices_[i])); + } + views_.emplace_back(data_ + indices_.back(), + size_ - indices_.back(), + element_type_, + std::move(element_indices_[indices_.size() - 1])); + element_indices_.clear(); + } + + private: + std::vector indices_{}; + std::vector> element_indices_{}; + // Compatible with current Span type + std::vector views_{}; + DataType element_type_; +}; } // namespace milvus diff --git a/internal/core/src/mmap/Utils.h b/internal/core/src/mmap/Utils.h index 3f01fe7b4d0d6..e3b718e766a3f 100644 --- a/internal/core/src/mmap/Utils.h +++ b/internal/core/src/mmap/Utils.h @@ -27,7 +27,7 @@ #include "common/FieldMeta.h" #include "mmap/Types.h" #include "storage/Util.h" -#include "utils/File.h" +#include "common/File.h" namespace milvus { @@ -66,8 +66,8 @@ FillField(DataType data_type, const storage::FieldDataPtr data, void* dst) { break; } default: - PanicInfo(fmt::format("not supported data type {}", - datatype_name(data_type))); + PanicInfo(DataTypeInvalid, + fmt::format("not supported data type {}", data_type)); } } else { memcpy(dst, data->Data(), data->Size()); @@ -80,7 +80,8 @@ FillField(DataType data_type, const storage::FieldDataPtr data, void* dst) { inline size_t WriteFieldData(File& file, DataType data_type, - const storage::FieldDataPtr& data) { + const storage::FieldDataPtr& data, + std::vector>& element_indices) { size_t total_written{0}; if (datatype_is_variable(data_type)) { switch (data_type) { @@ -110,8 +111,22 @@ WriteFieldData(File& file, } break; } + case DataType::ARRAY: { + for (size_t i = 0; i < data->get_num_rows(); ++i) { + auto array = static_cast(data->RawValue(i)); + ssize_t written = + file.Write(array->data(), array->byte_size()); + if (written < array->byte_size()) { + break; + } + element_indices.emplace_back(array->get_offsets()); + total_written += written; + } + break; + } default: - PanicInfo(fmt::format("not supported data type {}", + PanicInfo(DataTypeInvalid, + fmt::format("not supported data type {}", datatype_name(data_type))); } } else { diff --git a/internal/core/src/pb/common.pb.cc b/internal/core/src/pb/common.pb.cc index d71ef09b617c4..e994af71fbcfd 100644 --- a/internal/core/src/pb/common.pb.cc +++ b/internal/core/src/pb/common.pb.cc @@ -121,9 +121,22 @@ struct AddressDefaultTypeInternal { }; }; PROTOBUF_ATTRIBUTE_NO_DESTROY PROTOBUF_CONSTINIT PROTOBUF_ATTRIBUTE_INIT_PRIORITY1 AddressDefaultTypeInternal _Address_default_instance_; +PROTOBUF_CONSTEXPR MsgBase_PropertiesEntry_DoNotUse::MsgBase_PropertiesEntry_DoNotUse( + ::_pbi::ConstantInitialized) {} +struct MsgBase_PropertiesEntry_DoNotUseDefaultTypeInternal { + PROTOBUF_CONSTEXPR MsgBase_PropertiesEntry_DoNotUseDefaultTypeInternal() + : _instance(::_pbi::ConstantInitialized{}) {} + ~MsgBase_PropertiesEntry_DoNotUseDefaultTypeInternal() {} + union { + MsgBase_PropertiesEntry_DoNotUse _instance; + }; +}; +PROTOBUF_ATTRIBUTE_NO_DESTROY PROTOBUF_CONSTINIT PROTOBUF_ATTRIBUTE_INIT_PRIORITY1 MsgBase_PropertiesEntry_DoNotUseDefaultTypeInternal _MsgBase_PropertiesEntry_DoNotUse_default_instance_; PROTOBUF_CONSTEXPR MsgBase::MsgBase( ::_pbi::ConstantInitialized): _impl_{ - /*decltype(_impl_.msgid_)*/int64_t{0} + /*decltype(_impl_.properties_)*/{::_pbi::ConstantInitialized()} + , /*decltype(_impl_.replicateinfo_)*/nullptr + , /*decltype(_impl_.msgid_)*/int64_t{0} , /*decltype(_impl_.timestamp_)*/uint64_t{0u} , /*decltype(_impl_.sourceid_)*/int64_t{0} , /*decltype(_impl_.targetid_)*/int64_t{0} @@ -138,6 +151,20 @@ struct MsgBaseDefaultTypeInternal { }; }; PROTOBUF_ATTRIBUTE_NO_DESTROY PROTOBUF_CONSTINIT PROTOBUF_ATTRIBUTE_INIT_PRIORITY1 MsgBaseDefaultTypeInternal _MsgBase_default_instance_; +PROTOBUF_CONSTEXPR ReplicateInfo::ReplicateInfo( + ::_pbi::ConstantInitialized): _impl_{ + /*decltype(_impl_.msgtimestamp_)*/uint64_t{0u} + , /*decltype(_impl_.isreplicate_)*/false + , /*decltype(_impl_._cached_size_)*/{}} {} +struct ReplicateInfoDefaultTypeInternal { + PROTOBUF_CONSTEXPR ReplicateInfoDefaultTypeInternal() + : _instance(::_pbi::ConstantInitialized{}) {} + ~ReplicateInfoDefaultTypeInternal() {} + union { + ReplicateInfo _instance; + }; +}; +PROTOBUF_ATTRIBUTE_NO_DESTROY PROTOBUF_CONSTINIT PROTOBUF_ATTRIBUTE_INIT_PRIORITY1 ReplicateInfoDefaultTypeInternal _ReplicateInfo_default_instance_; PROTOBUF_CONSTEXPR MsgHeader::MsgHeader( ::_pbi::ConstantInitialized): _impl_{ /*decltype(_impl_.base_)*/nullptr @@ -256,7 +283,7 @@ PROTOBUF_ATTRIBUTE_NO_DESTROY PROTOBUF_CONSTINIT PROTOBUF_ATTRIBUTE_INIT_PRIORIT } // namespace common } // namespace proto } // namespace milvus -static ::_pb::Metadata file_level_metadata_common_2eproto[16]; +static ::_pb::Metadata file_level_metadata_common_2eproto[18]; static const ::_pb::EnumDescriptor* file_level_enum_descriptors_common_2eproto[13]; static constexpr ::_pb::ServiceDescriptor const** file_level_service_descriptors_common_2eproto = nullptr; @@ -317,6 +344,16 @@ const uint32_t TableStruct_common_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(p ~0u, // no _inlined_string_donated_ PROTOBUF_FIELD_OFFSET(::milvus::proto::common::Address, _impl_.ip_), PROTOBUF_FIELD_OFFSET(::milvus::proto::common::Address, _impl_.port_), + PROTOBUF_FIELD_OFFSET(::milvus::proto::common::MsgBase_PropertiesEntry_DoNotUse, _has_bits_), + PROTOBUF_FIELD_OFFSET(::milvus::proto::common::MsgBase_PropertiesEntry_DoNotUse, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + ~0u, // no _inlined_string_donated_ + PROTOBUF_FIELD_OFFSET(::milvus::proto::common::MsgBase_PropertiesEntry_DoNotUse, key_), + PROTOBUF_FIELD_OFFSET(::milvus::proto::common::MsgBase_PropertiesEntry_DoNotUse, value_), + 0, + 1, ~0u, // no _has_bits_ PROTOBUF_FIELD_OFFSET(::milvus::proto::common::MsgBase, _internal_metadata_), ~0u, // no _extensions_ @@ -328,6 +365,16 @@ const uint32_t TableStruct_common_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(p PROTOBUF_FIELD_OFFSET(::milvus::proto::common::MsgBase, _impl_.timestamp_), PROTOBUF_FIELD_OFFSET(::milvus::proto::common::MsgBase, _impl_.sourceid_), PROTOBUF_FIELD_OFFSET(::milvus::proto::common::MsgBase, _impl_.targetid_), + PROTOBUF_FIELD_OFFSET(::milvus::proto::common::MsgBase, _impl_.properties_), + PROTOBUF_FIELD_OFFSET(::milvus::proto::common::MsgBase, _impl_.replicateinfo_), + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::milvus::proto::common::ReplicateInfo, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + ~0u, // no _inlined_string_donated_ + PROTOBUF_FIELD_OFFSET(::milvus::proto::common::ReplicateInfo, _impl_.isreplicate_), + PROTOBUF_FIELD_OFFSET(::milvus::proto::common::ReplicateInfo, _impl_.msgtimestamp_), ~0u, // no _has_bits_ PROTOBUF_FIELD_OFFSET(::milvus::proto::common::MsgHeader, _internal_metadata_), ~0u, // no _extensions_ @@ -414,15 +461,17 @@ static const ::_pbi::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protode { 32, -1, -1, sizeof(::milvus::proto::common::PlaceholderValue)}, { 41, -1, -1, sizeof(::milvus::proto::common::PlaceholderGroup)}, { 48, -1, -1, sizeof(::milvus::proto::common::Address)}, - { 56, -1, -1, sizeof(::milvus::proto::common::MsgBase)}, - { 67, -1, -1, sizeof(::milvus::proto::common::MsgHeader)}, - { 74, -1, -1, sizeof(::milvus::proto::common::DMLMsgHeader)}, - { 82, -1, -1, sizeof(::milvus::proto::common::PrivilegeExt)}, - { 92, -1, -1, sizeof(::milvus::proto::common::SegmentStats)}, - { 100, 108, -1, sizeof(::milvus::proto::common::ClientInfo_ReservedEntry_DoNotUse)}, - { 110, -1, -1, sizeof(::milvus::proto::common::ClientInfo)}, - { 122, 130, -1, sizeof(::milvus::proto::common::ServerInfo_ReservedEntry_DoNotUse)}, - { 132, -1, -1, sizeof(::milvus::proto::common::ServerInfo)}, + { 56, 64, -1, sizeof(::milvus::proto::common::MsgBase_PropertiesEntry_DoNotUse)}, + { 66, -1, -1, sizeof(::milvus::proto::common::MsgBase)}, + { 79, -1, -1, sizeof(::milvus::proto::common::ReplicateInfo)}, + { 87, -1, -1, sizeof(::milvus::proto::common::MsgHeader)}, + { 94, -1, -1, sizeof(::milvus::proto::common::DMLMsgHeader)}, + { 102, -1, -1, sizeof(::milvus::proto::common::PrivilegeExt)}, + { 112, -1, -1, sizeof(::milvus::proto::common::SegmentStats)}, + { 120, 128, -1, sizeof(::milvus::proto::common::ClientInfo_ReservedEntry_DoNotUse)}, + { 130, -1, -1, sizeof(::milvus::proto::common::ClientInfo)}, + { 142, 150, -1, sizeof(::milvus::proto::common::ServerInfo_ReservedEntry_DoNotUse)}, + { 152, -1, -1, sizeof(::milvus::proto::common::ServerInfo)}, }; static const ::_pb::Message* const file_default_instances[] = { @@ -433,7 +482,9 @@ static const ::_pb::Message* const file_default_instances[] = { &::milvus::proto::common::_PlaceholderValue_default_instance_._instance, &::milvus::proto::common::_PlaceholderGroup_default_instance_._instance, &::milvus::proto::common::_Address_default_instance_._instance, + &::milvus::proto::common::_MsgBase_PropertiesEntry_DoNotUse_default_instance_._instance, &::milvus::proto::common::_MsgBase_default_instance_._instance, + &::milvus::proto::common::_ReplicateInfo_default_instance_._instance, &::milvus::proto::common::_MsgHeader_default_instance_._instance, &::milvus::proto::common::_DMLMsgHeader_default_instance_._instance, &::milvus::proto::common::_PrivilegeExt_default_instance_._instance, @@ -457,181 +508,187 @@ const char descriptor_table_protodef_common_2eproto[] PROTOBUF_SECTION_VARIABLE( "alues\030\003 \003(\014\"O\n\020PlaceholderGroup\022;\n\014place" "holders\030\001 \003(\0132%.milvus.proto.common.Plac" "eholderValue\"#\n\007Address\022\n\n\002ip\030\001 \001(\t\022\014\n\004p" - "ort\030\002 \001(\003\"\177\n\007MsgBase\022.\n\010msg_type\030\001 \001(\0162\034" - ".milvus.proto.common.MsgType\022\r\n\005msgID\030\002 " - "\001(\003\022\021\n\ttimestamp\030\003 \001(\004\022\020\n\010sourceID\030\004 \001(\003" - "\022\020\n\010targetID\030\005 \001(\003\"7\n\tMsgHeader\022*\n\004base\030" - "\001 \001(\0132\034.milvus.proto.common.MsgBase\"M\n\014D" - "MLMsgHeader\022*\n\004base\030\001 \001(\0132\034.milvus.proto" - ".common.MsgBase\022\021\n\tshardName\030\002 \001(\t\"\273\001\n\014P" - "rivilegeExt\0224\n\013object_type\030\001 \001(\0162\037.milvu" - "s.proto.common.ObjectType\022>\n\020object_priv" - "ilege\030\002 \001(\0162$.milvus.proto.common.Object" - "Privilege\022\031\n\021object_name_index\030\003 \001(\005\022\032\n\022" - "object_name_indexs\030\004 \001(\005\"2\n\014SegmentStats" - "\022\021\n\tSegmentID\030\001 \001(\003\022\017\n\007NumRows\030\002 \001(\003\"\325\001\n" - "\nClientInfo\022\020\n\010sdk_type\030\001 \001(\t\022\023\n\013sdk_ver" - "sion\030\002 \001(\t\022\022\n\nlocal_time\030\003 \001(\t\022\014\n\004user\030\004" - " \001(\t\022\014\n\004host\030\005 \001(\t\022\?\n\010reserved\030\006 \003(\0132-.m" - "ilvus.proto.common.ClientInfo.ReservedEn" - "try\032/\n\rReservedEntry\022\013\n\003key\030\001 \001(\t\022\r\n\005val" - "ue\030\002 \001(\t:\0028\001\"\343\001\n\nServerInfo\022\022\n\nbuild_tag" - "s\030\001 \001(\t\022\022\n\nbuild_time\030\002 \001(\t\022\022\n\ngit_commi" - "t\030\003 \001(\t\022\022\n\ngo_version\030\004 \001(\t\022\023\n\013deploy_mo" - "de\030\005 \001(\t\022\?\n\010reserved\030\006 \003(\0132-.milvus.prot" - "o.common.ServerInfo.ReservedEntry\032/\n\rRes" - "ervedEntry\022\013\n\003key\030\001 \001(\t\022\r\n\005value\030\002 \001(\t:\002" - "8\001*\303\n\n\tErrorCode\022\013\n\007Success\020\000\022\023\n\017Unexpec" - "tedError\020\001\022\021\n\rConnectFailed\020\002\022\024\n\020Permiss" - "ionDenied\020\003\022\027\n\023CollectionNotExists\020\004\022\023\n\017" - "IllegalArgument\020\005\022\024\n\020IllegalDimension\020\007\022" - "\024\n\020IllegalIndexType\020\010\022\031\n\025IllegalCollecti" - "onName\020\t\022\017\n\013IllegalTOPK\020\n\022\024\n\020IllegalRowR" - "ecord\020\013\022\023\n\017IllegalVectorID\020\014\022\027\n\023IllegalS" - "earchResult\020\r\022\020\n\014FileNotFound\020\016\022\016\n\nMetaF" - "ailed\020\017\022\017\n\013CacheFailed\020\020\022\026\n\022CannotCreate" - "Folder\020\021\022\024\n\020CannotCreateFile\020\022\022\026\n\022Cannot" - "DeleteFolder\020\023\022\024\n\020CannotDeleteFile\020\024\022\023\n\017" - "BuildIndexError\020\025\022\020\n\014IllegalNLIST\020\026\022\025\n\021I" - "llegalMetricType\020\027\022\017\n\013OutOfMemory\020\030\022\021\n\rI" - "ndexNotExist\020\031\022\023\n\017EmptyCollection\020\032\022\033\n\027U" - "pdateImportTaskFailure\020\033\022\032\n\026CollectionNa" - "meNotFound\020\034\022\033\n\027CreateCredentialFailure\020" - "\035\022\033\n\027UpdateCredentialFailure\020\036\022\033\n\027Delete" - "CredentialFailure\020\037\022\030\n\024GetCredentialFail" - "ure\020 \022\030\n\024ListCredUsersFailure\020!\022\022\n\016GetUs" - "erFailure\020\"\022\025\n\021CreateRoleFailure\020#\022\023\n\017Dr" - "opRoleFailure\020$\022\032\n\026OperateUserRoleFailur" - "e\020%\022\025\n\021SelectRoleFailure\020&\022\025\n\021SelectUser" - "Failure\020\'\022\031\n\025SelectResourceFailure\020(\022\033\n\027" - "OperatePrivilegeFailure\020)\022\026\n\022SelectGrant" - "Failure\020*\022!\n\035RefreshPolicyInfoCacheFailu" - "re\020+\022\025\n\021ListPolicyFailure\020,\022\022\n\016NotShardL" - "eader\020-\022\026\n\022NoReplicaAvailable\020.\022\023\n\017Segme" - "ntNotFound\020/\022\r\n\tForceDeny\0200\022\r\n\tRateLimit" - "\0201\022\022\n\016NodeIDNotMatch\0202\022\024\n\020UpsertAutoIDTr" - "ue\0203\022\034\n\030InsufficientMemoryToLoad\0204\022\030\n\024Me" - "moryQuotaExhausted\0205\022\026\n\022DiskQuotaExhaust" - "ed\0206\022\025\n\021TimeTickLongDelay\0207\022\021\n\rNotReadyS" - "erve\0208\022\033\n\027NotReadyCoordActivating\0209\022\017\n\013D" - "ataCoordNA\020d\022\022\n\rDDRequestRace\020\350\007*c\n\nInde" - "xState\022\022\n\016IndexStateNone\020\000\022\014\n\010Unissued\020\001" - "\022\016\n\nInProgress\020\002\022\014\n\010Finished\020\003\022\n\n\006Failed" - "\020\004\022\t\n\005Retry\020\005*\202\001\n\014SegmentState\022\024\n\020Segmen" - "tStateNone\020\000\022\014\n\010NotExist\020\001\022\013\n\007Growing\020\002\022" - "\n\n\006Sealed\020\003\022\013\n\007Flushed\020\004\022\014\n\010Flushing\020\005\022\013" - "\n\007Dropped\020\006\022\r\n\tImporting\020\007*i\n\017Placeholde" - "rType\022\010\n\004None\020\000\022\020\n\014BinaryVector\020d\022\017\n\013Flo" - "atVector\020e\022\021\n\rFloat16Vector\020f\022\t\n\005Int64\020\005" - "\022\013\n\007VarChar\020\025*\264\020\n\007MsgType\022\r\n\tUndefined\020\000" - "\022\024\n\020CreateCollection\020d\022\022\n\016DropCollection" - "\020e\022\021\n\rHasCollection\020f\022\026\n\022DescribeCollect" - "ion\020g\022\023\n\017ShowCollections\020h\022\024\n\020GetSystemC" - "onfigs\020i\022\022\n\016LoadCollection\020j\022\025\n\021ReleaseC" - "ollection\020k\022\017\n\013CreateAlias\020l\022\r\n\tDropAlia" - "s\020m\022\016\n\nAlterAlias\020n\022\023\n\017AlterCollection\020o" - "\022\024\n\020RenameCollection\020p\022\021\n\rDescribeAlias\020" - "q\022\017\n\013ListAliases\020r\022\024\n\017CreatePartition\020\310\001" - "\022\022\n\rDropPartition\020\311\001\022\021\n\014HasPartition\020\312\001\022" - "\026\n\021DescribePartition\020\313\001\022\023\n\016ShowPartition" - "s\020\314\001\022\023\n\016LoadPartitions\020\315\001\022\026\n\021ReleasePart" - "itions\020\316\001\022\021\n\014ShowSegments\020\372\001\022\024\n\017Describe" - "Segment\020\373\001\022\021\n\014LoadSegments\020\374\001\022\024\n\017Release" - "Segments\020\375\001\022\024\n\017HandoffSegments\020\376\001\022\030\n\023Loa" - "dBalanceSegments\020\377\001\022\025\n\020DescribeSegments\020" - "\200\002\022\034\n\027FederListIndexedSegment\020\201\002\022\"\n\035Fede" - "rDescribeSegmentIndexData\020\202\002\022\020\n\013CreateIn" - "dex\020\254\002\022\022\n\rDescribeIndex\020\255\002\022\016\n\tDropIndex\020" - "\256\002\022\027\n\022GetIndexStatistics\020\257\002\022\013\n\006Insert\020\220\003" - "\022\013\n\006Delete\020\221\003\022\n\n\005Flush\020\222\003\022\027\n\022ResendSegme" - "ntStats\020\223\003\022\013\n\006Upsert\020\224\003\022\013\n\006Search\020\364\003\022\021\n\014" - "SearchResult\020\365\003\022\022\n\rGetIndexState\020\366\003\022\032\n\025G" - "etIndexBuildProgress\020\367\003\022\034\n\027GetCollection" - "Statistics\020\370\003\022\033\n\026GetPartitionStatistics\020" - "\371\003\022\r\n\010Retrieve\020\372\003\022\023\n\016RetrieveResult\020\373\003\022\024" - "\n\017WatchDmChannels\020\374\003\022\025\n\020RemoveDmChannels" - "\020\375\003\022\027\n\022WatchQueryChannels\020\376\003\022\030\n\023RemoveQu" - "eryChannels\020\377\003\022\035\n\030SealedSegmentsChangeIn" - "fo\020\200\004\022\027\n\022WatchDeltaChannels\020\201\004\022\024\n\017GetSha" - "rdLeaders\020\202\004\022\020\n\013GetReplicas\020\203\004\022\023\n\016UnsubD" - "mChannel\020\204\004\022\024\n\017GetDistribution\020\205\004\022\025\n\020Syn" - "cDistribution\020\206\004\022\020\n\013SegmentInfo\020\330\004\022\017\n\nSy" - "stemInfo\020\331\004\022\024\n\017GetRecoveryInfo\020\332\004\022\024\n\017Get" - "SegmentState\020\333\004\022\r\n\010TimeTick\020\260\t\022\023\n\016QueryN" - "odeStats\020\261\t\022\016\n\tLoadIndex\020\262\t\022\016\n\tRequestID" - "\020\263\t\022\017\n\nRequestTSO\020\264\t\022\024\n\017AllocateSegment\020" - "\265\t\022\026\n\021SegmentStatistics\020\266\t\022\025\n\020SegmentFlu" - "shDone\020\267\t\022\017\n\nDataNodeTt\020\270\t\022\014\n\007Connect\020\271\t" - "\022\024\n\017ListClientInfos\020\272\t\022\023\n\016AllocTimestamp" - "\020\273\t\022\025\n\020CreateCredential\020\334\013\022\022\n\rGetCredent" - "ial\020\335\013\022\025\n\020DeleteCredential\020\336\013\022\025\n\020UpdateC" - "redential\020\337\013\022\026\n\021ListCredUsernames\020\340\013\022\017\n\n" - "CreateRole\020\300\014\022\r\n\010DropRole\020\301\014\022\024\n\017OperateU" - "serRole\020\302\014\022\017\n\nSelectRole\020\303\014\022\017\n\nSelectUse" - "r\020\304\014\022\023\n\016SelectResource\020\305\014\022\025\n\020OperatePriv" - "ilege\020\306\014\022\020\n\013SelectGrant\020\307\014\022\033\n\026RefreshPol" - "icyInfoCache\020\310\014\022\017\n\nListPolicy\020\311\014\022\030\n\023Crea" - "teResourceGroup\020\244\r\022\026\n\021DropResourceGroup\020" - "\245\r\022\027\n\022ListResourceGroups\020\246\r\022\032\n\025DescribeR" - "esourceGroup\020\247\r\022\021\n\014TransferNode\020\250\r\022\024\n\017Tr" - "ansferReplica\020\251\r\022\023\n\016CreateDatabase\020\211\016\022\021\n" - "\014DropDatabase\020\212\016\022\022\n\rListDatabases\020\213\016*\"\n\007" - "DslType\022\007\n\003Dsl\020\000\022\016\n\nBoolExprV1\020\001*B\n\017Comp" - "actionState\022\021\n\rUndefiedState\020\000\022\r\n\tExecut" - "ing\020\001\022\r\n\tCompleted\020\002*X\n\020ConsistencyLevel" - "\022\n\n\006Strong\020\000\022\013\n\007Session\020\001\022\013\n\007Bounded\020\002\022\016" - "\n\nEventually\020\003\022\016\n\nCustomized\020\004*\236\001\n\013Impor" - "tState\022\021\n\rImportPending\020\000\022\020\n\014ImportFaile" - "d\020\001\022\021\n\rImportStarted\020\002\022\023\n\017ImportPersiste" - "d\020\005\022\021\n\rImportFlushed\020\010\022\023\n\017ImportComplete" - "d\020\006\022\032\n\026ImportFailedAndCleaned\020\007*2\n\nObjec" - "tType\022\016\n\nCollection\020\000\022\n\n\006Global\020\001\022\010\n\004Use" - "r\020\002*\241\010\n\017ObjectPrivilege\022\020\n\014PrivilegeAll\020" - "\000\022\035\n\031PrivilegeCreateCollection\020\001\022\033\n\027Priv" - "ilegeDropCollection\020\002\022\037\n\033PrivilegeDescri" - "beCollection\020\003\022\034\n\030PrivilegeShowCollectio" - "ns\020\004\022\021\n\rPrivilegeLoad\020\005\022\024\n\020PrivilegeRele" - "ase\020\006\022\027\n\023PrivilegeCompaction\020\007\022\023\n\017Privil" - "egeInsert\020\010\022\023\n\017PrivilegeDelete\020\t\022\032\n\026Priv" - "ilegeGetStatistics\020\n\022\030\n\024PrivilegeCreateI" - "ndex\020\013\022\030\n\024PrivilegeIndexDetail\020\014\022\026\n\022Priv" - "ilegeDropIndex\020\r\022\023\n\017PrivilegeSearch\020\016\022\022\n" - "\016PrivilegeFlush\020\017\022\022\n\016PrivilegeQuery\020\020\022\030\n" - "\024PrivilegeLoadBalance\020\021\022\023\n\017PrivilegeImpo" - "rt\020\022\022\034\n\030PrivilegeCreateOwnership\020\023\022\027\n\023Pr" - "ivilegeUpdateUser\020\024\022\032\n\026PrivilegeDropOwne" - "rship\020\025\022\034\n\030PrivilegeSelectOwnership\020\026\022\034\n" - "\030PrivilegeManageOwnership\020\027\022\027\n\023Privilege" - "SelectUser\020\030\022\023\n\017PrivilegeUpsert\020\031\022 \n\034Pri" - "vilegeCreateResourceGroup\020\032\022\036\n\032Privilege" - "DropResourceGroup\020\033\022\"\n\036PrivilegeDescribe" - "ResourceGroup\020\034\022\037\n\033PrivilegeListResource" - "Groups\020\035\022\031\n\025PrivilegeTransferNode\020\036\022\034\n\030P" - "rivilegeTransferReplica\020\037\022\037\n\033PrivilegeGe" - "tLoadingProgress\020 \022\031\n\025PrivilegeGetLoadSt" - "ate\020!\022\035\n\031PrivilegeRenameCollection\020\"\022\033\n\027" - "PrivilegeCreateDatabase\020#\022\031\n\025PrivilegeDr" - "opDatabase\020$\022\032\n\026PrivilegeListDatabases\020%" - "\022\025\n\021PrivilegeFlushAll\020&*S\n\tStateCode\022\020\n\014" - "Initializing\020\000\022\013\n\007Healthy\020\001\022\014\n\010Abnormal\020" - "\002\022\013\n\007StandBy\020\003\022\014\n\010Stopping\020\004*c\n\tLoadStat" - "e\022\025\n\021LoadStateNotExist\020\000\022\024\n\020LoadStateNot" - "Load\020\001\022\024\n\020LoadStateLoading\020\002\022\023\n\017LoadStat" - "eLoaded\020\003:^\n\021privilege_ext_obj\022\037.google." - "protobuf.MessageOptions\030\351\007 \001(\0132!.milvus." - "proto.common.PrivilegeExtBm\n\016io.milvus.g" - "rpcB\013CommonProtoP\001Z4github.com/milvus-io" - "/milvus-proto/go-api/v2/commonpb\240\001\001\252\002\022Mi" - "lvus.Client.Grpcb\006proto3" + "ort\030\002 \001(\003\"\257\002\n\007MsgBase\022.\n\010msg_type\030\001 \001(\0162" + "\034.milvus.proto.common.MsgType\022\r\n\005msgID\030\002" + " \001(\003\022\021\n\ttimestamp\030\003 \001(\004\022\020\n\010sourceID\030\004 \001(" + "\003\022\020\n\010targetID\030\005 \001(\003\022@\n\nproperties\030\006 \003(\0132" + ",.milvus.proto.common.MsgBase.Properties" + "Entry\0229\n\rreplicateInfo\030\007 \001(\0132\".milvus.pr" + "oto.common.ReplicateInfo\0321\n\017PropertiesEn" + "try\022\013\n\003key\030\001 \001(\t\022\r\n\005value\030\002 \001(\t:\0028\001\":\n\rR" + "eplicateInfo\022\023\n\013isReplicate\030\001 \001(\010\022\024\n\014msg" + "Timestamp\030\002 \001(\004\"7\n\tMsgHeader\022*\n\004base\030\001 \001" + "(\0132\034.milvus.proto.common.MsgBase\"M\n\014DMLM" + "sgHeader\022*\n\004base\030\001 \001(\0132\034.milvus.proto.co" + "mmon.MsgBase\022\021\n\tshardName\030\002 \001(\t\"\273\001\n\014Priv" + "ilegeExt\0224\n\013object_type\030\001 \001(\0162\037.milvus.p" + "roto.common.ObjectType\022>\n\020object_privile" + "ge\030\002 \001(\0162$.milvus.proto.common.ObjectPri" + "vilege\022\031\n\021object_name_index\030\003 \001(\005\022\032\n\022obj" + "ect_name_indexs\030\004 \001(\005\"2\n\014SegmentStats\022\021\n" + "\tSegmentID\030\001 \001(\003\022\017\n\007NumRows\030\002 \001(\003\"\325\001\n\nCl" + "ientInfo\022\020\n\010sdk_type\030\001 \001(\t\022\023\n\013sdk_versio" + "n\030\002 \001(\t\022\022\n\nlocal_time\030\003 \001(\t\022\014\n\004user\030\004 \001(" + "\t\022\014\n\004host\030\005 \001(\t\022\?\n\010reserved\030\006 \003(\0132-.milv" + "us.proto.common.ClientInfo.ReservedEntry" + "\032/\n\rReservedEntry\022\013\n\003key\030\001 \001(\t\022\r\n\005value\030" + "\002 \001(\t:\0028\001\"\343\001\n\nServerInfo\022\022\n\nbuild_tags\030\001" + " \001(\t\022\022\n\nbuild_time\030\002 \001(\t\022\022\n\ngit_commit\030\003" + " \001(\t\022\022\n\ngo_version\030\004 \001(\t\022\023\n\013deploy_mode\030" + "\005 \001(\t\022\?\n\010reserved\030\006 \003(\0132-.milvus.proto.c" + "ommon.ServerInfo.ReservedEntry\032/\n\rReserv" + "edEntry\022\013\n\003key\030\001 \001(\t\022\r\n\005value\030\002 \001(\t:\0028\001*" + "\303\n\n\tErrorCode\022\013\n\007Success\020\000\022\023\n\017Unexpected" + "Error\020\001\022\021\n\rConnectFailed\020\002\022\024\n\020Permission" + "Denied\020\003\022\027\n\023CollectionNotExists\020\004\022\023\n\017Ill" + "egalArgument\020\005\022\024\n\020IllegalDimension\020\007\022\024\n\020" + "IllegalIndexType\020\010\022\031\n\025IllegalCollectionN" + "ame\020\t\022\017\n\013IllegalTOPK\020\n\022\024\n\020IllegalRowReco" + "rd\020\013\022\023\n\017IllegalVectorID\020\014\022\027\n\023IllegalSear" + "chResult\020\r\022\020\n\014FileNotFound\020\016\022\016\n\nMetaFail" + "ed\020\017\022\017\n\013CacheFailed\020\020\022\026\n\022CannotCreateFol" + "der\020\021\022\024\n\020CannotCreateFile\020\022\022\026\n\022CannotDel" + "eteFolder\020\023\022\024\n\020CannotDeleteFile\020\024\022\023\n\017Bui" + "ldIndexError\020\025\022\020\n\014IllegalNLIST\020\026\022\025\n\021Ille" + "galMetricType\020\027\022\017\n\013OutOfMemory\020\030\022\021\n\rInde" + "xNotExist\020\031\022\023\n\017EmptyCollection\020\032\022\033\n\027Upda" + "teImportTaskFailure\020\033\022\032\n\026CollectionNameN" + "otFound\020\034\022\033\n\027CreateCredentialFailure\020\035\022\033" + "\n\027UpdateCredentialFailure\020\036\022\033\n\027DeleteCre" + "dentialFailure\020\037\022\030\n\024GetCredentialFailure" + "\020 \022\030\n\024ListCredUsersFailure\020!\022\022\n\016GetUserF" + "ailure\020\"\022\025\n\021CreateRoleFailure\020#\022\023\n\017DropR" + "oleFailure\020$\022\032\n\026OperateUserRoleFailure\020%" + "\022\025\n\021SelectRoleFailure\020&\022\025\n\021SelectUserFai" + "lure\020\'\022\031\n\025SelectResourceFailure\020(\022\033\n\027Ope" + "ratePrivilegeFailure\020)\022\026\n\022SelectGrantFai" + "lure\020*\022!\n\035RefreshPolicyInfoCacheFailure\020" + "+\022\025\n\021ListPolicyFailure\020,\022\022\n\016NotShardLead" + "er\020-\022\026\n\022NoReplicaAvailable\020.\022\023\n\017SegmentN" + "otFound\020/\022\r\n\tForceDeny\0200\022\r\n\tRateLimit\0201\022" + "\022\n\016NodeIDNotMatch\0202\022\024\n\020UpsertAutoIDTrue\020" + "3\022\034\n\030InsufficientMemoryToLoad\0204\022\030\n\024Memor" + "yQuotaExhausted\0205\022\026\n\022DiskQuotaExhausted\020" + "6\022\025\n\021TimeTickLongDelay\0207\022\021\n\rNotReadyServ" + "e\0208\022\033\n\027NotReadyCoordActivating\0209\022\017\n\013Data" + "CoordNA\020d\022\022\n\rDDRequestRace\020\350\007*c\n\nIndexSt" + "ate\022\022\n\016IndexStateNone\020\000\022\014\n\010Unissued\020\001\022\016\n" + "\nInProgress\020\002\022\014\n\010Finished\020\003\022\n\n\006Failed\020\004\022" + "\t\n\005Retry\020\005*\202\001\n\014SegmentState\022\024\n\020SegmentSt" + "ateNone\020\000\022\014\n\010NotExist\020\001\022\013\n\007Growing\020\002\022\n\n\006" + "Sealed\020\003\022\013\n\007Flushed\020\004\022\014\n\010Flushing\020\005\022\013\n\007D" + "ropped\020\006\022\r\n\tImporting\020\007*i\n\017PlaceholderTy" + "pe\022\010\n\004None\020\000\022\020\n\014BinaryVector\020d\022\017\n\013FloatV" + "ector\020e\022\021\n\rFloat16Vector\020f\022\t\n\005Int64\020\005\022\013\n" + "\007VarChar\020\025*\264\020\n\007MsgType\022\r\n\tUndefined\020\000\022\024\n" + "\020CreateCollection\020d\022\022\n\016DropCollection\020e\022" + "\021\n\rHasCollection\020f\022\026\n\022DescribeCollection" + "\020g\022\023\n\017ShowCollections\020h\022\024\n\020GetSystemConf" + "igs\020i\022\022\n\016LoadCollection\020j\022\025\n\021ReleaseColl" + "ection\020k\022\017\n\013CreateAlias\020l\022\r\n\tDropAlias\020m" + "\022\016\n\nAlterAlias\020n\022\023\n\017AlterCollection\020o\022\024\n" + "\020RenameCollection\020p\022\021\n\rDescribeAlias\020q\022\017" + "\n\013ListAliases\020r\022\024\n\017CreatePartition\020\310\001\022\022\n" + "\rDropPartition\020\311\001\022\021\n\014HasPartition\020\312\001\022\026\n\021" + "DescribePartition\020\313\001\022\023\n\016ShowPartitions\020\314" + "\001\022\023\n\016LoadPartitions\020\315\001\022\026\n\021ReleasePartiti" + "ons\020\316\001\022\021\n\014ShowSegments\020\372\001\022\024\n\017DescribeSeg" + "ment\020\373\001\022\021\n\014LoadSegments\020\374\001\022\024\n\017ReleaseSeg" + "ments\020\375\001\022\024\n\017HandoffSegments\020\376\001\022\030\n\023LoadBa" + "lanceSegments\020\377\001\022\025\n\020DescribeSegments\020\200\002\022" + "\034\n\027FederListIndexedSegment\020\201\002\022\"\n\035FederDe" + "scribeSegmentIndexData\020\202\002\022\020\n\013CreateIndex" + "\020\254\002\022\022\n\rDescribeIndex\020\255\002\022\016\n\tDropIndex\020\256\002\022" + "\027\n\022GetIndexStatistics\020\257\002\022\013\n\006Insert\020\220\003\022\013\n" + "\006Delete\020\221\003\022\n\n\005Flush\020\222\003\022\027\n\022ResendSegmentS" + "tats\020\223\003\022\013\n\006Upsert\020\224\003\022\013\n\006Search\020\364\003\022\021\n\014Sea" + "rchResult\020\365\003\022\022\n\rGetIndexState\020\366\003\022\032\n\025GetI" + "ndexBuildProgress\020\367\003\022\034\n\027GetCollectionSta" + "tistics\020\370\003\022\033\n\026GetPartitionStatistics\020\371\003\022" + "\r\n\010Retrieve\020\372\003\022\023\n\016RetrieveResult\020\373\003\022\024\n\017W" + "atchDmChannels\020\374\003\022\025\n\020RemoveDmChannels\020\375\003" + "\022\027\n\022WatchQueryChannels\020\376\003\022\030\n\023RemoveQuery" + "Channels\020\377\003\022\035\n\030SealedSegmentsChangeInfo\020" + "\200\004\022\027\n\022WatchDeltaChannels\020\201\004\022\024\n\017GetShardL" + "eaders\020\202\004\022\020\n\013GetReplicas\020\203\004\022\023\n\016UnsubDmCh" + "annel\020\204\004\022\024\n\017GetDistribution\020\205\004\022\025\n\020SyncDi" + "stribution\020\206\004\022\020\n\013SegmentInfo\020\330\004\022\017\n\nSyste" + "mInfo\020\331\004\022\024\n\017GetRecoveryInfo\020\332\004\022\024\n\017GetSeg" + "mentState\020\333\004\022\r\n\010TimeTick\020\260\t\022\023\n\016QueryNode" + "Stats\020\261\t\022\016\n\tLoadIndex\020\262\t\022\016\n\tRequestID\020\263\t" + "\022\017\n\nRequestTSO\020\264\t\022\024\n\017AllocateSegment\020\265\t\022" + "\026\n\021SegmentStatistics\020\266\t\022\025\n\020SegmentFlushD" + "one\020\267\t\022\017\n\nDataNodeTt\020\270\t\022\014\n\007Connect\020\271\t\022\024\n" + "\017ListClientInfos\020\272\t\022\023\n\016AllocTimestamp\020\273\t" + "\022\025\n\020CreateCredential\020\334\013\022\022\n\rGetCredential" + "\020\335\013\022\025\n\020DeleteCredential\020\336\013\022\025\n\020UpdateCred" + "ential\020\337\013\022\026\n\021ListCredUsernames\020\340\013\022\017\n\nCre" + "ateRole\020\300\014\022\r\n\010DropRole\020\301\014\022\024\n\017OperateUser" + "Role\020\302\014\022\017\n\nSelectRole\020\303\014\022\017\n\nSelectUser\020\304" + "\014\022\023\n\016SelectResource\020\305\014\022\025\n\020OperatePrivile" + "ge\020\306\014\022\020\n\013SelectGrant\020\307\014\022\033\n\026RefreshPolicy" + "InfoCache\020\310\014\022\017\n\nListPolicy\020\311\014\022\030\n\023CreateR" + "esourceGroup\020\244\r\022\026\n\021DropResourceGroup\020\245\r\022" + "\027\n\022ListResourceGroups\020\246\r\022\032\n\025DescribeReso" + "urceGroup\020\247\r\022\021\n\014TransferNode\020\250\r\022\024\n\017Trans" + "ferReplica\020\251\r\022\023\n\016CreateDatabase\020\211\016\022\021\n\014Dr" + "opDatabase\020\212\016\022\022\n\rListDatabases\020\213\016*\"\n\007Dsl" + "Type\022\007\n\003Dsl\020\000\022\016\n\nBoolExprV1\020\001*B\n\017Compact" + "ionState\022\021\n\rUndefiedState\020\000\022\r\n\tExecuting" + "\020\001\022\r\n\tCompleted\020\002*X\n\020ConsistencyLevel\022\n\n" + "\006Strong\020\000\022\013\n\007Session\020\001\022\013\n\007Bounded\020\002\022\016\n\nE" + "ventually\020\003\022\016\n\nCustomized\020\004*\236\001\n\013ImportSt" + "ate\022\021\n\rImportPending\020\000\022\020\n\014ImportFailed\020\001" + "\022\021\n\rImportStarted\020\002\022\023\n\017ImportPersisted\020\005" + "\022\021\n\rImportFlushed\020\010\022\023\n\017ImportCompleted\020\006" + "\022\032\n\026ImportFailedAndCleaned\020\007*2\n\nObjectTy" + "pe\022\016\n\nCollection\020\000\022\n\n\006Global\020\001\022\010\n\004User\020\002" + "*\241\010\n\017ObjectPrivilege\022\020\n\014PrivilegeAll\020\000\022\035" + "\n\031PrivilegeCreateCollection\020\001\022\033\n\027Privile" + "geDropCollection\020\002\022\037\n\033PrivilegeDescribeC" + "ollection\020\003\022\034\n\030PrivilegeShowCollections\020" + "\004\022\021\n\rPrivilegeLoad\020\005\022\024\n\020PrivilegeRelease" + "\020\006\022\027\n\023PrivilegeCompaction\020\007\022\023\n\017Privilege" + "Insert\020\010\022\023\n\017PrivilegeDelete\020\t\022\032\n\026Privile" + "geGetStatistics\020\n\022\030\n\024PrivilegeCreateInde" + "x\020\013\022\030\n\024PrivilegeIndexDetail\020\014\022\026\n\022Privile" + "geDropIndex\020\r\022\023\n\017PrivilegeSearch\020\016\022\022\n\016Pr" + "ivilegeFlush\020\017\022\022\n\016PrivilegeQuery\020\020\022\030\n\024Pr" + "ivilegeLoadBalance\020\021\022\023\n\017PrivilegeImport\020" + "\022\022\034\n\030PrivilegeCreateOwnership\020\023\022\027\n\023Privi" + "legeUpdateUser\020\024\022\032\n\026PrivilegeDropOwnersh" + "ip\020\025\022\034\n\030PrivilegeSelectOwnership\020\026\022\034\n\030Pr" + "ivilegeManageOwnership\020\027\022\027\n\023PrivilegeSel" + "ectUser\020\030\022\023\n\017PrivilegeUpsert\020\031\022 \n\034Privil" + "egeCreateResourceGroup\020\032\022\036\n\032PrivilegeDro" + "pResourceGroup\020\033\022\"\n\036PrivilegeDescribeRes" + "ourceGroup\020\034\022\037\n\033PrivilegeListResourceGro" + "ups\020\035\022\031\n\025PrivilegeTransferNode\020\036\022\034\n\030Priv" + "ilegeTransferReplica\020\037\022\037\n\033PrivilegeGetLo" + "adingProgress\020 \022\031\n\025PrivilegeGetLoadState" + "\020!\022\035\n\031PrivilegeRenameCollection\020\"\022\033\n\027Pri" + "vilegeCreateDatabase\020#\022\031\n\025PrivilegeDropD" + "atabase\020$\022\032\n\026PrivilegeListDatabases\020%\022\025\n" + "\021PrivilegeFlushAll\020&*S\n\tStateCode\022\020\n\014Ini" + "tializing\020\000\022\013\n\007Healthy\020\001\022\014\n\010Abnormal\020\002\022\013" + "\n\007StandBy\020\003\022\014\n\010Stopping\020\004*c\n\tLoadState\022\025" + "\n\021LoadStateNotExist\020\000\022\024\n\020LoadStateNotLoa" + "d\020\001\022\024\n\020LoadStateLoading\020\002\022\023\n\017LoadStateLo" + "aded\020\003:^\n\021privilege_ext_obj\022\037.google.pro" + "tobuf.MessageOptions\030\351\007 \001(\0132!.milvus.pro" + "to.common.PrivilegeExtBm\n\016io.milvus.grpc" + "B\013CommonProtoP\001Z4github.com/milvus-io/mi" + "lvus-proto/go-api/v2/commonpb\240\001\001\252\002\022Milvu" + "s.Client.Grpcb\006proto3" ; static const ::_pbi::DescriptorTable* const descriptor_table_common_2eproto_deps[1] = { &::descriptor_table_google_2fprotobuf_2fdescriptor_2eproto, }; static ::_pbi::once_flag descriptor_table_common_2eproto_once; const ::_pbi::DescriptorTable descriptor_table_common_2eproto = { - false, false, 7104, descriptor_table_protodef_common_2eproto, + false, false, 7341, descriptor_table_protodef_common_2eproto, "common.proto", - &descriptor_table_common_2eproto_once, descriptor_table_common_2eproto_deps, 1, 16, + &descriptor_table_common_2eproto_once, descriptor_table_common_2eproto_deps, 1, 18, schemas, file_default_instances, TableStruct_common_2eproto::offsets, file_level_metadata_common_2eproto, file_level_enum_descriptors_common_2eproto, file_level_service_descriptors_common_2eproto, @@ -2698,21 +2755,45 @@ ::PROTOBUF_NAMESPACE_ID::Metadata Address::GetMetadata() const { // =================================================================== +MsgBase_PropertiesEntry_DoNotUse::MsgBase_PropertiesEntry_DoNotUse() {} +MsgBase_PropertiesEntry_DoNotUse::MsgBase_PropertiesEntry_DoNotUse(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : SuperType(arena) {} +void MsgBase_PropertiesEntry_DoNotUse::MergeFrom(const MsgBase_PropertiesEntry_DoNotUse& other) { + MergeFromInternal(other); +} +::PROTOBUF_NAMESPACE_ID::Metadata MsgBase_PropertiesEntry_DoNotUse::GetMetadata() const { + return ::_pbi::AssignDescriptors( + &descriptor_table_common_2eproto_getter, &descriptor_table_common_2eproto_once, + file_level_metadata_common_2eproto[7]); +} + +// =================================================================== + class MsgBase::_Internal { public: + static const ::milvus::proto::common::ReplicateInfo& replicateinfo(const MsgBase* msg); }; +const ::milvus::proto::common::ReplicateInfo& +MsgBase::_Internal::replicateinfo(const MsgBase* msg) { + return *msg->_impl_.replicateinfo_; +} MsgBase::MsgBase(::PROTOBUF_NAMESPACE_ID::Arena* arena, bool is_message_owned) : ::PROTOBUF_NAMESPACE_ID::Message(arena, is_message_owned) { SharedCtor(arena, is_message_owned); + if (arena != nullptr && !is_message_owned) { + arena->OwnCustomDestructor(this, &MsgBase::ArenaDtor); + } // @@protoc_insertion_point(arena_constructor:milvus.proto.common.MsgBase) } MsgBase::MsgBase(const MsgBase& from) : ::PROTOBUF_NAMESPACE_ID::Message() { MsgBase* const _this = this; (void)_this; new (&_impl_) Impl_{ - decltype(_impl_.msgid_){} + /*decltype(_impl_.properties_)*/{} + , decltype(_impl_.replicateinfo_){nullptr} + , decltype(_impl_.msgid_){} , decltype(_impl_.timestamp_){} , decltype(_impl_.sourceid_){} , decltype(_impl_.targetid_){} @@ -2720,6 +2801,10 @@ MsgBase::MsgBase(const MsgBase& from) , /*decltype(_impl_._cached_size_)*/{}}; _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + _this->_impl_.properties_.MergeFrom(from._impl_.properties_); + if (from._internal_has_replicateinfo()) { + _this->_impl_.replicateinfo_ = new ::milvus::proto::common::ReplicateInfo(*from._impl_.replicateinfo_); + } ::memcpy(&_impl_.msgid_, &from._impl_.msgid_, static_cast(reinterpret_cast(&_impl_.msg_type_) - reinterpret_cast(&_impl_.msgid_)) + sizeof(_impl_.msg_type_)); @@ -2731,7 +2816,9 @@ inline void MsgBase::SharedCtor( (void)arena; (void)is_message_owned; new (&_impl_) Impl_{ - decltype(_impl_.msgid_){int64_t{0}} + /*decltype(_impl_.properties_)*/{::_pbi::ArenaInitialized(), arena} + , decltype(_impl_.replicateinfo_){nullptr} + , decltype(_impl_.msgid_){int64_t{0}} , decltype(_impl_.timestamp_){uint64_t{0u}} , decltype(_impl_.sourceid_){int64_t{0}} , decltype(_impl_.targetid_){int64_t{0}} @@ -2744,6 +2831,7 @@ MsgBase::~MsgBase() { // @@protoc_insertion_point(destructor:milvus.proto.common.MsgBase) if (auto *arena = _internal_metadata_.DeleteReturnArena<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>()) { (void)arena; + ArenaDtor(this); return; } SharedDtor(); @@ -2751,8 +2839,15 @@ MsgBase::~MsgBase() { inline void MsgBase::SharedDtor() { GOOGLE_DCHECK(GetArenaForAllocation() == nullptr); + _impl_.properties_.Destruct(); + _impl_.properties_.~MapField(); + if (this != internal_default_instance()) delete _impl_.replicateinfo_; } +void MsgBase::ArenaDtor(void* object) { + MsgBase* _this = reinterpret_cast< MsgBase* >(object); + _this->_impl_.properties_.Destruct(); +} void MsgBase::SetCachedSize(int size) const { _impl_._cached_size_.Set(size); } @@ -2763,6 +2858,11 @@ void MsgBase::Clear() { // Prevent compiler warnings about cached_has_bits being unused (void) cached_has_bits; + _impl_.properties_.Clear(); + if (GetArenaForAllocation() == nullptr && _impl_.replicateinfo_ != nullptr) { + delete _impl_.replicateinfo_; + } + _impl_.replicateinfo_ = nullptr; ::memset(&_impl_.msgid_, 0, static_cast( reinterpret_cast(&_impl_.msg_type_) - reinterpret_cast(&_impl_.msgid_)) + sizeof(_impl_.msg_type_)); @@ -2816,6 +2916,27 @@ const char* MsgBase::_InternalParse(const char* ptr, ::_pbi::ParseContext* ctx) } else goto handle_unusual; continue; + // map properties = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast(tag) == 50)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(&_impl_.properties_, ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<50>(ptr)); + } else + goto handle_unusual; + continue; + // .milvus.proto.common.ReplicateInfo replicateInfo = 7; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast(tag) == 58)) { + ptr = ctx->ParseMessage(_internal_mutable_replicateinfo(), ptr); + CHK_(ptr); + } else + goto handle_unusual; + continue; default: goto handle_unusual; } // switch @@ -2876,6 +2997,43 @@ uint8_t* MsgBase::_InternalSerialize( target = ::_pbi::WireFormatLite::WriteInt64ToArray(5, this->_internal_targetid(), target); } + // map properties = 6; + if (!this->_internal_properties().empty()) { + using MapType = ::_pb::Map; + using WireHelper = MsgBase_PropertiesEntry_DoNotUse::Funcs; + const auto& map_field = this->_internal_properties(); + auto check_utf8 = [](const MapType::value_type& entry) { + (void)entry; + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + entry.first.data(), static_cast(entry.first.length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "milvus.proto.common.MsgBase.PropertiesEntry.key"); + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String( + entry.second.data(), static_cast(entry.second.length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::SERIALIZE, + "milvus.proto.common.MsgBase.PropertiesEntry.value"); + }; + + if (stream->IsSerializationDeterministic() && map_field.size() > 1) { + for (const auto& entry : ::_pbi::MapSorterPtr(map_field)) { + target = WireHelper::InternalSerialize(6, entry.first, entry.second, target, stream); + check_utf8(entry); + } + } else { + for (const auto& entry : map_field) { + target = WireHelper::InternalSerialize(6, entry.first, entry.second, target, stream); + check_utf8(entry); + } + } + } + + // .milvus.proto.common.ReplicateInfo replicateInfo = 7; + if (this->_internal_has_replicateinfo()) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(7, _Internal::replicateinfo(this), + _Internal::replicateinfo(this).GetCachedSize(), target, stream); + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { target = ::_pbi::WireFormat::InternalSerializeUnknownFieldsToArray( _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); @@ -2892,6 +3050,22 @@ size_t MsgBase::ByteSizeLong() const { // Prevent compiler warnings about cached_has_bits being unused (void) cached_has_bits; + // map properties = 6; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_properties_size()); + for (::PROTOBUF_NAMESPACE_ID::Map< std::string, std::string >::const_iterator + it = this->_internal_properties().begin(); + it != this->_internal_properties().end(); ++it) { + total_size += MsgBase_PropertiesEntry_DoNotUse::Funcs::ByteSizeLong(it->first, it->second); + } + + // .milvus.proto.common.ReplicateInfo replicateInfo = 7; + if (this->_internal_has_replicateinfo()) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *_impl_.replicateinfo_); + } + // int64 msgID = 2; if (this->_internal_msgid() != 0) { total_size += ::_pbi::WireFormatLite::Int64SizePlusOne(this->_internal_msgid()); @@ -2936,6 +3110,11 @@ void MsgBase::MergeImpl(::PROTOBUF_NAMESPACE_ID::Message& to_msg, const ::PROTOB uint32_t cached_has_bits = 0; (void) cached_has_bits; + _this->_impl_.properties_.MergeFrom(from._impl_.properties_); + if (from._internal_has_replicateinfo()) { + _this->_internal_mutable_replicateinfo()->::milvus::proto::common::ReplicateInfo::MergeFrom( + from._internal_replicateinfo()); + } if (from._internal_msgid() != 0) { _this->_internal_set_msgid(from._internal_msgid()); } @@ -2968,18 +3147,230 @@ bool MsgBase::IsInitialized() const { void MsgBase::InternalSwap(MsgBase* other) { using std::swap; _internal_metadata_.InternalSwap(&other->_internal_metadata_); + _impl_.properties_.InternalSwap(&other->_impl_.properties_); ::PROTOBUF_NAMESPACE_ID::internal::memswap< PROTOBUF_FIELD_OFFSET(MsgBase, _impl_.msg_type_) + sizeof(MsgBase::_impl_.msg_type_) - - PROTOBUF_FIELD_OFFSET(MsgBase, _impl_.msgid_)>( - reinterpret_cast(&_impl_.msgid_), - reinterpret_cast(&other->_impl_.msgid_)); + - PROTOBUF_FIELD_OFFSET(MsgBase, _impl_.replicateinfo_)>( + reinterpret_cast(&_impl_.replicateinfo_), + reinterpret_cast(&other->_impl_.replicateinfo_)); } ::PROTOBUF_NAMESPACE_ID::Metadata MsgBase::GetMetadata() const { return ::_pbi::AssignDescriptors( &descriptor_table_common_2eproto_getter, &descriptor_table_common_2eproto_once, - file_level_metadata_common_2eproto[7]); + file_level_metadata_common_2eproto[8]); +} + +// =================================================================== + +class ReplicateInfo::_Internal { + public: +}; + +ReplicateInfo::ReplicateInfo(::PROTOBUF_NAMESPACE_ID::Arena* arena, + bool is_message_owned) + : ::PROTOBUF_NAMESPACE_ID::Message(arena, is_message_owned) { + SharedCtor(arena, is_message_owned); + // @@protoc_insertion_point(arena_constructor:milvus.proto.common.ReplicateInfo) +} +ReplicateInfo::ReplicateInfo(const ReplicateInfo& from) + : ::PROTOBUF_NAMESPACE_ID::Message() { + ReplicateInfo* const _this = this; (void)_this; + new (&_impl_) Impl_{ + decltype(_impl_.msgtimestamp_){} + , decltype(_impl_.isreplicate_){} + , /*decltype(_impl_._cached_size_)*/{}}; + + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::memcpy(&_impl_.msgtimestamp_, &from._impl_.msgtimestamp_, + static_cast(reinterpret_cast(&_impl_.isreplicate_) - + reinterpret_cast(&_impl_.msgtimestamp_)) + sizeof(_impl_.isreplicate_)); + // @@protoc_insertion_point(copy_constructor:milvus.proto.common.ReplicateInfo) +} + +inline void ReplicateInfo::SharedCtor( + ::_pb::Arena* arena, bool is_message_owned) { + (void)arena; + (void)is_message_owned; + new (&_impl_) Impl_{ + decltype(_impl_.msgtimestamp_){uint64_t{0u}} + , decltype(_impl_.isreplicate_){false} + , /*decltype(_impl_._cached_size_)*/{} + }; +} + +ReplicateInfo::~ReplicateInfo() { + // @@protoc_insertion_point(destructor:milvus.proto.common.ReplicateInfo) + if (auto *arena = _internal_metadata_.DeleteReturnArena<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>()) { + (void)arena; + return; + } + SharedDtor(); +} + +inline void ReplicateInfo::SharedDtor() { + GOOGLE_DCHECK(GetArenaForAllocation() == nullptr); +} + +void ReplicateInfo::SetCachedSize(int size) const { + _impl_._cached_size_.Set(size); +} + +void ReplicateInfo::Clear() { +// @@protoc_insertion_point(message_clear_start:milvus.proto.common.ReplicateInfo) + uint32_t cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + ::memset(&_impl_.msgtimestamp_, 0, static_cast( + reinterpret_cast(&_impl_.isreplicate_) - + reinterpret_cast(&_impl_.msgtimestamp_)) + sizeof(_impl_.isreplicate_)); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* ReplicateInfo::_InternalParse(const char* ptr, ::_pbi::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + while (!ctx->Done(&ptr)) { + uint32_t tag; + ptr = ::_pbi::ReadTag(ptr, &tag); + switch (tag >> 3) { + // bool isReplicate = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast(tag) == 8)) { + _impl_.isreplicate_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else + goto handle_unusual; + continue; + // uint64 msgTimestamp = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast(tag) == 16)) { + _impl_.msgtimestamp_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else + goto handle_unusual; + continue; + default: + goto handle_unusual; + } // switch + handle_unusual: + if ((tag == 0) || ((tag & 7) == 4)) { + CHK_(ptr); + ctx->SetLastTag(tag); + goto message_done; + } + ptr = UnknownFieldParse( + tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + } // while +message_done: + return ptr; +failure: + ptr = nullptr; + goto message_done; +#undef CHK_ +} + +uint8_t* ReplicateInfo::_InternalSerialize( + uint8_t* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:milvus.proto.common.ReplicateInfo) + uint32_t cached_has_bits = 0; + (void) cached_has_bits; + + // bool isReplicate = 1; + if (this->_internal_isreplicate() != 0) { + target = stream->EnsureSpace(target); + target = ::_pbi::WireFormatLite::WriteBoolToArray(1, this->_internal_isreplicate(), target); + } + + // uint64 msgTimestamp = 2; + if (this->_internal_msgtimestamp() != 0) { + target = stream->EnsureSpace(target); + target = ::_pbi::WireFormatLite::WriteUInt64ToArray(2, this->_internal_msgtimestamp(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::_pbi::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:milvus.proto.common.ReplicateInfo) + return target; +} + +size_t ReplicateInfo::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:milvus.proto.common.ReplicateInfo) + size_t total_size = 0; + + uint32_t cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // uint64 msgTimestamp = 2; + if (this->_internal_msgtimestamp() != 0) { + total_size += ::_pbi::WireFormatLite::UInt64SizePlusOne(this->_internal_msgtimestamp()); + } + + // bool isReplicate = 1; + if (this->_internal_isreplicate() != 0) { + total_size += 1 + 1; + } + + return MaybeComputeUnknownFieldsSize(total_size, &_impl_._cached_size_); +} + +const ::PROTOBUF_NAMESPACE_ID::Message::ClassData ReplicateInfo::_class_data_ = { + ::PROTOBUF_NAMESPACE_ID::Message::CopyWithSourceCheck, + ReplicateInfo::MergeImpl +}; +const ::PROTOBUF_NAMESPACE_ID::Message::ClassData*ReplicateInfo::GetClassData() const { return &_class_data_; } + + +void ReplicateInfo::MergeImpl(::PROTOBUF_NAMESPACE_ID::Message& to_msg, const ::PROTOBUF_NAMESPACE_ID::Message& from_msg) { + auto* const _this = static_cast(&to_msg); + auto& from = static_cast(from_msg); + // @@protoc_insertion_point(class_specific_merge_from_start:milvus.proto.common.ReplicateInfo) + GOOGLE_DCHECK_NE(&from, _this); + uint32_t cached_has_bits = 0; + (void) cached_has_bits; + + if (from._internal_msgtimestamp() != 0) { + _this->_internal_set_msgtimestamp(from._internal_msgtimestamp()); + } + if (from._internal_isreplicate() != 0) { + _this->_internal_set_isreplicate(from._internal_isreplicate()); + } + _this->_internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); +} + +void ReplicateInfo::CopyFrom(const ReplicateInfo& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:milvus.proto.common.ReplicateInfo) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool ReplicateInfo::IsInitialized() const { + return true; +} + +void ReplicateInfo::InternalSwap(ReplicateInfo* other) { + using std::swap; + _internal_metadata_.InternalSwap(&other->_internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(ReplicateInfo, _impl_.isreplicate_) + + sizeof(ReplicateInfo::_impl_.isreplicate_) + - PROTOBUF_FIELD_OFFSET(ReplicateInfo, _impl_.msgtimestamp_)>( + reinterpret_cast(&_impl_.msgtimestamp_), + reinterpret_cast(&other->_impl_.msgtimestamp_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata ReplicateInfo::GetMetadata() const { + return ::_pbi::AssignDescriptors( + &descriptor_table_common_2eproto_getter, &descriptor_table_common_2eproto_once, + file_level_metadata_common_2eproto[9]); } // =================================================================== @@ -3172,7 +3563,7 @@ void MsgHeader::InternalSwap(MsgHeader* other) { ::PROTOBUF_NAMESPACE_ID::Metadata MsgHeader::GetMetadata() const { return ::_pbi::AssignDescriptors( &descriptor_table_common_2eproto_getter, &descriptor_table_common_2eproto_once, - file_level_metadata_common_2eproto[8]); + file_level_metadata_common_2eproto[10]); } // =================================================================== @@ -3417,7 +3808,7 @@ void DMLMsgHeader::InternalSwap(DMLMsgHeader* other) { ::PROTOBUF_NAMESPACE_ID::Metadata DMLMsgHeader::GetMetadata() const { return ::_pbi::AssignDescriptors( &descriptor_table_common_2eproto_getter, &descriptor_table_common_2eproto_once, - file_level_metadata_common_2eproto[9]); + file_level_metadata_common_2eproto[11]); } // =================================================================== @@ -3682,7 +4073,7 @@ void PrivilegeExt::InternalSwap(PrivilegeExt* other) { ::PROTOBUF_NAMESPACE_ID::Metadata PrivilegeExt::GetMetadata() const { return ::_pbi::AssignDescriptors( &descriptor_table_common_2eproto_getter, &descriptor_table_common_2eproto_once, - file_level_metadata_common_2eproto[10]); + file_level_metadata_common_2eproto[12]); } // =================================================================== @@ -3893,7 +4284,7 @@ void SegmentStats::InternalSwap(SegmentStats* other) { ::PROTOBUF_NAMESPACE_ID::Metadata SegmentStats::GetMetadata() const { return ::_pbi::AssignDescriptors( &descriptor_table_common_2eproto_getter, &descriptor_table_common_2eproto_once, - file_level_metadata_common_2eproto[11]); + file_level_metadata_common_2eproto[13]); } // =================================================================== @@ -3907,7 +4298,7 @@ void ClientInfo_ReservedEntry_DoNotUse::MergeFrom(const ClientInfo_ReservedEntry ::PROTOBUF_NAMESPACE_ID::Metadata ClientInfo_ReservedEntry_DoNotUse::GetMetadata() const { return ::_pbi::AssignDescriptors( &descriptor_table_common_2eproto_getter, &descriptor_table_common_2eproto_once, - file_level_metadata_common_2eproto[12]); + file_level_metadata_common_2eproto[14]); } // =================================================================== @@ -4378,7 +4769,7 @@ void ClientInfo::InternalSwap(ClientInfo* other) { ::PROTOBUF_NAMESPACE_ID::Metadata ClientInfo::GetMetadata() const { return ::_pbi::AssignDescriptors( &descriptor_table_common_2eproto_getter, &descriptor_table_common_2eproto_once, - file_level_metadata_common_2eproto[13]); + file_level_metadata_common_2eproto[15]); } // =================================================================== @@ -4392,7 +4783,7 @@ void ServerInfo_ReservedEntry_DoNotUse::MergeFrom(const ServerInfo_ReservedEntry ::PROTOBUF_NAMESPACE_ID::Metadata ServerInfo_ReservedEntry_DoNotUse::GetMetadata() const { return ::_pbi::AssignDescriptors( &descriptor_table_common_2eproto_getter, &descriptor_table_common_2eproto_once, - file_level_metadata_common_2eproto[14]); + file_level_metadata_common_2eproto[16]); } // =================================================================== @@ -4863,7 +5254,7 @@ void ServerInfo::InternalSwap(ServerInfo* other) { ::PROTOBUF_NAMESPACE_ID::Metadata ServerInfo::GetMetadata() const { return ::_pbi::AssignDescriptors( &descriptor_table_common_2eproto_getter, &descriptor_table_common_2eproto_once, - file_level_metadata_common_2eproto[15]); + file_level_metadata_common_2eproto[17]); } PROTOBUF_ATTRIBUTE_INIT_PRIORITY2 ::PROTOBUF_NAMESPACE_ID::internal::ExtensionIdentifier< ::PROTOBUF_NAMESPACE_ID::MessageOptions, ::PROTOBUF_NAMESPACE_ID::internal::MessageTypeTraits< ::milvus::proto::common::PrivilegeExt >, 11, false> @@ -4902,10 +5293,18 @@ template<> PROTOBUF_NOINLINE ::milvus::proto::common::Address* Arena::CreateMaybeMessage< ::milvus::proto::common::Address >(Arena* arena) { return Arena::CreateMessageInternal< ::milvus::proto::common::Address >(arena); } +template<> PROTOBUF_NOINLINE ::milvus::proto::common::MsgBase_PropertiesEntry_DoNotUse* +Arena::CreateMaybeMessage< ::milvus::proto::common::MsgBase_PropertiesEntry_DoNotUse >(Arena* arena) { + return Arena::CreateMessageInternal< ::milvus::proto::common::MsgBase_PropertiesEntry_DoNotUse >(arena); +} template<> PROTOBUF_NOINLINE ::milvus::proto::common::MsgBase* Arena::CreateMaybeMessage< ::milvus::proto::common::MsgBase >(Arena* arena) { return Arena::CreateMessageInternal< ::milvus::proto::common::MsgBase >(arena); } +template<> PROTOBUF_NOINLINE ::milvus::proto::common::ReplicateInfo* +Arena::CreateMaybeMessage< ::milvus::proto::common::ReplicateInfo >(Arena* arena) { + return Arena::CreateMessageInternal< ::milvus::proto::common::ReplicateInfo >(arena); +} template<> PROTOBUF_NOINLINE ::milvus::proto::common::MsgHeader* Arena::CreateMaybeMessage< ::milvus::proto::common::MsgHeader >(Arena* arena) { return Arena::CreateMessageInternal< ::milvus::proto::common::MsgHeader >(arena); diff --git a/internal/core/src/pb/common.pb.h b/internal/core/src/pb/common.pb.h index 4e164b872ff3b..690bcf51a8afd 100644 --- a/internal/core/src/pb/common.pb.h +++ b/internal/core/src/pb/common.pb.h @@ -76,6 +76,9 @@ extern KeyValuePairDefaultTypeInternal _KeyValuePair_default_instance_; class MsgBase; struct MsgBaseDefaultTypeInternal; extern MsgBaseDefaultTypeInternal _MsgBase_default_instance_; +class MsgBase_PropertiesEntry_DoNotUse; +struct MsgBase_PropertiesEntry_DoNotUseDefaultTypeInternal; +extern MsgBase_PropertiesEntry_DoNotUseDefaultTypeInternal _MsgBase_PropertiesEntry_DoNotUse_default_instance_; class MsgHeader; struct MsgHeaderDefaultTypeInternal; extern MsgHeaderDefaultTypeInternal _MsgHeader_default_instance_; @@ -88,6 +91,9 @@ extern PlaceholderValueDefaultTypeInternal _PlaceholderValue_default_instance_; class PrivilegeExt; struct PrivilegeExtDefaultTypeInternal; extern PrivilegeExtDefaultTypeInternal _PrivilegeExt_default_instance_; +class ReplicateInfo; +struct ReplicateInfoDefaultTypeInternal; +extern ReplicateInfoDefaultTypeInternal _ReplicateInfo_default_instance_; class SegmentStats; struct SegmentStatsDefaultTypeInternal; extern SegmentStatsDefaultTypeInternal _SegmentStats_default_instance_; @@ -112,10 +118,12 @@ template<> ::milvus::proto::common::DMLMsgHeader* Arena::CreateMaybeMessage<::mi template<> ::milvus::proto::common::KeyDataPair* Arena::CreateMaybeMessage<::milvus::proto::common::KeyDataPair>(Arena*); template<> ::milvus::proto::common::KeyValuePair* Arena::CreateMaybeMessage<::milvus::proto::common::KeyValuePair>(Arena*); template<> ::milvus::proto::common::MsgBase* Arena::CreateMaybeMessage<::milvus::proto::common::MsgBase>(Arena*); +template<> ::milvus::proto::common::MsgBase_PropertiesEntry_DoNotUse* Arena::CreateMaybeMessage<::milvus::proto::common::MsgBase_PropertiesEntry_DoNotUse>(Arena*); template<> ::milvus::proto::common::MsgHeader* Arena::CreateMaybeMessage<::milvus::proto::common::MsgHeader>(Arena*); template<> ::milvus::proto::common::PlaceholderGroup* Arena::CreateMaybeMessage<::milvus::proto::common::PlaceholderGroup>(Arena*); template<> ::milvus::proto::common::PlaceholderValue* Arena::CreateMaybeMessage<::milvus::proto::common::PlaceholderValue>(Arena*); template<> ::milvus::proto::common::PrivilegeExt* Arena::CreateMaybeMessage<::milvus::proto::common::PrivilegeExt>(Arena*); +template<> ::milvus::proto::common::ReplicateInfo* Arena::CreateMaybeMessage<::milvus::proto::common::ReplicateInfo>(Arena*); template<> ::milvus::proto::common::SegmentStats* Arena::CreateMaybeMessage<::milvus::proto::common::SegmentStats>(Arena*); template<> ::milvus::proto::common::ServerInfo* Arena::CreateMaybeMessage<::milvus::proto::common::ServerInfo>(Arena*); template<> ::milvus::proto::common::ServerInfo_ReservedEntry_DoNotUse* Arena::CreateMaybeMessage<::milvus::proto::common::ServerInfo_ReservedEntry_DoNotUse>(Arena*); @@ -1850,6 +1858,34 @@ class Address final : }; // ------------------------------------------------------------------- +class MsgBase_PropertiesEntry_DoNotUse : public ::PROTOBUF_NAMESPACE_ID::internal::MapEntry { +public: + typedef ::PROTOBUF_NAMESPACE_ID::internal::MapEntry SuperType; + MsgBase_PropertiesEntry_DoNotUse(); + explicit PROTOBUF_CONSTEXPR MsgBase_PropertiesEntry_DoNotUse( + ::PROTOBUF_NAMESPACE_ID::internal::ConstantInitialized); + explicit MsgBase_PropertiesEntry_DoNotUse(::PROTOBUF_NAMESPACE_ID::Arena* arena); + void MergeFrom(const MsgBase_PropertiesEntry_DoNotUse& other); + static const MsgBase_PropertiesEntry_DoNotUse* internal_default_instance() { return reinterpret_cast(&_MsgBase_PropertiesEntry_DoNotUse_default_instance_); } + static bool ValidateKey(std::string* s) { + return ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String(s->data(), static_cast(s->size()), ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::PARSE, "milvus.proto.common.MsgBase.PropertiesEntry.key"); + } + static bool ValidateValue(std::string* s) { + return ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::VerifyUtf8String(s->data(), static_cast(s->size()), ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::PARSE, "milvus.proto.common.MsgBase.PropertiesEntry.value"); + } + using ::PROTOBUF_NAMESPACE_ID::Message::MergeFrom; + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + friend struct ::TableStruct_common_2eproto; +}; + +// ------------------------------------------------------------------- + class MsgBase final : public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:milvus.proto.common.MsgBase) */ { public: @@ -1898,7 +1934,7 @@ class MsgBase final : &_MsgBase_default_instance_); } static constexpr int kIndexInFileMessages = - 7; + 8; friend void swap(MsgBase& a, MsgBase& b) { a.Swap(&b); @@ -1959,6 +1995,8 @@ class MsgBase final : protected: explicit MsgBase(::PROTOBUF_NAMESPACE_ID::Arena* arena, bool is_message_owned = false); + private: + static void ArenaDtor(void* object); public: static const ClassData _class_data_; @@ -1968,15 +2006,53 @@ class MsgBase final : // nested types ---------------------------------------------------- + // accessors ------------------------------------------------------- enum : int { + kPropertiesFieldNumber = 6, + kReplicateInfoFieldNumber = 7, kMsgIDFieldNumber = 2, kTimestampFieldNumber = 3, kSourceIDFieldNumber = 4, kTargetIDFieldNumber = 5, kMsgTypeFieldNumber = 1, }; + // map properties = 6; + int properties_size() const; + private: + int _internal_properties_size() const; + public: + void clear_properties(); + private: + const ::PROTOBUF_NAMESPACE_ID::Map< std::string, std::string >& + _internal_properties() const; + ::PROTOBUF_NAMESPACE_ID::Map< std::string, std::string >* + _internal_mutable_properties(); + public: + const ::PROTOBUF_NAMESPACE_ID::Map< std::string, std::string >& + properties() const; + ::PROTOBUF_NAMESPACE_ID::Map< std::string, std::string >* + mutable_properties(); + + // .milvus.proto.common.ReplicateInfo replicateInfo = 7; + bool has_replicateinfo() const; + private: + bool _internal_has_replicateinfo() const; + public: + void clear_replicateinfo(); + const ::milvus::proto::common::ReplicateInfo& replicateinfo() const; + PROTOBUF_NODISCARD ::milvus::proto::common::ReplicateInfo* release_replicateinfo(); + ::milvus::proto::common::ReplicateInfo* mutable_replicateinfo(); + void set_allocated_replicateinfo(::milvus::proto::common::ReplicateInfo* replicateinfo); + private: + const ::milvus::proto::common::ReplicateInfo& _internal_replicateinfo() const; + ::milvus::proto::common::ReplicateInfo* _internal_mutable_replicateinfo(); + public: + void unsafe_arena_set_allocated_replicateinfo( + ::milvus::proto::common::ReplicateInfo* replicateinfo); + ::milvus::proto::common::ReplicateInfo* unsafe_arena_release_replicateinfo(); + // int64 msgID = 2; void clear_msgid(); int64_t msgid() const; @@ -2030,6 +2106,12 @@ class MsgBase final : typedef void InternalArenaConstructable_; typedef void DestructorSkippable_; struct Impl_ { + ::PROTOBUF_NAMESPACE_ID::internal::MapField< + MsgBase_PropertiesEntry_DoNotUse, + std::string, std::string, + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_STRING, + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::TYPE_STRING> properties_; + ::milvus::proto::common::ReplicateInfo* replicateinfo_; int64_t msgid_; uint64_t timestamp_; int64_t sourceid_; @@ -2042,6 +2124,165 @@ class MsgBase final : }; // ------------------------------------------------------------------- +class ReplicateInfo final : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:milvus.proto.common.ReplicateInfo) */ { + public: + inline ReplicateInfo() : ReplicateInfo(nullptr) {} + ~ReplicateInfo() override; + explicit PROTOBUF_CONSTEXPR ReplicateInfo(::PROTOBUF_NAMESPACE_ID::internal::ConstantInitialized); + + ReplicateInfo(const ReplicateInfo& from); + ReplicateInfo(ReplicateInfo&& from) noexcept + : ReplicateInfo() { + *this = ::std::move(from); + } + + inline ReplicateInfo& operator=(const ReplicateInfo& from) { + CopyFrom(from); + return *this; + } + inline ReplicateInfo& operator=(ReplicateInfo&& from) noexcept { + if (this == &from) return *this; + if (GetOwningArena() == from.GetOwningArena() + #ifdef PROTOBUF_FORCE_COPY_IN_MOVE + && GetOwningArena() != nullptr + #endif // !PROTOBUF_FORCE_COPY_IN_MOVE + ) { + InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return default_instance().GetMetadata().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return default_instance().GetMetadata().reflection; + } + static const ReplicateInfo& default_instance() { + return *internal_default_instance(); + } + static inline const ReplicateInfo* internal_default_instance() { + return reinterpret_cast( + &_ReplicateInfo_default_instance_); + } + static constexpr int kIndexInFileMessages = + 9; + + friend void swap(ReplicateInfo& a, ReplicateInfo& b) { + a.Swap(&b); + } + inline void Swap(ReplicateInfo* other) { + if (other == this) return; + #ifdef PROTOBUF_FORCE_COPY_IN_SWAP + if (GetOwningArena() != nullptr && + GetOwningArena() == other->GetOwningArena()) { + #else // PROTOBUF_FORCE_COPY_IN_SWAP + if (GetOwningArena() == other->GetOwningArena()) { + #endif // !PROTOBUF_FORCE_COPY_IN_SWAP + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(ReplicateInfo* other) { + if (other == this) return; + GOOGLE_DCHECK(GetOwningArena() == other->GetOwningArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + ReplicateInfo* New(::PROTOBUF_NAMESPACE_ID::Arena* arena = nullptr) const final { + return CreateMaybeMessage(arena); + } + using ::PROTOBUF_NAMESPACE_ID::Message::CopyFrom; + void CopyFrom(const ReplicateInfo& from); + using ::PROTOBUF_NAMESPACE_ID::Message::MergeFrom; + void MergeFrom( const ReplicateInfo& from) { + ReplicateInfo::MergeImpl(*this, from); + } + private: + static void MergeImpl(::PROTOBUF_NAMESPACE_ID::Message& to_msg, const ::PROTOBUF_NAMESPACE_ID::Message& from_msg); + public: + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + uint8_t* _InternalSerialize( + uint8_t* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _impl_._cached_size_.Get(); } + + private: + void SharedCtor(::PROTOBUF_NAMESPACE_ID::Arena* arena, bool is_message_owned); + void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ReplicateInfo* other); + + private: + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "milvus.proto.common.ReplicateInfo"; + } + protected: + explicit ReplicateInfo(::PROTOBUF_NAMESPACE_ID::Arena* arena, + bool is_message_owned = false); + public: + + static const ClassData _class_data_; + const ::PROTOBUF_NAMESPACE_ID::Message::ClassData*GetClassData() const final; + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kMsgTimestampFieldNumber = 2, + kIsReplicateFieldNumber = 1, + }; + // uint64 msgTimestamp = 2; + void clear_msgtimestamp(); + uint64_t msgtimestamp() const; + void set_msgtimestamp(uint64_t value); + private: + uint64_t _internal_msgtimestamp() const; + void _internal_set_msgtimestamp(uint64_t value); + public: + + // bool isReplicate = 1; + void clear_isreplicate(); + bool isreplicate() const; + void set_isreplicate(bool value); + private: + bool _internal_isreplicate() const; + void _internal_set_isreplicate(bool value); + public: + + // @@protoc_insertion_point(class_scope:milvus.proto.common.ReplicateInfo) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + struct Impl_ { + uint64_t msgtimestamp_; + bool isreplicate_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + }; + union { Impl_ _impl_; }; + friend struct ::TableStruct_common_2eproto; +}; +// ------------------------------------------------------------------- + class MsgHeader final : public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:milvus.proto.common.MsgHeader) */ { public: @@ -2090,7 +2331,7 @@ class MsgHeader final : &_MsgHeader_default_instance_); } static constexpr int kIndexInFileMessages = - 8; + 10; friend void swap(MsgHeader& a, MsgHeader& b) { a.Swap(&b); @@ -2247,7 +2488,7 @@ class DMLMsgHeader final : &_DMLMsgHeader_default_instance_); } static constexpr int kIndexInFileMessages = - 9; + 11; friend void swap(DMLMsgHeader& a, DMLMsgHeader& b) { a.Swap(&b); @@ -2420,7 +2661,7 @@ class PrivilegeExt final : &_PrivilegeExt_default_instance_); } static constexpr int kIndexInFileMessages = - 10; + 12; friend void swap(PrivilegeExt& a, PrivilegeExt& b) { a.Swap(&b); @@ -2601,7 +2842,7 @@ class SegmentStats final : &_SegmentStats_default_instance_); } static constexpr int kIndexInFileMessages = - 11; + 13; friend void swap(SegmentStats& a, SegmentStats& b) { a.Swap(&b); @@ -2788,7 +3029,7 @@ class ClientInfo final : &_ClientInfo_default_instance_); } static constexpr int kIndexInFileMessages = - 13; + 15; friend void swap(ClientInfo& a, ClientInfo& b) { a.Swap(&b); @@ -3059,7 +3300,7 @@ class ServerInfo final : &_ServerInfo_default_instance_); } static constexpr int kIndexInFileMessages = - 15; + 17; friend void swap(ServerInfo& a, ServerInfo& b) { a.Swap(&b); @@ -3888,6 +4129,8 @@ inline void Address::set_port(int64_t value) { // ------------------------------------------------------------------- +// ------------------------------------------------------------------- + // MsgBase // .milvus.proto.common.MsgType msg_type = 1; @@ -3990,6 +4233,169 @@ inline void MsgBase::set_targetid(int64_t value) { // @@protoc_insertion_point(field_set:milvus.proto.common.MsgBase.targetID) } +// map properties = 6; +inline int MsgBase::_internal_properties_size() const { + return _impl_.properties_.size(); +} +inline int MsgBase::properties_size() const { + return _internal_properties_size(); +} +inline void MsgBase::clear_properties() { + _impl_.properties_.Clear(); +} +inline const ::PROTOBUF_NAMESPACE_ID::Map< std::string, std::string >& +MsgBase::_internal_properties() const { + return _impl_.properties_.GetMap(); +} +inline const ::PROTOBUF_NAMESPACE_ID::Map< std::string, std::string >& +MsgBase::properties() const { + // @@protoc_insertion_point(field_map:milvus.proto.common.MsgBase.properties) + return _internal_properties(); +} +inline ::PROTOBUF_NAMESPACE_ID::Map< std::string, std::string >* +MsgBase::_internal_mutable_properties() { + return _impl_.properties_.MutableMap(); +} +inline ::PROTOBUF_NAMESPACE_ID::Map< std::string, std::string >* +MsgBase::mutable_properties() { + // @@protoc_insertion_point(field_mutable_map:milvus.proto.common.MsgBase.properties) + return _internal_mutable_properties(); +} + +// .milvus.proto.common.ReplicateInfo replicateInfo = 7; +inline bool MsgBase::_internal_has_replicateinfo() const { + return this != internal_default_instance() && _impl_.replicateinfo_ != nullptr; +} +inline bool MsgBase::has_replicateinfo() const { + return _internal_has_replicateinfo(); +} +inline void MsgBase::clear_replicateinfo() { + if (GetArenaForAllocation() == nullptr && _impl_.replicateinfo_ != nullptr) { + delete _impl_.replicateinfo_; + } + _impl_.replicateinfo_ = nullptr; +} +inline const ::milvus::proto::common::ReplicateInfo& MsgBase::_internal_replicateinfo() const { + const ::milvus::proto::common::ReplicateInfo* p = _impl_.replicateinfo_; + return p != nullptr ? *p : reinterpret_cast( + ::milvus::proto::common::_ReplicateInfo_default_instance_); +} +inline const ::milvus::proto::common::ReplicateInfo& MsgBase::replicateinfo() const { + // @@protoc_insertion_point(field_get:milvus.proto.common.MsgBase.replicateInfo) + return _internal_replicateinfo(); +} +inline void MsgBase::unsafe_arena_set_allocated_replicateinfo( + ::milvus::proto::common::ReplicateInfo* replicateinfo) { + if (GetArenaForAllocation() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(_impl_.replicateinfo_); + } + _impl_.replicateinfo_ = replicateinfo; + if (replicateinfo) { + + } else { + + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:milvus.proto.common.MsgBase.replicateInfo) +} +inline ::milvus::proto::common::ReplicateInfo* MsgBase::release_replicateinfo() { + + ::milvus::proto::common::ReplicateInfo* temp = _impl_.replicateinfo_; + _impl_.replicateinfo_ = nullptr; +#ifdef PROTOBUF_FORCE_COPY_IN_RELEASE + auto* old = reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(temp); + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + if (GetArenaForAllocation() == nullptr) { delete old; } +#else // PROTOBUF_FORCE_COPY_IN_RELEASE + if (GetArenaForAllocation() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } +#endif // !PROTOBUF_FORCE_COPY_IN_RELEASE + return temp; +} +inline ::milvus::proto::common::ReplicateInfo* MsgBase::unsafe_arena_release_replicateinfo() { + // @@protoc_insertion_point(field_release:milvus.proto.common.MsgBase.replicateInfo) + + ::milvus::proto::common::ReplicateInfo* temp = _impl_.replicateinfo_; + _impl_.replicateinfo_ = nullptr; + return temp; +} +inline ::milvus::proto::common::ReplicateInfo* MsgBase::_internal_mutable_replicateinfo() { + + if (_impl_.replicateinfo_ == nullptr) { + auto* p = CreateMaybeMessage<::milvus::proto::common::ReplicateInfo>(GetArenaForAllocation()); + _impl_.replicateinfo_ = p; + } + return _impl_.replicateinfo_; +} +inline ::milvus::proto::common::ReplicateInfo* MsgBase::mutable_replicateinfo() { + ::milvus::proto::common::ReplicateInfo* _msg = _internal_mutable_replicateinfo(); + // @@protoc_insertion_point(field_mutable:milvus.proto.common.MsgBase.replicateInfo) + return _msg; +} +inline void MsgBase::set_allocated_replicateinfo(::milvus::proto::common::ReplicateInfo* replicateinfo) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArenaForAllocation(); + if (message_arena == nullptr) { + delete _impl_.replicateinfo_; + } + if (replicateinfo) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::InternalGetOwningArena(replicateinfo); + if (message_arena != submessage_arena) { + replicateinfo = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, replicateinfo, submessage_arena); + } + + } else { + + } + _impl_.replicateinfo_ = replicateinfo; + // @@protoc_insertion_point(field_set_allocated:milvus.proto.common.MsgBase.replicateInfo) +} + +// ------------------------------------------------------------------- + +// ReplicateInfo + +// bool isReplicate = 1; +inline void ReplicateInfo::clear_isreplicate() { + _impl_.isreplicate_ = false; +} +inline bool ReplicateInfo::_internal_isreplicate() const { + return _impl_.isreplicate_; +} +inline bool ReplicateInfo::isreplicate() const { + // @@protoc_insertion_point(field_get:milvus.proto.common.ReplicateInfo.isReplicate) + return _internal_isreplicate(); +} +inline void ReplicateInfo::_internal_set_isreplicate(bool value) { + + _impl_.isreplicate_ = value; +} +inline void ReplicateInfo::set_isreplicate(bool value) { + _internal_set_isreplicate(value); + // @@protoc_insertion_point(field_set:milvus.proto.common.ReplicateInfo.isReplicate) +} + +// uint64 msgTimestamp = 2; +inline void ReplicateInfo::clear_msgtimestamp() { + _impl_.msgtimestamp_ = uint64_t{0u}; +} +inline uint64_t ReplicateInfo::_internal_msgtimestamp() const { + return _impl_.msgtimestamp_; +} +inline uint64_t ReplicateInfo::msgtimestamp() const { + // @@protoc_insertion_point(field_get:milvus.proto.common.ReplicateInfo.msgTimestamp) + return _internal_msgtimestamp(); +} +inline void ReplicateInfo::_internal_set_msgtimestamp(uint64_t value) { + + _impl_.msgtimestamp_ = value; +} +inline void ReplicateInfo::set_msgtimestamp(uint64_t value) { + _internal_set_msgtimestamp(value); + // @@protoc_insertion_point(field_set:milvus.proto.common.ReplicateInfo.msgTimestamp) +} + // ------------------------------------------------------------------- // MsgHeader @@ -4959,6 +5365,10 @@ ServerInfo::mutable_reserved() { // ------------------------------------------------------------------- +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + // @@protoc_insertion_point(namespace_scope) diff --git a/internal/core/src/pb/plan.pb.cc b/internal/core/src/pb/plan.pb.cc index 592b90fa6a6af..977e3998ec686 100644 --- a/internal/core/src/pb/plan.pb.cc +++ b/internal/core/src/pb/plan.pb.cc @@ -41,6 +41,7 @@ PROTOBUF_CONSTEXPR Array::Array( ::_pbi::ConstantInitialized): _impl_{ /*decltype(_impl_.array_)*/{} , /*decltype(_impl_.same_type_)*/false + , /*decltype(_impl_.element_type_)*/0 , /*decltype(_impl_._cached_size_)*/{}} {} struct ArrayDefaultTypeInternal { PROTOBUF_CONSTEXPR ArrayDefaultTypeInternal() @@ -75,6 +76,7 @@ PROTOBUF_CONSTEXPR ColumnInfo::ColumnInfo( , /*decltype(_impl_.is_primary_key_)*/false , /*decltype(_impl_.is_autoid_)*/false , /*decltype(_impl_.is_partition_key_)*/false + , /*decltype(_impl_.element_type_)*/0 , /*decltype(_impl_._cached_size_)*/{}} {} struct ColumnInfoDefaultTypeInternal { PROTOBUF_CONSTEXPR ColumnInfoDefaultTypeInternal() @@ -309,7 +311,7 @@ PROTOBUF_CONSTEXPR VectorANNS::VectorANNS( , /*decltype(_impl_.predicates_)*/nullptr , /*decltype(_impl_.query_info_)*/nullptr , /*decltype(_impl_.field_id_)*/int64_t{0} - , /*decltype(_impl_.is_binary_)*/false + , /*decltype(_impl_.vector_type_)*/0 , /*decltype(_impl_._cached_size_)*/{}} {} struct VectorANNSDefaultTypeInternal { PROTOBUF_CONSTEXPR VectorANNSDefaultTypeInternal() @@ -355,7 +357,7 @@ PROTOBUF_ATTRIBUTE_NO_DESTROY PROTOBUF_CONSTINIT PROTOBUF_ATTRIBUTE_INIT_PRIORIT } // namespace proto } // namespace milvus static ::_pb::Metadata file_level_metadata_plan_2eproto[22]; -static const ::_pb::EnumDescriptor* file_level_enum_descriptors_plan_2eproto[5]; +static const ::_pb::EnumDescriptor* file_level_enum_descriptors_plan_2eproto[6]; static constexpr ::_pb::ServiceDescriptor const** file_level_service_descriptors_plan_2eproto = nullptr; const uint32_t TableStruct_plan_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { @@ -379,6 +381,7 @@ const uint32_t TableStruct_plan_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(pro ~0u, // no _inlined_string_donated_ PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::Array, _impl_.array_), PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::Array, _impl_.same_type_), + PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::Array, _impl_.element_type_), ~0u, // no _has_bits_ PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::QueryInfo, _internal_metadata_), ~0u, // no _extensions_ @@ -401,6 +404,7 @@ const uint32_t TableStruct_plan_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(pro PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::ColumnInfo, _impl_.is_autoid_), PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::ColumnInfo, _impl_.nested_path_), PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::ColumnInfo, _impl_.is_partition_key_), + PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::ColumnInfo, _impl_.element_type_), ~0u, // no _has_bits_ PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::ColumnExpr, _internal_metadata_), ~0u, // no _extensions_ @@ -548,7 +552,7 @@ const uint32_t TableStruct_plan_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(pro ~0u, // no _oneof_case_ ~0u, // no _weak_field_map_ ~0u, // no _inlined_string_donated_ - PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::VectorANNS, _impl_.is_binary_), + PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::VectorANNS, _impl_.vector_type_), PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::VectorANNS, _impl_.field_id_), PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::VectorANNS, _impl_.predicates_), PROTOBUF_FIELD_OFFSET(::milvus::proto::plan::VectorANNS, _impl_.query_info_), @@ -577,26 +581,26 @@ const uint32_t TableStruct_plan_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(pro static const ::_pbi::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { { 0, -1, -1, sizeof(::milvus::proto::plan::GenericValue)}, { 12, -1, -1, sizeof(::milvus::proto::plan::Array)}, - { 20, -1, -1, sizeof(::milvus::proto::plan::QueryInfo)}, - { 30, -1, -1, sizeof(::milvus::proto::plan::ColumnInfo)}, - { 42, -1, -1, sizeof(::milvus::proto::plan::ColumnExpr)}, - { 49, -1, -1, sizeof(::milvus::proto::plan::ExistsExpr)}, - { 56, -1, -1, sizeof(::milvus::proto::plan::ValueExpr)}, - { 63, -1, -1, sizeof(::milvus::proto::plan::UnaryRangeExpr)}, - { 72, -1, -1, sizeof(::milvus::proto::plan::BinaryRangeExpr)}, - { 83, -1, -1, sizeof(::milvus::proto::plan::CompareExpr)}, - { 92, -1, -1, sizeof(::milvus::proto::plan::TermExpr)}, - { 101, -1, -1, sizeof(::milvus::proto::plan::JSONContainsExpr)}, - { 111, -1, -1, sizeof(::milvus::proto::plan::UnaryExpr)}, - { 119, -1, -1, sizeof(::milvus::proto::plan::BinaryExpr)}, - { 128, -1, -1, sizeof(::milvus::proto::plan::BinaryArithOp)}, - { 137, -1, -1, sizeof(::milvus::proto::plan::BinaryArithExpr)}, - { 146, -1, -1, sizeof(::milvus::proto::plan::BinaryArithOpEvalRangeExpr)}, - { 157, -1, -1, sizeof(::milvus::proto::plan::AlwaysTrueExpr)}, - { 163, -1, -1, sizeof(::milvus::proto::plan::Expr)}, - { 183, -1, -1, sizeof(::milvus::proto::plan::VectorANNS)}, - { 194, -1, -1, sizeof(::milvus::proto::plan::QueryPlanNode)}, - { 203, -1, -1, sizeof(::milvus::proto::plan::PlanNode)}, + { 21, -1, -1, sizeof(::milvus::proto::plan::QueryInfo)}, + { 31, -1, -1, sizeof(::milvus::proto::plan::ColumnInfo)}, + { 44, -1, -1, sizeof(::milvus::proto::plan::ColumnExpr)}, + { 51, -1, -1, sizeof(::milvus::proto::plan::ExistsExpr)}, + { 58, -1, -1, sizeof(::milvus::proto::plan::ValueExpr)}, + { 65, -1, -1, sizeof(::milvus::proto::plan::UnaryRangeExpr)}, + { 74, -1, -1, sizeof(::milvus::proto::plan::BinaryRangeExpr)}, + { 85, -1, -1, sizeof(::milvus::proto::plan::CompareExpr)}, + { 94, -1, -1, sizeof(::milvus::proto::plan::TermExpr)}, + { 103, -1, -1, sizeof(::milvus::proto::plan::JSONContainsExpr)}, + { 113, -1, -1, sizeof(::milvus::proto::plan::UnaryExpr)}, + { 121, -1, -1, sizeof(::milvus::proto::plan::BinaryExpr)}, + { 130, -1, -1, sizeof(::milvus::proto::plan::BinaryArithOp)}, + { 139, -1, -1, sizeof(::milvus::proto::plan::BinaryArithExpr)}, + { 148, -1, -1, sizeof(::milvus::proto::plan::BinaryArithOpEvalRangeExpr)}, + { 159, -1, -1, sizeof(::milvus::proto::plan::AlwaysTrueExpr)}, + { 165, -1, -1, sizeof(::milvus::proto::plan::Expr)}, + { 185, -1, -1, sizeof(::milvus::proto::plan::VectorANNS)}, + { 196, -1, -1, sizeof(::milvus::proto::plan::QueryPlanNode)}, + { 205, -1, -1, sizeof(::milvus::proto::plan::PlanNode)}, }; static const ::_pb::Message* const file_default_instances[] = { @@ -630,114 +634,120 @@ const char descriptor_table_protodef_plan_2eproto[] PROTOBUF_SECTION_VARIABLE(pr "H\000\022\023\n\tint64_val\030\002 \001(\003H\000\022\023\n\tfloat_val\030\003 \001" "(\001H\000\022\024\n\nstring_val\030\004 \001(\tH\000\022-\n\tarray_val\030" "\005 \001(\0132\030.milvus.proto.plan.ArrayH\000B\005\n\003val" - "\"J\n\005Array\022.\n\005array\030\001 \003(\0132\037.milvus.proto." - "plan.GenericValue\022\021\n\tsame_type\030\002 \001(\010\"\\\n\t" - "QueryInfo\022\014\n\004topk\030\001 \001(\003\022\023\n\013metric_type\030\003" - " \001(\t\022\025\n\rsearch_params\030\004 \001(\t\022\025\n\rround_dec" - "imal\030\005 \001(\003\"\252\001\n\nColumnInfo\022\020\n\010field_id\030\001 " - "\001(\003\0220\n\tdata_type\030\002 \001(\0162\035.milvus.proto.sc" - "hema.DataType\022\026\n\016is_primary_key\030\003 \001(\010\022\021\n" - "\tis_autoID\030\004 \001(\010\022\023\n\013nested_path\030\005 \003(\t\022\030\n" - "\020is_partition_key\030\006 \001(\010\"9\n\nColumnExpr\022+\n" - "\004info\030\001 \001(\0132\035.milvus.proto.plan.ColumnIn" - "fo\"9\n\nExistsExpr\022+\n\004info\030\001 \001(\0132\035.milvus." - "proto.plan.ColumnInfo\";\n\tValueExpr\022.\n\005va" - "lue\030\001 \001(\0132\037.milvus.proto.plan.GenericVal" - "ue\"\233\001\n\016UnaryRangeExpr\0222\n\013column_info\030\001 \001" - "(\0132\035.milvus.proto.plan.ColumnInfo\022%\n\002op\030" - "\002 \001(\0162\031.milvus.proto.plan.OpType\022.\n\005valu" - "e\030\003 \001(\0132\037.milvus.proto.plan.GenericValue" - "\"\343\001\n\017BinaryRangeExpr\0222\n\013column_info\030\001 \001(" - "\0132\035.milvus.proto.plan.ColumnInfo\022\027\n\017lowe" - "r_inclusive\030\002 \001(\010\022\027\n\017upper_inclusive\030\003 \001" - "(\010\0224\n\013lower_value\030\004 \001(\0132\037.milvus.proto.p" - "lan.GenericValue\0224\n\013upper_value\030\005 \001(\0132\037." - "milvus.proto.plan.GenericValue\"\247\001\n\013Compa" - "reExpr\0227\n\020left_column_info\030\001 \001(\0132\035.milvu" - "s.proto.plan.ColumnInfo\0228\n\021right_column_" - "info\030\002 \001(\0132\035.milvus.proto.plan.ColumnInf" - "o\022%\n\002op\030\003 \001(\0162\031.milvus.proto.plan.OpType" - "\"\204\001\n\010TermExpr\0222\n\013column_info\030\001 \001(\0132\035.mil" - "vus.proto.plan.ColumnInfo\022/\n\006values\030\002 \003(" - "\0132\037.milvus.proto.plan.GenericValue\022\023\n\013is" - "_in_field\030\003 \001(\010\"\224\002\n\020JSONContainsExpr\0222\n\013" - "column_info\030\001 \001(\0132\035.milvus.proto.plan.Co" - "lumnInfo\0221\n\010elements\030\002 \003(\0132\037.milvus.prot" - "o.plan.GenericValue\0226\n\002op\030\003 \001(\0162*.milvus" - ".proto.plan.JSONContainsExpr.JSONOp\022\032\n\022e" - "lements_same_type\030\004 \001(\010\"E\n\006JSONOp\022\013\n\007Inv" - "alid\020\000\022\014\n\010Contains\020\001\022\017\n\013ContainsAll\020\002\022\017\n" - "\013ContainsAny\020\003\"\206\001\n\tUnaryExpr\0220\n\002op\030\001 \001(\016" - "2$.milvus.proto.plan.UnaryExpr.UnaryOp\022&" - "\n\005child\030\002 \001(\0132\027.milvus.proto.plan.Expr\"\037" - "\n\007UnaryOp\022\013\n\007Invalid\020\000\022\007\n\003Not\020\001\"\307\001\n\nBina" - "ryExpr\0222\n\002op\030\001 \001(\0162&.milvus.proto.plan.B" - "inaryExpr.BinaryOp\022%\n\004left\030\002 \001(\0132\027.milvu" - "s.proto.plan.Expr\022&\n\005right\030\003 \001(\0132\027.milvu" - "s.proto.plan.Expr\"6\n\010BinaryOp\022\013\n\007Invalid" - "\020\000\022\016\n\nLogicalAnd\020\001\022\r\n\tLogicalOr\020\002\"\255\001\n\rBi" - "naryArithOp\0222\n\013column_info\030\001 \001(\0132\035.milvu" - "s.proto.plan.ColumnInfo\0220\n\010arith_op\030\002 \001(" - "\0162\036.milvus.proto.plan.ArithOpType\0226\n\rrig" - "ht_operand\030\003 \001(\0132\037.milvus.proto.plan.Gen" - "ericValue\"\214\001\n\017BinaryArithExpr\022%\n\004left\030\001 " - "\001(\0132\027.milvus.proto.plan.Expr\022&\n\005right\030\002 " - "\001(\0132\027.milvus.proto.plan.Expr\022*\n\002op\030\003 \001(\016" - "2\036.milvus.proto.plan.ArithOpType\"\221\002\n\032Bin" - "aryArithOpEvalRangeExpr\0222\n\013column_info\030\001" - " \001(\0132\035.milvus.proto.plan.ColumnInfo\0220\n\010a" - "rith_op\030\002 \001(\0162\036.milvus.proto.plan.ArithO" - "pType\0226\n\rright_operand\030\003 \001(\0132\037.milvus.pr" - "oto.plan.GenericValue\022%\n\002op\030\004 \001(\0162\031.milv" - "us.proto.plan.OpType\022.\n\005value\030\005 \001(\0132\037.mi" - "lvus.proto.plan.GenericValue\"\020\n\016AlwaysTr" - "ueExpr\"\237\006\n\004Expr\0220\n\tterm_expr\030\001 \001(\0132\033.mil" - "vus.proto.plan.TermExprH\000\0222\n\nunary_expr\030" - "\002 \001(\0132\034.milvus.proto.plan.UnaryExprH\000\0224\n" - "\013binary_expr\030\003 \001(\0132\035.milvus.proto.plan.B" - "inaryExprH\000\0226\n\014compare_expr\030\004 \001(\0132\036.milv" - "us.proto.plan.CompareExprH\000\022=\n\020unary_ran" - "ge_expr\030\005 \001(\0132!.milvus.proto.plan.UnaryR" - "angeExprH\000\022\?\n\021binary_range_expr\030\006 \001(\0132\"." - "milvus.proto.plan.BinaryRangeExprH\000\022X\n\037b" - "inary_arith_op_eval_range_expr\030\007 \001(\0132-.m" - "ilvus.proto.plan.BinaryArithOpEvalRangeE" - "xprH\000\022\?\n\021binary_arith_expr\030\010 \001(\0132\".milvu" - "s.proto.plan.BinaryArithExprH\000\0222\n\nvalue_" - "expr\030\t \001(\0132\034.milvus.proto.plan.ValueExpr" - "H\000\0224\n\013column_expr\030\n \001(\0132\035.milvus.proto.p" - "lan.ColumnExprH\000\0224\n\013exists_expr\030\013 \001(\0132\035." - "milvus.proto.plan.ExistsExprH\000\022=\n\020always" - "_true_expr\030\014 \001(\0132!.milvus.proto.plan.Alw" - "aysTrueExprH\000\022A\n\022json_contains_expr\030\r \001(" - "\0132#.milvus.proto.plan.JSONContainsExprH\000" - "B\006\n\004expr\"\251\001\n\nVectorANNS\022\021\n\tis_binary\030\001 \001" - "(\010\022\020\n\010field_id\030\002 \001(\003\022+\n\npredicates\030\003 \001(\013" - "2\027.milvus.proto.plan.Expr\0220\n\nquery_info\030" - "\004 \001(\0132\034.milvus.proto.plan.QueryInfo\022\027\n\017p" - "laceholder_tag\030\005 \001(\t\"]\n\rQueryPlanNode\022+\n" - "\npredicates\030\001 \001(\0132\027.milvus.proto.plan.Ex" - "pr\022\020\n\010is_count\030\002 \001(\010\022\r\n\005limit\030\003 \001(\003\"\304\001\n\010" - "PlanNode\0224\n\013vector_anns\030\001 \001(\0132\035.milvus.p" - "roto.plan.VectorANNSH\000\022-\n\npredicates\030\002 \001" - "(\0132\027.milvus.proto.plan.ExprH\000\0221\n\005query\030\004" - " \001(\0132 .milvus.proto.plan.QueryPlanNodeH\000" - "\022\030\n\020output_field_ids\030\003 \003(\003B\006\n\004node*\272\001\n\006O" - "pType\022\013\n\007Invalid\020\000\022\017\n\013GreaterThan\020\001\022\020\n\014G" - "reaterEqual\020\002\022\014\n\010LessThan\020\003\022\r\n\tLessEqual" - "\020\004\022\t\n\005Equal\020\005\022\014\n\010NotEqual\020\006\022\017\n\013PrefixMat" - "ch\020\007\022\020\n\014PostfixMatch\020\010\022\t\n\005Match\020\t\022\t\n\005Ran" - "ge\020\n\022\006\n\002In\020\013\022\t\n\005NotIn\020\014*G\n\013ArithOpType\022\013" - "\n\007Unknown\020\000\022\007\n\003Add\020\001\022\007\n\003Sub\020\002\022\007\n\003Mul\020\003\022\007" - "\n\003Div\020\004\022\007\n\003Mod\020\005B3Z1github.com/milvus-io" - "/milvus/internal/proto/planpbb\006proto3" + "\"\177\n\005Array\022.\n\005array\030\001 \003(\0132\037.milvus.proto." + "plan.GenericValue\022\021\n\tsame_type\030\002 \001(\010\0223\n\014" + "element_type\030\003 \001(\0162\035.milvus.proto.schema" + ".DataType\"\\\n\tQueryInfo\022\014\n\004topk\030\001 \001(\003\022\023\n\013" + "metric_type\030\003 \001(\t\022\025\n\rsearch_params\030\004 \001(\t" + "\022\025\n\rround_decimal\030\005 \001(\003\"\337\001\n\nColumnInfo\022\020" + "\n\010field_id\030\001 \001(\003\0220\n\tdata_type\030\002 \001(\0162\035.mi" + "lvus.proto.schema.DataType\022\026\n\016is_primary" + "_key\030\003 \001(\010\022\021\n\tis_autoID\030\004 \001(\010\022\023\n\013nested_" + "path\030\005 \003(\t\022\030\n\020is_partition_key\030\006 \001(\010\0223\n\014" + "element_type\030\007 \001(\0162\035.milvus.proto.schema" + ".DataType\"9\n\nColumnExpr\022+\n\004info\030\001 \001(\0132\035." + "milvus.proto.plan.ColumnInfo\"9\n\nExistsEx" + "pr\022+\n\004info\030\001 \001(\0132\035.milvus.proto.plan.Col" + "umnInfo\";\n\tValueExpr\022.\n\005value\030\001 \001(\0132\037.mi" + "lvus.proto.plan.GenericValue\"\233\001\n\016UnaryRa" + "ngeExpr\0222\n\013column_info\030\001 \001(\0132\035.milvus.pr" + "oto.plan.ColumnInfo\022%\n\002op\030\002 \001(\0162\031.milvus" + ".proto.plan.OpType\022.\n\005value\030\003 \001(\0132\037.milv" + "us.proto.plan.GenericValue\"\343\001\n\017BinaryRan" + "geExpr\0222\n\013column_info\030\001 \001(\0132\035.milvus.pro" + "to.plan.ColumnInfo\022\027\n\017lower_inclusive\030\002 " + "\001(\010\022\027\n\017upper_inclusive\030\003 \001(\010\0224\n\013lower_va" + "lue\030\004 \001(\0132\037.milvus.proto.plan.GenericVal" + "ue\0224\n\013upper_value\030\005 \001(\0132\037.milvus.proto.p" + "lan.GenericValue\"\247\001\n\013CompareExpr\0227\n\020left" + "_column_info\030\001 \001(\0132\035.milvus.proto.plan.C" + "olumnInfo\0228\n\021right_column_info\030\002 \001(\0132\035.m" + "ilvus.proto.plan.ColumnInfo\022%\n\002op\030\003 \001(\0162" + "\031.milvus.proto.plan.OpType\"\204\001\n\010TermExpr\022" + "2\n\013column_info\030\001 \001(\0132\035.milvus.proto.plan" + ".ColumnInfo\022/\n\006values\030\002 \003(\0132\037.milvus.pro" + "to.plan.GenericValue\022\023\n\013is_in_field\030\003 \001(" + "\010\"\224\002\n\020JSONContainsExpr\0222\n\013column_info\030\001 " + "\001(\0132\035.milvus.proto.plan.ColumnInfo\0221\n\010el" + "ements\030\002 \003(\0132\037.milvus.proto.plan.Generic" + "Value\0226\n\002op\030\003 \001(\0162*.milvus.proto.plan.JS" + "ONContainsExpr.JSONOp\022\032\n\022elements_same_t" + "ype\030\004 \001(\010\"E\n\006JSONOp\022\013\n\007Invalid\020\000\022\014\n\010Cont" + "ains\020\001\022\017\n\013ContainsAll\020\002\022\017\n\013ContainsAny\020\003" + "\"\206\001\n\tUnaryExpr\0220\n\002op\030\001 \001(\0162$.milvus.prot" + "o.plan.UnaryExpr.UnaryOp\022&\n\005child\030\002 \001(\0132" + "\027.milvus.proto.plan.Expr\"\037\n\007UnaryOp\022\013\n\007I" + "nvalid\020\000\022\007\n\003Not\020\001\"\307\001\n\nBinaryExpr\0222\n\002op\030\001" + " \001(\0162&.milvus.proto.plan.BinaryExpr.Bina" + "ryOp\022%\n\004left\030\002 \001(\0132\027.milvus.proto.plan.E" + "xpr\022&\n\005right\030\003 \001(\0132\027.milvus.proto.plan.E" + "xpr\"6\n\010BinaryOp\022\013\n\007Invalid\020\000\022\016\n\nLogicalA" + "nd\020\001\022\r\n\tLogicalOr\020\002\"\255\001\n\rBinaryArithOp\0222\n" + "\013column_info\030\001 \001(\0132\035.milvus.proto.plan.C" + "olumnInfo\0220\n\010arith_op\030\002 \001(\0162\036.milvus.pro" + "to.plan.ArithOpType\0226\n\rright_operand\030\003 \001" + "(\0132\037.milvus.proto.plan.GenericValue\"\214\001\n\017" + "BinaryArithExpr\022%\n\004left\030\001 \001(\0132\027.milvus.p" + "roto.plan.Expr\022&\n\005right\030\002 \001(\0132\027.milvus.p" + "roto.plan.Expr\022*\n\002op\030\003 \001(\0162\036.milvus.prot" + "o.plan.ArithOpType\"\221\002\n\032BinaryArithOpEval" + "RangeExpr\0222\n\013column_info\030\001 \001(\0132\035.milvus." + "proto.plan.ColumnInfo\0220\n\010arith_op\030\002 \001(\0162" + "\036.milvus.proto.plan.ArithOpType\0226\n\rright" + "_operand\030\003 \001(\0132\037.milvus.proto.plan.Gener" + "icValue\022%\n\002op\030\004 \001(\0162\031.milvus.proto.plan." + "OpType\022.\n\005value\030\005 \001(\0132\037.milvus.proto.pla" + "n.GenericValue\"\020\n\016AlwaysTrueExpr\"\237\006\n\004Exp" + "r\0220\n\tterm_expr\030\001 \001(\0132\033.milvus.proto.plan" + ".TermExprH\000\0222\n\nunary_expr\030\002 \001(\0132\034.milvus" + ".proto.plan.UnaryExprH\000\0224\n\013binary_expr\030\003" + " \001(\0132\035.milvus.proto.plan.BinaryExprH\000\0226\n" + "\014compare_expr\030\004 \001(\0132\036.milvus.proto.plan." + "CompareExprH\000\022=\n\020unary_range_expr\030\005 \001(\0132" + "!.milvus.proto.plan.UnaryRangeExprH\000\022\?\n\021" + "binary_range_expr\030\006 \001(\0132\".milvus.proto.p" + "lan.BinaryRangeExprH\000\022X\n\037binary_arith_op" + "_eval_range_expr\030\007 \001(\0132-.milvus.proto.pl" + "an.BinaryArithOpEvalRangeExprH\000\022\?\n\021binar" + "y_arith_expr\030\010 \001(\0132\".milvus.proto.plan.B" + "inaryArithExprH\000\0222\n\nvalue_expr\030\t \001(\0132\034.m" + "ilvus.proto.plan.ValueExprH\000\0224\n\013column_e" + "xpr\030\n \001(\0132\035.milvus.proto.plan.ColumnExpr" + "H\000\0224\n\013exists_expr\030\013 \001(\0132\035.milvus.proto.p" + "lan.ExistsExprH\000\022=\n\020always_true_expr\030\014 \001" + "(\0132!.milvus.proto.plan.AlwaysTrueExprH\000\022" + "A\n\022json_contains_expr\030\r \001(\0132#.milvus.pro" + "to.plan.JSONContainsExprH\000B\006\n\004expr\"\312\001\n\nV" + "ectorANNS\0222\n\013vector_type\030\001 \001(\0162\035.milvus." + "proto.plan.VectorType\022\020\n\010field_id\030\002 \001(\003\022" + "+\n\npredicates\030\003 \001(\0132\027.milvus.proto.plan." + "Expr\0220\n\nquery_info\030\004 \001(\0132\034.milvus.proto." + "plan.QueryInfo\022\027\n\017placeholder_tag\030\005 \001(\t\"" + "]\n\rQueryPlanNode\022+\n\npredicates\030\001 \001(\0132\027.m" + "ilvus.proto.plan.Expr\022\020\n\010is_count\030\002 \001(\010\022" + "\r\n\005limit\030\003 \001(\003\"\304\001\n\010PlanNode\0224\n\013vector_an" + "ns\030\001 \001(\0132\035.milvus.proto.plan.VectorANNSH" + "\000\022-\n\npredicates\030\002 \001(\0132\027.milvus.proto.pla" + "n.ExprH\000\0221\n\005query\030\004 \001(\0132 .milvus.proto.p" + "lan.QueryPlanNodeH\000\022\030\n\020output_field_ids\030" + "\003 \003(\003B\006\n\004node*\272\001\n\006OpType\022\013\n\007Invalid\020\000\022\017\n" + "\013GreaterThan\020\001\022\020\n\014GreaterEqual\020\002\022\014\n\010Less" + "Than\020\003\022\r\n\tLessEqual\020\004\022\t\n\005Equal\020\005\022\014\n\010NotE" + "qual\020\006\022\017\n\013PrefixMatch\020\007\022\020\n\014PostfixMatch\020" + "\010\022\t\n\005Match\020\t\022\t\n\005Range\020\n\022\006\n\002In\020\013\022\t\n\005NotIn" + "\020\014*X\n\013ArithOpType\022\013\n\007Unknown\020\000\022\007\n\003Add\020\001\022" + "\007\n\003Sub\020\002\022\007\n\003Mul\020\003\022\007\n\003Div\020\004\022\007\n\003Mod\020\005\022\017\n\013A" + "rrayLength\020\006*B\n\nVectorType\022\020\n\014BinaryVect" + "or\020\000\022\017\n\013FloatVector\020\001\022\021\n\rFloat16Vector\020\002" + "B3Z1github.com/milvus-io/milvus/internal" + "/proto/planpbb\006proto3" ; static const ::_pbi::DescriptorTable* const descriptor_table_plan_2eproto_deps[1] = { &::descriptor_table_schema_2eproto, }; static ::_pbi::once_flag descriptor_table_plan_2eproto_once; const ::_pbi::DescriptorTable descriptor_table_plan_2eproto = { - false, false, 4237, descriptor_table_protodef_plan_2eproto, + false, false, 4461, descriptor_table_protodef_plan_2eproto, "plan.proto", &descriptor_table_plan_2eproto_once, descriptor_table_plan_2eproto_deps, 1, 22, schemas, file_default_instances, TableStruct_plan_2eproto::offsets, @@ -859,6 +869,22 @@ bool ArithOpType_IsValid(int value) { case 3: case 4: case 5: + case 6: + return true; + default: + return false; + } +} + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* VectorType_descriptor() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&descriptor_table_plan_2eproto); + return file_level_enum_descriptors_plan_2eproto[5]; +} +bool VectorType_IsValid(int value) { + switch (value) { + case 0: + case 1: + case 2: return true; default: return false; @@ -1263,10 +1289,13 @@ Array::Array(const Array& from) new (&_impl_) Impl_{ decltype(_impl_.array_){from._impl_.array_} , decltype(_impl_.same_type_){} + , decltype(_impl_.element_type_){} , /*decltype(_impl_._cached_size_)*/{}}; _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); - _this->_impl_.same_type_ = from._impl_.same_type_; + ::memcpy(&_impl_.same_type_, &from._impl_.same_type_, + static_cast(reinterpret_cast(&_impl_.element_type_) - + reinterpret_cast(&_impl_.same_type_)) + sizeof(_impl_.element_type_)); // @@protoc_insertion_point(copy_constructor:milvus.proto.plan.Array) } @@ -1277,6 +1306,7 @@ inline void Array::SharedCtor( new (&_impl_) Impl_{ decltype(_impl_.array_){arena} , decltype(_impl_.same_type_){false} + , decltype(_impl_.element_type_){0} , /*decltype(_impl_._cached_size_)*/{} }; } @@ -1306,7 +1336,9 @@ void Array::Clear() { (void) cached_has_bits; _impl_.array_.Clear(); - _impl_.same_type_ = false; + ::memset(&_impl_.same_type_, 0, static_cast( + reinterpret_cast(&_impl_.element_type_) - + reinterpret_cast(&_impl_.same_type_)) + sizeof(_impl_.element_type_)); _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); } @@ -1337,6 +1369,15 @@ const char* Array::_InternalParse(const char* ptr, ::_pbi::ParseContext* ctx) { } else goto handle_unusual; continue; + // .milvus.proto.schema.DataType element_type = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast(tag) == 24)) { + uint64_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + _internal_set_element_type(static_cast<::milvus::proto::schema::DataType>(val)); + } else + goto handle_unusual; + continue; default: goto handle_unusual; } // switch @@ -1380,6 +1421,13 @@ uint8_t* Array::_InternalSerialize( target = ::_pbi::WireFormatLite::WriteBoolToArray(2, this->_internal_same_type(), target); } + // .milvus.proto.schema.DataType element_type = 3; + if (this->_internal_element_type() != 0) { + target = stream->EnsureSpace(target); + target = ::_pbi::WireFormatLite::WriteEnumToArray( + 3, this->_internal_element_type(), target); + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { target = ::_pbi::WireFormat::InternalSerializeUnknownFieldsToArray( _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); @@ -1408,6 +1456,12 @@ size_t Array::ByteSizeLong() const { total_size += 1 + 1; } + // .milvus.proto.schema.DataType element_type = 3; + if (this->_internal_element_type() != 0) { + total_size += 1 + + ::_pbi::WireFormatLite::EnumSize(this->_internal_element_type()); + } + return MaybeComputeUnknownFieldsSize(total_size, &_impl_._cached_size_); } @@ -1430,6 +1484,9 @@ void Array::MergeImpl(::PROTOBUF_NAMESPACE_ID::Message& to_msg, const ::PROTOBUF if (from._internal_same_type() != 0) { _this->_internal_set_same_type(from._internal_same_type()); } + if (from._internal_element_type() != 0) { + _this->_internal_set_element_type(from._internal_element_type()); + } _this->_internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); } @@ -1448,7 +1505,12 @@ void Array::InternalSwap(Array* other) { using std::swap; _internal_metadata_.InternalSwap(&other->_internal_metadata_); _impl_.array_.InternalSwap(&other->_impl_.array_); - swap(_impl_.same_type_, other->_impl_.same_type_); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(Array, _impl_.element_type_) + + sizeof(Array::_impl_.element_type_) + - PROTOBUF_FIELD_OFFSET(Array, _impl_.same_type_)>( + reinterpret_cast(&_impl_.same_type_), + reinterpret_cast(&other->_impl_.same_type_)); } ::PROTOBUF_NAMESPACE_ID::Metadata Array::GetMetadata() const { @@ -1792,12 +1854,13 @@ ColumnInfo::ColumnInfo(const ColumnInfo& from) , decltype(_impl_.is_primary_key_){} , decltype(_impl_.is_autoid_){} , decltype(_impl_.is_partition_key_){} + , decltype(_impl_.element_type_){} , /*decltype(_impl_._cached_size_)*/{}}; _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); ::memcpy(&_impl_.field_id_, &from._impl_.field_id_, - static_cast(reinterpret_cast(&_impl_.is_partition_key_) - - reinterpret_cast(&_impl_.field_id_)) + sizeof(_impl_.is_partition_key_)); + static_cast(reinterpret_cast(&_impl_.element_type_) - + reinterpret_cast(&_impl_.field_id_)) + sizeof(_impl_.element_type_)); // @@protoc_insertion_point(copy_constructor:milvus.proto.plan.ColumnInfo) } @@ -1812,6 +1875,7 @@ inline void ColumnInfo::SharedCtor( , decltype(_impl_.is_primary_key_){false} , decltype(_impl_.is_autoid_){false} , decltype(_impl_.is_partition_key_){false} + , decltype(_impl_.element_type_){0} , /*decltype(_impl_._cached_size_)*/{} }; } @@ -1842,8 +1906,8 @@ void ColumnInfo::Clear() { _impl_.nested_path_.Clear(); ::memset(&_impl_.field_id_, 0, static_cast( - reinterpret_cast(&_impl_.is_partition_key_) - - reinterpret_cast(&_impl_.field_id_)) + sizeof(_impl_.is_partition_key_)); + reinterpret_cast(&_impl_.element_type_) - + reinterpret_cast(&_impl_.field_id_)) + sizeof(_impl_.element_type_)); _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); } @@ -1909,6 +1973,15 @@ const char* ColumnInfo::_InternalParse(const char* ptr, ::_pbi::ParseContext* ct } else goto handle_unusual; continue; + // .milvus.proto.schema.DataType element_type = 7; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast(tag) == 56)) { + uint64_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + _internal_set_element_type(static_cast<::milvus::proto::schema::DataType>(val)); + } else + goto handle_unusual; + continue; default: goto handle_unusual; } // switch @@ -1979,6 +2052,13 @@ uint8_t* ColumnInfo::_InternalSerialize( target = ::_pbi::WireFormatLite::WriteBoolToArray(6, this->_internal_is_partition_key(), target); } + // .milvus.proto.schema.DataType element_type = 7; + if (this->_internal_element_type() != 0) { + target = stream->EnsureSpace(target); + target = ::_pbi::WireFormatLite::WriteEnumToArray( + 7, this->_internal_element_type(), target); + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { target = ::_pbi::WireFormat::InternalSerializeUnknownFieldsToArray( _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); @@ -2029,6 +2109,12 @@ size_t ColumnInfo::ByteSizeLong() const { total_size += 1 + 1; } + // .milvus.proto.schema.DataType element_type = 7; + if (this->_internal_element_type() != 0) { + total_size += 1 + + ::_pbi::WireFormatLite::EnumSize(this->_internal_element_type()); + } + return MaybeComputeUnknownFieldsSize(total_size, &_impl_._cached_size_); } @@ -2063,6 +2149,9 @@ void ColumnInfo::MergeImpl(::PROTOBUF_NAMESPACE_ID::Message& to_msg, const ::PRO if (from._internal_is_partition_key() != 0) { _this->_internal_set_is_partition_key(from._internal_is_partition_key()); } + if (from._internal_element_type() != 0) { + _this->_internal_set_element_type(from._internal_element_type()); + } _this->_internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); } @@ -2082,8 +2171,8 @@ void ColumnInfo::InternalSwap(ColumnInfo* other) { _internal_metadata_.InternalSwap(&other->_internal_metadata_); _impl_.nested_path_.InternalSwap(&other->_impl_.nested_path_); ::PROTOBUF_NAMESPACE_ID::internal::memswap< - PROTOBUF_FIELD_OFFSET(ColumnInfo, _impl_.is_partition_key_) - + sizeof(ColumnInfo::_impl_.is_partition_key_) + PROTOBUF_FIELD_OFFSET(ColumnInfo, _impl_.element_type_) + + sizeof(ColumnInfo::_impl_.element_type_) - PROTOBUF_FIELD_OFFSET(ColumnInfo, _impl_.field_id_)>( reinterpret_cast(&_impl_.field_id_), reinterpret_cast(&other->_impl_.field_id_)); @@ -6475,7 +6564,7 @@ VectorANNS::VectorANNS(const VectorANNS& from) , decltype(_impl_.predicates_){nullptr} , decltype(_impl_.query_info_){nullptr} , decltype(_impl_.field_id_){} - , decltype(_impl_.is_binary_){} + , decltype(_impl_.vector_type_){} , /*decltype(_impl_._cached_size_)*/{}}; _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); @@ -6494,8 +6583,8 @@ VectorANNS::VectorANNS(const VectorANNS& from) _this->_impl_.query_info_ = new ::milvus::proto::plan::QueryInfo(*from._impl_.query_info_); } ::memcpy(&_impl_.field_id_, &from._impl_.field_id_, - static_cast(reinterpret_cast(&_impl_.is_binary_) - - reinterpret_cast(&_impl_.field_id_)) + sizeof(_impl_.is_binary_)); + static_cast(reinterpret_cast(&_impl_.vector_type_) - + reinterpret_cast(&_impl_.field_id_)) + sizeof(_impl_.vector_type_)); // @@protoc_insertion_point(copy_constructor:milvus.proto.plan.VectorANNS) } @@ -6508,7 +6597,7 @@ inline void VectorANNS::SharedCtor( , decltype(_impl_.predicates_){nullptr} , decltype(_impl_.query_info_){nullptr} , decltype(_impl_.field_id_){int64_t{0}} - , decltype(_impl_.is_binary_){false} + , decltype(_impl_.vector_type_){0} , /*decltype(_impl_._cached_size_)*/{} }; _impl_.placeholder_tag_.InitDefault(); @@ -6553,8 +6642,8 @@ void VectorANNS::Clear() { } _impl_.query_info_ = nullptr; ::memset(&_impl_.field_id_, 0, static_cast( - reinterpret_cast(&_impl_.is_binary_) - - reinterpret_cast(&_impl_.field_id_)) + sizeof(_impl_.is_binary_)); + reinterpret_cast(&_impl_.vector_type_) - + reinterpret_cast(&_impl_.field_id_)) + sizeof(_impl_.vector_type_)); _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); } @@ -6564,11 +6653,12 @@ const char* VectorANNS::_InternalParse(const char* ptr, ::_pbi::ParseContext* ct uint32_t tag; ptr = ::_pbi::ReadTag(ptr, &tag); switch (tag >> 3) { - // bool is_binary = 1; + // .milvus.proto.plan.VectorType vector_type = 1; case 1: if (PROTOBUF_PREDICT_TRUE(static_cast(tag) == 8)) { - _impl_.is_binary_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + uint64_t val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); CHK_(ptr); + _internal_set_vector_type(static_cast<::milvus::proto::plan::VectorType>(val)); } else goto handle_unusual; continue; @@ -6635,10 +6725,11 @@ uint8_t* VectorANNS::_InternalSerialize( uint32_t cached_has_bits = 0; (void) cached_has_bits; - // bool is_binary = 1; - if (this->_internal_is_binary() != 0) { + // .milvus.proto.plan.VectorType vector_type = 1; + if (this->_internal_vector_type() != 0) { target = stream->EnsureSpace(target); - target = ::_pbi::WireFormatLite::WriteBoolToArray(1, this->_internal_is_binary(), target); + target = ::_pbi::WireFormatLite::WriteEnumToArray( + 1, this->_internal_vector_type(), target); } // int64 field_id = 2; @@ -6713,9 +6804,10 @@ size_t VectorANNS::ByteSizeLong() const { total_size += ::_pbi::WireFormatLite::Int64SizePlusOne(this->_internal_field_id()); } - // bool is_binary = 1; - if (this->_internal_is_binary() != 0) { - total_size += 1 + 1; + // .milvus.proto.plan.VectorType vector_type = 1; + if (this->_internal_vector_type() != 0) { + total_size += 1 + + ::_pbi::WireFormatLite::EnumSize(this->_internal_vector_type()); } return MaybeComputeUnknownFieldsSize(total_size, &_impl_._cached_size_); @@ -6750,8 +6842,8 @@ void VectorANNS::MergeImpl(::PROTOBUF_NAMESPACE_ID::Message& to_msg, const ::PRO if (from._internal_field_id() != 0) { _this->_internal_set_field_id(from._internal_field_id()); } - if (from._internal_is_binary() != 0) { - _this->_internal_set_is_binary(from._internal_is_binary()); + if (from._internal_vector_type() != 0) { + _this->_internal_set_vector_type(from._internal_vector_type()); } _this->_internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); } @@ -6777,8 +6869,8 @@ void VectorANNS::InternalSwap(VectorANNS* other) { &other->_impl_.placeholder_tag_, rhs_arena ); ::PROTOBUF_NAMESPACE_ID::internal::memswap< - PROTOBUF_FIELD_OFFSET(VectorANNS, _impl_.is_binary_) - + sizeof(VectorANNS::_impl_.is_binary_) + PROTOBUF_FIELD_OFFSET(VectorANNS, _impl_.vector_type_) + + sizeof(VectorANNS::_impl_.vector_type_) - PROTOBUF_FIELD_OFFSET(VectorANNS, _impl_.predicates_)>( reinterpret_cast(&_impl_.predicates_), reinterpret_cast(&other->_impl_.predicates_)); diff --git a/internal/core/src/pb/plan.pb.h b/internal/core/src/pb/plan.pb.h index 4f1f181f8d2e2..70e1f3e6ce00f 100644 --- a/internal/core/src/pb/plan.pb.h +++ b/internal/core/src/pb/plan.pb.h @@ -268,12 +268,13 @@ enum ArithOpType : int { Mul = 3, Div = 4, Mod = 5, + ArrayLength = 6, ArithOpType_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits::min(), ArithOpType_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits::max() }; bool ArithOpType_IsValid(int value); constexpr ArithOpType ArithOpType_MIN = Unknown; -constexpr ArithOpType ArithOpType_MAX = Mod; +constexpr ArithOpType ArithOpType_MAX = ArrayLength; constexpr int ArithOpType_ARRAYSIZE = ArithOpType_MAX + 1; const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* ArithOpType_descriptor(); @@ -290,6 +291,32 @@ inline bool ArithOpType_Parse( return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( ArithOpType_descriptor(), name, value); } +enum VectorType : int { + BinaryVector = 0, + FloatVector = 1, + Float16Vector = 2, + VectorType_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits::min(), + VectorType_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits::max() +}; +bool VectorType_IsValid(int value); +constexpr VectorType VectorType_MIN = BinaryVector; +constexpr VectorType VectorType_MAX = Float16Vector; +constexpr int VectorType_ARRAYSIZE = VectorType_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* VectorType_descriptor(); +template +inline const std::string& VectorType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function VectorType_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + VectorType_descriptor(), enum_t_value); +} +inline bool VectorType_Parse( + ::PROTOBUF_NAMESPACE_ID::ConstStringParam name, VectorType* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + VectorType_descriptor(), name, value); +} // =================================================================== class GenericValue final : @@ -662,6 +689,7 @@ class Array final : enum : int { kArrayFieldNumber = 1, kSameTypeFieldNumber = 2, + kElementTypeFieldNumber = 3, }; // repeated .milvus.proto.plan.GenericValue array = 1; int array_size() const; @@ -690,6 +718,15 @@ class Array final : void _internal_set_same_type(bool value); public: + // .milvus.proto.schema.DataType element_type = 3; + void clear_element_type(); + ::milvus::proto::schema::DataType element_type() const; + void set_element_type(::milvus::proto::schema::DataType value); + private: + ::milvus::proto::schema::DataType _internal_element_type() const; + void _internal_set_element_type(::milvus::proto::schema::DataType value); + public: + // @@protoc_insertion_point(class_scope:milvus.proto.plan.Array) private: class _Internal; @@ -700,6 +737,7 @@ class Array final : struct Impl_ { ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::milvus::proto::plan::GenericValue > array_; bool same_type_; + int element_type_; mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; }; union { Impl_ _impl_; }; @@ -1025,6 +1063,7 @@ class ColumnInfo final : kIsPrimaryKeyFieldNumber = 3, kIsAutoIDFieldNumber = 4, kIsPartitionKeyFieldNumber = 6, + kElementTypeFieldNumber = 7, }; // repeated string nested_path = 5; int nested_path_size() const; @@ -1095,6 +1134,15 @@ class ColumnInfo final : void _internal_set_is_partition_key(bool value); public: + // .milvus.proto.schema.DataType element_type = 7; + void clear_element_type(); + ::milvus::proto::schema::DataType element_type() const; + void set_element_type(::milvus::proto::schema::DataType value); + private: + ::milvus::proto::schema::DataType _internal_element_type() const; + void _internal_set_element_type(::milvus::proto::schema::DataType value); + public: + // @@protoc_insertion_point(class_scope:milvus.proto.plan.ColumnInfo) private: class _Internal; @@ -1109,6 +1157,7 @@ class ColumnInfo final : bool is_primary_key_; bool is_autoid_; bool is_partition_key_; + int element_type_; mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; }; union { Impl_ _impl_; }; @@ -4297,7 +4346,7 @@ class VectorANNS final : kPredicatesFieldNumber = 3, kQueryInfoFieldNumber = 4, kFieldIdFieldNumber = 2, - kIsBinaryFieldNumber = 1, + kVectorTypeFieldNumber = 1, }; // string placeholder_tag = 5; void clear_placeholder_tag(); @@ -4358,13 +4407,13 @@ class VectorANNS final : void _internal_set_field_id(int64_t value); public: - // bool is_binary = 1; - void clear_is_binary(); - bool is_binary() const; - void set_is_binary(bool value); + // .milvus.proto.plan.VectorType vector_type = 1; + void clear_vector_type(); + ::milvus::proto::plan::VectorType vector_type() const; + void set_vector_type(::milvus::proto::plan::VectorType value); private: - bool _internal_is_binary() const; - void _internal_set_is_binary(bool value); + ::milvus::proto::plan::VectorType _internal_vector_type() const; + void _internal_set_vector_type(::milvus::proto::plan::VectorType value); public: // @@protoc_insertion_point(class_scope:milvus.proto.plan.VectorANNS) @@ -4379,7 +4428,7 @@ class VectorANNS final : ::milvus::proto::plan::Expr* predicates_; ::milvus::proto::plan::QueryInfo* query_info_; int64_t field_id_; - bool is_binary_; + int vector_type_; mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; }; union { Impl_ _impl_; }; @@ -5156,6 +5205,26 @@ inline void Array::set_same_type(bool value) { // @@protoc_insertion_point(field_set:milvus.proto.plan.Array.same_type) } +// .milvus.proto.schema.DataType element_type = 3; +inline void Array::clear_element_type() { + _impl_.element_type_ = 0; +} +inline ::milvus::proto::schema::DataType Array::_internal_element_type() const { + return static_cast< ::milvus::proto::schema::DataType >(_impl_.element_type_); +} +inline ::milvus::proto::schema::DataType Array::element_type() const { + // @@protoc_insertion_point(field_get:milvus.proto.plan.Array.element_type) + return _internal_element_type(); +} +inline void Array::_internal_set_element_type(::milvus::proto::schema::DataType value) { + + _impl_.element_type_ = value; +} +inline void Array::set_element_type(::milvus::proto::schema::DataType value) { + _internal_set_element_type(value); + // @@protoc_insertion_point(field_set:milvus.proto.plan.Array.element_type) +} + // ------------------------------------------------------------------- // QueryInfo @@ -5479,6 +5548,26 @@ inline void ColumnInfo::set_is_partition_key(bool value) { // @@protoc_insertion_point(field_set:milvus.proto.plan.ColumnInfo.is_partition_key) } +// .milvus.proto.schema.DataType element_type = 7; +inline void ColumnInfo::clear_element_type() { + _impl_.element_type_ = 0; +} +inline ::milvus::proto::schema::DataType ColumnInfo::_internal_element_type() const { + return static_cast< ::milvus::proto::schema::DataType >(_impl_.element_type_); +} +inline ::milvus::proto::schema::DataType ColumnInfo::element_type() const { + // @@protoc_insertion_point(field_get:milvus.proto.plan.ColumnInfo.element_type) + return _internal_element_type(); +} +inline void ColumnInfo::_internal_set_element_type(::milvus::proto::schema::DataType value) { + + _impl_.element_type_ = value; +} +inline void ColumnInfo::set_element_type(::milvus::proto::schema::DataType value) { + _internal_set_element_type(value); + // @@protoc_insertion_point(field_set:milvus.proto.plan.ColumnInfo.element_type) +} + // ------------------------------------------------------------------- // ColumnExpr @@ -8834,24 +8923,24 @@ inline Expr::ExprCase Expr::expr_case() const { // VectorANNS -// bool is_binary = 1; -inline void VectorANNS::clear_is_binary() { - _impl_.is_binary_ = false; +// .milvus.proto.plan.VectorType vector_type = 1; +inline void VectorANNS::clear_vector_type() { + _impl_.vector_type_ = 0; } -inline bool VectorANNS::_internal_is_binary() const { - return _impl_.is_binary_; +inline ::milvus::proto::plan::VectorType VectorANNS::_internal_vector_type() const { + return static_cast< ::milvus::proto::plan::VectorType >(_impl_.vector_type_); } -inline bool VectorANNS::is_binary() const { - // @@protoc_insertion_point(field_get:milvus.proto.plan.VectorANNS.is_binary) - return _internal_is_binary(); +inline ::milvus::proto::plan::VectorType VectorANNS::vector_type() const { + // @@protoc_insertion_point(field_get:milvus.proto.plan.VectorANNS.vector_type) + return _internal_vector_type(); } -inline void VectorANNS::_internal_set_is_binary(bool value) { +inline void VectorANNS::_internal_set_vector_type(::milvus::proto::plan::VectorType value) { - _impl_.is_binary_ = value; + _impl_.vector_type_ = value; } -inline void VectorANNS::set_is_binary(bool value) { - _internal_set_is_binary(value); - // @@protoc_insertion_point(field_set:milvus.proto.plan.VectorANNS.is_binary) +inline void VectorANNS::set_vector_type(::milvus::proto::plan::VectorType value) { + _internal_set_vector_type(value); + // @@protoc_insertion_point(field_set:milvus.proto.plan.VectorANNS.vector_type) } // int64 field_id = 2; @@ -9599,6 +9688,11 @@ template <> inline const EnumDescriptor* GetEnumDescriptor< ::milvus::proto::plan::ArithOpType>() { return ::milvus::proto::plan::ArithOpType_descriptor(); } +template <> struct is_proto_enum< ::milvus::proto::plan::VectorType> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::milvus::proto::plan::VectorType>() { + return ::milvus::proto::plan::VectorType_descriptor(); +} PROTOBUF_NAMESPACE_CLOSE diff --git a/internal/core/src/query/CMakeLists.txt b/internal/core/src/query/CMakeLists.txt index c476e476a95d0..2674cbd5204a8 100644 --- a/internal/core/src/query/CMakeLists.txt +++ b/internal/core/src/query/CMakeLists.txt @@ -10,7 +10,6 @@ # or implied. See the License for the specific language governing permissions and limitations under the License set(MILVUS_QUERY_SRCS - deprecated/BinaryQuery.cpp generated/PlanNode.cpp generated/Expr.cpp visitors/ShowPlanNodeVisitor.cpp @@ -32,6 +31,6 @@ set(MILVUS_QUERY_SRCS add_library(milvus_query ${MILVUS_QUERY_SRCS}) if(USE_DYNAMIC_SIMD) target_link_libraries(milvus_query milvus_index milvus_simd) -else() +else() target_link_libraries(milvus_query milvus_index) endif() diff --git a/internal/core/src/query/Expr.h b/internal/core/src/query/Expr.h index 839e10a3e036a..93a52f0076c3c 100644 --- a/internal/core/src/query/Expr.h +++ b/internal/core/src/query/Expr.h @@ -310,3 +310,50 @@ IsTermExpr(Expr* expr) { } } // namespace milvus::query + +template <> +struct fmt::formatter + : formatter { + auto + format(milvus::query::LogicalUnaryExpr::OpType c, + format_context& ctx) const { + string_view name = "unknown"; + switch (c) { + case milvus::query::LogicalUnaryExpr::OpType::Invalid: + name = "Invalid"; + break; + case milvus::query::LogicalUnaryExpr::OpType::LogicalNot: + name = "LogicalNot"; + break; + } + return formatter::format(name, ctx); + } +}; + +template <> +struct fmt::formatter + : formatter { + auto + format(milvus::query::LogicalBinaryExpr::OpType c, + format_context& ctx) const { + string_view name = "unknown"; + switch (c) { + case milvus::query::LogicalBinaryExpr::OpType::Invalid: + name = "Invalid"; + break; + case milvus::query::LogicalBinaryExpr::OpType::LogicalAnd: + name = "LogicalAdd"; + break; + case milvus::query::LogicalBinaryExpr::OpType::LogicalOr: + name = "LogicalOr"; + break; + case milvus::query::LogicalBinaryExpr::OpType::LogicalXor: + name = "LogicalXor"; + break; + case milvus::query::LogicalBinaryExpr::OpType::LogicalMinus: + name = "LogicalMinus"; + break; + } + return formatter::format(name, ctx); + } +}; diff --git a/internal/core/src/query/PlanImpl.h b/internal/core/src/query/PlanImpl.h index 6fe94a82bcf01..d015387f63d22 100644 --- a/internal/core/src/query/PlanImpl.h +++ b/internal/core/src/query/PlanImpl.h @@ -19,8 +19,8 @@ #include "Plan.h" #include "PlanNode.h" -#include "exceptions/EasyAssert.h" -#include "utils/Json.h" +#include "common/EasyAssert.h" +#include "common/Json.h" #include "common/Consts.h" namespace milvus::query { diff --git a/internal/core/src/query/PlanNode.h b/internal/core/src/query/PlanNode.h index 40b62e6e3d7b6..18f7af49e5fe0 100644 --- a/internal/core/src/query/PlanNode.h +++ b/internal/core/src/query/PlanNode.h @@ -52,6 +52,12 @@ struct BinaryVectorANNS : VectorPlanNode { accept(PlanNodeVisitor&) override; }; +struct Float16VectorANNS : VectorPlanNode { + public: + void + accept(PlanNodeVisitor&) override; +}; + struct RetrievePlanNode : PlanNode { public: void diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index b7b8078418a9e..021fece0ce08f 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -18,7 +18,7 @@ #include "ExprImpl.h" #include "common/VectorTrait.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "generated/ExtractInfoExprVisitor.h" #include "generated/ExtractInfoPlanNodeVisitor.h" #include "pb/plan.pb.h" @@ -87,6 +87,9 @@ ExtractUnaryRangeExprImpl(FieldId field_id, } else if constexpr (std::is_same_v) { Assert(value_proto.val_case() == planpb::GenericValue::kStringVal); return static_cast(value_proto.string_val()); + } else if constexpr (std::is_same_v) { + Assert(value_proto.val_case() == planpb::GenericValue::kArrayVal); + return static_cast(value_proto.array_val()); } else { static_assert(always_false); } @@ -151,6 +154,15 @@ ExtractBinaryArithOpEvalRangeExprImpl( static_assert(always_false); } }; + if (expr_proto.arith_op() == proto::plan::ArrayLength) { + return std::make_unique>( + expr_proto.column_info(), + expr_proto.value().val_case(), + expr_proto.arith_op(), + 0, + expr_proto.op(), + getValue(expr_proto.value())); + } return std::make_unique>( expr_proto.column_info(), expr_proto.value().val_case(), @@ -182,11 +194,16 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { search_info.metric_type_ = query_info_proto.metric_type(); search_info.topk_ = query_info_proto.topk(); search_info.round_decimal_ = query_info_proto.round_decimal(); - search_info.search_params_ = json::parse(query_info_proto.search_params()); + search_info.search_params_ = + nlohmann::json::parse(query_info_proto.search_params()); auto plan_node = [&]() -> std::unique_ptr { - if (anns_proto.is_binary()) { + if (anns_proto.vector_type() == + milvus::proto::plan::VectorType::BinaryVector) { return std::make_unique(); + } else if (anns_proto.vector_type() == + milvus::proto::plan::VectorType::Float16Vector) { + return std::make_unique(); } else { return std::make_unique(); } @@ -302,7 +319,8 @@ ProtoParser::ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb) { return ExtractUnaryRangeExprImpl( field_id, data_type, expr_pb); } - case DataType::JSON: { + case DataType::JSON: + case DataType::ARRAY: { switch (expr_pb.value().val_case()) { case proto::plan::GenericValue::ValCase::kBoolVal: return ExtractUnaryRangeExprImpl( @@ -316,14 +334,18 @@ ProtoParser::ParseUnaryRangeExpr(const proto::plan::UnaryRangeExpr& expr_pb) { case proto::plan::GenericValue::ValCase::kStringVal: return ExtractUnaryRangeExprImpl( field_id, data_type, expr_pb); + case proto::plan::GenericValue::ValCase::kArrayVal: + return ExtractUnaryRangeExprImpl( + field_id, data_type, expr_pb); default: PanicInfo( + DataTypeInvalid, fmt::format("unknown data type: {} in expression", expr_pb.value().val_case())); } } default: { - PanicInfo("unsupported data type"); + PanicInfo(DataTypeInvalid, "unsupported data type"); } } }(); @@ -380,12 +402,34 @@ ProtoParser::ParseBinaryRangeExpr(const proto::plan::BinaryRangeExpr& expr_pb) { return ExtractBinaryRangeExprImpl( field_id, data_type, expr_pb); default: - PanicInfo("unknown data type in expression"); + PanicInfo( + DataTypeInvalid, + fmt::format("unknown data type in expression {}", + data_type)); + } + } + case DataType::ARRAY: { + switch (expr_pb.lower_value().val_case()) { + case proto::plan::GenericValue::ValCase::kInt64Val: + return ExtractBinaryRangeExprImpl( + field_id, data_type, expr_pb); + case proto::plan::GenericValue::ValCase::kFloatVal: + return ExtractBinaryRangeExprImpl( + field_id, data_type, expr_pb); + case proto::plan::GenericValue::ValCase::kStringVal: + return ExtractBinaryRangeExprImpl( + field_id, data_type, expr_pb); + default: + PanicInfo( + DataTypeInvalid, + fmt::format("unknown data type in expression {}", + data_type)); } } default: { - PanicInfo("unsupported data type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", data_type)); } } }(); @@ -477,12 +521,39 @@ ProtoParser::ParseTermExpr(const proto::plan::TermExpr& expr_pb) { field_id, data_type, expr_pb); default: PanicInfo( + DataTypeInvalid, + fmt::format("unknown data type: {} in expression", + expr_pb.values()[0].val_case())); + } + } + case DataType::ARRAY: { + if (expr_pb.values().size() == 0) { + return ExtractTermExprImpl( + field_id, data_type, expr_pb); + } + switch (expr_pb.values()[0].val_case()) { + case proto::plan::GenericValue::ValCase::kBoolVal: + return ExtractTermExprImpl( + field_id, data_type, expr_pb); + case proto::plan::GenericValue::ValCase::kFloatVal: + return ExtractTermExprImpl( + field_id, data_type, expr_pb); + case proto::plan::GenericValue::ValCase::kInt64Val: + return ExtractTermExprImpl( + field_id, data_type, expr_pb); + case proto::plan::GenericValue::ValCase::kStringVal: + return ExtractTermExprImpl( + field_id, data_type, expr_pb); + default: + PanicInfo( + DataTypeInvalid, fmt::format("unknown data type: {} in expression", expr_pb.values()[0].val_case())); } } default: { - PanicInfo("unsupported data type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", data_type)); } } }(); @@ -541,13 +612,30 @@ ProtoParser::ParseBinaryArithOpEvalRangeExpr( return ExtractBinaryArithOpEvalRangeExprImpl( field_id, data_type, expr_pb); default: - PanicInfo(fmt::format( - "unsupported data type {} in expression", - expr_pb.value().val_case())); + PanicInfo(DataTypeInvalid, + fmt::format( + "unsupported data type {} in expression", + expr_pb.value().val_case())); + } + } + case DataType::ARRAY: { + switch (expr_pb.value().val_case()) { + case proto::plan::GenericValue::ValCase::kInt64Val: + return ExtractBinaryArithOpEvalRangeExprImpl( + field_id, data_type, expr_pb); + case proto::plan::GenericValue::ValCase::kFloatVal: + return ExtractBinaryArithOpEvalRangeExprImpl( + field_id, data_type, expr_pb); + default: + PanicInfo(DataTypeInvalid, + fmt::format( + "unsupported data type {} in expression", + expr_pb.value().val_case())); } } default: { - PanicInfo("unsupported data type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", data_type)); } } }(); @@ -572,7 +660,8 @@ ProtoParser::ParseExistExpr(const proto::plan::ExistsExpr& expr_pb) { return ExtractExistsExprImpl(expr_pb); } default: { - PanicInfo("unsupported data type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", data_type)); } } }(); @@ -640,7 +729,7 @@ ProtoParser::ParseJsonContainsExpr( // auto& field_meta = schema[field_offset]; auto result = [&]() -> ExprPtr { if (expr_pb.elements_size() == 0) { - PanicInfo("no elements in expression"); + PanicInfo(DataIsEmpty, "no elements in expression"); } if (expr_pb.elements_same_type()) { switch (expr_pb.elements(0).val_case()) { @@ -656,7 +745,9 @@ ProtoParser::ParseJsonContainsExpr( return ExtractJsonContainsExprImpl( expr_pb); default: - PanicInfo("unsupported data type"); + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported data type {}", data_type)); } } return ExtractJsonContainsExprImpl(expr_pb); @@ -702,7 +793,8 @@ ProtoParser::ParseExpr(const proto::plan::Expr& expr_pb) { default: { std::string s; google::protobuf::TextFormat::PrintToString(expr_pb, &s); - PanicInfo(std::string("unsupported expr proto node: ") + s); + PanicInfo(ExprInvalid, + fmt::format("unsupported expr proto node: {}", s)); } } } diff --git a/internal/core/src/query/PlanProto.h b/internal/core/src/query/PlanProto.h index b6797488b6bbe..806ff62d604fa 100644 --- a/internal/core/src/query/PlanProto.h +++ b/internal/core/src/query/PlanProto.h @@ -98,6 +98,9 @@ struct fmt::formatter case milvus::proto::plan::GenericValue::ValCase::kStringVal: name = "kStringVal"; break; + case milvus::proto::plan::GenericValue::ValCase::kArrayVal: + name = "kArrayVal"; + break; case milvus::proto::plan::GenericValue::ValCase::VAL_NOT_SET: name = "VAL_NOT_SET"; break; diff --git a/internal/core/src/query/Relational.h b/internal/core/src/query/Relational.h index e96ec60b0b0d2..1839221db65ce 100644 --- a/internal/core/src/query/Relational.h +++ b/internal/core/src/query/Relational.h @@ -16,7 +16,7 @@ #include "common/Utils.h" #include "common/VectorTrait.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "query/Expr.h" #include "query/Utils.h" @@ -30,13 +30,13 @@ RelationalImpl(const T& t, const U& u, FundamentalTag, FundamentalTag) { template bool RelationalImpl(const T& t, const U& u, FundamentalTag, StringTag) { - PanicInfo("incompitible data type"); + PanicInfo(DataTypeInvalid, "incompitible data type"); } template bool RelationalImpl(const T& t, const U& u, StringTag, FundamentalTag) { - PanicInfo("incompitible data type"); + PanicInfo(DataTypeInvalid, "incompitible data type"); } template @@ -59,7 +59,7 @@ struct Relational { template bool operator()(const T&...) const { - PanicInfo("incompatible operands"); + PanicInfo(OpTypeInvalid, "incompatible operands"); } }; diff --git a/internal/core/src/query/ScalarIndex.h b/internal/core/src/query/ScalarIndex.h index 33794da454eed..b3ea232cc165a 100644 --- a/internal/core/src/query/ScalarIndex.h +++ b/internal/core/src/query/ScalarIndex.h @@ -58,7 +58,8 @@ generate_scalar_index(SpanBase data, DataType data_type) { case DataType::VARCHAR: return generate_scalar_index(Span(data)); default: - PanicInfo("unsupported type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported type {}", data_type)); } } diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index b7ae70d86b95a..4e6ea4bd64087 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -13,6 +13,7 @@ #include #include "common/Consts.h" +#include "common/EasyAssert.h" #include "common/RangeSearchHelper.h" #include "common/Utils.h" #include "common/Tracer.h" @@ -30,7 +31,8 @@ CheckBruteForceSearchParam(const FieldMeta& field, AssertInfo(datatype_is_vector(data_type), "[BruteForceSearch] Data type isn't vector type"); - bool is_float_data_type = (data_type == DataType::VECTOR_FLOAT); + bool is_float_data_type = (data_type == DataType::VECTOR_FLOAT || + data_type == DataType::VECTOR_FLOAT16); bool is_float_metric_type = IsFloatMetricType(metric_type); AssertInfo(is_float_data_type == is_float_metric_type, "[BruteForceSearch] Data type and metric type miss-match"); @@ -41,69 +43,94 @@ BruteForceSearch(const dataset::SearchDataset& dataset, const void* chunk_data_raw, int64_t chunk_rows, const knowhere::Json& conf, - const BitsetView& bitset) { + const BitsetView& bitset, + DataType data_type) { SubSearchResult sub_result(dataset.num_queries, dataset.topk, dataset.metric_type, dataset.round_decimal); - try { - auto nq = dataset.num_queries; - auto dim = dataset.dim; - auto topk = dataset.topk; + auto nq = dataset.num_queries; + auto dim = dataset.dim; + auto topk = dataset.topk; - auto base_dataset = - knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw); - auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data); - auto config = knowhere::Json{ - {knowhere::meta::METRIC_TYPE, dataset.metric_type}, - {knowhere::meta::DIM, dim}, - {knowhere::meta::TOPK, topk}, - }; + auto base_dataset = knowhere::GenDataSet(chunk_rows, dim, chunk_data_raw); + auto query_dataset = knowhere::GenDataSet(nq, dim, dataset.query_data); - sub_result.mutable_seg_offsets().resize(nq * topk); - sub_result.mutable_distances().resize(nq * topk); + if (data_type == DataType::VECTOR_FLOAT16) { + // Todo: Temporarily use cast to float32 to achieve, need to optimize + // first, First, transfer the cast to knowhere part + // second, knowhere partially supports float16 and removes the forced conversion to float32 + auto xb = base_dataset->GetTensor(); + std::vector float_xb(base_dataset->GetRows() * + base_dataset->GetDim()); - if (conf.contains(RADIUS)) { - config[RADIUS] = conf[RADIUS].get(); - if (conf.contains(RANGE_FILTER)) { - config[RANGE_FILTER] = conf[RANGE_FILTER].get(); - CheckRangeSearchParam( - config[RADIUS], config[RANGE_FILTER], dataset.metric_type); - } - auto res = knowhere::BruteForce::RangeSearch( - base_dataset, query_dataset, config, bitset); - milvus::tracer::AddEvent("knowhere_finish_BruteForce_RangeSearch"); - if (!res.has_value()) { - PanicCodeInfo(ErrorCodeEnum::UnexpectedError, - fmt::format("failed to range search: {}: {}", - KnowhereStatusString(res.error()), - res.what())); - } - auto result = ReGenRangeSearchResult( - res.value(), topk, nq, dataset.metric_type); - milvus::tracer::AddEvent("ReGenRangeSearchResult"); - std::copy_n( - GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets()); - std::copy_n(GetDatasetDistance(result), - nq * topk, - sub_result.get_distances()); - } else { - auto stat = knowhere::BruteForce::SearchWithBuf( - base_dataset, - query_dataset, - sub_result.mutable_seg_offsets().data(), - sub_result.mutable_distances().data(), - config, - bitset); - milvus::tracer::AddEvent( - "knowhere_finish_BruteForce_SearchWithBuf"); - if (stat != knowhere::Status::success) { - throw std::invalid_argument("invalid metric type, " + - KnowhereStatusString(stat)); - } + auto xq = query_dataset->GetTensor(); + std::vector float_xq(query_dataset->GetRows() * + query_dataset->GetDim()); + + auto fp16_xb = static_cast(xb); + for (int i = 0; i < base_dataset->GetRows() * base_dataset->GetDim(); + i++) { + float_xb[i] = (float)fp16_xb[i]; + } + + auto fp16_xq = static_cast(xq); + for (int i = 0; i < query_dataset->GetRows() * query_dataset->GetDim(); + i++) { + float_xq[i] = (float)fp16_xq[i]; + } + void* void_ptr_xb = static_cast(float_xb.data()); + void* void_ptr_xq = static_cast(float_xq.data()); + base_dataset = knowhere::GenDataSet(chunk_rows, dim, void_ptr_xb); + query_dataset = knowhere::GenDataSet(nq, dim, void_ptr_xq); + } + + auto config = knowhere::Json{ + {knowhere::meta::METRIC_TYPE, dataset.metric_type}, + {knowhere::meta::DIM, dim}, + {knowhere::meta::TOPK, topk}, + }; + + sub_result.mutable_seg_offsets().resize(nq * topk); + sub_result.mutable_distances().resize(nq * topk); + + if (conf.contains(RADIUS)) { + config[RADIUS] = conf[RADIUS].get(); + if (conf.contains(RANGE_FILTER)) { + config[RANGE_FILTER] = conf[RANGE_FILTER].get(); + CheckRangeSearchParam( + config[RADIUS], config[RANGE_FILTER], dataset.metric_type); + } + auto res = knowhere::BruteForce::RangeSearch( + base_dataset, query_dataset, config, bitset); + milvus::tracer::AddEvent("knowhere_finish_BruteForce_RangeSearch"); + if (!res.has_value()) { + PanicInfo(KnowhereError, + fmt::format("failed to range search: {}: {}", + KnowhereStatusString(res.error()), + res.what())); + } + auto result = + ReGenRangeSearchResult(res.value(), topk, nq, dataset.metric_type); + milvus::tracer::AddEvent("ReGenRangeSearchResult"); + std::copy_n( + GetDatasetIDs(result), nq * topk, sub_result.get_seg_offsets()); + std::copy_n( + GetDatasetDistance(result), nq * topk, sub_result.get_distances()); + } else { + auto stat = knowhere::BruteForce::SearchWithBuf( + base_dataset, + query_dataset, + sub_result.mutable_seg_offsets().data(), + sub_result.mutable_distances().data(), + config, + bitset); + milvus::tracer::AddEvent("knowhere_finish_BruteForce_SearchWithBuf"); + if (stat != knowhere::Status::success) { + throw SegcoreError( + KnowhereError, + "invalid metric type, " + KnowhereStatusString(stat)); } - } catch (std::exception& e) { - PanicInfo(e.what()); } sub_result.round_values(); return sub_result; diff --git a/internal/core/src/query/SearchBruteForce.h b/internal/core/src/query/SearchBruteForce.h index 122ec762ab204..882b0955960b4 100644 --- a/internal/core/src/query/SearchBruteForce.h +++ b/internal/core/src/query/SearchBruteForce.h @@ -28,6 +28,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset, const void* chunk_data_raw, int64_t chunk_rows, const knowhere::Json& conf, - const BitsetView& bitset); + const BitsetView& bitset, + DataType data_type = DataType::VECTOR_FLOAT); } // namespace milvus::query diff --git a/internal/core/src/query/SearchOnGrowing.cpp b/internal/core/src/query/SearchOnGrowing.cpp index c44e37d1182f9..ebdbe3db6fe6e 100644 --- a/internal/core/src/query/SearchOnGrowing.cpp +++ b/internal/core/src/query/SearchOnGrowing.cpp @@ -123,7 +123,9 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, chunk_data, size_per_chunk, info.search_params_, - sub_view); + sub_view, + data_type); + // convert chunk uid to segment uid for (auto& x : sub_qr.mutable_seg_offsets()) { if (x != -1) { diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index b4b5e2c809dee..3ce0d2e3f1698 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -84,9 +84,14 @@ SearchOnSealed(const Schema& schema, field.get_dim(), query_data}; + auto data_type = field.get_data_type(); CheckBruteForceSearchParam(field, search_info); - auto sub_qr = BruteForceSearch( - dataset, vec_data, row_count, search_info.search_params_, bitset); + auto sub_qr = BruteForceSearch(dataset, + vec_data, + row_count, + search_info.search_params_, + bitset, + data_type); result.distances_ = std::move(sub_qr.mutable_distances()); result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets()); diff --git a/internal/core/src/query/SubSearchResult.cpp b/internal/core/src/query/SubSearchResult.cpp index e25b1090ff463..d9e34b0b76c0d 100644 --- a/internal/core/src/query/SubSearchResult.cpp +++ b/internal/core/src/query/SubSearchResult.cpp @@ -11,7 +11,7 @@ #include -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "query/SubSearchResult.h" namespace milvus::query { diff --git a/internal/core/src/query/Utils.h b/internal/core/src/query/Utils.h index 30a5d2bad4eb8..8e7ba5170cd04 100644 --- a/internal/core/src/query/Utils.h +++ b/internal/core/src/query/Utils.h @@ -22,7 +22,7 @@ namespace milvus::query { template inline bool Match(const T& x, const U& y, OpType op) { - PanicInfo("not supported"); + PanicInfo(NotImplemented, "not supported"); } template <> @@ -34,7 +34,7 @@ Match(const std::string& str, const std::string& val, OpType op) { case OpType::PostfixMatch: return PostfixMatch(str, val); default: - PanicInfo("not supported"); + PanicInfo(OpTypeInvalid, "not supported"); } } @@ -49,7 +49,7 @@ Match(const std::string_view& str, case OpType::PostfixMatch: return PostfixMatch(str, val); default: - PanicInfo("not supported"); + PanicInfo(OpTypeInvalid, "not supported"); } } diff --git a/internal/core/src/query/deprecated/BinaryQuery.cpp b/internal/core/src/query/deprecated/BinaryQuery.cpp deleted file mode 100644 index 832eb801c26e2..0000000000000 --- a/internal/core/src/query/deprecated/BinaryQuery.cpp +++ /dev/null @@ -1,324 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#include -#include -#include -#include -#include -#include - -#include "BinaryQuery.h" - -namespace milvus { -namespace query_old { - -BinaryQueryPtr -ConstructBinTree(std::vector queries, - QueryRelation relation, - uint64_t idx) { - if (idx == queries.size()) { - return nullptr; - } else if (idx == queries.size() - 1) { - return queries[idx]->getBinaryQuery(); - } else { - BinaryQueryPtr bquery = std::make_shared(); - bquery->relation = relation; - bquery->left_query = std::make_shared(); - bquery->right_query = std::make_shared(); - bquery->left_query->bin = queries[idx]->getBinaryQuery(); - ++idx; - bquery->right_query->bin = ConstructBinTree(queries, relation, idx); - return bquery; - } -} - -Status -ConstructLeafBinTree(std::vector leaf_queries, - BinaryQueryPtr binary_query, - uint64_t idx) { - if (idx == leaf_queries.size()) { - return Status::OK(); - } - binary_query->left_query = std::make_shared(); - binary_query->right_query = std::make_shared(); - if (leaf_queries.size() == leaf_queries.size() - 1) { - binary_query->left_query->leaf = leaf_queries[idx]; - return Status::OK(); - } else if (idx == leaf_queries.size() - 2) { - binary_query->left_query->leaf = leaf_queries[idx]; - ++idx; - binary_query->right_query->leaf = leaf_queries[idx]; - return Status::OK(); - } else { - binary_query->left_query->bin->relation = binary_query->relation; - binary_query->right_query->leaf = leaf_queries[idx]; - ++idx; - return ConstructLeafBinTree( - leaf_queries, binary_query->left_query->bin, idx); - } -} - -Status -GenBinaryQuery(BooleanQueryPtr query, BinaryQueryPtr& binary_query) { - if (query->getBooleanQueries().size() == 0) { - if (binary_query->relation == QueryRelation::AND || - binary_query->relation == QueryRelation::OR) { - // Put VectorQuery to the end of leaf queries - auto query_size = query->getLeafQueries().size(); - for (uint64_t i = 0; i < query_size; ++i) { - if (query->getLeafQueries()[i]->vector_placeholder.size() > 0) { - std::swap(query->getLeafQueries()[i], - query->getLeafQueries()[0]); - break; - } - } - return ConstructLeafBinTree( - query->getLeafQueries(), binary_query, 0); - } else { - switch (query->getOccur()) { - case Occur::MUST: { - binary_query->relation = QueryRelation::AND; - return GenBinaryQuery(query, binary_query); - } - case Occur::MUST_NOT: - binary_query->is_not = true; - case Occur::SHOULD: { - binary_query->relation = QueryRelation::OR; - return GenBinaryQuery(query, binary_query); - } - default: - return Status::OK(); - } - } - } - - if (query->getBooleanQueries().size() == 1) { - auto bc = query->getBooleanQueries()[0]; - binary_query->left_query = std::make_shared(); - switch (bc->getOccur()) { - case Occur::MUST: { - binary_query->relation = QueryRelation::AND; - return GenBinaryQuery(bc, binary_query); - } - case Occur::MUST_NOT: - binary_query->is_not = true; - case Occur::SHOULD: { - binary_query->relation = QueryRelation::OR; - return GenBinaryQuery(bc, binary_query); - } - default: - return Status::OK(); - } - } - - // Construct binary query for every single boolean query - std::vector must_queries; - std::vector must_not_queries; - std::vector should_queries; - Status status; - for (auto& _query : query->getBooleanQueries()) { - status = GenBinaryQuery(_query, _query->getBinaryQuery()); - if (!status.ok()) { - return status; - } - if (_query->getOccur() == Occur::MUST) { - must_queries.emplace_back(_query); - } else if (_query->getOccur() == Occur::MUST_NOT) { - must_not_queries.emplace_back(_query); - } else { - should_queries.emplace_back(_query); - } - } - - // Construct binary query for combine boolean queries - BinaryQueryPtr must_bquery, should_bquery, must_not_bquery; - uint64_t bquery_num = 0; - if (must_queries.size() > 1) { - // Construct a must binary tree - must_bquery = ConstructBinTree(must_queries, QueryRelation::R1, 0); - ++bquery_num; - } else if (must_queries.size() == 1) { - must_bquery = must_queries[0]->getBinaryQuery(); - ++bquery_num; - } - - if (should_queries.size() > 1) { - // Construct a should binary tree - should_bquery = ConstructBinTree(should_queries, QueryRelation::R2, 0); - ++bquery_num; - } else if (should_queries.size() == 1) { - should_bquery = should_queries[0]->getBinaryQuery(); - ++bquery_num; - } - - if (must_not_queries.size() > 1) { - // Construct a must_not binary tree - must_not_bquery = - ConstructBinTree(must_not_queries, QueryRelation::R1, 0); - ++bquery_num; - } else if (must_not_queries.size() == 1) { - must_not_bquery = must_not_queries[0]->getBinaryQuery(); - ++bquery_num; - } - - binary_query->left_query = std::make_shared(); - binary_query->right_query = std::make_shared(); - BinaryQueryPtr must_should_query = std::make_shared(); - must_should_query->left_query = std::make_shared(); - must_should_query->right_query = std::make_shared(); - if (bquery_num == 3) { - must_should_query->relation = QueryRelation::R3; - must_should_query->left_query->bin = must_bquery; - must_should_query->right_query->bin = should_bquery; - binary_query->relation = QueryRelation::R1; - binary_query->left_query->bin = must_should_query; - binary_query->right_query->bin = must_not_bquery; - } else if (bquery_num == 2) { - if (must_bquery == nullptr) { - // should + must_not - binary_query->relation = QueryRelation::R3; - binary_query->left_query->bin = must_not_bquery; - binary_query->right_query->bin = should_bquery; - } else if (should_bquery == nullptr) { - // must + must_not - binary_query->relation = QueryRelation::R4; - binary_query->left_query->bin = must_bquery; - binary_query->right_query->bin = must_not_bquery; - } else { - // must + should - binary_query->relation = QueryRelation::R3; - binary_query->left_query->bin = must_bquery; - binary_query->right_query->bin = should_bquery; - } - } else { - if (must_bquery != nullptr) { - binary_query = must_bquery; - } else if (should_bquery != nullptr) { - binary_query = should_bquery; - } else { - binary_query = must_not_bquery; - } - } - - return Status::OK(); -} - -uint64_t -BinaryQueryHeight(BinaryQueryPtr& binary_query) { - if (binary_query == nullptr) { - return 1; - } - uint64_t left_height = 0, right_height = 0; - if (binary_query->left_query != nullptr) { - left_height = BinaryQueryHeight(binary_query->left_query->bin); - } - if (binary_query->right_query != nullptr) { - right_height = BinaryQueryHeight(binary_query->right_query->bin); - } - return left_height > right_height ? left_height + 1 : right_height + 1; -} - -/** - * rules: - * 1. The child node of 'should' and 'must_not' can only be 'term query' and 'range query'. - * 2. One layer cannot include bool query and leaf query. - * 3. The direct child node of 'bool' node cannot be 'should' node or 'must_not' node. - * 4. All filters are pre-filtered(Do structure query first, then use the result to do filtering for vector query). - * - */ - -Status -rule_1(BooleanQueryPtr& boolean_query, - std::stack& path_stack) { - auto status = Status::OK(); - if (boolean_query != nullptr) { - path_stack.push(boolean_query); - for (const auto& leaf_query : boolean_query->getLeafQueries()) { - if (!leaf_query->vector_placeholder.empty()) { - while (!path_stack.empty()) { - auto query = path_stack.top(); - if (query->getOccur() == Occur::SHOULD || - query->getOccur() == Occur::MUST_NOT) { - std::string msg = - "The child node of 'should' and 'must_not' can " - "only be 'term query' and 'range query'."; - return Status{SERVER_INVALID_DSL_PARAMETER, msg}; - } - path_stack.pop(); - } - } - } - for (auto query : boolean_query->getBooleanQueries()) { - status = rule_1(query, path_stack); - if (!status.ok()) { - return status; - } - } - } - return status; -} - -Status -rule_2(BooleanQueryPtr& boolean_query) { - auto status = Status::OK(); - if (boolean_query != nullptr) { - if (!boolean_query->getBooleanQueries().empty() && - !boolean_query->getLeafQueries().empty()) { - std::string msg = - "One layer cannot include bool query and leaf query."; - return Status{SERVER_INVALID_DSL_PARAMETER, msg}; - } else { - for (auto query : boolean_query->getBooleanQueries()) { - status = rule_2(query); - if (!status.ok()) { - return status; - } - } - } - } - return status; -} - -Status -ValidateBooleanQuery(BooleanQueryPtr& boolean_query) { - auto status = Status::OK(); - if (boolean_query != nullptr) { - for (auto& query : boolean_query->getBooleanQueries()) { - if (query->getOccur() == Occur::SHOULD || - query->getOccur() == Occur::MUST_NOT) { - std::string msg = - "The direct child node of 'bool' node cannot be 'should' " - "node or 'must_not' node."; - return Status{SERVER_INVALID_DSL_PARAMETER, msg}; - } - } - std::stack path_stack; - status = rule_1(boolean_query, path_stack); - if (!status.ok()) { - return status; - } - status = rule_2(boolean_query); - if (!status.ok()) { - return status; - } - } - return status; -} - -bool -ValidateBinaryQuery(BinaryQueryPtr& binary_query) { - uint64_t height = BinaryQueryHeight(binary_query); - return height > 1; -} - -} // namespace query_old -} // namespace milvus diff --git a/internal/core/src/query/deprecated/BinaryQuery.h b/internal/core/src/query/deprecated/BinaryQuery.h deleted file mode 100644 index d99c49a040a9b..0000000000000 --- a/internal/core/src/query/deprecated/BinaryQuery.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include -#include - -#include "BooleanQuery.h" - -namespace milvus { -namespace query_old { - -BinaryQueryPtr -ConstructBinTree(std::vector clauses, - QueryRelation relation, - uint64_t idx); - -Status -ConstructLeafBinTree(std::vector leaf_clauses, - BinaryQueryPtr binary_query, - uint64_t idx); - -Status -GenBinaryQuery(BooleanQueryPtr clause, BinaryQueryPtr& binary_query); - -uint64_t -BinaryQueryHeight(BinaryQueryPtr& binary_query); - -Status -ValidateBooleanQuery(BooleanQueryPtr& boolean_query); - -bool -ValidateBinaryQuery(BinaryQueryPtr& binary_query); - -} // namespace query_old -} // namespace milvus diff --git a/internal/core/src/query/deprecated/BooleanQuery.h b/internal/core/src/query/deprecated/BooleanQuery.h deleted file mode 100644 index 10618d1294cab..0000000000000 --- a/internal/core/src/query/deprecated/BooleanQuery.h +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include -#include - -#include "GeneralQuery.h" -#include "utils/Status.h" - -namespace milvus { -namespace query_old { - -enum class Occur { - INVALID = 0, - MUST, - MUST_NOT, - SHOULD, -}; - -class BooleanQuery { - public: - BooleanQuery() { - } - - explicit BooleanQuery(Occur occur) : occur_(occur) { - } - - Occur - getOccur() { - return occur_; - } - - void - SetOccur(Occur occur) { - occur_ = occur; - } - - void - AddBooleanQuery(std::shared_ptr boolean_clause) { - boolean_clauses_.emplace_back(boolean_clause); - } - - void - AddLeafQuery(LeafQueryPtr leaf_query) { - leaf_queries_.emplace_back(leaf_query); - } - - void - SetLeafQuery(std::vector leaf_queries) { - leaf_queries_ = leaf_queries; - } - - std::vector> - getBooleanQueries() { - return boolean_clauses_; - } - - BinaryQueryPtr& - getBinaryQuery() { - return binary_query_; - } - - std::vector& - getLeafQueries() { - return leaf_queries_; - } - - private: - Occur occur_ = Occur::INVALID; - std::vector> boolean_clauses_; - std::vector leaf_queries_; - BinaryQueryPtr binary_query_ = std::make_shared(); -}; -using BooleanQueryPtr = std::shared_ptr; - -} // namespace query_old -} // namespace milvus diff --git a/internal/core/src/query/deprecated/GeneralQuery.h b/internal/core/src/query/deprecated/GeneralQuery.h deleted file mode 100644 index 55a371d5f3562..0000000000000 --- a/internal/core/src/query/deprecated/GeneralQuery.h +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "common/Types.h" -#include "utils/Json.h" - -namespace milvus { -namespace query_old { - -enum class CompareOperator { - LT = 0, - LTE, - EQ, - GT, - GTE, - NE, -}; - -enum class QueryRelation { - INVALID = 0, - R1, - R2, - R3, - R4, - AND, - OR, -}; - -struct QueryColumn { - std::string name; - std::string column_value; -}; - -struct TermQuery { - milvus::json json_obj; - // std::string field_name; - // std::vector field_value; - // float boost; -}; -using TermQueryPtr = std::shared_ptr; - -struct CompareExpr { - CompareOperator compare_operator; - std::string operand; -}; - -struct RangeQuery { - milvus::json json_obj; - // std::string field_name; - // std::vector compare_expr; - // float boost; -}; -using RangeQueryPtr = std::shared_ptr; - -struct VectorRecord { - size_t vector_count; - std::vector float_data; - std::vector binary_data; -}; - -struct VectorQuery { - std::string field_name; - milvus::json extra_params = {}; - int64_t topk; - int64_t nq; - std::string metric_type = ""; - float boost; - VectorRecord query_vector; -}; -using VectorQueryPtr = std::shared_ptr; - -struct LeafQuery; -using LeafQueryPtr = std::shared_ptr; - -struct BinaryQuery; -using BinaryQueryPtr = std::shared_ptr; - -struct GeneralQuery { - LeafQueryPtr leaf; - BinaryQueryPtr bin = std::make_shared(); -}; -using GeneralQueryPtr = std::shared_ptr; - -struct LeafQuery { - TermQueryPtr term_query; - RangeQueryPtr range_query; - std::string vector_placeholder; - float query_boost; -}; - -struct BinaryQuery { - GeneralQueryPtr left_query; - GeneralQueryPtr right_query; - QueryRelation relation; - float query_boost; - bool is_not = false; -}; - -struct Query { - GeneralQueryPtr root; - std::unordered_map vectors; - - std::string collection_id; - std::vector partitions; - std::vector field_names; - std::set index_fields; - std::unordered_map metric_types; - std::string index_type; -}; - -using QueryPtr = std::shared_ptr; - -} // namespace query_old - -namespace query { -struct QueryDeprecated { - int64_t num_queries; // - int topK; // topK of queries - std::string - field_name; // must be fakevec, whose data_type must be VEC_FLOAT(DIM) - std::vector query_raw_data; // must be size of num_queries * DIM -}; - -// std::unique_ptr CreateNaiveQueryPtr(int64_t num_queries, int topK, std::string& field_name, const float* -// raw_data) { -// return std:: -//} - -using QueryDeprecatedPtr = std::shared_ptr; -} // namespace query -} // namespace milvus diff --git a/internal/core/src/query/deprecated/ParserDeprecated.cpp b/internal/core/src/query/deprecated/ParserDeprecated.cpp deleted file mode 100644 index 149683f434dcf..0000000000000 --- a/internal/core/src/query/deprecated/ParserDeprecated.cpp +++ /dev/null @@ -1,295 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#include -#include -#include "ParserDeprecated.h" - -namespace milvus::wtf { -using google::protobuf::RepeatedField; -using google::protobuf::RepeatedPtrField; -#if 1 -#if 0 -void -CopyRowRecords(const RepeatedPtrField& grpc_records, - const RepeatedField& grpc_id_array, - engine::VectorsData& vectors - ) { - // step 1: copy vector data - int64_t float_data_size = 0, binary_data_size = 0; - - for (auto& record : grpc_records) { - float_data_size += record.float_data_size(); - binary_data_size += record.binary_data().size(); - } - - std::vector float_array(float_data_size, 0.0f); - std::vector binary_array(binary_data_size, 0); - int64_t offset = 0; - if (float_data_size > 0) { - for (auto& record : grpc_records) { - memcpy(&float_array[offset], record.float_data().data(), record.float_data_size() * sizeof(float)); - offset += record.float_data_size(); - } - } else if (binary_data_size > 0) { - for (auto& record : grpc_records) { - memcpy(&binary_array[offset], record.binary_data().data(), record.binary_data().size()); - offset += record.binary_data().size(); - } - } - - // step 2: copy id array - std::vector id_array; - if (grpc_id_array.size() > 0) { - id_array.resize(grpc_id_array.size()); - memcpy(id_array.data(), grpc_id_array.data(), grpc_id_array.size() * sizeof(int64_t)); - } - - // step 3: construct vectors - vectors.vector_count_ = grpc_records.size(); - vectors.float_data_.swap(float_array); - vectors.binary_data_.swap(binary_array); - vectors.id_array_.swap(id_array); -} -#endif - -Status -ProcessLeafQueryJson(const milvus::json& query_json, - query_old::BooleanQueryPtr& query, - std::string& field_name) { -#if 1 - if (query_json.contains("term")) { - auto leaf_query = std::make_shared(); - auto term_query = std::make_shared(); - milvus::json json_obj = query_json["term"]; - JSON_NULL_CHECK(json_obj); - JSON_OBJECT_CHECK(json_obj); - term_query->json_obj = json_obj; - milvus::json::iterator json_it = json_obj.begin(); - field_name = json_it.key(); - leaf_query->term_query = term_query; - query->AddLeafQuery(leaf_query); - } else if (query_json.contains("range")) { - auto leaf_query = std::make_shared(); - auto range_query = std::make_shared(); - milvus::json json_obj = query_json["range"]; - JSON_NULL_CHECK(json_obj); - JSON_OBJECT_CHECK(json_obj); - range_query->json_obj = json_obj; - milvus::json::iterator json_it = json_obj.begin(); - field_name = json_it.key(); - - leaf_query->range_query = range_query; - query->AddLeafQuery(leaf_query); - } else if (query_json.contains("vector")) { - auto leaf_query = std::make_shared(); - auto vector_json = query_json["vector"]; - JSON_NULL_CHECK(vector_json); - - leaf_query->vector_placeholder = vector_json.get(); - query->AddLeafQuery(leaf_query); - } else { - return Status{SERVER_INVALID_ARGUMENT, "Leaf query get wrong key"}; - } -#endif - return Status::OK(); -} - -Status -ProcessBooleanQueryJson(const milvus::json& query_json, - query_old::BooleanQueryPtr& boolean_query, - query_old::QueryPtr& query_ptr) { -#if 1 - if (query_json.empty()) { - return Status{SERVER_INVALID_ARGUMENT, "BoolQuery is null"}; - } - for (auto& el : query_json.items()) { - if (el.key() == "must") { - boolean_query->SetOccur(query_old::Occur::MUST); - auto must_json = el.value(); - if (!must_json.is_array()) { - std::string msg = "Must json string is not an array"; - return Status{SERVER_INVALID_DSL_PARAMETER, msg}; - } - - for (auto& json : must_json) { - auto must_query = std::make_shared(); - if (json.contains("must") || json.contains("should") || - json.contains("must_not")) { - STATUS_CHECK( - ProcessBooleanQueryJson(json, must_query, query_ptr)); - boolean_query->AddBooleanQuery(must_query); - } else { - std::string field_name; - STATUS_CHECK( - ProcessLeafQueryJson(json, boolean_query, field_name)); - if (!field_name.empty()) { - query_ptr->index_fields.insert(field_name); - } - } - } - } else if (el.key() == "should") { - boolean_query->SetOccur(query_old::Occur::SHOULD); - auto should_json = el.value(); - if (!should_json.is_array()) { - std::string msg = "Should json string is not an array"; - return Status{SERVER_INVALID_DSL_PARAMETER, msg}; - } - - for (auto& json : should_json) { - auto should_query = std::make_shared(); - if (json.contains("must") || json.contains("should") || - json.contains("must_not")) { - STATUS_CHECK( - ProcessBooleanQueryJson(json, should_query, query_ptr)); - boolean_query->AddBooleanQuery(should_query); - } else { - std::string field_name; - STATUS_CHECK( - ProcessLeafQueryJson(json, boolean_query, field_name)); - if (!field_name.empty()) { - query_ptr->index_fields.insert(field_name); - } - } - } - } else if (el.key() == "must_not") { - boolean_query->SetOccur(query_old::Occur::MUST_NOT); - auto should_json = el.value(); - if (!should_json.is_array()) { - std::string msg = "Must_not json string is not an array"; - return Status{SERVER_INVALID_DSL_PARAMETER, msg}; - } - - for (auto& json : should_json) { - if (json.contains("must") || json.contains("should") || - json.contains("must_not")) { - auto must_not_query = - std::make_shared(); - STATUS_CHECK(ProcessBooleanQueryJson( - json, must_not_query, query_ptr)); - boolean_query->AddBooleanQuery(must_not_query); - } else { - std::string field_name; - STATUS_CHECK( - ProcessLeafQueryJson(json, boolean_query, field_name)); - if (!field_name.empty()) { - query_ptr->index_fields.insert(field_name); - } - } - } - } else { - std::string msg = - "BoolQuery json string does not include bool query"; - return Status{SERVER_INVALID_DSL_PARAMETER, msg}; - } - } -#endif - return Status::OK(); -} - -Status -DeserializeJsonToBoolQuery(const google::protobuf::RepeatedPtrField< - ::milvus::grpc::VectorParam>& vector_params, - const std::string_view dsl_string, - query_old::BooleanQueryPtr& boolean_query, - query_old::QueryPtr& query_ptr) { -#if 1 - try { - milvus::json dsl_json = Json::parse(dsl_string); - - if (dsl_json.empty()) { - return Status{SERVER_INVALID_ARGUMENT, "Query dsl is null"}; - } - auto status = Status::OK(); - if (vector_params.empty()) { - return Status(SERVER_INVALID_DSL_PARAMETER, - "DSL must include vector query"); - } - for (const auto& vector_param : vector_params) { - const std::string_view vector_string = vector_param.json(); - milvus::json vector_json = Json::parse(vector_string); - milvus::json::iterator it = vector_json.begin(); - std::string placeholder = it.key(); - - auto vector_query = std::make_shared(); - milvus::json::iterator vector_param_it = it.value().begin(); - if (vector_param_it != it.value().end()) { - const std::string_view field_name = vector_param_it.key(); - vector_query->field_name = field_name; - milvus::json param_json = vector_param_it.value(); - int64_t topk = param_json["topk"]; - // STATUS_CHECK(server::ValidateSearchTopk(topk)); - vector_query->topk = topk; - if (param_json.contains("metric_type")) { - std::string metric_type = param_json["metric_type"]; - vector_query->metric_type = metric_type; - query_ptr->metric_types.insert( - {field_name, param_json["metric_type"]}); - } - if (!vector_param_it.value()["params"].empty()) { - vector_query->extra_params = - vector_param_it.value()["params"]; - } - query_ptr->index_fields.insert(field_name); - } - - engine::VectorsData vector_data; - CopyRowRecords( - vector_param.row_record().records(), - google::protobuf::RepeatedField(), - vector_data); - vector_query->query_vector.vector_count = vector_data.vector_count_; - vector_query->query_vector.binary_data.swap( - vector_data.binary_data_); - vector_query->query_vector.float_data.swap(vector_data.float_data_); - - query_ptr->vectors.insert( - std::make_pair(placeholder, vector_query)); - } - if (dsl_json.contains("bool")) { - auto boolean_query_json = dsl_json["bool"]; - JSON_NULL_CHECK(boolean_query_json); - status = ProcessBooleanQueryJson( - boolean_query_json, boolean_query, query_ptr); - if (!status.ok()) { - return Status(SERVER_INVALID_DSL_PARAMETER, - "DSL does not include bool"); - } - } else { - return Status(SERVER_INVALID_DSL_PARAMETER, - "DSL does not include bool query"); - } - return Status::OK(); - } catch (std::exception& e) { - return Status(SERVER_INVALID_DSL_PARAMETER, e.what()); - } -#endif - return Status::OK(); -} - -#endif -query_old::QueryPtr -Transformer(proto::service::Query* request) { - query_old::BooleanQueryPtr boolean_query = - std::make_shared(); - query_old::QueryPtr query_ptr = std::make_shared(); -#if 0 - query_ptr->collection_id = request->collection_name(); - auto status = DeserializeJsonToBoolQuery(request->placeholders(), request->dsl(), boolean_query, query_ptr); - status = query_old::ValidateBooleanQuery(boolean_query); - query_old::GeneralQueryPtr general_query = std::make_shared(); - query_old::GenBinaryQuery(boolean_query, general_query->bin); - query_ptr->root = general_query; -#endif - return query_ptr; -} - -} // namespace milvus::wtf diff --git a/internal/core/src/query/deprecated/ParserDeprecated.h b/internal/core/src/query/deprecated/ParserDeprecated.h deleted file mode 100644 index adf079b56edd6..0000000000000 --- a/internal/core/src/query/deprecated/ParserDeprecated.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once -#include "pb/milvus.pb.h" -#include "query/deprecated/BooleanQuery.h" -#include "query/deprecated/BinaryQuery.h" -#include "query/deprecated/GeneralQuery.h" - -namespace milvus::wtf { - -query_old::QueryPtr -Transformer(proto::milvus::SearchRequest* query); - -} // namespace milvus::wtf diff --git a/internal/core/src/query/deprecated/ValidationUtil.cpp b/internal/core/src/query/deprecated/ValidationUtil.cpp deleted file mode 100644 index 49ab41d9af947..0000000000000 --- a/internal/core/src/query/deprecated/ValidationUtil.cpp +++ /dev/null @@ -1,523 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#include "ValidationUtil.h" -#include "config/ServerConfig.h" -#include "utils/Log.h" -#include "utils/StringHelpFunctions.h" - -#include -#include -#include -#include - -namespace milvus { -namespace server { - -namespace { - -Status -CheckParameterRange(const milvus::json& json_params, - const std::string_view param_name, - int64_t min, - int64_t max, - bool min_close = true, - bool max_closed = true) { - if (json_params.find(param_name) == json_params.end()) { - std::string msg = "Parameter list must contain: "; - msg += param_name; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_ARGUMENT, msg); - } - - try { - int64_t value = json_params[param_name]; - bool min_err = min_close ? value < min : value <= min; - bool max_err = max_closed ? value > max : value >= max; - if (min_err || max_err) { - std::string msg = "Invalid " + param_name + - " value: " + std::to_string(value) + - ". Valid range is " + (min_close ? "[" : "(") + - std::to_string(min) + ", " + std::to_string(max) + - (max_closed ? "]" : ")"); - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_ARGUMENT, msg); - } - } catch (std::exception& e) { - std::string msg = "Invalid " + param_name + ": "; - msg += e.what(); - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_ARGUMENT, msg); - } - - return Status::OK(); -} - -Status -CheckParameterExistence(const milvus::json& json_params, - const std::string_view param_name) { - if (json_params.find(param_name) == json_params.end()) { - std::string msg = "Parameter list must contain: "; - msg += param_name; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_ARGUMENT, msg); - } - - try { - int64_t value = json_params[param_name]; - if (value < 0) { - std::string msg = - "Invalid " + param_name + " value: " + std::to_string(value); - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_ARGUMENT, msg); - } - } catch (std::exception& e) { - std::string msg = "Invalid " + param_name + ": "; - msg += e.what(); - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_ARGUMENT, msg); - } - - return Status::OK(); -} - -} // namespace - -Status -ValidateCollectionName(const std::string_view collection_name) { - // Collection name shouldn't be empty. - if (collection_name.empty()) { - std::string msg = "Collection name should not be empty."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_COLLECTION_NAME, msg); - } - - std::string invalid_msg = - "Invalid collection name: " + collection_name + ". "; - // Collection name size shouldn't exceed engine::MAX_NAME_LENGTH. - if (collection_name.size() > engine::MAX_NAME_LENGTH) { - std::string msg = - invalid_msg + "The length of a collection name must be less than " + - std::to_string(engine::MAX_NAME_LENGTH) + " characters."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_COLLECTION_NAME, msg); - } - - // Collection name first character should be underscore or character. - char first_char = collection_name[0]; - if (first_char != '_' && std::isalpha(first_char) == 0) { - std::string msg = invalid_msg + - "The first character of a collection name must be an " - "underscore or letter."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_COLLECTION_NAME, msg); - } - - int64_t table_name_size = collection_name.size(); - for (int64_t i = 1; i < table_name_size; ++i) { - char name_char = collection_name[i]; - if (name_char != '_' && name_char != '$' && - std::isalnum(name_char) == 0) { - std::string msg = invalid_msg + - "Collection name can only contain numbers, " - "letters, and underscores."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_COLLECTION_NAME, msg); - } - } - - return Status::OK(); -} - -Status -ValidateFieldName(const std::string_view field_name) { - // Field name shouldn't be empty. - if (field_name.empty()) { - std::string msg = "Field name should not be empty."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_FIELD_NAME, msg); - } - - std::string invalid_msg = "Invalid field name: " + field_name + ". "; - // Field name size shouldn't exceed engine::MAX_NAME_LENGTH. - if (field_name.size() > engine::MAX_NAME_LENGTH) { - std::string msg = - invalid_msg + "The length of a field name must be less than " + - std::to_string(engine::MAX_NAME_LENGTH) + " characters."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_FIELD_NAME, msg); - } - - // Field name first character should be underscore or character. - char first_char = field_name[0]; - if (first_char != '_' && std::isalpha(first_char) == 0) { - std::string msg = invalid_msg + - "The first character of a field name must be an " - "underscore or letter."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_FIELD_NAME, msg); - } - - int64_t field_name_size = field_name.size(); - for (int64_t i = 1; i < field_name_size; ++i) { - char name_char = field_name[i]; - if (name_char != '_' && std::isalnum(name_char) == 0) { - std::string msg = invalid_msg + - "Field name cannot only contain numbers, " - "letters, and underscores."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_FIELD_NAME, msg); - } - } - - return Status::OK(); -} - -Status -ValidateVectorIndexType(std::string& index_type, bool is_binary) { - // Index name shouldn't be empty. - if (index_type.empty()) { - std::string msg = "Index type should not be empty."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_FIELD_NAME, msg); - } - - // string case insensitive - std::transform( - index_type.begin(), index_type.end(), index_type.begin(), ::toupper); - - static std::set s_vector_index_type = { - knowhere::IndexEnum::INVALID, - knowhere::IndexEnum::INDEX_FAISS_IDMAP, - knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, - knowhere::IndexEnum::INDEX_FAISS_IVFPQ, - knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, - knowhere::IndexEnum::INDEX_HNSW, - }; - - static std::set s_binary_index_types = { - knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, - knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, - }; - - std::set& index_types = - is_binary ? s_binary_index_types : s_vector_index_type; - if (index_types.find(index_type) == index_types.end()) { - std::string msg = "Invalid index type: " + index_type; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_INDEX_TYPE, msg); - } - - return Status::OK(); -} - -Status -ValidateStructuredIndexType(std::string& index_type) { - // Index name shouldn't be empty. - if (index_type.empty()) { - std::string msg = "Index type should not be empty."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_FIELD_NAME, msg); - } - - // string case insensitive - std::transform( - index_type.begin(), index_type.end(), index_type.begin(), ::toupper); - - static std::set s_index_types = { - engine::DEFAULT_STRUCTURED_INDEX, - }; - - if (s_index_types.find(index_type) == s_index_types.end()) { - std::string msg = "Invalid index type: " + index_type; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_INDEX_TYPE, msg); - } - - return Status::OK(); -} - -Status -ValidateDimension(int64_t dim, bool is_binary) { - if (dim <= 0 || dim > engine::MAX_DIMENSION) { - std::string msg = "Invalid dimension: " + std::to_string(dim) + - ". Should be in range 1 ~ " + - std::to_string(engine::MAX_DIMENSION) + "."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); - } - - if (is_binary && (dim % 8) != 0) { - std::string msg = "Invalid dimension: " + std::to_string(dim) + - ". Should be multiple of 8."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); - } - - return Status::OK(); -} - -Status -ValidateIndexParams(const milvus::json& index_params, - int64_t dimension, - const std::string_view index_type) { - if (engine::utils::IsFlatIndexType(index_type)) { - return Status::OK(); - } else if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT || - index_type == knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 || - index_type == knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) { - auto status = CheckParameterRange( - index_params, knowhere::IndexParams::nlist, 1, 65536); - if (!status.ok()) { - return status; - } - } else if (index_type == knowhere::IndexEnum::INDEX_FAISS_IVFPQ) { - auto status = CheckParameterRange( - index_params, knowhere::IndexParams::nlist, 1, 65536); - if (!status.ok()) { - return status; - } - - status = - CheckParameterExistence(index_params, knowhere::IndexParams::m); - if (!status.ok()) { - return status; - } - - // special check for 'm' parameter - int64_t m_value = index_params[knowhere::IndexParams::m]; - if (!milvus::knowhere::IVFPQConfAdapter::GetValidCPUM(dimension, - m_value)) { - std::string msg = "Invalid m, dimension can't not be divided by m "; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_ARGUMENT, msg); - } - /*std::vector resset; - milvus::knowhere::IVFPQConfAdapter::GetValidMList(dimension, resset); - int64_t m_value = index_params[knowhere::IndexParams::m]; - if (resset.empty()) { - std::string msg = "Invalid collection dimension, unable to get reasonable values for 'm'"; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_COLLECTION_DIMENSION, msg); - } - - auto iter = std::find(std::begin(resset), std::end(resset), m_value); - if (iter == std::end(resset)) { - std::string msg = - "Invalid " + std::string(knowhere::IndexParams::m) + ", must be one of the following values: "; - for (size_t i = 0; i < resset.size(); i++) { - if (i != 0) { - msg += ","; - } - msg += std::to_string(resset[i]); - } - - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_ARGUMENT, msg); - }*/ - } else if (index_type == knowhere::IndexEnum::INDEX_HNSW) { - auto status = - CheckParameterRange(index_params, knowhere::IndexParams::M, 4, 64); - if (!status.ok()) { - return status; - } - status = CheckParameterRange( - index_params, knowhere::IndexParams::efConstruction, 8, 512); - if (!status.ok()) { - return status; - } - } - - return Status::OK(); -} - -Status -ValidateSegmentRowCount(int64_t segment_row_count) { - int64_t min = config.engine.build_index_threshold(); - int max = engine::MAX_SEGMENT_ROW_COUNT; - if (segment_row_count < min || segment_row_count > max) { - std::string msg = - "Invalid segment row count: " + std::to_string(segment_row_count) + - ". " + "Should be in range " + std::to_string(min) + " ~ " + - std::to_string(max) + "."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_SEGMENT_ROW_COUNT, msg); - } - return Status::OK(); -} - -Status -ValidateIndexMetricType(const std::string_view metric_type, - const std::string_view index_type) { - if (engine::utils::IsFlatIndexType(index_type)) { - // pass - } else if (index_type == knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT) { - // binary - if (metric_type != knowhere::Metric::HAMMING && - metric_type != knowhere::Metric::JACCARD) { - std::string msg = "Index metric type " + metric_type + - " does not match index type " + index_type; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_ARGUMENT, msg); - } - } else { - // float - if (metric_type != knowhere::Metric::L2 && - metric_type != knowhere::Metric::IP) { - std::string msg = "Index metric type " + metric_type + - " does not match index type " + index_type; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_ARGUMENT, msg); - } - } - - return Status::OK(); -} - -Status -ValidateSearchMetricType(const std::string_view metric_type, bool is_binary) { - if (is_binary) { - // binary - if (metric_type == knowhere::Metric::L2 || - metric_type == knowhere::Metric::IP) { - std::string msg = - "Cannot search binary entities with index metric type " + - metric_type; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_ARGUMENT, msg); - } - } else { - // float - if (metric_type == knowhere::Metric::HAMMING || - metric_type == knowhere::Metric::JACCARD) { - std::string msg = - "Cannot search float entities with index metric type " + - metric_type; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_ARGUMENT, msg); - } - } - - return Status::OK(); -} - -Status -ValidateSearchTopk(int64_t top_k) { - if (top_k <= 0 || top_k > QUERY_MAX_TOPK) { - std::string msg = "Invalid topk: " + std::to_string(top_k) + ". " + - "The topk must be within the range of 1 ~ 16384."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_TOPK, msg); - } - - return Status::OK(); -} - -Status -ValidatePartitionTags(const std::vector& partition_tags) { - for (const std::string_view tag : partition_tags) { - // Partition nametag shouldn't be empty. - if (tag.empty()) { - std::string msg = "Partition tag should not be empty."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_PARTITION_TAG, msg); - } - - std::string invalid_msg = "Invalid partition tag: " + tag + ". "; - // Partition tag size shouldn't exceed 255. - if (tag.size() > engine::MAX_NAME_LENGTH) { - std::string msg = invalid_msg + - "The length of a partition tag must be less than " - "255 characters."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_PARTITION_TAG, msg); - } - - // Partition tag first character should be underscore or character. - char first_char = tag[0]; - if (first_char != '_' && std::isalnum(first_char) == 0) { - std::string msg = invalid_msg + - "The first character of a partition tag must be " - "an underscore or letter."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_PARTITION_TAG, msg); - } - - int64_t tag_size = tag.size(); - for (int64_t i = 1; i < tag_size; ++i) { - char name_char = tag[i]; - if (name_char != '_' && name_char != '$' && - std::isalnum(name_char) == 0) { - std::string msg = invalid_msg + - "Partition tag can only contain numbers, " - "letters, and underscores."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_PARTITION_TAG, msg); - } - } - -#if 0 - // trim side-blank of tag, only compare valid characters - // for example: " ab cd " is treated as "ab cd" - std::string valid_tag = tag; - StringHelpFunctions::TrimStringBlank(valid_tag); - if (valid_tag.empty()) { - std::string msg = "Invalid partition tag: " + valid_tag + ". " + "Partition tag should not be empty."; - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_PARTITION_TAG, msg); - } - - // max length of partition tag - if (valid_tag.length() > engine::MAX_NAME_LENGTH) { - std::string msg = "Invalid partition tag: " + valid_tag + ". " + - "Partition tag exceed max length: " + std::to_string(engine::MAX_NAME_LENGTH); - LOG_SERVER_ERROR_ << msg; - return Status(SERVER_INVALID_PARTITION_TAG, msg); - } -#endif - } - - return Status::OK(); -} - -Status -ValidateInsertDataSize(const InsertParam& insert_param) { - int64_t chunk_size = 0; - for (auto& pair : insert_param.fields_data_) { - for (auto& data : pair.second) { - chunk_size += data.second; - } - } - - if (chunk_size > engine::MAX_INSERT_DATA_SIZE) { - std::string msg = - "The amount of data inserted each time cannot exceed " + - std::to_string(engine::MAX_INSERT_DATA_SIZE / engine::MB) + " MB"; - return Status(SERVER_INVALID_ROWRECORD_ARRAY, msg); - } - - return Status::OK(); -} - -Status -ValidateCompactThreshold(double threshold) { - if (threshold > 1.0 || threshold < 0.0) { - std::string msg = - "Invalid compact threshold: " + std::to_string(threshold) + - ". Should be in range [0.0, 1.0]"; - return Status(SERVER_INVALID_ROWRECORD_ARRAY, msg); - } - - return Status::OK(); -} - -} // namespace server -} // namespace milvus diff --git a/internal/core/src/query/deprecated/ValidationUtil.h b/internal/core/src/query/deprecated/ValidationUtil.h deleted file mode 100644 index ae0ae09be4815..0000000000000 --- a/internal/core/src/query/deprecated/ValidationUtil.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include "db/Types.h" -#include "server/delivery/request/Types.h" -#include "utils/Json.h" -#include "utils/Status.h" - -#include -#include - -namespace milvus { -namespace server { - -constexpr int64_t QUERY_MAX_TOPK = 16384; -constexpr int64_t GPU_QUERY_MAX_TOPK = 2048; -constexpr int64_t GPU_QUERY_MAX_NPROBE = 2048; - -extern Status -ValidateCollectionName(const std::string_view collection_name); - -extern Status -ValidateFieldName(const std::string_view field_name); - -extern Status -ValidateDimension(int64_t dimension, bool is_binary); - -extern Status -ValidateVectorIndexType(std::string& index_type, bool is_binary); - -extern Status -ValidateStructuredIndexType(std::string& index_type); - -extern Status -ValidateIndexParams(const milvus::json& index_params, - int64_t dimension, - const std::string_view index_type); - -extern Status -ValidateSegmentRowCount(int64_t segment_row_count); - -extern Status -ValidateIndexMetricType(const std::string_view metric_type, - const std::string_view index_type); - -extern Status -ValidateSearchMetricType(const std::string_view metric_type, bool is_binary); - -extern Status -ValidateSearchTopk(int64_t top_k); - -extern Status -ValidatePartitionTags(const std::vector& partition_tags); - -extern Status -ValidateInsertDataSize(const InsertParam& insert_param); - -extern Status -ValidateCompactThreshold(double threshold); -} // namespace server -} // namespace milvus diff --git a/internal/core/src/query/generated/ExecExprVisitor.h b/internal/core/src/query/generated/ExecExprVisitor.h index 69e3e2769ff9b..90872ee420cc6 100644 --- a/internal/core/src/query/generated/ExecExprVisitor.h +++ b/internal/core/src/query/generated/ExecExprVisitor.h @@ -114,11 +114,21 @@ class ExecExprVisitor : public ExprVisitor { auto ExecUnaryRangeVisitorDispatcherJson(UnaryRangeExpr& expr_raw) -> BitsetType; + template + auto + ExecUnaryRangeVisitorDispatcherArray(UnaryRangeExpr& expr_raw) + -> BitsetType; + template auto ExecBinaryArithOpEvalRangeVisitorDispatcherJson( BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType; + template + auto + ExecBinaryArithOpEvalRangeVisitorDispatcherArray( + BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType; + template auto ExecBinaryArithOpEvalRangeVisitorDispatcher( @@ -129,6 +139,11 @@ class ExecExprVisitor : public ExprVisitor { ExecBinaryRangeVisitorDispatcherJson(BinaryRangeExpr& expr_raw) -> BitsetType; + template + auto + ExecBinaryRangeVisitorDispatcherArray(BinaryRangeExpr& expr_raw) + -> BitsetType; + template auto ExecBinaryRangeVisitorDispatcher(BinaryRangeExpr& expr_raw) -> BitsetType; @@ -145,14 +160,26 @@ class ExecExprVisitor : public ExprVisitor { auto ExecTermJsonVariableInField(TermExpr& expr_raw) -> BitsetType; + template + auto + ExecTermArrayVariableInField(TermExpr& expr_raw) -> BitsetType; + template auto ExecTermJsonFieldInVariable(TermExpr& expr_raw) -> BitsetType; + template + auto + ExecTermArrayFieldInVariable(TermExpr& expr_raw) -> BitsetType; + template auto ExecTermVisitorImplTemplateJson(TermExpr& expr_raw) -> BitsetType; + template + auto + ExecTermVisitorImplTemplateArray(TermExpr& expr_raw) -> BitsetType; + template auto ExecCompareExprDispatcher(CompareExpr& expr, CmpFunc cmp_func) @@ -162,6 +189,10 @@ class ExecExprVisitor : public ExprVisitor { auto ExecJsonContains(JsonContainsExpr& expr_raw) -> BitsetType; + template + auto + ExecArrayContains(JsonContainsExpr& expr_raw) -> BitsetType; + auto ExecJsonContainsArray(JsonContainsExpr& expr_raw) -> BitsetType; @@ -172,6 +203,10 @@ class ExecExprVisitor : public ExprVisitor { auto ExecJsonContainsAll(JsonContainsExpr& expr_raw) -> BitsetType; + template + auto + ExecArrayContainsAll(JsonContainsExpr& expr_raw) -> BitsetType; + auto ExecJsonContainsAllArray(JsonContainsExpr& expr_raw) -> BitsetType; diff --git a/internal/core/src/query/generated/ExecPlanNodeVisitor.h b/internal/core/src/query/generated/ExecPlanNodeVisitor.h index 80fcc1c4252f3..cd1aa91ce17ed 100644 --- a/internal/core/src/query/generated/ExecPlanNodeVisitor.h +++ b/internal/core/src/query/generated/ExecPlanNodeVisitor.h @@ -12,7 +12,7 @@ #pragma once // Generated File // DO NOT EDIT -#include "utils/Json.h" +#include "common/Json.h" #include "query/PlanImpl.h" #include "segcore/SegmentGrowing.h" #include @@ -27,6 +27,9 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor { void visit(BinaryVectorANNS& node) override; + void + visit(Float16VectorANNS& node) override; + void visit(RetrievePlanNode& node) override; diff --git a/internal/core/src/query/generated/ExtractInfoPlanNodeVisitor.h b/internal/core/src/query/generated/ExtractInfoPlanNodeVisitor.h index afe5d8367d004..578077b85d092 100644 --- a/internal/core/src/query/generated/ExtractInfoPlanNodeVisitor.h +++ b/internal/core/src/query/generated/ExtractInfoPlanNodeVisitor.h @@ -24,6 +24,9 @@ class ExtractInfoPlanNodeVisitor : public PlanNodeVisitor { void visit(BinaryVectorANNS& node) override; + void + visit(Float16VectorANNS& node) override; + void visit(RetrievePlanNode& node) override; diff --git a/internal/core/src/query/generated/PlanNode.cpp b/internal/core/src/query/generated/PlanNode.cpp index 59a9448fb9ce9..f91f5b6c8404f 100644 --- a/internal/core/src/query/generated/PlanNode.cpp +++ b/internal/core/src/query/generated/PlanNode.cpp @@ -25,6 +25,11 @@ BinaryVectorANNS::accept(PlanNodeVisitor& visitor) { visitor.visit(*this); } +void +Float16VectorANNS::accept(PlanNodeVisitor& visitor) { + visitor.visit(*this); +} + void RetrievePlanNode::accept(PlanNodeVisitor& visitor) { visitor.visit(*this); diff --git a/internal/core/src/query/generated/PlanNodeVisitor.h b/internal/core/src/query/generated/PlanNodeVisitor.h index 3589535b2cde9..b41fba91e308a 100644 --- a/internal/core/src/query/generated/PlanNodeVisitor.h +++ b/internal/core/src/query/generated/PlanNodeVisitor.h @@ -25,6 +25,9 @@ class PlanNodeVisitor { virtual void visit(BinaryVectorANNS&) = 0; + virtual void + visit(Float16VectorANNS&) = 0; + virtual void visit(RetrievePlanNode&) = 0; }; diff --git a/internal/core/src/query/generated/ShowExprVisitor.h b/internal/core/src/query/generated/ShowExprVisitor.h index 9c5ca313b48fa..64532a00ddaa0 100644 --- a/internal/core/src/query/generated/ShowExprVisitor.h +++ b/internal/core/src/query/generated/ShowExprVisitor.h @@ -77,6 +77,6 @@ class ShowExprVisitor : public ExprVisitor { } private: - std::optional json_opt_; + std::optional json_opt_; }; } // namespace milvus::query diff --git a/internal/core/src/query/generated/ShowPlanNodeVisitor.h b/internal/core/src/query/generated/ShowPlanNodeVisitor.h index 9ff19f54fbc9e..4a8743763b734 100644 --- a/internal/core/src/query/generated/ShowPlanNodeVisitor.h +++ b/internal/core/src/query/generated/ShowPlanNodeVisitor.h @@ -12,8 +12,8 @@ #pragma once // Generated File // DO NOT EDIT -#include "exceptions/EasyAssert.h" -#include "utils/Json.h" +#include "common/EasyAssert.h" +#include "common/Json.h" #include #include @@ -28,6 +28,9 @@ class ShowPlanNodeVisitor : public PlanNodeVisitor { void visit(BinaryVectorANNS& node) override; + void + visit(Float16VectorANNS& node) override; + void visit(RetrievePlanNode& node) override; diff --git a/internal/core/src/query/generated/VerifyPlanNodeVisitor.h b/internal/core/src/query/generated/VerifyPlanNodeVisitor.h index a8810d001cb3f..6b9653d278797 100644 --- a/internal/core/src/query/generated/VerifyPlanNodeVisitor.h +++ b/internal/core/src/query/generated/VerifyPlanNodeVisitor.h @@ -12,7 +12,7 @@ #pragma once // Generated File // DO NOT EDIT -#include "utils/Json.h" +#include "common/Json.h" #include "query/PlanImpl.h" #include "segcore/SegmentGrowing.h" #include @@ -27,6 +27,9 @@ class VerifyPlanNodeVisitor : public PlanNodeVisitor { void visit(BinaryVectorANNS& node) override; + void + visit(Float16VectorANNS& node) override; + void visit(RetrievePlanNode& node) override; diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index 034e8de2a0206..06bf573f05a7d 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -27,7 +27,8 @@ #include "arrow/type_fwd.h" #include "common/Json.h" #include "common/Types.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" +#include "fmt/core.h" #include "pb/plan.pb.h" #include "query/ExprImpl.h" #include "query/Relational.h" @@ -117,7 +118,8 @@ ExecExprVisitor::visit(LogicalUnaryExpr& expr) { break; } default: { - PanicInfo("Invalid Unary Op"); + PanicInfo(OpTypeInvalid, + fmt::format("Invalid Unary Op {}", expr.op_type_)); } } AssertInfo(res.size() == row_count_, @@ -164,7 +166,8 @@ ExecExprVisitor::visit(LogicalBinaryExpr& expr) { break; } default: { - PanicInfo("Invalid Binary Op"); + PanicInfo(OpTypeInvalid, + fmt::format("Invalid Binary Op {}", expr.op_type_)); } } AssertInfo(res.size() == row_count_, @@ -440,7 +443,8 @@ ExecExprVisitor::ExecUnaryRangeVisitorDispatcherImpl(UnaryRangeExpr& expr_raw) } // TODO: PostfixMatch default: { - PanicInfo("unsupported range node"); + PanicInfo(OpTypeInvalid, + fmt::format("unsupported range node {}", op)); } } } @@ -494,13 +498,73 @@ ExecExprVisitor::ExecUnaryRangeVisitorDispatcher(UnaryRangeExpr& expr_raw) } default: { - PanicInfo("unsupported range node"); + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported range node {}", expr.op_type_)); } } } return ExecUnaryRangeVisitorDispatcherImpl(expr_raw); } +template +bool +CompareTwoJsonArray(T arr1, const proto::plan::Array& arr2) { + int json_array_length = 0; + if constexpr (std::is_same_v< + T, + simdjson::simdjson_result>) { + json_array_length = arr1.count_elements(); + } + if constexpr (std::is_same_v>>) { + json_array_length = arr1.size(); + } + if (arr2.array_size() != json_array_length) { + return false; + } + int i = 0; + for (auto&& it : arr1) { + switch (arr2.array(i).val_case()) { + case proto::plan::GenericValue::kBoolVal: { + auto val = it.template get(); + if (val.error() || val.value() != arr2.array(i).bool_val()) { + return false; + } + break; + } + case proto::plan::GenericValue::kInt64Val: { + auto val = it.template get(); + if (val.error() || val.value() != arr2.array(i).int64_val()) { + return false; + } + break; + } + case proto::plan::GenericValue::kFloatVal: { + auto val = it.template get(); + if (val.error() || val.value() != arr2.array(i).float_val()) { + return false; + } + break; + } + case proto::plan::GenericValue::kStringVal: { + auto val = it.template get(); + if (val.error() || val.value() != arr2.array(i).string_val()) { + return false; + } + break; + } + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", + arr2.array(i).val_case())); + } + i++; + } + return true; +} + template auto ExecExprVisitor::ExecUnaryRangeVisitorDispatcherJson(UnaryRangeExpr& expr_raw) @@ -547,56 +611,209 @@ ExecExprVisitor::ExecUnaryRangeVisitorDispatcherJson(UnaryRangeExpr& expr_raw) switch (op) { case OpType::Equal: { auto elem_func = [&](const milvus::Json& json) { - UnaryRangeJSONCompare(x.value() == val); + if constexpr (std::is_same_v) { + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (array.error()) { + return false; + } + return CompareTwoJsonArray(array, val); + } else { + UnaryRangeJSONCompare(x.value() == val); + } }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); } case OpType::NotEqual: { auto elem_func = [&](const milvus::Json& json) { - UnaryRangeJSONCompareNotEqual(x.value() != val); + if constexpr (std::is_same_v) { + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (array.error()) { + return false; + } + return !CompareTwoJsonArray(array, val); + } else { + UnaryRangeJSONCompareNotEqual(x.value() != val); + } }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); } case OpType::GreaterEqual: { auto elem_func = [&](const milvus::Json& json) { - UnaryRangeJSONCompare(x.value() >= val); + if constexpr (std::is_same_v) { + return false; + } else { + UnaryRangeJSONCompare(x.value() >= val); + } }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); } case OpType::GreaterThan: { auto elem_func = [&](const milvus::Json& json) { - UnaryRangeJSONCompare(x.value() > val); + if constexpr (std::is_same_v) { + return false; + } else { + UnaryRangeJSONCompare(x.value() > val); + } }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); } case OpType::LessEqual: { auto elem_func = [&](const milvus::Json& json) { - UnaryRangeJSONCompare(x.value() <= val); + if constexpr (std::is_same_v) { + return false; + } else { + UnaryRangeJSONCompare(x.value() <= val); + } }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); } case OpType::LessThan: { auto elem_func = [&](const milvus::Json& json) { - UnaryRangeJSONCompare(x.value() < val); + if constexpr (std::is_same_v) { + return false; + } else { + UnaryRangeJSONCompare(x.value() < val); + } }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); } case OpType::PrefixMatch: { auto elem_func = [&](const milvus::Json& json) { - UnaryRangeJSONCompare(Match(ExprValueType(x.value()), val, op)); + if constexpr (std::is_same_v) { + return false; + } else { + UnaryRangeJSONCompare( + Match(ExprValueType(x.value()), val, op)); + } }; return ExecRangeVisitorImpl( field_id, index_func, elem_func); } // TODO: PostfixMatch default: { - PanicInfo("unsupported range node"); + PanicInfo(OpTypeInvalid, + fmt::format("unsupported range node {}", op)); + } + } +} + +template +auto +ExecExprVisitor::ExecUnaryRangeVisitorDispatcherArray(UnaryRangeExpr& expr_raw) + -> BitsetType { + using Index = index::ScalarIndex; + auto& expr = static_cast&>(expr_raw); + + auto op = expr.op_type_; + auto val = expr.value_; + auto field_id = expr.column_.field_id; + auto index_func = [=](Index* index) { return TargetBitmap{}; }; + int index = -1; + if (expr.column_.nested_path.size() > 0) { + index = std::stoi(expr.column_.nested_path[0]); + } + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + + switch (op) { + case OpType::Equal: { + auto elem_func = [&](const milvus::ArrayView& array) { + if constexpr (std::is_same_v) { + return array.is_same_array(val); + } else { + auto array_data = array.template get_data(index); + return array_data == val; + } + }; + return ExecRangeVisitorImpl( + field_id, index_func, elem_func); + } + case OpType::NotEqual: { + auto elem_func = [&](const milvus::ArrayView& array) { + if constexpr (std::is_same_v) { + return !array.is_same_array(val); + } else { + auto array_data = array.template get_data(index); + return array_data != val; + } + }; + return ExecRangeVisitorImpl( + field_id, index_func, elem_func); + } + case OpType::GreaterEqual: { + auto elem_func = [&](const milvus::ArrayView& array) { + if constexpr (std::is_same_v) { + return false; + } else { + auto array_data = array.template get_data(index); + return array_data >= val; + } + }; + return ExecRangeVisitorImpl( + field_id, index_func, elem_func); + } + case OpType::GreaterThan: { + auto elem_func = [&](const milvus::ArrayView& array) { + if constexpr (std::is_same_v) { + return false; + } else { + auto array_data = array.template get_data(index); + return array_data > val; + } + }; + return ExecRangeVisitorImpl( + field_id, index_func, elem_func); + } + case OpType::LessEqual: { + auto elem_func = [&](const milvus::ArrayView& array) { + if constexpr (std::is_same_v) { + return false; + } else { + auto array_data = array.template get_data(index); + return array_data <= val; + } + }; + return ExecRangeVisitorImpl( + field_id, index_func, elem_func); + } + case OpType::LessThan: { + auto elem_func = [&](const milvus::ArrayView& array) { + if constexpr (std::is_same_v) { + return false; + } else { + auto array_data = array.template get_data(index); + return array_data < val; + } + }; + return ExecRangeVisitorImpl( + field_id, index_func, elem_func); + } + case OpType::PrefixMatch: { + auto elem_func = [&](const milvus::ArrayView& array) { + if constexpr (std::is_same_v) { + return false; + } else { + auto array_data = array.template get_data(index); + return Match(array_data, val, op); + } + }; + return ExecRangeVisitorImpl( + field_id, index_func, elem_func); + } + // TODO: PostfixMatch + default: { + PanicInfo(OpTypeInvalid, + fmt::format("unsupported range node {}", op)); } } } @@ -687,7 +904,9 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher( expr.column_.field_id, index_func, elem_func); } default: { - PanicInfo("unsupported arithmetic operation"); + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arithmetic operation {}", op)); } } } @@ -754,12 +973,17 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcher( expr.column_.field_id, index_func, elem_func); } default: { - PanicInfo("unsupported arithmetic operation"); + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arithmetic operation {}", op)); } } } default: { - PanicInfo("unsupported range node with arithmetic operation"); + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported range node with arithmetic operation {}", op)); } } } @@ -873,8 +1097,27 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( return ExecDataRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); } + case ArithOpType::ArrayLength: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::Json& json) { + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length == val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } default: { - PanicInfo("unsupported arithmetic operation"); + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arithmetic operation {}", op)); } } } @@ -941,13 +1184,228 @@ ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherJson( return ExecDataRangeVisitorImpl( expr.column_.field_id, index_func, elem_func); } + case ArithOpType::ArrayLength: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::Json& json) { + int array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + return array_length != val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } default: { - PanicInfo("unsupported arithmetic operation"); + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arithmetic operation {}", op)); } } } default: { - PanicInfo("unsupported range node with arithmetic operation"); + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported range node with arithmetic operation {}", op)); + } + } +} // namespace milvus::query + +template +auto +ExecExprVisitor::ExecBinaryArithOpEvalRangeVisitorDispatcherArray( + BinaryArithOpEvalRangeExpr& expr_raw) -> BitsetType { + auto& expr = + static_cast&>(expr_raw); + using Index = index::ScalarIndex; + + auto arith_op = expr.arith_op_; + auto right_operand = expr.right_operand_; + auto op = expr.op_type_; + auto val = expr.value_; + int index = -1; + if (expr.column_.nested_path.size() > 0) { + index = std::stoi(expr.column_.nested_path[0]); + } + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + + switch (op) { + case OpType::Equal: { + switch (arith_op) { + case ArithOpType::Add: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return value + right_operand == val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } + case ArithOpType::Sub: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return value - right_operand == val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } + case ArithOpType::Mul: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return value * right_operand == val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } + case ArithOpType::Div: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return value / right_operand == val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } + case ArithOpType::Mod: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return static_cast( + fmod(value, right_operand)) == val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } + case ArithOpType::ArrayLength: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::ArrayView& array) { + return array.length() == val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } + default: { + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arithmetic operation {}", op)); + } + } + } + case OpType::NotEqual: { + switch (arith_op) { + case ArithOpType::Add: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return value + right_operand != val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } + case ArithOpType::Sub: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return value - right_operand != val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } + case ArithOpType::Mul: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return value * right_operand != val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } + case ArithOpType::Div: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return value / right_operand != val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } + case ArithOpType::Mod: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return static_cast( + fmod(value, right_operand)) != val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } + case ArithOpType::ArrayLength: { + auto index_func = [val, right_operand](Index* index, + size_t offset) { + return false; + }; + auto elem_func = [&](const milvus::ArrayView& array) { + return array.length() != val; + }; + return ExecDataRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } + default: { + PanicInfo( + OpTypeInvalid, + fmt::format("unsupported arithmetic operation {}", op)); + } + } + } + default: { + PanicInfo( + OpTypeInvalid, + fmt::format( + "unsupported range node with arithmetic operation {}", op)); } } } // namespace milvus::query @@ -1092,6 +1550,60 @@ ExecExprVisitor::ExecBinaryRangeVisitorDispatcherJson(BinaryRangeExpr& expr_raw) } } +template +auto +ExecExprVisitor::ExecBinaryRangeVisitorDispatcherArray( + BinaryRangeExpr& expr_raw) -> BitsetType { + using Index = index::ScalarIndex; + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + + auto& expr = static_cast&>(expr_raw); + bool lower_inclusive = expr.lower_inclusive_; + bool upper_inclusive = expr.upper_inclusive_; + ExprValueType val1 = expr.lower_value_; + ExprValueType val2 = expr.upper_value_; + int index = -1; + if (expr.column_.nested_path.size() > 0) { + index = std::stoi(expr.column_.nested_path[0]); + } + + // no json index now + auto index_func = [=](Index* index) { return TargetBitmap{}; }; + + if (lower_inclusive && upper_inclusive) { + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return val1 <= value && value <= val2; + }; + return ExecRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } else if (lower_inclusive && !upper_inclusive) { + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return val1 <= value && value < val2; + }; + return ExecRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } else if (!lower_inclusive && upper_inclusive) { + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return val1 < value && value <= val2; + }; + return ExecRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } else { + auto elem_func = [&](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return val1 < value && value < val2; + }; + return ExecRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } +} + void ExecExprVisitor::visit(UnaryRangeExpr& expr) { auto& field_meta = segment_.get_schema()[expr.column_.field_id]; @@ -1150,14 +1662,47 @@ ExecExprVisitor::visit(UnaryRangeExpr& expr) { res = ExecUnaryRangeVisitorDispatcherJson(expr); break; + case proto::plan::GenericValue::ValCase::kArrayVal: + res = + ExecUnaryRangeVisitorDispatcherJson( + expr); + break; + default: + PanicInfo( + DataTypeInvalid, + fmt::format("unknown data type: {}", expr.val_case_)); + } + break; + } + case DataType::ARRAY: { + switch (expr.val_case_) { + case proto::plan::GenericValue::ValCase::kBoolVal: + res = ExecUnaryRangeVisitorDispatcherArray(expr); + break; + case proto::plan::GenericValue::ValCase::kInt64Val: + res = ExecUnaryRangeVisitorDispatcherArray(expr); + break; + case proto::plan::GenericValue::ValCase::kFloatVal: + res = ExecUnaryRangeVisitorDispatcherArray(expr); + break; + case proto::plan::GenericValue::ValCase::kStringVal: + res = + ExecUnaryRangeVisitorDispatcherArray(expr); + break; + case proto::plan::GenericValue::ValCase::kArrayVal: + res = ExecUnaryRangeVisitorDispatcherArray< + proto::plan::Array>(expr); + break; default: PanicInfo( + DataTypeInvalid, fmt::format("unknown data type: {}", expr.val_case_)); } break; } default: - PanicInfo(fmt::format("unsupported data type: {}", + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type: {}", expr.column_.data_type)); } AssertInfo(res.size() == row_count_, @@ -1216,6 +1761,28 @@ ExecExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) { } default: { PanicInfo( + DataTypeInvalid, + fmt::format("unsupported value type {} in expression", + expr.val_case_)); + } + } + break; + } + case DataType::ARRAY: { + switch (expr.val_case_) { + case proto::plan::GenericValue::ValCase::kInt64Val: { + res = ExecBinaryArithOpEvalRangeVisitorDispatcherArray< + int64_t>(expr); + break; + } + case proto::plan::GenericValue::ValCase::kFloatVal: { + res = ExecBinaryArithOpEvalRangeVisitorDispatcherArray< + double>(expr); + break; + } + default: { + PanicInfo( + DataTypeInvalid, fmt::format("unsupported value type {} in expression", expr.val_case_)); } @@ -1223,7 +1790,8 @@ ExecExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) { break; } default: - PanicInfo(fmt::format("unsupported data type: {}", + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type: {}", expr.column_.data_type)); } AssertInfo(res.size() == row_count_, @@ -1276,10 +1844,6 @@ ExecExprVisitor::visit(BinaryRangeExpr& expr) { } case DataType::JSON: { switch (expr.val_case_) { - case proto::plan::GenericValue::ValCase::kBoolVal: { - res = ExecBinaryRangeVisitorDispatcherJson(expr); - break; - } case proto::plan::GenericValue::ValCase::kInt64Val: { res = ExecBinaryRangeVisitorDispatcherJson(expr); break; @@ -1295,6 +1859,31 @@ ExecExprVisitor::visit(BinaryRangeExpr& expr) { } default: { PanicInfo( + DataTypeInvalid, + fmt::format("unsupported value type {} in expression", + expr.val_case_)); + } + } + break; + } + case DataType::ARRAY: { + switch (expr.val_case_) { + case proto::plan::GenericValue::ValCase::kInt64Val: { + res = ExecBinaryRangeVisitorDispatcherArray(expr); + break; + } + case proto::plan::GenericValue::ValCase::kFloatVal: { + res = ExecBinaryRangeVisitorDispatcherArray(expr); + break; + } + case proto::plan::GenericValue::ValCase::kStringVal: { + res = ExecBinaryRangeVisitorDispatcherArray( + expr); + break; + } + default: { + PanicInfo( + DataTypeInvalid, fmt::format("unsupported value type {} in expression", expr.val_case_)); } @@ -1302,7 +1891,8 @@ ExecExprVisitor::visit(BinaryRangeExpr& expr) { break; } default: - PanicInfo(fmt::format("unsupported data type: {}", + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type: {}", expr.column_.data_type)); } AssertInfo(res.size() == row_count_, @@ -1320,7 +1910,7 @@ struct relational { template bool operator()(T const&...) const { - PanicInfo("incompatible operands"); + PanicInfo(OpTypeInvalid, "incompatible operands"); } }; @@ -1393,7 +1983,10 @@ ExecExprVisitor::ExecCompareLeftType(const FieldId& left_field_id, left_raw_data, right_field_id, chunk_id, cmp_func); break; default: - PanicInfo("unsupported left datatype of compare expr"); + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported right datatype {} of compare expr", + right_field_type)); } results.push_back(result); } @@ -1444,7 +2037,10 @@ ExecExprVisitor::ExecCompareExprDispatcherForNonIndexedSegment( expr.right_data_type_, cmp_func); default: - PanicInfo("unsupported right datatype of compare expr"); + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported right datatype {} of compare expr", + expr.left_data_type_)); } } @@ -1649,7 +2245,8 @@ ExecExprVisitor::ExecCompareExprDispatcher(CompareExpr& expr, Op op) } } default: - PanicInfo(fmt::format("unsupported data type: {}", type)); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", type)); } }; auto left = getChunkData( @@ -1717,7 +2314,8 @@ ExecExprVisitor::visit(CompareExpr& expr) { // case OpType::PostfixMatch: { // } default: { - PanicInfo("unsupported optype"); + PanicInfo(OpTypeInvalid, + fmt::format("unsupported optype {}", expr.op_type_)); } } AssertInfo(res.size() == row_count_, @@ -1761,7 +2359,9 @@ ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType { break; } default: { - PanicInfo("unsupported type"); + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported data type {}", expr.val_case_)); } } @@ -1773,7 +2373,7 @@ ExecExprVisitor::ExecTermVisitorImpl(TermExpr& expr_raw) -> BitsetType { bitset[_offset] = true; cached_offsets.push_back(_offset); } - // If enable plan_visitor pk index cache, pass offsets to it + // If enable plan_visitor pk index cache, pass offsets_ to it if (plan_visitor_ != nullptr) { plan_visitor_->SetExprUsePkIndex(true); plan_visitor_->SetExprCacheOffsets(std::move(cached_offsets)); @@ -1908,6 +2508,39 @@ ExecExprVisitor::ExecTermJsonFieldInVariable(TermExpr& expr_raw) -> BitsetType { expr.column_.field_id, index_func, elem_func); } +template +auto +ExecExprVisitor::ExecTermArrayFieldInVariable(TermExpr& expr_raw) + -> BitsetType { + using Index = index::ScalarIndex; + auto& expr = static_cast&>(expr_raw); + auto index_func = [](Index* index) { return TargetBitmap{}; }; + int index = -1; + if (expr.column_.nested_path.size() > 0) { + index = std::stoi(expr.column_.nested_path[0]); + } + std::unordered_set term_set(expr.terms_.begin(), + expr.terms_.end()); + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + + if (term_set.empty()) { + auto elem_func = [=](const milvus::ArrayView& array) { return false; }; + return ExecRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); + } + + auto elem_func = [&term_set, &index](const milvus::ArrayView& array) { + auto value = array.get_data(index); + return term_set.find(ExprValueType(value)) != term_set.end(); + }; + + return ExecRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); +} + template auto ExecExprVisitor::ExecTermJsonVariableInField(TermExpr& expr_raw) -> BitsetType { @@ -1945,6 +2578,36 @@ ExecExprVisitor::ExecTermJsonVariableInField(TermExpr& expr_raw) -> BitsetType { expr.column_.field_id, index_func, elem_func); } +template +auto +ExecExprVisitor::ExecTermArrayVariableInField(TermExpr& expr_raw) + -> BitsetType { + using Index = index::ScalarIndex; + auto& expr = static_cast&>(expr_raw); + auto index_func = [](Index* index) { return TargetBitmap{}; }; + + AssertInfo(expr.terms_.size() == 1, + "element length in json array must be one"); + ExprValueType target_val = expr.terms_[0]; + + auto elem_func = [&target_val](const milvus::ArrayView& array) { + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + for (int i = 0; i < array.length(); i++) { + auto val = array.template get_data(i); + if (val == target_val) { + return true; + } + } + return false; + }; + + return ExecRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); +} + template auto ExecExprVisitor::ExecTermVisitorImplTemplateJson(TermExpr& expr_raw) @@ -1956,6 +2619,17 @@ ExecExprVisitor::ExecTermVisitorImplTemplateJson(TermExpr& expr_raw) } } +template +auto +ExecExprVisitor::ExecTermVisitorImplTemplateArray(TermExpr& expr_raw) + -> BitsetType { + if (expr_raw.is_in_field_) { + return ExecTermArrayVariableInField(expr_raw); + } else { + return ExecTermArrayFieldInVariable(expr_raw); + } +} + void ExecExprVisitor::visit(TermExpr& expr) { auto& field_meta = segment_.get_schema()[expr.column_.field_id]; @@ -2017,14 +2691,40 @@ ExecExprVisitor::visit(TermExpr& expr) { case proto::plan::GenericValue::ValCase::VAL_NOT_SET: res = ExecTermVisitorImplTemplateJson(expr); break; + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", + expr.val_case_)); + } + break; + } + case DataType::ARRAY: { + switch (expr.val_case_) { + case proto::plan::GenericValue::ValCase::kBoolVal: + res = ExecTermVisitorImplTemplateArray(expr); + break; + case proto::plan::GenericValue::ValCase::kInt64Val: + res = ExecTermVisitorImplTemplateArray(expr); + break; + case proto::plan::GenericValue::ValCase::kFloatVal: + res = ExecTermVisitorImplTemplateArray(expr); + break; + case proto::plan::GenericValue::ValCase::kStringVal: + res = ExecTermVisitorImplTemplateArray(expr); + break; + case proto::plan::GenericValue::ValCase::VAL_NOT_SET: + res = ExecTermVisitorImplTemplateArray(expr); + break; default: PanicInfo( + Unsupported, fmt::format("unknown data type: {}", expr.val_case_)); } break; } default: - PanicInfo(fmt::format("unsupported data type: {}", + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", expr.column_.data_type)); } AssertInfo(res.size() == row_count_, @@ -2052,7 +2752,8 @@ ExecExprVisitor::visit(ExistsExpr& expr) { break; } default: - PanicInfo(fmt::format("unsupported data type {}", + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", expr.column_.data_type)); } AssertInfo(res.size() == row_count_, @@ -2067,63 +2768,6 @@ ExecExprVisitor::visit(AlwaysTrueExpr& expr) { bitset_opt_ = std::move(res); } -template -bool -compareTwoJsonArray(T arr1, const proto::plan::Array& arr2) { - int json_array_length = 0; - if constexpr (std::is_same_v< - T, - simdjson::simdjson_result>) { - json_array_length = arr1.count_elements(); - } - if constexpr (std::is_same_v>>) { - json_array_length = arr1.size(); - } - if (arr2.array_size() != json_array_length) { - return false; - } - int i = 0; - for (auto&& it : arr1) { - switch (arr2.array(i).val_case()) { - case proto::plan::GenericValue::kBoolVal: { - auto val = it.template get(); - if (val.error() || val.value() != arr2.array(i).bool_val()) { - return false; - } - break; - } - case proto::plan::GenericValue::kInt64Val: { - auto val = it.template get(); - if (val.error() || val.value() != arr2.array(i).int64_val()) { - return false; - } - break; - } - case proto::plan::GenericValue::kFloatVal: { - auto val = it.template get(); - if (val.error() || val.value() != arr2.array(i).float_val()) { - return false; - } - break; - } - case proto::plan::GenericValue::kStringVal: { - auto val = it.template get(); - if (val.error() || val.value() != arr2.array(i).string_val()) { - return false; - } - break; - } - default: - PanicInfo(fmt::format("unsupported data type {}", - arr2.array(i).val_case())); - } - i++; - } - return true; -} - template auto ExecExprVisitor::ExecJsonContains(JsonContainsExpr& expr_raw) -> BitsetType { @@ -2161,6 +2805,35 @@ ExecExprVisitor::ExecJsonContains(JsonContainsExpr& expr_raw) -> BitsetType { expr.column_.field_id, index_func, elem_func); } +template +auto +ExecExprVisitor::ExecArrayContains(JsonContainsExpr& expr_raw) -> BitsetType { + using Index = index::ScalarIndex; + auto& expr = static_cast&>(expr_raw); + AssertInfo(expr.column_.nested_path.size() == 0, + "[ExecArrayContains]nested path must be null"); + auto index_func = [](Index* index) { return TargetBitmap{}; }; + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + std::unordered_set elements; + for (auto const& element : expr.elements_) { + elements.insert(element); + } + auto elem_func = [&elements](const milvus::ArrayView& array) { + for (int i = 0; i < array.length(); ++i) { + if (elements.count(array.template get_data(i)) > 0) { + return true; + } + } + return false; + }; + + return ExecRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); +} + auto ExecExprVisitor::ExecJsonContainsArray(JsonContainsExpr& expr_raw) -> BitsetType { @@ -2188,7 +2861,7 @@ ExecExprVisitor::ExecJsonContainsArray(JsonContainsExpr& expr_raw) json_array.emplace_back(e); } for (auto const& element : elements) { - if (compareTwoJsonArray(json_array, element)) { + if (CompareTwoJsonArray(json_array, element)) { return true; } } @@ -2264,13 +2937,14 @@ ExecExprVisitor::ExecJsonContainsWithDiffType(JsonContainsExpr& expr_raw) if (val.error()) { continue; } - if (compareTwoJsonArray(val, element.array_val())) { + if (CompareTwoJsonArray(val, element.array_val())) { return true; } break; } default: - PanicInfo(fmt::format("unsupported data type {}", + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", element.val_case())); } } @@ -2324,6 +2998,41 @@ ExecExprVisitor::ExecJsonContainsAll(JsonContainsExpr& expr_raw) -> BitsetType { expr.column_.field_id, index_func, elem_func); } +template +auto +ExecExprVisitor::ExecArrayContainsAll(JsonContainsExpr& expr_raw) + -> BitsetType { + using Index = index::ScalarIndex; + auto& expr = static_cast&>(expr_raw); + AssertInfo(expr.column_.nested_path.size() == 0, + "[ExecArrayContains]nested path must be null"); + auto index_func = [](Index* index) { return TargetBitmap{}; }; + using GetType = + std::conditional_t, + std::string_view, + ExprValueType>; + + std::unordered_set elements; + for (auto const& element : expr.elements_) { + elements.insert(element); + } + // auto elements = expr.elements_; + auto elem_func = [&elements](const milvus::ArrayView& array) { + std::unordered_set tmp_elements(elements); + // Note: array can only be iterated once + for (int i = 0; i < array.length(); ++i) { + tmp_elements.erase(array.template get_data(i)); + if (tmp_elements.size() == 0) { + return true; + } + } + return tmp_elements.size() == 0; + }; + + return ExecRangeVisitorImpl( + expr.column_.field_id, index_func, elem_func); +} + auto ExecExprVisitor::ExecJsonContainsAllArray(JsonContainsExpr& expr_raw) -> BitsetType { @@ -2353,7 +3062,7 @@ ExecExprVisitor::ExecJsonContainsAllArray(JsonContainsExpr& expr_raw) json_array.emplace_back(e); } for (int index = 0; index < elements.size(); ++index) { - if (compareTwoJsonArray(json_array, elements[index])) { + if (CompareTwoJsonArray(json_array, elements[index])) { exist_elements_index.insert(index); } } @@ -2442,13 +3151,14 @@ ExecExprVisitor::ExecJsonContainsAllWithDiffType(JsonContainsExpr& expr_raw) if (val.error()) { continue; } - if (compareTwoJsonArray(val, element.array_val())) { + if (CompareTwoJsonArray(val, element.array_val())) { tmp_elements_index.erase(i); } break; } default: - PanicInfo(fmt::format("unsupported data type {}", + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", element.val_case())); } if (tmp_elements_index.size() == 0) { @@ -2470,75 +3180,135 @@ void ExecExprVisitor::visit(JsonContainsExpr& expr) { auto& field_meta = segment_.get_schema()[expr.column_.field_id]; AssertInfo( - expr.column_.data_type == DataType::JSON, + expr.column_.data_type == DataType::JSON || + expr.column_.data_type == DataType::ARRAY, "[ExecExprVisitor]DataType of JsonContainsExpr isn't json data type"); BitsetType res; + auto data_type = expr.column_.data_type; switch (expr.op_) { case proto::plan::JSONContainsExpr_JSONOp_Contains: case proto::plan::JSONContainsExpr_JSONOp_ContainsAny: { - if (expr.same_type_) { + if (datatype_is_array(data_type)) { switch (expr.val_case_) { case proto::plan::GenericValue::kBoolVal: { - res = ExecJsonContains(expr); + res = ExecArrayContains(expr); break; } case proto::plan::GenericValue::kInt64Val: { - res = ExecJsonContains(expr); + res = ExecArrayContains(expr); break; } case proto::plan::GenericValue::kFloatVal: { - res = ExecJsonContains(expr); + res = ExecArrayContains(expr); break; } case proto::plan::GenericValue::kStringVal: { - res = ExecJsonContains(expr); - break; - } - case proto::plan::GenericValue::kArrayVal: { - res = ExecJsonContainsArray(expr); + res = ExecArrayContains(expr); break; } default: - PanicInfo(fmt::format("unsupported data type")); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", + expr.val_case_)); + } + } else { + if (expr.same_type_) { + switch (expr.val_case_) { + case proto::plan::GenericValue::kBoolVal: { + res = ExecJsonContains(expr); + break; + } + case proto::plan::GenericValue::kInt64Val: { + res = ExecJsonContains(expr); + break; + } + case proto::plan::GenericValue::kFloatVal: { + res = ExecJsonContains(expr); + break; + } + case proto::plan::GenericValue::kStringVal: { + res = ExecJsonContains(expr); + break; + } + case proto::plan::GenericValue::kArrayVal: { + res = ExecJsonContainsArray(expr); + break; + } + default: + PanicInfo(Unsupported, + fmt::format("unsupported value type {}", + expr.val_case_)); + } + } else { + res = ExecJsonContainsWithDiffType(expr); } - break; } - res = ExecJsonContainsWithDiffType(expr); break; } case proto::plan::JSONContainsExpr_JSONOp_ContainsAll: { - if (expr.same_type_) { + if (datatype_is_array(data_type)) { switch (expr.val_case_) { case proto::plan::GenericValue::kBoolVal: { - res = ExecJsonContainsAll(expr); + res = ExecArrayContainsAll(expr); break; } case proto::plan::GenericValue::kInt64Val: { - res = ExecJsonContainsAll(expr); + res = ExecArrayContainsAll(expr); break; } case proto::plan::GenericValue::kFloatVal: { - res = ExecJsonContainsAll(expr); + res = ExecArrayContainsAll(expr); break; } case proto::plan::GenericValue::kStringVal: { - res = ExecJsonContainsAll(expr); - break; - } - case proto::plan::GenericValue::kArrayVal: { - res = ExecJsonContainsAllArray(expr); + res = ExecArrayContainsAll(expr); break; } default: - PanicInfo(fmt::format("unsupported data type")); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", + expr.val_case_)); + } + } else { + if (expr.same_type_) { + switch (expr.val_case_) { + case proto::plan::GenericValue::kBoolVal: { + res = ExecJsonContainsAll(expr); + break; + } + case proto::plan::GenericValue::kInt64Val: { + res = ExecJsonContainsAll(expr); + break; + } + case proto::plan::GenericValue::kFloatVal: { + res = ExecJsonContainsAll(expr); + break; + } + case proto::plan::GenericValue::kStringVal: { + res = ExecJsonContainsAll(expr); + break; + } + case proto::plan::GenericValue::kArrayVal: { + res = ExecJsonContainsAllArray(expr); + break; + } + default: + PanicInfo( + Unsupported, + fmt::format( + "unsupported value type {} in expression", + expr.val_case_)); + } + } else { + res = ExecJsonContainsAllWithDiffType(expr); } - break; } - res = ExecJsonContainsAllWithDiffType(expr); break; } default: - PanicInfo(fmt::format("unsupported json contains type")); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported json contains type {}", + expr.val_case_)); } AssertInfo(res.size() == row_count_, "[ExecExprVisitor]Size of results not equal row count"); diff --git a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp index ca8357177bffd..2b1018477fa90 100644 --- a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp @@ -17,7 +17,7 @@ #include "query/SubSearchResult.h" #include "query/generated/ExecExprVisitor.h" #include "segcore/SegmentGrowing.h" -#include "utils/Json.h" +#include "common/Json.h" #include "log/Log.h" namespace milvus::query { @@ -212,4 +212,9 @@ ExecPlanNodeVisitor::visit(BinaryVectorANNS& node) { VectorVisitorImpl(node); } +void +ExecPlanNodeVisitor::visit(Float16VectorANNS& node) { + VectorVisitorImpl(node); +} + } // namespace milvus::query diff --git a/internal/core/src/query/visitors/ExtractInfoPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExtractInfoPlanNodeVisitor.cpp index 9e383e2776d28..5bb6cd68242c9 100644 --- a/internal/core/src/query/visitors/ExtractInfoPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ExtractInfoPlanNodeVisitor.cpp @@ -47,6 +47,15 @@ ExtractInfoPlanNodeVisitor::visit(BinaryVectorANNS& node) { } } +void +ExtractInfoPlanNodeVisitor::visit(Float16VectorANNS& node) { + plan_info_.add_involved_field(node.search_info_.field_id_); + if (node.predicate_.has_value()) { + ExtractInfoExprVisitor expr_visitor(plan_info_); + node.predicate_.value()->accept(expr_visitor); + } +} + void ExtractInfoPlanNodeVisitor::visit(RetrievePlanNode& node) { // Assert(node.predicate_.has_value()); diff --git a/internal/core/src/query/visitors/ShowExprVisitor.cpp b/internal/core/src/query/visitors/ShowExprVisitor.cpp index 7e4081065debb..55c892240889c 100644 --- a/internal/core/src/query/visitors/ShowExprVisitor.cpp +++ b/internal/core/src/query/visitors/ShowExprVisitor.cpp @@ -14,6 +14,7 @@ #include "query/ExprImpl.h" #include "query/Plan.h" #include "query/generated/ShowExprVisitor.h" +#include "common/Types.h" namespace milvus::query { using Json = nlohmann::json; @@ -90,7 +91,8 @@ ShowExprVisitor::visit(LogicalBinaryExpr& expr) { case OpType::LogicalXor: return "LogicalXor"; default: - PanicInfo("unsupported op"); + PanicInfo(OpTypeInvalid, + fmt::format("unsupported operation {}", op)); } }(expr.op_type_); @@ -134,7 +136,9 @@ ShowExprVisitor::visit(TermExpr& expr) { case DataType::JSON: return TermExtract(expr); default: - PanicInfo("unsupported type"); + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported type {}", expr.column_.data_type)); } }(); @@ -192,7 +196,9 @@ ShowExprVisitor::visit(UnaryRangeExpr& expr) { json_opt_ = UnaryRangeExtract(expr); return; default: - PanicInfo("unsupported type"); + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported type {}", expr.column_.data_type)); } } @@ -247,7 +253,9 @@ ShowExprVisitor::visit(BinaryRangeExpr& expr) { json_opt_ = BinaryRangeExtract(expr); return; default: - PanicInfo("unsupported type"); + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported type {}", expr.column_.data_type)); } } @@ -317,7 +325,9 @@ ShowExprVisitor::visit(BinaryArithOpEvalRangeExpr& expr) { json_opt_ = BinaryArithOpEvalRangeExtract(expr); return; default: - PanicInfo("unsupported type"); + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported type {}", expr.column_.data_type)); } } diff --git a/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp index 48a472f0bfa7c..4325f41539b1b 100644 --- a/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ShowPlanNodeVisitor.cpp @@ -11,10 +11,10 @@ #include -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" +#include "common/Json.h" #include "query/generated/ShowExprVisitor.h" #include "query/generated/ShowPlanNodeVisitor.h" -#include "utils/Json.h" namespace milvus::query { #if 0 @@ -96,6 +96,30 @@ ShowPlanNodeVisitor::visit(BinaryVectorANNS& node) { ret_ = json_body; } +void +ShowPlanNodeVisitor::visit(Float16VectorANNS& node) { + assert(!ret_); + auto& info = node.search_info_; + Json json_body{ + {"node_type", "Float16VectorANNS"}, // + {"metric_type", info.metric_type_}, // + {"field_id_", info.field_id_.get()}, // + {"topk", info.topk_}, // + {"search_params", info.search_params_}, // + {"placeholder_tag", node.placeholder_tag_}, // + }; + if (node.predicate_.has_value()) { + ShowExprVisitor expr_show; + AssertInfo(node.predicate_.value(), + "[ShowPlanNodeVisitor]Can't get value from node predict"); + json_body["predicate"] = + expr_show.call_child(node.predicate_->operator*()); + } else { + json_body["predicate"] = "None"; + } + ret_ = json_body; +} + void ShowPlanNodeVisitor::visit(RetrievePlanNode& node) { } diff --git a/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp b/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp index bd3fa0b6fb24f..06fca20799686 100644 --- a/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/VerifyPlanNodeVisitor.cpp @@ -26,39 +26,6 @@ class VerifyPlanNodeVisitor : PlanNodeVisitor { }; } // namespace impl -static IndexType -InferIndexType(const Json& search_params) { - // ivf -> nprobe - // hnsw -> ef - static const std::map key_list = [] { - std::map list; - namespace ip = knowhere::indexparam; - namespace ie = knowhere::IndexEnum; - list.emplace(ip::NPROBE, ie::INDEX_FAISS_IVFFLAT); - list.emplace(ip::EF, ie::INDEX_HNSW); - return list; - }(); - auto dbg_str = search_params.dump(); - for (auto& kv : search_params.items()) { - std::string key = kv.key(); - if (key_list.count(key)) { - return key_list.at(key); - } - } - PanicCodeInfo(ErrorCodeEnum::IllegalArgument, "failed to infer index type"); -} - -static IndexType -InferBinaryIndexType(const Json& search_params) { - namespace ip = knowhere::indexparam; - namespace ie = knowhere::IndexEnum; - if (search_params.contains(ip::NPROBE)) { - return ie::INDEX_FAISS_BIN_IVFFLAT; - } else { - return ie::INDEX_FAISS_BIN_IDMAP; - } -} - void VerifyPlanNodeVisitor::visit(FloatVectorANNS&) { } @@ -67,6 +34,10 @@ void VerifyPlanNodeVisitor::visit(BinaryVectorANNS&) { } +void +VerifyPlanNodeVisitor::visit(Float16VectorANNS&) { +} + void VerifyPlanNodeVisitor::visit(RetrievePlanNode&) { } diff --git a/internal/core/src/segcore/ConcurrentVector.cpp b/internal/core/src/segcore/ConcurrentVector.cpp index 9ddbd1c085e4f..972a81dd057b0 100644 --- a/internal/core/src/segcore/ConcurrentVector.cpp +++ b/internal/core/src/segcore/ConcurrentVector.cpp @@ -30,8 +30,11 @@ VectorBase::set_data_raw(ssize_t element_offset, } else if (field_meta.get_data_type() == DataType::VECTOR_BINARY) { return set_data_raw( element_offset, VEC_FIELD_DATA(data, binary), element_count); + } else if (field_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + return set_data_raw( + element_offset, VEC_FIELD_DATA(data, float16), element_count); } else { - PanicInfo("unsupported"); + PanicInfo(DataTypeInvalid, "unsupported"); } } @@ -84,8 +87,19 @@ VectorBase::set_data_raw(ssize_t element_offset, return set_data_raw(element_offset, data_raw.data(), element_count); } + case DataType::ARRAY: { + auto& array_data = FIELD_DATA(data, array); + std::vector data_raw{}; + data_raw.reserve(array_data.size()); + for (auto& array_bytes : array_data) { + data_raw.emplace_back(Array(array_bytes)); + } + + return set_data_raw(element_offset, data_raw.data(), element_count); + } default: { - PanicInfo(fmt::format("unsupported datatype {}", + PanicInfo(DataTypeInvalid, + fmt::format("unsupported datatype {}", field_meta.get_data_type())); } } diff --git a/internal/core/src/segcore/ConcurrentVector.h b/internal/core/src/segcore/ConcurrentVector.h index d4af000cc748f..22dc50e08b6a4 100644 --- a/internal/core/src/segcore/ConcurrentVector.h +++ b/internal/core/src/segcore/ConcurrentVector.h @@ -30,7 +30,7 @@ #include "common/Span.h" #include "common/Types.h" #include "common/Utils.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "storage/FieldData.h" namespace milvus::segcore { @@ -148,12 +148,14 @@ class ConcurrentVectorImpl : public VectorBase { ConcurrentVectorImpl& operator=(const ConcurrentVectorImpl&) = delete; - using TraitType = - std::conditional_t, - FloatVector, - BinaryVector>>; + using TraitType = std::conditional_t< + is_scalar, + Type, + std::conditional_t, + FloatVector, + std::conditional_t, + Float16Vector, + BinaryVector>>>; public: explicit ConcurrentVectorImpl(ssize_t dim, int64_t size_per_chunk) @@ -180,9 +182,8 @@ class ConcurrentVectorImpl : public VectorBase { return Span(chunk.data(), chunk.size()); } else if constexpr (std::is_same_v || // NOLINT std::is_same_v) { - // TODO: where should the braces be placed? // only for testing - PanicInfo("unimplemented"); + PanicInfo(NotImplemented, "unimplemented"); } else { static_assert( std::is_same_v); @@ -380,13 +381,20 @@ class ConcurrentVector : public ConcurrentVectorImpl { public: explicit ConcurrentVector(int64_t dim, int64_t size_per_chunk) - : binary_dim_(dim), ConcurrentVectorImpl(dim / 8, size_per_chunk) { + : ConcurrentVectorImpl(dim / 8, size_per_chunk) { AssertInfo(dim % 8 == 0, fmt::format("dim is not a multiple of 8, dim={}", dim)); } +}; - private: - int64_t binary_dim_; +template <> +class ConcurrentVector + : public ConcurrentVectorImpl { + public: + ConcurrentVector(int64_t dim, int64_t size_per_chunk) + : ConcurrentVectorImpl::ConcurrentVectorImpl( + dim, size_per_chunk) { + } }; } // namespace milvus::segcore diff --git a/internal/core/src/segcore/FieldIndexing.cpp b/internal/core/src/segcore/FieldIndexing.cpp index cf636d127d954..f86245f2e57f2 100644 --- a/internal/core/src/segcore/FieldIndexing.cpp +++ b/internal/core/src/segcore/FieldIndexing.cpp @@ -11,27 +11,32 @@ #include #include +#include "common/EasyAssert.h" +#include "fmt/format.h" #include "index/ScalarIndexSort.h" #include "index/StringIndexSort.h" #include "common/SystemProperty.h" #include "segcore/FieldIndexing.h" -#include "index/VectorMemNMIndex.h" +#include "index/VectorMemIndex.h" #include "IndexConfigGenerator.h" namespace milvus::segcore { +using std::unique_ptr; VectorFieldIndexing::VectorFieldIndexing(const FieldMeta& field_meta, const FieldIndexMeta& field_index_meta, int64_t segment_max_row_count, const SegcoreConfig& segcore_config) : FieldIndexing(field_meta, segcore_config), - config_(std::make_unique( - segment_max_row_count, field_index_meta, segcore_config)), build(false), - sync_with_index(false) { - index_ = std::make_unique(config_->GetIndexType(), - config_->GetMetricType()); + sync_with_index(false), + config_(std::make_unique( + segment_max_row_count, field_index_meta, segcore_config)) { + index_ = std::make_unique( + config_->GetIndexType(), + config_->GetMetricType(), + knowhere::Version::GetCurrentVersion().VersionNumber()); } void @@ -50,8 +55,10 @@ VectorFieldIndexing::BuildIndexRange(int64_t ack_beg, data_.grow_to_at_least(ack_end); for (int chunk_id = ack_beg; chunk_id < ack_end; chunk_id++) { const auto& chunk = source->get_chunk(chunk_id); - auto indexing = std::make_unique( - knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, knowhere::metric::L2); + auto indexing = std::make_unique( + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, + knowhere::metric::L2, + knowhere::Version::GetCurrentVersion().VersionNumber()); auto dataset = knowhere::GenDataSet( source->get_size_per_chunk(), dim, chunk.data()); indexing->BuildWithDataset(dataset, conf); @@ -99,7 +106,7 @@ VectorFieldIndexing::AppendSegmentIndex(int64_t reserved_offset, int64_t vec_num = vector_id_end - vector_id_beg + 1; // for train index const void* data_addr; - std::unique_ptr vec_data; + unique_ptr vec_data; //all train data in one chunk if (chunk_id_beg == chunk_id_end) { data_addr = vec_base->get_chunk_data(chunk_id_beg); @@ -237,9 +244,15 @@ CreateIndex(const FieldMeta& field_meta, field_index_meta, segment_max_row_count, segcore_config); + } else if (field_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + return std::make_unique(field_meta, + field_index_meta, + segment_max_row_count, + segcore_config); } else { - // TODO - PanicInfo("unsupported"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported vector type in index: {}", + field_meta.get_data_type())); } } switch (field_meta.get_data_type()) { @@ -268,7 +281,9 @@ CreateIndex(const FieldMeta& field_meta, return std::make_unique>( field_meta, segcore_config); default: - PanicInfo("unsupported"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported scalar type in index: {}", + field_meta.get_data_type())); } } diff --git a/internal/core/src/segcore/FieldIndexing.h b/internal/core/src/segcore/FieldIndexing.h index 4917db2fbf875..ee3038b6a654c 100644 --- a/internal/core/src/segcore/FieldIndexing.h +++ b/internal/core/src/segcore/FieldIndexing.h @@ -111,7 +111,8 @@ class ScalarFieldIndexing : public FieldIndexing { int64_t size, const VectorBase* vec_base, const void* data_source) override { - PanicInfo("scalar index don't support append segment index"); + PanicInfo(Unsupported, + "scalar index don't support append segment index"); } void @@ -119,7 +120,8 @@ class ScalarFieldIndexing : public FieldIndexing { int64_t count, int64_t element_size, void* output) override { - PanicInfo("scalar index don't support get data from index"); + PanicInfo(Unsupported, + "scalar index don't support get data from index"); } idx_t get_index_cursor() override { @@ -249,12 +251,17 @@ class IndexingRecord { //Small-Index disabled, create index for vector field only if (index_meta_->GetIndexMaxRowCount() > 0 && index_meta_->HasFiled(field_id)) { - field_indexings_.try_emplace( - field_id, - CreateIndex(field_meta, - index_meta_->GetFieldIndexMeta(field_id), - index_meta_->GetIndexMaxRowCount(), - segcore_config_)); + auto vec_filed_meta = + index_meta_->GetFieldIndexMeta(field_id); + //Disable growing index for flat + if (!vec_filed_meta.IsFlatIndex()) { + field_indexings_.try_emplace( + field_id, + CreateIndex(field_meta, + vec_filed_meta, + index_meta_->GetIndexMaxRowCount(), + segcore_config_)); + } } } } @@ -335,11 +342,11 @@ class IndexingRecord { bool HasRawData(FieldId fieldId) const { - if (is_in(fieldId)) { + if (is_in(fieldId) && SyncDataWithIndex(fieldId)) { const FieldIndexing& indexing = get_field_indexing(fieldId); return indexing.has_raw_data(); } - return false; + return true; } // concurrent diff --git a/internal/core/src/segcore/InsertRecord.h b/internal/core/src/segcore/InsertRecord.h index f14a3ac03dc24..b03a09e53e994 100644 --- a/internal/core/src/segcore/InsertRecord.h +++ b/internal/core/src/segcore/InsertRecord.h @@ -21,8 +21,10 @@ #include #include "TimestampIndex.h" +#include "common/EasyAssert.h" #include "common/Schema.h" #include "common/Types.h" +#include "fmt/format.h" #include "mmap/Column.h" #include "segcore/AckResponder.h" #include "segcore/ConcurrentVector.h" @@ -40,6 +42,9 @@ class OffsetMap { public: virtual ~OffsetMap() = default; + virtual bool + contain(const PkType& pk) const = 0; + virtual std::vector find(const PkType& pk) const = 0; @@ -63,6 +68,11 @@ class OffsetMap { template class OffsetOrderedMap : public OffsetMap { public: + bool + contain(const PkType& pk) const override { + return map_.find(std::get(pk)) != map_.end(); + } + std::vector find(const PkType& pk) const override { auto offset_vector = map_.find(std::get(pk)); @@ -78,6 +88,7 @@ class OffsetOrderedMap : public OffsetMap { void seal() override { PanicInfo( + NotImplemented, "OffsetOrderedMap used for growing segment could not be sealed."); } @@ -104,21 +115,23 @@ class OffsetOrderedMap : public OffsetMap { find_first_by_index(int64_t limit, const BitsetType& bitset, bool false_filtered_out) const { - std::vector seg_offsets; - seg_offsets.reserve(limit); int64_t hit_num = 0; // avoid counting the number everytime. - auto cnt = bitset.count(); + int64_t cnt = bitset.count(); if (!false_filtered_out) { cnt = bitset.size() - bitset.count(); } - for (auto it = map_.begin(); it != map_.end(); it++) { - if (hit_num >= limit || hit_num >= cnt) { - break; - } + limit = std::min(limit, cnt); + std::vector seg_offsets; + seg_offsets.reserve(limit); + for (auto it = map_.begin(); hit_num < limit && it != map_.end(); + it++) { for (auto seg_offset : it->second) { if (!(bitset[seg_offset] ^ false_filtered_out)) { seg_offsets.push_back(seg_offset); hit_num++; + if (hit_num >= limit) { + break; + } } } } @@ -133,6 +146,19 @@ class OffsetOrderedMap : public OffsetMap { template class OffsetOrderedArray : public OffsetMap { public: + bool + contain(const PkType& pk) const override { + const T& target = std::get(pk); + auto it = + std::lower_bound(array_.begin(), + array_.end(), + target, + [](const std::pair& elem, + const T& value) { return elem.first < value; }); + + return it != array_.end(); + } + std::vector find(const PkType& pk) const override { check_search(); @@ -155,8 +181,10 @@ class OffsetOrderedArray : public OffsetMap { void insert(const PkType& pk, int64_t offset) override { - if (is_sealed) - PanicInfo("OffsetOrderedArray could not insert after seal"); + if (is_sealed) { + PanicInfo(Unsupported, + "OffsetOrderedArray could not insert after seal"); + } array_.push_back(std::make_pair(std::get(pk), offset)); } @@ -191,17 +219,16 @@ class OffsetOrderedArray : public OffsetMap { find_first_by_index(int64_t limit, const BitsetType& bitset, bool false_filtered_out) const { - std::vector seg_offsets; - seg_offsets.reserve(limit); int64_t hit_num = 0; // avoid counting the number everytime. - auto cnt = bitset.count(); + int64_t cnt = bitset.count(); if (!false_filtered_out) { cnt = bitset.size() - bitset.count(); } - for (auto it = array_.begin(); it != array_.end(); it++) { - if (hit_num >= limit || hit_num >= cnt) { - break; - } + limit = std::min(limit, cnt); + std::vector seg_offsets; + seg_offsets.reserve(limit); + for (auto it = array_.begin(); hit_num < limit && it != array_.end(); + it++) { if (!(bitset[it->second] ^ false_filtered_out)) { seg_offsets.push_back(it->second); hit_num++; @@ -247,25 +274,29 @@ struct InsertRecord { pk_field_id.value() == field_id) { switch (field_meta.get_data_type()) { case DataType::INT64: { - if (is_sealed) + if (is_sealed) { pk2offset_ = std::make_unique>(); - else + } else { pk2offset_ = std::make_unique>(); + } break; } case DataType::VARCHAR: { - if (is_sealed) + if (is_sealed) { pk2offset_ = std::make_unique< OffsetOrderedArray>(); - else + } else { pk2offset_ = std::make_unique< OffsetOrderedMap>(); + } break; } default: { - PanicInfo("unsupported pk type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported pk type", + field_meta.get_data_type())); } } } @@ -279,8 +310,15 @@ struct InsertRecord { this->append_field_data( field_id, field_meta.get_dim(), size_per_chunk); continue; + } else if (field_meta.get_data_type() == + DataType::VECTOR_FLOAT16) { + this->append_field_data( + field_id, field_meta.get_dim(), size_per_chunk); + continue; } else { - PanicInfo("unsupported"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported vector type", + field_meta.get_data_type())); } } switch (field_meta.get_data_type()) { @@ -321,18 +359,24 @@ struct InsertRecord { this->append_field_data(field_id, size_per_chunk); break; } - // case DataType::ARRAY: { - // this->append_field_data(field_id, - // size_per_chunk); - // break; - // } + case DataType::ARRAY: { + this->append_field_data(field_id, size_per_chunk); + break; + } default: { - PanicInfo("unsupported"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported scalar type", + field_meta.get_data_type())); } } } } + bool + contain(const PkType& pk) const { + return pk2offset_->contain(pk); + } + std::vector search_pk(const PkType& pk, Timestamp timestamp) const { std::shared_lock lck(shared_mutex_); @@ -372,7 +416,9 @@ struct InsertRecord { break; } default: { - PanicInfo("unsupported primary key data type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported primary key data type", + data_type)); } } } @@ -402,7 +448,9 @@ struct InsertRecord { break; } default: { - PanicInfo("unsupported primary key data type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported primary key data type", + data_type)); } } } diff --git a/internal/core/src/segcore/Reduce.cpp b/internal/core/src/segcore/Reduce.cpp index c5a48d9221ceb..e3781e93d93cb 100644 --- a/internal/core/src/segcore/Reduce.cpp +++ b/internal/core/src/segcore/Reduce.cpp @@ -19,7 +19,7 @@ #include "SegmentInterface.h" #include "Utils.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "pkVisitor.h" namespace milvus::segcore { @@ -124,6 +124,8 @@ ReduceHelper::FillPrimaryKey() { uint32_t valid_index = 0; for (auto& search_result : search_results_) { FilterInvalidSearchResult(search_result); + LOG_SEGCORE_DEBUG_ << "the size of search result" + << search_result->seg_offsets_.size(); if (search_result->get_total_result_count() > 0) { auto segment = static_cast(search_result->segment_); @@ -318,7 +320,8 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) { break; } default: { - PanicInfo("unsupported primary key type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported primary key type {}", pk_type)); } } @@ -365,7 +368,9 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) { break; } default: { - PanicInfo("unsupported primary key type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported primary key type {}", + pk_type)); } } @@ -391,6 +396,12 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) { auto& field_meta = plan_->schema_[field_id]; auto field_data = milvus::segcore::MergeDataArray(result_pairs, field_meta); + if (field_meta.get_data_type() == DataType::ARRAY) { + field_data->mutable_scalars() + ->mutable_array_data() + ->set_element_type( + proto::schema::DataType(field_meta.get_element_type())); + } search_result_data->mutable_fields_data()->AddAllocated( field_data.release()); } diff --git a/internal/core/src/segcore/Reduce.h b/internal/core/src/segcore/Reduce.h index eeeb8d9fdfa64..073beb07ef90f 100644 --- a/internal/core/src/segcore/Reduce.h +++ b/internal/core/src/segcore/Reduce.h @@ -18,7 +18,6 @@ #include #include -#include "utils/Status.h" #include "common/type_c.h" #include "common/QueryResult.h" #include "query/PlanImpl.h" @@ -84,15 +83,15 @@ class ReduceHelper { GetSearchResultDataSlice(int slice_index_); private: - std::vector slice_topKs_; + std::vector& search_results_; + milvus::query::Plan* plan_; + std::vector slice_nqs_; + std::vector slice_topKs_; int64_t total_nq_; int64_t num_segments_; int64_t num_slices_; - milvus::query::Plan* plan_; - std::vector& search_results_; - std::vector slice_nqs_prefix_sum_; // dim0: num_segments_; dim1: total_nq_; dim2: offset diff --git a/internal/core/src/segcore/ScalarIndex.cpp b/internal/core/src/segcore/ScalarIndex.cpp index bc988b035688d..c5aaacdd70f09 100644 --- a/internal/core/src/segcore/ScalarIndex.cpp +++ b/internal/core/src/segcore/ScalarIndex.cpp @@ -9,7 +9,7 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "ScalarIndex.h" namespace milvus::segcore { diff --git a/internal/core/src/segcore/SealedIndexingRecord.h b/internal/core/src/segcore/SealedIndexingRecord.h index a17481f840eb6..5b16115f96bce 100644 --- a/internal/core/src/segcore/SealedIndexingRecord.h +++ b/internal/core/src/segcore/SealedIndexingRecord.h @@ -19,7 +19,7 @@ #include #include "common/Types.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "index/VectorIndex.h" namespace milvus::segcore { diff --git a/internal/core/src/segcore/SegcoreConfig.cpp b/internal/core/src/segcore/SegcoreConfig.cpp index 292de265ad540..478100e640a7a 100644 --- a/internal/core/src/segcore/SegcoreConfig.cpp +++ b/internal/core/src/segcore/SegcoreConfig.cpp @@ -11,7 +11,7 @@ #include "common/Schema.h" #include "SegcoreConfig.h" -#include "utils/Json.h" +#include "common/Json.h" #include "yaml-cpp/yaml.h" namespace milvus::segcore { @@ -37,7 +37,7 @@ apply_parser(const YAML::Node& node, Func func) { results.emplace_back(func(element)); } } else { - PanicInfo("node should be scalar or sequence"); + PanicInfo(ConfigInvalid, "node should be scalar or sequence"); } return results; } @@ -102,7 +102,7 @@ SegcoreConfig::parse_from(const std::string& config_path) { } catch (const std::exception& e) { std::string str = std::string("Invalid Yaml: ") + config_path + ", err: " + e.what(); - PanicInfo(str); + PanicInfo(ConfigInvalid, str); } } diff --git a/internal/core/src/segcore/SegcoreConfig.h b/internal/core/src/segcore/SegcoreConfig.h index dc2360a52610e..c7f3175119018 100644 --- a/internal/core/src/segcore/SegcoreConfig.h +++ b/internal/core/src/segcore/SegcoreConfig.h @@ -15,9 +15,9 @@ #include #include "common/Types.h" +#include "common/Json.h" #include "index/Utils.h" -#include "exceptions/EasyAssert.h" -#include "utils/Json.h" +#include "common/EasyAssert.h" namespace milvus::segcore { diff --git a/internal/core/src/segcore/SegmentGrowing.h b/internal/core/src/segcore/SegmentGrowing.h index 64a7435002711..f00b3fc3fca69 100644 --- a/internal/core/src/segcore/SegmentGrowing.h +++ b/internal/core/src/segcore/SegmentGrowing.h @@ -18,7 +18,6 @@ #include "common/Schema.h" #include "common/Types.h" #include "query/Plan.h" -#include "query/deprecated/GeneralQuery.h" #include "segcore/SegmentInterface.h" namespace milvus::segcore { diff --git a/internal/core/src/segcore/SegmentGrowingImpl.cpp b/internal/core/src/segcore/SegmentGrowingImpl.cpp index 49e2e082811cc..3c86d4ba85d73 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.cpp +++ b/internal/core/src/segcore/SegmentGrowingImpl.cpp @@ -19,10 +19,9 @@ #include #include "common/Consts.h" +#include "common/EasyAssert.h" #include "common/Types.h" #include "common/macro.h" -#include "exceptions/EasyAssert.h" -#include "fmt/core.h" #include "nlohmann/json.hpp" #include "query/PlanNode.h" #include "query/SearchOnSealed.h" @@ -307,7 +306,7 @@ SegmentGrowingImpl::LoadFieldDataV2(const LoadFieldDataInfo& infos) { insert_record_.ack_responder_.AddSegment(reserved_offset, reserved_offset + num_rows); } -Status +SegcoreError SegmentGrowingImpl::Delete(int64_t reserved_begin, int64_t size, const IdArray* ids, @@ -318,6 +317,15 @@ SegmentGrowingImpl::Delete(int64_t reserved_begin, std::vector pks(size); ParsePksFromIDs(pks, field_meta.get_data_type(), *ids); + // filter out the deletions that the primary key not exists + auto end = std::remove_if(pks.begin(), pks.end(), [&](const PkType& pk) { + return !insert_record_.contain(pk); + }); + size = end - pks.begin(); + if (size == 0) { + return SegcoreError::success(); + } + // step 1: sort timestamp std::vector> ordering(size); for (int i = 0; i < size; i++) { @@ -335,7 +343,7 @@ SegmentGrowingImpl::Delete(int64_t reserved_begin, // step 2: fill delete record deleted_record_.push(sort_pks, sort_timestamps.data()); - return Status::OK(); + return SegcoreError::success(); } int64_t @@ -412,7 +420,6 @@ std::unique_ptr SegmentGrowingImpl::bulk_subscript(FieldId field_id, const int64_t* seg_offsets, int64_t count) const { - // TODO: support more types auto vec_ptr = insert_record_.get_field_data_base(field_id); auto& field_meta = schema_->operator[](field_id); if (field_meta.is_vector()) { @@ -431,8 +438,15 @@ SegmentGrowingImpl::bulk_subscript(FieldId field_id, seg_offsets, count, output.data()); + } else if (field_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + bulk_subscript_impl(field_id, + field_meta.get_sizeof(), + vec_ptr, + seg_offsets, + count, + output.data()); } else { - PanicInfo("logical error"); + PanicInfo(DataTypeInvalid, "logical error"); } return CreateVectorDataArrayFrom(output.data(), count, field_meta); } @@ -494,8 +508,16 @@ SegmentGrowingImpl::bulk_subscript(FieldId field_id, vec_ptr, seg_offsets, count, output.data()); return CreateScalarDataArrayFrom(output.data(), count, field_meta); } + case DataType::ARRAY: { + // element + FixedVector output(count); + bulk_subscript_impl(*vec_ptr, seg_offsets, count, output.data()); + return CreateScalarDataArrayFrom(output.data(), count, field_meta); + } default: { - PanicInfo("unsupported type"); + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported type {}", field_meta.get_data_type())); } } } @@ -558,6 +580,23 @@ SegmentGrowingImpl::bulk_subscript_impl(const VectorBase* vec_raw, } } +void +SegmentGrowingImpl::bulk_subscript_impl(const VectorBase& vec_raw, + const int64_t* seg_offsets, + int64_t count, + void* output_raw) const { + auto vec_ptr = dynamic_cast*>(&vec_raw); + AssertInfo(vec_ptr, "Pointer of vec_raw is nullptr"); + auto& vec = *vec_ptr; + auto output = reinterpret_cast(output_raw); + for (int64_t i = 0; i < count; ++i) { + auto offset = seg_offsets[i]; + if (offset != INVALID_SEG_OFFSET) { + output[i] = vec[offset].output_data(); + } + } +} + void SegmentGrowingImpl::bulk_subscript(SystemFieldType system_type, const int64_t* seg_offsets, @@ -573,7 +612,7 @@ SegmentGrowingImpl::bulk_subscript(SystemFieldType system_type, &this->insert_record_.row_ids_, seg_offsets, count, output); break; default: - PanicInfo("unknown subscript fields"); + PanicInfo(DataTypeInvalid, "unknown subscript fields"); } } @@ -590,7 +629,8 @@ SegmentGrowingImpl::search_ids(const IdArray& id_array, auto res_id_arr = std::make_unique(); std::vector res_offsets; - for (auto pk : pks) { + res_offsets.reserve(pks.size()); + for (auto& pk : pks) { auto segOffsets = insert_record_.search_pk(pk, timestamp); for (auto offset : segOffsets) { switch (data_type) { @@ -601,11 +641,12 @@ SegmentGrowingImpl::search_ids(const IdArray& id_array, } case DataType::VARCHAR: { res_id_arr->mutable_str_id()->add_data( - std::get(pk)); + std::get(std::move(pk))); break; } default: { - PanicInfo("unsupported type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported type {}", data_type)); } } res_offsets.push_back(offset); diff --git a/internal/core/src/segcore/SegmentGrowingImpl.h b/internal/core/src/segcore/SegmentGrowingImpl.h index 27afdc28b1a56..b836645157932 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.h +++ b/internal/core/src/segcore/SegmentGrowingImpl.h @@ -28,11 +28,9 @@ #include "InsertRecord.h" #include "SealedIndexingRecord.h" #include "SegmentGrowing.h" - -#include "exceptions/EasyAssert.h" +#include "common/Types.h" +#include "common/EasyAssert.h" #include "query/PlanNode.h" -#include "query/deprecated/GeneralQuery.h" -#include "utils/Status.h" #include "common/IndexMeta.h" namespace milvus::segcore { @@ -49,8 +47,13 @@ class SegmentGrowingImpl : public SegmentGrowing { const Timestamp* timestamps, const InsertData* insert_data) override; + bool + Contain(const PkType& pk) const override { + return insert_record_.contain(pk); + } + // TODO: add id into delete log, possibly bitmap - Status + SegcoreError Delete(int64_t reserved_offset, int64_t size, const IdArray* pks, @@ -156,6 +159,13 @@ class SegmentGrowingImpl : public SegmentGrowing { int64_t count, void* output_raw) const; + // for scalar array vectors + void + bulk_subscript_impl(const VectorBase& vec_raw, + const int64_t* seg_offsets, + int64_t count, + void* output_raw) const; + template void bulk_subscript_impl(FieldId field_id, diff --git a/internal/core/src/segcore/SegmentInterface.cpp b/internal/core/src/segcore/SegmentInterface.cpp index cbf1de28490ec..bbc08206059e5 100644 --- a/internal/core/src/segcore/SegmentInterface.cpp +++ b/internal/core/src/segcore/SegmentInterface.cpp @@ -14,6 +14,7 @@ #include #include "Utils.h" +#include "common/EasyAssert.h" #include "common/SystemProperty.h" #include "common/Tracer.h" #include "common/Types.h" @@ -93,8 +94,9 @@ SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan, output_data_size += get_field_avg_size(field_id) * result_rows; } if (output_data_size > limit_size) { - throw std::runtime_error("query results exceed the limit size " + - std::to_string(limit_size)); + throw SegcoreError( + RetrieveError, + fmt::format("query results exceed the limit size ", limit_size)); } if (plan->plan_node_->is_count_) { @@ -139,6 +141,10 @@ SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan, auto col = bulk_subscript(field_id, retrieve_results.result_offsets_.data(), retrieve_results.result_offsets_.size()); + if (field_meta.get_data_type() == DataType::ARRAY) { + col->mutable_scalars()->mutable_array_data()->set_element_type( + proto::schema::DataType(field_meta.get_element_type())); + } auto col_data = col.release(); fields_data->AddAllocated(col_data); if (pk_field_id.has_value() && pk_field_id.value() == field_id) { @@ -159,7 +165,9 @@ SegmentInternalInterface::Retrieve(const query::RetrievePlan* plan, break; } default: { - PanicInfo("unsupported data type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported datatype {}", + field_meta.get_data_type())); } } } @@ -200,7 +208,7 @@ SegmentInternalInterface::get_field_avg_size(FieldId field_id) const { return sizeof(int64_t); } - throw std::runtime_error("unsupported system field id"); + throw SegcoreError(FieldIDInvalid, "unsupported system field id"); } auto schema = get_schema(); diff --git a/internal/core/src/segcore/SegmentInterface.h b/internal/core/src/segcore/SegmentInterface.h index a9477009c4a05..fbf25b81bccda 100644 --- a/internal/core/src/segcore/SegmentInterface.h +++ b/internal/core/src/segcore/SegmentInterface.h @@ -47,6 +47,9 @@ class SegmentInterface { virtual void FillTargetEntry(const query::Plan* plan, SearchResult& results) const = 0; + virtual bool + Contain(const PkType& pk) const = 0; + virtual std::unique_ptr Search(const query::Plan* Plan, const query::PlaceholderGroup* placeholder_group) const = 0; @@ -83,7 +86,7 @@ class SegmentInterface { // virtual int64_t // PreDelete(int64_t size) = 0; - virtual Status + virtual SegcoreError Delete(int64_t reserved_offset, int64_t size, const IdArray* pks, diff --git a/internal/core/src/segcore/SegmentSealed.h b/internal/core/src/segcore/SegmentSealed.h index 48337f3a2a6cf..3771646cbf4cd 100644 --- a/internal/core/src/segcore/SegmentSealed.h +++ b/internal/core/src/segcore/SegmentSealed.h @@ -13,6 +13,7 @@ #include #include +#include #include "common/LoadInfo.h" #include "pb/segcore.pb.h" @@ -37,6 +38,8 @@ class SegmentSealed : public SegmentInternalInterface { LoadFieldData(FieldId field_id, FieldDataInfo& data) = 0; virtual void MapFieldData(const FieldId field_id, FieldDataInfo& data) = 0; + virtual void + AddFieldDataInfoForSealed(const LoadFieldDataInfo& field_data_info) = 0; SegmentType type() const override { diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index 28716d0a23248..7278b86098e25 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -26,7 +27,8 @@ #include "Types.h" #include "common/Json.h" #include "common/LoadInfo.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" +#include "common/Array.h" #include "mmap/Column.h" #include "common/Consts.h" #include "common/FieldMeta.h" @@ -39,7 +41,8 @@ #include "storage/FieldData.h" #include "storage/Util.h" #include "storage/ThreadPools.h" -#include "utils/File.h" +#include "storage/ChunkCacheSingleton.h" +#include "common/File.h" #include "common/Tracer.h" namespace milvus::segcore { @@ -160,7 +163,9 @@ SegmentSealedImpl::LoadScalarIndex(const LoadIndexInfo& info) { break; } default: { - PanicInfo("unsupported primary key type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported primary key type {}", + field_meta.get_data_type())); } } } @@ -352,7 +357,25 @@ SegmentSealedImpl::LoadFieldData(FieldId field_id, FieldDataInfo& data) { column = std::move(var_column); break; } + case milvus::DataType::ARRAY: { + auto var_column = + std::make_shared(num_rows, field_meta); + storage::FieldDataPtr field_data; + while (data.channel->pop(field_data)) { + for (auto i = 0; i < field_data->get_num_rows(); i++) { + auto rawValue = field_data->RawValue(i); + auto array = + static_cast(rawValue); + var_column->Append(*array); + } + } + var_column->Seal(); + column = std::move(var_column); + break; + } default: { + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type", data_type)); } } @@ -411,10 +434,12 @@ SegmentSealedImpl::MapFieldData(const FieldId field_id, FieldDataInfo& data) { size_t total_written{0}; auto data_size = 0; std::vector indices{}; + std::vector> element_indices{}; storage::FieldDataPtr field_data; while (data.channel->pop(field_data)) { data_size += field_data->Size(); - auto written = WriteFieldData(file, data_type, field_data); + auto written = + WriteFieldData(file, data_type, field_data, element_indices); if (written != field_data->Size()) { break; } @@ -454,7 +479,17 @@ SegmentSealedImpl::MapFieldData(const FieldId field_id, FieldDataInfo& data) { column = std::move(var_column); break; } + case milvus::DataType::ARRAY: { + auto arr_column = std::make_shared( + file, total_written, field_meta); + arr_column->Seal(std::move(indices), + std::move(element_indices)); + column = std::move(arr_column); + break; + } default: { + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", data_type)); } } } else { @@ -502,6 +537,13 @@ SegmentSealedImpl::LoadDeletedRecord(const LoadDeletedRecordInfo& info) { deleted_record_.push(pks, timestamps); } +void +SegmentSealedImpl::AddFieldDataInfoForSealed( + const LoadFieldDataInfo& field_data_info) { + // copy assignment + field_data_info_ = field_data_info; +} + // internal API: support scalar index only int64_t SegmentSealedImpl::num_chunk_index(FieldId field_id) const { @@ -642,12 +684,34 @@ SegmentSealedImpl::vector_search(SearchInfo& search_info, } } +std::tuple +SegmentSealedImpl::GetFieldDataPath(FieldId field_id, int64_t offset) const { + auto offset_in_binlog = offset; + auto data_path = std::string(); + auto it = field_data_info_.field_infos.find(field_id.get()); + AssertInfo(it != field_data_info_.field_infos.end(), + fmt::format("cannot find binlog file for field: {}, seg: {}", + field_id.get(), + id_)); + auto field_info = it->second; + + for (auto i = 0; i < field_info.insert_files.size(); i++) { + if (offset_in_binlog < field_info.entries_nums[i]) { + data_path = field_info.insert_files[i]; + break; + } else { + offset_in_binlog -= field_info.entries_nums[i]; + } + } + return {data_path, offset_in_binlog}; +} + std::unique_ptr SegmentSealedImpl::get_vector(FieldId field_id, const int64_t* ids, int64_t count) const { - auto& filed_meta = schema_->operator[](field_id); - AssertInfo(filed_meta.is_vector(), "vector field is not vector type"); + auto& field_meta = schema_->operator[](field_id); + AssertInfo(field_meta.is_vector(), "vector field is not vector type"); if (get_bit(index_ready_bitset_, field_id)) { AssertInfo(vector_indexings_.is_ready(field_id), @@ -661,10 +725,57 @@ SegmentSealedImpl::get_vector(FieldId field_id, auto has_raw_data = vec_index->HasRawData(); if (has_raw_data) { + // If index has raw data, get vector from memory. auto ids_ds = GenIdsDataset(count, ids); auto vector = vec_index->GetVector(ids_ds); return segcore::CreateVectorDataArrayFrom( - vector.data(), count, filed_meta); + vector.data(), count, field_meta); + } else { + // If index doesn't have raw data, get vector from chunk cache. + auto cc = + storage::ChunkCacheSingleton::GetInstance().GetChunkCache(); + + // group by data_path + auto id_to_data_path = + std::unordered_map>{}; + auto path_to_column = + std::unordered_map>{}; + for (auto i = 0; i < count; i++) { + const auto& tuple = GetFieldDataPath(field_id, ids[i]); + id_to_data_path.emplace(ids[i], tuple); + path_to_column.emplace(std::get<0>(tuple), nullptr); + } + + // read and prefetch + for (const auto& iter : path_to_column) { + auto data_path = iter.first; + const auto& column = cc->Read(data_path); + cc->Prefetch(data_path); + path_to_column[data_path] = column; + } + + // assign to data array + auto dim = field_meta.get_dim(); + auto row_bytes = field_meta.is_vector() ? dim * 4 : dim / 8; + auto buf = std::vector(count * row_bytes); + for (auto i = 0; i < count; i++) { + AssertInfo(id_to_data_path.count(ids[i]) != 0, "id not found"); + const auto& [data_path, offset_in_binlog] = + id_to_data_path.at(ids[i]); + AssertInfo(path_to_column.count(data_path) != 0, + "column not found"); + const auto& column = path_to_column.at(data_path); + AssertInfo( + offset_in_binlog * row_bytes < column->ByteSize(), + fmt::format("column idx out of range, idx: {}, size: {}", + offset_in_binlog * row_bytes, + column->ByteSize())); + auto vector = &column->Data()[offset_in_binlog * row_bytes]; + std::memcpy(buf.data() + i * row_bytes, vector, row_bytes); + } + return segcore::CreateVectorDataArrayFrom( + buf.data(), count, field_meta); } } @@ -717,9 +828,10 @@ SegmentSealedImpl::check_search(const query::Plan* plan) const { if (!is_system_field_ready()) { PanicInfo( + FieldNotLoaded, "failed to load row ID or timestamp, potential missing bin logs or " "empty segments. Segment ID = " + - std::to_string(this->id_)); + std::to_string(this->id_)); } auto& request_fields = plan->extra_info_opt_.value().involved_fields_; @@ -733,20 +845,34 @@ SegmentSealedImpl::check_search(const query::Plan* plan) const { auto field_id = FieldId(absent_fields.find_first() + START_USER_FIELDID); auto& field_meta = schema_->operator[](field_id); - PanicInfo("User Field(" + field_meta.get_name().get() + - ") is not loaded"); + PanicInfo( + FieldNotLoaded, + "User Field(" + field_meta.get_name().get() + ") is not loaded"); } } SegmentSealedImpl::SegmentSealedImpl(SchemaPtr schema, int64_t segment_id) - : schema_(schema), - insert_record_(*schema, MAX_ROW_COUNT), - field_data_ready_bitset_(schema->size()), + : field_data_ready_bitset_(schema->size()), index_ready_bitset_(schema->size()), scalar_indexings_(schema->size()), + insert_record_(*schema, MAX_ROW_COUNT), + schema_(schema), id_(segment_id) { } +SegmentSealedImpl::~SegmentSealedImpl() { + auto cc = storage::ChunkCacheSingleton::GetInstance().GetChunkCache(); + if (cc == nullptr) { + return; + } + // munmap and remove binlog from chunk cache + for (const auto& iter : field_data_info_.field_infos) { + for (const auto& binlog : iter.second.insert_files) { + cc->Remove(binlog); + } + } +} + void SegmentSealedImpl::bulk_subscript(SystemFieldType system_type, const int64_t* seg_offsets, @@ -775,7 +901,8 @@ SegmentSealedImpl::bulk_subscript(SystemFieldType system_type, output); break; default: - PanicInfo("unknown subscript fields"); + PanicInfo(DataTypeInvalid, + fmt::format("unknown subscript fields", system_type)); } } @@ -809,6 +936,21 @@ SegmentSealedImpl::bulk_subscript_impl(const ColumnBase* column, } } +void +SegmentSealedImpl::bulk_subscript_impl(const ColumnBase* column, + const int64_t* seg_offsets, + int64_t count, + void* dst_raw) { + auto field = reinterpret_cast(column); + auto dst = reinterpret_cast(dst_raw); + for (int64_t i = 0; i < count; ++i) { + auto offset = seg_offsets[i]; + if (offset != INVALID_SEG_OFFSET) { + dst[i] = std::move(field->RawAt(offset)); + } + } +} + // for vector void SegmentSealedImpl::bulk_subscript_impl(int64_t element_sizeof, @@ -883,8 +1025,17 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id, output.data(), count, field_meta); } + case DataType::ARRAY: { + FixedVector output(count); + bulk_subscript_impl( + column.get(), seg_offsets, count, output.data()); + return CreateScalarDataArrayFrom( + output.data(), count, field_meta); + } + default: PanicInfo( + DataTypeInvalid, fmt::format("unsupported data type: {}", datatype_name(field_meta.get_data_type()))); } @@ -936,6 +1087,7 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id, } case DataType::VECTOR_FLOAT: + case DataType::VECTOR_FLOAT16: case DataType::VECTOR_BINARY: { aligned_vector output(field_meta.get_sizeof() * count); bulk_subscript_impl(field_meta.get_sizeof(), @@ -947,7 +1099,9 @@ SegmentSealedImpl::bulk_subscript(FieldId field_id, } default: { - PanicInfo("unsupported"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", + field_meta.get_data_type())); } } } @@ -999,7 +1153,8 @@ SegmentSealedImpl::search_ids(const IdArray& id_array, auto res_id_arr = std::make_unique(); std::vector res_offsets; - for (auto pk : pks) { + res_offsets.reserve(pks.size()); + for (auto& pk : pks) { auto segOffsets = insert_record_.search_pk(pk, timestamp); for (auto offset : segOffsets) { switch (data_type) { @@ -1010,11 +1165,12 @@ SegmentSealedImpl::search_ids(const IdArray& id_array, } case DataType::VARCHAR: { res_id_arr->mutable_str_id()->add_data( - std::get(pk)); + std::get(std::move(pk))); break; } default: { - PanicInfo("unsupported type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported type {}", data_type)); } } res_offsets.push_back(offset); @@ -1023,7 +1179,7 @@ SegmentSealedImpl::search_ids(const IdArray& id_array, return {std::move(res_id_arr), std::move(res_offsets)}; } -Status +SegcoreError SegmentSealedImpl::Delete(int64_t reserved_offset, // deprecated int64_t size, const IdArray* ids, @@ -1034,6 +1190,15 @@ SegmentSealedImpl::Delete(int64_t reserved_offset, // deprecated std::vector pks(size); ParsePksFromIDs(pks, field_meta.get_data_type(), *ids); + // filter out the deletions that the primary key not exists + auto end = std::remove_if(pks.begin(), pks.end(), [&](const PkType& pk) { + return !insert_record_.contain(pk); + }); + size = end - pks.begin(); + if (size == 0) { + return SegcoreError::success(); + } + // step 1: sort timestamp std::vector> ordering(size); for (int i = 0; i < size; i++) { @@ -1050,7 +1215,7 @@ SegmentSealedImpl::Delete(int64_t reserved_offset, // deprecated } deleted_record_.push(sort_pks, sort_timestamps.data()); - return Status::OK(); + return SegcoreError::success(); } std::string @@ -1070,7 +1235,7 @@ SegmentSealedImpl::LoadSegmentMeta( slice_lengths.push_back(info.row_count()); } insert_record_.timestamp_index_.set_length_meta(std::move(slice_lengths)); - PanicInfo("unimplemented"); + PanicInfo(NotImplemented, "unimplemented"); } int64_t diff --git a/internal/core/src/segcore/SegmentSealedImpl.h b/internal/core/src/segcore/SegmentSealedImpl.h index 12e66341fccf3..3c945d1620c4b 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.h +++ b/internal/core/src/segcore/SegmentSealedImpl.h @@ -28,6 +28,7 @@ #include "SealedIndexingRecord.h" #include "SegmentSealed.h" #include "TimestampIndex.h" +#include "common/EasyAssert.h" #include "mmap/Column.h" #include "index/ScalarIndex.h" #include "sys/mman.h" @@ -38,7 +39,7 @@ namespace milvus::segcore { class SegmentSealedImpl : public SegmentSealed { public: explicit SegmentSealedImpl(SchemaPtr schema, int64_t segment_id); - ~SegmentSealedImpl() override = default; + ~SegmentSealedImpl() override; void LoadIndex(const LoadIndexInfo& info) override; void @@ -59,10 +60,18 @@ class SegmentSealedImpl : public SegmentSealed { bool HasFieldData(FieldId field_id) const override; + bool + Contain(const PkType& pk) const override { + return insert_record_.contain(pk); + } + void LoadFieldData(FieldId field_id, FieldDataInfo& data) override; void MapFieldData(const FieldId field_id, FieldDataInfo& data) override; + void + AddFieldDataInfoForSealed( + const LoadFieldDataInfo& field_data_info) override; int64_t get_segment_id() const override { @@ -106,7 +115,7 @@ class SegmentSealedImpl : public SegmentSealed { std::string debug() const override; - Status + SegcoreError Delete(int64_t reserved_offset, int64_t size, const IdArray* pks, @@ -120,6 +129,13 @@ class SegmentSealedImpl : public SegmentSealed { limit, bitset, false_filtered_out); } + // Calculate: output[i] = Vec[seg_offset[i]] + // where Vec is determined from field_offset + std::unique_ptr + bulk_subscript(FieldId field_id, + const int64_t* seg_offsets, + int64_t count) const override; + protected: // blob and row_count SpanBase @@ -136,13 +152,6 @@ class SegmentSealedImpl : public SegmentSealed { int64_t count, void* output) const override; - // Calculate: output[i] = Vec[seg_offset[i]] - // where Vec is determined from field_offset - std::unique_ptr - bulk_subscript(FieldId field_id, - const int64_t* seg_offsets, - int64_t count) const override; - void check_search(const query::Plan* plan) const override; @@ -169,6 +178,12 @@ class SegmentSealedImpl : public SegmentSealed { int64_t count, void* dst_raw); + static void + bulk_subscript_impl(const ColumnBase* column, + const int64_t* seg_offsets, + int64_t count, + void* dst_raw); + static void bulk_subscript_impl(int64_t element_sizeof, const void* src_raw, @@ -218,6 +233,9 @@ class SegmentSealedImpl : public SegmentSealed { std::pair, std::vector> search_ids(const IdArray& id_array, Timestamp timestamp) const override; + std::tuple + GetFieldDataPath(FieldId field_id, int64_t offset) const; + void LoadVecIndex(const LoadIndexInfo& info); @@ -245,6 +263,8 @@ class SegmentSealedImpl : public SegmentSealed { // deleted pks mutable DeletedRecord deleted_record_; + LoadFieldDataInfo field_data_info_; + SchemaPtr schema_; int64_t id_; std::unordered_map> fields_; diff --git a/internal/core/src/segcore/Types.h b/internal/core/src/segcore/Types.h index 78af7e5a4c2f4..9df88e008e13a 100644 --- a/internal/core/src/segcore/Types.h +++ b/internal/core/src/segcore/Types.h @@ -45,6 +45,7 @@ struct LoadIndexInfo { std::string blob_name; std::string uri; int64_t storage_version; + IndexVersion index_engine_version; }; } // namespace milvus::segcore diff --git a/internal/core/src/segcore/Utils.cpp b/internal/core/src/segcore/Utils.cpp index 9419bcf521b37..69f87d7bf84de 100644 --- a/internal/core/src/segcore/Utils.cpp +++ b/internal/core/src/segcore/Utils.cpp @@ -27,7 +27,8 @@ namespace milvus::segcore { void ParsePksFromFieldData(std::vector& pks, const DataArray& data) { - switch (static_cast(data.type())) { + auto data_type = static_cast(data.type()); + switch (data_type) { case DataType::INT64: { auto source_data = reinterpret_cast( data.scalars().long_data().data().data()); @@ -40,7 +41,8 @@ ParsePksFromFieldData(std::vector& pks, const DataArray& data) { break; } default: { - PanicInfo("unsupported"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported PK {}", data_type)); } } } @@ -69,7 +71,8 @@ ParsePksFromFieldData(DataType data_type, break; } default: { - PanicInfo("unsupported"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported PK {}", data_type)); } } offset += row_count; @@ -93,7 +96,8 @@ ParsePksFromIDs(std::vector& pks, break; } default: { - PanicInfo("unsupported"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported PK {}", data_type)); } } } @@ -108,7 +112,8 @@ GetSizeOfIdArray(const IdArray& data) { return data.str_id().data_size(); } - PanicInfo("unsupported id type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported id {}", data.descriptor()->name())); } int64_t @@ -136,8 +141,70 @@ GetRawDataSizeOfDataArray(const DataArray* data, } break; } + case DataType::ARRAY: { + auto& array_data = FIELD_DATA(data, array); + switch (field_meta.get_element_type()) { + case DataType::BOOL: { + for (auto& array_bytes : array_data) { + result += array_bytes.bool_data().data_size() * + sizeof(bool); + } + break; + } + case DataType::INT8: + case DataType::INT16: + case DataType::INT32: { + for (auto& array_bytes : array_data) { + result += array_bytes.int_data().data_size() * + sizeof(int); + } + break; + } + case DataType::INT64: { + for (auto& array_bytes : array_data) { + result += array_bytes.long_data().data_size() * + sizeof(int64_t); + } + break; + } + case DataType::FLOAT: { + for (auto& array_bytes : array_data) { + result += array_bytes.float_data().data_size() * + sizeof(float); + } + break; + } + case DataType::DOUBLE: { + for (auto& array_bytes : array_data) { + result += array_bytes.double_data().data_size() * + sizeof(double); + } + break; + } + case DataType::VARCHAR: + case DataType::STRING: { + for (auto& array_bytes : array_data) { + auto element_num = + array_bytes.string_data().data_size(); + for (int i = 0; i < element_num; ++i) { + result += + array_bytes.string_data().data(i).size(); + } + } + break; + } + default: + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported element type for array", + field_meta.get_element_type())); + } + + break; + } default: { PanicInfo( + DataTypeInvalid, fmt::format("unsupported variable datatype {}", data_type)); } } @@ -210,8 +277,17 @@ CreateScalarDataArray(int64_t count, const FieldMeta& field_meta) { } break; } + case DataType::ARRAY: { + auto obj = scalar_array->mutable_array_data(); + obj->mutable_data()->Reserve(count); + for (int i = 0; i < count; i++) { + *(obj->mutable_data()->Add()) = proto::schema::ScalarField(); + } + break; + } default: { - PanicInfo("unsupported datatype"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported datatype {}", data_type)); } } @@ -244,8 +320,15 @@ CreateVectorDataArray(int64_t count, const FieldMeta& field_meta) { obj->resize(num_bytes); break; } + case DataType::VECTOR_FLOAT16: { + auto length = count * dim; + auto obj = vector_array->mutable_float16_vector(); + obj->resize(length * sizeof(float16)); + break; + } default: { - PanicInfo("unsupported datatype"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported datatype {}", data_type)); } } return data_array; @@ -321,8 +404,17 @@ CreateScalarDataArrayFrom(const void* data_raw, } break; } + case DataType::ARRAY: { + auto data = reinterpret_cast(data_raw); + auto obj = scalar_array->mutable_array_data(); + for (auto i = 0; i < count; i++) { + *(obj->mutable_data()->Add()) = data[i]; + } + break; + } default: { - PanicInfo("unsupported datatype"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported datatype {}", data_type)); } } @@ -359,8 +451,16 @@ CreateVectorDataArrayFrom(const void* data_raw, obj->assign(data, num_bytes); break; } + case DataType::VECTOR_FLOAT16: { + auto length = count * dim; + auto data = reinterpret_cast(data_raw); + auto obj = vector_array->mutable_float16_vector(); + obj->assign(data, length * sizeof(float16)); + break; + } default: { - PanicInfo("unsupported datatype"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported datatype {}", data_type)); } } return data_array; @@ -414,7 +514,8 @@ MergeDataArray( auto obj = vector_array->mutable_binary_vector(); obj->assign(data + src_offset * num_bytes, num_bytes); } else { - PanicInfo("logical error"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported datatype {}", data_type)); } continue; } @@ -465,8 +566,17 @@ MergeDataArray( *(obj->mutable_data()->Add()) = data[src_offset]; break; } + case DataType::ARRAY: { + auto& data = FIELD_DATA(src_field_data, array); + auto obj = scalar_array->mutable_array_data(); + obj->set_element_type( + proto::schema::DataType(field_meta.get_element_type())); + *(obj->mutable_data()->Add()) = data[src_offset]; + break; + } default: { - PanicInfo(fmt::format("unsupported data type {}", data_type)); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported datatype {}", data_type)); } } } @@ -577,7 +687,8 @@ ReverseDataFromIndex(const index::IndexBase* index, break; } default: { - PanicInfo("unsupported datatype"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported datatype {}", data_type)); } } diff --git a/internal/core/src/segcore/load_field_data_c.cpp b/internal/core/src/segcore/load_field_data_c.cpp index ba8de5d543eda..983ef277f0256 100644 --- a/internal/core/src/segcore/load_field_data_c.cpp +++ b/internal/core/src/segcore/load_field_data_c.cpp @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "common/CGoHelper.h" +#include "common/EasyAssert.h" #include "common/LoadInfo.h" #include "segcore/load_field_data_c.h" @@ -25,13 +25,13 @@ NewLoadFieldDataInfo(CLoadFieldDataInfo* c_load_field_data_info) { *c_load_field_data_info = load_field_data_info.release(); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } void DeleteLoadFieldDataInfo(CLoadFieldDataInfo c_load_field_data_info) { - auto info = (LoadFieldDataInfo*)c_load_field_data_info; + auto info = static_cast(c_load_field_data_info); delete info; } @@ -40,10 +40,12 @@ AppendLoadFieldInfo(CLoadFieldDataInfo c_load_field_data_info, int64_t field_id, int64_t row_count) { try { - auto load_field_data_info = (LoadFieldDataInfo*)c_load_field_data_info; + auto load_field_data_info = + static_cast(c_load_field_data_info); auto iter = load_field_data_info->field_infos.find(field_id); if (iter != load_field_data_info->field_infos.end()) { - throw std::runtime_error("append same field info multi times"); + throw milvus::SegcoreError(milvus::FieldAlreadyExist, + "append same field info multi times"); } FieldBinlogInfo binlog_info; binlog_info.field_id = field_id; @@ -51,34 +53,39 @@ AppendLoadFieldInfo(CLoadFieldDataInfo c_load_field_data_info, load_field_data_info->field_infos[field_id] = binlog_info; return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } CStatus AppendLoadFieldDataPath(CLoadFieldDataInfo c_load_field_data_info, int64_t field_id, + int64_t entries_num, const char* c_file_path) { try { - auto load_field_data_info = (LoadFieldDataInfo*)c_load_field_data_info; + auto load_field_data_info = + static_cast(c_load_field_data_info); auto iter = load_field_data_info->field_infos.find(field_id); - std::string file_path(c_file_path); if (iter == load_field_data_info->field_infos.end()) { - throw std::runtime_error("please append field info first"); + throw milvus::SegcoreError(milvus::FieldIDInvalid, + "please append field info first"); } - + std::string file_path(c_file_path); load_field_data_info->field_infos[field_id].insert_files.emplace_back( file_path); + load_field_data_info->field_infos[field_id].entries_nums.emplace_back( + entries_num); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } void AppendMMapDirPath(CLoadFieldDataInfo c_load_field_data_info, const char* c_dir_path) { - auto load_field_data_info = (LoadFieldDataInfo*)c_load_field_data_info; + auto load_field_data_info = + static_cast(c_load_field_data_info); load_field_data_info->mmap_dir_path = std::string(c_dir_path); } diff --git a/internal/core/src/segcore/load_field_data_c.h b/internal/core/src/segcore/load_field_data_c.h index 9a7af6a83dfd9..2327426256bd5 100644 --- a/internal/core/src/segcore/load_field_data_c.h +++ b/internal/core/src/segcore/load_field_data_c.h @@ -39,6 +39,7 @@ AppendLoadFieldInfo(CLoadFieldDataInfo c_load_field_data_info, CStatus AppendLoadFieldDataPath(CLoadFieldDataInfo c_load_field_data_info, int64_t field_id, + int64_t entries_num, const char* file_path); void diff --git a/internal/core/src/segcore/load_index_c.cpp b/internal/core/src/segcore/load_index_c.cpp index 0f787665b4c9d..4c93e3fe10b0d 100644 --- a/internal/core/src/segcore/load_index_c.cpp +++ b/internal/core/src/segcore/load_index_c.cpp @@ -12,12 +12,13 @@ #include "segcore/load_index_c.h" #include "common/FieldMeta.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "index/Index.h" #include "index/IndexFactory.h" #include "index/Meta.h" #include "index/Utils.h" #include "log/Log.h" +#include "storage/FileManager.h" #include "segcore/Types.h" #include "storage/Util.h" #include "storage/RemoteChunkManagerSingleton.h" @@ -31,12 +32,12 @@ NewLoadIndexInfo(CLoadIndexInfo* c_load_index_info) { *c_load_index_info = load_index_info.release(); auto status = CStatus(); - status.error_code = Success; + status.error_code = milvus::Success; status.error_msg = ""; return status; } catch (std::exception& e) { auto status = CStatus(); - status.error_code = UnexpectedError; + status.error_code = milvus::UnexpectedError; status.error_msg = strdup(e.what()); return status; } @@ -60,12 +61,12 @@ AppendIndexParam(CLoadIndexInfo c_load_index_info, load_index_info->index_params[index_key] = index_value; auto status = CStatus(); - status.error_code = Success; + status.error_code = milvus::Success; status.error_msg = ""; return status; } catch (std::exception& e) { auto status = CStatus(); - status.error_code = UnexpectedError; + status.error_code = milvus::UnexpectedError; status.error_msg = strdup(e.what()); return status; } @@ -90,12 +91,12 @@ AppendFieldInfo(CLoadIndexInfo c_load_index_info, load_index_info->mmap_dir_path = std::string(mmap_dir_path); auto status = CStatus(); - status.error_code = Success; + status.error_code = milvus::Success; status.error_msg = ""; return status; } catch (std::exception& e) { auto status = CStatus(); - status.error_code = UnexpectedError; + status.error_code = milvus::UnexpectedError; status.error_msg = strdup(e.what()); return status; } @@ -111,6 +112,7 @@ appendVecIndex(CLoadIndexInfo c_load_index_info, CBinarySet c_binary_set) { milvus::index::CreateIndexInfo index_info; index_info.field_type = load_index_info->field_type; + index_info.index_engine_version = load_index_info->index_engine_version; // get index type AssertInfo(index_params.find("index_type") != index_params.end(), @@ -135,28 +137,24 @@ appendVecIndex(CLoadIndexInfo c_load_index_info, CBinarySet c_binary_set) { auto remote_chunk_manager = milvus::storage::RemoteChunkManagerSingleton::GetInstance() .GetRemoteChunkManager(); - auto file_manager = - milvus::storage::CreateFileManager(index_info.index_type, - field_meta, - index_meta, - remote_chunk_manager); - AssertInfo(file_manager != nullptr, "create file manager failed!"); auto config = milvus::index::ParseConfigFromIndexParams( load_index_info->index_params); config["index_files"] = load_index_info->index_files; + milvus::storage::FileManagerContext fileManagerContext( + field_meta, index_meta, remote_chunk_manager); load_index_info->index = milvus::index::IndexFactory::GetInstance().CreateIndex( - index_info, file_manager); + index_info, fileManagerContext); load_index_info->index->Load(*binary_set, config); auto status = CStatus(); - status.error_code = Success; + status.error_code = milvus::Success; status.error_msg = ""; return status; } catch (std::exception& e) { auto status = CStatus(); - status.error_code = UnexpectedError; + status.error_code = milvus::UnexpectedError; status.error_msg = strdup(e.what()); return status; } @@ -180,16 +178,16 @@ appendScalarIndex(CLoadIndexInfo c_load_index_info, CBinarySet c_binary_set) { index_info.index_type = index_params["index_type"]; load_index_info->index = - milvus::index::IndexFactory::GetInstance().CreateIndex(index_info, - nullptr); + milvus::index::IndexFactory::GetInstance().CreateIndex( + index_info, milvus::storage::FileManagerContext()); load_index_info->index->Load(*binary_set); auto status = CStatus(); - status.error_code = Success; + status.error_code = milvus::Success; status.error_msg = ""; return status; } catch (std::exception& e) { auto status = CStatus(); - status.error_code = UnexpectedError; + status.error_code = milvus::UnexpectedError; status.error_msg = strdup(e.what()); return status; } @@ -213,8 +211,11 @@ AppendIndexV2(CLoadIndexInfo c_load_index_info) { auto& index_params = load_index_info->index_params; auto field_type = load_index_info->field_type; + auto engine_version = load_index_info->index_engine_version; + milvus::index::CreateIndexInfo index_info; index_info.field_type = load_index_info->field_type; + index_info.index_engine_version = engine_version; // get index type AssertInfo(index_params.find("index_type") != index_params.end(), @@ -241,20 +242,16 @@ AppendIndexV2(CLoadIndexInfo c_load_index_info) { auto remote_chunk_manager = milvus::storage::RemoteChunkManagerSingleton::GetInstance() .GetRemoteChunkManager(); - auto file_manager = - milvus::storage::CreateFileManager(index_info.index_type, - field_meta, - index_meta, - remote_chunk_manager); - AssertInfo(file_manager != nullptr, "create file manager failed!"); auto config = milvus::index::ParseConfigFromIndexParams( load_index_info->index_params); config["index_files"] = load_index_info->index_files; + milvus::storage::FileManagerContext fileManagerContext( + field_meta, index_meta, remote_chunk_manager); load_index_info->index = milvus::index::IndexFactory::GetInstance().CreateIndex( - index_info, file_manager); + index_info, fileManagerContext); if (!load_index_info->mmap_dir_path.empty() && load_index_info->index->IsMmapSupported()) { @@ -269,12 +266,12 @@ AppendIndexV2(CLoadIndexInfo c_load_index_info) { load_index_info->index->Load(config); auto status = CStatus(); - status.error_code = Success; + status.error_code = milvus::Success; status.error_msg = ""; return status; } catch (std::exception& e) { auto status = CStatus(); - status.error_code = UnexpectedError; + status.error_code = milvus::UnexpectedError; status.error_msg = strdup(e.what()); return status; } @@ -348,12 +345,12 @@ AppendIndexFilePath(CLoadIndexInfo c_load_index_info, const char* c_file_path) { load_index_info->index_files.emplace_back(index_file_path); auto status = CStatus(); - status.error_code = Success; + status.error_code = milvus::Success; status.error_msg = ""; return status; } catch (std::exception& e) { auto status = CStatus(); - status.error_code = UnexpectedError; + status.error_code = milvus::UnexpectedError; status.error_msg = strdup(e.what()); return status; } @@ -372,12 +369,32 @@ AppendIndexInfo(CLoadIndexInfo c_load_index_info, load_index_info->index_version = version; auto status = CStatus(); - status.error_code = Success; + status.error_code = milvus::Success; status.error_msg = ""; return status; } catch (std::exception& e) { auto status = CStatus(); - status.error_code = UnexpectedError; + status.error_code = milvus::UnexpectedError; + status.error_msg = strdup(e.what()); + return status; + } +} + +CStatus +AppendIndexEngineVersionToLoadInfo(CLoadIndexInfo c_load_index_info, + int32_t index_engine_version) { + try { + auto load_index_info = + (milvus::segcore::LoadIndexInfo*)c_load_index_info; + load_index_info->index_engine_version = index_engine_version; + + auto status = CStatus(); + status.error_code = milvus::Success; + status.error_msg = ""; + return status; + } catch (std::exception& e) { + auto status = CStatus(); + status.error_code = milvus::UnexpectedError; status.error_msg = strdup(e.what()); return status; } @@ -397,12 +414,12 @@ CleanLoadedIndex(CLoadIndexInfo c_load_index_info) { load_index_info->index_version); local_chunk_manager->RemoveDir(index_file_path_prefix); auto status = CStatus(); - status.error_code = Success; + status.error_code = milvus::Success; status.error_msg = ""; return status; } catch (std::exception& e) { auto status = CStatus(); - status.error_code = UnexpectedError; + status.error_code = milvus::UnexpectedError; status.error_msg = strdup(e.what()); return status; } diff --git a/internal/core/src/segcore/load_index_c.h b/internal/core/src/segcore/load_index_c.h index c831c1055d989..d9729377827b5 100644 --- a/internal/core/src/segcore/load_index_c.h +++ b/internal/core/src/segcore/load_index_c.h @@ -61,6 +61,10 @@ AppendIndexV2(CLoadIndexInfo c_load_index_info); CStatus AppendIndexV3(CLoadIndexInfo c_load_index_info); +CStatus +AppendIndexEngineVersionToLoadInfo(CLoadIndexInfo c_load_index_info, + int32_t index_engine_version); + CStatus CleanLoadedIndex(CLoadIndexInfo c_load_index_info); diff --git a/internal/core/src/segcore/pkVisitor.h b/internal/core/src/segcore/pkVisitor.h index 65a713c134676..d7fef1fb081f6 100644 --- a/internal/core/src/segcore/pkVisitor.h +++ b/internal/core/src/segcore/pkVisitor.h @@ -11,6 +11,7 @@ #pragma once #include +#include "common/EasyAssert.h" namespace milvus::segcore { @@ -18,7 +19,7 @@ struct Int64PKVisitor { template int64_t operator()(T t) const { - PanicInfo("invalid int64 pk value"); + PanicInfo(Unsupported, "invalid int64 pk value"); } }; @@ -32,7 +33,7 @@ struct StrPKVisitor { template std::string operator()(T t) const { - PanicInfo("invalid string pk value"); + PanicInfo(Unsupported, "invalid string pk value"); } }; diff --git a/internal/core/src/segcore/plan_c.cpp b/internal/core/src/segcore/plan_c.cpp index 8a7892d898d20..72c3d463c8c00 100644 --- a/internal/core/src/segcore/plan_c.cpp +++ b/internal/core/src/segcore/plan_c.cpp @@ -9,7 +9,6 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License -#include "common/CGoHelper.h" #include "pb/segcore.pb.h" #include "query/Plan.h" #include "segcore/Collection.h" @@ -28,7 +27,7 @@ CreateSearchPlanByExpr(CCollection c_col, *col->get_schema(), serialized_expr_plan, size); auto status = CStatus(); - status.error_code = Success; + status.error_code = milvus::Success; status.error_msg = ""; auto plan = (CSearchPlan)res.release(); *res_plan = plan; @@ -41,7 +40,7 @@ CreateSearchPlanByExpr(CCollection c_col, return status; } catch (std::exception& e) { auto status = CStatus(); - status.error_code = UnexpectedError; + status.error_code = milvus::UnexpectedError; status.error_msg = strdup(e.what()); *res_plan = nullptr; return status; @@ -60,14 +59,14 @@ ParsePlaceholderGroup(CSearchPlan c_plan, plan, (const uint8_t*)(placeholder_group_blob), blob_size); auto status = CStatus(); - status.error_code = Success; + status.error_code = milvus::Success; status.error_msg = ""; auto group = (CPlaceholderGroup)res.release(); *res_placeholder_group = group; return status; } catch (std::exception& e) { auto status = CStatus(); - status.error_code = UnexpectedError; + status.error_code = milvus::UnexpectedError; status.error_msg = strdup(e.what()); *res_placeholder_group = nullptr; return status; @@ -77,13 +76,13 @@ ParsePlaceholderGroup(CSearchPlan c_plan, int64_t GetNumOfQueries(CPlaceholderGroup placeholder_group) { auto res = milvus::query::GetNumOfQueries( - (milvus::query::PlaceholderGroup*)placeholder_group); + static_cast(placeholder_group)); return res; } int64_t GetTopK(CSearchPlan plan) { - auto res = milvus::query::GetTopK((milvus::query::Plan*)plan); + auto res = milvus::query::GetTopK(static_cast(plan)); return res; } @@ -94,7 +93,7 @@ GetFieldID(CSearchPlan plan, int64_t* field_id) { *field_id = milvus::query::GetFieldID(p); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, strdup(e.what())); + return milvus::FailureCStatus(&e); } } @@ -116,14 +115,14 @@ SetMetricType(CSearchPlan plan, const char* metric_type) { void DeleteSearchPlan(CSearchPlan cPlan) { - auto plan = (milvus::query::Plan*)cPlan; + auto plan = static_cast(cPlan); delete plan; } void DeletePlaceholderGroup(CPlaceholderGroup cPlaceholder_group) { auto placeHolder_group = - (milvus::query::PlaceholderGroup*)cPlaceholder_group; + static_cast(cPlaceholder_group); delete placeHolder_group; } @@ -132,14 +131,14 @@ CreateRetrievePlanByExpr(CCollection c_col, const void* serialized_expr_plan, const int64_t size, CRetrievePlan* res_plan) { - auto col = (milvus::segcore::Collection*)c_col; + auto col = static_cast(c_col); try { auto res = milvus::query::CreateRetrievePlanByExpr( *col->get_schema(), serialized_expr_plan, size); auto status = CStatus(); - status.error_code = Success; + status.error_code = milvus::Success; status.error_msg = ""; auto plan = (CRetrievePlan)res.release(); *res_plan = plan; @@ -152,7 +151,7 @@ CreateRetrievePlanByExpr(CCollection c_col, return status; } catch (std::exception& e) { auto status = CStatus(); - status.error_code = UnexpectedError; + status.error_code = milvus::UnexpectedError; status.error_msg = strdup(e.what()); *res_plan = nullptr; return status; @@ -161,6 +160,6 @@ CreateRetrievePlanByExpr(CCollection c_col, void DeleteRetrievePlan(CRetrievePlan c_plan) { - auto plan = (milvus::query::RetrievePlan*)c_plan; + auto plan = static_cast(c_plan); delete plan; } diff --git a/internal/core/src/segcore/reduce_c.cpp b/internal/core/src/segcore/reduce_c.cpp index b683fd902156d..c833af25533d3 100644 --- a/internal/core/src/segcore/reduce_c.cpp +++ b/internal/core/src/segcore/reduce_c.cpp @@ -11,9 +11,8 @@ #include #include "Reduce.h" -#include "common/CGoHelper.h" #include "common/QueryResult.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "query/Plan.h" #include "segcore/reduce_c.h" #include "segcore/Utils.h" @@ -46,7 +45,7 @@ ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs, *cSearchResultDataBlobs = reduce_helper.GetSearchResultDataBlobs(); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -68,7 +67,7 @@ GetSearchResultDataBlob(CProto* searchResultDataBlob, } catch (std::exception& e) { searchResultDataBlob->proto_blob = nullptr; searchResultDataBlob->proto_size = 0; - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } diff --git a/internal/core/src/segcore/segcore_init_c.cpp b/internal/core/src/segcore/segcore_init_c.cpp index 87a02bcac12ec..5ebf53f651f01 100644 --- a/internal/core/src/segcore/segcore_init_c.cpp +++ b/internal/core/src/segcore/segcore_init_c.cpp @@ -15,6 +15,9 @@ #include "segcore/segcore_init_c.h" namespace milvus::segcore { + +std::once_flag close_glog_once; + extern "C" void SegcoreInit(const char* conf_file) { milvus::config::KnowhereInitImpl(conf_file); @@ -72,9 +75,21 @@ SegcoreSetSimdType(const char* value) { extern "C" void SegcoreCloseGlog() { - if (google::IsGoogleLoggingInitialized()) { - google::ShutdownGoogleLogging(); - } + std::call_once(close_glog_once, [&]() { + if (google::IsGoogleLoggingInitialized()) { + google::ShutdownGoogleLogging(); + } + }); +} + +extern "C" int32_t +GetCurrentIndexVersion() { + return milvus::config::GetCurrentIndexVersion(); +} + +extern "C" int32_t +GetMinimalIndexVersion() { + return milvus::config::GetMinimalIndexVersion(); } } // namespace milvus::segcore diff --git a/internal/core/src/segcore/segcore_init_c.h b/internal/core/src/segcore/segcore_init_c.h index 24a8abdab38e2..ace8e9e723f75 100644 --- a/internal/core/src/segcore/segcore_init_c.h +++ b/internal/core/src/segcore/segcore_init_c.h @@ -43,6 +43,12 @@ SegcoreSetKnowhereSearchThreadPoolNum(const uint32_t num_threads); void SegcoreCloseGlog(); +int32_t +GetCurrentIndexVersion(); + +int32_t +GetMinimalIndexVersion(); + #ifdef __cplusplus } #endif diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index d06a333fcb2c2..5a0e14303b5c6 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -12,7 +12,6 @@ #include "segcore/segment_c.h" #include -#include "common/CGoHelper.h" #include "common/LoadInfo.h" #include "common/Types.h" #include "common/Tracer.h" @@ -22,6 +21,7 @@ #include "segcore/Collection.h" #include "segcore/SegmentGrowingImpl.h" #include "segcore/SegmentSealedImpl.h" +#include "segcore/Utils.h" #include "storage/FieldData.h" #include "storage/Util.h" #include "mmap/Types.h" @@ -93,7 +93,7 @@ Search(CSegmentInterface c_segment, milvus::tracer::CloseRootSpan(); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -130,7 +130,7 @@ Retrieve(CSegmentInterface c_segment, span->End(); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -192,7 +192,7 @@ Insert(CSegmentInterface c_segment, reserved_offset, size, row_ids, timestamps, insert_data.get()); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -203,7 +203,7 @@ PreInsert(CSegmentInterface c_segment, int64_t size, int64_t* offset) { *offset = segment->PreInsert(size); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -223,7 +223,7 @@ Delete(CSegmentInterface c_segment, segment->Delete(reserved_offset, size, pks.get(), timestamps); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -239,7 +239,7 @@ LoadFieldData(CSegmentInterface c_segment, segment->LoadFieldData(*load_info); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -293,7 +293,7 @@ LoadFieldRawData(CSegmentInterface c_segment, segment->LoadFieldData(milvus::FieldId(field_id), field_data_info); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -314,7 +314,7 @@ LoadDeletedRecord(CSegmentInterface c_segment, segment_interface->LoadDeletedRecord(load_info); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -332,7 +332,7 @@ UpdateSealedSegmentIndex(CSegmentInterface c_segment, segment->LoadIndex(*load_index_info); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -349,7 +349,7 @@ UpdateFieldRawDataSize(CSegmentInterface c_segment, milvus::FieldId(field_id), num_rows, field_data_size); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -364,7 +364,7 @@ DropFieldData(CSegmentInterface c_segment, int64_t field_id) { segment->DropFieldData(milvus::FieldId(field_id)); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -379,6 +379,24 @@ DropSealedSegmentIndex(CSegmentInterface c_segment, int64_t field_id) { segment->DropIndex(milvus::FieldId(field_id)); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); + } +} + +CStatus +AddFieldDataInfoForSealed(CSegmentInterface c_segment, + CLoadFieldDataInfo c_load_field_data_info) { + try { + auto segment_interface = + reinterpret_cast(c_segment); + auto segment = + dynamic_cast(segment_interface); + AssertInfo(segment != nullptr, "segment conversion failed"); + auto load_info = + static_cast(c_load_field_data_info); + segment->AddFieldDataInfoForSealed(*load_info); + return milvus::SuccessCStatus(); + } catch (std::exception& e) { + return milvus::FailureCStatus(milvus::UnexpectedError, e.what()); } } diff --git a/internal/core/src/segcore/segment_c.h b/internal/core/src/segcore/segment_c.h index a76eaa720af1c..8c22fc69ee20d 100644 --- a/internal/core/src/segcore/segment_c.h +++ b/internal/core/src/segcore/segment_c.h @@ -119,7 +119,17 @@ DropFieldData(CSegmentInterface c_segment, int64_t field_id); CStatus DropSealedSegmentIndex(CSegmentInterface c_segment, int64_t field_id); +CStatus +AddFieldDataInfoForSealed(CSegmentInterface c_segment, + CLoadFieldDataInfo c_load_field_data_info); + ////////////////////////////// interfaces for SegmentInterface ////////////////////////////// +CStatus +ExistPk(CSegmentInterface c_segment, + const uint8_t* raw_ids, + const uint64_t size, + bool* results); + CStatus Delete(CSegmentInterface c_segment, int64_t reserved_offset, diff --git a/internal/core/src/storage/AzureChunkManager.cpp b/internal/core/src/storage/AzureChunkManager.cpp new file mode 100644 index 0000000000000..ff100f32734ae --- /dev/null +++ b/internal/core/src/storage/AzureChunkManager.cpp @@ -0,0 +1,155 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "common/EasyAssert.h" +#include "storage/AzureChunkManager.h" + +namespace milvus { +namespace storage { + +AzureChunkManager::AzureChunkManager(const StorageConfig& storage_config) + : default_bucket_name_(storage_config.bucket_name), + path_prefix_(storage_config.root_path) { + client_ = std::make_shared( + storage_config.access_key_id, + storage_config.access_key_value, + storage_config.address, + storage_config.useIAM); +} + +AzureChunkManager::~AzureChunkManager() { +} + +uint64_t +AzureChunkManager::Size(const std::string& filepath) { + return GetObjectSize(default_bucket_name_, filepath); +} + +bool +AzureChunkManager::Exist(const std::string& filepath) { + return ObjectExists(default_bucket_name_, filepath); +} + +void +AzureChunkManager::Remove(const std::string& filepath) { + DeleteObject(default_bucket_name_, filepath); +} + +std::vector +AzureChunkManager::ListWithPrefix(const std::string& filepath) { + return ListObjects(default_bucket_name_.c_str(), filepath.c_str()); +} + +uint64_t +AzureChunkManager::Read(const std::string& filepath, void* buf, uint64_t size) { + if (!ObjectExists(default_bucket_name_, filepath)) { + std::stringstream err_msg; + err_msg << "object('" << default_bucket_name_ << "', '" << filepath + << "') not exists"; + throw SegcoreError(ObjectNotExist, err_msg.str()); + } + return GetObjectBuffer(default_bucket_name_, filepath, buf, size); +} + +void +AzureChunkManager::Write(const std::string& filepath, + void* buf, + uint64_t size) { + PutObjectBuffer(default_bucket_name_, filepath, buf, size); +} + +bool +AzureChunkManager::BucketExists(const std::string& bucket_name) { + return client_->BucketExists(bucket_name); +} + +std::vector +AzureChunkManager::ListBuckets() { + return client_->ListBuckets(); +} + +bool +AzureChunkManager::CreateBucket(const std::string& bucket_name) { + try { + client_->CreateBucket(bucket_name); + } catch (std::exception& e) { + throw SegcoreError(BucketInvalid, e.what()); + } + return true; +} + +bool +AzureChunkManager::DeleteBucket(const std::string& bucket_name) { + try { + client_->DeleteBucket(bucket_name); + } catch (std::exception& e) { + throw SegcoreError(BucketInvalid, e.what()); + } + return true; +} + +bool +AzureChunkManager::ObjectExists(const std::string& bucket_name, + const std::string& object_name) { + return client_->ObjectExists(bucket_name, object_name); +} + +int64_t +AzureChunkManager::GetObjectSize(const std::string& bucket_name, + const std::string& object_name) { + try { + return client_->GetObjectSize(bucket_name, object_name); + } catch (std::exception& e) { + throw SegcoreError(ObjectNotExist, e.what()); + } +} + +bool +AzureChunkManager::DeleteObject(const std::string& bucket_name, + const std::string& object_name) { + try { + client_->DeleteObject(bucket_name, object_name); + } catch (std::exception& e) { + throw SegcoreError(ObjectNotExist, e.what()); + } + return true; +} + +bool +AzureChunkManager::PutObjectBuffer(const std::string& bucket_name, + const std::string& object_name, + void* buf, + uint64_t size) { + return client_->PutObjectBuffer(bucket_name, object_name, buf, size); +} + +uint64_t +AzureChunkManager::GetObjectBuffer(const std::string& bucket_name, + const std::string& object_name, + void* buf, + uint64_t size) { + return client_->GetObjectBuffer(bucket_name, object_name, buf, size); +} + +std::vector +AzureChunkManager::ListObjects(const char* bucket_name, const char* prefix) { + return client_->ListObjects(bucket_name, prefix); +} + +} // namespace storage +} // namespace milvus diff --git a/internal/core/src/storage/AzureChunkManager.h b/internal/core/src/storage/AzureChunkManager.h new file mode 100644 index 0000000000000..dc4c6ab5e44d0 --- /dev/null +++ b/internal/core/src/storage/AzureChunkManager.h @@ -0,0 +1,144 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// export CPLUS_INCLUDE_PATH=/opt/homebrew/Cellar/boost/1.81.0_1/include/ + +#pragma once + +#include +#include +#include +#include +#include "storage/azure-blob-storage/AzureBlobChunkManager.h" +#include "storage/ChunkManager.h" +#include "storage/Types.h" + +namespace milvus { +namespace storage { + +/** + * @brief This AzureChunkManager is responsible for read and write file in blob. + */ +class AzureChunkManager : public ChunkManager { + public: + explicit AzureChunkManager(const StorageConfig& storage_config); + + AzureChunkManager(const AzureChunkManager&); + AzureChunkManager& + operator=(const AzureChunkManager&); + + public: + virtual ~AzureChunkManager(); + + virtual bool + Exist(const std::string& filepath); + + virtual uint64_t + Size(const std::string& filepath); + + virtual uint64_t + Read(const std::string& filepath, + uint64_t offset, + void* buf, + uint64_t len) { + throw SegcoreError(NotImplemented, + GetName() + "Read with offset not implement"); + } + + virtual void + Write(const std::string& filepath, + uint64_t offset, + void* buf, + uint64_t len) { + throw SegcoreError(NotImplemented, + GetName() + "Write with offset not implement"); + } + + virtual uint64_t + Read(const std::string& filepath, void* buf, uint64_t len); + + virtual void + Write(const std::string& filepath, void* buf, uint64_t len); + + virtual std::vector + ListWithPrefix(const std::string& filepath); + + virtual void + Remove(const std::string& filepath); + + virtual std::string + GetName() const { + return "AzureChunkManager"; + } + + virtual std::string + GetRootPath() const { + return path_prefix_; + } + + inline std::string + GetBucketName() { + return default_bucket_name_; + } + + inline void + SetBucketName(const std::string& bucket_name) { + default_bucket_name_ = bucket_name; + } + + bool + BucketExists(const std::string& bucket_name); + + bool + CreateBucket(const std::string& bucket_name); + + bool + DeleteBucket(const std::string& bucket_name); + + std::vector + ListBuckets(); + + public: + bool + ObjectExists(const std::string& bucket_name, + const std::string& object_name); + int64_t + GetObjectSize(const std::string& bucket_name, + const std::string& object_name); + bool + DeleteObject(const std::string& bucket_name, + const std::string& object_name); + bool + PutObjectBuffer(const std::string& bucket_name, + const std::string& object_name, + void* buf, + uint64_t size); + uint64_t + GetObjectBuffer(const std::string& bucket_name, + const std::string& object_name, + void* buf, + uint64_t size); + std::vector + ListObjects(const char* bucket_name, const char* prefix = nullptr); + + private: + std::shared_ptr client_; + std::string default_bucket_name_; + std::string path_prefix_; +}; + +using AzureChunkManagerPtr = std::unique_ptr; + +} // namespace storage +} // namespace milvus diff --git a/internal/core/src/storage/BinlogReader.cpp b/internal/core/src/storage/BinlogReader.cpp index fae1a65a9044f..975ee92adee49 100644 --- a/internal/core/src/storage/BinlogReader.cpp +++ b/internal/core/src/storage/BinlogReader.cpp @@ -15,32 +15,34 @@ // limitations under the License. #include "storage/BinlogReader.h" +#include "common/EasyAssert.h" namespace milvus::storage { -Status +milvus::SegcoreError BinlogReader::Read(int64_t nbytes, void* out) { auto remain = size_ - tell_; if (nbytes > remain) { - return Status(SERVER_UNEXPECTED_ERROR, "out range of binlog data"); + return SegcoreError(milvus::UnexpectedError, + "out range of binlog data"); } std::memcpy(out, data_.get() + tell_, nbytes); tell_ += nbytes; - return Status(SERVER_SUCCESS, ""); + return SegcoreError(milvus::Success, ""); } -std::pair> +std::pair> BinlogReader::Read(int64_t nbytes) { auto remain = size_ - tell_; if (nbytes > remain) { return std::make_pair( - Status(SERVER_UNEXPECTED_ERROR, "out range of binlog data"), + SegcoreError(milvus::UnexpectedError, "out range of binlog data"), nullptr); } auto deleter = [&](uint8_t*) {}; // avoid repeated deconstruction auto res = std::shared_ptr(data_.get() + tell_, deleter); tell_ += nbytes; - return std::make_pair(Status(SERVER_SUCCESS, ""), res); + return std::make_pair(SegcoreError(milvus::Success, ""), res); } } // namespace milvus::storage diff --git a/internal/core/src/storage/BinlogReader.h b/internal/core/src/storage/BinlogReader.h index 480362a6afe9a..467407ac80075 100644 --- a/internal/core/src/storage/BinlogReader.h +++ b/internal/core/src/storage/BinlogReader.h @@ -19,8 +19,7 @@ #include #include -#include "utils/Status.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" namespace milvus::storage { @@ -31,10 +30,10 @@ class BinlogReader { : data_(binlog_data), size_(length), tell_(0) { } - Status + SegcoreError Read(int64_t nbytes, void* out); - std::pair> + std::pair> Read(int64_t nbytes); int64_t @@ -43,9 +42,9 @@ class BinlogReader { } private: + std::shared_ptr data_; int64_t size_; int64_t tell_; - std::shared_ptr data_; }; using BinlogReaderPtr = std::shared_ptr; diff --git a/internal/core/src/storage/CMakeLists.txt b/internal/core/src/storage/CMakeLists.txt index 74f801e4bd385..e6fe641e3db3e 100644 --- a/internal/core/src/storage/CMakeLists.txt +++ b/internal/core/src/storage/CMakeLists.txt @@ -22,7 +22,18 @@ endif() milvus_add_pkg_config("milvus_storage") +if (DEFINED AZURE_BUILD_DIR) + add_definitions(-DAZURE_BUILD_DIR) + include_directories(azure-blob-storage) + include_directories("${AZURE_BUILD_DIR}/vcpkg_installed/${VCPKG_TARGET_TRIPLET}/include") + set(STORAGE_FILES + ${STORAGE_FILES} + AzureChunkManager.cpp + ) +endif() + set(STORAGE_FILES + ${STORAGE_FILES} parquet_c.cpp PayloadStream.cpp DataCodec.cpp @@ -31,29 +42,38 @@ set(STORAGE_FILES PayloadReader.cpp PayloadWriter.cpp BinlogReader.cpp - IndexData.cpp + IndexData.cpp InsertData.cpp Event.cpp ThreadPool.cpp storage_c.cpp MinioChunkManager.cpp + ChunkManagers.cpp AliyunSTSClient.cpp AliyunCredentialsProvider.cpp MemFileManagerImpl.cpp LocalChunkManager.cpp - DiskFileManagerImpl.cpp ThreadPools.cpp) + DiskFileManagerImpl.cpp + ThreadPools.cpp + ChunkCache.cpp) add_library(milvus_storage SHARED ${STORAGE_FILES}) -find_package(Boost REQUIRED COMPONENTS filesystem) - -# message(FATAL_ERROR "${CONAN_LIBS}") -target_link_libraries(milvus_storage PUBLIC - milvus-storage - milvus_common - Boost::filesystem - pthread - ${CONAN_LIBS} - ) +if (DEFINED AZURE_BUILD_DIR) + target_link_libraries(milvus_storage PUBLIC + "-L${AZURE_BUILD_DIR} -lblob-chunk-manager" + milvus_common + milvus-storage + pthread + ${CONAN_LIBS} + ) +else () + target_link_libraries(milvus_storage PUBLIC + milvus_common + milvus-storage + pthread + ${CONAN_LIBS} + ) +endif() install(TARGETS milvus_storage DESTINATION "${CMAKE_INSTALL_LIBDIR}") diff --git a/internal/core/src/storage/ChunkCache.cpp b/internal/core/src/storage/ChunkCache.cpp new file mode 100644 index 0000000000000..0a0454c4830ff --- /dev/null +++ b/internal/core/src/storage/ChunkCache.cpp @@ -0,0 +1,108 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ChunkCache.h" + +namespace milvus::storage { + +std::shared_ptr +ChunkCache::Read(const std::string& filepath) { + auto path = std::filesystem::path(path_prefix_) / filepath; + + ColumnTable::const_accessor ca; + if (columns_.find(ca, path)) { + return ca->second; + } + ca.release(); + + auto object_data = + GetObjectData(cm_.get(), std::vector{filepath}); + AssertInfo(object_data.size() == 1, "GetObjectData failed"); + auto field_data = object_data[0]; + + auto column = Mmap(path, field_data); + columns_.emplace(path, column); + return column; +} + +void +ChunkCache::Remove(const std::string& filepath) { + auto path = std::filesystem::path(path_prefix_) / filepath; + columns_.erase(path); +} + +void +ChunkCache::Prefetch(const std::string& filepath) { + auto path = std::filesystem::path(path_prefix_) / filepath; + ColumnTable::const_accessor ca; + if (!columns_.find(ca, path)) { + return; + } + auto column = ca->second; + auto ok = + madvise(reinterpret_cast(const_cast(column->Data())), + column->ByteSize(), + read_ahead_policy_); + AssertInfo(ok == 0, + fmt::format("failed to madvise to the data file {}, err: {}", + path.c_str(), + strerror(errno))); +} + +std::shared_ptr +ChunkCache::Mmap(const std::filesystem::path& path, + const FieldDataPtr& field_data) { + std::unique_lock lck(mutex_); + auto dir = path.parent_path(); + std::filesystem::create_directories(dir); + + auto dim = field_data->get_dim(); + auto data_type = field_data->get_data_type(); + + auto file = File::Open(path.string(), O_CREAT | O_TRUNC | O_RDWR); + + // write the field data to disk + auto data_size = field_data->Size(); + // unused + std::vector> element_indices{}; + auto written = WriteFieldData(file, data_type, field_data, element_indices); + AssertInfo(written == data_size, + fmt::format("failed to write data file {}, written " + "{} but total {}, err: {}", + path.c_str(), + written, + data_size, + strerror(errno))); + + std::shared_ptr column{}; + + if (datatype_is_variable(data_type)) { + AssertInfo(false, "TODO: unimplemented for variable data type"); + } else { + column = std::make_shared(file, data_size, dim, data_type); + } + + // unlink + auto ok = unlink(path.c_str()); + AssertInfo(ok == 0, + fmt::format("failed to unlink mmap data file {}, err: {}", + path.c_str(), + strerror(errno))); + + return column; +} + +} // namespace milvus::storage diff --git a/internal/core/src/storage/ChunkCache.h b/internal/core/src/storage/ChunkCache.h new file mode 100644 index 0000000000000..9d842b8e556ec --- /dev/null +++ b/internal/core/src/storage/ChunkCache.h @@ -0,0 +1,74 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "mmap/Column.h" + +namespace milvus::storage { + +extern std::map ReadAheadPolicy_Map; + +class ChunkCache { + public: + explicit ChunkCache(std::string path, + const std::string& read_ahead_policy, + ChunkManagerPtr cm) + : path_prefix_(std::move(path)), cm_(cm) { + auto iter = ReadAheadPolicy_Map.find(read_ahead_policy); + AssertInfo(iter != ReadAheadPolicy_Map.end(), + fmt::format("unrecognized read ahead policy: {}, " + "should be one of `normal, random, sequential, " + "willneed, dontneed`", + read_ahead_policy)); + read_ahead_policy_ = iter->second; + LOG_SEGCORE_INFO_ << "Init ChunkCache with prefix: " << path_prefix_ + << ", read_ahead_policy: " << read_ahead_policy; + } + + ~ChunkCache() = default; + + public: + std::shared_ptr + Read(const std::string& filepath); + + void + Remove(const std::string& filepath); + + void + Prefetch(const std::string& filepath); + + private: + std::shared_ptr + Mmap(const std::filesystem::path& path, const FieldDataPtr& field_data); + + private: + using ColumnTable = + oneapi::tbb::concurrent_hash_map>; + + private: + mutable std::mutex mutex_; + int read_ahead_policy_; + std::string path_prefix_; + ChunkManagerPtr cm_; + ColumnTable columns_; +}; + +using ChunkCachePtr = std::shared_ptr; + +} // namespace milvus::storage diff --git a/internal/core/src/storage/ChunkCacheSingleton.h b/internal/core/src/storage/ChunkCacheSingleton.h new file mode 100644 index 0000000000000..c1abfb7379621 --- /dev/null +++ b/internal/core/src/storage/ChunkCacheSingleton.h @@ -0,0 +1,60 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "ChunkCache.h" +#include "RemoteChunkManagerSingleton.h" + +namespace milvus::storage { + +class ChunkCacheSingleton { + private: + ChunkCacheSingleton() { + } + + public: + ChunkCacheSingleton(const ChunkCacheSingleton&) = delete; + ChunkCacheSingleton& + operator=(const ChunkCacheSingleton&) = delete; + + static ChunkCacheSingleton& + GetInstance() { + static ChunkCacheSingleton instance; + return instance; + } + + void + Init(std::string root_path, std::string read_ahead_policy) { + if (cc_ == nullptr) { + auto rcm = RemoteChunkManagerSingleton::GetInstance() + .GetRemoteChunkManager(); + cc_ = std::make_shared( + std::move(root_path), std::move(read_ahead_policy), rcm); + } + } + + ChunkCachePtr + GetChunkCache() { + return cc_; + } + + private: + ChunkCachePtr cc_ = nullptr; +}; + +} // namespace milvus::storage \ No newline at end of file diff --git a/internal/core/src/storage/ChunkManager.h b/internal/core/src/storage/ChunkManager.h index c80fef4c68a1c..bec5addd2a547 100644 --- a/internal/core/src/storage/ChunkManager.h +++ b/internal/core/src/storage/ChunkManager.h @@ -124,10 +124,11 @@ class ChunkManager { using ChunkManagerPtr = std::shared_ptr; -enum ChunkManagerType : int8_t { - None_CM = 0, +enum class ChunkManagerType : int8_t { + None = 0, Local = 1, Minio = 2, + Remote = 3, }; extern std::map ChunkManagerType_Map; diff --git a/internal/core/src/storage/ChunkManagers.cpp b/internal/core/src/storage/ChunkManagers.cpp new file mode 100644 index 0000000000000..ea097e47d166c --- /dev/null +++ b/internal/core/src/storage/ChunkManagers.cpp @@ -0,0 +1,175 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "storage/MinioChunkManager.h" +#include "storage/AliyunSTSClient.h" +#include "storage/AliyunCredentialsProvider.h" +#include "common/Consts.h" +#include "common/EasyAssert.h" +#include "log/Log.h" +#include "signal.h" + +namespace milvus::storage { + +Aws::String +ConvertToAwsString(const std::string& str) { + return Aws::String(str.c_str(), str.size()); +} + +Aws::Client::ClientConfiguration +generateConfig(const StorageConfig& storage_config) { + // The ClientConfiguration default constructor will take a long time. + // For more details, please refer to https://github.com/aws/aws-sdk-cpp/issues/1440 + static Aws::Client::ClientConfiguration g_config; + Aws::Client::ClientConfiguration config = g_config; + config.endpointOverride = ConvertToAwsString(storage_config.address); + + if (storage_config.useSSL) { + config.scheme = Aws::Http::Scheme::HTTPS; + config.verifySSL = true; + } else { + config.scheme = Aws::Http::Scheme::HTTP; + config.verifySSL = false; + } + + if (!storage_config.region.empty()) { + config.region = ConvertToAwsString(storage_config.region); + } + + config.requestTimeoutMs = storage_config.requestTimeoutMs == 0 + ? DEFAULT_CHUNK_MANAGER_REQUEST_TIMEOUT_MS + : storage_config.requestTimeoutMs; + + return config; +} + +AwsChunkManager::AwsChunkManager(const StorageConfig& storage_config) { + default_bucket_name_ = storage_config.bucket_name; + remote_root_path_ = storage_config.root_path; + + InitSDKAPIDefault(storage_config.log_level); + + Aws::Client::ClientConfiguration config = generateConfig(storage_config); + if (storage_config.useIAM) { + auto provider = + std::make_shared(); + auto aws_credentials = provider->GetAWSCredentials(); + AssertInfo(!aws_credentials.GetAWSAccessKeyId().empty(), + "if use iam, access key id should not be empty"); + AssertInfo(!aws_credentials.GetAWSSecretKey().empty(), + "if use iam, secret key should not be empty"); + AssertInfo(!aws_credentials.GetSessionToken().empty(), + "if use iam, token should not be empty"); + + client_ = std::make_shared( + provider, + config, + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, + storage_config.useVirtualHost); + } else { + BuildAccessKeyClient(storage_config, config); + } + + LOG_SEGCORE_INFO_ << "init AwsChunkManager with parameter[endpoint: '" + << storage_config.address << "', default_bucket_name:'" + << storage_config.bucket_name << "', root_path:'" + << storage_config.root_path << "', use_secure:'" + << std::boolalpha << storage_config.useSSL << "']"; +} + +GcpChunkManager::GcpChunkManager(const StorageConfig& storage_config) { + default_bucket_name_ = storage_config.bucket_name; + remote_root_path_ = storage_config.root_path; + + if (storage_config.useIAM) { + sdk_options_.httpOptions.httpClientFactory_create_fn = []() { + auto credentials = std::make_shared< + google::cloud::oauth2_internal::GOOGLE_CLOUD_CPP_NS:: + ComputeEngineCredentials>(); + return Aws::MakeShared( + GOOGLE_CLIENT_FACTORY_ALLOCATION_TAG, credentials); + }; + } + + InitSDKAPIDefault(storage_config.log_level); + + Aws::Client::ClientConfiguration config = generateConfig(storage_config); + if (storage_config.useIAM) { + // Using S3 client instead of google client because of compatible protocol + client_ = std::make_shared( + config, + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, + storage_config.useVirtualHost); + } else { + BuildAccessKeyClient(storage_config, config); + } + + LOG_SEGCORE_INFO_ << "init GcpChunkManager with parameter[endpoint: '" + << storage_config.address << "', default_bucket_name:'" + << storage_config.bucket_name << "', root_path:'" + << storage_config.root_path << "', use_secure:'" + << std::boolalpha << storage_config.useSSL << "']"; +} + +AliyunChunkManager::AliyunChunkManager(const StorageConfig& storage_config) { + default_bucket_name_ = storage_config.bucket_name; + remote_root_path_ = storage_config.root_path; + + InitSDKAPIDefault(storage_config.log_level); + + Aws::Client::ClientConfiguration config = generateConfig(storage_config); + if (storage_config.useIAM) { + auto aliyun_provider = Aws::MakeShared< + Aws::Auth::AliyunSTSAssumeRoleWebIdentityCredentialsProvider>( + "AliyunSTSAssumeRoleWebIdentityCredentialsProvider"); + auto aliyun_credentials = aliyun_provider->GetAWSCredentials(); + AssertInfo(!aliyun_credentials.GetAWSAccessKeyId().empty(), + "if use iam, access key id should not be empty"); + AssertInfo(!aliyun_credentials.GetAWSSecretKey().empty(), + "if use iam, secret key should not be empty"); + AssertInfo(!aliyun_credentials.GetSessionToken().empty(), + "if use iam, token should not be empty"); + client_ = std::make_shared( + aliyun_provider, + config, + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, + storage_config.useVirtualHost); + } else { + BuildAccessKeyClient(storage_config, config); + } + + LOG_SEGCORE_INFO_ << "init AliyunChunkManager with parameter[endpoint: '" + << storage_config.address << "', default_bucket_name:'" + << storage_config.bucket_name << "', root_path:'" + << storage_config.root_path << "', use_secure:'" + << std::boolalpha << storage_config.useSSL << "']"; +} + +} // namespace milvus::storage diff --git a/internal/core/src/storage/DataCodec.cpp b/internal/core/src/storage/DataCodec.cpp index 923e8ef57eb15..2e37f7bf732bc 100644 --- a/internal/core/src/storage/DataCodec.cpp +++ b/internal/core/src/storage/DataCodec.cpp @@ -20,7 +20,7 @@ #include "storage/InsertData.h" #include "storage/IndexData.h" #include "storage/BinlogReader.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "common/Consts.h" namespace milvus::storage { @@ -86,14 +86,16 @@ DeserializeRemoteFileData(BinlogReaderPtr reader) { return index_data; } default: - PanicInfo("unsupported event type"); + PanicInfo( + DataFormatBroken, + fmt::format("unsupported event type {}", header.event_type_)); } } // For now, no file header in file data std::unique_ptr DeserializeLocalFileData(BinlogReaderPtr reader) { - PanicInfo("not supported"); + PanicInfo(NotImplemented, "not supported"); } std::unique_ptr @@ -109,7 +111,8 @@ DeserializeFileData(const std::shared_ptr input_data, return DeserializeLocalFileData(binlog_reader); } default: - PanicInfo("unsupported medium type"); + PanicInfo(DataFormatBroken, + fmt::format("unsupported medium type {}", medium_type)); } } diff --git a/internal/core/src/storage/DiskFileManagerImpl.cpp b/internal/core/src/storage/DiskFileManagerImpl.cpp index 9c9934bcb3268..083ee13fcb03c 100644 --- a/internal/core/src/storage/DiskFileManagerImpl.cpp +++ b/internal/core/src/storage/DiskFileManagerImpl.cpp @@ -26,7 +26,6 @@ #include "storage/DiskFileManagerImpl.h" #include "storage/FileManager.h" #include "storage/LocalChunkManagerSingleton.h" -#include "storage/Exception.h" #include "storage/IndexData.h" #include "storage/Util.h" #include "storage/ThreadPools.h" @@ -34,17 +33,17 @@ namespace milvus::storage { DiskFileManagerImpl::DiskFileManagerImpl( - const FieldDataMeta& field_mata, - IndexMeta index_meta, + const FileManagerContext& fileManagerContext, std::shared_ptr space) - : FileManagerImpl(field_mata, index_meta), space_(space) { + : FileManagerImpl(fileManagerContext.fieldDataMeta, fileManagerContext.indexMeta), space_(space) { + rcm_ = fileManagerContext.chunkManagerPtr; } -DiskFileManagerImpl::DiskFileManagerImpl(const FieldDataMeta& field_mata, - IndexMeta index_meta, - ChunkManagerPtr remote_chunk_manager) - : FileManagerImpl(field_mata, index_meta) { - rcm_ = remote_chunk_manager; +DiskFileManagerImpl::DiskFileManagerImpl( + const FileManagerContext& fileManagerContext) + : FileManagerImpl(fileManagerContext.fieldDataMeta, + fileManagerContext.indexMeta) { + rcm_ = fileManagerContext.chunkManagerPtr; } DiskFileManagerImpl::~DiskFileManagerImpl() { @@ -447,10 +446,6 @@ DiskFileManagerImpl::IsExisted(const std::string& file) noexcept { LocalChunkManagerSingleton::GetInstance().GetChunkManager(); try { isExist = local_chunk_manager->Exist(file); - } catch (LocalChunkManagerException& e) { - // LOG_SEGCORE_DEBUG_ << "LocalChunkManagerException:" - // << e.what(); - return std::nullopt; } catch (std::exception& e) { // LOG_SEGCORE_DEBUG_ << "Exception:" << e.what(); return std::nullopt; diff --git a/internal/core/src/storage/DiskFileManagerImpl.h b/internal/core/src/storage/DiskFileManagerImpl.h index ec3197d814a75..600cf76479741 100644 --- a/internal/core/src/storage/DiskFileManagerImpl.h +++ b/internal/core/src/storage/DiskFileManagerImpl.h @@ -33,12 +33,9 @@ namespace milvus::storage { class DiskFileManagerImpl : public FileManagerImpl { public: - explicit DiskFileManagerImpl(const FieldDataMeta& field_mata, - IndexMeta index_meta, - ChunkManagerPtr remote_chunk_manager); + explicit DiskFileManagerImpl(const FileManagerContext& fileManagerContext); - explicit DiskFileManagerImpl(const FieldDataMeta& field_mata, - IndexMeta index_meta, + explicit DiskFileManagerImpl(const FileManagerContext& fileManagerContext, std::shared_ptr space); virtual ~DiskFileManagerImpl(); diff --git a/internal/core/src/storage/Event.cpp b/internal/core/src/storage/Event.cpp index 708481778cc35..55ff73ced2765 100644 --- a/internal/core/src/storage/Event.cpp +++ b/internal/core/src/storage/Event.cpp @@ -15,12 +15,15 @@ // limitations under the License. #include "storage/Event.h" +#include "fmt/format.h" +#include "nlohmann/json.hpp" #include "storage/PayloadReader.h" #include "storage/PayloadWriter.h" -#include "exceptions/EasyAssert.h" -#include "utils/Json.h" +#include "common/EasyAssert.h" +#include "common/Json.h" #include "common/Consts.h" #include "common/FieldMeta.h" +#include "common/Array.h" namespace milvus::storage { @@ -45,8 +48,8 @@ GetEventHeaderSize(EventHeader& header) { } int -GetEventFixPartSize(EventType EventTypeCode) { - switch (EventTypeCode) { +GetEventFixPartSize(EventType event_type) { + switch (event_type) { case EventType::DescriptorEvent: { DescriptorEventData data; return GetFixPartSize(data); @@ -62,7 +65,8 @@ GetEventFixPartSize(EventType EventTypeCode) { return GetFixPartSize(data); } default: - PanicInfo("unsupported event type"); + PanicInfo(DataFormatBroken, + fmt::format("unsupported event type {}", event_type)); } } @@ -138,8 +142,8 @@ DescriptorEventDataFixPart::Serialize() { DescriptorEventData::DescriptorEventData(BinlogReaderPtr reader) { fix_part = DescriptorEventDataFixPart(reader); - for (auto i = int8_t(EventType::DescriptorEvent); - i < int8_t(EventType::EventTypeEnd); + for (auto i = static_cast(EventType::DescriptorEvent); + i < static_cast(EventType::EventTypeEnd); i++) { post_header_lengths.push_back(GetEventFixPartSize(EventType(i))); } @@ -152,8 +156,8 @@ DescriptorEventData::DescriptorEventData(BinlogReaderPtr reader) { ast = reader->Read(extra_length, extra_bytes.data()); assert(ast.ok()); - milvus::json json = - milvus::json::parse(extra_bytes.begin(), extra_bytes.end()); + nlohmann::json json = + nlohmann::json::parse(extra_bytes.begin(), extra_bytes.end()); if (json.contains(ORIGIN_SIZE_KEY)) { extras[ORIGIN_SIZE_KEY] = json[ORIGIN_SIZE_KEY]; } @@ -165,7 +169,7 @@ DescriptorEventData::DescriptorEventData(BinlogReaderPtr reader) { std::vector DescriptorEventData::Serialize() { auto fix_part_data = fix_part.Serialize(); - milvus::json extras_json; + nlohmann::json extras_json; for (auto v : extras) { extras_json.emplace(v.first, v.second); } @@ -229,7 +233,19 @@ BaseEventData::Serialize() { } break; } - case DataType::ARRAY: + case DataType::ARRAY: { + for (size_t offset = 0; offset < field_data->get_num_rows(); + ++offset) { + auto array = + static_cast(field_data->RawValue(offset)); + auto array_string = array->output_data().SerializeAsString(); + + payload_writer->add_one_binary_payload( + reinterpret_cast(array_string.c_str()), + array_string.size()); + } + break; + } case DataType::JSON: { for (size_t offset = 0; offset < field_data->get_num_rows(); ++offset) { diff --git a/internal/core/src/storage/Exception.h b/internal/core/src/storage/Exception.h deleted file mode 100644 index 781850cc86ef0..0000000000000 --- a/internal/core/src/storage/Exception.h +++ /dev/null @@ -1,179 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -namespace milvus::storage { - -class NotImplementedException : public std::exception { - public: - explicit NotImplementedException(const std::string& msg) - : std::exception(), exception_message_(msg) { - } - const char* - what() const noexcept { - return exception_message_.c_str(); - } - virtual ~NotImplementedException() { - } - - private: - std::string exception_message_; -}; - -class NotSupportedDataTypeException : public std::exception { - public: - explicit NotSupportedDataTypeException(const std::string& msg) - : std::exception(), exception_message_(msg) { - } - const char* - what() const noexcept { - return exception_message_.c_str(); - } - virtual ~NotSupportedDataTypeException() { - } - - private: - std::string exception_message_; -}; - -class LocalChunkManagerException : public std::runtime_error { - public: - explicit LocalChunkManagerException(const std::string& msg) - : std::runtime_error(msg) { - } - virtual ~LocalChunkManagerException() { - } -}; - -class InvalidPathException : public LocalChunkManagerException { - public: - explicit InvalidPathException(const std::string& msg) - : LocalChunkManagerException(msg) { - } - virtual ~InvalidPathException() { - } -}; - -class OpenFileException : public LocalChunkManagerException { - public: - explicit OpenFileException(const std::string& msg) - : LocalChunkManagerException(msg) { - } - virtual ~OpenFileException() { - } -}; - -class CreateFileException : public LocalChunkManagerException { - public: - explicit CreateFileException(const std::string& msg) - : LocalChunkManagerException(msg) { - } - virtual ~CreateFileException() { - } -}; - -class ReadFileException : public LocalChunkManagerException { - public: - explicit ReadFileException(const std::string& msg) - : LocalChunkManagerException(msg) { - } - virtual ~ReadFileException() { - } -}; - -class WriteFileException : public LocalChunkManagerException { - public: - explicit WriteFileException(const std::string& msg) - : LocalChunkManagerException(msg) { - } - virtual ~WriteFileException() { - } -}; - -class PathAlreadyExistException : public LocalChunkManagerException { - public: - explicit PathAlreadyExistException(const std::string& msg) - : LocalChunkManagerException(msg) { - } - virtual ~PathAlreadyExistException() { - } -}; - -class DirNotExistException : public LocalChunkManagerException { - public: - explicit DirNotExistException(const std::string& msg) - : LocalChunkManagerException(msg) { - } - virtual ~DirNotExistException() { - } -}; - -class MinioException : public std::runtime_error { - public: - explicit MinioException(const std::string& msg) : std::runtime_error(msg) { - } - virtual ~MinioException() { - } -}; - -class InvalidBucketNameException : public MinioException { - public: - explicit InvalidBucketNameException(const std::string& msg) - : MinioException(msg) { - } - virtual ~InvalidBucketNameException() { - } -}; - -class ObjectNotExistException : public MinioException { - public: - explicit ObjectNotExistException(const std::string& msg) - : MinioException(msg) { - } - virtual ~ObjectNotExistException() { - } -}; -class S3ErrorException : public MinioException { - public: - explicit S3ErrorException(const std::string& msg) : MinioException(msg) { - } - virtual ~S3ErrorException() { - } -}; - -class DiskANNFileManagerException : public std::runtime_error { - public: - explicit DiskANNFileManagerException(const std::string& msg) - : std::runtime_error(msg) { - } - virtual ~DiskANNFileManagerException() { - } -}; - -class ArrowException : public std::runtime_error { - public: - explicit ArrowException(const std::string& msg) : std::runtime_error(msg) { - } - virtual ~ArrowException() { - } -}; - -} // namespace milvus::storage diff --git a/internal/core/src/storage/FieldData.cpp b/internal/core/src/storage/FieldData.cpp index 2d7cf85a8977d..fe1770e6feeed 100644 --- a/internal/core/src/storage/FieldData.cpp +++ b/internal/core/src/storage/FieldData.cpp @@ -16,8 +16,11 @@ #include "storage/FieldData.h" #include "arrow/array/array_binary.h" +#include "common/EasyAssert.h" #include "common/Json.h" #include "simdjson/padded_string.h" +#include "common/Array.h" +#include "FieldDataInterface.h" namespace milvus::storage { @@ -130,7 +133,19 @@ FieldDataImpl::FillFieldData( } return FillFieldData(values.data(), element_count); } + case DataType::ARRAY: { + auto array_array = + std::dynamic_pointer_cast(array); + std::vector values(element_count); + for (size_t index = 0; index < element_count; ++index) { + ScalarArray field_data; + field_data.ParseFromString(array_array->GetString(index)); + values[index] = Array(field_data); + } + return FillFieldData(values.data(), element_count); + } case DataType::VECTOR_FLOAT: + case DataType::VECTOR_FLOAT16: case DataType::VECTOR_BINARY: { auto array_info = GetDataInfoFromArray::FillFieldData( return FillFieldData(array_info.first, array_info.second); } default: { - throw NotSupportedDataTypeException(GetName() + "::FillFieldData" + - " not support data type " + - datatype_name(data_type_)); + throw SegcoreError(DataTypeInvalid, + GetName() + "::FillFieldData" + + " not support data type " + + datatype_name(data_type_)); } } } @@ -157,9 +173,11 @@ template class FieldDataImpl; template class FieldDataImpl; template class FieldDataImpl; template class FieldDataImpl; +template class FieldDataImpl; // vector data template class FieldDataImpl; template class FieldDataImpl; +template class FieldDataImpl; -} // namespace milvus::storage \ No newline at end of file +} // namespace milvus::storage diff --git a/internal/core/src/storage/FieldData.h b/internal/core/src/storage/FieldData.h index 92d47b3dbdc64..0a30006ab1bd1 100644 --- a/internal/core/src/storage/FieldData.h +++ b/internal/core/src/storage/FieldData.h @@ -54,6 +54,15 @@ class FieldData : public FieldDataJsonImpl { } }; +template <> +class FieldData : public FieldDataArrayImpl { + public: + static_assert(IsScalar || std::is_same_v); + explicit FieldData(DataType data_type, int64_t buffered_num_rows = 0) + : FieldDataArrayImpl(data_type, buffered_num_rows) { + } +}; + template <> class FieldData : public FieldDataImpl { public: @@ -85,6 +94,17 @@ class FieldData : public FieldDataImpl { int64_t binary_dim_; }; +template <> +class FieldData : public FieldDataImpl { + public: + explicit FieldData(int64_t dim, + DataType data_type, + int64_t buffered_num_rows = 0) + : FieldDataImpl::FieldDataImpl( + dim, data_type, buffered_num_rows) { + } +}; + using FieldDataPtr = std::shared_ptr; using FieldDataChannel = Channel; using FieldDataChannelPtr = std::shared_ptr; diff --git a/internal/core/src/storage/FieldDataInterface.h b/internal/core/src/storage/FieldDataInterface.h index 3165972ab843a..b2a490271f22b 100644 --- a/internal/core/src/storage/FieldDataInterface.h +++ b/internal/core/src/storage/FieldDataInterface.h @@ -30,8 +30,8 @@ #include "common/FieldMeta.h" #include "common/Utils.h" #include "common/VectorTrait.h" -#include "exceptions/EasyAssert.h" -#include "storage/Exception.h" +#include "common/EasyAssert.h" +#include "common/Array.h" namespace milvus::storage { @@ -104,8 +104,8 @@ class FieldDataImpl : public FieldDataBase { DataType data_type, int64_t buffered_num_rows = 0) : FieldDataBase(data_type), - dim_(is_scalar ? 1 : dim), - num_rows_(buffered_num_rows) { + num_rows_(buffered_num_rows), + dim_(is_scalar ? 1 : dim) { field_data_.resize(num_rows_ * dim_); } @@ -306,4 +306,30 @@ class FieldDataJsonImpl : public FieldDataImpl { } }; +class FieldDataArrayImpl : public FieldDataImpl { + public: + explicit FieldDataArrayImpl(DataType data_type, int64_t total_num_rows = 0) + : FieldDataImpl(1, data_type, total_num_rows) { + } + + int64_t + Size() const { + int64_t data_size = 0; + for (size_t offset = 0; offset < length(); ++offset) { + data_size += field_data_[offset].byte_size(); + } + + return data_size; + } + + int64_t + Size(ssize_t offset) const { + AssertInfo(offset < get_num_rows(), + "field data subscript out of range"); + AssertInfo(offset < length(), + "subscript position don't has valid value"); + return field_data_[offset].byte_size(); + } +}; + } // namespace milvus::storage diff --git a/internal/core/src/storage/FileManager.h b/internal/core/src/storage/FileManager.h index 5daa6cd555f4e..13493d31fbf50 100644 --- a/internal/core/src/storage/FileManager.h +++ b/internal/core/src/storage/FileManager.h @@ -28,28 +28,36 @@ namespace milvus::storage { +struct FileManagerContext { + FileManagerContext() : chunkManagerPtr(nullptr) { + } + FileManagerContext(const FieldDataMeta& fieldDataMeta, + const IndexMeta& indexMeta, + const ChunkManagerPtr& chunkManagerPtr) + : fieldDataMeta(fieldDataMeta), + indexMeta(indexMeta), + chunkManagerPtr(chunkManagerPtr) { + } + bool + Valid() const { + return chunkManagerPtr != nullptr; + } + + FieldDataMeta fieldDataMeta; + IndexMeta indexMeta; + ChunkManagerPtr chunkManagerPtr; +}; + #define FILEMANAGER_TRY try { -#define FILEMANAGER_CATCH \ - } \ - catch (LocalChunkManagerException & e) { \ - LOG_SEGCORE_ERROR_ << "LocalChunkManagerException:" << e.what(); \ - return false; \ - } \ - catch (MinioException & e) { \ - LOG_SEGCORE_ERROR_ << "milvus::storage::MinioException:" << e.what(); \ - return false; \ - } \ - catch (DiskANNFileManagerException & e) { \ - LOG_SEGCORE_ERROR_ << "milvus::storage::DiskANNFileManagerException:" \ - << e.what(); \ - return false; \ - } \ - catch (ArrowException & e) { \ - LOG_SEGCORE_ERROR_ << "milvus::storage::ArrowException:" << e.what(); \ - return false; \ - } \ - catch (std::exception & e) { \ - LOG_SEGCORE_ERROR_ << "Exception:" << e.what(); \ +#define FILEMANAGER_CATCH \ + } \ + catch (SegcoreError & e) { \ + LOG_SEGCORE_ERROR_ << "SegcoreError: code " << e.get_error_code() \ + << ", " << e.what(); \ + return false; \ + } \ + catch (std::exception & e) { \ + LOG_SEGCORE_ERROR_ << "Exception:" << e.what(); \ return false; #define FILEMANAGER_END } diff --git a/internal/core/src/storage/IndexData.cpp b/internal/core/src/storage/IndexData.cpp index fb448a362cacb..6309fd44f7dee 100644 --- a/internal/core/src/storage/IndexData.cpp +++ b/internal/core/src/storage/IndexData.cpp @@ -15,7 +15,7 @@ // limitations under the License. #include "storage/IndexData.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "common/Consts.h" #include "storage/Event.h" @@ -41,7 +41,8 @@ IndexData::Serialize(StorageType medium) { case StorageType::LocalDisk: return serialize_to_local_file(); default: - PanicInfo("unsupported medium type"); + PanicInfo(DataFormatBroken, + fmt::format("unsupported medium type {}", medium)); } } diff --git a/internal/core/src/storage/InsertData.cpp b/internal/core/src/storage/InsertData.cpp index 91cbf093b12e5..514d98d56aac6 100644 --- a/internal/core/src/storage/InsertData.cpp +++ b/internal/core/src/storage/InsertData.cpp @@ -17,7 +17,7 @@ #include "storage/InsertData.h" #include "storage/Event.h" #include "storage/Util.h" -#include "utils/Json.h" +#include "common/Json.h" #include "common/FieldMeta.h" #include "common/Consts.h" @@ -37,7 +37,8 @@ InsertData::Serialize(StorageType medium) { case StorageType::LocalDisk: return serialize_to_local_file(); default: - PanicInfo("unsupported medium type"); + PanicInfo(DataFormatBroken, + fmt::format("unsupported medium type {}", medium)); } } diff --git a/internal/core/src/storage/LocalChunkManager.cpp b/internal/core/src/storage/LocalChunkManager.cpp index 03d5cfbbd8423..7baca5e6c0948 100644 --- a/internal/core/src/storage/LocalChunkManager.cpp +++ b/internal/core/src/storage/LocalChunkManager.cpp @@ -21,13 +21,13 @@ #include #include -#include "Exception.h" +#include "common/EasyAssert.h" -#define THROWLOCALERROR(FUNCTION) \ +#define THROWLOCALERROR(code, FUNCTION) \ do { \ std::stringstream err_msg; \ err_msg << "Error:" << #FUNCTION << ":" << err.message(); \ - throw LocalChunkManagerException(err_msg.str()); \ + throw SegcoreError(code, err_msg.str()); \ } while (0) namespace milvus::storage { @@ -38,7 +38,7 @@ LocalChunkManager::Exist(const std::string& filepath) { boost::system::error_code err; bool isExist = boost::filesystem::exists(absPath, err); if (err && err.value() != boost::system::errc::no_such_file_or_directory) { - THROWLOCALERROR(Exist); + THROWLOCALERROR(FileReadFailed, Exist); } return isExist; } @@ -48,12 +48,13 @@ LocalChunkManager::Size(const std::string& filepath) { boost::filesystem::path absPath(filepath); if (!Exist(filepath)) { - throw InvalidPathException("invalid local path:" + absPath.string()); + throw SegcoreError(PathNotExist, + "invalid local path:" + absPath.string()); } boost::system::error_code err; int64_t size = boost::filesystem::file_size(absPath, err); if (err) { - THROWLOCALERROR(FileSize); + THROWLOCALERROR(FileReadFailed, FileSize); } return size; } @@ -64,7 +65,7 @@ LocalChunkManager::Remove(const std::string& filepath) { boost::system::error_code err; boost::filesystem::remove(absPath, err); if (err) { - THROWLOCALERROR(Remove); + THROWLOCALERROR(FileWriteFailed, Remove); } } @@ -84,7 +85,7 @@ LocalChunkManager::Read(const std::string& filepath, std::stringstream err_msg; err_msg << "Error: open local file '" << filepath << " failed, " << strerror(errno); - throw OpenFileException(err_msg.str()); + throw SegcoreError(FileOpenFailed, err_msg.str()); } infile.seekg(offset, std::ios::beg); @@ -93,7 +94,7 @@ LocalChunkManager::Read(const std::string& filepath, std::stringstream err_msg; err_msg << "Error: read local file '" << filepath << " failed, " << strerror(errno); - throw ReadFileException(err_msg.str()); + throw SegcoreError(FileReadFailed, err_msg.str()); } } return infile.gcount(); @@ -114,13 +115,13 @@ LocalChunkManager::Write(const std::string& absPathStr, std::stringstream err_msg; err_msg << "Error: open local file '" << absPathStr << " failed, " << strerror(errno); - throw OpenFileException(err_msg.str()); + throw SegcoreError(FileOpenFailed, err_msg.str()); } if (!outfile.write(reinterpret_cast(buf), size)) { std::stringstream err_msg; err_msg << "Error: write local file '" << absPathStr << " failed, " << strerror(errno); - throw WriteFileException(err_msg.str()); + throw SegcoreError(FileWriteFailed, err_msg.str()); } } @@ -142,7 +143,7 @@ LocalChunkManager::Write(const std::string& absPathStr, std::stringstream err_msg; err_msg << "Error: open local file '" << absPathStr << " failed, " << strerror(errno); - throw OpenFileException(err_msg.str()); + throw SegcoreError(FileOpenFailed, err_msg.str()); } outfile.seekp(offset, std::ios::beg); @@ -150,14 +151,14 @@ LocalChunkManager::Write(const std::string& absPathStr, std::stringstream err_msg; err_msg << "Error: write local file '" << absPathStr << " failed, " << strerror(errno); - throw WriteFileException(err_msg.str()); + throw SegcoreError(FileWriteFailed, err_msg.str()); } } std::vector LocalChunkManager::ListWithPrefix(const std::string& filepath) { - throw NotImplementedException(GetName() + "::ListWithPrefix" + - " not implement now"); + throw SegcoreError(NotImplemented, + GetName() + "::ListWithPrefix" + " not implement now"); } bool @@ -173,7 +174,7 @@ LocalChunkManager::CreateFile(const std::string& filepath) { std::stringstream err_msg; err_msg << "Error: create new local file '" << absPathStr << " failed, " << strerror(errno); - throw CreateFileException(err_msg.str()); + throw SegcoreError(FileCreateFailed, err_msg.str()); } file.close(); return true; @@ -185,7 +186,7 @@ LocalChunkManager::DirExist(const std::string& dir) { boost::system::error_code err; bool isExist = boost::filesystem::exists(dirPath, err); if (err && err.value() != boost::system::errc::no_such_file_or_directory) { - THROWLOCALERROR(DirExist); + THROWLOCALERROR(FileReadFailed, DirExist); } return isExist; } @@ -194,12 +195,12 @@ void LocalChunkManager::CreateDir(const std::string& dir) { bool isExist = DirExist(dir); if (isExist) { - throw PathAlreadyExistException("dir:" + dir + " already exists"); + throw SegcoreError(PathAlreadyExist, "dir:" + dir + " already exists"); } boost::filesystem::path dirPath(dir); auto create_success = boost::filesystem::create_directories(dirPath); if (!create_success) { - throw CreateFileException("create dir failed" + dir); + throw SegcoreError(FileCreateFailed, "create dir failed" + dir); } } @@ -209,7 +210,7 @@ LocalChunkManager::RemoveDir(const std::string& dir) { boost::system::error_code err; boost::filesystem::remove_all(dirPath, err); if (err) { - THROWLOCALERROR(RemoveDir); + THROWLOCALERROR(FileCreateFailed, RemoveDir); } } @@ -218,7 +219,7 @@ LocalChunkManager::GetSizeOfDir(const std::string& dir) { boost::filesystem::path dirPath(dir); bool is_dir = boost::filesystem::is_directory(dirPath); if (!is_dir) { - throw DirNotExistException("dir:" + dir + " not exists"); + throw SegcoreError(PathNotExist, "dir:" + dir + " not exists"); } using boost::filesystem::directory_entry; diff --git a/internal/core/src/storage/LocalChunkManagerSingleton.h b/internal/core/src/storage/LocalChunkManagerSingleton.h index f393a109e00cb..2715796d7b771 100644 --- a/internal/core/src/storage/LocalChunkManagerSingleton.h +++ b/internal/core/src/storage/LocalChunkManagerSingleton.h @@ -43,25 +43,17 @@ class LocalChunkManagerSingleton { void Init(std::string root_path) { - std::unique_lock lck(mutex_); if (lcm_ == nullptr) { lcm_ = std::make_shared(root_path); } } - void - Release() { - std::unique_lock lck(mutex_); - lcm_ = nullptr; - } - LocalChunkManagerSPtr GetChunkManager() { return lcm_; } private: - mutable std::shared_mutex mutex_; LocalChunkManagerSPtr lcm_ = nullptr; }; diff --git a/internal/core/src/storage/MemFileManagerImpl.cpp b/internal/core/src/storage/MemFileManagerImpl.cpp index 2ef232ef1951c..2fe3af7b87d16 100644 --- a/internal/core/src/storage/MemFileManagerImpl.cpp +++ b/internal/core/src/storage/MemFileManagerImpl.cpp @@ -26,17 +26,18 @@ namespace milvus::storage { MemFileManagerImpl::MemFileManagerImpl( - const FieldDataMeta& field_mata, - IndexMeta index_meta, + const FileManagerContext& fileManagerContext, std::shared_ptr space) - : FileManagerImpl(field_mata, index_meta), space_(space) { + : FileManagerImpl(fileManagerContext.fieldDataMeta, + fileManagerContext.indexMeta), space_(space) { + rcm_ = fileManagerContext.chunkManagerPtr; } -MemFileManagerImpl::MemFileManagerImpl(const FieldDataMeta& field_mata, - IndexMeta index_meta, - ChunkManagerPtr remote_chunk_manager) - : FileManagerImpl(field_mata, index_meta) { - rcm_ = remote_chunk_manager; +MemFileManagerImpl::MemFileManagerImpl( + const FileManagerContext& fileManagerContext) + : FileManagerImpl(fileManagerContext.fieldDataMeta, + fileManagerContext.indexMeta) { + rcm_ = fileManagerContext.chunkManagerPtr; } bool @@ -221,4 +222,4 @@ MemFileManagerImpl::RemoveFile(const std::string& filename) noexcept { return false; } -} // namespace milvus::storage \ No newline at end of file +} // namespace milvus::storage diff --git a/internal/core/src/storage/MemFileManagerImpl.h b/internal/core/src/storage/MemFileManagerImpl.h index bebdb2eca10b1..3541e754f1bc6 100644 --- a/internal/core/src/storage/MemFileManagerImpl.h +++ b/internal/core/src/storage/MemFileManagerImpl.h @@ -31,12 +31,9 @@ namespace milvus::storage { class MemFileManagerImpl : public FileManagerImpl { public: - explicit MemFileManagerImpl(const FieldDataMeta& field_mata, - IndexMeta index_meta, - ChunkManagerPtr remote_chunk_manager); + explicit MemFileManagerImpl(const FileManagerContext& fileManagerContext); - MemFileManagerImpl(const FieldDataMeta& field_mata, - IndexMeta index_meta, + MemFileManagerImpl(const FileManagerContext& fileManagerContext, std::shared_ptr space); virtual bool @@ -82,4 +79,4 @@ class MemFileManagerImpl : public FileManagerImpl { using MemFileManagerImplPtr = std::shared_ptr; -} // namespace milvus::storage \ No newline at end of file +} // namespace milvus::storage diff --git a/internal/core/src/storage/MinioChunkManager.cpp b/internal/core/src/storage/MinioChunkManager.cpp index 430bfbbb1f5dc..badddb5fa7977 100644 --- a/internal/core/src/storage/MinioChunkManager.cpp +++ b/internal/core/src/storage/MinioChunkManager.cpp @@ -31,22 +31,11 @@ #include "storage/MinioChunkManager.h" #include "storage/AliyunSTSClient.h" #include "storage/AliyunCredentialsProvider.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "log/Log.h" #include "signal.h" +#include "common/Consts.h" -#define THROWS3ERROR(FUNCTION) \ - do { \ - auto& err = outcome.GetError(); \ - std::stringstream err_msg; \ - err_msg << "Error:" << #FUNCTION \ - << "[errcode:" << int(err.GetResponseCode()) \ - << ", exception:" << err.GetExceptionName() \ - << ", errmessage:" << err.GetMessage() << "]"; \ - throw S3ErrorException(err_msg.str()); \ - } while (0) - -#define S3NoSuchBucket "NoSuchBucket" namespace milvus::storage { std::atomic MinioChunkManager::init_count_(0); @@ -150,6 +139,49 @@ MinioChunkManager::InitSDKAPI(RemoteStorageType type, } } +void +MinioChunkManager::InitSDKAPIDefault(const std::string& log_level_str) { + std::scoped_lock lock{client_mutex_}; + const size_t initCount = init_count_++; + if (initCount == 0) { + // sdk_options_.httpOptions.installSigPipeHandler = true; + struct sigaction psa; + memset(&psa, 0, sizeof psa); + psa.sa_handler = SwallowHandler; + psa.sa_flags = psa.sa_flags | SA_ONSTACK; + sigaction(SIGPIPE, &psa, 0); + // block multiple SIGPIPE concurrently processing + sigemptyset(&psa.sa_mask); + sigaddset(&psa.sa_mask, SIGPIPE); + sigaction(SIGPIPE, &psa, 0); + LOG_SEGCORE_INFO_ << "init aws with log level:" << log_level_str; + auto get_aws_log_level = [](const std::string& level_str) { + Aws::Utils::Logging::LogLevel level = + Aws::Utils::Logging::LogLevel::Off; + if (level_str == "fatal") { + level = Aws::Utils::Logging::LogLevel::Fatal; + } else if (level_str == "error") { + level = Aws::Utils::Logging::LogLevel::Error; + } else if (level_str == "warn") { + level = Aws::Utils::Logging::LogLevel::Warn; + } else if (level_str == "info") { + level = Aws::Utils::Logging::LogLevel::Info; + } else if (level_str == "debug") { + level = Aws::Utils::Logging::LogLevel::Debug; + } else if (level_str == "trace") { + level = Aws::Utils::Logging::LogLevel::Trace; + } + return level; + }; + auto log_level = get_aws_log_level(log_level_str); + sdk_options_.loggingOptions.logLevel = log_level; + sdk_options_.loggingOptions.logger_create_fn = [log_level]() { + return std::make_shared(log_level); + }; + Aws::InitAPI(sdk_options_); + } +} + void MinioChunkManager::ShutdownSDKAPI() { std::scoped_lock lock{client_mutex_}; @@ -270,6 +302,10 @@ MinioChunkManager::MinioChunkManager(const StorageConfig& storage_config) config.verifySSL = false; } + config.requestTimeoutMs = storage_config.requestTimeoutMs == 0 + ? DEFAULT_CHUNK_MANAGER_REQUEST_TIMEOUT_MS + : storage_config.requestTimeoutMs; + if (!storage_config.region.empty()) { config.region = ConvertToAwsString(storage_config.region); } @@ -284,7 +320,8 @@ MinioChunkManager::MinioChunkManager(const StorageConfig& storage_config) LOG_SEGCORE_INFO_ << "init MinioChunkManager with parameter[endpoint: '" << storage_config.address << "', default_bucket_name:'" - << storage_config.bucket_name << "', use_secure:'" + << storage_config.bucket_name << "', root_path:'" + << storage_config.root_path << "', use_secure:'" << std::boolalpha << storage_config.useSSL << "']"; } @@ -315,12 +352,6 @@ MinioChunkManager::ListWithPrefix(const std::string& filepath) { uint64_t MinioChunkManager::Read(const std::string& filepath, void* buf, uint64_t size) { - if (!ObjectExists(default_bucket_name_, filepath)) { - std::stringstream err_msg; - err_msg << "object('" << default_bucket_name_ << "', " << filepath - << "') not exists"; - throw ObjectNotExistException(err_msg.str()); - } return GetObjectBuffer(default_bucket_name_, filepath, buf, size); } @@ -333,28 +364,19 @@ MinioChunkManager::Write(const std::string& filepath, bool MinioChunkManager::BucketExists(const std::string& bucket_name) { - // auto outcome = client_->ListBuckets(); - - // if (!outcome.IsSuccess()) { - // THROWS3ERROR(BucketExists); - // } - // for (auto&& b : outcome.GetResult().GetBuckets()) { - // if (ConvertFromAwsString(b.GetName()) == bucket_name) { - // return true; - // } - // } Aws::S3::Model::HeadBucketRequest request; request.SetBucket(bucket_name.c_str()); auto outcome = client_->HeadBucket(request); if (!outcome.IsSuccess()) { - auto error = outcome.GetError(); - if (!error.GetExceptionName().empty()) { - std::stringstream err_msg; - err_msg << "Error: BucketExists: " - << error.GetExceptionName() + " - " + error.GetMessage(); - throw S3ErrorException(err_msg.str()); + const auto& err = outcome.GetError(); + auto error_type = err.GetErrorType(); + // only throw if the error is not nosuchbucket + // if bucket not exist, HeadBucket return errorType RESOURCE_NOT_FOUND + if (error_type != Aws::S3::S3Errors::NO_SUCH_BUCKET && + error_type != Aws::S3::S3Errors::RESOURCE_NOT_FOUND) { + ThrowS3Error("BucketExists", err, "params, bucket={}", bucket_name); } return false; } @@ -367,7 +389,8 @@ MinioChunkManager::ListBuckets() { auto outcome = client_->ListBuckets(); if (!outcome.IsSuccess()) { - THROWS3ERROR(CreateBucket); + const auto& err = outcome.GetError(); + ThrowS3Error("ListBuckets", err, "params"); } for (auto&& b : outcome.GetResult().GetBuckets()) { buckets.emplace_back(b.GetName().c_str()); @@ -382,10 +405,13 @@ MinioChunkManager::CreateBucket(const std::string& bucket_name) { auto outcome = client_->CreateBucket(request); - if (!outcome.IsSuccess() && - Aws::S3::S3Errors(outcome.GetError().GetErrorType()) != + if (!outcome.IsSuccess()) { + const auto& err = outcome.GetError(); + if (err.GetErrorType() != Aws::S3::S3Errors::BUCKET_ALREADY_OWNED_BY_YOU) { - THROWS3ERROR(CreateBucket); + ThrowS3Error("CreateBucket", err, "params, bucket={}", bucket_name); + } + return false; } return true; } @@ -398,9 +424,11 @@ MinioChunkManager::DeleteBucket(const std::string& bucket_name) { auto outcome = client_->DeleteBucket(request); if (!outcome.IsSuccess()) { - auto err = outcome.GetError(); - if (err.GetExceptionName() != S3NoSuchBucket) { - THROWS3ERROR(DeleteBucket); + const auto& err = outcome.GetError(); + auto error_type = err.GetErrorType(); + if (error_type != Aws::S3::S3Errors::NO_SUCH_BUCKET && + error_type != Aws::S3::S3Errors::RESOURCE_NOT_FOUND) { + ThrowS3Error("DeleteBucket", err, "params, bucket={}", bucket_name); } return false; } @@ -417,11 +445,13 @@ MinioChunkManager::ObjectExists(const std::string& bucket_name, auto outcome = client_->HeadObject(request); if (!outcome.IsSuccess()) { - auto& err = outcome.GetError(); - if (!err.GetExceptionName().empty()) { - std::stringstream err_msg; - err_msg << "Error: ObjectExists: " << err.GetMessage(); - throw S3ErrorException(err_msg.str()); + const auto& err = outcome.GetError(); + if (!IsNotFound(err.GetErrorType())) { + ThrowS3Error("ObjectExists", + err, + "params, bucket={}, object={}", + bucket_name, + object_name); } return false; } @@ -437,7 +467,12 @@ MinioChunkManager::GetObjectSize(const std::string& bucket_name, auto outcome = client_->HeadObject(request); if (!outcome.IsSuccess()) { - THROWS3ERROR(GetObjectSize); + const auto& err = outcome.GetError(); + ThrowS3Error("GetObjectSize", + err, + "params, bucket={}, object={}", + bucket_name, + object_name); } return outcome.GetResult().GetContentLength(); } @@ -452,11 +487,15 @@ MinioChunkManager::DeleteObject(const std::string& bucket_name, auto outcome = client_->DeleteObject(request); if (!outcome.IsSuccess()) { - // auto err = outcome.GetError(); - // std::stringstream err_msg; - // err_msg << "Error: DeleteObject:" << err.GetMessage(); - // throw S3ErrorException(err_msg.str()); - THROWS3ERROR(DeleteObject); + const auto& err = outcome.GetError(); + if (!IsNotFound(err.GetErrorType())) { + ThrowS3Error("DeleteObject", + err, + "params, bucket={}, object={}", + bucket_name, + object_name); + } + return false; } return true; } @@ -479,7 +518,12 @@ MinioChunkManager::PutObjectBuffer(const std::string& bucket_name, auto outcome = client_->PutObject(request); if (!outcome.IsSuccess()) { - THROWS3ERROR(PutObjectBuffer); + const auto& err = outcome.GetError(); + ThrowS3Error("PutObjectBuffer", + err, + "params, bucket={}, object={}", + bucket_name, + object_name); } return true; } @@ -547,24 +591,35 @@ MinioChunkManager::GetObjectBuffer(const std::string& bucket_name, auto outcome = client_->GetObject(request); if (!outcome.IsSuccess()) { - THROWS3ERROR(GetObjectBuffer); + const auto& err = outcome.GetError(); + ThrowS3Error("GetObjectBuffer", + err, + "params, bucket={}, object={}", + bucket_name, + object_name); } return size; } std::vector -MinioChunkManager::ListObjects(const char* bucket_name, const char* prefix) { +MinioChunkManager::ListObjects(const std::string& bucket_name, + const std::string& prefix) { std::vector objects_vec; Aws::S3::Model::ListObjectsRequest request; request.WithBucket(bucket_name); - if (prefix != nullptr) { + if (prefix != "") { request.SetPrefix(prefix); } auto outcome = client_->ListObjects(request); if (!outcome.IsSuccess()) { - THROWS3ERROR(ListObjects); + const auto& err = outcome.GetError(); + ThrowS3Error("ListObjects", + err, + "params, bucket={}, prefix={}", + bucket_name, + prefix); } auto objects = outcome.GetResult().GetContents(); for (auto& obj : objects) { diff --git a/internal/core/src/storage/MinioChunkManager.h b/internal/core/src/storage/MinioChunkManager.h index 2ff1e9d8bdedd..740a426260914 100644 --- a/internal/core/src/storage/MinioChunkManager.h +++ b/internal/core/src/storage/MinioChunkManager.h @@ -36,15 +36,37 @@ #include #include #include +#include +#include "common/EasyAssert.h" #include "storage/ChunkManager.h" -#include "storage/Exception.h" #include "storage/Types.h" namespace milvus::storage { enum class RemoteStorageType { S3 = 0, GOOGLE_CLOUD = 1, ALIYUN_CLOUD = 2 }; +template + +static SegcoreError +ThrowS3Error(const std::string& func, + const Aws::S3::S3Error& err, + const std::string& fmtString, + Args&&... args) { + std::ostringstream oss; + const auto& message = fmt::format(fmtString, std::forward(args)...); + oss << "Error in " << func << "[errcode:" << int(err.GetResponseCode()) + << ", exception:" << err.GetExceptionName() + << ", errmessage:" << err.GetMessage() << ", params:" << message << "]"; + throw SegcoreError(S3Error, oss.str()); +} + +static bool +IsNotFound(const Aws::S3::S3Errors& s3err) { + return (s3err == Aws::S3::S3Errors::NO_SUCH_KEY || + s3err == Aws::S3::S3Errors::RESOURCE_NOT_FOUND); +} + /** * @brief user defined aws logger, redirect aws log to segcore log */ @@ -69,6 +91,8 @@ class AwsLogger : public Aws::Utils::Logging::FormattedLogSystem { */ class MinioChunkManager : public ChunkManager { public: + MinioChunkManager() { + } explicit MinioChunkManager(const StorageConfig& storage_config); MinioChunkManager(const MinioChunkManager&); @@ -89,8 +113,8 @@ class MinioChunkManager : public ChunkManager { uint64_t offset, void* buf, uint64_t len) { - throw NotImplementedException(GetName() + - "Read with offset not implement"); + throw SegcoreError(NotImplemented, + GetName() + "Read with offset not implement"); } virtual void @@ -98,8 +122,8 @@ class MinioChunkManager : public ChunkManager { uint64_t offset, void* buf, uint64_t len) { - throw NotImplementedException(GetName() + - "Write with offset not implement"); + throw SegcoreError(NotImplemented, + GetName() + "Write with offset not implement"); } virtual uint64_t @@ -166,8 +190,12 @@ class MinioChunkManager : public ChunkManager { const std::string& object_name, void* buf, uint64_t size); + std::vector - ListObjects(const char* bucket_name, const char* prefix = nullptr); + ListObjects(const std::string& bucket_name, const std::string& prefix = ""); + + void + InitSDKAPIDefault(const std::string& log_level); void InitSDKAPI(RemoteStorageType type, bool useIAM, @@ -184,7 +212,7 @@ class MinioChunkManager : public ChunkManager { BuildGoogleCloudClient(const StorageConfig& storage_config, const Aws::Client::ClientConfiguration& config); - private: + protected: void BuildAccessKeyClient(const StorageConfig& storage_config, const Aws::Client::ClientConfiguration& config); @@ -197,6 +225,33 @@ class MinioChunkManager : public ChunkManager { std::string remote_root_path_; }; +class AwsChunkManager : public MinioChunkManager { + public: + explicit AwsChunkManager(const StorageConfig& storage_config); + virtual std::string + GetName() const { + return "AwsChunkManager"; + } +}; + +class GcpChunkManager : public MinioChunkManager { + public: + explicit GcpChunkManager(const StorageConfig& storage_config); + virtual std::string + GetName() const { + return "GcpChunkManager"; + } +}; + +class AliyunChunkManager : public MinioChunkManager { + public: + explicit AliyunChunkManager(const StorageConfig& storage_config); + virtual std::string + GetName() const { + return "AliyunChunkManager"; + } +}; + using MinioChunkManagerPtr = std::unique_ptr; static const char* GOOGLE_CLIENT_FACTORY_ALLOCATION_TAG = @@ -242,9 +297,10 @@ class GoogleHttpClientFactory : public Aws::Http::HttpClientFactory { request->SetResponseStreamFactory(streamFactory); auto auth_header = credentials_->AuthorizationHeader(); if (!auth_header.ok()) { - throw std::runtime_error( - "get authorization failed, errcode:" + - StatusCodeToString(auth_header.status().code())); + throw SegcoreError( + S3Error, + fmt::format("get authorization failed, errcode: {}", + StatusCodeToString(auth_header.status().code()))); } request->SetHeaderValue(auth_header->first.c_str(), auth_header->second.c_str()); diff --git a/internal/core/src/storage/PayloadReader.cpp b/internal/core/src/storage/PayloadReader.cpp index 8f9b19031996f..54e39cb636966 100644 --- a/internal/core/src/storage/PayloadReader.cpp +++ b/internal/core/src/storage/PayloadReader.cpp @@ -15,7 +15,7 @@ // limitations under the License. #include "storage/PayloadReader.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "storage/Util.h" #include "parquet/column_reader.h" #include "arrow/io/api.h" diff --git a/internal/core/src/storage/PayloadStream.cpp b/internal/core/src/storage/PayloadStream.cpp index d7ec9f303c0eb..93ccdef858592 100644 --- a/internal/core/src/storage/PayloadStream.cpp +++ b/internal/core/src/storage/PayloadStream.cpp @@ -17,7 +17,7 @@ #include "arrow/api.h" #include "storage/PayloadStream.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" namespace milvus::storage { diff --git a/internal/core/src/storage/PayloadWriter.cpp b/internal/core/src/storage/PayloadWriter.cpp index 7d9def154a7d1..54c47ed81ea68 100644 --- a/internal/core/src/storage/PayloadWriter.cpp +++ b/internal/core/src/storage/PayloadWriter.cpp @@ -15,7 +15,7 @@ // limitations under the License. #include "storage/PayloadWriter.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "common/FieldMeta.h" #include "storage/Util.h" diff --git a/internal/core/src/storage/RemoteChunkManagerSingleton.h b/internal/core/src/storage/RemoteChunkManagerSingleton.h index 7c72cbf8c1ad0..75a070497c417 100644 --- a/internal/core/src/storage/RemoteChunkManagerSingleton.h +++ b/internal/core/src/storage/RemoteChunkManagerSingleton.h @@ -41,7 +41,6 @@ class RemoteChunkManagerSingleton { void Init(const StorageConfig& storage_config) { - std::unique_lock lck(mutex_); if (rcm_ == nullptr) { rcm_ = CreateChunkManager(storage_config); } @@ -49,8 +48,6 @@ class RemoteChunkManagerSingleton { void Release() { - std::unique_lock lck(mutex_); - rcm_ = nullptr; } ChunkManagerPtr @@ -59,7 +56,6 @@ class RemoteChunkManagerSingleton { } private: - mutable std::shared_mutex mutex_; ChunkManagerPtr rcm_ = nullptr; }; diff --git a/internal/core/src/storage/ThreadPools.cpp b/internal/core/src/storage/ThreadPools.cpp index ec3a019702ad6..e2253fafe2464 100644 --- a/internal/core/src/storage/ThreadPools.cpp +++ b/internal/core/src/storage/ThreadPools.cpp @@ -1,3 +1,14 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + // // Created by zilliz on 2023/7/31. // diff --git a/internal/core/src/storage/Types.h b/internal/core/src/storage/Types.h index 51a828de66a70..660eb2d6639e1 100644 --- a/internal/core/src/storage/Types.h +++ b/internal/core/src/storage/Types.h @@ -91,12 +91,70 @@ struct StorageConfig { std::string access_key_value = "minioadmin"; std::string root_path = "files"; std::string storage_type = "minio"; + std::string cloud_provider = "aws"; std::string iam_endpoint = ""; - std::string log_level = "error"; + std::string log_level = "warn"; std::string region = ""; bool useSSL = false; bool useIAM = false; bool useVirtualHost = false; + int64_t requestTimeoutMs = 3000; }; } // namespace milvus::storage + +template <> +struct fmt::formatter : formatter { + auto + format(milvus::storage::EventType c, format_context& ctx) const { + string_view name = "unknown"; + switch (c) { + case milvus::storage::EventType::DescriptorEvent: + name = "DescriptorEvent"; + break; + case milvus::storage::EventType::InsertEvent: + name = "InsertEvent"; + break; + case milvus::storage::EventType::DeleteEvent: + name = "DeleteEvent"; + break; + case milvus::storage::EventType::CreateCollectionEvent: + name = "CreateCollectionEvent"; + break; + case milvus::storage::EventType::DropCollectionEvent: + name = "DropCollectionEvent"; + break; + case milvus::storage::EventType::CreatePartitionEvent: + name = "CreatePartitionEvent"; + break; + case milvus::storage::EventType::DropPartitionEvent: + name = "DropPartitionEvent"; + break; + case milvus::storage::EventType::IndexFileEvent: + name = "IndexFileEvent"; + break; + case milvus::storage::EventType::EventTypeEnd: + name = "EventTypeEnd"; + break; + } + return formatter::format(name, ctx); + } +}; + +template <> +struct fmt::formatter : formatter { + auto + format(milvus::storage::StorageType c, format_context& ctx) const { + switch (c) { + case milvus::storage::StorageType::None: + return formatter::format("None", ctx); + case milvus::storage::StorageType::Memory: + return formatter::format("Memory", ctx); + case milvus::storage::StorageType::LocalDisk: + return formatter::format("LocalDisk", ctx); + case milvus::storage::StorageType::Remote: + return formatter::format("Remote", ctx); + } + return formatter::format("unknown", ctx); + } +}; diff --git a/internal/core/src/storage/Util.cpp b/internal/core/src/storage/Util.cpp index f4b146f74decc..7348ffc950537 100644 --- a/internal/core/src/storage/Util.cpp +++ b/internal/core/src/storage/Util.cpp @@ -1,3 +1,4 @@ + // Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information @@ -18,10 +19,14 @@ #include #include "arrow/array/builder_binary.h" #include "arrow/type_fwd.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" #include "common/Consts.h" -#include "fmt/core.h" +#include "fmt/format.h" +#ifdef AZURE_BUILD_DIR +#include "storage/AzureChunkManager.h" +#endif #include "storage/FieldData.h" +#include "storage/InsertData.h" #include "storage/FieldDataInterface.h" #include "storage/ThreadPools.h" #include "storage/LocalChunkManager.h" @@ -33,7 +38,30 @@ namespace milvus::storage { std::map ChunkManagerType_Map = { - {"local", ChunkManagerType::Local}, {"minio", ChunkManagerType::Minio}}; + {"local", ChunkManagerType::Local}, + {"minio", ChunkManagerType::Minio}, + {"remote", ChunkManagerType::Remote}}; + +enum class CloudProviderType : int8_t { + UNKNOWN = 0, + AWS = 1, + GCP = 2, + ALIYUN = 3, + AZURE = 4, +}; + +std::map CloudProviderType_Map = { + {"aws", CloudProviderType::AWS}, + {"gcp", CloudProviderType::GCP}, + {"aliyun", CloudProviderType::ALIYUN}, + {"azure", CloudProviderType::AZURE}}; + +std::map ReadAheadPolicy_Map = { + {"normal", MADV_NORMAL}, + {"random", MADV_RANDOM}, + {"sequential", MADV_SEQUENTIAL}, + {"willneed", MADV_WILLNEED}, + {"dontneed", MADV_DONTNEED}}; StorageType ReadMediumType(BinlogReaderPtr reader) { @@ -123,13 +151,15 @@ AddPayloadToArrowBuilder(std::shared_ptr builder, builder, double_data, length); break; } + case DataType::VECTOR_FLOAT16: case DataType::VECTOR_BINARY: case DataType::VECTOR_FLOAT: { add_vector_payload(builder, const_cast(raw_data), length); break; } default: { - PanicInfo("unsupported data type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", data_type)); } } } @@ -199,7 +229,9 @@ CreateArrowBuilder(DataType data_type) { return std::make_shared(); } default: { - PanicInfo("unsupported numeric data type"); + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported numeric data type {}", data_type)); } } } @@ -217,8 +249,15 @@ CreateArrowBuilder(DataType data_type, int dim) { return std::make_shared( arrow::fixed_size_binary(dim / 8)); } + case DataType::VECTOR_FLOAT16: { + AssertInfo(dim > 0, "invalid dim value"); + return std::make_shared( + arrow::fixed_size_binary(dim * sizeof(float16))); + } default: { - PanicInfo("unsupported vector data type"); + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported vector data type {}", data_type)); } } } @@ -256,7 +295,9 @@ CreateArrowSchema(DataType data_type) { return arrow::schema({arrow::field("val", arrow::binary())}); } default: { - PanicInfo("unsupported numeric data type"); + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported numeric data type {}", data_type)); } } } @@ -274,8 +315,15 @@ CreateArrowSchema(DataType data_type, int dim) { return arrow::schema( {arrow::field("val", arrow::fixed_size_binary(dim / 8))}); } + case DataType::VECTOR_FLOAT16: { + AssertInfo(dim > 0, "invalid dim value"); + return arrow::schema({arrow::field( + "val", arrow::fixed_size_binary(dim * sizeof(float16)))}); + } default: { - PanicInfo("unsupported vector data type"); + PanicInfo( + DataTypeInvalid, + fmt::format("unsupported vector data type {}", data_type)); } } } @@ -290,8 +338,12 @@ GetDimensionFromFileMetaData(const parquet::ColumnDescriptor* schema, case DataType::VECTOR_BINARY: { return schema->type_length() * 8; } + case DataType::VECTOR_FLOAT16: { + return schema->type_length() / sizeof(float16); + } default: - PanicInfo("unsupported data type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", data_type)); } } @@ -316,7 +368,8 @@ GetDimensionFromArrowArray(std::shared_ptr data, return array->byte_width() * 8; } default: - PanicInfo("unsupported data type"); + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type {}", data_type)); } } @@ -398,6 +451,25 @@ EncodeAndUploadIndexSlice2(std::shared_ptr space, return std::make_pair(std::move(object_key), serialized_index_size); } +std::pair +EncodeAndUploadFieldSlice(ChunkManager* chunk_manager, + uint8_t* buf, + int64_t element_count, + FieldDataMeta field_data_meta, + const FieldMeta& field_meta, + std::string object_key) { + auto field_data = + CreateFieldData(field_meta.get_data_type(), field_meta.get_dim(), 0); + field_data->FillFieldData(buf, element_count); + auto insertData = std::make_shared(field_data); + insertData->SetFieldDataMeta(field_data_meta); + auto serialized_index_data = insertData->serialize_to_remote_file(); + auto serialized_index_size = serialized_index_data.size(); + chunk_manager->Write( + object_key, serialized_index_data.data(), serialized_index_size); + return std::make_pair(std::move(object_key), serialized_index_size); +} + // /** // * Returns the current resident set size (physical memory use) measured // * in bytes, or zero if the value cannot be determined on this OS. @@ -466,13 +538,13 @@ GetObjectData(std::shared_ptr space, [&](const std::string& file) -> std::vector { auto res = space->ScanData(); if (!res.ok()) { - PanicInfo("failed to create scan iterator"); + PanicInfo(DataFormatBroken,"failed to create scan iterator"); } auto reader = res.value(); std::vector datas; for (auto rec = reader->Next(); rec != nullptr; rec = reader->Next()) { if (!rec.ok()) { - PanicInfo("failed to read data"); + PanicInfo(DataFormatBroken, "failed to read data"); } auto data = rec.ValueUnsafe(); auto total_num_rows = data->num_rows(); @@ -609,36 +681,36 @@ CreateChunkManager(const StorageConfig& storage_config) { case ChunkManagerType::Minio: { return std::make_shared(storage_config); } - default: { - PanicInfo("unsupported"); + case ChunkManagerType::Remote: { + auto cloud_provider_type = + CloudProviderType_Map[storage_config.cloud_provider]; + switch (cloud_provider_type) { + case CloudProviderType::AWS: { + return std::make_shared(storage_config); + } + case CloudProviderType::GCP: { + return std::make_shared(storage_config); + } + case CloudProviderType::ALIYUN: { + return std::make_shared(storage_config); + } +#ifdef AZURE_BUILD_DIR + case CloudProviderType::AZURE: { + return std::make_shared(storage_config); + } +#endif + default: { + return std::make_shared(storage_config); + } + } } - } -} - -FileManagerImplPtr -CreateFileManager(IndexType index_type, - const FieldDataMeta& field_meta, - const IndexMeta& index_meta, - ChunkManagerPtr cm) { - if (is_in_disk_list(index_type)) { - return std::make_shared( - field_meta, index_meta, cm); - } - return std::make_shared(field_meta, index_meta, cm); -} - -FileManagerImplPtr -CreateFileManager(IndexType index_type, - const FieldDataMeta& field_meta, - const IndexMeta& index_meta, - std::shared_ptr space) { - if (is_in_disk_list(index_type)) { - return std::make_shared( - field_meta, index_meta, space); + default: { + PanicInfo(ConfigInvalid, + fmt::format("unsupported storage_config.storage_type {}", + fmt::underlying(storage_type))); + } } - - return std::make_shared(field_meta, index_meta, space); } FieldDataPtr @@ -664,14 +736,20 @@ CreateFieldData(const DataType& type, int64_t dim, int64_t total_num_rows) { total_num_rows); case DataType::JSON: return std::make_shared>(type, total_num_rows); + case DataType::ARRAY: + return std::make_shared>(type, total_num_rows); case DataType::VECTOR_FLOAT: return std::make_shared>( dim, type, total_num_rows); case DataType::VECTOR_BINARY: return std::make_shared>( dim, type, total_num_rows); + case DataType::VECTOR_FLOAT16: + return std::make_shared>( + dim, type, total_num_rows); default: - throw NotSupportedDataTypeException( + throw SegcoreError( + DataTypeInvalid, "CreateFieldData not support data type " + datatype_name(type)); } } diff --git a/internal/core/src/storage/Util.h b/internal/core/src/storage/Util.h index 81031230de939..ca2231ab846b9 100644 --- a/internal/core/src/storage/Util.h +++ b/internal/core/src/storage/Util.h @@ -103,6 +103,13 @@ EncodeAndUploadIndexSlice2(std::shared_ptr space, IndexMeta index_meta, FieldDataMeta field_meta, std::string object_key); +std::pair +EncodeAndUploadFieldSlice(ChunkManager* chunk_manager, + uint8_t* buf, + int64_t element_count, + FieldDataMeta field_data_meta, + const FieldMeta& field_meta, + std::string object_key); std::vector GetObjectData(ChunkManager* remote_chunk_manager, @@ -143,18 +150,6 @@ ReleaseArrowUnused(); ChunkManagerPtr CreateChunkManager(const StorageConfig& storage_config); -FileManagerImplPtr -CreateFileManager(IndexType index_type, - const FieldDataMeta& field_meta, - const IndexMeta& index_meta, - ChunkManagerPtr cm); - -FileManagerImplPtr -CreateFileManager(IndexType index_type, - const FieldDataMeta& field_meta, - const IndexMeta& index_meta, - std::shared_ptr space); - FieldDataPtr CreateFieldData(const DataType& type, int64_t dim = 1, diff --git a/internal/core/src/storage/azure-blob-storage/AzureBlobChunkManager.cpp b/internal/core/src/storage/azure-blob-storage/AzureBlobChunkManager.cpp new file mode 100644 index 0000000000000..10b6ef1da32b7 --- /dev/null +++ b/internal/core/src/storage/azure-blob-storage/AzureBlobChunkManager.cpp @@ -0,0 +1,244 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "AzureBlobChunkManager.h" + +namespace azure { + +std::string +GetTenantId() { + return std::getenv("AZURE_TENANT_ID"); +} +std::string +GetClientId() { + return std::getenv("AZURE_CLIENT_ID"); +} +std::string +GetTokenFilePath() { + return std::getenv("AZURE_FEDERATED_TOKEN_FILE"); +} +std::string +GetConnectionString(const std::string& access_key_id, + const std::string& access_key_value, + const std::string& address) { + char const* tmp = getenv("AZURE_STORAGE_CONNECTION_STRING"); + if (tmp != NULL) { + std::string envConnectionString(tmp); + if (!envConnectionString.empty()) { + return envConnectionString; + } + } + return "DefaultEndpointsProtocol=https;AccountName=" + access_key_id + + ";AccountKey=" + access_key_value + ";EndpointSuffix=" + address; +} + +AzureBlobChunkManager::AzureBlobChunkManager( + const std::string& access_key_id, + const std::string& access_key_value, + const std::string& address, + bool useIAM) { + if (useIAM) { + auto workloadIdentityCredential = + std::make_shared( + GetTenantId(), GetClientId(), GetTokenFilePath()); + client_ = std::make_shared( + "https://" + access_key_id + ".blob." + address + "/", + workloadIdentityCredential); + } else { + client_ = std::make_shared( + Azure::Storage::Blobs::BlobServiceClient:: + CreateFromConnectionString(GetConnectionString( + access_key_id, access_key_value, address))); + } +} + +AzureBlobChunkManager::~AzureBlobChunkManager() { +} + +bool +AzureBlobChunkManager::BucketExists(const std::string& bucket_name) { + std::vector buckets; + for (auto containerPage = client_->ListBlobContainers(); + containerPage.HasPage(); + containerPage.MoveToNextPage()) { + for (auto& container : containerPage.BlobContainers) { + if (container.Name == bucket_name) { + return true; + } + } + } + return false; +} + +std::vector +AzureBlobChunkManager::ListBuckets() { + std::vector buckets; + for (auto containerPage = client_->ListBlobContainers(); + containerPage.HasPage(); + containerPage.MoveToNextPage()) { + for (auto& container : containerPage.BlobContainers) { + buckets.emplace_back(container.Name); + } + } + return buckets; +} + +void +AzureBlobChunkManager::CreateBucket(const std::string& bucket_name) { + client_->GetBlobContainerClient(bucket_name).Create(); +} + +void +AzureBlobChunkManager::DeleteBucket(const std::string& bucket_name) { + client_->GetBlobContainerClient(bucket_name).Delete(); +} + +bool +AzureBlobChunkManager::ObjectExists(const std::string& bucket_name, + const std::string& object_name) { + for (auto blobPage = + client_->GetBlobContainerClient(bucket_name).ListBlobs(); + blobPage.HasPage(); + blobPage.MoveToNextPage()) { + for (auto& blob : blobPage.Blobs) { + if (blob.Name == object_name) { + return true; + } + } + } + return false; +} + +int64_t +AzureBlobChunkManager::GetObjectSize(const std::string& bucket_name, + const std::string& object_name) { + for (auto blobPage = + client_->GetBlobContainerClient(bucket_name).ListBlobs(); + blobPage.HasPage(); + blobPage.MoveToNextPage()) { + for (auto& blob : blobPage.Blobs) { + if (blob.Name == object_name) { + return blob.BlobSize; + } + } + } + std::stringstream err_msg; + err_msg << "object('" << bucket_name << "', '" << object_name + << "') not exists"; + throw std::runtime_error(err_msg.str()); +} + +void +AzureBlobChunkManager::DeleteObject(const std::string& bucket_name, + const std::string& object_name) { + client_->GetBlobContainerClient(bucket_name) + .GetBlockBlobClient(object_name) + .Delete(); +} + +bool +AzureBlobChunkManager::PutObjectBuffer(const std::string& bucket_name, + const std::string& object_name, + void* buf, + uint64_t size) { + std::vector str(static_cast(buf), + static_cast(buf) + size); + client_->GetBlobContainerClient(bucket_name) + .GetBlockBlobClient(object_name) + .UploadFrom(str.data(), str.size()); + return true; +} + +uint64_t +AzureBlobChunkManager::GetObjectBuffer(const std::string& bucket_name, + const std::string& object_name, + void* buf, + uint64_t size) { + Azure::Storage::Blobs::DownloadBlobOptions downloadOptions; + downloadOptions.Range = Azure::Core::Http::HttpRange(); + downloadOptions.Range.Value().Offset = 0; + downloadOptions.Range.Value().Length = size; + auto downloadResponse = client_->GetBlobContainerClient(bucket_name) + .GetBlockBlobClient(object_name) + .Download(downloadOptions); + std::vector str = + downloadResponse.Value.BodyStream->ReadToEnd(); + memcpy(static_cast(buf), &str[0], str.size() * sizeof(str[0])); + return str.size(); +} + +std::vector +AzureBlobChunkManager::ListObjects(const char* bucket_name, + const char* prefix) { + std::vector objects_vec; + for (auto blobPage = + client_->GetBlobContainerClient(bucket_name).ListBlobs(); + blobPage.HasPage(); + blobPage.MoveToNextPage()) { + for (auto& blob : blobPage.Blobs) { + if (blob.Name.rfind(prefix, 0) == 0) { + objects_vec.emplace_back(blob.Name); + } + } + } + return objects_vec; +} + +} // namespace azure + +int +main() { + const char* containerName = "default"; + const char* blobName = "sample-blob"; + using namespace azure; + AzureBlobChunkManager chunkManager = AzureBlobChunkManager("", "", ""); + std::vector buckets = chunkManager.ListBuckets(); + for (const auto& bucket : buckets) { + std::cout << bucket << std::endl; + } + std::vector objects = + chunkManager.ListObjects(containerName, blobName); + for (const auto& object : objects) { + std::cout << object << std::endl; + } + std::cout << chunkManager.GetObjectSize(containerName, blobName) + << std::endl; + std::cout << chunkManager.ObjectExists(containerName, blobName) + << std::endl; + std::cout << chunkManager.ObjectExists(containerName, "blobName") + << std::endl; + std::cout << chunkManager.BucketExists(containerName) << std::endl; + char buffer[1024 * 1024]; + chunkManager.GetObjectBuffer(containerName, blobName, buffer, 1024 * 1024); + std::cout << buffer << std::endl; + + char msg[12]; + memcpy(msg, "Azure hello!", 12); + if (!chunkManager.ObjectExists(containerName, "blobName")) { + chunkManager.PutObjectBuffer(containerName, "blobName", msg, 12); + } + char buffer0[1024 * 1024]; + chunkManager.GetObjectBuffer( + containerName, "blobName", buffer0, 1024 * 1024); + std::cout << buffer0 << std::endl; + chunkManager.DeleteObject(containerName, "blobName"); + chunkManager.CreateBucket("sample-container1"); + chunkManager.DeleteBucket("sample-container1"); + exit(EXIT_SUCCESS); +} \ No newline at end of file diff --git a/internal/core/src/storage/azure-blob-storage/AzureBlobChunkManager.h b/internal/core/src/storage/azure-blob-storage/AzureBlobChunkManager.h new file mode 100644 index 0000000000000..3ff19ad9f0197 --- /dev/null +++ b/internal/core/src/storage/azure-blob-storage/AzureBlobChunkManager.h @@ -0,0 +1,78 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// export CPLUS_INCLUDE_PATH=/opt/homebrew/Cellar/boost/1.81.0_1/include/ + +#pragma once + +#include +#include +#include +#include +#include +#include "azure/storage/common/storage_exception.hpp" + +namespace azure { +/** + * @brief This AzureBlobChunkManager is responsible for read and write file in blob. + */ +class AzureBlobChunkManager { + public: + explicit AzureBlobChunkManager(const std::string& access_key_id, + const std::string& access_key_value, + const std::string& address, + bool useIAM = false); + + AzureBlobChunkManager(const AzureBlobChunkManager&); + AzureBlobChunkManager& + operator=(const AzureBlobChunkManager&); + + public: + virtual ~AzureBlobChunkManager(); + + bool + BucketExists(const std::string& bucket_name); + void + CreateBucket(const std::string& bucket_name); + void + DeleteBucket(const std::string& bucket_name); + std::vector + ListBuckets(); + bool + ObjectExists(const std::string& bucket_name, + const std::string& object_name); + int64_t + GetObjectSize(const std::string& bucket_name, + const std::string& object_name); + void + DeleteObject(const std::string& bucket_name, + const std::string& object_name); + bool + PutObjectBuffer(const std::string& bucket_name, + const std::string& object_name, + void* buf, + uint64_t size); + uint64_t + GetObjectBuffer(const std::string& bucket_name, + const std::string& object_name, + void* buf, + uint64_t size); + std::vector + ListObjects(const char* bucket_name, const char* prefix = nullptr); + + private: + std::shared_ptr client_; +}; + +} // namespace azure diff --git a/internal/core/src/storage/azure-blob-storage/CMakeLists.txt b/internal/core/src/storage/azure-blob-storage/CMakeLists.txt new file mode 100644 index 0000000000000..91c2cc34719f6 --- /dev/null +++ b/internal/core/src/storage/azure-blob-storage/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# SPDX-License-Identifier: MIT + +cmake_minimum_required (VERSION 3.12) +set(CMAKE_CXX_STANDARD 17) +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake-modules") +message("${CMAKE_CURRENT_SOURCE_DIR}") +include(AzureVcpkg) +az_vcpkg_integrate() + +project(azure-blob-storage) + +find_program(NUGET_EXE NAMES nuget) + +if(NOT NUGET_EXE) + message(FATAL "CMake could not find the nuget command line tool. Please install it from https://www.nuget.org/downloads!") +else() + exec_program(${NUGET_EXE} + ARGS install "Microsoft.Attestation.Client" -Version 0.1.181 -ExcludeVersion -OutputDirectory ${CMAKE_BINARY_DIR}/packages) +endif() + +find_package(azure-storage-blobs-cpp CONFIG REQUIRED) +find_package(azure-identity-cpp CONFIG REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-return-type -Wno-pedantic") +add_library(blob-chunk-manager SHARED AzureBlobChunkManager.cpp) +target_link_libraries(blob-chunk-manager PRIVATE Azure::azure-identity Azure::azure-storage-blobs) + +install(TARGETS blob-chunk-manager DESTINATION "${CMAKE_INSTALL_LIBDIR}") + diff --git a/internal/core/src/storage/azure-blob-storage/cmake-modules/AzureVcpkg.cmake b/internal/core/src/storage/azure-blob-storage/cmake-modules/AzureVcpkg.cmake new file mode 100644 index 0000000000000..c49a433e59a45 --- /dev/null +++ b/internal/core/src/storage/azure-blob-storage/cmake-modules/AzureVcpkg.cmake @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# We need to know an absolute path to our repo root to do things like referencing ./LICENSE.txt file. +set(AZ_ROOT_DIR "${CMAKE_CURRENT_LIST_DIR}/..") + +macro(az_vcpkg_integrate) + message("Vcpkg integrate step.") + # AUTO CMAKE_TOOLCHAIN_FILE: + # User can call `cmake -DCMAKE_TOOLCHAIN_FILE="path_to_the_toolchain"` as the most specific scenario. + # As the last alternative (default case), Azure SDK will automatically clone VCPKG folder and set toolchain from there. + if(NOT DEFINED CMAKE_TOOLCHAIN_FILE) + message("CMAKE_TOOLCHAIN_FILE is not defined. Define it for the user.") + # Set AZURE_SDK_DISABLE_AUTO_VCPKG env var to avoid Azure SDK from cloning and setting VCPKG automatically + # This option delegate package's dependencies installation to user. + if(NOT DEFINED ENV{AZURE_SDK_DISABLE_AUTO_VCPKG}) + message("AZURE_SDK_DISABLE_AUTO_VCPKG is not defined. Fetch a local copy of vcpkg.") + # GET VCPKG FROM SOURCE + # User can set env var AZURE_SDK_VCPKG_COMMIT to pick the VCPKG commit to fetch + set(VCPKG_COMMIT_STRING 71d875654e32ee216b0b7e0dc684e589dffa1b1c) # default SDK tested commit + if(DEFINED ENV{AZURE_SDK_VCPKG_COMMIT}) + message("AZURE_SDK_VCPKG_COMMIT is defined. Using that instead of the default.") + set(VCPKG_COMMIT_STRING "$ENV{AZURE_SDK_VCPKG_COMMIT}") # default SDK tested commit + endif() + message("Vcpkg commit string used: ${VCPKG_COMMIT_STRING}") + include(FetchContent) + FetchContent_Declare( + vcpkg + GIT_REPOSITORY https://github.com/milvus-io/vcpkg.git + GIT_TAG ${VCPKG_COMMIT_STRING} + ) + FetchContent_GetProperties(vcpkg) + # make sure to pull vcpkg only once. + if(NOT vcpkg_POPULATED) + FetchContent_Populate(vcpkg) + endif() + # use the vcpkg source path + set(CMAKE_TOOLCHAIN_FILE "${vcpkg_SOURCE_DIR}/scripts/buildsystems/vcpkg.cmake" CACHE STRING "") + endif() + endif() + + # enable triplet customization + if(DEFINED ENV{VCPKG_DEFAULT_TRIPLET} AND NOT DEFINED VCPKG_TARGET_TRIPLET) + set(VCPKG_TARGET_TRIPLET "$ENV{VCPKG_DEFAULT_TRIPLET}" CACHE STRING "") + endif() +endmacro() + +macro(az_vcpkg_portfile_prep targetName fileName contentToRemove) + # with sdk//vcpkg/ + file(READ "${CMAKE_CURRENT_SOURCE_DIR}/vcpkg/${fileName}" fileContents) + + # Windows -> Unix line endings + string(FIND fileContents "\r\n" crLfPos) + + if (crLfPos GREATER -1) + string(REPLACE "\r\n" "\n" fileContents ${fileContents}) + endif() + + # remove comment header + string(REPLACE "${contentToRemove}" "" fileContents ${fileContents}) + + # undo Windows -> Unix line endings (if applicable) + if (crLfPos GREATER -1) + string(REPLACE "\n" "\r\n" fileContents ${fileContents}) + endif() + unset(crLfPos) + + # output to an intermediate location + file (WRITE "${CMAKE_BINARY_DIR}/vcpkg_prep/${targetName}/${fileName}" ${fileContents}) + unset(fileContents) + + # Produce the files to help with the vcpkg release. + # Go to the /out/build//vcpkg directory, and copy (merge) "ports" folder to the vcpkg repo. + # Then, update the portfile.cmake file SHA512 from "1" to the actual hash (a good way to do it is to uninstall a package, + # clean vcpkg/downloads, vcpkg/buildtrees, run "vcpkg install ", and get the SHA from the error message). + configure_file( + "${CMAKE_BINARY_DIR}/vcpkg_prep/${targetName}/${fileName}" + "${CMAKE_BINARY_DIR}/vcpkg/ports/${targetName}-cpp/${fileName}" + @ONLY + ) +endmacro() + +macro(az_vcpkg_export targetName macroNamePart dllImportExportHeaderPath) + foreach(vcpkgFile "vcpkg.json" "portfile.cmake") + az_vcpkg_portfile_prep( + "${targetName}" + "${vcpkgFile}" + "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n" + ) + endforeach() + + # Standard names for folders such as "bin", "lib", "include". We could hardcode, but some other libs use it too (curl). + include(GNUInstallDirs) + + # When installing, copy our "inc" directory (headers) to "include" directory at the install location. + install(DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/inc/azure/" DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/azure") + + # Copy license as "copyright" (vcpkg dictates naming and location). + install(FILES "${AZ_ROOT_DIR}/LICENSE.txt" DESTINATION "${CMAKE_INSTALL_DATAROOTDIR}/${targetName}-cpp" RENAME "copyright") + + # Indicate where to install targets. Mirrors what other ports do. + install( + TARGETS "${targetName}" + EXPORT "${targetName}-cppTargets" + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} # DLLs (if produced by build) go to "/bin" + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} # static .lib files + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} # .lib files for DLL build + INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} # headers + ) + + # If building a Windows DLL, patch the dll_import_export.hpp + if(WIN32 AND BUILD_SHARED_LIBS) + add_compile_definitions(AZ_${macroNamePart}_BEING_BUILT) + target_compile_definitions(${targetName} PUBLIC AZ_${macroNamePart}_DLL) + + set(AZ_${macroNamePart}_DLL_INSTALLED_AS_PACKAGE "*/ + 1 /*") + configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/inc/${dllImportExportHeaderPath}" + "${CMAKE_BINARY_DIR}/${CMAKE_INSTALL_INCLUDEDIR}/${dllImportExportHeaderPath}" + @ONLY + ) + unset(AZ_${macroNamePart}_DLL_INSTALLED_AS_PACKAGE) + + get_filename_component(dllImportExportHeaderDir ${dllImportExportHeaderPath} DIRECTORY) + install( + FILES "${CMAKE_BINARY_DIR}/${CMAKE_INSTALL_INCLUDEDIR}/${dllImportExportHeaderPath}" + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${dllImportExportHeaderDir}" + ) + unset(dllImportExportHeaderDir) + endif() + + # Export the targets file itself. + install( + EXPORT "${targetName}-cppTargets" + DESTINATION "${CMAKE_INSTALL_DATAROOTDIR}/${targetName}-cpp" + NAMESPACE Azure:: # Not the C++ namespace, but a namespace in terms of cmake. + FILE "${targetName}-cppTargets.cmake" + ) + + # configure_package_config_file(), write_basic_package_version_file() + include(CMakePackageConfigHelpers) + + # Produce package config file. + configure_package_config_file( + "${CMAKE_CURRENT_SOURCE_DIR}/vcpkg/Config.cmake.in" + "${targetName}-cppConfig.cmake" + INSTALL_DESTINATION "${CMAKE_INSTALL_DATAROOTDIR}/${targetName}-cpp" + PATH_VARS + CMAKE_INSTALL_LIBDIR) + + # Produce version file. + write_basic_package_version_file( + "${targetName}-cppConfigVersion.cmake" + VERSION ${AZ_LIBRARY_VERSION} # the version that we extracted from package_version.hpp + COMPATIBILITY SameMajorVersion + ) + + # Install package config and version files. + install( + FILES + "${CMAKE_CURRENT_BINARY_DIR}/${targetName}-cppConfig.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/${targetName}-cppConfigVersion.cmake" + DESTINATION + "${CMAKE_INSTALL_DATAROOTDIR}/${targetName}-cpp" # to shares/ + ) + + # Export all the installs above as package. + export(PACKAGE "${targetName}-cpp") +endmacro() diff --git a/internal/core/src/storage/azure-blob-storage/vcpkg.json b/internal/core/src/storage/azure-blob-storage/vcpkg.json new file mode 100644 index 0000000000000..ac0d797d5d9e5 --- /dev/null +++ b/internal/core/src/storage/azure-blob-storage/vcpkg.json @@ -0,0 +1,8 @@ +{ + "name": "azure-blob-storage", + "version-string": "1.0.0", + "dependencies": [ + "azure-identity-cpp", + "azure-storage-blobs-cpp" + ] +} diff --git a/internal/core/src/storage/parquet_c.cpp b/internal/core/src/storage/parquet_c.cpp index f4935f2ad98f2..caa7ca50575e7 100644 --- a/internal/core/src/storage/parquet_c.cpp +++ b/internal/core/src/storage/parquet_c.cpp @@ -16,11 +16,11 @@ #include +#include "common/EasyAssert.h" #include "storage/parquet_c.h" #include "storage/PayloadReader.h" #include "storage/PayloadWriter.h" #include "storage/FieldData.h" -#include "common/CGoHelper.h" #include "storage/Util.h" using Payload = milvus::storage::Payload; @@ -50,7 +50,7 @@ AddValuesToPayload(CPayloadWriter payloadWriter, const Payload& info) { p->add_payload(info); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -117,7 +117,7 @@ AddOneStringToPayload(CPayloadWriter payloadWriter, char* cstr, int str_size) { p->add_one_string_payload(cstr, str_size); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -128,7 +128,7 @@ AddOneArrayToPayload(CPayloadWriter payloadWriter, uint8_t* data, int length) { p->add_one_binary_payload(data, length); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -139,7 +139,7 @@ AddOneJSONToPayload(CPayloadWriter payloadWriter, uint8_t* data, int length) { p->add_one_binary_payload(data, length); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -155,7 +155,7 @@ AddBinaryVectorToPayload(CPayloadWriter payloadWriter, p->add_payload(raw_data_info); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -173,7 +173,7 @@ AddFloatVectorToPayload(CPayloadWriter payloadWriter, p->add_payload(raw_data_info); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -184,7 +184,7 @@ FinishPayloadWriter(CPayloadWriter payloadWriter) { p->finish(); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -240,7 +240,7 @@ NewPayloadReader(int columnType, break; } default: { - return milvus::FailureCStatus(UnexpectedError, + return milvus::FailureCStatus(milvus::DataTypeInvalid, "unsupported data type"); } } @@ -250,7 +250,7 @@ NewPayloadReader(int columnType, *c_reader = (CPayloadReader)(p.release()); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -262,7 +262,7 @@ GetBoolFromPayload(CPayloadReader payloadReader, int idx, bool* value) { *value = *reinterpret_cast(field_data->RawValue(idx)); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -276,7 +276,7 @@ GetInt8FromPayload(CPayloadReader payloadReader, int8_t** values, int* length) { reinterpret_cast(const_cast(field_data->Data())); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -292,7 +292,7 @@ GetInt16FromPayload(CPayloadReader payloadReader, reinterpret_cast(const_cast(field_data->Data())); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -308,7 +308,7 @@ GetInt32FromPayload(CPayloadReader payloadReader, reinterpret_cast(const_cast(field_data->Data())); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -324,7 +324,7 @@ GetInt64FromPayload(CPayloadReader payloadReader, reinterpret_cast(const_cast(field_data->Data())); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -338,7 +338,7 @@ GetFloatFromPayload(CPayloadReader payloadReader, float** values, int* length) { reinterpret_cast(const_cast(field_data->Data())); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -354,7 +354,7 @@ GetDoubleFromPayload(CPayloadReader payloadReader, reinterpret_cast(const_cast(field_data->Data())); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -371,7 +371,7 @@ GetOneStringFromPayload(CPayloadReader payloadReader, *str_size = field_data->Size(idx); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -388,7 +388,7 @@ GetBinaryVectorFromPayload(CPayloadReader payloadReader, *length = field_data->get_num_rows(); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -405,7 +405,7 @@ GetFloatVectorFromPayload(CPayloadReader payloadReader, *length = field_data->get_num_rows(); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -427,6 +427,6 @@ ReleasePayloadReader(CPayloadReader payloadReader) { milvus::storage::ReleaseArrowUnused(); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } diff --git a/internal/core/src/storage/storage_c.cpp b/internal/core/src/storage/storage_c.cpp index 020c54159a876..fb7d20879b5bc 100644 --- a/internal/core/src/storage/storage_c.cpp +++ b/internal/core/src/storage/storage_c.cpp @@ -15,9 +15,9 @@ // limitations under the License. #include "storage/storage_c.h" -#include "common/CGoHelper.h" #include "storage/RemoteChunkManagerSingleton.h" #include "storage/LocalChunkManagerSingleton.h" +#include "storage/ChunkCacheSingleton.h" CStatus GetLocalUsedSize(const char* c_dir, int64_t* size) { @@ -33,7 +33,7 @@ GetLocalUsedSize(const char* c_dir, int64_t* size) { } return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -45,7 +45,7 @@ InitLocalChunkManagerSingleton(const char* c_path) { return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); } } @@ -62,6 +62,8 @@ InitRemoteChunkManagerSingleton(CStorageConfig c_storage_config) { storage_config.root_path = std::string(c_storage_config.root_path); storage_config.storage_type = std::string(c_storage_config.storage_type); + storage_config.cloud_provider = + std::string(c_storage_config.cloud_provider); storage_config.iam_endpoint = std::string(c_storage_config.iam_endpoint); storage_config.log_level = std::string(c_storage_config.log_level); @@ -69,12 +71,24 @@ InitRemoteChunkManagerSingleton(CStorageConfig c_storage_config) { storage_config.useIAM = c_storage_config.useIAM; storage_config.useVirtualHost = c_storage_config.useVirtualHost; storage_config.region = c_storage_config.region; + storage_config.requestTimeoutMs = c_storage_config.requestTimeoutMs; milvus::storage::RemoteChunkManagerSingleton::GetInstance().Init( storage_config); return milvus::SuccessCStatus(); } catch (std::exception& e) { - return milvus::FailureCStatus(UnexpectedError, e.what()); + return milvus::FailureCStatus(&e); + } +} + +CStatus +InitChunkCacheSingleton(const char* c_dir_path, const char* read_ahead_policy) { + try { + milvus::storage::ChunkCacheSingleton::GetInstance().Init( + c_dir_path, read_ahead_policy); + return milvus::SuccessCStatus(); + } catch (std::exception& e) { + return milvus::FailureCStatus(&e); } } diff --git a/internal/core/src/storage/storage_c.h b/internal/core/src/storage/storage_c.h index 8418694ddfbf7..a10b38c3c21a7 100644 --- a/internal/core/src/storage/storage_c.h +++ b/internal/core/src/storage/storage_c.h @@ -30,6 +30,9 @@ InitLocalChunkManagerSingleton(const char* path); CStatus InitRemoteChunkManagerSingleton(CStorageConfig c_storage_config); +CStatus +InitChunkCacheSingleton(const char* c_dir_path, const char* read_ahead_policy); + void CleanRemoteChunkManagerSingleton(); diff --git a/internal/core/src/utils/CMakeLists.txt b/internal/core/src/utils/CMakeLists.txt deleted file mode 100644 index 83934e00f0692..0000000000000 --- a/internal/core/src/utils/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -#------------------------------------------------------------------------------- -# Copyright (C) 2019-2020 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing permissions and limitations under the License. -#------------------------------------------------------------------------------- - -aux_source_directory( ${MILVUS_ENGINE_SRC}/utils UTILS_FILES ) - -add_library( milvus_utils STATIC ${UTILS_FILES} ) - diff --git a/internal/core/src/utils/Error.h b/internal/core/src/utils/Error.h deleted file mode 100644 index 229b1edbede06..0000000000000 --- a/internal/core/src/utils/Error.h +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include -#include -#include - -namespace milvus { - -using ErrorCode = int32_t; - -constexpr ErrorCode SERVER_SUCCESS = 0; -constexpr ErrorCode SERVER_ERROR_CODE_BASE = 30000; - -constexpr ErrorCode -ToServerErrorCode(const ErrorCode error_code) { - return SERVER_ERROR_CODE_BASE + error_code; -} - -constexpr ErrorCode DB_SUCCESS = 0; -constexpr ErrorCode DB_ERROR_CODE_BASE = 40000; - -constexpr ErrorCode -ToDbErrorCode(const ErrorCode error_code) { - return DB_ERROR_CODE_BASE + error_code; -} - -constexpr ErrorCode KNOWHERE_SUCCESS = 0; -constexpr ErrorCode KNOWHERE_ERROR_CODE_BASE = 50000; - -constexpr ErrorCode -ToKnowhereErrorCode(const ErrorCode error_code) { - return KNOWHERE_ERROR_CODE_BASE + error_code; -} - -constexpr ErrorCode WAL_SUCCESS = 0; -constexpr ErrorCode WAL_ERROR_CODE_BASE = 60000; - -constexpr ErrorCode -ToWalErrorCode(const ErrorCode error_code) { - return WAL_ERROR_CODE_BASE + error_code; -} - -constexpr ErrorCode SS_SUCCESS = 0; -constexpr ErrorCode SS_ERROR_CODE_BASE = 70000; - -constexpr ErrorCode -ToSSErrorCode(const ErrorCode error_code) { - return SS_ERROR_CODE_BASE + error_code; -} - -// server error code -constexpr ErrorCode SERVER_UNEXPECTED_ERROR = ToServerErrorCode(1); -constexpr ErrorCode SERVER_UNSUPPORTED_ERROR = ToServerErrorCode(2); -constexpr ErrorCode SERVER_NULL_POINTER = ToServerErrorCode(3); -constexpr ErrorCode SERVER_INVALID_ARGUMENT = ToServerErrorCode(4); -constexpr ErrorCode SERVER_FILE_NOT_FOUND = ToServerErrorCode(5); -constexpr ErrorCode SERVER_NOT_IMPLEMENT = ToServerErrorCode(6); -constexpr ErrorCode SERVER_CANNOT_CREATE_FOLDER = ToServerErrorCode(8); -constexpr ErrorCode SERVER_CANNOT_CREATE_FILE = ToServerErrorCode(9); -constexpr ErrorCode SERVER_CANNOT_DELETE_FOLDER = ToServerErrorCode(10); -constexpr ErrorCode SERVER_CANNOT_DELETE_FILE = ToServerErrorCode(11); -constexpr ErrorCode SERVER_BUILD_INDEX_ERROR = ToServerErrorCode(12); -constexpr ErrorCode SERVER_CANNOT_OPEN_FILE = ToServerErrorCode(13); -constexpr ErrorCode SERVER_FILE_MAGIC_BYTES_ERROR = ToServerErrorCode(14); -constexpr ErrorCode SERVER_FILE_SUM_BYTES_ERROR = ToServerErrorCode(15); -constexpr ErrorCode SERVER_CANNOT_READ_FILE = ToServerErrorCode(16); - -constexpr ErrorCode SERVER_COLLECTION_NOT_EXIST = ToServerErrorCode(100); -constexpr ErrorCode SERVER_INVALID_COLLECTION_NAME = ToServerErrorCode(101); -constexpr ErrorCode SERVER_INVALID_COLLECTION_DIMENSION = - ToServerErrorCode(102); -constexpr ErrorCode SERVER_INVALID_VECTOR_DIMENSION = ToServerErrorCode(104); -constexpr ErrorCode SERVER_INVALID_INDEX_TYPE = ToServerErrorCode(105); -constexpr ErrorCode SERVER_INVALID_ROWRECORD = ToServerErrorCode(106); -constexpr ErrorCode SERVER_INVALID_ROWRECORD_ARRAY = ToServerErrorCode(107); -constexpr ErrorCode SERVER_INVALID_TOPK = ToServerErrorCode(108); -constexpr ErrorCode SERVER_ILLEGAL_VECTOR_ID = ToServerErrorCode(109); -constexpr ErrorCode SERVER_ILLEGAL_SEARCH_RESULT = ToServerErrorCode(110); -constexpr ErrorCode SERVER_CACHE_FULL = ToServerErrorCode(111); -constexpr ErrorCode SERVER_WRITE_ERROR = ToServerErrorCode(112); -constexpr ErrorCode SERVER_INVALID_NPROBE = ToServerErrorCode(113); -constexpr ErrorCode SERVER_INVALID_INDEX_NLIST = ToServerErrorCode(114); -constexpr ErrorCode SERVER_INVALID_INDEX_METRIC_TYPE = ToServerErrorCode(115); -constexpr ErrorCode SERVER_INVALID_SEGMENT_ROW_COUNT = ToServerErrorCode(116); -constexpr ErrorCode SERVER_OUT_OF_MEMORY = ToServerErrorCode(117); -constexpr ErrorCode SERVER_INVALID_PARTITION_TAG = ToServerErrorCode(118); -constexpr ErrorCode SERVER_INVALID_BINARY_QUERY = ToServerErrorCode(119); -constexpr ErrorCode SERVER_INVALID_DSL_PARAMETER = ToServerErrorCode(120); -constexpr ErrorCode SERVER_INVALID_FIELD_NAME = ToServerErrorCode(121); -constexpr ErrorCode SERVER_INVALID_FIELD_NUM = ToServerErrorCode(122); - -// db error code -constexpr ErrorCode DB_META_TRANSACTION_FAILED = ToDbErrorCode(1); -constexpr ErrorCode DB_ERROR = ToDbErrorCode(2); -constexpr ErrorCode DB_NOT_FOUND = ToDbErrorCode(3); -constexpr ErrorCode DB_ALREADY_EXIST = ToDbErrorCode(4); -constexpr ErrorCode DB_INVALID_PATH = ToDbErrorCode(5); -constexpr ErrorCode DB_INCOMPATIB_META = ToDbErrorCode(6); -constexpr ErrorCode DB_INVALID_META_URI = ToDbErrorCode(7); -constexpr ErrorCode DB_EMPTY_COLLECTION = ToDbErrorCode(8); -constexpr ErrorCode DB_BLOOM_FILTER_ERROR = ToDbErrorCode(9); -constexpr ErrorCode DB_PARTITION_NOT_FOUND = ToDbErrorCode(10); -constexpr ErrorCode DB_OUT_OF_STORAGE = ToDbErrorCode(11); -constexpr ErrorCode DB_META_QUERY_FAILED = ToDbErrorCode(12); -constexpr ErrorCode DB_FILE_NOT_FOUND = ToDbErrorCode(13); -constexpr ErrorCode DB_PERMISSION_ERROR = ToDbErrorCode(14); - -// knowhere error code -constexpr ErrorCode KNOWHERE_ERROR = ToKnowhereErrorCode(1); -constexpr ErrorCode KNOWHERE_INVALID_ARGUMENT = ToKnowhereErrorCode(2); -constexpr ErrorCode KNOWHERE_UNEXPECTED_ERROR = ToKnowhereErrorCode(3); -constexpr ErrorCode KNOWHERE_NO_SPACE = ToKnowhereErrorCode(4); - -// knowhere error code -constexpr ErrorCode WAL_ERROR = ToWalErrorCode(1); -constexpr ErrorCode WAL_META_ERROR = ToWalErrorCode(2); -constexpr ErrorCode WAL_FILE_ERROR = ToWalErrorCode(3); -constexpr ErrorCode WAL_PATH_ERROR = ToWalErrorCode(4); - -// Snapshot error code -constexpr ErrorCode SS_ERROR = ToSSErrorCode(1); -constexpr ErrorCode SS_STALE_ERROR = ToSSErrorCode(2); -constexpr ErrorCode SS_NOT_FOUND_ERROR = ToSSErrorCode(3); -constexpr ErrorCode SS_INVALID_CONTEXT_ERROR = ToSSErrorCode(4); -constexpr ErrorCode SS_DUPLICATED_ERROR = ToSSErrorCode(5); -constexpr ErrorCode SS_NOT_ACTIVE_ERROR = ToSSErrorCode(6); -constexpr ErrorCode SS_CONSTRAINT_CHECK_ERROR = ToSSErrorCode(7); -constexpr ErrorCode SS_INVALID_ARGUMENT_ERROR = ToSSErrorCode(8); -constexpr ErrorCode SS_OPERATION_PENDING = ToSSErrorCode(9); -constexpr ErrorCode SS_TIMEOUT = ToSSErrorCode(10); -constexpr ErrorCode SS_NOT_COMMITTED = ToSSErrorCode(11); -constexpr ErrorCode SS_COLLECTION_DROPPED = ToSSErrorCode(12); -constexpr ErrorCode SS_EMPTY_HOLDER = ToSSErrorCode(13); - -} // namespace milvus diff --git a/internal/core/src/utils/Json.h b/internal/core/src/utils/Json.h deleted file mode 100644 index e4056f7e38568..0000000000000 --- a/internal/core/src/utils/Json.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include "nlohmann/json.hpp" - -namespace milvus { - -using json = nlohmann::json; - -#define JSON_NULL_CHECK(json) \ - do { \ - if (json.empty()) { \ - return Status{SERVER_INVALID_ARGUMENT, "Json is null"}; \ - } \ - } while (false) - -#define JSON_OBJECT_CHECK(json) \ - do { \ - if (!json.is_object()) { \ - return Status{SERVER_INVALID_ARGUMENT, \ - "Json is not a json object"}; \ - } \ - } while (false) - -} // namespace milvus diff --git a/internal/core/src/utils/Status.cpp b/internal/core/src/utils/Status.cpp deleted file mode 100644 index 332fe0d987d44..0000000000000 --- a/internal/core/src/utils/Status.cpp +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#include "utils/Status.h" -#include "memory" - -#include - -namespace milvus { - -constexpr int CODE_WIDTH = sizeof(StatusCode); - -Status::Status(StatusCode code, const std::string_view msg) { - // 4 bytes store code - // 4 bytes store message length - // the left bytes store message string - auto length = static_cast(msg.size()); - // auto result = new char[length + sizeof(length) + CODE_WIDTH]; - state_.resize(length + sizeof(length) + CODE_WIDTH); - std::memcpy(state_.data(), &code, CODE_WIDTH); - std::memcpy(state_.data() + CODE_WIDTH, &length, sizeof(length)); - memcpy(state_.data() + sizeof(length) + CODE_WIDTH, msg.data(), length); -} - -Status::~Status() { -} - -Status::Status(const Status& s) { - CopyFrom(s); -} - -Status::Status(Status&& s) noexcept { - MoveFrom(s); -} - -Status& -Status::operator=(const Status& s) { - CopyFrom(s); - return *this; -} - -Status& -Status::operator=(Status&& s) noexcept { - MoveFrom(s); - return *this; -} - -void -Status::CopyFrom(const Status& s) { - state_.clear(); - if (s.state_.empty()) { - return; - } - - uint32_t length = 0; - memcpy(&length, s.state_.data() + CODE_WIDTH, sizeof(length)); - int buff_len = length + sizeof(length) + CODE_WIDTH; - state_.resize(buff_len); - memcpy(state_.data(), s.state_.data(), buff_len); -} - -void -Status::MoveFrom(Status& s) { - state_ = s.state_; - s.state_.clear(); -} - -std::string -Status::message() const { - if (state_.empty()) { - return "OK"; - } - - std::string msg; - uint32_t length = 0; - memcpy(&length, state_.data() + CODE_WIDTH, sizeof(length)); - if (length > 0) { - msg.append(state_.data() + sizeof(length) + CODE_WIDTH, length); - } - - return msg; -} - -std::string -Status::ToString() const { - if (state_.empty()) { - return "OK"; - } - - std::string result; - switch (code()) { - case DB_SUCCESS: - result = "OK "; - break; - case DB_ERROR: - result = "Error: "; - break; - case DB_META_TRANSACTION_FAILED: - result = "Database error: "; - break; - case DB_NOT_FOUND: - result = "Not found: "; - break; - case DB_ALREADY_EXIST: - result = "Already exist: "; - break; - case DB_INVALID_PATH: - result = "Invalid path: "; - break; - default: - result = "Error code(" + std::to_string(code()) + "): "; - break; - } - - result += message(); - return result; -} - -} // namespace milvus diff --git a/internal/core/src/utils/Status.h b/internal/core/src/utils/Status.h deleted file mode 100644 index 73bc395861e43..0000000000000 --- a/internal/core/src/utils/Status.h +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -#pragma once - -#include "Error.h" - -#include - -namespace milvus { - -class Status; -#define STATUS_CHECK(func) \ - do { \ - Status s = func; \ - if (!s.ok()) { \ - return s; \ - } \ - } while (false) - -using StatusCode = ErrorCode; - -class Status { - public: - Status(StatusCode code, const std::string_view msg); - Status() = default; - virtual ~Status(); - - Status(const Status& s); - - Status(Status&& s) noexcept; - - Status& - operator=(const Status& s); - - Status& - operator=(Status&& s) noexcept; - - static Status - OK() { - return Status(); - } - - bool - ok() const { - return state_.empty() || code() == 0; - } - - StatusCode - code() const { - return (state_.empty()) ? 0 : *(StatusCode*)(state_.data()); - } - - std::string - message() const; - - std::string - ToString() const; - - private: - inline void - CopyFrom(const Status& s); - - inline void - MoveFrom(Status& s); - - private: - std::string state_; -}; // Status - -} // namespace milvus diff --git a/internal/core/thirdparty/CMakeLists.txt b/internal/core/thirdparty/CMakeLists.txt index 2ab2ec53d7286..c0f5e93960692 100644 --- a/internal/core/thirdparty/CMakeLists.txt +++ b/internal/core/thirdparty/CMakeLists.txt @@ -26,15 +26,7 @@ include(FetchContent) set(FETCHCONTENT_BASE_DIR ${MILVUS_BINARY_DIR}/3rdparty_download) set(FETCHCONTENT_QUIET OFF) -if(CUSTOM_THIRDPARTY_DOWNLOAD_PATH) - set(THIRDPARTY_DOWNLOAD_PATH ${CUSTOM_THIRDPARTY_DOWNLOAD_PATH}) -else() - set(THIRDPARTY_DOWNLOAD_PATH ${CMAKE_BINARY_DIR}/3rdparty_download/download) -endif() -message(STATUS "Thirdparty downloaded file path: ${THIRDPARTY_DOWNLOAD_PATH}") -# ---------------------------------------------------------------------- # Find pthreads - set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) diff --git a/internal/core/thirdparty/knowhere/CMakeLists.txt b/internal/core/thirdparty/knowhere/CMakeLists.txt index 2612db8bad26c..cc0729a8407a5 100644 --- a/internal/core/thirdparty/knowhere/CMakeLists.txt +++ b/internal/core/thirdparty/knowhere/CMakeLists.txt @@ -11,7 +11,15 @@ # or implied. See the License for the specific language governing permissions and limitations under the License. #------------------------------------------------------------------------------- -set( KNOWHERE_VERSION e9cf4de ) +# Update KNOWHERE_VERSION for the first occurrence +set( KNOWHERE_VERSION f4c1757 ) +set( GIT_REPOSITORY "https://github.com/zilliztech/knowhere.git") +if ( INDEX_ENGINE STREQUAL "cardinal" ) + set( KNOWHERE_VERSION main ) + set( GIT_REPOSITORY "https://github.com/zilliztech/knowhere-cloud.git") +endif() +message(STATUS "Knowhere repo: ${GIT_REPOSITORY}") +message(STATUS "Knowhere version: ${KNOWHERE_VERSION}") message(STATUS "Building knowhere-${KNOWHERE_SOURCE_VER} from source") message(STATUS ${CMAKE_BUILD_TYPE}) @@ -26,11 +34,10 @@ if ( MILVUS_GPU_VERSION STREQUAL "ON" ) set(WITH_RAFT ON CACHE BOOL "" FORCE ) endif () - set( CMAKE_PREFIX_PATH ${CONAN_BOOST_ROOT} ) FetchContent_Declare( knowhere - GIT_REPOSITORY "https://github.com/zilliztech/knowhere.git" + GIT_REPOSITORY ${GIT_REPOSITORY} GIT_TAG ${KNOWHERE_VERSION} SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/knowhere-src BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/knowhere-build diff --git a/internal/core/thirdparty/nlohmann/json.hpp b/internal/core/thirdparty/nlohmann/json.hpp deleted file mode 100644 index 08856522ca35f..0000000000000 --- a/internal/core/thirdparty/nlohmann/json.hpp +++ /dev/null @@ -1,25447 +0,0 @@ -/* - __ _____ _____ _____ - __| | __| | | | JSON for Modern C++ -| | |__ | | | | | | version 3.9.1 -|_____|_____|_____|_|___| https://github.com/nlohmann/json - -Licensed under the MIT License . -SPDX-License-Identifier: MIT -Copyright (c) 2013-2019 Niels Lohmann . - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -*/ - -#ifndef INCLUDE_NLOHMANN_JSON_HPP_ -#define INCLUDE_NLOHMANN_JSON_HPP_ - -#define NLOHMANN_JSON_VERSION_MAJOR 3 -#define NLOHMANN_JSON_VERSION_MINOR 9 -#define NLOHMANN_JSON_VERSION_PATCH 1 - -#include // all_of, find, for_each -#include // nullptr_t, ptrdiff_t, size_t -#include // hash, less -#include // initializer_list -#include // istream, ostream -#include // random_access_iterator_tag -#include // unique_ptr -#include // accumulate -#include // string, stoi, to_string -#include // declval, forward, move, pair, swap -#include // vector - -// #include - - -#include - -// #include - - -#include // transform -#include // array -#include // forward_list -#include // inserter, front_inserter, end -#include // map -#include // string -#include // tuple, make_tuple -#include // is_arithmetic, is_same, is_enum, underlying_type, is_convertible -#include // unordered_map -#include // pair, declval -#include // valarray - -// #include - - -#include // exception -#include // runtime_error -#include // to_string - -// #include - - -#include // size_t - -namespace nlohmann -{ -namespace detail -{ -/// struct to capture the start position of the current token -struct position_t -{ - /// the total number of characters read - std::size_t chars_read_total = 0; - /// the number of characters read in the current line - std::size_t chars_read_current_line = 0; - /// the number of lines read - std::size_t lines_read = 0; - - /// conversion to size_t to preserve SAX interface - constexpr operator size_t() const - { - return chars_read_total; - } -}; - -} // namespace detail -} // namespace nlohmann - -// #include - - -#include // pair -// #include -/* Hedley - https://nemequ.github.io/hedley - * Created by Evan Nemerson - * - * To the extent possible under law, the author(s) have dedicated all - * copyright and related and neighboring rights to this software to - * the public domain worldwide. This software is distributed without - * any warranty. - * - * For details, see . - * SPDX-License-Identifier: CC0-1.0 - */ - -#if !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < 13) -#if defined(JSON_HEDLEY_VERSION) - #undef JSON_HEDLEY_VERSION -#endif -#define JSON_HEDLEY_VERSION 13 - -#if defined(JSON_HEDLEY_STRINGIFY_EX) - #undef JSON_HEDLEY_STRINGIFY_EX -#endif -#define JSON_HEDLEY_STRINGIFY_EX(x) #x - -#if defined(JSON_HEDLEY_STRINGIFY) - #undef JSON_HEDLEY_STRINGIFY -#endif -#define JSON_HEDLEY_STRINGIFY(x) JSON_HEDLEY_STRINGIFY_EX(x) - -#if defined(JSON_HEDLEY_CONCAT_EX) - #undef JSON_HEDLEY_CONCAT_EX -#endif -#define JSON_HEDLEY_CONCAT_EX(a,b) a##b - -#if defined(JSON_HEDLEY_CONCAT) - #undef JSON_HEDLEY_CONCAT -#endif -#define JSON_HEDLEY_CONCAT(a,b) JSON_HEDLEY_CONCAT_EX(a,b) - -#if defined(JSON_HEDLEY_CONCAT3_EX) - #undef JSON_HEDLEY_CONCAT3_EX -#endif -#define JSON_HEDLEY_CONCAT3_EX(a,b,c) a##b##c - -#if defined(JSON_HEDLEY_CONCAT3) - #undef JSON_HEDLEY_CONCAT3 -#endif -#define JSON_HEDLEY_CONCAT3(a,b,c) JSON_HEDLEY_CONCAT3_EX(a,b,c) - -#if defined(JSON_HEDLEY_VERSION_ENCODE) - #undef JSON_HEDLEY_VERSION_ENCODE -#endif -#define JSON_HEDLEY_VERSION_ENCODE(major,minor,revision) (((major) * 1000000) + ((minor) * 1000) + (revision)) - -#if defined(JSON_HEDLEY_VERSION_DECODE_MAJOR) - #undef JSON_HEDLEY_VERSION_DECODE_MAJOR -#endif -#define JSON_HEDLEY_VERSION_DECODE_MAJOR(version) ((version) / 1000000) - -#if defined(JSON_HEDLEY_VERSION_DECODE_MINOR) - #undef JSON_HEDLEY_VERSION_DECODE_MINOR -#endif -#define JSON_HEDLEY_VERSION_DECODE_MINOR(version) (((version) % 1000000) / 1000) - -#if defined(JSON_HEDLEY_VERSION_DECODE_REVISION) - #undef JSON_HEDLEY_VERSION_DECODE_REVISION -#endif -#define JSON_HEDLEY_VERSION_DECODE_REVISION(version) ((version) % 1000) - -#if defined(JSON_HEDLEY_GNUC_VERSION) - #undef JSON_HEDLEY_GNUC_VERSION -#endif -#if defined(__GNUC__) && defined(__GNUC_PATCHLEVEL__) - #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) -#elif defined(__GNUC__) - #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, 0) -#endif - -#if defined(JSON_HEDLEY_GNUC_VERSION_CHECK) - #undef JSON_HEDLEY_GNUC_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_GNUC_VERSION) - #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GNUC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_MSVC_VERSION) - #undef JSON_HEDLEY_MSVC_VERSION -#endif -#if defined(_MSC_FULL_VER) && (_MSC_FULL_VER >= 140000000) - #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 10000000, (_MSC_FULL_VER % 10000000) / 100000, (_MSC_FULL_VER % 100000) / 100) -#elif defined(_MSC_FULL_VER) - #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 1000000, (_MSC_FULL_VER % 1000000) / 10000, (_MSC_FULL_VER % 10000) / 10) -#elif defined(_MSC_VER) - #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_VER / 100, _MSC_VER % 100, 0) -#endif - -#if defined(JSON_HEDLEY_MSVC_VERSION_CHECK) - #undef JSON_HEDLEY_MSVC_VERSION_CHECK -#endif -#if !defined(_MSC_VER) - #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (0) -#elif defined(_MSC_VER) && (_MSC_VER >= 1400) - #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 10000000) + (minor * 100000) + (patch))) -#elif defined(_MSC_VER) && (_MSC_VER >= 1200) - #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 1000000) + (minor * 10000) + (patch))) -#else - #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_VER >= ((major * 100) + (minor))) -#endif - -#if defined(JSON_HEDLEY_INTEL_VERSION) - #undef JSON_HEDLEY_INTEL_VERSION -#endif -#if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) - #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, __INTEL_COMPILER_UPDATE) -#elif defined(__INTEL_COMPILER) - #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, 0) -#endif - -#if defined(JSON_HEDLEY_INTEL_VERSION_CHECK) - #undef JSON_HEDLEY_INTEL_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_INTEL_VERSION) - #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_PGI_VERSION) - #undef JSON_HEDLEY_PGI_VERSION -#endif -#if defined(__PGI) && defined(__PGIC__) && defined(__PGIC_MINOR__) && defined(__PGIC_PATCHLEVEL__) - #define JSON_HEDLEY_PGI_VERSION JSON_HEDLEY_VERSION_ENCODE(__PGIC__, __PGIC_MINOR__, __PGIC_PATCHLEVEL__) -#endif - -#if defined(JSON_HEDLEY_PGI_VERSION_CHECK) - #undef JSON_HEDLEY_PGI_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_PGI_VERSION) - #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PGI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_SUNPRO_VERSION) - #undef JSON_HEDLEY_SUNPRO_VERSION -#endif -#if defined(__SUNPRO_C) && (__SUNPRO_C > 0x1000) - #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_C >> 16) & 0xf) * 10) + ((__SUNPRO_C >> 12) & 0xf), (((__SUNPRO_C >> 8) & 0xf) * 10) + ((__SUNPRO_C >> 4) & 0xf), (__SUNPRO_C & 0xf) * 10) -#elif defined(__SUNPRO_C) - #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_C >> 8) & 0xf, (__SUNPRO_C >> 4) & 0xf, (__SUNPRO_C) & 0xf) -#elif defined(__SUNPRO_CC) && (__SUNPRO_CC > 0x1000) - #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_CC >> 16) & 0xf) * 10) + ((__SUNPRO_CC >> 12) & 0xf), (((__SUNPRO_CC >> 8) & 0xf) * 10) + ((__SUNPRO_CC >> 4) & 0xf), (__SUNPRO_CC & 0xf) * 10) -#elif defined(__SUNPRO_CC) - #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_CC >> 8) & 0xf, (__SUNPRO_CC >> 4) & 0xf, (__SUNPRO_CC) & 0xf) -#endif - -#if defined(JSON_HEDLEY_SUNPRO_VERSION_CHECK) - #undef JSON_HEDLEY_SUNPRO_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_SUNPRO_VERSION) - #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_SUNPRO_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION) - #undef JSON_HEDLEY_EMSCRIPTEN_VERSION -#endif -#if defined(__EMSCRIPTEN__) - #define JSON_HEDLEY_EMSCRIPTEN_VERSION JSON_HEDLEY_VERSION_ENCODE(__EMSCRIPTEN_major__, __EMSCRIPTEN_minor__, __EMSCRIPTEN_tiny__) -#endif - -#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK) - #undef JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION) - #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_EMSCRIPTEN_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_ARM_VERSION) - #undef JSON_HEDLEY_ARM_VERSION -#endif -#if defined(__CC_ARM) && defined(__ARMCOMPILER_VERSION) - #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCOMPILER_VERSION / 1000000, (__ARMCOMPILER_VERSION % 1000000) / 10000, (__ARMCOMPILER_VERSION % 10000) / 100) -#elif defined(__CC_ARM) && defined(__ARMCC_VERSION) - #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCC_VERSION / 1000000, (__ARMCC_VERSION % 1000000) / 10000, (__ARMCC_VERSION % 10000) / 100) -#endif - -#if defined(JSON_HEDLEY_ARM_VERSION_CHECK) - #undef JSON_HEDLEY_ARM_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_ARM_VERSION) - #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_ARM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_IBM_VERSION) - #undef JSON_HEDLEY_IBM_VERSION -#endif -#if defined(__ibmxl__) - #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ibmxl_version__, __ibmxl_release__, __ibmxl_modification__) -#elif defined(__xlC__) && defined(__xlC_ver__) - #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, (__xlC_ver__ >> 8) & 0xff) -#elif defined(__xlC__) - #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, 0) -#endif - -#if defined(JSON_HEDLEY_IBM_VERSION_CHECK) - #undef JSON_HEDLEY_IBM_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_IBM_VERSION) - #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IBM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_VERSION) - #undef JSON_HEDLEY_TI_VERSION -#endif -#if \ - defined(__TI_COMPILER_VERSION__) && \ - ( \ - defined(__TMS470__) || defined(__TI_ARM__) || \ - defined(__MSP430__) || \ - defined(__TMS320C2000__) \ - ) -#if (__TI_COMPILER_VERSION__ >= 16000000) - #define JSON_HEDLEY_TI_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif -#endif - -#if defined(JSON_HEDLEY_TI_VERSION_CHECK) - #undef JSON_HEDLEY_TI_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_VERSION) - #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_CL2000_VERSION) - #undef JSON_HEDLEY_TI_CL2000_VERSION -#endif -#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C2000__) - #define JSON_HEDLEY_TI_CL2000_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif - -#if defined(JSON_HEDLEY_TI_CL2000_VERSION_CHECK) - #undef JSON_HEDLEY_TI_CL2000_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_CL2000_VERSION) - #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL2000_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_CL430_VERSION) - #undef JSON_HEDLEY_TI_CL430_VERSION -#endif -#if defined(__TI_COMPILER_VERSION__) && defined(__MSP430__) - #define JSON_HEDLEY_TI_CL430_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif - -#if defined(JSON_HEDLEY_TI_CL430_VERSION_CHECK) - #undef JSON_HEDLEY_TI_CL430_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_CL430_VERSION) - #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL430_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_ARMCL_VERSION) - #undef JSON_HEDLEY_TI_ARMCL_VERSION -#endif -#if defined(__TI_COMPILER_VERSION__) && (defined(__TMS470__) || defined(__TI_ARM__)) - #define JSON_HEDLEY_TI_ARMCL_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif - -#if defined(JSON_HEDLEY_TI_ARMCL_VERSION_CHECK) - #undef JSON_HEDLEY_TI_ARMCL_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_ARMCL_VERSION) - #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_ARMCL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_CL6X_VERSION) - #undef JSON_HEDLEY_TI_CL6X_VERSION -#endif -#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C6X__) - #define JSON_HEDLEY_TI_CL6X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif - -#if defined(JSON_HEDLEY_TI_CL6X_VERSION_CHECK) - #undef JSON_HEDLEY_TI_CL6X_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_CL6X_VERSION) - #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL6X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_CL7X_VERSION) - #undef JSON_HEDLEY_TI_CL7X_VERSION -#endif -#if defined(__TI_COMPILER_VERSION__) && defined(__C7000__) - #define JSON_HEDLEY_TI_CL7X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif - -#if defined(JSON_HEDLEY_TI_CL7X_VERSION_CHECK) - #undef JSON_HEDLEY_TI_CL7X_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_CL7X_VERSION) - #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL7X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TI_CLPRU_VERSION) - #undef JSON_HEDLEY_TI_CLPRU_VERSION -#endif -#if defined(__TI_COMPILER_VERSION__) && defined(__PRU__) - #define JSON_HEDLEY_TI_CLPRU_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) -#endif - -#if defined(JSON_HEDLEY_TI_CLPRU_VERSION_CHECK) - #undef JSON_HEDLEY_TI_CLPRU_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TI_CLPRU_VERSION) - #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CLPRU_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_CRAY_VERSION) - #undef JSON_HEDLEY_CRAY_VERSION -#endif -#if defined(_CRAYC) - #if defined(_RELEASE_PATCHLEVEL) - #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, _RELEASE_PATCHLEVEL) - #else - #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, 0) - #endif -#endif - -#if defined(JSON_HEDLEY_CRAY_VERSION_CHECK) - #undef JSON_HEDLEY_CRAY_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_CRAY_VERSION) - #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_CRAY_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_IAR_VERSION) - #undef JSON_HEDLEY_IAR_VERSION -#endif -#if defined(__IAR_SYSTEMS_ICC__) - #if __VER__ > 1000 - #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE((__VER__ / 1000000), ((__VER__ / 1000) % 1000), (__VER__ % 1000)) - #else - #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE(VER / 100, __VER__ % 100, 0) - #endif -#endif - -#if defined(JSON_HEDLEY_IAR_VERSION_CHECK) - #undef JSON_HEDLEY_IAR_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_IAR_VERSION) - #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IAR_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_TINYC_VERSION) - #undef JSON_HEDLEY_TINYC_VERSION -#endif -#if defined(__TINYC__) - #define JSON_HEDLEY_TINYC_VERSION JSON_HEDLEY_VERSION_ENCODE(__TINYC__ / 1000, (__TINYC__ / 100) % 10, __TINYC__ % 100) -#endif - -#if defined(JSON_HEDLEY_TINYC_VERSION_CHECK) - #undef JSON_HEDLEY_TINYC_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_TINYC_VERSION) - #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TINYC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_DMC_VERSION) - #undef JSON_HEDLEY_DMC_VERSION -#endif -#if defined(__DMC__) - #define JSON_HEDLEY_DMC_VERSION JSON_HEDLEY_VERSION_ENCODE(__DMC__ >> 8, (__DMC__ >> 4) & 0xf, __DMC__ & 0xf) -#endif - -#if defined(JSON_HEDLEY_DMC_VERSION_CHECK) - #undef JSON_HEDLEY_DMC_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_DMC_VERSION) - #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_DMC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_COMPCERT_VERSION) - #undef JSON_HEDLEY_COMPCERT_VERSION -#endif -#if defined(__COMPCERT_VERSION__) - #define JSON_HEDLEY_COMPCERT_VERSION JSON_HEDLEY_VERSION_ENCODE(__COMPCERT_VERSION__ / 10000, (__COMPCERT_VERSION__ / 100) % 100, __COMPCERT_VERSION__ % 100) -#endif - -#if defined(JSON_HEDLEY_COMPCERT_VERSION_CHECK) - #undef JSON_HEDLEY_COMPCERT_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_COMPCERT_VERSION) - #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_COMPCERT_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_PELLES_VERSION) - #undef JSON_HEDLEY_PELLES_VERSION -#endif -#if defined(__POCC__) - #define JSON_HEDLEY_PELLES_VERSION JSON_HEDLEY_VERSION_ENCODE(__POCC__ / 100, __POCC__ % 100, 0) -#endif - -#if defined(JSON_HEDLEY_PELLES_VERSION_CHECK) - #undef JSON_HEDLEY_PELLES_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_PELLES_VERSION) - #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PELLES_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_GCC_VERSION) - #undef JSON_HEDLEY_GCC_VERSION -#endif -#if \ - defined(JSON_HEDLEY_GNUC_VERSION) && \ - !defined(__clang__) && \ - !defined(JSON_HEDLEY_INTEL_VERSION) && \ - !defined(JSON_HEDLEY_PGI_VERSION) && \ - !defined(JSON_HEDLEY_ARM_VERSION) && \ - !defined(JSON_HEDLEY_TI_VERSION) && \ - !defined(JSON_HEDLEY_TI_ARMCL_VERSION) && \ - !defined(JSON_HEDLEY_TI_CL430_VERSION) && \ - !defined(JSON_HEDLEY_TI_CL2000_VERSION) && \ - !defined(JSON_HEDLEY_TI_CL6X_VERSION) && \ - !defined(JSON_HEDLEY_TI_CL7X_VERSION) && \ - !defined(JSON_HEDLEY_TI_CLPRU_VERSION) && \ - !defined(__COMPCERT__) - #define JSON_HEDLEY_GCC_VERSION JSON_HEDLEY_GNUC_VERSION -#endif - -#if defined(JSON_HEDLEY_GCC_VERSION_CHECK) - #undef JSON_HEDLEY_GCC_VERSION_CHECK -#endif -#if defined(JSON_HEDLEY_GCC_VERSION) - #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) -#else - #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (0) -#endif - -#if defined(JSON_HEDLEY_HAS_ATTRIBUTE) - #undef JSON_HEDLEY_HAS_ATTRIBUTE -#endif -#if defined(__has_attribute) - #define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) __has_attribute(attribute) -#else - #define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_ATTRIBUTE) - #undef JSON_HEDLEY_GNUC_HAS_ATTRIBUTE -#endif -#if defined(__has_attribute) - #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) __has_attribute(attribute) -#else - #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_ATTRIBUTE) - #undef JSON_HEDLEY_GCC_HAS_ATTRIBUTE -#endif -#if defined(__has_attribute) - #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) __has_attribute(attribute) -#else - #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE) - #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE -#endif -#if \ - defined(__has_cpp_attribute) && \ - defined(__cplusplus) && \ - (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) - #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) __has_cpp_attribute(attribute) -#else - #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) (0) -#endif - -#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS) - #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS -#endif -#if !defined(__cplusplus) || !defined(__has_cpp_attribute) - #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0) -#elif \ - !defined(JSON_HEDLEY_PGI_VERSION) && \ - !defined(JSON_HEDLEY_IAR_VERSION) && \ - (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) && \ - (!defined(JSON_HEDLEY_MSVC_VERSION) || JSON_HEDLEY_MSVC_VERSION_CHECK(19,20,0)) - #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(ns::attribute) -#else - #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE) - #undef JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE -#endif -#if defined(__has_cpp_attribute) && defined(__cplusplus) - #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute) -#else - #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE) - #undef JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE -#endif -#if defined(__has_cpp_attribute) && defined(__cplusplus) - #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute) -#else - #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_HAS_BUILTIN) - #undef JSON_HEDLEY_HAS_BUILTIN -#endif -#if defined(__has_builtin) - #define JSON_HEDLEY_HAS_BUILTIN(builtin) __has_builtin(builtin) -#else - #define JSON_HEDLEY_HAS_BUILTIN(builtin) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_BUILTIN) - #undef JSON_HEDLEY_GNUC_HAS_BUILTIN -#endif -#if defined(__has_builtin) - #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin) -#else - #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_BUILTIN) - #undef JSON_HEDLEY_GCC_HAS_BUILTIN -#endif -#if defined(__has_builtin) - #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin) -#else - #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_HAS_FEATURE) - #undef JSON_HEDLEY_HAS_FEATURE -#endif -#if defined(__has_feature) - #define JSON_HEDLEY_HAS_FEATURE(feature) __has_feature(feature) -#else - #define JSON_HEDLEY_HAS_FEATURE(feature) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_FEATURE) - #undef JSON_HEDLEY_GNUC_HAS_FEATURE -#endif -#if defined(__has_feature) - #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature) -#else - #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_FEATURE) - #undef JSON_HEDLEY_GCC_HAS_FEATURE -#endif -#if defined(__has_feature) - #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature) -#else - #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_HAS_EXTENSION) - #undef JSON_HEDLEY_HAS_EXTENSION -#endif -#if defined(__has_extension) - #define JSON_HEDLEY_HAS_EXTENSION(extension) __has_extension(extension) -#else - #define JSON_HEDLEY_HAS_EXTENSION(extension) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_EXTENSION) - #undef JSON_HEDLEY_GNUC_HAS_EXTENSION -#endif -#if defined(__has_extension) - #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension) -#else - #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_EXTENSION) - #undef JSON_HEDLEY_GCC_HAS_EXTENSION -#endif -#if defined(__has_extension) - #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension) -#else - #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE) - #undef JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE -#endif -#if defined(__has_declspec_attribute) - #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) __has_declspec_attribute(attribute) -#else - #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE) - #undef JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE -#endif -#if defined(__has_declspec_attribute) - #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute) -#else - #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE) - #undef JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE -#endif -#if defined(__has_declspec_attribute) - #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute) -#else - #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_HAS_WARNING) - #undef JSON_HEDLEY_HAS_WARNING -#endif -#if defined(__has_warning) - #define JSON_HEDLEY_HAS_WARNING(warning) __has_warning(warning) -#else - #define JSON_HEDLEY_HAS_WARNING(warning) (0) -#endif - -#if defined(JSON_HEDLEY_GNUC_HAS_WARNING) - #undef JSON_HEDLEY_GNUC_HAS_WARNING -#endif -#if defined(__has_warning) - #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning) -#else - #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_GCC_HAS_WARNING) - #undef JSON_HEDLEY_GCC_HAS_WARNING -#endif -#if defined(__has_warning) - #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning) -#else - #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -/* JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ is for - HEDLEY INTERNAL USE ONLY. API subject to change without notice. */ -#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_) - #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ -#endif -#if defined(__cplusplus) -# if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat") -# if JSON_HEDLEY_HAS_WARNING("-Wc++17-extensions") -# define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ - _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \ - xpr \ - JSON_HEDLEY_DIAGNOSTIC_POP -# else -# define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ - xpr \ - JSON_HEDLEY_DIAGNOSTIC_POP -# endif -# endif -#endif -#if !defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(x) x -#endif - -#if defined(JSON_HEDLEY_CONST_CAST) - #undef JSON_HEDLEY_CONST_CAST -#endif -#if defined(__cplusplus) -# define JSON_HEDLEY_CONST_CAST(T, expr) (const_cast(expr)) -#elif \ - JSON_HEDLEY_HAS_WARNING("-Wcast-qual") || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) -# define JSON_HEDLEY_CONST_CAST(T, expr) (__extension__ ({ \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL \ - ((T) (expr)); \ - JSON_HEDLEY_DIAGNOSTIC_POP \ - })) -#else -# define JSON_HEDLEY_CONST_CAST(T, expr) ((T) (expr)) -#endif - -#if defined(JSON_HEDLEY_REINTERPRET_CAST) - #undef JSON_HEDLEY_REINTERPRET_CAST -#endif -#if defined(__cplusplus) - #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) (reinterpret_cast(expr)) -#else - #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) ((T) (expr)) -#endif - -#if defined(JSON_HEDLEY_STATIC_CAST) - #undef JSON_HEDLEY_STATIC_CAST -#endif -#if defined(__cplusplus) - #define JSON_HEDLEY_STATIC_CAST(T, expr) (static_cast(expr)) -#else - #define JSON_HEDLEY_STATIC_CAST(T, expr) ((T) (expr)) -#endif - -#if defined(JSON_HEDLEY_CPP_CAST) - #undef JSON_HEDLEY_CPP_CAST -#endif -#if defined(__cplusplus) -# if JSON_HEDLEY_HAS_WARNING("-Wold-style-cast") -# define JSON_HEDLEY_CPP_CAST(T, expr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wold-style-cast\"") \ - ((T) (expr)) \ - JSON_HEDLEY_DIAGNOSTIC_POP -# elif JSON_HEDLEY_IAR_VERSION_CHECK(8,3,0) -# define JSON_HEDLEY_CPP_CAST(T, expr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("diag_suppress=Pe137") \ - JSON_HEDLEY_DIAGNOSTIC_POP \ -# else -# define JSON_HEDLEY_CPP_CAST(T, expr) ((T) (expr)) -# endif -#else -# define JSON_HEDLEY_CPP_CAST(T, expr) (expr) -#endif - -#if \ - (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \ - defined(__clang__) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,0,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) || \ - JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,17) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(8,0,0) || \ - (JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) && defined(__C99_PRAGMA_OPERATOR)) - #define JSON_HEDLEY_PRAGMA(value) _Pragma(#value) -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) - #define JSON_HEDLEY_PRAGMA(value) __pragma(value) -#else - #define JSON_HEDLEY_PRAGMA(value) -#endif - -#if defined(JSON_HEDLEY_DIAGNOSTIC_PUSH) - #undef JSON_HEDLEY_DIAGNOSTIC_PUSH -#endif -#if defined(JSON_HEDLEY_DIAGNOSTIC_POP) - #undef JSON_HEDLEY_DIAGNOSTIC_POP -#endif -#if defined(__clang__) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("clang diagnostic push") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("clang diagnostic pop") -#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") -#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("GCC diagnostic push") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("GCC diagnostic pop") -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH __pragma(warning(push)) - #define JSON_HEDLEY_DIAGNOSTIC_POP __pragma(warning(pop)) -#elif JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("push") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("pop") -#elif \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,4,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("diag_push") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("diag_pop") -#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) - #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") - #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") -#else - #define JSON_HEDLEY_DIAGNOSTIC_PUSH - #define JSON_HEDLEY_DIAGNOSTIC_POP -#endif - -#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED) - #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wdeprecated-declarations") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"") -#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warning(disable:1478 1786)") -#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") -#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:4996)) -#elif \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1291,1718") -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && !defined(__cplusplus) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,E_DEPRECATED_ATT,E_DEPRECATED_ATT_MESS)") -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && defined(__cplusplus) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,symdeprecated,symdeprecated2)") -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress=Pe1444,Pe1215") -#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warn(disable:2241)") -#else - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED -#endif - -#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS) - #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("clang diagnostic ignored \"-Wunknown-pragmas\"") -#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("warning(disable:161)") -#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 1675") -#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("GCC diagnostic ignored \"-Wunknown-pragmas\"") -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:4068)) -#elif \ - JSON_HEDLEY_TI_VERSION_CHECK(16,9,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") -#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress=Pe161") -#else - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS -#endif - -#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES) - #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wunknown-attributes") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("clang diagnostic ignored \"-Wunknown-attributes\"") -#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") -#elif JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("warning(disable:1292)") -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:5030)) -#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("error_messages(off,attrskipunsup)") -#elif \ - JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1173") -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress=Pe1097") -#else - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES -#endif - -#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL) - #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wcast-qual") - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("clang diagnostic ignored \"-Wcast-qual\"") -#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("warning(disable:2203 2331)") -#elif JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("GCC diagnostic ignored \"-Wcast-qual\"") -#else - #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL -#endif - -#if defined(JSON_HEDLEY_DEPRECATED) - #undef JSON_HEDLEY_DEPRECATED -#endif -#if defined(JSON_HEDLEY_DEPRECATED_FOR) - #undef JSON_HEDLEY_DEPRECATED_FOR -#endif -#if JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) - #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated("Since " # since)) - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated("Since " #since "; use " #replacement)) -#elif defined(__cplusplus) && (__cplusplus >= 201402L) - #define JSON_HEDLEY_DEPRECATED(since) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since)]]) - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since "; use " #replacement)]]) -#elif \ - JSON_HEDLEY_HAS_EXTENSION(attribute_deprecated_with_message) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(18,1,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) - #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__("Since " #since))) - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__("Since " #since "; use " #replacement))) -#elif \ - JSON_HEDLEY_HAS_ATTRIBUTE(deprecated) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) - #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__)) - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__)) -#elif \ - JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ - JSON_HEDLEY_PELLES_VERSION_CHECK(6,50,0) - #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated) - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated) -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_DEPRECATED(since) _Pragma("deprecated") - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) _Pragma("deprecated") -#else - #define JSON_HEDLEY_DEPRECATED(since) - #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) -#endif - -#if defined(JSON_HEDLEY_UNAVAILABLE) - #undef JSON_HEDLEY_UNAVAILABLE -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(warning) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_UNAVAILABLE(available_since) __attribute__((__warning__("Not available until " #available_since))) -#else - #define JSON_HEDLEY_UNAVAILABLE(available_since) -#endif - -#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT) - #undef JSON_HEDLEY_WARN_UNUSED_RESULT -#endif -#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT_MSG) - #undef JSON_HEDLEY_WARN_UNUSED_RESULT_MSG -#endif -#if (JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) >= 201907L) - #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) - #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard(msg)]]) -#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) - #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) - #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) -#elif \ - JSON_HEDLEY_HAS_ATTRIBUTE(warn_unused_result) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) - #define JSON_HEDLEY_WARN_UNUSED_RESULT __attribute__((__warn_unused_result__)) - #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) __attribute__((__warn_unused_result__)) -#elif defined(_Check_return_) /* SAL */ - #define JSON_HEDLEY_WARN_UNUSED_RESULT _Check_return_ - #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) _Check_return_ -#else - #define JSON_HEDLEY_WARN_UNUSED_RESULT - #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) -#endif - -#if defined(JSON_HEDLEY_SENTINEL) - #undef JSON_HEDLEY_SENTINEL -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(sentinel) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) - #define JSON_HEDLEY_SENTINEL(position) __attribute__((__sentinel__(position))) -#else - #define JSON_HEDLEY_SENTINEL(position) -#endif - -#if defined(JSON_HEDLEY_NO_RETURN) - #undef JSON_HEDLEY_NO_RETURN -#endif -#if JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_NO_RETURN __noreturn -#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__)) -#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L - #define JSON_HEDLEY_NO_RETURN _Noreturn -#elif defined(__cplusplus) && (__cplusplus >= 201103L) - #define JSON_HEDLEY_NO_RETURN JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[noreturn]]) -#elif \ - JSON_HEDLEY_HAS_ATTRIBUTE(noreturn) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,2,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) - #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__)) -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) - #define JSON_HEDLEY_NO_RETURN _Pragma("does_not_return") -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) - #define JSON_HEDLEY_NO_RETURN __declspec(noreturn) -#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus) - #define JSON_HEDLEY_NO_RETURN _Pragma("FUNC_NEVER_RETURNS;") -#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0) - #define JSON_HEDLEY_NO_RETURN __attribute((noreturn)) -#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0) - #define JSON_HEDLEY_NO_RETURN __declspec(noreturn) -#else - #define JSON_HEDLEY_NO_RETURN -#endif - -#if defined(JSON_HEDLEY_NO_ESCAPE) - #undef JSON_HEDLEY_NO_ESCAPE -#endif -#if JSON_HEDLEY_HAS_ATTRIBUTE(noescape) - #define JSON_HEDLEY_NO_ESCAPE __attribute__((__noescape__)) -#else - #define JSON_HEDLEY_NO_ESCAPE -#endif - -#if defined(JSON_HEDLEY_UNREACHABLE) - #undef JSON_HEDLEY_UNREACHABLE -#endif -#if defined(JSON_HEDLEY_UNREACHABLE_RETURN) - #undef JSON_HEDLEY_UNREACHABLE_RETURN -#endif -#if defined(JSON_HEDLEY_ASSUME) - #undef JSON_HEDLEY_ASSUME -#endif -#if \ - JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_ASSUME(expr) __assume(expr) -#elif JSON_HEDLEY_HAS_BUILTIN(__builtin_assume) - #define JSON_HEDLEY_ASSUME(expr) __builtin_assume(expr) -#elif \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) - #if defined(__cplusplus) - #define JSON_HEDLEY_ASSUME(expr) std::_nassert(expr) - #else - #define JSON_HEDLEY_ASSUME(expr) _nassert(expr) - #endif -#endif -#if \ - (JSON_HEDLEY_HAS_BUILTIN(__builtin_unreachable) && (!defined(JSON_HEDLEY_ARM_VERSION))) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(18,10,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(13,1,5) - #define JSON_HEDLEY_UNREACHABLE() __builtin_unreachable() -#elif defined(JSON_HEDLEY_ASSUME) - #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0) -#endif -#if !defined(JSON_HEDLEY_ASSUME) - #if defined(JSON_HEDLEY_UNREACHABLE) - #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, ((expr) ? 1 : (JSON_HEDLEY_UNREACHABLE(), 1))) - #else - #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, expr) - #endif -#endif -#if defined(JSON_HEDLEY_UNREACHABLE) - #if \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) - #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (JSON_HEDLEY_STATIC_CAST(void, JSON_HEDLEY_ASSUME(0)), (value)) - #else - #define JSON_HEDLEY_UNREACHABLE_RETURN(value) JSON_HEDLEY_UNREACHABLE() - #endif -#else - #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (value) -#endif -#if !defined(JSON_HEDLEY_UNREACHABLE) - #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0) -#endif - -JSON_HEDLEY_DIAGNOSTIC_PUSH -#if JSON_HEDLEY_HAS_WARNING("-Wpedantic") - #pragma clang diagnostic ignored "-Wpedantic" -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat-pedantic") && defined(__cplusplus) - #pragma clang diagnostic ignored "-Wc++98-compat-pedantic" -#endif -#if JSON_HEDLEY_GCC_HAS_WARNING("-Wvariadic-macros",4,0,0) - #if defined(__clang__) - #pragma clang diagnostic ignored "-Wvariadic-macros" - #elif defined(JSON_HEDLEY_GCC_VERSION) - #pragma GCC diagnostic ignored "-Wvariadic-macros" - #endif -#endif -#if defined(JSON_HEDLEY_NON_NULL) - #undef JSON_HEDLEY_NON_NULL -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(nonnull) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) - #define JSON_HEDLEY_NON_NULL(...) __attribute__((__nonnull__(__VA_ARGS__))) -#else - #define JSON_HEDLEY_NON_NULL(...) -#endif -JSON_HEDLEY_DIAGNOSTIC_POP - -#if defined(JSON_HEDLEY_PRINTF_FORMAT) - #undef JSON_HEDLEY_PRINTF_FORMAT -#endif -#if defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && !defined(__USE_MINGW_ANSI_STDIO) - #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(ms_printf, string_idx, first_to_check))) -#elif defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && defined(__USE_MINGW_ANSI_STDIO) - #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(gnu_printf, string_idx, first_to_check))) -#elif \ - JSON_HEDLEY_HAS_ATTRIBUTE(format) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) - #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(__printf__, string_idx, first_to_check))) -#elif JSON_HEDLEY_PELLES_VERSION_CHECK(6,0,0) - #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __declspec(vaformat(printf,string_idx,first_to_check)) -#else - #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) -#endif - -#if defined(JSON_HEDLEY_CONSTEXPR) - #undef JSON_HEDLEY_CONSTEXPR -#endif -#if defined(__cplusplus) - #if __cplusplus >= 201103L - #define JSON_HEDLEY_CONSTEXPR JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(constexpr) - #endif -#endif -#if !defined(JSON_HEDLEY_CONSTEXPR) - #define JSON_HEDLEY_CONSTEXPR -#endif - -#if defined(JSON_HEDLEY_PREDICT) - #undef JSON_HEDLEY_PREDICT -#endif -#if defined(JSON_HEDLEY_LIKELY) - #undef JSON_HEDLEY_LIKELY -#endif -#if defined(JSON_HEDLEY_UNLIKELY) - #undef JSON_HEDLEY_UNLIKELY -#endif -#if defined(JSON_HEDLEY_UNPREDICTABLE) - #undef JSON_HEDLEY_UNPREDICTABLE -#endif -#if JSON_HEDLEY_HAS_BUILTIN(__builtin_unpredictable) - #define JSON_HEDLEY_UNPREDICTABLE(expr) __builtin_unpredictable((expr)) -#endif -#if \ - JSON_HEDLEY_HAS_BUILTIN(__builtin_expect_with_probability) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(9,0,0) -# define JSON_HEDLEY_PREDICT(expr, value, probability) __builtin_expect_with_probability( (expr), (value), (probability)) -# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) __builtin_expect_with_probability(!!(expr), 1 , (probability)) -# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) __builtin_expect_with_probability(!!(expr), 0 , (probability)) -# define JSON_HEDLEY_LIKELY(expr) __builtin_expect (!!(expr), 1 ) -# define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect (!!(expr), 0 ) -#elif \ - JSON_HEDLEY_HAS_BUILTIN(__builtin_expect) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,27) || \ - JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) -# define JSON_HEDLEY_PREDICT(expr, expected, probability) \ - (((probability) >= 0.9) ? __builtin_expect((expr), (expected)) : (JSON_HEDLEY_STATIC_CAST(void, expected), (expr))) -# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) \ - (__extension__ ({ \ - double hedley_probability_ = (probability); \ - ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 1) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 0) : !!(expr))); \ - })) -# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) \ - (__extension__ ({ \ - double hedley_probability_ = (probability); \ - ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 0) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 1) : !!(expr))); \ - })) -# define JSON_HEDLEY_LIKELY(expr) __builtin_expect(!!(expr), 1) -# define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect(!!(expr), 0) -#else -# define JSON_HEDLEY_PREDICT(expr, expected, probability) (JSON_HEDLEY_STATIC_CAST(void, expected), (expr)) -# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) (!!(expr)) -# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) (!!(expr)) -# define JSON_HEDLEY_LIKELY(expr) (!!(expr)) -# define JSON_HEDLEY_UNLIKELY(expr) (!!(expr)) -#endif -#if !defined(JSON_HEDLEY_UNPREDICTABLE) - #define JSON_HEDLEY_UNPREDICTABLE(expr) JSON_HEDLEY_PREDICT(expr, 1, 0.5) -#endif - -#if defined(JSON_HEDLEY_MALLOC) - #undef JSON_HEDLEY_MALLOC -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(malloc) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) - #define JSON_HEDLEY_MALLOC __attribute__((__malloc__)) -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) - #define JSON_HEDLEY_MALLOC _Pragma("returns_new_memory") -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(14, 0, 0) - #define JSON_HEDLEY_MALLOC __declspec(restrict) -#else - #define JSON_HEDLEY_MALLOC -#endif - -#if defined(JSON_HEDLEY_PURE) - #undef JSON_HEDLEY_PURE -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(pure) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(2,96,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) -# define JSON_HEDLEY_PURE __attribute__((__pure__)) -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) -# define JSON_HEDLEY_PURE _Pragma("does_not_write_global_data") -#elif defined(__cplusplus) && \ - ( \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) \ - ) -# define JSON_HEDLEY_PURE _Pragma("FUNC_IS_PURE;") -#else -# define JSON_HEDLEY_PURE -#endif - -#if defined(JSON_HEDLEY_CONST) - #undef JSON_HEDLEY_CONST -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(const) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(2,5,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) - #define JSON_HEDLEY_CONST __attribute__((__const__)) -#elif \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) - #define JSON_HEDLEY_CONST _Pragma("no_side_effect") -#else - #define JSON_HEDLEY_CONST JSON_HEDLEY_PURE -#endif - -#if defined(JSON_HEDLEY_RESTRICT) - #undef JSON_HEDLEY_RESTRICT -#endif -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && !defined(__cplusplus) - #define JSON_HEDLEY_RESTRICT restrict -#elif \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ - JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,4) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus)) || \ - JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \ - defined(__clang__) - #define JSON_HEDLEY_RESTRICT __restrict -#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,3,0) && !defined(__cplusplus) - #define JSON_HEDLEY_RESTRICT _Restrict -#else - #define JSON_HEDLEY_RESTRICT -#endif - -#if defined(JSON_HEDLEY_INLINE) - #undef JSON_HEDLEY_INLINE -#endif -#if \ - (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \ - (defined(__cplusplus) && (__cplusplus >= 199711L)) - #define JSON_HEDLEY_INLINE inline -#elif \ - defined(JSON_HEDLEY_GCC_VERSION) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(6,2,0) - #define JSON_HEDLEY_INLINE __inline__ -#elif \ - JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,1,0) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) - #define JSON_HEDLEY_INLINE __inline -#else - #define JSON_HEDLEY_INLINE -#endif - -#if defined(JSON_HEDLEY_ALWAYS_INLINE) - #undef JSON_HEDLEY_ALWAYS_INLINE -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(always_inline) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) -# define JSON_HEDLEY_ALWAYS_INLINE __attribute__((__always_inline__)) JSON_HEDLEY_INLINE -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) -# define JSON_HEDLEY_ALWAYS_INLINE __forceinline -#elif defined(__cplusplus) && \ - ( \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) \ - ) -# define JSON_HEDLEY_ALWAYS_INLINE _Pragma("FUNC_ALWAYS_INLINE;") -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) -# define JSON_HEDLEY_ALWAYS_INLINE _Pragma("inline=forced") -#else -# define JSON_HEDLEY_ALWAYS_INLINE JSON_HEDLEY_INLINE -#endif - -#if defined(JSON_HEDLEY_NEVER_INLINE) - #undef JSON_HEDLEY_NEVER_INLINE -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(noinline) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ - JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ - (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ - (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ - (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ - JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ - JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) - #define JSON_HEDLEY_NEVER_INLINE __attribute__((__noinline__)) -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) - #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline) -#elif JSON_HEDLEY_PGI_VERSION_CHECK(10,2,0) - #define JSON_HEDLEY_NEVER_INLINE _Pragma("noinline") -#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus) - #define JSON_HEDLEY_NEVER_INLINE _Pragma("FUNC_CANNOT_INLINE;") -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) - #define JSON_HEDLEY_NEVER_INLINE _Pragma("inline=never") -#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0) - #define JSON_HEDLEY_NEVER_INLINE __attribute((noinline)) -#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0) - #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline) -#else - #define JSON_HEDLEY_NEVER_INLINE -#endif - -#if defined(JSON_HEDLEY_PRIVATE) - #undef JSON_HEDLEY_PRIVATE -#endif -#if defined(JSON_HEDLEY_PUBLIC) - #undef JSON_HEDLEY_PUBLIC -#endif -#if defined(JSON_HEDLEY_IMPORT) - #undef JSON_HEDLEY_IMPORT -#endif -#if defined(_WIN32) || defined(__CYGWIN__) -# define JSON_HEDLEY_PRIVATE -# define JSON_HEDLEY_PUBLIC __declspec(dllexport) -# define JSON_HEDLEY_IMPORT __declspec(dllimport) -#else -# if \ - JSON_HEDLEY_HAS_ATTRIBUTE(visibility) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ - JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ - ( \ - defined(__TI_EABI__) && \ - ( \ - (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) \ - ) \ - ) -# define JSON_HEDLEY_PRIVATE __attribute__((__visibility__("hidden"))) -# define JSON_HEDLEY_PUBLIC __attribute__((__visibility__("default"))) -# else -# define JSON_HEDLEY_PRIVATE -# define JSON_HEDLEY_PUBLIC -# endif -# define JSON_HEDLEY_IMPORT extern -#endif - -#if defined(JSON_HEDLEY_NO_THROW) - #undef JSON_HEDLEY_NO_THROW -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(nothrow) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) - #define JSON_HEDLEY_NO_THROW __attribute__((__nothrow__)) -#elif \ - JSON_HEDLEY_MSVC_VERSION_CHECK(13,1,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) - #define JSON_HEDLEY_NO_THROW __declspec(nothrow) -#else - #define JSON_HEDLEY_NO_THROW -#endif - -#if defined(JSON_HEDLEY_FALL_THROUGH) - #undef JSON_HEDLEY_FALL_THROUGH -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(fallthrough) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(7,0,0) - #define JSON_HEDLEY_FALL_THROUGH __attribute__((__fallthrough__)) -#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(clang,fallthrough) - #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[clang::fallthrough]]) -#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(fallthrough) - #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[fallthrough]]) -#elif defined(__fallthrough) /* SAL */ - #define JSON_HEDLEY_FALL_THROUGH __fallthrough -#else - #define JSON_HEDLEY_FALL_THROUGH -#endif - -#if defined(JSON_HEDLEY_RETURNS_NON_NULL) - #undef JSON_HEDLEY_RETURNS_NON_NULL -#endif -#if \ - JSON_HEDLEY_HAS_ATTRIBUTE(returns_nonnull) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) - #define JSON_HEDLEY_RETURNS_NON_NULL __attribute__((__returns_nonnull__)) -#elif defined(_Ret_notnull_) /* SAL */ - #define JSON_HEDLEY_RETURNS_NON_NULL _Ret_notnull_ -#else - #define JSON_HEDLEY_RETURNS_NON_NULL -#endif - -#if defined(JSON_HEDLEY_ARRAY_PARAM) - #undef JSON_HEDLEY_ARRAY_PARAM -#endif -#if \ - defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && \ - !defined(__STDC_NO_VLA__) && \ - !defined(__cplusplus) && \ - !defined(JSON_HEDLEY_PGI_VERSION) && \ - !defined(JSON_HEDLEY_TINYC_VERSION) - #define JSON_HEDLEY_ARRAY_PARAM(name) (name) -#else - #define JSON_HEDLEY_ARRAY_PARAM(name) -#endif - -#if defined(JSON_HEDLEY_IS_CONSTANT) - #undef JSON_HEDLEY_IS_CONSTANT -#endif -#if defined(JSON_HEDLEY_REQUIRE_CONSTEXPR) - #undef JSON_HEDLEY_REQUIRE_CONSTEXPR -#endif -/* JSON_HEDLEY_IS_CONSTEXPR_ is for - HEDLEY INTERNAL USE ONLY. API subject to change without notice. */ -#if defined(JSON_HEDLEY_IS_CONSTEXPR_) - #undef JSON_HEDLEY_IS_CONSTEXPR_ -#endif -#if \ - JSON_HEDLEY_HAS_BUILTIN(__builtin_constant_p) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,19) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ - JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ - (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) && !defined(__cplusplus)) || \ - JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) - #define JSON_HEDLEY_IS_CONSTANT(expr) __builtin_constant_p(expr) -#endif -#if !defined(__cplusplus) -# if \ - JSON_HEDLEY_HAS_BUILTIN(__builtin_types_compatible_p) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ - JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \ - JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,24) -#if defined(__INTPTR_TYPE__) - #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0)), int*) -#else - #include - #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((intptr_t) ((expr) * 0)) : (int*) 0)), int*) -#endif -# elif \ - ( \ - defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) && \ - !defined(JSON_HEDLEY_SUNPRO_VERSION) && \ - !defined(JSON_HEDLEY_PGI_VERSION) && \ - !defined(JSON_HEDLEY_IAR_VERSION)) || \ - JSON_HEDLEY_HAS_EXTENSION(c_generic_selections) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) || \ - JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \ - JSON_HEDLEY_ARM_VERSION_CHECK(5,3,0) -#if defined(__INTPTR_TYPE__) - #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0), int*: 1, void*: 0) -#else - #include - #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((intptr_t) * 0) : (int*) 0), int*: 1, void*: 0) -#endif -# elif \ - defined(JSON_HEDLEY_GCC_VERSION) || \ - defined(JSON_HEDLEY_INTEL_VERSION) || \ - defined(JSON_HEDLEY_TINYC_VERSION) || \ - defined(JSON_HEDLEY_TI_ARMCL_VERSION) || \ - JSON_HEDLEY_TI_CL430_VERSION_CHECK(18,12,0) || \ - defined(JSON_HEDLEY_TI_CL2000_VERSION) || \ - defined(JSON_HEDLEY_TI_CL6X_VERSION) || \ - defined(JSON_HEDLEY_TI_CL7X_VERSION) || \ - defined(JSON_HEDLEY_TI_CLPRU_VERSION) || \ - defined(__clang__) -# define JSON_HEDLEY_IS_CONSTEXPR_(expr) ( \ - sizeof(void) != \ - sizeof(*( \ - 1 ? \ - ((void*) ((expr) * 0L) ) : \ -((struct { char v[sizeof(void) * 2]; } *) 1) \ - ) \ - ) \ - ) -# endif -#endif -#if defined(JSON_HEDLEY_IS_CONSTEXPR_) - #if !defined(JSON_HEDLEY_IS_CONSTANT) - #define JSON_HEDLEY_IS_CONSTANT(expr) JSON_HEDLEY_IS_CONSTEXPR_(expr) - #endif - #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (JSON_HEDLEY_IS_CONSTEXPR_(expr) ? (expr) : (-1)) -#else - #if !defined(JSON_HEDLEY_IS_CONSTANT) - #define JSON_HEDLEY_IS_CONSTANT(expr) (0) - #endif - #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (expr) -#endif - -#if defined(JSON_HEDLEY_BEGIN_C_DECLS) - #undef JSON_HEDLEY_BEGIN_C_DECLS -#endif -#if defined(JSON_HEDLEY_END_C_DECLS) - #undef JSON_HEDLEY_END_C_DECLS -#endif -#if defined(JSON_HEDLEY_C_DECL) - #undef JSON_HEDLEY_C_DECL -#endif -#if defined(__cplusplus) - #define JSON_HEDLEY_BEGIN_C_DECLS extern "C" { - #define JSON_HEDLEY_END_C_DECLS } - #define JSON_HEDLEY_C_DECL extern "C" -#else - #define JSON_HEDLEY_BEGIN_C_DECLS - #define JSON_HEDLEY_END_C_DECLS - #define JSON_HEDLEY_C_DECL -#endif - -#if defined(JSON_HEDLEY_STATIC_ASSERT) - #undef JSON_HEDLEY_STATIC_ASSERT -#endif -#if \ - !defined(__cplusplus) && ( \ - (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) || \ - JSON_HEDLEY_HAS_FEATURE(c_static_assert) || \ - JSON_HEDLEY_GCC_VERSION_CHECK(6,0,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ - defined(_Static_assert) \ - ) -# define JSON_HEDLEY_STATIC_ASSERT(expr, message) _Static_assert(expr, message) -#elif \ - (defined(__cplusplus) && (__cplusplus >= 201103L)) || \ - JSON_HEDLEY_MSVC_VERSION_CHECK(16,0,0) -# define JSON_HEDLEY_STATIC_ASSERT(expr, message) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(static_assert(expr, message)) -#else -# define JSON_HEDLEY_STATIC_ASSERT(expr, message) -#endif - -#if defined(JSON_HEDLEY_NULL) - #undef JSON_HEDLEY_NULL -#endif -#if defined(__cplusplus) - #if __cplusplus >= 201103L - #define JSON_HEDLEY_NULL JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(nullptr) - #elif defined(NULL) - #define JSON_HEDLEY_NULL NULL - #else - #define JSON_HEDLEY_NULL JSON_HEDLEY_STATIC_CAST(void*, 0) - #endif -#elif defined(NULL) - #define JSON_HEDLEY_NULL NULL -#else - #define JSON_HEDLEY_NULL ((void*) 0) -#endif - -#if defined(JSON_HEDLEY_MESSAGE) - #undef JSON_HEDLEY_MESSAGE -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") -# define JSON_HEDLEY_MESSAGE(msg) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \ - JSON_HEDLEY_PRAGMA(message msg) \ - JSON_HEDLEY_DIAGNOSTIC_POP -#elif \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,4,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) -# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message msg) -#elif JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) -# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(_CRI message msg) -#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) -# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg)) -#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,0,0) -# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg)) -#else -# define JSON_HEDLEY_MESSAGE(msg) -#endif - -#if defined(JSON_HEDLEY_WARNING) - #undef JSON_HEDLEY_WARNING -#endif -#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") -# define JSON_HEDLEY_WARNING(msg) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \ - JSON_HEDLEY_PRAGMA(clang warning msg) \ - JSON_HEDLEY_DIAGNOSTIC_POP -#elif \ - JSON_HEDLEY_GCC_VERSION_CHECK(4,8,0) || \ - JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \ - JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) -# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(GCC warning msg) -#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) -# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(message(msg)) -#else -# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_MESSAGE(msg) -#endif - -#if defined(JSON_HEDLEY_REQUIRE) - #undef JSON_HEDLEY_REQUIRE -#endif -#if defined(JSON_HEDLEY_REQUIRE_MSG) - #undef JSON_HEDLEY_REQUIRE_MSG -#endif -#if JSON_HEDLEY_HAS_ATTRIBUTE(diagnose_if) -# if JSON_HEDLEY_HAS_WARNING("-Wgcc-compat") -# define JSON_HEDLEY_REQUIRE(expr) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ - __attribute__((diagnose_if(!(expr), #expr, "error"))) \ - JSON_HEDLEY_DIAGNOSTIC_POP -# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ - __attribute__((diagnose_if(!(expr), msg, "error"))) \ - JSON_HEDLEY_DIAGNOSTIC_POP -# else -# define JSON_HEDLEY_REQUIRE(expr) __attribute__((diagnose_if(!(expr), #expr, "error"))) -# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) __attribute__((diagnose_if(!(expr), msg, "error"))) -# endif -#else -# define JSON_HEDLEY_REQUIRE(expr) -# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) -#endif - -#if defined(JSON_HEDLEY_FLAGS) - #undef JSON_HEDLEY_FLAGS -#endif -#if JSON_HEDLEY_HAS_ATTRIBUTE(flag_enum) - #define JSON_HEDLEY_FLAGS __attribute__((__flag_enum__)) -#endif - -#if defined(JSON_HEDLEY_FLAGS_CAST) - #undef JSON_HEDLEY_FLAGS_CAST -#endif -#if JSON_HEDLEY_INTEL_VERSION_CHECK(19,0,0) -# define JSON_HEDLEY_FLAGS_CAST(T, expr) (__extension__ ({ \ - JSON_HEDLEY_DIAGNOSTIC_PUSH \ - _Pragma("warning(disable:188)") \ - ((T) (expr)); \ - JSON_HEDLEY_DIAGNOSTIC_POP \ - })) -#else -# define JSON_HEDLEY_FLAGS_CAST(T, expr) JSON_HEDLEY_STATIC_CAST(T, expr) -#endif - -#if defined(JSON_HEDLEY_EMPTY_BASES) - #undef JSON_HEDLEY_EMPTY_BASES -#endif -#if JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,23918) && !JSON_HEDLEY_MSVC_VERSION_CHECK(20,0,0) - #define JSON_HEDLEY_EMPTY_BASES __declspec(empty_bases) -#else - #define JSON_HEDLEY_EMPTY_BASES -#endif - -/* Remaining macros are deprecated. */ - -#if defined(JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK) - #undef JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK -#endif -#if defined(__clang__) - #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) (0) -#else - #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) -#endif - -#if defined(JSON_HEDLEY_CLANG_HAS_ATTRIBUTE) - #undef JSON_HEDLEY_CLANG_HAS_ATTRIBUTE -#endif -#define JSON_HEDLEY_CLANG_HAS_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) - -#if defined(JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE) - #undef JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE -#endif -#define JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) - -#if defined(JSON_HEDLEY_CLANG_HAS_BUILTIN) - #undef JSON_HEDLEY_CLANG_HAS_BUILTIN -#endif -#define JSON_HEDLEY_CLANG_HAS_BUILTIN(builtin) JSON_HEDLEY_HAS_BUILTIN(builtin) - -#if defined(JSON_HEDLEY_CLANG_HAS_FEATURE) - #undef JSON_HEDLEY_CLANG_HAS_FEATURE -#endif -#define JSON_HEDLEY_CLANG_HAS_FEATURE(feature) JSON_HEDLEY_HAS_FEATURE(feature) - -#if defined(JSON_HEDLEY_CLANG_HAS_EXTENSION) - #undef JSON_HEDLEY_CLANG_HAS_EXTENSION -#endif -#define JSON_HEDLEY_CLANG_HAS_EXTENSION(extension) JSON_HEDLEY_HAS_EXTENSION(extension) - -#if defined(JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE) - #undef JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE -#endif -#define JSON_HEDLEY_CLANG_HAS_DECLSPEC_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) - -#if defined(JSON_HEDLEY_CLANG_HAS_WARNING) - #undef JSON_HEDLEY_CLANG_HAS_WARNING -#endif -#define JSON_HEDLEY_CLANG_HAS_WARNING(warning) JSON_HEDLEY_HAS_WARNING(warning) - -#endif /* !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < X) */ - - -// This file contains all internal macro definitions -// You MUST include macro_unscope.hpp at the end of json.hpp to undef all of them - -// exclude unsupported compilers -#if !defined(JSON_SKIP_UNSUPPORTED_COMPILER_CHECK) - #if defined(__clang__) - #if (__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__) < 30400 - #error "unsupported Clang version - see https://github.com/nlohmann/json#supported-compilers" - #endif - #elif defined(__GNUC__) && !(defined(__ICC) || defined(__INTEL_COMPILER)) - #if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) < 40800 - #error "unsupported GCC version - see https://github.com/nlohmann/json#supported-compilers" - #endif - #endif -#endif - -// C++ language standard detection -#if (defined(__cplusplus) && __cplusplus >= 202002L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 202002L) - #define JSON_HAS_CPP_20 - #define JSON_HAS_CPP_17 - #define JSON_HAS_CPP_14 -#elif (defined(__cplusplus) && __cplusplus >= 201703L) || (defined(_HAS_CXX17) && _HAS_CXX17 == 1) // fix for issue #464 - #define JSON_HAS_CPP_17 - #define JSON_HAS_CPP_14 -#elif (defined(__cplusplus) && __cplusplus >= 201402L) || (defined(_HAS_CXX14) && _HAS_CXX14 == 1) - #define JSON_HAS_CPP_14 -#endif - -// disable float-equal warnings on GCC/clang -#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wfloat-equal" -#endif - -// disable documentation warnings on clang -#if defined(__clang__) - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wdocumentation" -#endif - -// allow to disable exceptions -#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) && !defined(JSON_NOEXCEPTION) - #define JSON_THROW(exception) throw exception - #define JSON_TRY try - #define JSON_CATCH(exception) catch(exception) - #define JSON_INTERNAL_CATCH(exception) catch(exception) -#else - #include - #define JSON_THROW(exception) std::abort() - #define JSON_TRY if(true) - #define JSON_CATCH(exception) if(false) - #define JSON_INTERNAL_CATCH(exception) if(false) -#endif - -// override exception macros -#if defined(JSON_THROW_USER) - #undef JSON_THROW - #define JSON_THROW JSON_THROW_USER -#endif -#if defined(JSON_TRY_USER) - #undef JSON_TRY - #define JSON_TRY JSON_TRY_USER -#endif -#if defined(JSON_CATCH_USER) - #undef JSON_CATCH - #define JSON_CATCH JSON_CATCH_USER - #undef JSON_INTERNAL_CATCH - #define JSON_INTERNAL_CATCH JSON_CATCH_USER -#endif -#if defined(JSON_INTERNAL_CATCH_USER) - #undef JSON_INTERNAL_CATCH - #define JSON_INTERNAL_CATCH JSON_INTERNAL_CATCH_USER -#endif - -// allow to override assert -#if !defined(JSON_ASSERT) - #include // assert - #define JSON_ASSERT(x) assert(x) -#endif - -/*! -@brief macro to briefly define a mapping between an enum and JSON -@def NLOHMANN_JSON_SERIALIZE_ENUM -@since version 3.4.0 -*/ -#define NLOHMANN_JSON_SERIALIZE_ENUM(ENUM_TYPE, ...) \ - template \ - inline void to_json(BasicJsonType& j, const ENUM_TYPE& e) \ - { \ - static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ - static const std::pair m[] = __VA_ARGS__; \ - auto it = std::find_if(std::begin(m), std::end(m), \ - [e](const std::pair& ej_pair) -> bool \ - { \ - return ej_pair.first == e; \ - }); \ - j = ((it != std::end(m)) ? it : std::begin(m))->second; \ - } \ - template \ - inline void from_json(const BasicJsonType& j, ENUM_TYPE& e) \ - { \ - static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ - static const std::pair m[] = __VA_ARGS__; \ - auto it = std::find_if(std::begin(m), std::end(m), \ - [&j](const std::pair& ej_pair) -> bool \ - { \ - return ej_pair.second == j; \ - }); \ - e = ((it != std::end(m)) ? it : std::begin(m))->first; \ - } - -// Ugly macros to avoid uglier copy-paste when specializing basic_json. They -// may be removed in the future once the class is split. - -#define NLOHMANN_BASIC_JSON_TPL_DECLARATION \ - template class ObjectType, \ - template class ArrayType, \ - class StringType, class BooleanType, class NumberIntegerType, \ - class NumberUnsignedType, class NumberFloatType, \ - template class AllocatorType, \ - template class JSONSerializer, \ - class BinaryType> - -#define NLOHMANN_BASIC_JSON_TPL \ - basic_json - -// Macros to simplify conversion from/to types - -#define NLOHMANN_JSON_EXPAND( x ) x -#define NLOHMANN_JSON_GET_MACRO(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, NAME,...) NAME -#define NLOHMANN_JSON_PASTE(...) NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_GET_MACRO(__VA_ARGS__, \ - NLOHMANN_JSON_PASTE64, \ - NLOHMANN_JSON_PASTE63, \ - NLOHMANN_JSON_PASTE62, \ - NLOHMANN_JSON_PASTE61, \ - NLOHMANN_JSON_PASTE60, \ - NLOHMANN_JSON_PASTE59, \ - NLOHMANN_JSON_PASTE58, \ - NLOHMANN_JSON_PASTE57, \ - NLOHMANN_JSON_PASTE56, \ - NLOHMANN_JSON_PASTE55, \ - NLOHMANN_JSON_PASTE54, \ - NLOHMANN_JSON_PASTE53, \ - NLOHMANN_JSON_PASTE52, \ - NLOHMANN_JSON_PASTE51, \ - NLOHMANN_JSON_PASTE50, \ - NLOHMANN_JSON_PASTE49, \ - NLOHMANN_JSON_PASTE48, \ - NLOHMANN_JSON_PASTE47, \ - NLOHMANN_JSON_PASTE46, \ - NLOHMANN_JSON_PASTE45, \ - NLOHMANN_JSON_PASTE44, \ - NLOHMANN_JSON_PASTE43, \ - NLOHMANN_JSON_PASTE42, \ - NLOHMANN_JSON_PASTE41, \ - NLOHMANN_JSON_PASTE40, \ - NLOHMANN_JSON_PASTE39, \ - NLOHMANN_JSON_PASTE38, \ - NLOHMANN_JSON_PASTE37, \ - NLOHMANN_JSON_PASTE36, \ - NLOHMANN_JSON_PASTE35, \ - NLOHMANN_JSON_PASTE34, \ - NLOHMANN_JSON_PASTE33, \ - NLOHMANN_JSON_PASTE32, \ - NLOHMANN_JSON_PASTE31, \ - NLOHMANN_JSON_PASTE30, \ - NLOHMANN_JSON_PASTE29, \ - NLOHMANN_JSON_PASTE28, \ - NLOHMANN_JSON_PASTE27, \ - NLOHMANN_JSON_PASTE26, \ - NLOHMANN_JSON_PASTE25, \ - NLOHMANN_JSON_PASTE24, \ - NLOHMANN_JSON_PASTE23, \ - NLOHMANN_JSON_PASTE22, \ - NLOHMANN_JSON_PASTE21, \ - NLOHMANN_JSON_PASTE20, \ - NLOHMANN_JSON_PASTE19, \ - NLOHMANN_JSON_PASTE18, \ - NLOHMANN_JSON_PASTE17, \ - NLOHMANN_JSON_PASTE16, \ - NLOHMANN_JSON_PASTE15, \ - NLOHMANN_JSON_PASTE14, \ - NLOHMANN_JSON_PASTE13, \ - NLOHMANN_JSON_PASTE12, \ - NLOHMANN_JSON_PASTE11, \ - NLOHMANN_JSON_PASTE10, \ - NLOHMANN_JSON_PASTE9, \ - NLOHMANN_JSON_PASTE8, \ - NLOHMANN_JSON_PASTE7, \ - NLOHMANN_JSON_PASTE6, \ - NLOHMANN_JSON_PASTE5, \ - NLOHMANN_JSON_PASTE4, \ - NLOHMANN_JSON_PASTE3, \ - NLOHMANN_JSON_PASTE2, \ - NLOHMANN_JSON_PASTE1)(__VA_ARGS__)) -#define NLOHMANN_JSON_PASTE2(func, v1) func(v1) -#define NLOHMANN_JSON_PASTE3(func, v1, v2) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE2(func, v2) -#define NLOHMANN_JSON_PASTE4(func, v1, v2, v3) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE3(func, v2, v3) -#define NLOHMANN_JSON_PASTE5(func, v1, v2, v3, v4) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE4(func, v2, v3, v4) -#define NLOHMANN_JSON_PASTE6(func, v1, v2, v3, v4, v5) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE5(func, v2, v3, v4, v5) -#define NLOHMANN_JSON_PASTE7(func, v1, v2, v3, v4, v5, v6) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE6(func, v2, v3, v4, v5, v6) -#define NLOHMANN_JSON_PASTE8(func, v1, v2, v3, v4, v5, v6, v7) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE7(func, v2, v3, v4, v5, v6, v7) -#define NLOHMANN_JSON_PASTE9(func, v1, v2, v3, v4, v5, v6, v7, v8) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE8(func, v2, v3, v4, v5, v6, v7, v8) -#define NLOHMANN_JSON_PASTE10(func, v1, v2, v3, v4, v5, v6, v7, v8, v9) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE9(func, v2, v3, v4, v5, v6, v7, v8, v9) -#define NLOHMANN_JSON_PASTE11(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE10(func, v2, v3, v4, v5, v6, v7, v8, v9, v10) -#define NLOHMANN_JSON_PASTE12(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE11(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) -#define NLOHMANN_JSON_PASTE13(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE12(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) -#define NLOHMANN_JSON_PASTE14(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE13(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) -#define NLOHMANN_JSON_PASTE15(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE14(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) -#define NLOHMANN_JSON_PASTE16(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE15(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) -#define NLOHMANN_JSON_PASTE17(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE16(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) -#define NLOHMANN_JSON_PASTE18(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE17(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) -#define NLOHMANN_JSON_PASTE19(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE18(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) -#define NLOHMANN_JSON_PASTE20(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE19(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) -#define NLOHMANN_JSON_PASTE21(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE20(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) -#define NLOHMANN_JSON_PASTE22(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE21(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) -#define NLOHMANN_JSON_PASTE23(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE22(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) -#define NLOHMANN_JSON_PASTE24(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE23(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) -#define NLOHMANN_JSON_PASTE25(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE24(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) -#define NLOHMANN_JSON_PASTE26(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE25(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) -#define NLOHMANN_JSON_PASTE27(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE26(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) -#define NLOHMANN_JSON_PASTE28(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE27(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) -#define NLOHMANN_JSON_PASTE29(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE28(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) -#define NLOHMANN_JSON_PASTE30(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE29(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) -#define NLOHMANN_JSON_PASTE31(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE30(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) -#define NLOHMANN_JSON_PASTE32(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE31(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) -#define NLOHMANN_JSON_PASTE33(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE32(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) -#define NLOHMANN_JSON_PASTE34(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE33(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) -#define NLOHMANN_JSON_PASTE35(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE34(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) -#define NLOHMANN_JSON_PASTE36(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE35(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) -#define NLOHMANN_JSON_PASTE37(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE36(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) -#define NLOHMANN_JSON_PASTE38(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE37(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) -#define NLOHMANN_JSON_PASTE39(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE38(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) -#define NLOHMANN_JSON_PASTE40(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE39(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) -#define NLOHMANN_JSON_PASTE41(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE40(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) -#define NLOHMANN_JSON_PASTE42(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE41(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) -#define NLOHMANN_JSON_PASTE43(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE42(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) -#define NLOHMANN_JSON_PASTE44(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE43(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) -#define NLOHMANN_JSON_PASTE45(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE44(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) -#define NLOHMANN_JSON_PASTE46(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE45(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) -#define NLOHMANN_JSON_PASTE47(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE46(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) -#define NLOHMANN_JSON_PASTE48(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE47(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) -#define NLOHMANN_JSON_PASTE49(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE48(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) -#define NLOHMANN_JSON_PASTE50(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE49(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) -#define NLOHMANN_JSON_PASTE51(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE50(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) -#define NLOHMANN_JSON_PASTE52(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE51(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) -#define NLOHMANN_JSON_PASTE53(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE52(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) -#define NLOHMANN_JSON_PASTE54(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE53(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) -#define NLOHMANN_JSON_PASTE55(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE54(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) -#define NLOHMANN_JSON_PASTE56(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE55(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) -#define NLOHMANN_JSON_PASTE57(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE56(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) -#define NLOHMANN_JSON_PASTE58(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE57(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) -#define NLOHMANN_JSON_PASTE59(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE58(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) -#define NLOHMANN_JSON_PASTE60(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE59(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) -#define NLOHMANN_JSON_PASTE61(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE60(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) -#define NLOHMANN_JSON_PASTE62(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE61(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) -#define NLOHMANN_JSON_PASTE63(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE62(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) -#define NLOHMANN_JSON_PASTE64(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE63(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) - -#define NLOHMANN_JSON_TO(v1) nlohmann_json_j[#v1] = nlohmann_json_t.v1; -#define NLOHMANN_JSON_FROM(v1) nlohmann_json_j.at(#v1).get_to(nlohmann_json_t.v1); - -/*! -@brief macro -@def NLOHMANN_DEFINE_TYPE_INTRUSIVE -@since version 3.9.0 -*/ -#define NLOHMANN_DEFINE_TYPE_INTRUSIVE(Type, ...) \ - friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ - friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } - -/*! -@brief macro -@def NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE -@since version 3.9.0 -*/ -#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Type, ...) \ - inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ - inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } - -#ifndef JSON_USE_IMPLICIT_CONVERSIONS - #define JSON_USE_IMPLICIT_CONVERSIONS 1 -#endif - -#if JSON_USE_IMPLICIT_CONVERSIONS - #define JSON_EXPLICIT -#else - #define JSON_EXPLICIT explicit -#endif - - -namespace nlohmann -{ -namespace detail -{ -//////////////// -// exceptions // -//////////////// - -/*! -@brief general exception of the @ref basic_json class - -This class is an extension of `std::exception` objects with a member @a id for -exception ids. It is used as the base class for all exceptions thrown by the -@ref basic_json class. This class can hence be used as "wildcard" to catch -exceptions. - -Subclasses: -- @ref parse_error for exceptions indicating a parse error -- @ref invalid_iterator for exceptions indicating errors with iterators -- @ref type_error for exceptions indicating executing a member function with - a wrong type -- @ref out_of_range for exceptions indicating access out of the defined range -- @ref other_error for exceptions indicating other library errors - -@internal -@note To have nothrow-copy-constructible exceptions, we internally use - `std::runtime_error` which can cope with arbitrary-length error messages. - Intermediate strings are built with static functions and then passed to - the actual constructor. -@endinternal - -@liveexample{The following code shows how arbitrary library exceptions can be -caught.,exception} - -@since version 3.0.0 -*/ -class exception : public std::exception -{ - public: - /// returns the explanatory string - JSON_HEDLEY_RETURNS_NON_NULL - const char* what() const noexcept override - { - return m.what(); - } - - /// the id of the exception - const int id; - - protected: - JSON_HEDLEY_NON_NULL(3) - exception(int id_, const char* what_arg) : id(id_), m(what_arg) {} - - static std::string name(const std::string& ename, int id_) - { - return "[json.exception." + ename + "." + std::to_string(id_) + "] "; - } - - private: - /// an exception object as storage for error messages - std::runtime_error m; -}; - -/*! -@brief exception indicating a parse error - -This exception is thrown by the library when a parse error occurs. Parse errors -can occur during the deserialization of JSON text, CBOR, MessagePack, as well -as when using JSON Patch. - -Member @a byte holds the byte index of the last read character in the input -file. - -Exceptions have ids 1xx. - -name / id | example message | description ------------------------------- | --------------- | ------------------------- -json.exception.parse_error.101 | parse error at 2: unexpected end of input; expected string literal | This error indicates a syntax error while deserializing a JSON text. The error message describes that an unexpected token (character) was encountered, and the member @a byte indicates the error position. -json.exception.parse_error.102 | parse error at 14: missing or wrong low surrogate | JSON uses the `\uxxxx` format to describe Unicode characters. Code points above above 0xFFFF are split into two `\uxxxx` entries ("surrogate pairs"). This error indicates that the surrogate pair is incomplete or contains an invalid code point. -json.exception.parse_error.103 | parse error: code points above 0x10FFFF are invalid | Unicode supports code points up to 0x10FFFF. Code points above 0x10FFFF are invalid. -json.exception.parse_error.104 | parse error: JSON patch must be an array of objects | [RFC 6902](https://tools.ietf.org/html/rfc6902) requires a JSON Patch document to be a JSON document that represents an array of objects. -json.exception.parse_error.105 | parse error: operation must have string member 'op' | An operation of a JSON Patch document must contain exactly one "op" member, whose value indicates the operation to perform. Its value must be one of "add", "remove", "replace", "move", "copy", or "test"; other values are errors. -json.exception.parse_error.106 | parse error: array index '01' must not begin with '0' | An array index in a JSON Pointer ([RFC 6901](https://tools.ietf.org/html/rfc6901)) may be `0` or any number without a leading `0`. -json.exception.parse_error.107 | parse error: JSON pointer must be empty or begin with '/' - was: 'foo' | A JSON Pointer must be a Unicode string containing a sequence of zero or more reference tokens, each prefixed by a `/` character. -json.exception.parse_error.108 | parse error: escape character '~' must be followed with '0' or '1' | In a JSON Pointer, only `~0` and `~1` are valid escape sequences. -json.exception.parse_error.109 | parse error: array index 'one' is not a number | A JSON Pointer array index must be a number. -json.exception.parse_error.110 | parse error at 1: cannot read 2 bytes from vector | When parsing CBOR or MessagePack, the byte vector ends before the complete value has been read. -json.exception.parse_error.112 | parse error at 1: error reading CBOR; last byte: 0xF8 | Not all types of CBOR or MessagePack are supported. This exception occurs if an unsupported byte was read. -json.exception.parse_error.113 | parse error at 2: expected a CBOR string; last byte: 0x98 | While parsing a map key, a value that is not a string has been read. -json.exception.parse_error.114 | parse error: Unsupported BSON record type 0x0F | The parsing of the corresponding BSON record type is not implemented (yet). -json.exception.parse_error.115 | parse error at byte 5: syntax error while parsing UBJSON high-precision number: invalid number text: 1A | A UBJSON high-precision number could not be parsed. - -@note For an input with n bytes, 1 is the index of the first character and n+1 - is the index of the terminating null byte or the end of file. This also - holds true when reading a byte vector (CBOR or MessagePack). - -@liveexample{The following code shows how a `parse_error` exception can be -caught.,parse_error} - -@sa - @ref exception for the base class of the library exceptions -@sa - @ref invalid_iterator for exceptions indicating errors with iterators -@sa - @ref type_error for exceptions indicating executing a member function with - a wrong type -@sa - @ref out_of_range for exceptions indicating access out of the defined range -@sa - @ref other_error for exceptions indicating other library errors - -@since version 3.0.0 -*/ -class parse_error : public exception -{ - public: - /*! - @brief create a parse error exception - @param[in] id_ the id of the exception - @param[in] pos the position where the error occurred (or with - chars_read_total=0 if the position cannot be - determined) - @param[in] what_arg the explanatory string - @return parse_error object - */ - static parse_error create(int id_, const position_t& pos, const std::string& what_arg) - { - std::string w = exception::name("parse_error", id_) + "parse error" + - position_string(pos) + ": " + what_arg; - return parse_error(id_, pos.chars_read_total, w.c_str()); - } - - static parse_error create(int id_, std::size_t byte_, const std::string& what_arg) - { - std::string w = exception::name("parse_error", id_) + "parse error" + - (byte_ != 0 ? (" at byte " + std::to_string(byte_)) : "") + - ": " + what_arg; - return parse_error(id_, byte_, w.c_str()); - } - - /*! - @brief byte index of the parse error - - The byte index of the last read character in the input file. - - @note For an input with n bytes, 1 is the index of the first character and - n+1 is the index of the terminating null byte or the end of file. - This also holds true when reading a byte vector (CBOR or MessagePack). - */ - const std::size_t byte; - - private: - parse_error(int id_, std::size_t byte_, const char* what_arg) - : exception(id_, what_arg), byte(byte_) {} - - static std::string position_string(const position_t& pos) - { - return " at line " + std::to_string(pos.lines_read + 1) + - ", column " + std::to_string(pos.chars_read_current_line); - } -}; - -/*! -@brief exception indicating errors with iterators - -This exception is thrown if iterators passed to a library function do not match -the expected semantics. - -Exceptions have ids 2xx. - -name / id | example message | description ------------------------------------ | --------------- | ------------------------- -json.exception.invalid_iterator.201 | iterators are not compatible | The iterators passed to constructor @ref basic_json(InputIT first, InputIT last) are not compatible, meaning they do not belong to the same container. Therefore, the range (@a first, @a last) is invalid. -json.exception.invalid_iterator.202 | iterator does not fit current value | In an erase or insert function, the passed iterator @a pos does not belong to the JSON value for which the function was called. It hence does not define a valid position for the deletion/insertion. -json.exception.invalid_iterator.203 | iterators do not fit current value | Either iterator passed to function @ref erase(IteratorType first, IteratorType last) does not belong to the JSON value from which values shall be erased. It hence does not define a valid range to delete values from. -json.exception.invalid_iterator.204 | iterators out of range | When an iterator range for a primitive type (number, boolean, or string) is passed to a constructor or an erase function, this range has to be exactly (@ref begin(), @ref end()), because this is the only way the single stored value is expressed. All other ranges are invalid. -json.exception.invalid_iterator.205 | iterator out of range | When an iterator for a primitive type (number, boolean, or string) is passed to an erase function, the iterator has to be the @ref begin() iterator, because it is the only way to address the stored value. All other iterators are invalid. -json.exception.invalid_iterator.206 | cannot construct with iterators from null | The iterators passed to constructor @ref basic_json(InputIT first, InputIT last) belong to a JSON null value and hence to not define a valid range. -json.exception.invalid_iterator.207 | cannot use key() for non-object iterators | The key() member function can only be used on iterators belonging to a JSON object, because other types do not have a concept of a key. -json.exception.invalid_iterator.208 | cannot use operator[] for object iterators | The operator[] to specify a concrete offset cannot be used on iterators belonging to a JSON object, because JSON objects are unordered. -json.exception.invalid_iterator.209 | cannot use offsets with object iterators | The offset operators (+, -, +=, -=) cannot be used on iterators belonging to a JSON object, because JSON objects are unordered. -json.exception.invalid_iterator.210 | iterators do not fit | The iterator range passed to the insert function are not compatible, meaning they do not belong to the same container. Therefore, the range (@a first, @a last) is invalid. -json.exception.invalid_iterator.211 | passed iterators may not belong to container | The iterator range passed to the insert function must not be a subrange of the container to insert to. -json.exception.invalid_iterator.212 | cannot compare iterators of different containers | When two iterators are compared, they must belong to the same container. -json.exception.invalid_iterator.213 | cannot compare order of object iterators | The order of object iterators cannot be compared, because JSON objects are unordered. -json.exception.invalid_iterator.214 | cannot get value | Cannot get value for iterator: Either the iterator belongs to a null value or it is an iterator to a primitive type (number, boolean, or string), but the iterator is different to @ref begin(). - -@liveexample{The following code shows how an `invalid_iterator` exception can be -caught.,invalid_iterator} - -@sa - @ref exception for the base class of the library exceptions -@sa - @ref parse_error for exceptions indicating a parse error -@sa - @ref type_error for exceptions indicating executing a member function with - a wrong type -@sa - @ref out_of_range for exceptions indicating access out of the defined range -@sa - @ref other_error for exceptions indicating other library errors - -@since version 3.0.0 -*/ -class invalid_iterator : public exception -{ - public: - static invalid_iterator create(int id_, const std::string& what_arg) - { - std::string w = exception::name("invalid_iterator", id_) + what_arg; - return invalid_iterator(id_, w.c_str()); - } - - private: - JSON_HEDLEY_NON_NULL(3) - invalid_iterator(int id_, const char* what_arg) - : exception(id_, what_arg) {} -}; - -/*! -@brief exception indicating executing a member function with a wrong type - -This exception is thrown in case of a type error; that is, a library function is -executed on a JSON value whose type does not match the expected semantics. - -Exceptions have ids 3xx. - -name / id | example message | description ------------------------------ | --------------- | ------------------------- -json.exception.type_error.301 | cannot create object from initializer list | To create an object from an initializer list, the initializer list must consist only of a list of pairs whose first element is a string. When this constraint is violated, an array is created instead. -json.exception.type_error.302 | type must be object, but is array | During implicit or explicit value conversion, the JSON type must be compatible to the target type. For instance, a JSON string can only be converted into string types, but not into numbers or boolean types. -json.exception.type_error.303 | incompatible ReferenceType for get_ref, actual type is object | To retrieve a reference to a value stored in a @ref basic_json object with @ref get_ref, the type of the reference must match the value type. For instance, for a JSON array, the @a ReferenceType must be @ref array_t &. -json.exception.type_error.304 | cannot use at() with string | The @ref at() member functions can only be executed for certain JSON types. -json.exception.type_error.305 | cannot use operator[] with string | The @ref operator[] member functions can only be executed for certain JSON types. -json.exception.type_error.306 | cannot use value() with string | The @ref value() member functions can only be executed for certain JSON types. -json.exception.type_error.307 | cannot use erase() with string | The @ref erase() member functions can only be executed for certain JSON types. -json.exception.type_error.308 | cannot use push_back() with string | The @ref push_back() and @ref operator+= member functions can only be executed for certain JSON types. -json.exception.type_error.309 | cannot use insert() with | The @ref insert() member functions can only be executed for certain JSON types. -json.exception.type_error.310 | cannot use swap() with number | The @ref swap() member functions can only be executed for certain JSON types. -json.exception.type_error.311 | cannot use emplace_back() with string | The @ref emplace_back() member function can only be executed for certain JSON types. -json.exception.type_error.312 | cannot use update() with string | The @ref update() member functions can only be executed for certain JSON types. -json.exception.type_error.313 | invalid value to unflatten | The @ref unflatten function converts an object whose keys are JSON Pointers back into an arbitrary nested JSON value. The JSON Pointers must not overlap, because then the resulting value would not be well defined. -json.exception.type_error.314 | only objects can be unflattened | The @ref unflatten function only works for an object whose keys are JSON Pointers. -json.exception.type_error.315 | values in object must be primitive | The @ref unflatten function only works for an object whose keys are JSON Pointers and whose values are primitive. -json.exception.type_error.316 | invalid UTF-8 byte at index 10: 0x7E | The @ref dump function only works with UTF-8 encoded strings; that is, if you assign a `std::string` to a JSON value, make sure it is UTF-8 encoded. | -json.exception.type_error.317 | JSON value cannot be serialized to requested format | The dynamic type of the object cannot be represented in the requested serialization format (e.g. a raw `true` or `null` JSON object cannot be serialized to BSON) | - -@liveexample{The following code shows how a `type_error` exception can be -caught.,type_error} - -@sa - @ref exception for the base class of the library exceptions -@sa - @ref parse_error for exceptions indicating a parse error -@sa - @ref invalid_iterator for exceptions indicating errors with iterators -@sa - @ref out_of_range for exceptions indicating access out of the defined range -@sa - @ref other_error for exceptions indicating other library errors - -@since version 3.0.0 -*/ -class type_error : public exception -{ - public: - static type_error create(int id_, const std::string& what_arg) - { - std::string w = exception::name("type_error", id_) + what_arg; - return type_error(id_, w.c_str()); - } - - private: - JSON_HEDLEY_NON_NULL(3) - type_error(int id_, const char* what_arg) : exception(id_, what_arg) {} -}; - -/*! -@brief exception indicating access out of the defined range - -This exception is thrown in case a library function is called on an input -parameter that exceeds the expected range, for instance in case of array -indices or nonexisting object keys. - -Exceptions have ids 4xx. - -name / id | example message | description -------------------------------- | --------------- | ------------------------- -json.exception.out_of_range.401 | array index 3 is out of range | The provided array index @a i is larger than @a size-1. -json.exception.out_of_range.402 | array index '-' (3) is out of range | The special array index `-` in a JSON Pointer never describes a valid element of the array, but the index past the end. That is, it can only be used to add elements at this position, but not to read it. -json.exception.out_of_range.403 | key 'foo' not found | The provided key was not found in the JSON object. -json.exception.out_of_range.404 | unresolved reference token 'foo' | A reference token in a JSON Pointer could not be resolved. -json.exception.out_of_range.405 | JSON pointer has no parent | The JSON Patch operations 'remove' and 'add' can not be applied to the root element of the JSON value. -json.exception.out_of_range.406 | number overflow parsing '10E1000' | A parsed number could not be stored as without changing it to NaN or INF. -json.exception.out_of_range.407 | number overflow serializing '9223372036854775808' | UBJSON and BSON only support integer numbers up to 9223372036854775807. (until version 3.8.0) | -json.exception.out_of_range.408 | excessive array size: 8658170730974374167 | The size (following `#`) of an UBJSON array or object exceeds the maximal capacity. | -json.exception.out_of_range.409 | BSON key cannot contain code point U+0000 (at byte 2) | Key identifiers to be serialized to BSON cannot contain code point U+0000, since the key is stored as zero-terminated c-string | - -@liveexample{The following code shows how an `out_of_range` exception can be -caught.,out_of_range} - -@sa - @ref exception for the base class of the library exceptions -@sa - @ref parse_error for exceptions indicating a parse error -@sa - @ref invalid_iterator for exceptions indicating errors with iterators -@sa - @ref type_error for exceptions indicating executing a member function with - a wrong type -@sa - @ref other_error for exceptions indicating other library errors - -@since version 3.0.0 -*/ -class out_of_range : public exception -{ - public: - static out_of_range create(int id_, const std::string& what_arg) - { - std::string w = exception::name("out_of_range", id_) + what_arg; - return out_of_range(id_, w.c_str()); - } - - private: - JSON_HEDLEY_NON_NULL(3) - out_of_range(int id_, const char* what_arg) : exception(id_, what_arg) {} -}; - -/*! -@brief exception indicating other library errors - -This exception is thrown in case of errors that cannot be classified with the -other exception types. - -Exceptions have ids 5xx. - -name / id | example message | description ------------------------------- | --------------- | ------------------------- -json.exception.other_error.501 | unsuccessful: {"op":"test","path":"/baz", "value":"bar"} | A JSON Patch operation 'test' failed. The unsuccessful operation is also printed. - -@sa - @ref exception for the base class of the library exceptions -@sa - @ref parse_error for exceptions indicating a parse error -@sa - @ref invalid_iterator for exceptions indicating errors with iterators -@sa - @ref type_error for exceptions indicating executing a member function with - a wrong type -@sa - @ref out_of_range for exceptions indicating access out of the defined range - -@liveexample{The following code shows how an `other_error` exception can be -caught.,other_error} - -@since version 3.0.0 -*/ -class other_error : public exception -{ - public: - static other_error create(int id_, const std::string& what_arg) - { - std::string w = exception::name("other_error", id_) + what_arg; - return other_error(id_, w.c_str()); - } - - private: - JSON_HEDLEY_NON_NULL(3) - other_error(int id_, const char* what_arg) : exception(id_, what_arg) {} -}; -} // namespace detail -} // namespace nlohmann - -// #include - -// #include - - -#include // size_t -#include // conditional, enable_if, false_type, integral_constant, is_constructible, is_integral, is_same, remove_cv, remove_reference, true_type - -namespace nlohmann -{ -namespace detail -{ -// alias templates to reduce boilerplate -template -using enable_if_t = typename std::enable_if::type; - -template -using uncvref_t = typename std::remove_cv::type>::type; - -// implementation of C++14 index_sequence and affiliates -// source: https://stackoverflow.com/a/32223343 -template -struct index_sequence -{ - using type = index_sequence; - using value_type = std::size_t; - static constexpr std::size_t size() noexcept - { - return sizeof...(Ints); - } -}; - -template -struct merge_and_renumber; - -template -struct merge_and_renumber, index_sequence> - : index_sequence < I1..., (sizeof...(I1) + I2)... > {}; - -template -struct make_index_sequence - : merge_and_renumber < typename make_index_sequence < N / 2 >::type, - typename make_index_sequence < N - N / 2 >::type > {}; - -template<> struct make_index_sequence<0> : index_sequence<> {}; -template<> struct make_index_sequence<1> : index_sequence<0> {}; - -template -using index_sequence_for = make_index_sequence; - -// dispatch utility (taken from ranges-v3) -template struct priority_tag : priority_tag < N - 1 > {}; -template<> struct priority_tag<0> {}; - -// taken from ranges-v3 -template -struct static_const -{ - static constexpr T value{}; -}; - -template -constexpr T static_const::value; -} // namespace detail -} // namespace nlohmann - -// #include - - -#include // numeric_limits -#include // false_type, is_constructible, is_integral, is_same, true_type -#include // declval - -// #include - - -#include // random_access_iterator_tag - -// #include - - -namespace nlohmann -{ -namespace detail -{ -template struct make_void -{ - using type = void; -}; -template using void_t = typename make_void::type; -} // namespace detail -} // namespace nlohmann - -// #include - - -namespace nlohmann -{ -namespace detail -{ -template -struct iterator_types {}; - -template -struct iterator_types < - It, - void_t> -{ - using difference_type = typename It::difference_type; - using value_type = typename It::value_type; - using pointer = typename It::pointer; - using reference = typename It::reference; - using iterator_category = typename It::iterator_category; -}; - -// This is required as some compilers implement std::iterator_traits in a way that -// doesn't work with SFINAE. See https://github.com/nlohmann/json/issues/1341. -template -struct iterator_traits -{ -}; - -template -struct iterator_traits < T, enable_if_t < !std::is_pointer::value >> - : iterator_types -{ -}; - -template -struct iterator_traits::value>> -{ - using iterator_category = std::random_access_iterator_tag; - using value_type = T; - using difference_type = ptrdiff_t; - using pointer = T*; - using reference = T&; -}; -} // namespace detail -} // namespace nlohmann - -// #include - -// #include - -// #include - - -#include - -// #include - - -// https://en.cppreference.com/w/cpp/experimental/is_detected -namespace nlohmann -{ -namespace detail -{ -struct nonesuch -{ - nonesuch() = delete; - ~nonesuch() = delete; - nonesuch(nonesuch const&) = delete; - nonesuch(nonesuch const&&) = delete; - void operator=(nonesuch const&) = delete; - void operator=(nonesuch&&) = delete; -}; - -template class Op, - class... Args> -struct detector -{ - using value_t = std::false_type; - using type = Default; -}; - -template class Op, class... Args> -struct detector>, Op, Args...> -{ - using value_t = std::true_type; - using type = Op; -}; - -template class Op, class... Args> -using is_detected = typename detector::value_t; - -template class Op, class... Args> -using detected_t = typename detector::type; - -template class Op, class... Args> -using detected_or = detector; - -template class Op, class... Args> -using detected_or_t = typename detected_or::type; - -template class Op, class... Args> -using is_detected_exact = std::is_same>; - -template class Op, class... Args> -using is_detected_convertible = - std::is_convertible, To>; -} // namespace detail -} // namespace nlohmann - -// #include -#ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_ -#define INCLUDE_NLOHMANN_JSON_FWD_HPP_ - -#include // int64_t, uint64_t -#include // map -#include // allocator -#include // string -#include // vector - -/*! -@brief namespace for Niels Lohmann -@see https://github.com/nlohmann -@since version 1.0.0 -*/ -namespace nlohmann -{ -/*! -@brief default JSONSerializer template argument - -This serializer ignores the template arguments and uses ADL -([argument-dependent lookup](https://en.cppreference.com/w/cpp/language/adl)) -for serialization. -*/ -template -struct adl_serializer; - -template class ObjectType = - std::map, - template class ArrayType = std::vector, - class StringType = std::string, class BooleanType = bool, - class NumberIntegerType = std::int64_t, - class NumberUnsignedType = std::uint64_t, - class NumberFloatType = double, - template class AllocatorType = std::allocator, - template class JSONSerializer = - adl_serializer, - class BinaryType = std::vector> -class basic_json; - -/*! -@brief JSON Pointer - -A JSON pointer defines a string syntax for identifying a specific value -within a JSON document. It can be used with functions `at` and -`operator[]`. Furthermore, JSON pointers are the base for JSON patches. - -@sa [RFC 6901](https://tools.ietf.org/html/rfc6901) - -@since version 2.0.0 -*/ -template -class json_pointer; - -/*! -@brief default JSON class - -This type is the default specialization of the @ref basic_json class which -uses the standard template types. - -@since version 1.0.0 -*/ -using json = basic_json<>; - -template -struct ordered_map; - -/*! -@brief ordered JSON class - -This type preserves the insertion order of object keys. - -@since version 3.9.0 -*/ -using ordered_json = basic_json; - -} // namespace nlohmann - -#endif // INCLUDE_NLOHMANN_JSON_FWD_HPP_ - - -namespace nlohmann -{ -/*! -@brief detail namespace with internal helper functions - -This namespace collects functions that should not be exposed, -implementations of some @ref basic_json methods, and meta-programming helpers. - -@since version 2.1.0 -*/ -namespace detail -{ -///////////// -// helpers // -///////////// - -// Note to maintainers: -// -// Every trait in this file expects a non CV-qualified type. -// The only exceptions are in the 'aliases for detected' section -// (i.e. those of the form: decltype(T::member_function(std::declval()))) -// -// In this case, T has to be properly CV-qualified to constraint the function arguments -// (e.g. to_json(BasicJsonType&, const T&)) - -template struct is_basic_json : std::false_type {}; - -NLOHMANN_BASIC_JSON_TPL_DECLARATION -struct is_basic_json : std::true_type {}; - -////////////////////// -// json_ref helpers // -////////////////////// - -template -class json_ref; - -template -struct is_json_ref : std::false_type {}; - -template -struct is_json_ref> : std::true_type {}; - -////////////////////////// -// aliases for detected // -////////////////////////// - -template -using mapped_type_t = typename T::mapped_type; - -template -using key_type_t = typename T::key_type; - -template -using value_type_t = typename T::value_type; - -template -using difference_type_t = typename T::difference_type; - -template -using pointer_t = typename T::pointer; - -template -using reference_t = typename T::reference; - -template -using iterator_category_t = typename T::iterator_category; - -template -using iterator_t = typename T::iterator; - -template -using to_json_function = decltype(T::to_json(std::declval()...)); - -template -using from_json_function = decltype(T::from_json(std::declval()...)); - -template -using get_template_function = decltype(std::declval().template get()); - -// trait checking if JSONSerializer::from_json(json const&, udt&) exists -template -struct has_from_json : std::false_type {}; - -// trait checking if j.get is valid -// use this trait instead of std::is_constructible or std::is_convertible, -// both rely on, or make use of implicit conversions, and thus fail when T -// has several constructors/operator= (see https://github.com/nlohmann/json/issues/958) -template -struct is_getable -{ - static constexpr bool value = is_detected::value; -}; - -template -struct has_from_json < BasicJsonType, T, - enable_if_t < !is_basic_json::value >> -{ - using serializer = typename BasicJsonType::template json_serializer; - - static constexpr bool value = - is_detected_exact::value; -}; - -// This trait checks if JSONSerializer::from_json(json const&) exists -// this overload is used for non-default-constructible user-defined-types -template -struct has_non_default_from_json : std::false_type {}; - -template -struct has_non_default_from_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> -{ - using serializer = typename BasicJsonType::template json_serializer; - - static constexpr bool value = - is_detected_exact::value; -}; - -// This trait checks if BasicJsonType::json_serializer::to_json exists -// Do not evaluate the trait when T is a basic_json type, to avoid template instantiation infinite recursion. -template -struct has_to_json : std::false_type {}; - -template -struct has_to_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> -{ - using serializer = typename BasicJsonType::template json_serializer; - - static constexpr bool value = - is_detected_exact::value; -}; - - -/////////////////// -// is_ functions // -/////////////////// - -template -struct is_iterator_traits : std::false_type {}; - -template -struct is_iterator_traits> -{ - private: - using traits = iterator_traits; - - public: - static constexpr auto value = - is_detected::value && - is_detected::value && - is_detected::value && - is_detected::value && - is_detected::value; -}; - -// source: https://stackoverflow.com/a/37193089/4116453 - -template -struct is_complete_type : std::false_type {}; - -template -struct is_complete_type : std::true_type {}; - -template -struct is_compatible_object_type_impl : std::false_type {}; - -template -struct is_compatible_object_type_impl < - BasicJsonType, CompatibleObjectType, - enable_if_t < is_detected::value&& - is_detected::value >> -{ - - using object_t = typename BasicJsonType::object_t; - - // macOS's is_constructible does not play well with nonesuch... - static constexpr bool value = - std::is_constructible::value && - std::is_constructible::value; -}; - -template -struct is_compatible_object_type - : is_compatible_object_type_impl {}; - -template -struct is_constructible_object_type_impl : std::false_type {}; - -template -struct is_constructible_object_type_impl < - BasicJsonType, ConstructibleObjectType, - enable_if_t < is_detected::value&& - is_detected::value >> -{ - using object_t = typename BasicJsonType::object_t; - - static constexpr bool value = - (std::is_default_constructible::value && - (std::is_move_assignable::value || - std::is_copy_assignable::value) && - (std::is_constructible::value && - std::is_same < - typename object_t::mapped_type, - typename ConstructibleObjectType::mapped_type >::value)) || - (has_from_json::value || - has_non_default_from_json < - BasicJsonType, - typename ConstructibleObjectType::mapped_type >::value); -}; - -template -struct is_constructible_object_type - : is_constructible_object_type_impl {}; - -template -struct is_compatible_string_type_impl : std::false_type {}; - -template -struct is_compatible_string_type_impl < - BasicJsonType, CompatibleStringType, - enable_if_t::value >> -{ - static constexpr auto value = - std::is_constructible::value; -}; - -template -struct is_compatible_string_type - : is_compatible_string_type_impl {}; - -template -struct is_constructible_string_type_impl : std::false_type {}; - -template -struct is_constructible_string_type_impl < - BasicJsonType, ConstructibleStringType, - enable_if_t::value >> -{ - static constexpr auto value = - std::is_constructible::value; -}; - -template -struct is_constructible_string_type - : is_constructible_string_type_impl {}; - -template -struct is_compatible_array_type_impl : std::false_type {}; - -template -struct is_compatible_array_type_impl < - BasicJsonType, CompatibleArrayType, - enable_if_t < is_detected::value&& - is_detected::value&& -// This is needed because json_reverse_iterator has a ::iterator type... -// Therefore it is detected as a CompatibleArrayType. -// The real fix would be to have an Iterable concept. - !is_iterator_traits < - iterator_traits>::value >> -{ - static constexpr bool value = - std::is_constructible::value; -}; - -template -struct is_compatible_array_type - : is_compatible_array_type_impl {}; - -template -struct is_constructible_array_type_impl : std::false_type {}; - -template -struct is_constructible_array_type_impl < - BasicJsonType, ConstructibleArrayType, - enable_if_t::value >> - : std::true_type {}; - -template -struct is_constructible_array_type_impl < - BasicJsonType, ConstructibleArrayType, - enable_if_t < !std::is_same::value&& - std::is_default_constructible::value&& -(std::is_move_assignable::value || - std::is_copy_assignable::value)&& -is_detected::value&& -is_detected::value&& -is_complete_type < -detected_t>::value >> -{ - static constexpr bool value = - // This is needed because json_reverse_iterator has a ::iterator type, - // furthermore, std::back_insert_iterator (and other iterators) have a - // base class `iterator`... Therefore it is detected as a - // ConstructibleArrayType. The real fix would be to have an Iterable - // concept. - !is_iterator_traits>::value && - - (std::is_same::value || - has_from_json::value || - has_non_default_from_json < - BasicJsonType, typename ConstructibleArrayType::value_type >::value); -}; - -template -struct is_constructible_array_type - : is_constructible_array_type_impl {}; - -template -struct is_compatible_integer_type_impl : std::false_type {}; - -template -struct is_compatible_integer_type_impl < - RealIntegerType, CompatibleNumberIntegerType, - enable_if_t < std::is_integral::value&& - std::is_integral::value&& - !std::is_same::value >> -{ - // is there an assert somewhere on overflows? - using RealLimits = std::numeric_limits; - using CompatibleLimits = std::numeric_limits; - - static constexpr auto value = - std::is_constructible::value && - CompatibleLimits::is_integer && - RealLimits::is_signed == CompatibleLimits::is_signed; -}; - -template -struct is_compatible_integer_type - : is_compatible_integer_type_impl {}; - -template -struct is_compatible_type_impl: std::false_type {}; - -template -struct is_compatible_type_impl < - BasicJsonType, CompatibleType, - enable_if_t::value >> -{ - static constexpr bool value = - has_to_json::value; -}; - -template -struct is_compatible_type - : is_compatible_type_impl {}; - -// https://en.cppreference.com/w/cpp/types/conjunction -template struct conjunction : std::true_type { }; -template struct conjunction : B1 { }; -template -struct conjunction -: std::conditional, B1>::type {}; - -template -struct is_constructible_tuple : std::false_type {}; - -template -struct is_constructible_tuple> : conjunction...> {}; -} // namespace detail -} // namespace nlohmann - -// #include - - -#include // array -#include // size_t -#include // uint8_t -#include // string - -namespace nlohmann -{ -namespace detail -{ -/////////////////////////// -// JSON type enumeration // -/////////////////////////// - -/*! -@brief the JSON type enumeration - -This enumeration collects the different JSON types. It is internally used to -distinguish the stored values, and the functions @ref basic_json::is_null(), -@ref basic_json::is_object(), @ref basic_json::is_array(), -@ref basic_json::is_string(), @ref basic_json::is_boolean(), -@ref basic_json::is_number() (with @ref basic_json::is_number_integer(), -@ref basic_json::is_number_unsigned(), and @ref basic_json::is_number_float()), -@ref basic_json::is_discarded(), @ref basic_json::is_primitive(), and -@ref basic_json::is_structured() rely on it. - -@note There are three enumeration entries (number_integer, number_unsigned, and -number_float), because the library distinguishes these three types for numbers: -@ref basic_json::number_unsigned_t is used for unsigned integers, -@ref basic_json::number_integer_t is used for signed integers, and -@ref basic_json::number_float_t is used for floating-point numbers or to -approximate integers which do not fit in the limits of their respective type. - -@sa @ref basic_json::basic_json(const value_t value_type) -- create a JSON -value with the default value for a given type - -@since version 1.0.0 -*/ -enum class value_t : std::uint8_t -{ - null, ///< null value - object, ///< object (unordered set of name/value pairs) - array, ///< array (ordered collection of values) - string, ///< string value - boolean, ///< boolean value - number_integer, ///< number value (signed integer) - number_unsigned, ///< number value (unsigned integer) - number_float, ///< number value (floating-point) - binary, ///< binary array (ordered collection of bytes) - discarded ///< discarded by the parser callback function -}; - -/*! -@brief comparison operator for JSON types - -Returns an ordering that is similar to Python: -- order: null < boolean < number < object < array < string < binary -- furthermore, each type is not smaller than itself -- discarded values are not comparable -- binary is represented as a b"" string in python and directly comparable to a - string; however, making a binary array directly comparable with a string would - be surprising behavior in a JSON file. - -@since version 1.0.0 -*/ -inline bool operator<(const value_t lhs, const value_t rhs) noexcept -{ - static constexpr std::array order = {{ - 0 /* null */, 3 /* object */, 4 /* array */, 5 /* string */, - 1 /* boolean */, 2 /* integer */, 2 /* unsigned */, 2 /* float */, - 6 /* binary */ - } - }; - - const auto l_index = static_cast(lhs); - const auto r_index = static_cast(rhs); - return l_index < order.size() && r_index < order.size() && order[l_index] < order[r_index]; -} -} // namespace detail -} // namespace nlohmann - - -namespace nlohmann -{ -namespace detail -{ -template -void from_json(const BasicJsonType& j, typename std::nullptr_t& n) -{ - if (JSON_HEDLEY_UNLIKELY(!j.is_null())) - { - JSON_THROW(type_error::create(302, "type must be null, but is " + std::string(j.type_name()))); - } - n = nullptr; -} - -// overloads for basic_json template parameters -template < typename BasicJsonType, typename ArithmeticType, - enable_if_t < std::is_arithmetic::value&& - !std::is_same::value, - int > = 0 > -void get_arithmetic_value(const BasicJsonType& j, ArithmeticType& val) -{ - switch (static_cast(j)) - { - case value_t::number_unsigned: - { - val = static_cast(*j.template get_ptr()); - break; - } - case value_t::number_integer: - { - val = static_cast(*j.template get_ptr()); - break; - } - case value_t::number_float: - { - val = static_cast(*j.template get_ptr()); - break; - } - - default: - JSON_THROW(type_error::create(302, "type must be number, but is " + std::string(j.type_name()))); - } -} - -template -void from_json(const BasicJsonType& j, typename BasicJsonType::boolean_t& b) -{ - if (JSON_HEDLEY_UNLIKELY(!j.is_boolean())) - { - JSON_THROW(type_error::create(302, "type must be boolean, but is " + std::string(j.type_name()))); - } - b = *j.template get_ptr(); -} - -template -void from_json(const BasicJsonType& j, typename BasicJsonType::string_t& s) -{ - if (JSON_HEDLEY_UNLIKELY(!j.is_string())) - { - JSON_THROW(type_error::create(302, "type must be string, but is " + std::string(j.type_name()))); - } - s = *j.template get_ptr(); -} - -template < - typename BasicJsonType, typename ConstructibleStringType, - enable_if_t < - is_constructible_string_type::value&& - !std::is_same::value, - int > = 0 > -void from_json(const BasicJsonType& j, ConstructibleStringType& s) -{ - if (JSON_HEDLEY_UNLIKELY(!j.is_string())) - { - JSON_THROW(type_error::create(302, "type must be string, but is " + std::string(j.type_name()))); - } - - s = *j.template get_ptr(); -} - -template -void from_json(const BasicJsonType& j, typename BasicJsonType::number_float_t& val) -{ - get_arithmetic_value(j, val); -} - -template -void from_json(const BasicJsonType& j, typename BasicJsonType::number_unsigned_t& val) -{ - get_arithmetic_value(j, val); -} - -template -void from_json(const BasicJsonType& j, typename BasicJsonType::number_integer_t& val) -{ - get_arithmetic_value(j, val); -} - -template::value, int> = 0> -void from_json(const BasicJsonType& j, EnumType& e) -{ - typename std::underlying_type::type val; - get_arithmetic_value(j, val); - e = static_cast(val); -} - -// forward_list doesn't have an insert method -template::value, int> = 0> -void from_json(const BasicJsonType& j, std::forward_list& l) -{ - if (JSON_HEDLEY_UNLIKELY(!j.is_array())) - { - JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()))); - } - l.clear(); - std::transform(j.rbegin(), j.rend(), - std::front_inserter(l), [](const BasicJsonType & i) - { - return i.template get(); - }); -} - -// valarray doesn't have an insert method -template::value, int> = 0> -void from_json(const BasicJsonType& j, std::valarray& l) -{ - if (JSON_HEDLEY_UNLIKELY(!j.is_array())) - { - JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()))); - } - l.resize(j.size()); - std::transform(j.begin(), j.end(), std::begin(l), - [](const BasicJsonType & elem) - { - return elem.template get(); - }); -} - -template -auto from_json(const BasicJsonType& j, T (&arr)[N]) --> decltype(j.template get(), void()) -{ - for (std::size_t i = 0; i < N; ++i) - { - arr[i] = j.at(i).template get(); - } -} - -template -void from_json_array_impl(const BasicJsonType& j, typename BasicJsonType::array_t& arr, priority_tag<3> /*unused*/) -{ - arr = *j.template get_ptr(); -} - -template -auto from_json_array_impl(const BasicJsonType& j, std::array& arr, - priority_tag<2> /*unused*/) --> decltype(j.template get(), void()) -{ - for (std::size_t i = 0; i < N; ++i) - { - arr[i] = j.at(i).template get(); - } -} - -template -auto from_json_array_impl(const BasicJsonType& j, ConstructibleArrayType& arr, priority_tag<1> /*unused*/) --> decltype( - arr.reserve(std::declval()), - j.template get(), - void()) -{ - using std::end; - - ConstructibleArrayType ret; - ret.reserve(j.size()); - std::transform(j.begin(), j.end(), - std::inserter(ret, end(ret)), [](const BasicJsonType & i) - { - // get() returns *this, this won't call a from_json - // method when value_type is BasicJsonType - return i.template get(); - }); - arr = std::move(ret); -} - -template -void from_json_array_impl(const BasicJsonType& j, ConstructibleArrayType& arr, - priority_tag<0> /*unused*/) -{ - using std::end; - - ConstructibleArrayType ret; - std::transform( - j.begin(), j.end(), std::inserter(ret, end(ret)), - [](const BasicJsonType & i) - { - // get() returns *this, this won't call a from_json - // method when value_type is BasicJsonType - return i.template get(); - }); - arr = std::move(ret); -} - -template < typename BasicJsonType, typename ConstructibleArrayType, - enable_if_t < - is_constructible_array_type::value&& - !is_constructible_object_type::value&& - !is_constructible_string_type::value&& - !std::is_same::value&& - !is_basic_json::value, - int > = 0 > -auto from_json(const BasicJsonType& j, ConstructibleArrayType& arr) --> decltype(from_json_array_impl(j, arr, priority_tag<3> {}), -j.template get(), -void()) -{ - if (JSON_HEDLEY_UNLIKELY(!j.is_array())) - { - JSON_THROW(type_error::create(302, "type must be array, but is " + - std::string(j.type_name()))); - } - - from_json_array_impl(j, arr, priority_tag<3> {}); -} - -template -void from_json(const BasicJsonType& j, typename BasicJsonType::binary_t& bin) -{ - if (JSON_HEDLEY_UNLIKELY(!j.is_binary())) - { - JSON_THROW(type_error::create(302, "type must be binary, but is " + std::string(j.type_name()))); - } - - bin = *j.template get_ptr(); -} - -template::value, int> = 0> -void from_json(const BasicJsonType& j, ConstructibleObjectType& obj) -{ - if (JSON_HEDLEY_UNLIKELY(!j.is_object())) - { - JSON_THROW(type_error::create(302, "type must be object, but is " + std::string(j.type_name()))); - } - - ConstructibleObjectType ret; - auto inner_object = j.template get_ptr(); - using value_type = typename ConstructibleObjectType::value_type; - std::transform( - inner_object->begin(), inner_object->end(), - std::inserter(ret, ret.begin()), - [](typename BasicJsonType::object_t::value_type const & p) - { - return value_type(p.first, p.second.template get()); - }); - obj = std::move(ret); -} - -// overload for arithmetic types, not chosen for basic_json template arguments -// (BooleanType, etc..); note: Is it really necessary to provide explicit -// overloads for boolean_t etc. in case of a custom BooleanType which is not -// an arithmetic type? -template < typename BasicJsonType, typename ArithmeticType, - enable_if_t < - std::is_arithmetic::value&& - !std::is_same::value&& - !std::is_same::value&& - !std::is_same::value&& - !std::is_same::value, - int > = 0 > -void from_json(const BasicJsonType& j, ArithmeticType& val) -{ - switch (static_cast(j)) - { - case value_t::number_unsigned: - { - val = static_cast(*j.template get_ptr()); - break; - } - case value_t::number_integer: - { - val = static_cast(*j.template get_ptr()); - break; - } - case value_t::number_float: - { - val = static_cast(*j.template get_ptr()); - break; - } - case value_t::boolean: - { - val = static_cast(*j.template get_ptr()); - break; - } - - default: - JSON_THROW(type_error::create(302, "type must be number, but is " + std::string(j.type_name()))); - } -} - -template -void from_json(const BasicJsonType& j, std::pair& p) -{ - p = {j.at(0).template get(), j.at(1).template get()}; -} - -template -void from_json_tuple_impl(const BasicJsonType& j, Tuple& t, index_sequence /*unused*/) -{ - t = std::make_tuple(j.at(Idx).template get::type>()...); -} - -template -void from_json(const BasicJsonType& j, std::tuple& t) -{ - from_json_tuple_impl(j, t, index_sequence_for {}); -} - -template < typename BasicJsonType, typename Key, typename Value, typename Compare, typename Allocator, - typename = enable_if_t < !std::is_constructible < - typename BasicJsonType::string_t, Key >::value >> -void from_json(const BasicJsonType& j, std::map& m) -{ - if (JSON_HEDLEY_UNLIKELY(!j.is_array())) - { - JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()))); - } - m.clear(); - for (const auto& p : j) - { - if (JSON_HEDLEY_UNLIKELY(!p.is_array())) - { - JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(p.type_name()))); - } - m.emplace(p.at(0).template get(), p.at(1).template get()); - } -} - -template < typename BasicJsonType, typename Key, typename Value, typename Hash, typename KeyEqual, typename Allocator, - typename = enable_if_t < !std::is_constructible < - typename BasicJsonType::string_t, Key >::value >> -void from_json(const BasicJsonType& j, std::unordered_map& m) -{ - if (JSON_HEDLEY_UNLIKELY(!j.is_array())) - { - JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(j.type_name()))); - } - m.clear(); - for (const auto& p : j) - { - if (JSON_HEDLEY_UNLIKELY(!p.is_array())) - { - JSON_THROW(type_error::create(302, "type must be array, but is " + std::string(p.type_name()))); - } - m.emplace(p.at(0).template get(), p.at(1).template get()); - } -} - -struct from_json_fn -{ - template - auto operator()(const BasicJsonType& j, T& val) const - noexcept(noexcept(from_json(j, val))) - -> decltype(from_json(j, val), void()) - { - return from_json(j, val); - } -}; -} // namespace detail - -/// namespace to hold default `from_json` function -/// to see why this is required: -/// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4381.html -namespace -{ -constexpr const auto& from_json = detail::static_const::value; -} // namespace -} // namespace nlohmann - -// #include - - -#include // copy -#include // begin, end -#include // string -#include // tuple, get -#include // is_same, is_constructible, is_floating_point, is_enum, underlying_type -#include // move, forward, declval, pair -#include // valarray -#include // vector - -// #include - - -#include // size_t -#include // input_iterator_tag -#include // string, to_string -#include // tuple_size, get, tuple_element - -// #include - -// #include - - -namespace nlohmann -{ -namespace detail -{ -template -void int_to_string( string_type& target, std::size_t value ) -{ - // For ADL - using std::to_string; - target = to_string(value); -} -template class iteration_proxy_value -{ - public: - using difference_type = std::ptrdiff_t; - using value_type = iteration_proxy_value; - using pointer = value_type * ; - using reference = value_type & ; - using iterator_category = std::input_iterator_tag; - using string_type = typename std::remove_cv< typename std::remove_reference().key() ) >::type >::type; - - private: - /// the iterator - IteratorType anchor; - /// an index for arrays (used to create key names) - std::size_t array_index = 0; - /// last stringified array index - mutable std::size_t array_index_last = 0; - /// a string representation of the array index - mutable string_type array_index_str = "0"; - /// an empty string (to return a reference for primitive values) - const string_type empty_str = ""; - - public: - explicit iteration_proxy_value(IteratorType it) noexcept : anchor(it) {} - - /// dereference operator (needed for range-based for) - iteration_proxy_value& operator*() - { - return *this; - } - - /// increment operator (needed for range-based for) - iteration_proxy_value& operator++() - { - ++anchor; - ++array_index; - - return *this; - } - - /// equality operator (needed for InputIterator) - bool operator==(const iteration_proxy_value& o) const - { - return anchor == o.anchor; - } - - /// inequality operator (needed for range-based for) - bool operator!=(const iteration_proxy_value& o) const - { - return anchor != o.anchor; - } - - /// return key of the iterator - const string_type& key() const - { - JSON_ASSERT(anchor.m_object != nullptr); - - switch (anchor.m_object->type()) - { - // use integer array index as key - case value_t::array: - { - if (array_index != array_index_last) - { - int_to_string( array_index_str, array_index ); - array_index_last = array_index; - } - return array_index_str; - } - - // use key from the object - case value_t::object: - return anchor.key(); - - // use an empty key for all primitive types - default: - return empty_str; - } - } - - /// return value of the iterator - typename IteratorType::reference value() const - { - return anchor.value(); - } -}; - -/// proxy class for the items() function -template class iteration_proxy -{ - private: - /// the container to iterate - typename IteratorType::reference container; - - public: - /// construct iteration proxy from a container - explicit iteration_proxy(typename IteratorType::reference cont) noexcept - : container(cont) {} - - /// return iterator begin (needed for range-based for) - iteration_proxy_value begin() noexcept - { - return iteration_proxy_value(container.begin()); - } - - /// return iterator end (needed for range-based for) - iteration_proxy_value end() noexcept - { - return iteration_proxy_value(container.end()); - } -}; -// Structured Bindings Support -// For further reference see https://blog.tartanllama.xyz/structured-bindings/ -// And see https://github.com/nlohmann/json/pull/1391 -template = 0> -auto get(const nlohmann::detail::iteration_proxy_value& i) -> decltype(i.key()) -{ - return i.key(); -} -// Structured Bindings Support -// For further reference see https://blog.tartanllama.xyz/structured-bindings/ -// And see https://github.com/nlohmann/json/pull/1391 -template = 0> -auto get(const nlohmann::detail::iteration_proxy_value& i) -> decltype(i.value()) -{ - return i.value(); -} -} // namespace detail -} // namespace nlohmann - -// The Addition to the STD Namespace is required to add -// Structured Bindings Support to the iteration_proxy_value class -// For further reference see https://blog.tartanllama.xyz/structured-bindings/ -// And see https://github.com/nlohmann/json/pull/1391 -namespace std -{ -#if defined(__clang__) - // Fix: https://github.com/nlohmann/json/issues/1401 - #pragma clang diagnostic push - #pragma clang diagnostic ignored "-Wmismatched-tags" -#endif -template -class tuple_size<::nlohmann::detail::iteration_proxy_value> - : public std::integral_constant {}; - -template -class tuple_element> -{ - public: - using type = decltype( - get(std::declval < - ::nlohmann::detail::iteration_proxy_value> ())); -}; -#if defined(__clang__) - #pragma clang diagnostic pop -#endif -} // namespace std - -// #include - -// #include - -// #include - - -namespace nlohmann -{ -namespace detail -{ -////////////////// -// constructors // -////////////////// - -template struct external_constructor; - -template<> -struct external_constructor -{ - template - static void construct(BasicJsonType& j, typename BasicJsonType::boolean_t b) noexcept - { - j.m_type = value_t::boolean; - j.m_value = b; - j.assert_invariant(); - } -}; - -template<> -struct external_constructor -{ - template - static void construct(BasicJsonType& j, const typename BasicJsonType::string_t& s) - { - j.m_type = value_t::string; - j.m_value = s; - j.assert_invariant(); - } - - template - static void construct(BasicJsonType& j, typename BasicJsonType::string_t&& s) - { - j.m_type = value_t::string; - j.m_value = std::move(s); - j.assert_invariant(); - } - - template < typename BasicJsonType, typename CompatibleStringType, - enable_if_t < !std::is_same::value, - int > = 0 > - static void construct(BasicJsonType& j, const CompatibleStringType& str) - { - j.m_type = value_t::string; - j.m_value.string = j.template create(str); - j.assert_invariant(); - } -}; - -template<> -struct external_constructor -{ - template - static void construct(BasicJsonType& j, const typename BasicJsonType::binary_t& b) - { - j.m_type = value_t::binary; - typename BasicJsonType::binary_t value{b}; - j.m_value = value; - j.assert_invariant(); - } - - template - static void construct(BasicJsonType& j, typename BasicJsonType::binary_t&& b) - { - j.m_type = value_t::binary; - typename BasicJsonType::binary_t value{std::move(b)}; - j.m_value = value; - j.assert_invariant(); - } -}; - -template<> -struct external_constructor -{ - template - static void construct(BasicJsonType& j, typename BasicJsonType::number_float_t val) noexcept - { - j.m_type = value_t::number_float; - j.m_value = val; - j.assert_invariant(); - } -}; - -template<> -struct external_constructor -{ - template - static void construct(BasicJsonType& j, typename BasicJsonType::number_unsigned_t val) noexcept - { - j.m_type = value_t::number_unsigned; - j.m_value = val; - j.assert_invariant(); - } -}; - -template<> -struct external_constructor -{ - template - static void construct(BasicJsonType& j, typename BasicJsonType::number_integer_t val) noexcept - { - j.m_type = value_t::number_integer; - j.m_value = val; - j.assert_invariant(); - } -}; - -template<> -struct external_constructor -{ - template - static void construct(BasicJsonType& j, const typename BasicJsonType::array_t& arr) - { - j.m_type = value_t::array; - j.m_value = arr; - j.assert_invariant(); - } - - template - static void construct(BasicJsonType& j, typename BasicJsonType::array_t&& arr) - { - j.m_type = value_t::array; - j.m_value = std::move(arr); - j.assert_invariant(); - } - - template < typename BasicJsonType, typename CompatibleArrayType, - enable_if_t < !std::is_same::value, - int > = 0 > - static void construct(BasicJsonType& j, const CompatibleArrayType& arr) - { - using std::begin; - using std::end; - j.m_type = value_t::array; - j.m_value.array = j.template create(begin(arr), end(arr)); - j.assert_invariant(); - } - - template - static void construct(BasicJsonType& j, const std::vector& arr) - { - j.m_type = value_t::array; - j.m_value = value_t::array; - j.m_value.array->reserve(arr.size()); - for (const bool x : arr) - { - j.m_value.array->push_back(x); - } - j.assert_invariant(); - } - - template::value, int> = 0> - static void construct(BasicJsonType& j, const std::valarray& arr) - { - j.m_type = value_t::array; - j.m_value = value_t::array; - j.m_value.array->resize(arr.size()); - if (arr.size() > 0) - { - std::copy(std::begin(arr), std::end(arr), j.m_value.array->begin()); - } - j.assert_invariant(); - } -}; - -template<> -struct external_constructor -{ - template - static void construct(BasicJsonType& j, const typename BasicJsonType::object_t& obj) - { - j.m_type = value_t::object; - j.m_value = obj; - j.assert_invariant(); - } - - template - static void construct(BasicJsonType& j, typename BasicJsonType::object_t&& obj) - { - j.m_type = value_t::object; - j.m_value = std::move(obj); - j.assert_invariant(); - } - - template < typename BasicJsonType, typename CompatibleObjectType, - enable_if_t < !std::is_same::value, int > = 0 > - static void construct(BasicJsonType& j, const CompatibleObjectType& obj) - { - using std::begin; - using std::end; - - j.m_type = value_t::object; - j.m_value.object = j.template create(begin(obj), end(obj)); - j.assert_invariant(); - } -}; - -///////////// -// to_json // -///////////// - -template::value, int> = 0> -void to_json(BasicJsonType& j, T b) noexcept -{ - external_constructor::construct(j, b); -} - -template::value, int> = 0> -void to_json(BasicJsonType& j, const CompatibleString& s) -{ - external_constructor::construct(j, s); -} - -template -void to_json(BasicJsonType& j, typename BasicJsonType::string_t&& s) -{ - external_constructor::construct(j, std::move(s)); -} - -template::value, int> = 0> -void to_json(BasicJsonType& j, FloatType val) noexcept -{ - external_constructor::construct(j, static_cast(val)); -} - -template::value, int> = 0> -void to_json(BasicJsonType& j, CompatibleNumberUnsignedType val) noexcept -{ - external_constructor::construct(j, static_cast(val)); -} - -template::value, int> = 0> -void to_json(BasicJsonType& j, CompatibleNumberIntegerType val) noexcept -{ - external_constructor::construct(j, static_cast(val)); -} - -template::value, int> = 0> -void to_json(BasicJsonType& j, EnumType e) noexcept -{ - using underlying_type = typename std::underlying_type::type; - external_constructor::construct(j, static_cast(e)); -} - -template -void to_json(BasicJsonType& j, const std::vector& e) -{ - external_constructor::construct(j, e); -} - -template < typename BasicJsonType, typename CompatibleArrayType, - enable_if_t < is_compatible_array_type::value&& - !is_compatible_object_type::value&& - !is_compatible_string_type::value&& - !std::is_same::value&& - !is_basic_json::value, - int > = 0 > -void to_json(BasicJsonType& j, const CompatibleArrayType& arr) -{ - external_constructor::construct(j, arr); -} - -template -void to_json(BasicJsonType& j, const typename BasicJsonType::binary_t& bin) -{ - external_constructor::construct(j, bin); -} - -template::value, int> = 0> -void to_json(BasicJsonType& j, const std::valarray& arr) -{ - external_constructor::construct(j, std::move(arr)); -} - -template -void to_json(BasicJsonType& j, typename BasicJsonType::array_t&& arr) -{ - external_constructor::construct(j, std::move(arr)); -} - -template < typename BasicJsonType, typename CompatibleObjectType, - enable_if_t < is_compatible_object_type::value&& !is_basic_json::value, int > = 0 > -void to_json(BasicJsonType& j, const CompatibleObjectType& obj) -{ - external_constructor::construct(j, obj); -} - -template -void to_json(BasicJsonType& j, typename BasicJsonType::object_t&& obj) -{ - external_constructor::construct(j, std::move(obj)); -} - -template < - typename BasicJsonType, typename T, std::size_t N, - enable_if_t < !std::is_constructible::value, - int > = 0 > -void to_json(BasicJsonType& j, const T(&arr)[N]) -{ - external_constructor::construct(j, arr); -} - -template < typename BasicJsonType, typename T1, typename T2, enable_if_t < std::is_constructible::value&& std::is_constructible::value, int > = 0 > -void to_json(BasicJsonType& j, const std::pair& p) -{ - j = { p.first, p.second }; -} - -// for https://github.com/nlohmann/json/pull/1134 -template>::value, int> = 0> -void to_json(BasicJsonType& j, const T& b) -{ - j = { {b.key(), b.value()} }; -} - -template -void to_json_tuple_impl(BasicJsonType& j, const Tuple& t, index_sequence /*unused*/) -{ - j = { std::get(t)... }; -} - -template::value, int > = 0> -void to_json(BasicJsonType& j, const T& t) -{ - to_json_tuple_impl(j, t, make_index_sequence::value> {}); -} - -struct to_json_fn -{ - template - auto operator()(BasicJsonType& j, T&& val) const noexcept(noexcept(to_json(j, std::forward(val)))) - -> decltype(to_json(j, std::forward(val)), void()) - { - return to_json(j, std::forward(val)); - } -}; -} // namespace detail - -/// namespace to hold default `to_json` function -namespace -{ -constexpr const auto& to_json = detail::static_const::value; -} // namespace -} // namespace nlohmann - - -namespace nlohmann -{ - -template -struct adl_serializer -{ - /*! - @brief convert a JSON value to any value type - - This function is usually called by the `get()` function of the - @ref basic_json class (either explicit or via conversion operators). - - @param[in] j JSON value to read from - @param[in,out] val value to write to - */ - template - static auto from_json(BasicJsonType&& j, ValueType& val) noexcept( - noexcept(::nlohmann::from_json(std::forward(j), val))) - -> decltype(::nlohmann::from_json(std::forward(j), val), void()) - { - ::nlohmann::from_json(std::forward(j), val); - } - - /*! - @brief convert any value type to a JSON value - - This function is usually called by the constructors of the @ref basic_json - class. - - @param[in,out] j JSON value to write to - @param[in] val value to read from - */ - template - static auto to_json(BasicJsonType& j, ValueType&& val) noexcept( - noexcept(::nlohmann::to_json(j, std::forward(val)))) - -> decltype(::nlohmann::to_json(j, std::forward(val)), void()) - { - ::nlohmann::to_json(j, std::forward(val)); - } -}; - -} // namespace nlohmann - -// #include - - -#include // uint8_t -#include // tie -#include // move - -namespace nlohmann -{ - -/*! -@brief an internal type for a backed binary type - -This type extends the template parameter @a BinaryType provided to `basic_json` -with a subtype used by BSON and MessagePack. This type exists so that the user -does not have to specify a type themselves with a specific naming scheme in -order to override the binary type. - -@tparam BinaryType container to store bytes (`std::vector` by - default) - -@since version 3.8.0 -*/ -template -class byte_container_with_subtype : public BinaryType -{ - public: - /// the type of the underlying container - using container_type = BinaryType; - - byte_container_with_subtype() noexcept(noexcept(container_type())) - : container_type() - {} - - byte_container_with_subtype(const container_type& b) noexcept(noexcept(container_type(b))) - : container_type(b) - {} - - byte_container_with_subtype(container_type&& b) noexcept(noexcept(container_type(std::move(b)))) - : container_type(std::move(b)) - {} - - byte_container_with_subtype(const container_type& b, std::uint8_t subtype) noexcept(noexcept(container_type(b))) - : container_type(b) - , m_subtype(subtype) - , m_has_subtype(true) - {} - - byte_container_with_subtype(container_type&& b, std::uint8_t subtype) noexcept(noexcept(container_type(std::move(b)))) - : container_type(std::move(b)) - , m_subtype(subtype) - , m_has_subtype(true) - {} - - bool operator==(const byte_container_with_subtype& rhs) const - { - return std::tie(static_cast(*this), m_subtype, m_has_subtype) == - std::tie(static_cast(rhs), rhs.m_subtype, rhs.m_has_subtype); - } - - bool operator!=(const byte_container_with_subtype& rhs) const - { - return !(rhs == *this); - } - - /*! - @brief sets the binary subtype - - Sets the binary subtype of the value, also flags a binary JSON value as - having a subtype, which has implications for serialization. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @sa @ref subtype() -- return the binary subtype - @sa @ref clear_subtype() -- clears the binary subtype - @sa @ref has_subtype() -- returns whether or not the binary value has a - subtype - - @since version 3.8.0 - */ - void set_subtype(std::uint8_t subtype) noexcept - { - m_subtype = subtype; - m_has_subtype = true; - } - - /*! - @brief return the binary subtype - - Returns the numerical subtype of the value if it has a subtype. If it does - not have a subtype, this function will return size_t(-1) as a sentinel - value. - - @return the numerical subtype of the binary value - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @sa @ref set_subtype() -- sets the binary subtype - @sa @ref clear_subtype() -- clears the binary subtype - @sa @ref has_subtype() -- returns whether or not the binary value has a - subtype - - @since version 3.8.0 - */ - constexpr std::uint8_t subtype() const noexcept - { - return m_subtype; - } - - /*! - @brief return whether the value has a subtype - - @return whether the value has a subtype - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @sa @ref subtype() -- return the binary subtype - @sa @ref set_subtype() -- sets the binary subtype - @sa @ref clear_subtype() -- clears the binary subtype - - @since version 3.8.0 - */ - constexpr bool has_subtype() const noexcept - { - return m_has_subtype; - } - - /*! - @brief clears the binary subtype - - Clears the binary subtype and flags the value as not having a subtype, which - has implications for serialization; for instance MessagePack will prefer the - bin family over the ext family. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @sa @ref subtype() -- return the binary subtype - @sa @ref set_subtype() -- sets the binary subtype - @sa @ref has_subtype() -- returns whether or not the binary value has a - subtype - - @since version 3.8.0 - */ - void clear_subtype() noexcept - { - m_subtype = 0; - m_has_subtype = false; - } - - private: - std::uint8_t m_subtype = 0; - bool m_has_subtype = false; -}; - -} // namespace nlohmann - -// #include - -// #include - -// #include - -// #include - - -#include // size_t, uint8_t -#include // hash - -namespace nlohmann -{ -namespace detail -{ - -// boost::hash_combine -inline std::size_t combine(std::size_t seed, std::size_t h) noexcept -{ - seed ^= h + 0x9e3779b9 + (seed << 6U) + (seed >> 2U); - return seed; -} - -/*! -@brief hash a JSON value - -The hash function tries to rely on std::hash where possible. Furthermore, the -type of the JSON value is taken into account to have different hash values for -null, 0, 0U, and false, etc. - -@tparam BasicJsonType basic_json specialization -@param j JSON value to hash -@return hash value of j -*/ -template -std::size_t hash(const BasicJsonType& j) -{ - using string_t = typename BasicJsonType::string_t; - using number_integer_t = typename BasicJsonType::number_integer_t; - using number_unsigned_t = typename BasicJsonType::number_unsigned_t; - using number_float_t = typename BasicJsonType::number_float_t; - - const auto type = static_cast(j.type()); - switch (j.type()) - { - case BasicJsonType::value_t::null: - case BasicJsonType::value_t::discarded: - { - return combine(type, 0); - } - - case BasicJsonType::value_t::object: - { - auto seed = combine(type, j.size()); - for (const auto& element : j.items()) - { - const auto h = std::hash {}(element.key()); - seed = combine(seed, h); - seed = combine(seed, hash(element.value())); - } - return seed; - } - - case BasicJsonType::value_t::array: - { - auto seed = combine(type, j.size()); - for (const auto& element : j) - { - seed = combine(seed, hash(element)); - } - return seed; - } - - case BasicJsonType::value_t::string: - { - const auto h = std::hash {}(j.template get_ref()); - return combine(type, h); - } - - case BasicJsonType::value_t::boolean: - { - const auto h = std::hash {}(j.template get()); - return combine(type, h); - } - - case BasicJsonType::value_t::number_integer: - { - const auto h = std::hash {}(j.template get()); - return combine(type, h); - } - - case nlohmann::detail::value_t::number_unsigned: - { - const auto h = std::hash {}(j.template get()); - return combine(type, h); - } - - case nlohmann::detail::value_t::number_float: - { - const auto h = std::hash {}(j.template get()); - return combine(type, h); - } - - case nlohmann::detail::value_t::binary: - { - auto seed = combine(type, j.get_binary().size()); - const auto h = std::hash {}(j.get_binary().has_subtype()); - seed = combine(seed, h); - seed = combine(seed, j.get_binary().subtype()); - for (const auto byte : j.get_binary()) - { - seed = combine(seed, std::hash {}(byte)); - } - return seed; - } - - default: // LCOV_EXCL_LINE - JSON_ASSERT(false); // LCOV_EXCL_LINE - } -} - -} // namespace detail -} // namespace nlohmann - -// #include - - -#include // generate_n -#include // array -#include // ldexp -#include // size_t -#include // uint8_t, uint16_t, uint32_t, uint64_t -#include // snprintf -#include // memcpy -#include // back_inserter -#include // numeric_limits -#include // char_traits, string -#include // make_pair, move - -// #include - -// #include - - -#include // array -#include // size_t -#include //FILE * -#include // strlen -#include // istream -#include // begin, end, iterator_traits, random_access_iterator_tag, distance, next -#include // shared_ptr, make_shared, addressof -#include // accumulate -#include // string, char_traits -#include // enable_if, is_base_of, is_pointer, is_integral, remove_pointer -#include // pair, declval - -// #include - -// #include - - -namespace nlohmann -{ -namespace detail -{ -/// the supported input formats -enum class input_format_t { json, cbor, msgpack, ubjson, bson }; - -//////////////////// -// input adapters // -//////////////////// - -/*! -Input adapter for stdio file access. This adapter read only 1 byte and do not use any - buffer. This adapter is a very low level adapter. -*/ -class file_input_adapter -{ - public: - using char_type = char; - - JSON_HEDLEY_NON_NULL(2) - explicit file_input_adapter(std::FILE* f) noexcept - : m_file(f) - {} - - // make class move-only - file_input_adapter(const file_input_adapter&) = delete; - file_input_adapter(file_input_adapter&&) = default; - file_input_adapter& operator=(const file_input_adapter&) = delete; - file_input_adapter& operator=(file_input_adapter&&) = delete; - - std::char_traits::int_type get_character() noexcept - { - return std::fgetc(m_file); - } - - private: - /// the file pointer to read from - std::FILE* m_file; -}; - - -/*! -Input adapter for a (caching) istream. Ignores a UFT Byte Order Mark at -beginning of input. Does not support changing the underlying std::streambuf -in mid-input. Maintains underlying std::istream and std::streambuf to support -subsequent use of standard std::istream operations to process any input -characters following those used in parsing the JSON input. Clears the -std::istream flags; any input errors (e.g., EOF) will be detected by the first -subsequent call for input from the std::istream. -*/ -class input_stream_adapter -{ - public: - using char_type = char; - - ~input_stream_adapter() - { - // clear stream flags; we use underlying streambuf I/O, do not - // maintain ifstream flags, except eof - if (is != nullptr) - { - is->clear(is->rdstate() & std::ios::eofbit); - } - } - - explicit input_stream_adapter(std::istream& i) - : is(&i), sb(i.rdbuf()) - {} - - // delete because of pointer members - input_stream_adapter(const input_stream_adapter&) = delete; - input_stream_adapter& operator=(input_stream_adapter&) = delete; - input_stream_adapter& operator=(input_stream_adapter&& rhs) = delete; - - input_stream_adapter(input_stream_adapter&& rhs) noexcept : is(rhs.is), sb(rhs.sb) - { - rhs.is = nullptr; - rhs.sb = nullptr; - } - - // std::istream/std::streambuf use std::char_traits::to_int_type, to - // ensure that std::char_traits::eof() and the character 0xFF do not - // end up as the same value, eg. 0xFFFFFFFF. - std::char_traits::int_type get_character() - { - auto res = sb->sbumpc(); - // set eof manually, as we don't use the istream interface. - if (JSON_HEDLEY_UNLIKELY(res == EOF)) - { - is->clear(is->rdstate() | std::ios::eofbit); - } - return res; - } - - private: - /// the associated input stream - std::istream* is = nullptr; - std::streambuf* sb = nullptr; -}; - -// General-purpose iterator-based adapter. It might not be as fast as -// theoretically possible for some containers, but it is extremely versatile. -template -class iterator_input_adapter -{ - public: - using char_type = typename std::iterator_traits::value_type; - - iterator_input_adapter(IteratorType first, IteratorType last) - : current(std::move(first)), end(std::move(last)) {} - - typename std::char_traits::int_type get_character() - { - if (JSON_HEDLEY_LIKELY(current != end)) - { - auto result = std::char_traits::to_int_type(*current); - std::advance(current, 1); - return result; - } - else - { - return std::char_traits::eof(); - } - } - - private: - IteratorType current; - IteratorType end; - - template - friend struct wide_string_input_helper; - - bool empty() const - { - return current == end; - } - -}; - - -template -struct wide_string_input_helper; - -template -struct wide_string_input_helper -{ - // UTF-32 - static void fill_buffer(BaseInputAdapter& input, - std::array::int_type, 4>& utf8_bytes, - size_t& utf8_bytes_index, - size_t& utf8_bytes_filled) - { - utf8_bytes_index = 0; - - if (JSON_HEDLEY_UNLIKELY(input.empty())) - { - utf8_bytes[0] = std::char_traits::eof(); - utf8_bytes_filled = 1; - } - else - { - // get the current character - const auto wc = input.get_character(); - - // UTF-32 to UTF-8 encoding - if (wc < 0x80) - { - utf8_bytes[0] = static_cast::int_type>(wc); - utf8_bytes_filled = 1; - } - else if (wc <= 0x7FF) - { - utf8_bytes[0] = static_cast::int_type>(0xC0u | ((static_cast(wc) >> 6u) & 0x1Fu)); - utf8_bytes[1] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); - utf8_bytes_filled = 2; - } - else if (wc <= 0xFFFF) - { - utf8_bytes[0] = static_cast::int_type>(0xE0u | ((static_cast(wc) >> 12u) & 0x0Fu)); - utf8_bytes[1] = static_cast::int_type>(0x80u | ((static_cast(wc) >> 6u) & 0x3Fu)); - utf8_bytes[2] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); - utf8_bytes_filled = 3; - } - else if (wc <= 0x10FFFF) - { - utf8_bytes[0] = static_cast::int_type>(0xF0u | ((static_cast(wc) >> 18u) & 0x07u)); - utf8_bytes[1] = static_cast::int_type>(0x80u | ((static_cast(wc) >> 12u) & 0x3Fu)); - utf8_bytes[2] = static_cast::int_type>(0x80u | ((static_cast(wc) >> 6u) & 0x3Fu)); - utf8_bytes[3] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); - utf8_bytes_filled = 4; - } - else - { - // unknown character - utf8_bytes[0] = static_cast::int_type>(wc); - utf8_bytes_filled = 1; - } - } - } -}; - -template -struct wide_string_input_helper -{ - // UTF-16 - static void fill_buffer(BaseInputAdapter& input, - std::array::int_type, 4>& utf8_bytes, - size_t& utf8_bytes_index, - size_t& utf8_bytes_filled) - { - utf8_bytes_index = 0; - - if (JSON_HEDLEY_UNLIKELY(input.empty())) - { - utf8_bytes[0] = std::char_traits::eof(); - utf8_bytes_filled = 1; - } - else - { - // get the current character - const auto wc = input.get_character(); - - // UTF-16 to UTF-8 encoding - if (wc < 0x80) - { - utf8_bytes[0] = static_cast::int_type>(wc); - utf8_bytes_filled = 1; - } - else if (wc <= 0x7FF) - { - utf8_bytes[0] = static_cast::int_type>(0xC0u | ((static_cast(wc) >> 6u))); - utf8_bytes[1] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); - utf8_bytes_filled = 2; - } - else if (0xD800 > wc || wc >= 0xE000) - { - utf8_bytes[0] = static_cast::int_type>(0xE0u | ((static_cast(wc) >> 12u))); - utf8_bytes[1] = static_cast::int_type>(0x80u | ((static_cast(wc) >> 6u) & 0x3Fu)); - utf8_bytes[2] = static_cast::int_type>(0x80u | (static_cast(wc) & 0x3Fu)); - utf8_bytes_filled = 3; - } - else - { - if (JSON_HEDLEY_UNLIKELY(!input.empty())) - { - const auto wc2 = static_cast(input.get_character()); - const auto charcode = 0x10000u + (((static_cast(wc) & 0x3FFu) << 10u) | (wc2 & 0x3FFu)); - utf8_bytes[0] = static_cast::int_type>(0xF0u | (charcode >> 18u)); - utf8_bytes[1] = static_cast::int_type>(0x80u | ((charcode >> 12u) & 0x3Fu)); - utf8_bytes[2] = static_cast::int_type>(0x80u | ((charcode >> 6u) & 0x3Fu)); - utf8_bytes[3] = static_cast::int_type>(0x80u | (charcode & 0x3Fu)); - utf8_bytes_filled = 4; - } - else - { - utf8_bytes[0] = static_cast::int_type>(wc); - utf8_bytes_filled = 1; - } - } - } - } -}; - -// Wraps another input apdater to convert wide character types into individual bytes. -template -class wide_string_input_adapter -{ - public: - using char_type = char; - - wide_string_input_adapter(BaseInputAdapter base) - : base_adapter(base) {} - - typename std::char_traits::int_type get_character() noexcept - { - // check if buffer needs to be filled - if (utf8_bytes_index == utf8_bytes_filled) - { - fill_buffer(); - - JSON_ASSERT(utf8_bytes_filled > 0); - JSON_ASSERT(utf8_bytes_index == 0); - } - - // use buffer - JSON_ASSERT(utf8_bytes_filled > 0); - JSON_ASSERT(utf8_bytes_index < utf8_bytes_filled); - return utf8_bytes[utf8_bytes_index++]; - } - - private: - BaseInputAdapter base_adapter; - - template - void fill_buffer() - { - wide_string_input_helper::fill_buffer(base_adapter, utf8_bytes, utf8_bytes_index, utf8_bytes_filled); - } - - /// a buffer for UTF-8 bytes - std::array::int_type, 4> utf8_bytes = {{0, 0, 0, 0}}; - - /// index to the utf8_codes array for the next valid byte - std::size_t utf8_bytes_index = 0; - /// number of valid bytes in the utf8_codes array - std::size_t utf8_bytes_filled = 0; -}; - - -template -struct iterator_input_adapter_factory -{ - using iterator_type = IteratorType; - using char_type = typename std::iterator_traits::value_type; - using adapter_type = iterator_input_adapter; - - static adapter_type create(IteratorType first, IteratorType last) - { - return adapter_type(std::move(first), std::move(last)); - } -}; - -template -struct is_iterator_of_multibyte -{ - using value_type = typename std::iterator_traits::value_type; - enum - { - value = sizeof(value_type) > 1 - }; -}; - -template -struct iterator_input_adapter_factory::value>> -{ - using iterator_type = IteratorType; - using char_type = typename std::iterator_traits::value_type; - using base_adapter_type = iterator_input_adapter; - using adapter_type = wide_string_input_adapter; - - static adapter_type create(IteratorType first, IteratorType last) - { - return adapter_type(base_adapter_type(std::move(first), std::move(last))); - } -}; - -// General purpose iterator-based input -template -typename iterator_input_adapter_factory::adapter_type input_adapter(IteratorType first, IteratorType last) -{ - using factory_type = iterator_input_adapter_factory; - return factory_type::create(first, last); -} - -// Convenience shorthand from container to iterator -template -auto input_adapter(const ContainerType& container) -> decltype(input_adapter(begin(container), end(container))) -{ - // Enable ADL - using std::begin; - using std::end; - - return input_adapter(begin(container), end(container)); -} - -// Special cases with fast paths -inline file_input_adapter input_adapter(std::FILE* file) -{ - return file_input_adapter(file); -} - -inline input_stream_adapter input_adapter(std::istream& stream) -{ - return input_stream_adapter(stream); -} - -inline input_stream_adapter input_adapter(std::istream&& stream) -{ - return input_stream_adapter(stream); -} - -using contiguous_bytes_input_adapter = decltype(input_adapter(std::declval(), std::declval())); - -// Null-delimited strings, and the like. -template < typename CharT, - typename std::enable_if < - std::is_pointer::value&& - !std::is_array::value&& - std::is_integral::type>::value&& - sizeof(typename std::remove_pointer::type) == 1, - int >::type = 0 > -contiguous_bytes_input_adapter input_adapter(CharT b) -{ - auto length = std::strlen(reinterpret_cast(b)); - const auto* ptr = reinterpret_cast(b); - return input_adapter(ptr, ptr + length); -} - -template -auto input_adapter(T (&array)[N]) -> decltype(input_adapter(array, array + N)) -{ - return input_adapter(array, array + N); -} - -// This class only handles inputs of input_buffer_adapter type. -// It's required so that expressions like {ptr, len} can be implicitely casted -// to the correct adapter. -class span_input_adapter -{ - public: - template < typename CharT, - typename std::enable_if < - std::is_pointer::value&& - std::is_integral::type>::value&& - sizeof(typename std::remove_pointer::type) == 1, - int >::type = 0 > - span_input_adapter(CharT b, std::size_t l) - : ia(reinterpret_cast(b), reinterpret_cast(b) + l) {} - - template::iterator_category, std::random_access_iterator_tag>::value, - int>::type = 0> - span_input_adapter(IteratorType first, IteratorType last) - : ia(input_adapter(first, last)) {} - - contiguous_bytes_input_adapter&& get() - { - return std::move(ia); - } - - private: - contiguous_bytes_input_adapter ia; -}; -} // namespace detail -} // namespace nlohmann - -// #include - - -#include -#include // string -#include // move -#include // vector - -// #include - -// #include - - -namespace nlohmann -{ - -/*! -@brief SAX interface - -This class describes the SAX interface used by @ref nlohmann::json::sax_parse. -Each function is called in different situations while the input is parsed. The -boolean return value informs the parser whether to continue processing the -input. -*/ -template -struct json_sax -{ - using number_integer_t = typename BasicJsonType::number_integer_t; - using number_unsigned_t = typename BasicJsonType::number_unsigned_t; - using number_float_t = typename BasicJsonType::number_float_t; - using string_t = typename BasicJsonType::string_t; - using binary_t = typename BasicJsonType::binary_t; - - /*! - @brief a null value was read - @return whether parsing should proceed - */ - virtual bool null() = 0; - - /*! - @brief a boolean value was read - @param[in] val boolean value - @return whether parsing should proceed - */ - virtual bool boolean(bool val) = 0; - - /*! - @brief an integer number was read - @param[in] val integer value - @return whether parsing should proceed - */ - virtual bool number_integer(number_integer_t val) = 0; - - /*! - @brief an unsigned integer number was read - @param[in] val unsigned integer value - @return whether parsing should proceed - */ - virtual bool number_unsigned(number_unsigned_t val) = 0; - - /*! - @brief an floating-point number was read - @param[in] val floating-point value - @param[in] s raw token value - @return whether parsing should proceed - */ - virtual bool number_float(number_float_t val, const string_t& s) = 0; - - /*! - @brief a string was read - @param[in] val string value - @return whether parsing should proceed - @note It is safe to move the passed string. - */ - virtual bool string(string_t& val) = 0; - - /*! - @brief a binary string was read - @param[in] val binary value - @return whether parsing should proceed - @note It is safe to move the passed binary. - */ - virtual bool binary(binary_t& val) = 0; - - /*! - @brief the beginning of an object was read - @param[in] elements number of object elements or -1 if unknown - @return whether parsing should proceed - @note binary formats may report the number of elements - */ - virtual bool start_object(std::size_t elements) = 0; - - /*! - @brief an object key was read - @param[in] val object key - @return whether parsing should proceed - @note It is safe to move the passed string. - */ - virtual bool key(string_t& val) = 0; - - /*! - @brief the end of an object was read - @return whether parsing should proceed - */ - virtual bool end_object() = 0; - - /*! - @brief the beginning of an array was read - @param[in] elements number of array elements or -1 if unknown - @return whether parsing should proceed - @note binary formats may report the number of elements - */ - virtual bool start_array(std::size_t elements) = 0; - - /*! - @brief the end of an array was read - @return whether parsing should proceed - */ - virtual bool end_array() = 0; - - /*! - @brief a parse error occurred - @param[in] position the position in the input where the error occurs - @param[in] last_token the last read token - @param[in] ex an exception object describing the error - @return whether parsing should proceed (must return false) - */ - virtual bool parse_error(std::size_t position, - const std::string& last_token, - const detail::exception& ex) = 0; - - virtual ~json_sax() = default; -}; - - -namespace detail -{ -/*! -@brief SAX implementation to create a JSON value from SAX events - -This class implements the @ref json_sax interface and processes the SAX events -to create a JSON value which makes it basically a DOM parser. The structure or -hierarchy of the JSON value is managed by the stack `ref_stack` which contains -a pointer to the respective array or object for each recursion depth. - -After successful parsing, the value that is passed by reference to the -constructor contains the parsed value. - -@tparam BasicJsonType the JSON type -*/ -template -class json_sax_dom_parser -{ - public: - using number_integer_t = typename BasicJsonType::number_integer_t; - using number_unsigned_t = typename BasicJsonType::number_unsigned_t; - using number_float_t = typename BasicJsonType::number_float_t; - using string_t = typename BasicJsonType::string_t; - using binary_t = typename BasicJsonType::binary_t; - - /*! - @param[in, out] r reference to a JSON value that is manipulated while - parsing - @param[in] allow_exceptions_ whether parse errors yield exceptions - */ - explicit json_sax_dom_parser(BasicJsonType& r, const bool allow_exceptions_ = true) - : root(r), allow_exceptions(allow_exceptions_) - {} - - // make class move-only - json_sax_dom_parser(const json_sax_dom_parser&) = delete; - json_sax_dom_parser(json_sax_dom_parser&&) = default; - json_sax_dom_parser& operator=(const json_sax_dom_parser&) = delete; - json_sax_dom_parser& operator=(json_sax_dom_parser&&) = default; - ~json_sax_dom_parser() = default; - - bool null() - { - handle_value(nullptr); - return true; - } - - bool boolean(bool val) - { - handle_value(val); - return true; - } - - bool number_integer(number_integer_t val) - { - handle_value(val); - return true; - } - - bool number_unsigned(number_unsigned_t val) - { - handle_value(val); - return true; - } - - bool number_float(number_float_t val, const string_t& /*unused*/) - { - handle_value(val); - return true; - } - - bool string(string_t& val) - { - handle_value(val); - return true; - } - - bool binary(binary_t& val) - { - handle_value(std::move(val)); - return true; - } - - bool start_object(std::size_t len) - { - ref_stack.push_back(handle_value(BasicJsonType::value_t::object)); - - if (JSON_HEDLEY_UNLIKELY(len != std::size_t(-1) && len > ref_stack.back()->max_size())) - { - JSON_THROW(out_of_range::create(408, - "excessive object size: " + std::to_string(len))); - } - - return true; - } - - bool key(string_t& val) - { - // add null at given key and store the reference for later - object_element = &(ref_stack.back()->m_value.object->operator[](val)); - return true; - } - - bool end_object() - { - ref_stack.pop_back(); - return true; - } - - bool start_array(std::size_t len) - { - ref_stack.push_back(handle_value(BasicJsonType::value_t::array)); - - if (JSON_HEDLEY_UNLIKELY(len != std::size_t(-1) && len > ref_stack.back()->max_size())) - { - JSON_THROW(out_of_range::create(408, - "excessive array size: " + std::to_string(len))); - } - - return true; - } - - bool end_array() - { - ref_stack.pop_back(); - return true; - } - - template - bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/, - const Exception& ex) - { - errored = true; - static_cast(ex); - if (allow_exceptions) - { - JSON_THROW(ex); - } - return false; - } - - constexpr bool is_errored() const - { - return errored; - } - - private: - /*! - @invariant If the ref stack is empty, then the passed value will be the new - root. - @invariant If the ref stack contains a value, then it is an array or an - object to which we can add elements - */ - template - JSON_HEDLEY_RETURNS_NON_NULL - BasicJsonType* handle_value(Value&& v) - { - if (ref_stack.empty()) - { - root = BasicJsonType(std::forward(v)); - return &root; - } - - JSON_ASSERT(ref_stack.back()->is_array() || ref_stack.back()->is_object()); - - if (ref_stack.back()->is_array()) - { - ref_stack.back()->m_value.array->emplace_back(std::forward(v)); - return &(ref_stack.back()->m_value.array->back()); - } - - JSON_ASSERT(ref_stack.back()->is_object()); - JSON_ASSERT(object_element); - *object_element = BasicJsonType(std::forward(v)); - return object_element; - } - - /// the parsed JSON value - BasicJsonType& root; - /// stack to model hierarchy of values - std::vector ref_stack {}; - /// helper to hold the reference for the next object element - BasicJsonType* object_element = nullptr; - /// whether a syntax error occurred - bool errored = false; - /// whether to throw exceptions in case of errors - const bool allow_exceptions = true; -}; - -template -class json_sax_dom_callback_parser -{ - public: - using number_integer_t = typename BasicJsonType::number_integer_t; - using number_unsigned_t = typename BasicJsonType::number_unsigned_t; - using number_float_t = typename BasicJsonType::number_float_t; - using string_t = typename BasicJsonType::string_t; - using binary_t = typename BasicJsonType::binary_t; - using parser_callback_t = typename BasicJsonType::parser_callback_t; - using parse_event_t = typename BasicJsonType::parse_event_t; - - json_sax_dom_callback_parser(BasicJsonType& r, - const parser_callback_t cb, - const bool allow_exceptions_ = true) - : root(r), callback(cb), allow_exceptions(allow_exceptions_) - { - keep_stack.push_back(true); - } - - // make class move-only - json_sax_dom_callback_parser(const json_sax_dom_callback_parser&) = delete; - json_sax_dom_callback_parser(json_sax_dom_callback_parser&&) = default; - json_sax_dom_callback_parser& operator=(const json_sax_dom_callback_parser&) = delete; - json_sax_dom_callback_parser& operator=(json_sax_dom_callback_parser&&) = default; - ~json_sax_dom_callback_parser() = default; - - bool null() - { - handle_value(nullptr); - return true; - } - - bool boolean(bool val) - { - handle_value(val); - return true; - } - - bool number_integer(number_integer_t val) - { - handle_value(val); - return true; - } - - bool number_unsigned(number_unsigned_t val) - { - handle_value(val); - return true; - } - - bool number_float(number_float_t val, const string_t& /*unused*/) - { - handle_value(val); - return true; - } - - bool string(string_t& val) - { - handle_value(val); - return true; - } - - bool binary(binary_t& val) - { - handle_value(std::move(val)); - return true; - } - - bool start_object(std::size_t len) - { - // check callback for object start - const bool keep = callback(static_cast(ref_stack.size()), parse_event_t::object_start, discarded); - keep_stack.push_back(keep); - - auto val = handle_value(BasicJsonType::value_t::object, true); - ref_stack.push_back(val.second); - - // check object limit - if (ref_stack.back() && JSON_HEDLEY_UNLIKELY(len != std::size_t(-1) && len > ref_stack.back()->max_size())) - { - JSON_THROW(out_of_range::create(408, "excessive object size: " + std::to_string(len))); - } - - return true; - } - - bool key(string_t& val) - { - BasicJsonType k = BasicJsonType(val); - - // check callback for key - const bool keep = callback(static_cast(ref_stack.size()), parse_event_t::key, k); - key_keep_stack.push_back(keep); - - // add discarded value at given key and store the reference for later - if (keep && ref_stack.back()) - { - object_element = &(ref_stack.back()->m_value.object->operator[](val) = discarded); - } - - return true; - } - - bool end_object() - { - if (ref_stack.back() && !callback(static_cast(ref_stack.size()) - 1, parse_event_t::object_end, *ref_stack.back())) - { - // discard object - *ref_stack.back() = discarded; - } - - JSON_ASSERT(!ref_stack.empty()); - JSON_ASSERT(!keep_stack.empty()); - ref_stack.pop_back(); - keep_stack.pop_back(); - - if (!ref_stack.empty() && ref_stack.back() && ref_stack.back()->is_structured()) - { - // remove discarded value - for (auto it = ref_stack.back()->begin(); it != ref_stack.back()->end(); ++it) - { - if (it->is_discarded()) - { - ref_stack.back()->erase(it); - break; - } - } - } - - return true; - } - - bool start_array(std::size_t len) - { - const bool keep = callback(static_cast(ref_stack.size()), parse_event_t::array_start, discarded); - keep_stack.push_back(keep); - - auto val = handle_value(BasicJsonType::value_t::array, true); - ref_stack.push_back(val.second); - - // check array limit - if (ref_stack.back() && JSON_HEDLEY_UNLIKELY(len != std::size_t(-1) && len > ref_stack.back()->max_size())) - { - JSON_THROW(out_of_range::create(408, "excessive array size: " + std::to_string(len))); - } - - return true; - } - - bool end_array() - { - bool keep = true; - - if (ref_stack.back()) - { - keep = callback(static_cast(ref_stack.size()) - 1, parse_event_t::array_end, *ref_stack.back()); - if (!keep) - { - // discard array - *ref_stack.back() = discarded; - } - } - - JSON_ASSERT(!ref_stack.empty()); - JSON_ASSERT(!keep_stack.empty()); - ref_stack.pop_back(); - keep_stack.pop_back(); - - // remove discarded value - if (!keep && !ref_stack.empty() && ref_stack.back()->is_array()) - { - ref_stack.back()->m_value.array->pop_back(); - } - - return true; - } - - template - bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/, - const Exception& ex) - { - errored = true; - static_cast(ex); - if (allow_exceptions) - { - JSON_THROW(ex); - } - return false; - } - - constexpr bool is_errored() const - { - return errored; - } - - private: - /*! - @param[in] v value to add to the JSON value we build during parsing - @param[in] skip_callback whether we should skip calling the callback - function; this is required after start_array() and - start_object() SAX events, because otherwise we would call the - callback function with an empty array or object, respectively. - - @invariant If the ref stack is empty, then the passed value will be the new - root. - @invariant If the ref stack contains a value, then it is an array or an - object to which we can add elements - - @return pair of boolean (whether value should be kept) and pointer (to the - passed value in the ref_stack hierarchy; nullptr if not kept) - */ - template - std::pair handle_value(Value&& v, const bool skip_callback = false) - { - JSON_ASSERT(!keep_stack.empty()); - - // do not handle this value if we know it would be added to a discarded - // container - if (!keep_stack.back()) - { - return {false, nullptr}; - } - - // create value - auto value = BasicJsonType(std::forward(v)); - - // check callback - const bool keep = skip_callback || callback(static_cast(ref_stack.size()), parse_event_t::value, value); - - // do not handle this value if we just learnt it shall be discarded - if (!keep) - { - return {false, nullptr}; - } - - if (ref_stack.empty()) - { - root = std::move(value); - return {true, &root}; - } - - // skip this value if we already decided to skip the parent - // (https://github.com/nlohmann/json/issues/971#issuecomment-413678360) - if (!ref_stack.back()) - { - return {false, nullptr}; - } - - // we now only expect arrays and objects - JSON_ASSERT(ref_stack.back()->is_array() || ref_stack.back()->is_object()); - - // array - if (ref_stack.back()->is_array()) - { - ref_stack.back()->m_value.array->push_back(std::move(value)); - return {true, &(ref_stack.back()->m_value.array->back())}; - } - - // object - JSON_ASSERT(ref_stack.back()->is_object()); - // check if we should store an element for the current key - JSON_ASSERT(!key_keep_stack.empty()); - const bool store_element = key_keep_stack.back(); - key_keep_stack.pop_back(); - - if (!store_element) - { - return {false, nullptr}; - } - - JSON_ASSERT(object_element); - *object_element = std::move(value); - return {true, object_element}; - } - - /// the parsed JSON value - BasicJsonType& root; - /// stack to model hierarchy of values - std::vector ref_stack {}; - /// stack to manage which values to keep - std::vector keep_stack {}; - /// stack to manage which object keys to keep - std::vector key_keep_stack {}; - /// helper to hold the reference for the next object element - BasicJsonType* object_element = nullptr; - /// whether a syntax error occurred - bool errored = false; - /// callback function - const parser_callback_t callback = nullptr; - /// whether to throw exceptions in case of errors - const bool allow_exceptions = true; - /// a discarded value for the callback - BasicJsonType discarded = BasicJsonType::value_t::discarded; -}; - -template -class json_sax_acceptor -{ - public: - using number_integer_t = typename BasicJsonType::number_integer_t; - using number_unsigned_t = typename BasicJsonType::number_unsigned_t; - using number_float_t = typename BasicJsonType::number_float_t; - using string_t = typename BasicJsonType::string_t; - using binary_t = typename BasicJsonType::binary_t; - - bool null() - { - return true; - } - - bool boolean(bool /*unused*/) - { - return true; - } - - bool number_integer(number_integer_t /*unused*/) - { - return true; - } - - bool number_unsigned(number_unsigned_t /*unused*/) - { - return true; - } - - bool number_float(number_float_t /*unused*/, const string_t& /*unused*/) - { - return true; - } - - bool string(string_t& /*unused*/) - { - return true; - } - - bool binary(binary_t& /*unused*/) - { - return true; - } - - bool start_object(std::size_t /*unused*/ = std::size_t(-1)) - { - return true; - } - - bool key(string_t& /*unused*/) - { - return true; - } - - bool end_object() - { - return true; - } - - bool start_array(std::size_t /*unused*/ = std::size_t(-1)) - { - return true; - } - - bool end_array() - { - return true; - } - - bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/, const detail::exception& /*unused*/) - { - return false; - } -}; -} // namespace detail - -} // namespace nlohmann - -// #include - - -#include // array -#include // localeconv -#include // size_t -#include // snprintf -#include // strtof, strtod, strtold, strtoll, strtoull -#include // initializer_list -#include // char_traits, string -#include // move -#include // vector - -// #include - -// #include - -// #include - - -namespace nlohmann -{ -namespace detail -{ -/////////// -// lexer // -/////////// - -template -class lexer_base -{ - public: - /// token types for the parser - enum class token_type - { - uninitialized, ///< indicating the scanner is uninitialized - literal_true, ///< the `true` literal - literal_false, ///< the `false` literal - literal_null, ///< the `null` literal - value_string, ///< a string -- use get_string() for actual value - value_unsigned, ///< an unsigned integer -- use get_number_unsigned() for actual value - value_integer, ///< a signed integer -- use get_number_integer() for actual value - value_float, ///< an floating point number -- use get_number_float() for actual value - begin_array, ///< the character for array begin `[` - begin_object, ///< the character for object begin `{` - end_array, ///< the character for array end `]` - end_object, ///< the character for object end `}` - name_separator, ///< the name separator `:` - value_separator, ///< the value separator `,` - parse_error, ///< indicating a parse error - end_of_input, ///< indicating the end of the input buffer - literal_or_value ///< a literal or the begin of a value (only for diagnostics) - }; - - /// return name of values of type token_type (only used for errors) - JSON_HEDLEY_RETURNS_NON_NULL - JSON_HEDLEY_CONST - static const char* token_type_name(const token_type t) noexcept - { - switch (t) - { - case token_type::uninitialized: - return ""; - case token_type::literal_true: - return "true literal"; - case token_type::literal_false: - return "false literal"; - case token_type::literal_null: - return "null literal"; - case token_type::value_string: - return "string literal"; - case token_type::value_unsigned: - case token_type::value_integer: - case token_type::value_float: - return "number literal"; - case token_type::begin_array: - return "'['"; - case token_type::begin_object: - return "'{'"; - case token_type::end_array: - return "']'"; - case token_type::end_object: - return "'}'"; - case token_type::name_separator: - return "':'"; - case token_type::value_separator: - return "','"; - case token_type::parse_error: - return ""; - case token_type::end_of_input: - return "end of input"; - case token_type::literal_or_value: - return "'[', '{', or a literal"; - // LCOV_EXCL_START - default: // catch non-enum values - return "unknown token"; - // LCOV_EXCL_STOP - } - } -}; -/*! -@brief lexical analysis - -This class organizes the lexical analysis during JSON deserialization. -*/ -template -class lexer : public lexer_base -{ - using number_integer_t = typename BasicJsonType::number_integer_t; - using number_unsigned_t = typename BasicJsonType::number_unsigned_t; - using number_float_t = typename BasicJsonType::number_float_t; - using string_t = typename BasicJsonType::string_t; - using char_type = typename InputAdapterType::char_type; - using char_int_type = typename std::char_traits::int_type; - - public: - using token_type = typename lexer_base::token_type; - - explicit lexer(InputAdapterType&& adapter, bool ignore_comments_ = false) - : ia(std::move(adapter)) - , ignore_comments(ignore_comments_) - , decimal_point_char(static_cast(get_decimal_point())) - {} - - // delete because of pointer members - lexer(const lexer&) = delete; - lexer(lexer&&) = default; - lexer& operator=(lexer&) = delete; - lexer& operator=(lexer&&) = default; - ~lexer() = default; - - private: - ///////////////////// - // locales - ///////////////////// - - /// return the locale-dependent decimal point - JSON_HEDLEY_PURE - static char get_decimal_point() noexcept - { - const auto* loc = localeconv(); - JSON_ASSERT(loc != nullptr); - return (loc->decimal_point == nullptr) ? '.' : *(loc->decimal_point); - } - - ///////////////////// - // scan functions - ///////////////////// - - /*! - @brief get codepoint from 4 hex characters following `\u` - - For input "\u c1 c2 c3 c4" the codepoint is: - (c1 * 0x1000) + (c2 * 0x0100) + (c3 * 0x0010) + c4 - = (c1 << 12) + (c2 << 8) + (c3 << 4) + (c4 << 0) - - Furthermore, the possible characters '0'..'9', 'A'..'F', and 'a'..'f' - must be converted to the integers 0x0..0x9, 0xA..0xF, 0xA..0xF, resp. The - conversion is done by subtracting the offset (0x30, 0x37, and 0x57) - between the ASCII value of the character and the desired integer value. - - @return codepoint (0x0000..0xFFFF) or -1 in case of an error (e.g. EOF or - non-hex character) - */ - int get_codepoint() - { - // this function only makes sense after reading `\u` - JSON_ASSERT(current == 'u'); - int codepoint = 0; - - const auto factors = { 12u, 8u, 4u, 0u }; - for (const auto factor : factors) - { - get(); - - if (current >= '0' && current <= '9') - { - codepoint += static_cast((static_cast(current) - 0x30u) << factor); - } - else if (current >= 'A' && current <= 'F') - { - codepoint += static_cast((static_cast(current) - 0x37u) << factor); - } - else if (current >= 'a' && current <= 'f') - { - codepoint += static_cast((static_cast(current) - 0x57u) << factor); - } - else - { - return -1; - } - } - - JSON_ASSERT(0x0000 <= codepoint && codepoint <= 0xFFFF); - return codepoint; - } - - /*! - @brief check if the next byte(s) are inside a given range - - Adds the current byte and, for each passed range, reads a new byte and - checks if it is inside the range. If a violation was detected, set up an - error message and return false. Otherwise, return true. - - @param[in] ranges list of integers; interpreted as list of pairs of - inclusive lower and upper bound, respectively - - @pre The passed list @a ranges must have 2, 4, or 6 elements; that is, - 1, 2, or 3 pairs. This precondition is enforced by an assertion. - - @return true if and only if no range violation was detected - */ - bool next_byte_in_range(std::initializer_list ranges) - { - JSON_ASSERT(ranges.size() == 2 || ranges.size() == 4 || ranges.size() == 6); - add(current); - - for (auto range = ranges.begin(); range != ranges.end(); ++range) - { - get(); - if (JSON_HEDLEY_LIKELY(*range <= current && current <= *(++range))) - { - add(current); - } - else - { - error_message = "invalid string: ill-formed UTF-8 byte"; - return false; - } - } - - return true; - } - - /*! - @brief scan a string literal - - This function scans a string according to Sect. 7 of RFC 7159. While - scanning, bytes are escaped and copied into buffer token_buffer. Then the - function returns successfully, token_buffer is *not* null-terminated (as it - may contain \0 bytes), and token_buffer.size() is the number of bytes in the - string. - - @return token_type::value_string if string could be successfully scanned, - token_type::parse_error otherwise - - @note In case of errors, variable error_message contains a textual - description. - */ - token_type scan_string() - { - // reset token_buffer (ignore opening quote) - reset(); - - // we entered the function by reading an open quote - JSON_ASSERT(current == '\"'); - - while (true) - { - // get next character - switch (get()) - { - // end of file while parsing string - case std::char_traits::eof(): - { - error_message = "invalid string: missing closing quote"; - return token_type::parse_error; - } - - // closing quote - case '\"': - { - return token_type::value_string; - } - - // escapes - case '\\': - { - switch (get()) - { - // quotation mark - case '\"': - add('\"'); - break; - // reverse solidus - case '\\': - add('\\'); - break; - // solidus - case '/': - add('/'); - break; - // backspace - case 'b': - add('\b'); - break; - // form feed - case 'f': - add('\f'); - break; - // line feed - case 'n': - add('\n'); - break; - // carriage return - case 'r': - add('\r'); - break; - // tab - case 't': - add('\t'); - break; - - // unicode escapes - case 'u': - { - const int codepoint1 = get_codepoint(); - int codepoint = codepoint1; // start with codepoint1 - - if (JSON_HEDLEY_UNLIKELY(codepoint1 == -1)) - { - error_message = "invalid string: '\\u' must be followed by 4 hex digits"; - return token_type::parse_error; - } - - // check if code point is a high surrogate - if (0xD800 <= codepoint1 && codepoint1 <= 0xDBFF) - { - // expect next \uxxxx entry - if (JSON_HEDLEY_LIKELY(get() == '\\' && get() == 'u')) - { - const int codepoint2 = get_codepoint(); - - if (JSON_HEDLEY_UNLIKELY(codepoint2 == -1)) - { - error_message = "invalid string: '\\u' must be followed by 4 hex digits"; - return token_type::parse_error; - } - - // check if codepoint2 is a low surrogate - if (JSON_HEDLEY_LIKELY(0xDC00 <= codepoint2 && codepoint2 <= 0xDFFF)) - { - // overwrite codepoint - codepoint = static_cast( - // high surrogate occupies the most significant 22 bits - (static_cast(codepoint1) << 10u) - // low surrogate occupies the least significant 15 bits - + static_cast(codepoint2) - // there is still the 0xD800, 0xDC00 and 0x10000 noise - // in the result so we have to subtract with: - // (0xD800 << 10) + DC00 - 0x10000 = 0x35FDC00 - - 0x35FDC00u); - } - else - { - error_message = "invalid string: surrogate U+D800..U+DBFF must be followed by U+DC00..U+DFFF"; - return token_type::parse_error; - } - } - else - { - error_message = "invalid string: surrogate U+D800..U+DBFF must be followed by U+DC00..U+DFFF"; - return token_type::parse_error; - } - } - else - { - if (JSON_HEDLEY_UNLIKELY(0xDC00 <= codepoint1 && codepoint1 <= 0xDFFF)) - { - error_message = "invalid string: surrogate U+DC00..U+DFFF must follow U+D800..U+DBFF"; - return token_type::parse_error; - } - } - - // result of the above calculation yields a proper codepoint - JSON_ASSERT(0x00 <= codepoint && codepoint <= 0x10FFFF); - - // translate codepoint into bytes - if (codepoint < 0x80) - { - // 1-byte characters: 0xxxxxxx (ASCII) - add(static_cast(codepoint)); - } - else if (codepoint <= 0x7FF) - { - // 2-byte characters: 110xxxxx 10xxxxxx - add(static_cast(0xC0u | (static_cast(codepoint) >> 6u))); - add(static_cast(0x80u | (static_cast(codepoint) & 0x3Fu))); - } - else if (codepoint <= 0xFFFF) - { - // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx - add(static_cast(0xE0u | (static_cast(codepoint) >> 12u))); - add(static_cast(0x80u | ((static_cast(codepoint) >> 6u) & 0x3Fu))); - add(static_cast(0x80u | (static_cast(codepoint) & 0x3Fu))); - } - else - { - // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx - add(static_cast(0xF0u | (static_cast(codepoint) >> 18u))); - add(static_cast(0x80u | ((static_cast(codepoint) >> 12u) & 0x3Fu))); - add(static_cast(0x80u | ((static_cast(codepoint) >> 6u) & 0x3Fu))); - add(static_cast(0x80u | (static_cast(codepoint) & 0x3Fu))); - } - - break; - } - - // other characters after escape - default: - error_message = "invalid string: forbidden character after backslash"; - return token_type::parse_error; - } - - break; - } - - // invalid control characters - case 0x00: - { - error_message = "invalid string: control character U+0000 (NUL) must be escaped to \\u0000"; - return token_type::parse_error; - } - - case 0x01: - { - error_message = "invalid string: control character U+0001 (SOH) must be escaped to \\u0001"; - return token_type::parse_error; - } - - case 0x02: - { - error_message = "invalid string: control character U+0002 (STX) must be escaped to \\u0002"; - return token_type::parse_error; - } - - case 0x03: - { - error_message = "invalid string: control character U+0003 (ETX) must be escaped to \\u0003"; - return token_type::parse_error; - } - - case 0x04: - { - error_message = "invalid string: control character U+0004 (EOT) must be escaped to \\u0004"; - return token_type::parse_error; - } - - case 0x05: - { - error_message = "invalid string: control character U+0005 (ENQ) must be escaped to \\u0005"; - return token_type::parse_error; - } - - case 0x06: - { - error_message = "invalid string: control character U+0006 (ACK) must be escaped to \\u0006"; - return token_type::parse_error; - } - - case 0x07: - { - error_message = "invalid string: control character U+0007 (BEL) must be escaped to \\u0007"; - return token_type::parse_error; - } - - case 0x08: - { - error_message = "invalid string: control character U+0008 (BS) must be escaped to \\u0008 or \\b"; - return token_type::parse_error; - } - - case 0x09: - { - error_message = "invalid string: control character U+0009 (HT) must be escaped to \\u0009 or \\t"; - return token_type::parse_error; - } - - case 0x0A: - { - error_message = "invalid string: control character U+000A (LF) must be escaped to \\u000A or \\n"; - return token_type::parse_error; - } - - case 0x0B: - { - error_message = "invalid string: control character U+000B (VT) must be escaped to \\u000B"; - return token_type::parse_error; - } - - case 0x0C: - { - error_message = "invalid string: control character U+000C (FF) must be escaped to \\u000C or \\f"; - return token_type::parse_error; - } - - case 0x0D: - { - error_message = "invalid string: control character U+000D (CR) must be escaped to \\u000D or \\r"; - return token_type::parse_error; - } - - case 0x0E: - { - error_message = "invalid string: control character U+000E (SO) must be escaped to \\u000E"; - return token_type::parse_error; - } - - case 0x0F: - { - error_message = "invalid string: control character U+000F (SI) must be escaped to \\u000F"; - return token_type::parse_error; - } - - case 0x10: - { - error_message = "invalid string: control character U+0010 (DLE) must be escaped to \\u0010"; - return token_type::parse_error; - } - - case 0x11: - { - error_message = "invalid string: control character U+0011 (DC1) must be escaped to \\u0011"; - return token_type::parse_error; - } - - case 0x12: - { - error_message = "invalid string: control character U+0012 (DC2) must be escaped to \\u0012"; - return token_type::parse_error; - } - - case 0x13: - { - error_message = "invalid string: control character U+0013 (DC3) must be escaped to \\u0013"; - return token_type::parse_error; - } - - case 0x14: - { - error_message = "invalid string: control character U+0014 (DC4) must be escaped to \\u0014"; - return token_type::parse_error; - } - - case 0x15: - { - error_message = "invalid string: control character U+0015 (NAK) must be escaped to \\u0015"; - return token_type::parse_error; - } - - case 0x16: - { - error_message = "invalid string: control character U+0016 (SYN) must be escaped to \\u0016"; - return token_type::parse_error; - } - - case 0x17: - { - error_message = "invalid string: control character U+0017 (ETB) must be escaped to \\u0017"; - return token_type::parse_error; - } - - case 0x18: - { - error_message = "invalid string: control character U+0018 (CAN) must be escaped to \\u0018"; - return token_type::parse_error; - } - - case 0x19: - { - error_message = "invalid string: control character U+0019 (EM) must be escaped to \\u0019"; - return token_type::parse_error; - } - - case 0x1A: - { - error_message = "invalid string: control character U+001A (SUB) must be escaped to \\u001A"; - return token_type::parse_error; - } - - case 0x1B: - { - error_message = "invalid string: control character U+001B (ESC) must be escaped to \\u001B"; - return token_type::parse_error; - } - - case 0x1C: - { - error_message = "invalid string: control character U+001C (FS) must be escaped to \\u001C"; - return token_type::parse_error; - } - - case 0x1D: - { - error_message = "invalid string: control character U+001D (GS) must be escaped to \\u001D"; - return token_type::parse_error; - } - - case 0x1E: - { - error_message = "invalid string: control character U+001E (RS) must be escaped to \\u001E"; - return token_type::parse_error; - } - - case 0x1F: - { - error_message = "invalid string: control character U+001F (US) must be escaped to \\u001F"; - return token_type::parse_error; - } - - // U+0020..U+007F (except U+0022 (quote) and U+005C (backspace)) - case 0x20: - case 0x21: - case 0x23: - case 0x24: - case 0x25: - case 0x26: - case 0x27: - case 0x28: - case 0x29: - case 0x2A: - case 0x2B: - case 0x2C: - case 0x2D: - case 0x2E: - case 0x2F: - case 0x30: - case 0x31: - case 0x32: - case 0x33: - case 0x34: - case 0x35: - case 0x36: - case 0x37: - case 0x38: - case 0x39: - case 0x3A: - case 0x3B: - case 0x3C: - case 0x3D: - case 0x3E: - case 0x3F: - case 0x40: - case 0x41: - case 0x42: - case 0x43: - case 0x44: - case 0x45: - case 0x46: - case 0x47: - case 0x48: - case 0x49: - case 0x4A: - case 0x4B: - case 0x4C: - case 0x4D: - case 0x4E: - case 0x4F: - case 0x50: - case 0x51: - case 0x52: - case 0x53: - case 0x54: - case 0x55: - case 0x56: - case 0x57: - case 0x58: - case 0x59: - case 0x5A: - case 0x5B: - case 0x5D: - case 0x5E: - case 0x5F: - case 0x60: - case 0x61: - case 0x62: - case 0x63: - case 0x64: - case 0x65: - case 0x66: - case 0x67: - case 0x68: - case 0x69: - case 0x6A: - case 0x6B: - case 0x6C: - case 0x6D: - case 0x6E: - case 0x6F: - case 0x70: - case 0x71: - case 0x72: - case 0x73: - case 0x74: - case 0x75: - case 0x76: - case 0x77: - case 0x78: - case 0x79: - case 0x7A: - case 0x7B: - case 0x7C: - case 0x7D: - case 0x7E: - case 0x7F: - { - add(current); - break; - } - - // U+0080..U+07FF: bytes C2..DF 80..BF - case 0xC2: - case 0xC3: - case 0xC4: - case 0xC5: - case 0xC6: - case 0xC7: - case 0xC8: - case 0xC9: - case 0xCA: - case 0xCB: - case 0xCC: - case 0xCD: - case 0xCE: - case 0xCF: - case 0xD0: - case 0xD1: - case 0xD2: - case 0xD3: - case 0xD4: - case 0xD5: - case 0xD6: - case 0xD7: - case 0xD8: - case 0xD9: - case 0xDA: - case 0xDB: - case 0xDC: - case 0xDD: - case 0xDE: - case 0xDF: - { - if (JSON_HEDLEY_UNLIKELY(!next_byte_in_range({0x80, 0xBF}))) - { - return token_type::parse_error; - } - break; - } - - // U+0800..U+0FFF: bytes E0 A0..BF 80..BF - case 0xE0: - { - if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0xA0, 0xBF, 0x80, 0xBF})))) - { - return token_type::parse_error; - } - break; - } - - // U+1000..U+CFFF: bytes E1..EC 80..BF 80..BF - // U+E000..U+FFFF: bytes EE..EF 80..BF 80..BF - case 0xE1: - case 0xE2: - case 0xE3: - case 0xE4: - case 0xE5: - case 0xE6: - case 0xE7: - case 0xE8: - case 0xE9: - case 0xEA: - case 0xEB: - case 0xEC: - case 0xEE: - case 0xEF: - { - if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0xBF, 0x80, 0xBF})))) - { - return token_type::parse_error; - } - break; - } - - // U+D000..U+D7FF: bytes ED 80..9F 80..BF - case 0xED: - { - if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0x9F, 0x80, 0xBF})))) - { - return token_type::parse_error; - } - break; - } - - // U+10000..U+3FFFF F0 90..BF 80..BF 80..BF - case 0xF0: - { - if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x90, 0xBF, 0x80, 0xBF, 0x80, 0xBF})))) - { - return token_type::parse_error; - } - break; - } - - // U+40000..U+FFFFF F1..F3 80..BF 80..BF 80..BF - case 0xF1: - case 0xF2: - case 0xF3: - { - if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0xBF, 0x80, 0xBF, 0x80, 0xBF})))) - { - return token_type::parse_error; - } - break; - } - - // U+100000..U+10FFFF F4 80..8F 80..BF 80..BF - case 0xF4: - { - if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0x8F, 0x80, 0xBF, 0x80, 0xBF})))) - { - return token_type::parse_error; - } - break; - } - - // remaining bytes (80..C1 and F5..FF) are ill-formed - default: - { - error_message = "invalid string: ill-formed UTF-8 byte"; - return token_type::parse_error; - } - } - } - } - - /*! - * @brief scan a comment - * @return whether comment could be scanned successfully - */ - bool scan_comment() - { - switch (get()) - { - // single-line comments skip input until a newline or EOF is read - case '/': - { - while (true) - { - switch (get()) - { - case '\n': - case '\r': - case std::char_traits::eof(): - case '\0': - return true; - - default: - break; - } - } - } - - // multi-line comments skip input until */ is read - case '*': - { - while (true) - { - switch (get()) - { - case std::char_traits::eof(): - case '\0': - { - error_message = "invalid comment; missing closing '*/'"; - return false; - } - - case '*': - { - switch (get()) - { - case '/': - return true; - - default: - { - unget(); - continue; - } - } - } - - default: - continue; - } - } - } - - // unexpected character after reading '/' - default: - { - error_message = "invalid comment; expecting '/' or '*' after '/'"; - return false; - } - } - } - - JSON_HEDLEY_NON_NULL(2) - static void strtof(float& f, const char* str, char** endptr) noexcept - { - f = std::strtof(str, endptr); - } - - JSON_HEDLEY_NON_NULL(2) - static void strtof(double& f, const char* str, char** endptr) noexcept - { - f = std::strtod(str, endptr); - } - - JSON_HEDLEY_NON_NULL(2) - static void strtof(long double& f, const char* str, char** endptr) noexcept - { - f = std::strtold(str, endptr); - } - - /*! - @brief scan a number literal - - This function scans a string according to Sect. 6 of RFC 7159. - - The function is realized with a deterministic finite state machine derived - from the grammar described in RFC 7159. Starting in state "init", the - input is read and used to determined the next state. Only state "done" - accepts the number. State "error" is a trap state to model errors. In the - table below, "anything" means any character but the ones listed before. - - state | 0 | 1-9 | e E | + | - | . | anything - ---------|----------|----------|----------|---------|---------|----------|----------- - init | zero | any1 | [error] | [error] | minus | [error] | [error] - minus | zero | any1 | [error] | [error] | [error] | [error] | [error] - zero | done | done | exponent | done | done | decimal1 | done - any1 | any1 | any1 | exponent | done | done | decimal1 | done - decimal1 | decimal2 | decimal2 | [error] | [error] | [error] | [error] | [error] - decimal2 | decimal2 | decimal2 | exponent | done | done | done | done - exponent | any2 | any2 | [error] | sign | sign | [error] | [error] - sign | any2 | any2 | [error] | [error] | [error] | [error] | [error] - any2 | any2 | any2 | done | done | done | done | done - - The state machine is realized with one label per state (prefixed with - "scan_number_") and `goto` statements between them. The state machine - contains cycles, but any cycle can be left when EOF is read. Therefore, - the function is guaranteed to terminate. - - During scanning, the read bytes are stored in token_buffer. This string is - then converted to a signed integer, an unsigned integer, or a - floating-point number. - - @return token_type::value_unsigned, token_type::value_integer, or - token_type::value_float if number could be successfully scanned, - token_type::parse_error otherwise - - @note The scanner is independent of the current locale. Internally, the - locale's decimal point is used instead of `.` to work with the - locale-dependent converters. - */ - token_type scan_number() // lgtm [cpp/use-of-goto] - { - // reset token_buffer to store the number's bytes - reset(); - - // the type of the parsed number; initially set to unsigned; will be - // changed if minus sign, decimal point or exponent is read - token_type number_type = token_type::value_unsigned; - - // state (init): we just found out we need to scan a number - switch (current) - { - case '-': - { - add(current); - goto scan_number_minus; - } - - case '0': - { - add(current); - goto scan_number_zero; - } - - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - { - add(current); - goto scan_number_any1; - } - - // all other characters are rejected outside scan_number() - default: // LCOV_EXCL_LINE - JSON_ASSERT(false); // LCOV_EXCL_LINE - } - -scan_number_minus: - // state: we just parsed a leading minus sign - number_type = token_type::value_integer; - switch (get()) - { - case '0': - { - add(current); - goto scan_number_zero; - } - - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - { - add(current); - goto scan_number_any1; - } - - default: - { - error_message = "invalid number; expected digit after '-'"; - return token_type::parse_error; - } - } - -scan_number_zero: - // state: we just parse a zero (maybe with a leading minus sign) - switch (get()) - { - case '.': - { - add(decimal_point_char); - goto scan_number_decimal1; - } - - case 'e': - case 'E': - { - add(current); - goto scan_number_exponent; - } - - default: - goto scan_number_done; - } - -scan_number_any1: - // state: we just parsed a number 0-9 (maybe with a leading minus sign) - switch (get()) - { - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - { - add(current); - goto scan_number_any1; - } - - case '.': - { - add(decimal_point_char); - goto scan_number_decimal1; - } - - case 'e': - case 'E': - { - add(current); - goto scan_number_exponent; - } - - default: - goto scan_number_done; - } - -scan_number_decimal1: - // state: we just parsed a decimal point - number_type = token_type::value_float; - switch (get()) - { - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - { - add(current); - goto scan_number_decimal2; - } - - default: - { - error_message = "invalid number; expected digit after '.'"; - return token_type::parse_error; - } - } - -scan_number_decimal2: - // we just parsed at least one number after a decimal point - switch (get()) - { - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - { - add(current); - goto scan_number_decimal2; - } - - case 'e': - case 'E': - { - add(current); - goto scan_number_exponent; - } - - default: - goto scan_number_done; - } - -scan_number_exponent: - // we just parsed an exponent - number_type = token_type::value_float; - switch (get()) - { - case '+': - case '-': - { - add(current); - goto scan_number_sign; - } - - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - { - add(current); - goto scan_number_any2; - } - - default: - { - error_message = - "invalid number; expected '+', '-', or digit after exponent"; - return token_type::parse_error; - } - } - -scan_number_sign: - // we just parsed an exponent sign - switch (get()) - { - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - { - add(current); - goto scan_number_any2; - } - - default: - { - error_message = "invalid number; expected digit after exponent sign"; - return token_type::parse_error; - } - } - -scan_number_any2: - // we just parsed a number after the exponent or exponent sign - switch (get()) - { - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - { - add(current); - goto scan_number_any2; - } - - default: - goto scan_number_done; - } - -scan_number_done: - // unget the character after the number (we only read it to know that - // we are done scanning a number) - unget(); - - char* endptr = nullptr; - errno = 0; - - // try to parse integers first and fall back to floats - if (number_type == token_type::value_unsigned) - { - const auto x = std::strtoull(token_buffer.data(), &endptr, 10); - - // we checked the number format before - JSON_ASSERT(endptr == token_buffer.data() + token_buffer.size()); - - if (errno == 0) - { - value_unsigned = static_cast(x); - if (value_unsigned == x) - { - return token_type::value_unsigned; - } - } - } - else if (number_type == token_type::value_integer) - { - const auto x = std::strtoll(token_buffer.data(), &endptr, 10); - - // we checked the number format before - JSON_ASSERT(endptr == token_buffer.data() + token_buffer.size()); - - if (errno == 0) - { - value_integer = static_cast(x); - if (value_integer == x) - { - return token_type::value_integer; - } - } - } - - // this code is reached if we parse a floating-point number or if an - // integer conversion above failed - strtof(value_float, token_buffer.data(), &endptr); - - // we checked the number format before - JSON_ASSERT(endptr == token_buffer.data() + token_buffer.size()); - - return token_type::value_float; - } - - /*! - @param[in] literal_text the literal text to expect - @param[in] length the length of the passed literal text - @param[in] return_type the token type to return on success - */ - JSON_HEDLEY_NON_NULL(2) - token_type scan_literal(const char_type* literal_text, const std::size_t length, - token_type return_type) - { - JSON_ASSERT(std::char_traits::to_char_type(current) == literal_text[0]); - for (std::size_t i = 1; i < length; ++i) - { - if (JSON_HEDLEY_UNLIKELY(std::char_traits::to_char_type(get()) != literal_text[i])) - { - error_message = "invalid literal"; - return token_type::parse_error; - } - } - return return_type; - } - - ///////////////////// - // input management - ///////////////////// - - /// reset token_buffer; current character is beginning of token - void reset() noexcept - { - token_buffer.clear(); - token_string.clear(); - token_string.push_back(std::char_traits::to_char_type(current)); - } - - /* - @brief get next character from the input - - This function provides the interface to the used input adapter. It does - not throw in case the input reached EOF, but returns a - `std::char_traits::eof()` in that case. Stores the scanned characters - for use in error messages. - - @return character read from the input - */ - char_int_type get() - { - ++position.chars_read_total; - ++position.chars_read_current_line; - - if (next_unget) - { - // just reset the next_unget variable and work with current - next_unget = false; - } - else - { - current = ia.get_character(); - } - - if (JSON_HEDLEY_LIKELY(current != std::char_traits::eof())) - { - token_string.push_back(std::char_traits::to_char_type(current)); - } - - if (current == '\n') - { - ++position.lines_read; - position.chars_read_current_line = 0; - } - - return current; - } - - /*! - @brief unget current character (read it again on next get) - - We implement unget by setting variable next_unget to true. The input is not - changed - we just simulate ungetting by modifying chars_read_total, - chars_read_current_line, and token_string. The next call to get() will - behave as if the unget character is read again. - */ - void unget() - { - next_unget = true; - - --position.chars_read_total; - - // in case we "unget" a newline, we have to also decrement the lines_read - if (position.chars_read_current_line == 0) - { - if (position.lines_read > 0) - { - --position.lines_read; - } - } - else - { - --position.chars_read_current_line; - } - - if (JSON_HEDLEY_LIKELY(current != std::char_traits::eof())) - { - JSON_ASSERT(!token_string.empty()); - token_string.pop_back(); - } - } - - /// add a character to token_buffer - void add(char_int_type c) - { - token_buffer.push_back(static_cast(c)); - } - - public: - ///////////////////// - // value getters - ///////////////////// - - /// return integer value - constexpr number_integer_t get_number_integer() const noexcept - { - return value_integer; - } - - /// return unsigned integer value - constexpr number_unsigned_t get_number_unsigned() const noexcept - { - return value_unsigned; - } - - /// return floating-point value - constexpr number_float_t get_number_float() const noexcept - { - return value_float; - } - - /// return current string value (implicitly resets the token; useful only once) - string_t& get_string() - { - return token_buffer; - } - - ///////////////////// - // diagnostics - ///////////////////// - - /// return position of last read token - constexpr position_t get_position() const noexcept - { - return position; - } - - /// return the last read token (for errors only). Will never contain EOF - /// (an arbitrary value that is not a valid char value, often -1), because - /// 255 may legitimately occur. May contain NUL, which should be escaped. - std::string get_token_string() const - { - // escape control characters - std::string result; - for (const auto c : token_string) - { - if (static_cast(c) <= '\x1F') - { - // escape control characters - std::array cs{{}}; - (std::snprintf)(cs.data(), cs.size(), "", static_cast(c)); - result += cs.data(); - } - else - { - // add character as is - result.push_back(static_cast(c)); - } - } - - return result; - } - - /// return syntax error message - JSON_HEDLEY_RETURNS_NON_NULL - constexpr const char* get_error_message() const noexcept - { - return error_message; - } - - ///////////////////// - // actual scanner - ///////////////////// - - /*! - @brief skip the UTF-8 byte order mark - @return true iff there is no BOM or the correct BOM has been skipped - */ - bool skip_bom() - { - if (get() == 0xEF) - { - // check if we completely parse the BOM - return get() == 0xBB && get() == 0xBF; - } - - // the first character is not the beginning of the BOM; unget it to - // process is later - unget(); - return true; - } - - void skip_whitespace() - { - do - { - get(); - } - while (current == ' ' || current == '\t' || current == '\n' || current == '\r'); - } - - token_type scan() - { - // initially, skip the BOM - if (position.chars_read_total == 0 && !skip_bom()) - { - error_message = "invalid BOM; must be 0xEF 0xBB 0xBF if given"; - return token_type::parse_error; - } - - // read next character and ignore whitespace - skip_whitespace(); - - // ignore comments - while (ignore_comments && current == '/') - { - if (!scan_comment()) - { - return token_type::parse_error; - } - - // skip following whitespace - skip_whitespace(); - } - - switch (current) - { - // structural characters - case '[': - return token_type::begin_array; - case ']': - return token_type::end_array; - case '{': - return token_type::begin_object; - case '}': - return token_type::end_object; - case ':': - return token_type::name_separator; - case ',': - return token_type::value_separator; - - // literals - case 't': - { - std::array true_literal = {{'t', 'r', 'u', 'e'}}; - return scan_literal(true_literal.data(), true_literal.size(), token_type::literal_true); - } - case 'f': - { - std::array false_literal = {{'f', 'a', 'l', 's', 'e'}}; - return scan_literal(false_literal.data(), false_literal.size(), token_type::literal_false); - } - case 'n': - { - std::array null_literal = {{'n', 'u', 'l', 'l'}}; - return scan_literal(null_literal.data(), null_literal.size(), token_type::literal_null); - } - - // string - case '\"': - return scan_string(); - - // number - case '-': - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - return scan_number(); - - // end of input (the null byte is needed when parsing from - // string literals) - case '\0': - case std::char_traits::eof(): - return token_type::end_of_input; - - // error - default: - error_message = "invalid literal"; - return token_type::parse_error; - } - } - - private: - /// input adapter - InputAdapterType ia; - - /// whether comments should be ignored (true) or signaled as errors (false) - const bool ignore_comments = false; - - /// the current character - char_int_type current = std::char_traits::eof(); - - /// whether the next get() call should just return current - bool next_unget = false; - - /// the start position of the current token - position_t position {}; - - /// raw input token string (for error messages) - std::vector token_string {}; - - /// buffer for variable-length tokens (numbers, strings) - string_t token_buffer {}; - - /// a description of occurred lexer errors - const char* error_message = ""; - - // number values - number_integer_t value_integer = 0; - number_unsigned_t value_unsigned = 0; - number_float_t value_float = 0; - - /// the decimal point - const char_int_type decimal_point_char = '.'; -}; -} // namespace detail -} // namespace nlohmann - -// #include - -// #include - - -#include // size_t -#include // declval -#include // string - -// #include - -// #include - - -namespace nlohmann -{ -namespace detail -{ -template -using null_function_t = decltype(std::declval().null()); - -template -using boolean_function_t = - decltype(std::declval().boolean(std::declval())); - -template -using number_integer_function_t = - decltype(std::declval().number_integer(std::declval())); - -template -using number_unsigned_function_t = - decltype(std::declval().number_unsigned(std::declval())); - -template -using number_float_function_t = decltype(std::declval().number_float( - std::declval(), std::declval())); - -template -using string_function_t = - decltype(std::declval().string(std::declval())); - -template -using binary_function_t = - decltype(std::declval().binary(std::declval())); - -template -using start_object_function_t = - decltype(std::declval().start_object(std::declval())); - -template -using key_function_t = - decltype(std::declval().key(std::declval())); - -template -using end_object_function_t = decltype(std::declval().end_object()); - -template -using start_array_function_t = - decltype(std::declval().start_array(std::declval())); - -template -using end_array_function_t = decltype(std::declval().end_array()); - -template -using parse_error_function_t = decltype(std::declval().parse_error( - std::declval(), std::declval(), - std::declval())); - -template -struct is_sax -{ - private: - static_assert(is_basic_json::value, - "BasicJsonType must be of type basic_json<...>"); - - using number_integer_t = typename BasicJsonType::number_integer_t; - using number_unsigned_t = typename BasicJsonType::number_unsigned_t; - using number_float_t = typename BasicJsonType::number_float_t; - using string_t = typename BasicJsonType::string_t; - using binary_t = typename BasicJsonType::binary_t; - using exception_t = typename BasicJsonType::exception; - - public: - static constexpr bool value = - is_detected_exact::value && - is_detected_exact::value && - is_detected_exact::value && - is_detected_exact::value && - is_detected_exact::value && - is_detected_exact::value && - is_detected_exact::value && - is_detected_exact::value && - is_detected_exact::value && - is_detected_exact::value && - is_detected_exact::value && - is_detected_exact::value && - is_detected_exact::value; -}; - -template -struct is_sax_static_asserts -{ - private: - static_assert(is_basic_json::value, - "BasicJsonType must be of type basic_json<...>"); - - using number_integer_t = typename BasicJsonType::number_integer_t; - using number_unsigned_t = typename BasicJsonType::number_unsigned_t; - using number_float_t = typename BasicJsonType::number_float_t; - using string_t = typename BasicJsonType::string_t; - using binary_t = typename BasicJsonType::binary_t; - using exception_t = typename BasicJsonType::exception; - - public: - static_assert(is_detected_exact::value, - "Missing/invalid function: bool null()"); - static_assert(is_detected_exact::value, - "Missing/invalid function: bool boolean(bool)"); - static_assert(is_detected_exact::value, - "Missing/invalid function: bool boolean(bool)"); - static_assert( - is_detected_exact::value, - "Missing/invalid function: bool number_integer(number_integer_t)"); - static_assert( - is_detected_exact::value, - "Missing/invalid function: bool number_unsigned(number_unsigned_t)"); - static_assert(is_detected_exact::value, - "Missing/invalid function: bool number_float(number_float_t, const string_t&)"); - static_assert( - is_detected_exact::value, - "Missing/invalid function: bool string(string_t&)"); - static_assert( - is_detected_exact::value, - "Missing/invalid function: bool binary(binary_t&)"); - static_assert(is_detected_exact::value, - "Missing/invalid function: bool start_object(std::size_t)"); - static_assert(is_detected_exact::value, - "Missing/invalid function: bool key(string_t&)"); - static_assert(is_detected_exact::value, - "Missing/invalid function: bool end_object()"); - static_assert(is_detected_exact::value, - "Missing/invalid function: bool start_array(std::size_t)"); - static_assert(is_detected_exact::value, - "Missing/invalid function: bool end_array()"); - static_assert( - is_detected_exact::value, - "Missing/invalid function: bool parse_error(std::size_t, const " - "std::string&, const exception&)"); -}; -} // namespace detail -} // namespace nlohmann - -// #include - - -namespace nlohmann -{ -namespace detail -{ - -/// how to treat CBOR tags -enum class cbor_tag_handler_t -{ - error, ///< throw a parse_error exception in case of a tag - ignore ///< ignore tags -}; - -/*! -@brief determine system byte order - -@return true if and only if system's byte order is little endian - -@note from https://stackoverflow.com/a/1001328/266378 -*/ -static inline bool little_endianess(int num = 1) noexcept -{ - return *reinterpret_cast(&num) == 1; -} - - -/////////////////// -// binary reader // -/////////////////// - -/*! -@brief deserialization of CBOR, MessagePack, and UBJSON values -*/ -template> -class binary_reader -{ - using number_integer_t = typename BasicJsonType::number_integer_t; - using number_unsigned_t = typename BasicJsonType::number_unsigned_t; - using number_float_t = typename BasicJsonType::number_float_t; - using string_t = typename BasicJsonType::string_t; - using binary_t = typename BasicJsonType::binary_t; - using json_sax_t = SAX; - using char_type = typename InputAdapterType::char_type; - using char_int_type = typename std::char_traits::int_type; - - public: - /*! - @brief create a binary reader - - @param[in] adapter input adapter to read from - */ - explicit binary_reader(InputAdapterType&& adapter) : ia(std::move(adapter)) - { - (void)detail::is_sax_static_asserts {}; - } - - // make class move-only - binary_reader(const binary_reader&) = delete; - binary_reader(binary_reader&&) = default; - binary_reader& operator=(const binary_reader&) = delete; - binary_reader& operator=(binary_reader&&) = default; - ~binary_reader() = default; - - /*! - @param[in] format the binary format to parse - @param[in] sax_ a SAX event processor - @param[in] strict whether to expect the input to be consumed completed - @param[in] tag_handler how to treat CBOR tags - - @return - */ - JSON_HEDLEY_NON_NULL(3) - bool sax_parse(const input_format_t format, - json_sax_t* sax_, - const bool strict = true, - const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) - { - sax = sax_; - bool result = false; - - switch (format) - { - case input_format_t::bson: - result = parse_bson_internal(); - break; - - case input_format_t::cbor: - result = parse_cbor_internal(true, tag_handler); - break; - - case input_format_t::msgpack: - result = parse_msgpack_internal(); - break; - - case input_format_t::ubjson: - result = parse_ubjson_internal(); - break; - - default: // LCOV_EXCL_LINE - JSON_ASSERT(false); // LCOV_EXCL_LINE - } - - // strict mode: next byte must be EOF - if (result && strict) - { - if (format == input_format_t::ubjson) - { - get_ignore_noop(); - } - else - { - get(); - } - - if (JSON_HEDLEY_UNLIKELY(current != std::char_traits::eof())) - { - return sax->parse_error(chars_read, get_token_string(), - parse_error::create(110, chars_read, exception_message(format, "expected end of input; last byte: 0x" + get_token_string(), "value"))); - } - } - - return result; - } - - private: - ////////// - // BSON // - ////////// - - /*! - @brief Reads in a BSON-object and passes it to the SAX-parser. - @return whether a valid BSON-value was passed to the SAX parser - */ - bool parse_bson_internal() - { - std::int32_t document_size{}; - get_number(input_format_t::bson, document_size); - - if (JSON_HEDLEY_UNLIKELY(!sax->start_object(std::size_t(-1)))) - { - return false; - } - - if (JSON_HEDLEY_UNLIKELY(!parse_bson_element_list(/*is_array*/false))) - { - return false; - } - - return sax->end_object(); - } - - /*! - @brief Parses a C-style string from the BSON input. - @param[in, out] result A reference to the string variable where the read - string is to be stored. - @return `true` if the \x00-byte indicating the end of the string was - encountered before the EOF; false` indicates an unexpected EOF. - */ - bool get_bson_cstr(string_t& result) - { - auto out = std::back_inserter(result); - while (true) - { - get(); - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::bson, "cstring"))) - { - return false; - } - if (current == 0x00) - { - return true; - } - *out++ = static_cast(current); - } - } - - /*! - @brief Parses a zero-terminated string of length @a len from the BSON - input. - @param[in] len The length (including the zero-byte at the end) of the - string to be read. - @param[in, out] result A reference to the string variable where the read - string is to be stored. - @tparam NumberType The type of the length @a len - @pre len >= 1 - @return `true` if the string was successfully parsed - */ - template - bool get_bson_string(const NumberType len, string_t& result) - { - if (JSON_HEDLEY_UNLIKELY(len < 1)) - { - auto last_token = get_token_string(); - return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::bson, "string length must be at least 1, is " + std::to_string(len), "string"))); - } - - return get_string(input_format_t::bson, len - static_cast(1), result) && get() != std::char_traits::eof(); - } - - /*! - @brief Parses a byte array input of length @a len from the BSON input. - @param[in] len The length of the byte array to be read. - @param[in, out] result A reference to the binary variable where the read - array is to be stored. - @tparam NumberType The type of the length @a len - @pre len >= 0 - @return `true` if the byte array was successfully parsed - */ - template - bool get_bson_binary(const NumberType len, binary_t& result) - { - if (JSON_HEDLEY_UNLIKELY(len < 0)) - { - auto last_token = get_token_string(); - return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::bson, "byte array length cannot be negative, is " + std::to_string(len), "binary"))); - } - - // All BSON binary values have a subtype - std::uint8_t subtype{}; - get_number(input_format_t::bson, subtype); - result.set_subtype(subtype); - - return get_binary(input_format_t::bson, len, result); - } - - /*! - @brief Read a BSON document element of the given @a element_type. - @param[in] element_type The BSON element type, c.f. http://bsonspec.org/spec.html - @param[in] element_type_parse_position The position in the input stream, - where the `element_type` was read. - @warning Not all BSON element types are supported yet. An unsupported - @a element_type will give rise to a parse_error.114: - Unsupported BSON record type 0x... - @return whether a valid BSON-object/array was passed to the SAX parser - */ - bool parse_bson_element_internal(const char_int_type element_type, - const std::size_t element_type_parse_position) - { - switch (element_type) - { - case 0x01: // double - { - double number{}; - return get_number(input_format_t::bson, number) && sax->number_float(static_cast(number), ""); - } - - case 0x02: // string - { - std::int32_t len{}; - string_t value; - return get_number(input_format_t::bson, len) && get_bson_string(len, value) && sax->string(value); - } - - case 0x03: // object - { - return parse_bson_internal(); - } - - case 0x04: // array - { - return parse_bson_array(); - } - - case 0x05: // binary - { - std::int32_t len{}; - binary_t value; - return get_number(input_format_t::bson, len) && get_bson_binary(len, value) && sax->binary(value); - } - - case 0x08: // boolean - { - return sax->boolean(get() != 0); - } - - case 0x0A: // null - { - return sax->null(); - } - - case 0x10: // int32 - { - std::int32_t value{}; - return get_number(input_format_t::bson, value) && sax->number_integer(value); - } - - case 0x12: // int64 - { - std::int64_t value{}; - return get_number(input_format_t::bson, value) && sax->number_integer(value); - } - - default: // anything else not supported (yet) - { - std::array cr{{}}; - (std::snprintf)(cr.data(), cr.size(), "%.2hhX", static_cast(element_type)); - return sax->parse_error(element_type_parse_position, std::string(cr.data()), parse_error::create(114, element_type_parse_position, "Unsupported BSON record type 0x" + std::string(cr.data()))); - } - } - } - - /*! - @brief Read a BSON element list (as specified in the BSON-spec) - - The same binary layout is used for objects and arrays, hence it must be - indicated with the argument @a is_array which one is expected - (true --> array, false --> object). - - @param[in] is_array Determines if the element list being read is to be - treated as an object (@a is_array == false), or as an - array (@a is_array == true). - @return whether a valid BSON-object/array was passed to the SAX parser - */ - bool parse_bson_element_list(const bool is_array) - { - string_t key; - - while (auto element_type = get()) - { - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::bson, "element list"))) - { - return false; - } - - const std::size_t element_type_parse_position = chars_read; - if (JSON_HEDLEY_UNLIKELY(!get_bson_cstr(key))) - { - return false; - } - - if (!is_array && !sax->key(key)) - { - return false; - } - - if (JSON_HEDLEY_UNLIKELY(!parse_bson_element_internal(element_type, element_type_parse_position))) - { - return false; - } - - // get_bson_cstr only appends - key.clear(); - } - - return true; - } - - /*! - @brief Reads an array from the BSON input and passes it to the SAX-parser. - @return whether a valid BSON-array was passed to the SAX parser - */ - bool parse_bson_array() - { - std::int32_t document_size{}; - get_number(input_format_t::bson, document_size); - - if (JSON_HEDLEY_UNLIKELY(!sax->start_array(std::size_t(-1)))) - { - return false; - } - - if (JSON_HEDLEY_UNLIKELY(!parse_bson_element_list(/*is_array*/true))) - { - return false; - } - - return sax->end_array(); - } - - ////////// - // CBOR // - ////////// - - /*! - @param[in] get_char whether a new character should be retrieved from the - input (true) or whether the last read character should - be considered instead (false) - @param[in] tag_handler how CBOR tags should be treated - - @return whether a valid CBOR value was passed to the SAX parser - */ - bool parse_cbor_internal(const bool get_char, - const cbor_tag_handler_t tag_handler) - { - switch (get_char ? get() : current) - { - // EOF - case std::char_traits::eof(): - return unexpect_eof(input_format_t::cbor, "value"); - - // Integer 0x00..0x17 (0..23) - case 0x00: - case 0x01: - case 0x02: - case 0x03: - case 0x04: - case 0x05: - case 0x06: - case 0x07: - case 0x08: - case 0x09: - case 0x0A: - case 0x0B: - case 0x0C: - case 0x0D: - case 0x0E: - case 0x0F: - case 0x10: - case 0x11: - case 0x12: - case 0x13: - case 0x14: - case 0x15: - case 0x16: - case 0x17: - return sax->number_unsigned(static_cast(current)); - - case 0x18: // Unsigned integer (one-byte uint8_t follows) - { - std::uint8_t number{}; - return get_number(input_format_t::cbor, number) && sax->number_unsigned(number); - } - - case 0x19: // Unsigned integer (two-byte uint16_t follows) - { - std::uint16_t number{}; - return get_number(input_format_t::cbor, number) && sax->number_unsigned(number); - } - - case 0x1A: // Unsigned integer (four-byte uint32_t follows) - { - std::uint32_t number{}; - return get_number(input_format_t::cbor, number) && sax->number_unsigned(number); - } - - case 0x1B: // Unsigned integer (eight-byte uint64_t follows) - { - std::uint64_t number{}; - return get_number(input_format_t::cbor, number) && sax->number_unsigned(number); - } - - // Negative integer -1-0x00..-1-0x17 (-1..-24) - case 0x20: - case 0x21: - case 0x22: - case 0x23: - case 0x24: - case 0x25: - case 0x26: - case 0x27: - case 0x28: - case 0x29: - case 0x2A: - case 0x2B: - case 0x2C: - case 0x2D: - case 0x2E: - case 0x2F: - case 0x30: - case 0x31: - case 0x32: - case 0x33: - case 0x34: - case 0x35: - case 0x36: - case 0x37: - return sax->number_integer(static_cast(0x20 - 1 - current)); - - case 0x38: // Negative integer (one-byte uint8_t follows) - { - std::uint8_t number{}; - return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast(-1) - number); - } - - case 0x39: // Negative integer -1-n (two-byte uint16_t follows) - { - std::uint16_t number{}; - return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast(-1) - number); - } - - case 0x3A: // Negative integer -1-n (four-byte uint32_t follows) - { - std::uint32_t number{}; - return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast(-1) - number); - } - - case 0x3B: // Negative integer -1-n (eight-byte uint64_t follows) - { - std::uint64_t number{}; - return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast(-1) - - static_cast(number)); - } - - // Binary data (0x00..0x17 bytes follow) - case 0x40: - case 0x41: - case 0x42: - case 0x43: - case 0x44: - case 0x45: - case 0x46: - case 0x47: - case 0x48: - case 0x49: - case 0x4A: - case 0x4B: - case 0x4C: - case 0x4D: - case 0x4E: - case 0x4F: - case 0x50: - case 0x51: - case 0x52: - case 0x53: - case 0x54: - case 0x55: - case 0x56: - case 0x57: - case 0x58: // Binary data (one-byte uint8_t for n follows) - case 0x59: // Binary data (two-byte uint16_t for n follow) - case 0x5A: // Binary data (four-byte uint32_t for n follow) - case 0x5B: // Binary data (eight-byte uint64_t for n follow) - case 0x5F: // Binary data (indefinite length) - { - binary_t b; - return get_cbor_binary(b) && sax->binary(b); - } - - // UTF-8 string (0x00..0x17 bytes follow) - case 0x60: - case 0x61: - case 0x62: - case 0x63: - case 0x64: - case 0x65: - case 0x66: - case 0x67: - case 0x68: - case 0x69: - case 0x6A: - case 0x6B: - case 0x6C: - case 0x6D: - case 0x6E: - case 0x6F: - case 0x70: - case 0x71: - case 0x72: - case 0x73: - case 0x74: - case 0x75: - case 0x76: - case 0x77: - case 0x78: // UTF-8 string (one-byte uint8_t for n follows) - case 0x79: // UTF-8 string (two-byte uint16_t for n follow) - case 0x7A: // UTF-8 string (four-byte uint32_t for n follow) - case 0x7B: // UTF-8 string (eight-byte uint64_t for n follow) - case 0x7F: // UTF-8 string (indefinite length) - { - string_t s; - return get_cbor_string(s) && sax->string(s); - } - - // array (0x00..0x17 data items follow) - case 0x80: - case 0x81: - case 0x82: - case 0x83: - case 0x84: - case 0x85: - case 0x86: - case 0x87: - case 0x88: - case 0x89: - case 0x8A: - case 0x8B: - case 0x8C: - case 0x8D: - case 0x8E: - case 0x8F: - case 0x90: - case 0x91: - case 0x92: - case 0x93: - case 0x94: - case 0x95: - case 0x96: - case 0x97: - return get_cbor_array(static_cast(static_cast(current) & 0x1Fu), tag_handler); - - case 0x98: // array (one-byte uint8_t for n follows) - { - std::uint8_t len{}; - return get_number(input_format_t::cbor, len) && get_cbor_array(static_cast(len), tag_handler); - } - - case 0x99: // array (two-byte uint16_t for n follow) - { - std::uint16_t len{}; - return get_number(input_format_t::cbor, len) && get_cbor_array(static_cast(len), tag_handler); - } - - case 0x9A: // array (four-byte uint32_t for n follow) - { - std::uint32_t len{}; - return get_number(input_format_t::cbor, len) && get_cbor_array(static_cast(len), tag_handler); - } - - case 0x9B: // array (eight-byte uint64_t for n follow) - { - std::uint64_t len{}; - return get_number(input_format_t::cbor, len) && get_cbor_array(static_cast(len), tag_handler); - } - - case 0x9F: // array (indefinite length) - return get_cbor_array(std::size_t(-1), tag_handler); - - // map (0x00..0x17 pairs of data items follow) - case 0xA0: - case 0xA1: - case 0xA2: - case 0xA3: - case 0xA4: - case 0xA5: - case 0xA6: - case 0xA7: - case 0xA8: - case 0xA9: - case 0xAA: - case 0xAB: - case 0xAC: - case 0xAD: - case 0xAE: - case 0xAF: - case 0xB0: - case 0xB1: - case 0xB2: - case 0xB3: - case 0xB4: - case 0xB5: - case 0xB6: - case 0xB7: - return get_cbor_object(static_cast(static_cast(current) & 0x1Fu), tag_handler); - - case 0xB8: // map (one-byte uint8_t for n follows) - { - std::uint8_t len{}; - return get_number(input_format_t::cbor, len) && get_cbor_object(static_cast(len), tag_handler); - } - - case 0xB9: // map (two-byte uint16_t for n follow) - { - std::uint16_t len{}; - return get_number(input_format_t::cbor, len) && get_cbor_object(static_cast(len), tag_handler); - } - - case 0xBA: // map (four-byte uint32_t for n follow) - { - std::uint32_t len{}; - return get_number(input_format_t::cbor, len) && get_cbor_object(static_cast(len), tag_handler); - } - - case 0xBB: // map (eight-byte uint64_t for n follow) - { - std::uint64_t len{}; - return get_number(input_format_t::cbor, len) && get_cbor_object(static_cast(len), tag_handler); - } - - case 0xBF: // map (indefinite length) - return get_cbor_object(std::size_t(-1), tag_handler); - - case 0xC6: // tagged item - case 0xC7: - case 0xC8: - case 0xC9: - case 0xCA: - case 0xCB: - case 0xCC: - case 0xCD: - case 0xCE: - case 0xCF: - case 0xD0: - case 0xD1: - case 0xD2: - case 0xD3: - case 0xD4: - case 0xD8: // tagged item (1 bytes follow) - case 0xD9: // tagged item (2 bytes follow) - case 0xDA: // tagged item (4 bytes follow) - case 0xDB: // tagged item (8 bytes follow) - { - switch (tag_handler) - { - case cbor_tag_handler_t::error: - { - auto last_token = get_token_string(); - return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::cbor, "invalid byte: 0x" + last_token, "value"))); - } - - case cbor_tag_handler_t::ignore: - { - switch (current) - { - case 0xD8: - { - std::uint8_t len{}; - get_number(input_format_t::cbor, len); - break; - } - case 0xD9: - { - std::uint16_t len{}; - get_number(input_format_t::cbor, len); - break; - } - case 0xDA: - { - std::uint32_t len{}; - get_number(input_format_t::cbor, len); - break; - } - case 0xDB: - { - std::uint64_t len{}; - get_number(input_format_t::cbor, len); - break; - } - default: - break; - } - return parse_cbor_internal(true, tag_handler); - } - - default: // LCOV_EXCL_LINE - JSON_ASSERT(false); // LCOV_EXCL_LINE - } - } - - case 0xF4: // false - return sax->boolean(false); - - case 0xF5: // true - return sax->boolean(true); - - case 0xF6: // null - return sax->null(); - - case 0xF9: // Half-Precision Float (two-byte IEEE 754) - { - const auto byte1_raw = get(); - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, "number"))) - { - return false; - } - const auto byte2_raw = get(); - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, "number"))) - { - return false; - } - - const auto byte1 = static_cast(byte1_raw); - const auto byte2 = static_cast(byte2_raw); - - // code from RFC 7049, Appendix D, Figure 3: - // As half-precision floating-point numbers were only added - // to IEEE 754 in 2008, today's programming platforms often - // still only have limited support for them. It is very - // easy to include at least decoding support for them even - // without such support. An example of a small decoder for - // half-precision floating-point numbers in the C language - // is shown in Fig. 3. - const auto half = static_cast((byte1 << 8u) + byte2); - const double val = [&half] - { - const int exp = (half >> 10u) & 0x1Fu; - const unsigned int mant = half & 0x3FFu; - JSON_ASSERT(0 <= exp&& exp <= 32); - JSON_ASSERT(mant <= 1024); - switch (exp) - { - case 0: - return std::ldexp(mant, -24); - case 31: - return (mant == 0) - ? std::numeric_limits::infinity() - : std::numeric_limits::quiet_NaN(); - default: - return std::ldexp(mant + 1024, exp - 25); - } - }(); - return sax->number_float((half & 0x8000u) != 0 - ? static_cast(-val) - : static_cast(val), ""); - } - - case 0xFA: // Single-Precision Float (four-byte IEEE 754) - { - float number{}; - return get_number(input_format_t::cbor, number) && sax->number_float(static_cast(number), ""); - } - - case 0xFB: // Double-Precision Float (eight-byte IEEE 754) - { - double number{}; - return get_number(input_format_t::cbor, number) && sax->number_float(static_cast(number), ""); - } - - default: // anything else (0xFF is handled inside the other types) - { - auto last_token = get_token_string(); - return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::cbor, "invalid byte: 0x" + last_token, "value"))); - } - } - } - - /*! - @brief reads a CBOR string - - This function first reads starting bytes to determine the expected - string length and then copies this number of bytes into a string. - Additionally, CBOR's strings with indefinite lengths are supported. - - @param[out] result created string - - @return whether string creation completed - */ - bool get_cbor_string(string_t& result) - { - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, "string"))) - { - return false; - } - - switch (current) - { - // UTF-8 string (0x00..0x17 bytes follow) - case 0x60: - case 0x61: - case 0x62: - case 0x63: - case 0x64: - case 0x65: - case 0x66: - case 0x67: - case 0x68: - case 0x69: - case 0x6A: - case 0x6B: - case 0x6C: - case 0x6D: - case 0x6E: - case 0x6F: - case 0x70: - case 0x71: - case 0x72: - case 0x73: - case 0x74: - case 0x75: - case 0x76: - case 0x77: - { - return get_string(input_format_t::cbor, static_cast(current) & 0x1Fu, result); - } - - case 0x78: // UTF-8 string (one-byte uint8_t for n follows) - { - std::uint8_t len{}; - return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result); - } - - case 0x79: // UTF-8 string (two-byte uint16_t for n follow) - { - std::uint16_t len{}; - return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result); - } - - case 0x7A: // UTF-8 string (four-byte uint32_t for n follow) - { - std::uint32_t len{}; - return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result); - } - - case 0x7B: // UTF-8 string (eight-byte uint64_t for n follow) - { - std::uint64_t len{}; - return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result); - } - - case 0x7F: // UTF-8 string (indefinite length) - { - while (get() != 0xFF) - { - string_t chunk; - if (!get_cbor_string(chunk)) - { - return false; - } - result.append(chunk); - } - return true; - } - - default: - { - auto last_token = get_token_string(); - return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::cbor, "expected length specification (0x60-0x7B) or indefinite string type (0x7F); last byte: 0x" + last_token, "string"))); - } - } - } - - /*! - @brief reads a CBOR byte array - - This function first reads starting bytes to determine the expected - byte array length and then copies this number of bytes into the byte array. - Additionally, CBOR's byte arrays with indefinite lengths are supported. - - @param[out] result created byte array - - @return whether byte array creation completed - */ - bool get_cbor_binary(binary_t& result) - { - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, "binary"))) - { - return false; - } - - switch (current) - { - // Binary data (0x00..0x17 bytes follow) - case 0x40: - case 0x41: - case 0x42: - case 0x43: - case 0x44: - case 0x45: - case 0x46: - case 0x47: - case 0x48: - case 0x49: - case 0x4A: - case 0x4B: - case 0x4C: - case 0x4D: - case 0x4E: - case 0x4F: - case 0x50: - case 0x51: - case 0x52: - case 0x53: - case 0x54: - case 0x55: - case 0x56: - case 0x57: - { - return get_binary(input_format_t::cbor, static_cast(current) & 0x1Fu, result); - } - - case 0x58: // Binary data (one-byte uint8_t for n follows) - { - std::uint8_t len{}; - return get_number(input_format_t::cbor, len) && - get_binary(input_format_t::cbor, len, result); - } - - case 0x59: // Binary data (two-byte uint16_t for n follow) - { - std::uint16_t len{}; - return get_number(input_format_t::cbor, len) && - get_binary(input_format_t::cbor, len, result); - } - - case 0x5A: // Binary data (four-byte uint32_t for n follow) - { - std::uint32_t len{}; - return get_number(input_format_t::cbor, len) && - get_binary(input_format_t::cbor, len, result); - } - - case 0x5B: // Binary data (eight-byte uint64_t for n follow) - { - std::uint64_t len{}; - return get_number(input_format_t::cbor, len) && - get_binary(input_format_t::cbor, len, result); - } - - case 0x5F: // Binary data (indefinite length) - { - while (get() != 0xFF) - { - binary_t chunk; - if (!get_cbor_binary(chunk)) - { - return false; - } - result.insert(result.end(), chunk.begin(), chunk.end()); - } - return true; - } - - default: - { - auto last_token = get_token_string(); - return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::cbor, "expected length specification (0x40-0x5B) or indefinite binary array type (0x5F); last byte: 0x" + last_token, "binary"))); - } - } - } - - /*! - @param[in] len the length of the array or std::size_t(-1) for an - array of indefinite size - @param[in] tag_handler how CBOR tags should be treated - @return whether array creation completed - */ - bool get_cbor_array(const std::size_t len, - const cbor_tag_handler_t tag_handler) - { - if (JSON_HEDLEY_UNLIKELY(!sax->start_array(len))) - { - return false; - } - - if (len != std::size_t(-1)) - { - for (std::size_t i = 0; i < len; ++i) - { - if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(true, tag_handler))) - { - return false; - } - } - } - else - { - while (get() != 0xFF) - { - if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(false, tag_handler))) - { - return false; - } - } - } - - return sax->end_array(); - } - - /*! - @param[in] len the length of the object or std::size_t(-1) for an - object of indefinite size - @param[in] tag_handler how CBOR tags should be treated - @return whether object creation completed - */ - bool get_cbor_object(const std::size_t len, - const cbor_tag_handler_t tag_handler) - { - if (JSON_HEDLEY_UNLIKELY(!sax->start_object(len))) - { - return false; - } - - string_t key; - if (len != std::size_t(-1)) - { - for (std::size_t i = 0; i < len; ++i) - { - get(); - if (JSON_HEDLEY_UNLIKELY(!get_cbor_string(key) || !sax->key(key))) - { - return false; - } - - if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(true, tag_handler))) - { - return false; - } - key.clear(); - } - } - else - { - while (get() != 0xFF) - { - if (JSON_HEDLEY_UNLIKELY(!get_cbor_string(key) || !sax->key(key))) - { - return false; - } - - if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(true, tag_handler))) - { - return false; - } - key.clear(); - } - } - - return sax->end_object(); - } - - ///////////// - // MsgPack // - ///////////// - - /*! - @return whether a valid MessagePack value was passed to the SAX parser - */ - bool parse_msgpack_internal() - { - switch (get()) - { - // EOF - case std::char_traits::eof(): - return unexpect_eof(input_format_t::msgpack, "value"); - - // positive fixint - case 0x00: - case 0x01: - case 0x02: - case 0x03: - case 0x04: - case 0x05: - case 0x06: - case 0x07: - case 0x08: - case 0x09: - case 0x0A: - case 0x0B: - case 0x0C: - case 0x0D: - case 0x0E: - case 0x0F: - case 0x10: - case 0x11: - case 0x12: - case 0x13: - case 0x14: - case 0x15: - case 0x16: - case 0x17: - case 0x18: - case 0x19: - case 0x1A: - case 0x1B: - case 0x1C: - case 0x1D: - case 0x1E: - case 0x1F: - case 0x20: - case 0x21: - case 0x22: - case 0x23: - case 0x24: - case 0x25: - case 0x26: - case 0x27: - case 0x28: - case 0x29: - case 0x2A: - case 0x2B: - case 0x2C: - case 0x2D: - case 0x2E: - case 0x2F: - case 0x30: - case 0x31: - case 0x32: - case 0x33: - case 0x34: - case 0x35: - case 0x36: - case 0x37: - case 0x38: - case 0x39: - case 0x3A: - case 0x3B: - case 0x3C: - case 0x3D: - case 0x3E: - case 0x3F: - case 0x40: - case 0x41: - case 0x42: - case 0x43: - case 0x44: - case 0x45: - case 0x46: - case 0x47: - case 0x48: - case 0x49: - case 0x4A: - case 0x4B: - case 0x4C: - case 0x4D: - case 0x4E: - case 0x4F: - case 0x50: - case 0x51: - case 0x52: - case 0x53: - case 0x54: - case 0x55: - case 0x56: - case 0x57: - case 0x58: - case 0x59: - case 0x5A: - case 0x5B: - case 0x5C: - case 0x5D: - case 0x5E: - case 0x5F: - case 0x60: - case 0x61: - case 0x62: - case 0x63: - case 0x64: - case 0x65: - case 0x66: - case 0x67: - case 0x68: - case 0x69: - case 0x6A: - case 0x6B: - case 0x6C: - case 0x6D: - case 0x6E: - case 0x6F: - case 0x70: - case 0x71: - case 0x72: - case 0x73: - case 0x74: - case 0x75: - case 0x76: - case 0x77: - case 0x78: - case 0x79: - case 0x7A: - case 0x7B: - case 0x7C: - case 0x7D: - case 0x7E: - case 0x7F: - return sax->number_unsigned(static_cast(current)); - - // fixmap - case 0x80: - case 0x81: - case 0x82: - case 0x83: - case 0x84: - case 0x85: - case 0x86: - case 0x87: - case 0x88: - case 0x89: - case 0x8A: - case 0x8B: - case 0x8C: - case 0x8D: - case 0x8E: - case 0x8F: - return get_msgpack_object(static_cast(static_cast(current) & 0x0Fu)); - - // fixarray - case 0x90: - case 0x91: - case 0x92: - case 0x93: - case 0x94: - case 0x95: - case 0x96: - case 0x97: - case 0x98: - case 0x99: - case 0x9A: - case 0x9B: - case 0x9C: - case 0x9D: - case 0x9E: - case 0x9F: - return get_msgpack_array(static_cast(static_cast(current) & 0x0Fu)); - - // fixstr - case 0xA0: - case 0xA1: - case 0xA2: - case 0xA3: - case 0xA4: - case 0xA5: - case 0xA6: - case 0xA7: - case 0xA8: - case 0xA9: - case 0xAA: - case 0xAB: - case 0xAC: - case 0xAD: - case 0xAE: - case 0xAF: - case 0xB0: - case 0xB1: - case 0xB2: - case 0xB3: - case 0xB4: - case 0xB5: - case 0xB6: - case 0xB7: - case 0xB8: - case 0xB9: - case 0xBA: - case 0xBB: - case 0xBC: - case 0xBD: - case 0xBE: - case 0xBF: - case 0xD9: // str 8 - case 0xDA: // str 16 - case 0xDB: // str 32 - { - string_t s; - return get_msgpack_string(s) && sax->string(s); - } - - case 0xC0: // nil - return sax->null(); - - case 0xC2: // false - return sax->boolean(false); - - case 0xC3: // true - return sax->boolean(true); - - case 0xC4: // bin 8 - case 0xC5: // bin 16 - case 0xC6: // bin 32 - case 0xC7: // ext 8 - case 0xC8: // ext 16 - case 0xC9: // ext 32 - case 0xD4: // fixext 1 - case 0xD5: // fixext 2 - case 0xD6: // fixext 4 - case 0xD7: // fixext 8 - case 0xD8: // fixext 16 - { - binary_t b; - return get_msgpack_binary(b) && sax->binary(b); - } - - case 0xCA: // float 32 - { - float number{}; - return get_number(input_format_t::msgpack, number) && sax->number_float(static_cast(number), ""); - } - - case 0xCB: // float 64 - { - double number{}; - return get_number(input_format_t::msgpack, number) && sax->number_float(static_cast(number), ""); - } - - case 0xCC: // uint 8 - { - std::uint8_t number{}; - return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number); - } - - case 0xCD: // uint 16 - { - std::uint16_t number{}; - return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number); - } - - case 0xCE: // uint 32 - { - std::uint32_t number{}; - return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number); - } - - case 0xCF: // uint 64 - { - std::uint64_t number{}; - return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number); - } - - case 0xD0: // int 8 - { - std::int8_t number{}; - return get_number(input_format_t::msgpack, number) && sax->number_integer(number); - } - - case 0xD1: // int 16 - { - std::int16_t number{}; - return get_number(input_format_t::msgpack, number) && sax->number_integer(number); - } - - case 0xD2: // int 32 - { - std::int32_t number{}; - return get_number(input_format_t::msgpack, number) && sax->number_integer(number); - } - - case 0xD3: // int 64 - { - std::int64_t number{}; - return get_number(input_format_t::msgpack, number) && sax->number_integer(number); - } - - case 0xDC: // array 16 - { - std::uint16_t len{}; - return get_number(input_format_t::msgpack, len) && get_msgpack_array(static_cast(len)); - } - - case 0xDD: // array 32 - { - std::uint32_t len{}; - return get_number(input_format_t::msgpack, len) && get_msgpack_array(static_cast(len)); - } - - case 0xDE: // map 16 - { - std::uint16_t len{}; - return get_number(input_format_t::msgpack, len) && get_msgpack_object(static_cast(len)); - } - - case 0xDF: // map 32 - { - std::uint32_t len{}; - return get_number(input_format_t::msgpack, len) && get_msgpack_object(static_cast(len)); - } - - // negative fixint - case 0xE0: - case 0xE1: - case 0xE2: - case 0xE3: - case 0xE4: - case 0xE5: - case 0xE6: - case 0xE7: - case 0xE8: - case 0xE9: - case 0xEA: - case 0xEB: - case 0xEC: - case 0xED: - case 0xEE: - case 0xEF: - case 0xF0: - case 0xF1: - case 0xF2: - case 0xF3: - case 0xF4: - case 0xF5: - case 0xF6: - case 0xF7: - case 0xF8: - case 0xF9: - case 0xFA: - case 0xFB: - case 0xFC: - case 0xFD: - case 0xFE: - case 0xFF: - return sax->number_integer(static_cast(current)); - - default: // anything else - { - auto last_token = get_token_string(); - return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::msgpack, "invalid byte: 0x" + last_token, "value"))); - } - } - } - - /*! - @brief reads a MessagePack string - - This function first reads starting bytes to determine the expected - string length and then copies this number of bytes into a string. - - @param[out] result created string - - @return whether string creation completed - */ - bool get_msgpack_string(string_t& result) - { - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::msgpack, "string"))) - { - return false; - } - - switch (current) - { - // fixstr - case 0xA0: - case 0xA1: - case 0xA2: - case 0xA3: - case 0xA4: - case 0xA5: - case 0xA6: - case 0xA7: - case 0xA8: - case 0xA9: - case 0xAA: - case 0xAB: - case 0xAC: - case 0xAD: - case 0xAE: - case 0xAF: - case 0xB0: - case 0xB1: - case 0xB2: - case 0xB3: - case 0xB4: - case 0xB5: - case 0xB6: - case 0xB7: - case 0xB8: - case 0xB9: - case 0xBA: - case 0xBB: - case 0xBC: - case 0xBD: - case 0xBE: - case 0xBF: - { - return get_string(input_format_t::msgpack, static_cast(current) & 0x1Fu, result); - } - - case 0xD9: // str 8 - { - std::uint8_t len{}; - return get_number(input_format_t::msgpack, len) && get_string(input_format_t::msgpack, len, result); - } - - case 0xDA: // str 16 - { - std::uint16_t len{}; - return get_number(input_format_t::msgpack, len) && get_string(input_format_t::msgpack, len, result); - } - - case 0xDB: // str 32 - { - std::uint32_t len{}; - return get_number(input_format_t::msgpack, len) && get_string(input_format_t::msgpack, len, result); - } - - default: - { - auto last_token = get_token_string(); - return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::msgpack, "expected length specification (0xA0-0xBF, 0xD9-0xDB); last byte: 0x" + last_token, "string"))); - } - } - } - - /*! - @brief reads a MessagePack byte array - - This function first reads starting bytes to determine the expected - byte array length and then copies this number of bytes into a byte array. - - @param[out] result created byte array - - @return whether byte array creation completed - */ - bool get_msgpack_binary(binary_t& result) - { - // helper function to set the subtype - auto assign_and_return_true = [&result](std::int8_t subtype) - { - result.set_subtype(static_cast(subtype)); - return true; - }; - - switch (current) - { - case 0xC4: // bin 8 - { - std::uint8_t len{}; - return get_number(input_format_t::msgpack, len) && - get_binary(input_format_t::msgpack, len, result); - } - - case 0xC5: // bin 16 - { - std::uint16_t len{}; - return get_number(input_format_t::msgpack, len) && - get_binary(input_format_t::msgpack, len, result); - } - - case 0xC6: // bin 32 - { - std::uint32_t len{}; - return get_number(input_format_t::msgpack, len) && - get_binary(input_format_t::msgpack, len, result); - } - - case 0xC7: // ext 8 - { - std::uint8_t len{}; - std::int8_t subtype{}; - return get_number(input_format_t::msgpack, len) && - get_number(input_format_t::msgpack, subtype) && - get_binary(input_format_t::msgpack, len, result) && - assign_and_return_true(subtype); - } - - case 0xC8: // ext 16 - { - std::uint16_t len{}; - std::int8_t subtype{}; - return get_number(input_format_t::msgpack, len) && - get_number(input_format_t::msgpack, subtype) && - get_binary(input_format_t::msgpack, len, result) && - assign_and_return_true(subtype); - } - - case 0xC9: // ext 32 - { - std::uint32_t len{}; - std::int8_t subtype{}; - return get_number(input_format_t::msgpack, len) && - get_number(input_format_t::msgpack, subtype) && - get_binary(input_format_t::msgpack, len, result) && - assign_and_return_true(subtype); - } - - case 0xD4: // fixext 1 - { - std::int8_t subtype{}; - return get_number(input_format_t::msgpack, subtype) && - get_binary(input_format_t::msgpack, 1, result) && - assign_and_return_true(subtype); - } - - case 0xD5: // fixext 2 - { - std::int8_t subtype{}; - return get_number(input_format_t::msgpack, subtype) && - get_binary(input_format_t::msgpack, 2, result) && - assign_and_return_true(subtype); - } - - case 0xD6: // fixext 4 - { - std::int8_t subtype{}; - return get_number(input_format_t::msgpack, subtype) && - get_binary(input_format_t::msgpack, 4, result) && - assign_and_return_true(subtype); - } - - case 0xD7: // fixext 8 - { - std::int8_t subtype{}; - return get_number(input_format_t::msgpack, subtype) && - get_binary(input_format_t::msgpack, 8, result) && - assign_and_return_true(subtype); - } - - case 0xD8: // fixext 16 - { - std::int8_t subtype{}; - return get_number(input_format_t::msgpack, subtype) && - get_binary(input_format_t::msgpack, 16, result) && - assign_and_return_true(subtype); - } - - default: // LCOV_EXCL_LINE - return false; // LCOV_EXCL_LINE - } - } - - /*! - @param[in] len the length of the array - @return whether array creation completed - */ - bool get_msgpack_array(const std::size_t len) - { - if (JSON_HEDLEY_UNLIKELY(!sax->start_array(len))) - { - return false; - } - - for (std::size_t i = 0; i < len; ++i) - { - if (JSON_HEDLEY_UNLIKELY(!parse_msgpack_internal())) - { - return false; - } - } - - return sax->end_array(); - } - - /*! - @param[in] len the length of the object - @return whether object creation completed - */ - bool get_msgpack_object(const std::size_t len) - { - if (JSON_HEDLEY_UNLIKELY(!sax->start_object(len))) - { - return false; - } - - string_t key; - for (std::size_t i = 0; i < len; ++i) - { - get(); - if (JSON_HEDLEY_UNLIKELY(!get_msgpack_string(key) || !sax->key(key))) - { - return false; - } - - if (JSON_HEDLEY_UNLIKELY(!parse_msgpack_internal())) - { - return false; - } - key.clear(); - } - - return sax->end_object(); - } - - //////////// - // UBJSON // - //////////// - - /*! - @param[in] get_char whether a new character should be retrieved from the - input (true, default) or whether the last read - character should be considered instead - - @return whether a valid UBJSON value was passed to the SAX parser - */ - bool parse_ubjson_internal(const bool get_char = true) - { - return get_ubjson_value(get_char ? get_ignore_noop() : current); - } - - /*! - @brief reads a UBJSON string - - This function is either called after reading the 'S' byte explicitly - indicating a string, or in case of an object key where the 'S' byte can be - left out. - - @param[out] result created string - @param[in] get_char whether a new character should be retrieved from the - input (true, default) or whether the last read - character should be considered instead - - @return whether string creation completed - */ - bool get_ubjson_string(string_t& result, const bool get_char = true) - { - if (get_char) - { - get(); // TODO(niels): may we ignore N here? - } - - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "value"))) - { - return false; - } - - switch (current) - { - case 'U': - { - std::uint8_t len{}; - return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); - } - - case 'i': - { - std::int8_t len{}; - return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); - } - - case 'I': - { - std::int16_t len{}; - return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); - } - - case 'l': - { - std::int32_t len{}; - return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); - } - - case 'L': - { - std::int64_t len{}; - return get_number(input_format_t::ubjson, len) && get_string(input_format_t::ubjson, len, result); - } - - default: - auto last_token = get_token_string(); - return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::ubjson, "expected length type specification (U, i, I, l, L); last byte: 0x" + last_token, "string"))); - } - } - - /*! - @param[out] result determined size - @return whether size determination completed - */ - bool get_ubjson_size_value(std::size_t& result) - { - switch (get_ignore_noop()) - { - case 'U': - { - std::uint8_t number{}; - if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) - { - return false; - } - result = static_cast(number); - return true; - } - - case 'i': - { - std::int8_t number{}; - if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) - { - return false; - } - result = static_cast(number); - return true; - } - - case 'I': - { - std::int16_t number{}; - if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) - { - return false; - } - result = static_cast(number); - return true; - } - - case 'l': - { - std::int32_t number{}; - if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) - { - return false; - } - result = static_cast(number); - return true; - } - - case 'L': - { - std::int64_t number{}; - if (JSON_HEDLEY_UNLIKELY(!get_number(input_format_t::ubjson, number))) - { - return false; - } - result = static_cast(number); - return true; - } - - default: - { - auto last_token = get_token_string(); - return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::ubjson, "expected length type specification (U, i, I, l, L) after '#'; last byte: 0x" + last_token, "size"))); - } - } - } - - /*! - @brief determine the type and size for a container - - In the optimized UBJSON format, a type and a size can be provided to allow - for a more compact representation. - - @param[out] result pair of the size and the type - - @return whether pair creation completed - */ - bool get_ubjson_size_type(std::pair& result) - { - result.first = string_t::npos; // size - result.second = 0; // type - - get_ignore_noop(); - - if (current == '$') - { - result.second = get(); // must not ignore 'N', because 'N' maybe the type - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "type"))) - { - return false; - } - - get_ignore_noop(); - if (JSON_HEDLEY_UNLIKELY(current != '#')) - { - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "value"))) - { - return false; - } - auto last_token = get_token_string(); - return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::ubjson, "expected '#' after type information; last byte: 0x" + last_token, "size"))); - } - - return get_ubjson_size_value(result.first); - } - - if (current == '#') - { - return get_ubjson_size_value(result.first); - } - - return true; - } - - /*! - @param prefix the previously read or set type prefix - @return whether value creation completed - */ - bool get_ubjson_value(const char_int_type prefix) - { - switch (prefix) - { - case std::char_traits::eof(): // EOF - return unexpect_eof(input_format_t::ubjson, "value"); - - case 'T': // true - return sax->boolean(true); - case 'F': // false - return sax->boolean(false); - - case 'Z': // null - return sax->null(); - - case 'U': - { - std::uint8_t number{}; - return get_number(input_format_t::ubjson, number) && sax->number_unsigned(number); - } - - case 'i': - { - std::int8_t number{}; - return get_number(input_format_t::ubjson, number) && sax->number_integer(number); - } - - case 'I': - { - std::int16_t number{}; - return get_number(input_format_t::ubjson, number) && sax->number_integer(number); - } - - case 'l': - { - std::int32_t number{}; - return get_number(input_format_t::ubjson, number) && sax->number_integer(number); - } - - case 'L': - { - std::int64_t number{}; - return get_number(input_format_t::ubjson, number) && sax->number_integer(number); - } - - case 'd': - { - float number{}; - return get_number(input_format_t::ubjson, number) && sax->number_float(static_cast(number), ""); - } - - case 'D': - { - double number{}; - return get_number(input_format_t::ubjson, number) && sax->number_float(static_cast(number), ""); - } - - case 'H': - { - return get_ubjson_high_precision_number(); - } - - case 'C': // char - { - get(); - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "char"))) - { - return false; - } - if (JSON_HEDLEY_UNLIKELY(current > 127)) - { - auto last_token = get_token_string(); - return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format_t::ubjson, "byte after 'C' must be in range 0x00..0x7F; last byte: 0x" + last_token, "char"))); - } - string_t s(1, static_cast(current)); - return sax->string(s); - } - - case 'S': // string - { - string_t s; - return get_ubjson_string(s) && sax->string(s); - } - - case '[': // array - return get_ubjson_array(); - - case '{': // object - return get_ubjson_object(); - - default: // anything else - { - auto last_token = get_token_string(); - return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format_t::ubjson, "invalid byte: 0x" + last_token, "value"))); - } - } - } - - /*! - @return whether array creation completed - */ - bool get_ubjson_array() - { - std::pair size_and_type; - if (JSON_HEDLEY_UNLIKELY(!get_ubjson_size_type(size_and_type))) - { - return false; - } - - if (size_and_type.first != string_t::npos) - { - if (JSON_HEDLEY_UNLIKELY(!sax->start_array(size_and_type.first))) - { - return false; - } - - if (size_and_type.second != 0) - { - if (size_and_type.second != 'N') - { - for (std::size_t i = 0; i < size_and_type.first; ++i) - { - if (JSON_HEDLEY_UNLIKELY(!get_ubjson_value(size_and_type.second))) - { - return false; - } - } - } - } - else - { - for (std::size_t i = 0; i < size_and_type.first; ++i) - { - if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal())) - { - return false; - } - } - } - } - else - { - if (JSON_HEDLEY_UNLIKELY(!sax->start_array(std::size_t(-1)))) - { - return false; - } - - while (current != ']') - { - if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal(false))) - { - return false; - } - get_ignore_noop(); - } - } - - return sax->end_array(); - } - - /*! - @return whether object creation completed - */ - bool get_ubjson_object() - { - std::pair size_and_type; - if (JSON_HEDLEY_UNLIKELY(!get_ubjson_size_type(size_and_type))) - { - return false; - } - - string_t key; - if (size_and_type.first != string_t::npos) - { - if (JSON_HEDLEY_UNLIKELY(!sax->start_object(size_and_type.first))) - { - return false; - } - - if (size_and_type.second != 0) - { - for (std::size_t i = 0; i < size_and_type.first; ++i) - { - if (JSON_HEDLEY_UNLIKELY(!get_ubjson_string(key) || !sax->key(key))) - { - return false; - } - if (JSON_HEDLEY_UNLIKELY(!get_ubjson_value(size_and_type.second))) - { - return false; - } - key.clear(); - } - } - else - { - for (std::size_t i = 0; i < size_and_type.first; ++i) - { - if (JSON_HEDLEY_UNLIKELY(!get_ubjson_string(key) || !sax->key(key))) - { - return false; - } - if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal())) - { - return false; - } - key.clear(); - } - } - } - else - { - if (JSON_HEDLEY_UNLIKELY(!sax->start_object(std::size_t(-1)))) - { - return false; - } - - while (current != '}') - { - if (JSON_HEDLEY_UNLIKELY(!get_ubjson_string(key, false) || !sax->key(key))) - { - return false; - } - if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal())) - { - return false; - } - get_ignore_noop(); - key.clear(); - } - } - - return sax->end_object(); - } - - // Note, no reader for UBJSON binary types is implemented because they do - // not exist - - bool get_ubjson_high_precision_number() - { - // get size of following number string - std::size_t size{}; - auto res = get_ubjson_size_value(size); - if (JSON_HEDLEY_UNLIKELY(!res)) - { - return res; - } - - // get number string - std::vector number_vector; - for (std::size_t i = 0; i < size; ++i) - { - get(); - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::ubjson, "number"))) - { - return false; - } - number_vector.push_back(static_cast(current)); - } - - // parse number string - auto number_ia = detail::input_adapter(std::forward(number_vector)); - auto number_lexer = detail::lexer(std::move(number_ia), false); - const auto result_number = number_lexer.scan(); - const auto number_string = number_lexer.get_token_string(); - const auto result_remainder = number_lexer.scan(); - - using token_type = typename detail::lexer_base::token_type; - - if (JSON_HEDLEY_UNLIKELY(result_remainder != token_type::end_of_input)) - { - return sax->parse_error(chars_read, number_string, parse_error::create(115, chars_read, exception_message(input_format_t::ubjson, "invalid number text: " + number_lexer.get_token_string(), "high-precision number"))); - } - - switch (result_number) - { - case token_type::value_integer: - return sax->number_integer(number_lexer.get_number_integer()); - case token_type::value_unsigned: - return sax->number_unsigned(number_lexer.get_number_unsigned()); - case token_type::value_float: - return sax->number_float(number_lexer.get_number_float(), std::move(number_string)); - default: - return sax->parse_error(chars_read, number_string, parse_error::create(115, chars_read, exception_message(input_format_t::ubjson, "invalid number text: " + number_lexer.get_token_string(), "high-precision number"))); - } - } - - /////////////////////// - // Utility functions // - /////////////////////// - - /*! - @brief get next character from the input - - This function provides the interface to the used input adapter. It does - not throw in case the input reached EOF, but returns a -'ve valued - `std::char_traits::eof()` in that case. - - @return character read from the input - */ - char_int_type get() - { - ++chars_read; - return current = ia.get_character(); - } - - /*! - @return character read from the input after ignoring all 'N' entries - */ - char_int_type get_ignore_noop() - { - do - { - get(); - } - while (current == 'N'); - - return current; - } - - /* - @brief read a number from the input - - @tparam NumberType the type of the number - @param[in] format the current format (for diagnostics) - @param[out] result number of type @a NumberType - - @return whether conversion completed - - @note This function needs to respect the system's endianess, because - bytes in CBOR, MessagePack, and UBJSON are stored in network order - (big endian) and therefore need reordering on little endian systems. - */ - template - bool get_number(const input_format_t format, NumberType& result) - { - // step 1: read input into array with system's byte order - std::array vec; - for (std::size_t i = 0; i < sizeof(NumberType); ++i) - { - get(); - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(format, "number"))) - { - return false; - } - - // reverse byte order prior to conversion if necessary - if (is_little_endian != InputIsLittleEndian) - { - vec[sizeof(NumberType) - i - 1] = static_cast(current); - } - else - { - vec[i] = static_cast(current); // LCOV_EXCL_LINE - } - } - - // step 2: convert array into number of type T and return - std::memcpy(&result, vec.data(), sizeof(NumberType)); - return true; - } - - /*! - @brief create a string by reading characters from the input - - @tparam NumberType the type of the number - @param[in] format the current format (for diagnostics) - @param[in] len number of characters to read - @param[out] result string created by reading @a len bytes - - @return whether string creation completed - - @note We can not reserve @a len bytes for the result, because @a len - may be too large. Usually, @ref unexpect_eof() detects the end of - the input before we run out of string memory. - */ - template - bool get_string(const input_format_t format, - const NumberType len, - string_t& result) - { - bool success = true; - for (NumberType i = 0; i < len; i++) - { - get(); - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(format, "string"))) - { - success = false; - break; - } - result.push_back(static_cast(current)); - }; - return success; - } - - /*! - @brief create a byte array by reading bytes from the input - - @tparam NumberType the type of the number - @param[in] format the current format (for diagnostics) - @param[in] len number of bytes to read - @param[out] result byte array created by reading @a len bytes - - @return whether byte array creation completed - - @note We can not reserve @a len bytes for the result, because @a len - may be too large. Usually, @ref unexpect_eof() detects the end of - the input before we run out of memory. - */ - template - bool get_binary(const input_format_t format, - const NumberType len, - binary_t& result) - { - bool success = true; - for (NumberType i = 0; i < len; i++) - { - get(); - if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(format, "binary"))) - { - success = false; - break; - } - result.push_back(static_cast(current)); - } - return success; - } - - /*! - @param[in] format the current format (for diagnostics) - @param[in] context further context information (for diagnostics) - @return whether the last read character is not EOF - */ - JSON_HEDLEY_NON_NULL(3) - bool unexpect_eof(const input_format_t format, const char* context) const - { - if (JSON_HEDLEY_UNLIKELY(current == std::char_traits::eof())) - { - return sax->parse_error(chars_read, "", - parse_error::create(110, chars_read, exception_message(format, "unexpected end of input", context))); - } - return true; - } - - /*! - @return a string representation of the last read byte - */ - std::string get_token_string() const - { - std::array cr{{}}; - (std::snprintf)(cr.data(), cr.size(), "%.2hhX", static_cast(current)); - return std::string{cr.data()}; - } - - /*! - @param[in] format the current format - @param[in] detail a detailed error message - @param[in] context further context information - @return a message string to use in the parse_error exceptions - */ - std::string exception_message(const input_format_t format, - const std::string& detail, - const std::string& context) const - { - std::string error_msg = "syntax error while parsing "; - - switch (format) - { - case input_format_t::cbor: - error_msg += "CBOR"; - break; - - case input_format_t::msgpack: - error_msg += "MessagePack"; - break; - - case input_format_t::ubjson: - error_msg += "UBJSON"; - break; - - case input_format_t::bson: - error_msg += "BSON"; - break; - - default: // LCOV_EXCL_LINE - JSON_ASSERT(false); // LCOV_EXCL_LINE - } - - return error_msg + " " + context + ": " + detail; - } - - private: - /// input adapter - InputAdapterType ia; - - /// the current character - char_int_type current = std::char_traits::eof(); - - /// the number of characters read - std::size_t chars_read = 0; - - /// whether we can assume little endianess - const bool is_little_endian = little_endianess(); - - /// the SAX parser - json_sax_t* sax = nullptr; -}; -} // namespace detail -} // namespace nlohmann - -// #include - -// #include - -// #include - - -#include // isfinite -#include // uint8_t -#include // function -#include // string -#include // move -#include // vector - -// #include - -// #include - -// #include - -// #include - -// #include - -// #include - -// #include - - -namespace nlohmann -{ -namespace detail -{ -//////////// -// parser // -//////////// - -enum class parse_event_t : uint8_t -{ - /// the parser read `{` and started to process a JSON object - object_start, - /// the parser read `}` and finished processing a JSON object - object_end, - /// the parser read `[` and started to process a JSON array - array_start, - /// the parser read `]` and finished processing a JSON array - array_end, - /// the parser read a key of a value in an object - key, - /// the parser finished reading a JSON value - value -}; - -template -using parser_callback_t = - std::function; - -/*! -@brief syntax analysis - -This class implements a recursive descent parser. -*/ -template -class parser -{ - using number_integer_t = typename BasicJsonType::number_integer_t; - using number_unsigned_t = typename BasicJsonType::number_unsigned_t; - using number_float_t = typename BasicJsonType::number_float_t; - using string_t = typename BasicJsonType::string_t; - using lexer_t = lexer; - using token_type = typename lexer_t::token_type; - - public: - /// a parser reading from an input adapter - explicit parser(InputAdapterType&& adapter, - const parser_callback_t cb = nullptr, - const bool allow_exceptions_ = true, - const bool skip_comments = false) - : callback(cb) - , m_lexer(std::move(adapter), skip_comments) - , allow_exceptions(allow_exceptions_) - { - // read first token - get_token(); - } - - /*! - @brief public parser interface - - @param[in] strict whether to expect the last token to be EOF - @param[in,out] result parsed JSON value - - @throw parse_error.101 in case of an unexpected token - @throw parse_error.102 if to_unicode fails or surrogate error - @throw parse_error.103 if to_unicode fails - */ - void parse(const bool strict, BasicJsonType& result) - { - if (callback) - { - json_sax_dom_callback_parser sdp(result, callback, allow_exceptions); - sax_parse_internal(&sdp); - result.assert_invariant(); - - // in strict mode, input must be completely read - if (strict && (get_token() != token_type::end_of_input)) - { - sdp.parse_error(m_lexer.get_position(), - m_lexer.get_token_string(), - parse_error::create(101, m_lexer.get_position(), - exception_message(token_type::end_of_input, "value"))); - } - - // in case of an error, return discarded value - if (sdp.is_errored()) - { - result = value_t::discarded; - return; - } - - // set top-level value to null if it was discarded by the callback - // function - if (result.is_discarded()) - { - result = nullptr; - } - } - else - { - json_sax_dom_parser sdp(result, allow_exceptions); - sax_parse_internal(&sdp); - result.assert_invariant(); - - // in strict mode, input must be completely read - if (strict && (get_token() != token_type::end_of_input)) - { - sdp.parse_error(m_lexer.get_position(), - m_lexer.get_token_string(), - parse_error::create(101, m_lexer.get_position(), - exception_message(token_type::end_of_input, "value"))); - } - - // in case of an error, return discarded value - if (sdp.is_errored()) - { - result = value_t::discarded; - return; - } - } - } - - /*! - @brief public accept interface - - @param[in] strict whether to expect the last token to be EOF - @return whether the input is a proper JSON text - */ - bool accept(const bool strict = true) - { - json_sax_acceptor sax_acceptor; - return sax_parse(&sax_acceptor, strict); - } - - template - JSON_HEDLEY_NON_NULL(2) - bool sax_parse(SAX* sax, const bool strict = true) - { - (void)detail::is_sax_static_asserts {}; - const bool result = sax_parse_internal(sax); - - // strict mode: next byte must be EOF - if (result && strict && (get_token() != token_type::end_of_input)) - { - return sax->parse_error(m_lexer.get_position(), - m_lexer.get_token_string(), - parse_error::create(101, m_lexer.get_position(), - exception_message(token_type::end_of_input, "value"))); - } - - return result; - } - - private: - template - JSON_HEDLEY_NON_NULL(2) - bool sax_parse_internal(SAX* sax) - { - // stack to remember the hierarchy of structured values we are parsing - // true = array; false = object - std::vector states; - // value to avoid a goto (see comment where set to true) - bool skip_to_state_evaluation = false; - - while (true) - { - if (!skip_to_state_evaluation) - { - // invariant: get_token() was called before each iteration - switch (last_token) - { - case token_type::begin_object: - { - if (JSON_HEDLEY_UNLIKELY(!sax->start_object(std::size_t(-1)))) - { - return false; - } - - // closing } -> we are done - if (get_token() == token_type::end_object) - { - if (JSON_HEDLEY_UNLIKELY(!sax->end_object())) - { - return false; - } - break; - } - - // parse key - if (JSON_HEDLEY_UNLIKELY(last_token != token_type::value_string)) - { - return sax->parse_error(m_lexer.get_position(), - m_lexer.get_token_string(), - parse_error::create(101, m_lexer.get_position(), - exception_message(token_type::value_string, "object key"))); - } - if (JSON_HEDLEY_UNLIKELY(!sax->key(m_lexer.get_string()))) - { - return false; - } - - // parse separator (:) - if (JSON_HEDLEY_UNLIKELY(get_token() != token_type::name_separator)) - { - return sax->parse_error(m_lexer.get_position(), - m_lexer.get_token_string(), - parse_error::create(101, m_lexer.get_position(), - exception_message(token_type::name_separator, "object separator"))); - } - - // remember we are now inside an object - states.push_back(false); - - // parse values - get_token(); - continue; - } - - case token_type::begin_array: - { - if (JSON_HEDLEY_UNLIKELY(!sax->start_array(std::size_t(-1)))) - { - return false; - } - - // closing ] -> we are done - if (get_token() == token_type::end_array) - { - if (JSON_HEDLEY_UNLIKELY(!sax->end_array())) - { - return false; - } - break; - } - - // remember we are now inside an array - states.push_back(true); - - // parse values (no need to call get_token) - continue; - } - - case token_type::value_float: - { - const auto res = m_lexer.get_number_float(); - - if (JSON_HEDLEY_UNLIKELY(!std::isfinite(res))) - { - return sax->parse_error(m_lexer.get_position(), - m_lexer.get_token_string(), - out_of_range::create(406, "number overflow parsing '" + m_lexer.get_token_string() + "'")); - } - - if (JSON_HEDLEY_UNLIKELY(!sax->number_float(res, m_lexer.get_string()))) - { - return false; - } - - break; - } - - case token_type::literal_false: - { - if (JSON_HEDLEY_UNLIKELY(!sax->boolean(false))) - { - return false; - } - break; - } - - case token_type::literal_null: - { - if (JSON_HEDLEY_UNLIKELY(!sax->null())) - { - return false; - } - break; - } - - case token_type::literal_true: - { - if (JSON_HEDLEY_UNLIKELY(!sax->boolean(true))) - { - return false; - } - break; - } - - case token_type::value_integer: - { - if (JSON_HEDLEY_UNLIKELY(!sax->number_integer(m_lexer.get_number_integer()))) - { - return false; - } - break; - } - - case token_type::value_string: - { - if (JSON_HEDLEY_UNLIKELY(!sax->string(m_lexer.get_string()))) - { - return false; - } - break; - } - - case token_type::value_unsigned: - { - if (JSON_HEDLEY_UNLIKELY(!sax->number_unsigned(m_lexer.get_number_unsigned()))) - { - return false; - } - break; - } - - case token_type::parse_error: - { - // using "uninitialized" to avoid "expected" message - return sax->parse_error(m_lexer.get_position(), - m_lexer.get_token_string(), - parse_error::create(101, m_lexer.get_position(), - exception_message(token_type::uninitialized, "value"))); - } - - default: // the last token was unexpected - { - return sax->parse_error(m_lexer.get_position(), - m_lexer.get_token_string(), - parse_error::create(101, m_lexer.get_position(), - exception_message(token_type::literal_or_value, "value"))); - } - } - } - else - { - skip_to_state_evaluation = false; - } - - // we reached this line after we successfully parsed a value - if (states.empty()) - { - // empty stack: we reached the end of the hierarchy: done - return true; - } - - if (states.back()) // array - { - // comma -> next value - if (get_token() == token_type::value_separator) - { - // parse a new value - get_token(); - continue; - } - - // closing ] - if (JSON_HEDLEY_LIKELY(last_token == token_type::end_array)) - { - if (JSON_HEDLEY_UNLIKELY(!sax->end_array())) - { - return false; - } - - // We are done with this array. Before we can parse a - // new value, we need to evaluate the new state first. - // By setting skip_to_state_evaluation to false, we - // are effectively jumping to the beginning of this if. - JSON_ASSERT(!states.empty()); - states.pop_back(); - skip_to_state_evaluation = true; - continue; - } - - return sax->parse_error(m_lexer.get_position(), - m_lexer.get_token_string(), - parse_error::create(101, m_lexer.get_position(), - exception_message(token_type::end_array, "array"))); - } - else // object - { - // comma -> next value - if (get_token() == token_type::value_separator) - { - // parse key - if (JSON_HEDLEY_UNLIKELY(get_token() != token_type::value_string)) - { - return sax->parse_error(m_lexer.get_position(), - m_lexer.get_token_string(), - parse_error::create(101, m_lexer.get_position(), - exception_message(token_type::value_string, "object key"))); - } - - if (JSON_HEDLEY_UNLIKELY(!sax->key(m_lexer.get_string()))) - { - return false; - } - - // parse separator (:) - if (JSON_HEDLEY_UNLIKELY(get_token() != token_type::name_separator)) - { - return sax->parse_error(m_lexer.get_position(), - m_lexer.get_token_string(), - parse_error::create(101, m_lexer.get_position(), - exception_message(token_type::name_separator, "object separator"))); - } - - // parse values - get_token(); - continue; - } - - // closing } - if (JSON_HEDLEY_LIKELY(last_token == token_type::end_object)) - { - if (JSON_HEDLEY_UNLIKELY(!sax->end_object())) - { - return false; - } - - // We are done with this object. Before we can parse a - // new value, we need to evaluate the new state first. - // By setting skip_to_state_evaluation to false, we - // are effectively jumping to the beginning of this if. - JSON_ASSERT(!states.empty()); - states.pop_back(); - skip_to_state_evaluation = true; - continue; - } - - return sax->parse_error(m_lexer.get_position(), - m_lexer.get_token_string(), - parse_error::create(101, m_lexer.get_position(), - exception_message(token_type::end_object, "object"))); - } - } - } - - /// get next token from lexer - token_type get_token() - { - return last_token = m_lexer.scan(); - } - - std::string exception_message(const token_type expected, const std::string& context) - { - std::string error_msg = "syntax error "; - - if (!context.empty()) - { - error_msg += "while parsing " + context + " "; - } - - error_msg += "- "; - - if (last_token == token_type::parse_error) - { - error_msg += std::string(m_lexer.get_error_message()) + "; last read: '" + - m_lexer.get_token_string() + "'"; - } - else - { - error_msg += "unexpected " + std::string(lexer_t::token_type_name(last_token)); - } - - if (expected != token_type::uninitialized) - { - error_msg += "; expected " + std::string(lexer_t::token_type_name(expected)); - } - - return error_msg; - } - - private: - /// callback function - const parser_callback_t callback = nullptr; - /// the type of the last read token - token_type last_token = token_type::uninitialized; - /// the lexer - lexer_t m_lexer; - /// whether to throw exceptions in case of errors - const bool allow_exceptions = true; -}; -} // namespace detail -} // namespace nlohmann - -// #include - - -// #include - - -#include // ptrdiff_t -#include // numeric_limits - -namespace nlohmann -{ -namespace detail -{ -/* -@brief an iterator for primitive JSON types - -This class models an iterator for primitive JSON types (boolean, number, -string). It's only purpose is to allow the iterator/const_iterator classes -to "iterate" over primitive values. Internally, the iterator is modeled by -a `difference_type` variable. Value begin_value (`0`) models the begin, -end_value (`1`) models past the end. -*/ -class primitive_iterator_t -{ - private: - using difference_type = std::ptrdiff_t; - static constexpr difference_type begin_value = 0; - static constexpr difference_type end_value = begin_value + 1; - - /// iterator as signed integer type - difference_type m_it = (std::numeric_limits::min)(); - - public: - constexpr difference_type get_value() const noexcept - { - return m_it; - } - - /// set iterator to a defined beginning - void set_begin() noexcept - { - m_it = begin_value; - } - - /// set iterator to a defined past the end - void set_end() noexcept - { - m_it = end_value; - } - - /// return whether the iterator can be dereferenced - constexpr bool is_begin() const noexcept - { - return m_it == begin_value; - } - - /// return whether the iterator is at end - constexpr bool is_end() const noexcept - { - return m_it == end_value; - } - - friend constexpr bool operator==(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept - { - return lhs.m_it == rhs.m_it; - } - - friend constexpr bool operator<(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept - { - return lhs.m_it < rhs.m_it; - } - - primitive_iterator_t operator+(difference_type n) noexcept - { - auto result = *this; - result += n; - return result; - } - - friend constexpr difference_type operator-(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept - { - return lhs.m_it - rhs.m_it; - } - - primitive_iterator_t& operator++() noexcept - { - ++m_it; - return *this; - } - - primitive_iterator_t const operator++(int) noexcept - { - auto result = *this; - ++m_it; - return result; - } - - primitive_iterator_t& operator--() noexcept - { - --m_it; - return *this; - } - - primitive_iterator_t const operator--(int) noexcept - { - auto result = *this; - --m_it; - return result; - } - - primitive_iterator_t& operator+=(difference_type n) noexcept - { - m_it += n; - return *this; - } - - primitive_iterator_t& operator-=(difference_type n) noexcept - { - m_it -= n; - return *this; - } -}; -} // namespace detail -} // namespace nlohmann - - -namespace nlohmann -{ -namespace detail -{ -/*! -@brief an iterator value - -@note This structure could easily be a union, but MSVC currently does not allow -unions members with complex constructors, see https://github.com/nlohmann/json/pull/105. -*/ -template struct internal_iterator -{ - /// iterator for JSON objects - typename BasicJsonType::object_t::iterator object_iterator {}; - /// iterator for JSON arrays - typename BasicJsonType::array_t::iterator array_iterator {}; - /// generic iterator for all other types - primitive_iterator_t primitive_iterator {}; -}; -} // namespace detail -} // namespace nlohmann - -// #include - - -#include // iterator, random_access_iterator_tag, bidirectional_iterator_tag, advance, next -#include // conditional, is_const, remove_const - -// #include - -// #include - -// #include - -// #include - -// #include - -// #include - -// #include - - -namespace nlohmann -{ -namespace detail -{ -// forward declare, to be able to friend it later on -template class iteration_proxy; -template class iteration_proxy_value; - -/*! -@brief a template for a bidirectional iterator for the @ref basic_json class -This class implements a both iterators (iterator and const_iterator) for the -@ref basic_json class. -@note An iterator is called *initialized* when a pointer to a JSON value has - been set (e.g., by a constructor or a copy assignment). If the iterator is - default-constructed, it is *uninitialized* and most methods are undefined. - **The library uses assertions to detect calls on uninitialized iterators.** -@requirement The class satisfies the following concept requirements: -- -[BidirectionalIterator](https://en.cppreference.com/w/cpp/named_req/BidirectionalIterator): - The iterator that can be moved can be moved in both directions (i.e. - incremented and decremented). -@since version 1.0.0, simplified in version 2.0.9, change to bidirectional - iterators in version 3.0.0 (see https://github.com/nlohmann/json/issues/593) -*/ -template -class iter_impl -{ - /// allow basic_json to access private members - friend iter_impl::value, typename std::remove_const::type, const BasicJsonType>::type>; - friend BasicJsonType; - friend iteration_proxy; - friend iteration_proxy_value; - - using object_t = typename BasicJsonType::object_t; - using array_t = typename BasicJsonType::array_t; - // make sure BasicJsonType is basic_json or const basic_json - static_assert(is_basic_json::type>::value, - "iter_impl only accepts (const) basic_json"); - - public: - - /// The std::iterator class template (used as a base class to provide typedefs) is deprecated in C++17. - /// The C++ Standard has never required user-defined iterators to derive from std::iterator. - /// A user-defined iterator should provide publicly accessible typedefs named - /// iterator_category, value_type, difference_type, pointer, and reference. - /// Note that value_type is required to be non-const, even for constant iterators. - using iterator_category = std::bidirectional_iterator_tag; - - /// the type of the values when the iterator is dereferenced - using value_type = typename BasicJsonType::value_type; - /// a type to represent differences between iterators - using difference_type = typename BasicJsonType::difference_type; - /// defines a pointer to the type iterated over (value_type) - using pointer = typename std::conditional::value, - typename BasicJsonType::const_pointer, - typename BasicJsonType::pointer>::type; - /// defines a reference to the type iterated over (value_type) - using reference = - typename std::conditional::value, - typename BasicJsonType::const_reference, - typename BasicJsonType::reference>::type; - - /// default constructor - iter_impl() = default; - - /*! - @brief constructor for a given JSON instance - @param[in] object pointer to a JSON object for this iterator - @pre object != nullptr - @post The iterator is initialized; i.e. `m_object != nullptr`. - */ - explicit iter_impl(pointer object) noexcept : m_object(object) - { - JSON_ASSERT(m_object != nullptr); - - switch (m_object->m_type) - { - case value_t::object: - { - m_it.object_iterator = typename object_t::iterator(); - break; - } - - case value_t::array: - { - m_it.array_iterator = typename array_t::iterator(); - break; - } - - default: - { - m_it.primitive_iterator = primitive_iterator_t(); - break; - } - } - } - - /*! - @note The conventional copy constructor and copy assignment are implicitly - defined. Combined with the following converting constructor and - assignment, they support: (1) copy from iterator to iterator, (2) - copy from const iterator to const iterator, and (3) conversion from - iterator to const iterator. However conversion from const iterator - to iterator is not defined. - */ - - /*! - @brief const copy constructor - @param[in] other const iterator to copy from - @note This copy constructor had to be defined explicitly to circumvent a bug - occurring on msvc v19.0 compiler (VS 2015) debug build. For more - information refer to: https://github.com/nlohmann/json/issues/1608 - */ - iter_impl(const iter_impl& other) noexcept - : m_object(other.m_object), m_it(other.m_it) - {} - - /*! - @brief converting assignment - @param[in] other const iterator to copy from - @return const/non-const iterator - @note It is not checked whether @a other is initialized. - */ - iter_impl& operator=(const iter_impl& other) noexcept - { - m_object = other.m_object; - m_it = other.m_it; - return *this; - } - - /*! - @brief converting constructor - @param[in] other non-const iterator to copy from - @note It is not checked whether @a other is initialized. - */ - iter_impl(const iter_impl::type>& other) noexcept - : m_object(other.m_object), m_it(other.m_it) - {} - - /*! - @brief converting assignment - @param[in] other non-const iterator to copy from - @return const/non-const iterator - @note It is not checked whether @a other is initialized. - */ - iter_impl& operator=(const iter_impl::type>& other) noexcept - { - m_object = other.m_object; - m_it = other.m_it; - return *this; - } - - private: - /*! - @brief set the iterator to the first value - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - void set_begin() noexcept - { - JSON_ASSERT(m_object != nullptr); - - switch (m_object->m_type) - { - case value_t::object: - { - m_it.object_iterator = m_object->m_value.object->begin(); - break; - } - - case value_t::array: - { - m_it.array_iterator = m_object->m_value.array->begin(); - break; - } - - case value_t::null: - { - // set to end so begin()==end() is true: null is empty - m_it.primitive_iterator.set_end(); - break; - } - - default: - { - m_it.primitive_iterator.set_begin(); - break; - } - } - } - - /*! - @brief set the iterator past the last value - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - void set_end() noexcept - { - JSON_ASSERT(m_object != nullptr); - - switch (m_object->m_type) - { - case value_t::object: - { - m_it.object_iterator = m_object->m_value.object->end(); - break; - } - - case value_t::array: - { - m_it.array_iterator = m_object->m_value.array->end(); - break; - } - - default: - { - m_it.primitive_iterator.set_end(); - break; - } - } - } - - public: - /*! - @brief return a reference to the value pointed to by the iterator - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - reference operator*() const - { - JSON_ASSERT(m_object != nullptr); - - switch (m_object->m_type) - { - case value_t::object: - { - JSON_ASSERT(m_it.object_iterator != m_object->m_value.object->end()); - return m_it.object_iterator->second; - } - - case value_t::array: - { - JSON_ASSERT(m_it.array_iterator != m_object->m_value.array->end()); - return *m_it.array_iterator; - } - - case value_t::null: - JSON_THROW(invalid_iterator::create(214, "cannot get value")); - - default: - { - if (JSON_HEDLEY_LIKELY(m_it.primitive_iterator.is_begin())) - { - return *m_object; - } - - JSON_THROW(invalid_iterator::create(214, "cannot get value")); - } - } - } - - /*! - @brief dereference the iterator - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - pointer operator->() const - { - JSON_ASSERT(m_object != nullptr); - - switch (m_object->m_type) - { - case value_t::object: - { - JSON_ASSERT(m_it.object_iterator != m_object->m_value.object->end()); - return &(m_it.object_iterator->second); - } - - case value_t::array: - { - JSON_ASSERT(m_it.array_iterator != m_object->m_value.array->end()); - return &*m_it.array_iterator; - } - - default: - { - if (JSON_HEDLEY_LIKELY(m_it.primitive_iterator.is_begin())) - { - return m_object; - } - - JSON_THROW(invalid_iterator::create(214, "cannot get value")); - } - } - } - - /*! - @brief post-increment (it++) - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - iter_impl const operator++(int) - { - auto result = *this; - ++(*this); - return result; - } - - /*! - @brief pre-increment (++it) - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - iter_impl& operator++() - { - JSON_ASSERT(m_object != nullptr); - - switch (m_object->m_type) - { - case value_t::object: - { - std::advance(m_it.object_iterator, 1); - break; - } - - case value_t::array: - { - std::advance(m_it.array_iterator, 1); - break; - } - - default: - { - ++m_it.primitive_iterator; - break; - } - } - - return *this; - } - - /*! - @brief post-decrement (it--) - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - iter_impl const operator--(int) - { - auto result = *this; - --(*this); - return result; - } - - /*! - @brief pre-decrement (--it) - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - iter_impl& operator--() - { - JSON_ASSERT(m_object != nullptr); - - switch (m_object->m_type) - { - case value_t::object: - { - std::advance(m_it.object_iterator, -1); - break; - } - - case value_t::array: - { - std::advance(m_it.array_iterator, -1); - break; - } - - default: - { - --m_it.primitive_iterator; - break; - } - } - - return *this; - } - - /*! - @brief comparison: equal - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - bool operator==(const iter_impl& other) const - { - // if objects are not the same, the comparison is undefined - if (JSON_HEDLEY_UNLIKELY(m_object != other.m_object)) - { - JSON_THROW(invalid_iterator::create(212, "cannot compare iterators of different containers")); - } - - JSON_ASSERT(m_object != nullptr); - - switch (m_object->m_type) - { - case value_t::object: - return (m_it.object_iterator == other.m_it.object_iterator); - - case value_t::array: - return (m_it.array_iterator == other.m_it.array_iterator); - - default: - return (m_it.primitive_iterator == other.m_it.primitive_iterator); - } - } - - /*! - @brief comparison: not equal - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - bool operator!=(const iter_impl& other) const - { - return !operator==(other); - } - - /*! - @brief comparison: smaller - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - bool operator<(const iter_impl& other) const - { - // if objects are not the same, the comparison is undefined - if (JSON_HEDLEY_UNLIKELY(m_object != other.m_object)) - { - JSON_THROW(invalid_iterator::create(212, "cannot compare iterators of different containers")); - } - - JSON_ASSERT(m_object != nullptr); - - switch (m_object->m_type) - { - case value_t::object: - JSON_THROW(invalid_iterator::create(213, "cannot compare order of object iterators")); - - case value_t::array: - return (m_it.array_iterator < other.m_it.array_iterator); - - default: - return (m_it.primitive_iterator < other.m_it.primitive_iterator); - } - } - - /*! - @brief comparison: less than or equal - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - bool operator<=(const iter_impl& other) const - { - return !other.operator < (*this); - } - - /*! - @brief comparison: greater than - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - bool operator>(const iter_impl& other) const - { - return !operator<=(other); - } - - /*! - @brief comparison: greater than or equal - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - bool operator>=(const iter_impl& other) const - { - return !operator<(other); - } - - /*! - @brief add to iterator - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - iter_impl& operator+=(difference_type i) - { - JSON_ASSERT(m_object != nullptr); - - switch (m_object->m_type) - { - case value_t::object: - JSON_THROW(invalid_iterator::create(209, "cannot use offsets with object iterators")); - - case value_t::array: - { - std::advance(m_it.array_iterator, i); - break; - } - - default: - { - m_it.primitive_iterator += i; - break; - } - } - - return *this; - } - - /*! - @brief subtract from iterator - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - iter_impl& operator-=(difference_type i) - { - return operator+=(-i); - } - - /*! - @brief add to iterator - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - iter_impl operator+(difference_type i) const - { - auto result = *this; - result += i; - return result; - } - - /*! - @brief addition of distance and iterator - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - friend iter_impl operator+(difference_type i, const iter_impl& it) - { - auto result = it; - result += i; - return result; - } - - /*! - @brief subtract from iterator - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - iter_impl operator-(difference_type i) const - { - auto result = *this; - result -= i; - return result; - } - - /*! - @brief return difference - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - difference_type operator-(const iter_impl& other) const - { - JSON_ASSERT(m_object != nullptr); - - switch (m_object->m_type) - { - case value_t::object: - JSON_THROW(invalid_iterator::create(209, "cannot use offsets with object iterators")); - - case value_t::array: - return m_it.array_iterator - other.m_it.array_iterator; - - default: - return m_it.primitive_iterator - other.m_it.primitive_iterator; - } - } - - /*! - @brief access to successor - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - reference operator[](difference_type n) const - { - JSON_ASSERT(m_object != nullptr); - - switch (m_object->m_type) - { - case value_t::object: - JSON_THROW(invalid_iterator::create(208, "cannot use operator[] for object iterators")); - - case value_t::array: - return *std::next(m_it.array_iterator, n); - - case value_t::null: - JSON_THROW(invalid_iterator::create(214, "cannot get value")); - - default: - { - if (JSON_HEDLEY_LIKELY(m_it.primitive_iterator.get_value() == -n)) - { - return *m_object; - } - - JSON_THROW(invalid_iterator::create(214, "cannot get value")); - } - } - } - - /*! - @brief return the key of an object iterator - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - const typename object_t::key_type& key() const - { - JSON_ASSERT(m_object != nullptr); - - if (JSON_HEDLEY_LIKELY(m_object->is_object())) - { - return m_it.object_iterator->first; - } - - JSON_THROW(invalid_iterator::create(207, "cannot use key() for non-object iterators")); - } - - /*! - @brief return the value of an iterator - @pre The iterator is initialized; i.e. `m_object != nullptr`. - */ - reference value() const - { - return operator*(); - } - - private: - /// associated JSON instance - pointer m_object = nullptr; - /// the actual iterator of the associated instance - internal_iterator::type> m_it {}; -}; -} // namespace detail -} // namespace nlohmann - -// #include - -// #include - - -#include // ptrdiff_t -#include // reverse_iterator -#include // declval - -namespace nlohmann -{ -namespace detail -{ -////////////////////// -// reverse_iterator // -////////////////////// - -/*! -@brief a template for a reverse iterator class - -@tparam Base the base iterator type to reverse. Valid types are @ref -iterator (to create @ref reverse_iterator) and @ref const_iterator (to -create @ref const_reverse_iterator). - -@requirement The class satisfies the following concept requirements: -- -[BidirectionalIterator](https://en.cppreference.com/w/cpp/named_req/BidirectionalIterator): - The iterator that can be moved can be moved in both directions (i.e. - incremented and decremented). -- [OutputIterator](https://en.cppreference.com/w/cpp/named_req/OutputIterator): - It is possible to write to the pointed-to element (only if @a Base is - @ref iterator). - -@since version 1.0.0 -*/ -template -class json_reverse_iterator : public std::reverse_iterator -{ - public: - using difference_type = std::ptrdiff_t; - /// shortcut to the reverse iterator adapter - using base_iterator = std::reverse_iterator; - /// the reference type for the pointed-to element - using reference = typename Base::reference; - - /// create reverse iterator from iterator - explicit json_reverse_iterator(const typename base_iterator::iterator_type& it) noexcept - : base_iterator(it) {} - - /// create reverse iterator from base class - explicit json_reverse_iterator(const base_iterator& it) noexcept : base_iterator(it) {} - - /// post-increment (it++) - json_reverse_iterator const operator++(int) - { - return static_cast(base_iterator::operator++(1)); - } - - /// pre-increment (++it) - json_reverse_iterator& operator++() - { - return static_cast(base_iterator::operator++()); - } - - /// post-decrement (it--) - json_reverse_iterator const operator--(int) - { - return static_cast(base_iterator::operator--(1)); - } - - /// pre-decrement (--it) - json_reverse_iterator& operator--() - { - return static_cast(base_iterator::operator--()); - } - - /// add to iterator - json_reverse_iterator& operator+=(difference_type i) - { - return static_cast(base_iterator::operator+=(i)); - } - - /// add to iterator - json_reverse_iterator operator+(difference_type i) const - { - return static_cast(base_iterator::operator+(i)); - } - - /// subtract from iterator - json_reverse_iterator operator-(difference_type i) const - { - return static_cast(base_iterator::operator-(i)); - } - - /// return difference - difference_type operator-(const json_reverse_iterator& other) const - { - return base_iterator(*this) - base_iterator(other); - } - - /// access to successor - reference operator[](difference_type n) const - { - return *(this->operator+(n)); - } - - /// return the key of an object iterator - auto key() const -> decltype(std::declval().key()) - { - auto it = --this->base(); - return it.key(); - } - - /// return the value of an iterator - reference value() const - { - auto it = --this->base(); - return it.operator * (); - } -}; -} // namespace detail -} // namespace nlohmann - -// #include - -// #include - - -#include // all_of -#include // isdigit -#include // max -#include // accumulate -#include // string -#include // move -#include // vector - -// #include - -// #include - -// #include - - -namespace nlohmann -{ -template -class json_pointer -{ - // allow basic_json to access private members - NLOHMANN_BASIC_JSON_TPL_DECLARATION - friend class basic_json; - - public: - /*! - @brief create JSON pointer - - Create a JSON pointer according to the syntax described in - [Section 3 of RFC6901](https://tools.ietf.org/html/rfc6901#section-3). - - @param[in] s string representing the JSON pointer; if omitted, the empty - string is assumed which references the whole JSON value - - @throw parse_error.107 if the given JSON pointer @a s is nonempty and does - not begin with a slash (`/`); see example below - - @throw parse_error.108 if a tilde (`~`) in the given JSON pointer @a s is - not followed by `0` (representing `~`) or `1` (representing `/`); see - example below - - @liveexample{The example shows the construction several valid JSON pointers - as well as the exceptional behavior.,json_pointer} - - @since version 2.0.0 - */ - explicit json_pointer(const std::string& s = "") - : reference_tokens(split(s)) - {} - - /*! - @brief return a string representation of the JSON pointer - - @invariant For each JSON pointer `ptr`, it holds: - @code {.cpp} - ptr == json_pointer(ptr.to_string()); - @endcode - - @return a string representation of the JSON pointer - - @liveexample{The example shows the result of `to_string`.,json_pointer__to_string} - - @since version 2.0.0 - */ - std::string to_string() const - { - return std::accumulate(reference_tokens.begin(), reference_tokens.end(), - std::string{}, - [](const std::string & a, const std::string & b) - { - return a + "/" + escape(b); - }); - } - - /// @copydoc to_string() - operator std::string() const - { - return to_string(); - } - - /*! - @brief append another JSON pointer at the end of this JSON pointer - - @param[in] ptr JSON pointer to append - @return JSON pointer with @a ptr appended - - @liveexample{The example shows the usage of `operator/=`.,json_pointer__operator_add} - - @complexity Linear in the length of @a ptr. - - @sa @ref operator/=(std::string) to append a reference token - @sa @ref operator/=(std::size_t) to append an array index - @sa @ref operator/(const json_pointer&, const json_pointer&) for a binary operator - - @since version 3.6.0 - */ - json_pointer& operator/=(const json_pointer& ptr) - { - reference_tokens.insert(reference_tokens.end(), - ptr.reference_tokens.begin(), - ptr.reference_tokens.end()); - return *this; - } - - /*! - @brief append an unescaped reference token at the end of this JSON pointer - - @param[in] token reference token to append - @return JSON pointer with @a token appended without escaping @a token - - @liveexample{The example shows the usage of `operator/=`.,json_pointer__operator_add} - - @complexity Amortized constant. - - @sa @ref operator/=(const json_pointer&) to append a JSON pointer - @sa @ref operator/=(std::size_t) to append an array index - @sa @ref operator/(const json_pointer&, std::size_t) for a binary operator - - @since version 3.6.0 - */ - json_pointer& operator/=(std::string token) - { - push_back(std::move(token)); - return *this; - } - - /*! - @brief append an array index at the end of this JSON pointer - - @param[in] array_idx array index to append - @return JSON pointer with @a array_idx appended - - @liveexample{The example shows the usage of `operator/=`.,json_pointer__operator_add} - - @complexity Amortized constant. - - @sa @ref operator/=(const json_pointer&) to append a JSON pointer - @sa @ref operator/=(std::string) to append a reference token - @sa @ref operator/(const json_pointer&, std::string) for a binary operator - - @since version 3.6.0 - */ - json_pointer& operator/=(std::size_t array_idx) - { - return *this /= std::to_string(array_idx); - } - - /*! - @brief create a new JSON pointer by appending the right JSON pointer at the end of the left JSON pointer - - @param[in] lhs JSON pointer - @param[in] rhs JSON pointer - @return a new JSON pointer with @a rhs appended to @a lhs - - @liveexample{The example shows the usage of `operator/`.,json_pointer__operator_add_binary} - - @complexity Linear in the length of @a lhs and @a rhs. - - @sa @ref operator/=(const json_pointer&) to append a JSON pointer - - @since version 3.6.0 - */ - friend json_pointer operator/(const json_pointer& lhs, - const json_pointer& rhs) - { - return json_pointer(lhs) /= rhs; - } - - /*! - @brief create a new JSON pointer by appending the unescaped token at the end of the JSON pointer - - @param[in] ptr JSON pointer - @param[in] token reference token - @return a new JSON pointer with unescaped @a token appended to @a ptr - - @liveexample{The example shows the usage of `operator/`.,json_pointer__operator_add_binary} - - @complexity Linear in the length of @a ptr. - - @sa @ref operator/=(std::string) to append a reference token - - @since version 3.6.0 - */ - friend json_pointer operator/(const json_pointer& ptr, std::string token) - { - return json_pointer(ptr) /= std::move(token); - } - - /*! - @brief create a new JSON pointer by appending the array-index-token at the end of the JSON pointer - - @param[in] ptr JSON pointer - @param[in] array_idx array index - @return a new JSON pointer with @a array_idx appended to @a ptr - - @liveexample{The example shows the usage of `operator/`.,json_pointer__operator_add_binary} - - @complexity Linear in the length of @a ptr. - - @sa @ref operator/=(std::size_t) to append an array index - - @since version 3.6.0 - */ - friend json_pointer operator/(const json_pointer& ptr, std::size_t array_idx) - { - return json_pointer(ptr) /= array_idx; - } - - /*! - @brief returns the parent of this JSON pointer - - @return parent of this JSON pointer; in case this JSON pointer is the root, - the root itself is returned - - @complexity Linear in the length of the JSON pointer. - - @liveexample{The example shows the result of `parent_pointer` for different - JSON Pointers.,json_pointer__parent_pointer} - - @since version 3.6.0 - */ - json_pointer parent_pointer() const - { - if (empty()) - { - return *this; - } - - json_pointer res = *this; - res.pop_back(); - return res; - } - - /*! - @brief remove last reference token - - @pre not `empty()` - - @liveexample{The example shows the usage of `pop_back`.,json_pointer__pop_back} - - @complexity Constant. - - @throw out_of_range.405 if JSON pointer has no parent - - @since version 3.6.0 - */ - void pop_back() - { - if (JSON_HEDLEY_UNLIKELY(empty())) - { - JSON_THROW(detail::out_of_range::create(405, "JSON pointer has no parent")); - } - - reference_tokens.pop_back(); - } - - /*! - @brief return last reference token - - @pre not `empty()` - @return last reference token - - @liveexample{The example shows the usage of `back`.,json_pointer__back} - - @complexity Constant. - - @throw out_of_range.405 if JSON pointer has no parent - - @since version 3.6.0 - */ - const std::string& back() const - { - if (JSON_HEDLEY_UNLIKELY(empty())) - { - JSON_THROW(detail::out_of_range::create(405, "JSON pointer has no parent")); - } - - return reference_tokens.back(); - } - - /*! - @brief append an unescaped token at the end of the reference pointer - - @param[in] token token to add - - @complexity Amortized constant. - - @liveexample{The example shows the result of `push_back` for different - JSON Pointers.,json_pointer__push_back} - - @since version 3.6.0 - */ - void push_back(const std::string& token) - { - reference_tokens.push_back(token); - } - - /// @copydoc push_back(const std::string&) - void push_back(std::string&& token) - { - reference_tokens.push_back(std::move(token)); - } - - /*! - @brief return whether pointer points to the root document - - @return true iff the JSON pointer points to the root document - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - - @liveexample{The example shows the result of `empty` for different JSON - Pointers.,json_pointer__empty} - - @since version 3.6.0 - */ - bool empty() const noexcept - { - return reference_tokens.empty(); - } - - private: - /*! - @param[in] s reference token to be converted into an array index - - @return integer representation of @a s - - @throw parse_error.106 if an array index begins with '0' - @throw parse_error.109 if an array index begins not with a digit - @throw out_of_range.404 if string @a s could not be converted to an integer - @throw out_of_range.410 if an array index exceeds size_type - */ - static typename BasicJsonType::size_type array_index(const std::string& s) - { - using size_type = typename BasicJsonType::size_type; - - // error condition (cf. RFC 6901, Sect. 4) - if (JSON_HEDLEY_UNLIKELY(s.size() > 1 && s[0] == '0')) - { - JSON_THROW(detail::parse_error::create(106, 0, - "array index '" + s + - "' must not begin with '0'")); - } - - // error condition (cf. RFC 6901, Sect. 4) - if (JSON_HEDLEY_UNLIKELY(s.size() > 1 && !(s[0] >= '1' && s[0] <= '9'))) - { - JSON_THROW(detail::parse_error::create(109, 0, "array index '" + s + "' is not a number")); - } - - std::size_t processed_chars = 0; - unsigned long long res = 0; - JSON_TRY - { - res = std::stoull(s, &processed_chars); - } - JSON_CATCH(std::out_of_range&) - { - JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + s + "'")); - } - - // check if the string was completely read - if (JSON_HEDLEY_UNLIKELY(processed_chars != s.size())) - { - JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + s + "'")); - } - - // only triggered on special platforms (like 32bit), see also - // https://github.com/nlohmann/json/pull/2203 - if (res >= static_cast((std::numeric_limits::max)())) - { - JSON_THROW(detail::out_of_range::create(410, "array index " + s + " exceeds size_type")); // LCOV_EXCL_LINE - } - - return static_cast(res); - } - - json_pointer top() const - { - if (JSON_HEDLEY_UNLIKELY(empty())) - { - JSON_THROW(detail::out_of_range::create(405, "JSON pointer has no parent")); - } - - json_pointer result = *this; - result.reference_tokens = {reference_tokens[0]}; - return result; - } - - /*! - @brief create and return a reference to the pointed to value - - @complexity Linear in the number of reference tokens. - - @throw parse_error.109 if array index is not a number - @throw type_error.313 if value cannot be unflattened - */ - BasicJsonType& get_and_create(BasicJsonType& j) const - { - auto result = &j; - - // in case no reference tokens exist, return a reference to the JSON value - // j which will be overwritten by a primitive value - for (const auto& reference_token : reference_tokens) - { - switch (result->type()) - { - case detail::value_t::null: - { - if (reference_token == "0") - { - // start a new array if reference token is 0 - result = &result->operator[](0); - } - else - { - // start a new object otherwise - result = &result->operator[](reference_token); - } - break; - } - - case detail::value_t::object: - { - // create an entry in the object - result = &result->operator[](reference_token); - break; - } - - case detail::value_t::array: - { - // create an entry in the array - result = &result->operator[](array_index(reference_token)); - break; - } - - /* - The following code is only reached if there exists a reference - token _and_ the current value is primitive. In this case, we have - an error situation, because primitive values may only occur as - single value; that is, with an empty list of reference tokens. - */ - default: - JSON_THROW(detail::type_error::create(313, "invalid value to unflatten")); - } - } - - return *result; - } - - /*! - @brief return a reference to the pointed to value - - @note This version does not throw if a value is not present, but tries to - create nested values instead. For instance, calling this function - with pointer `"/this/that"` on a null value is equivalent to calling - `operator[]("this").operator[]("that")` on that value, effectively - changing the null value to an object. - - @param[in] ptr a JSON value - - @return reference to the JSON value pointed to by the JSON pointer - - @complexity Linear in the length of the JSON pointer. - - @throw parse_error.106 if an array index begins with '0' - @throw parse_error.109 if an array index was not a number - @throw out_of_range.404 if the JSON pointer can not be resolved - */ - BasicJsonType& get_unchecked(BasicJsonType* ptr) const - { - for (const auto& reference_token : reference_tokens) - { - // convert null values to arrays or objects before continuing - if (ptr->is_null()) - { - // check if reference token is a number - const bool nums = - std::all_of(reference_token.begin(), reference_token.end(), - [](const unsigned char x) - { - return std::isdigit(x); - }); - - // change value to array for numbers or "-" or to object otherwise - *ptr = (nums || reference_token == "-") - ? detail::value_t::array - : detail::value_t::object; - } - - switch (ptr->type()) - { - case detail::value_t::object: - { - // use unchecked object access - ptr = &ptr->operator[](reference_token); - break; - } - - case detail::value_t::array: - { - if (reference_token == "-") - { - // explicitly treat "-" as index beyond the end - ptr = &ptr->operator[](ptr->m_value.array->size()); - } - else - { - // convert array index to number; unchecked access - ptr = &ptr->operator[](array_index(reference_token)); - } - break; - } - - default: - JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'")); - } - } - - return *ptr; - } - - /*! - @throw parse_error.106 if an array index begins with '0' - @throw parse_error.109 if an array index was not a number - @throw out_of_range.402 if the array index '-' is used - @throw out_of_range.404 if the JSON pointer can not be resolved - */ - BasicJsonType& get_checked(BasicJsonType* ptr) const - { - for (const auto& reference_token : reference_tokens) - { - switch (ptr->type()) - { - case detail::value_t::object: - { - // note: at performs range check - ptr = &ptr->at(reference_token); - break; - } - - case detail::value_t::array: - { - if (JSON_HEDLEY_UNLIKELY(reference_token == "-")) - { - // "-" always fails the range check - JSON_THROW(detail::out_of_range::create(402, - "array index '-' (" + std::to_string(ptr->m_value.array->size()) + - ") is out of range")); - } - - // note: at performs range check - ptr = &ptr->at(array_index(reference_token)); - break; - } - - default: - JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'")); - } - } - - return *ptr; - } - - /*! - @brief return a const reference to the pointed to value - - @param[in] ptr a JSON value - - @return const reference to the JSON value pointed to by the JSON - pointer - - @throw parse_error.106 if an array index begins with '0' - @throw parse_error.109 if an array index was not a number - @throw out_of_range.402 if the array index '-' is used - @throw out_of_range.404 if the JSON pointer can not be resolved - */ - const BasicJsonType& get_unchecked(const BasicJsonType* ptr) const - { - for (const auto& reference_token : reference_tokens) - { - switch (ptr->type()) - { - case detail::value_t::object: - { - // use unchecked object access - ptr = &ptr->operator[](reference_token); - break; - } - - case detail::value_t::array: - { - if (JSON_HEDLEY_UNLIKELY(reference_token == "-")) - { - // "-" cannot be used for const access - JSON_THROW(detail::out_of_range::create(402, - "array index '-' (" + std::to_string(ptr->m_value.array->size()) + - ") is out of range")); - } - - // use unchecked array access - ptr = &ptr->operator[](array_index(reference_token)); - break; - } - - default: - JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'")); - } - } - - return *ptr; - } - - /*! - @throw parse_error.106 if an array index begins with '0' - @throw parse_error.109 if an array index was not a number - @throw out_of_range.402 if the array index '-' is used - @throw out_of_range.404 if the JSON pointer can not be resolved - */ - const BasicJsonType& get_checked(const BasicJsonType* ptr) const - { - for (const auto& reference_token : reference_tokens) - { - switch (ptr->type()) - { - case detail::value_t::object: - { - // note: at performs range check - ptr = &ptr->at(reference_token); - break; - } - - case detail::value_t::array: - { - if (JSON_HEDLEY_UNLIKELY(reference_token == "-")) - { - // "-" always fails the range check - JSON_THROW(detail::out_of_range::create(402, - "array index '-' (" + std::to_string(ptr->m_value.array->size()) + - ") is out of range")); - } - - // note: at performs range check - ptr = &ptr->at(array_index(reference_token)); - break; - } - - default: - JSON_THROW(detail::out_of_range::create(404, "unresolved reference token '" + reference_token + "'")); - } - } - - return *ptr; - } - - /*! - @throw parse_error.106 if an array index begins with '0' - @throw parse_error.109 if an array index was not a number - */ - bool contains(const BasicJsonType* ptr) const - { - for (const auto& reference_token : reference_tokens) - { - switch (ptr->type()) - { - case detail::value_t::object: - { - if (!ptr->contains(reference_token)) - { - // we did not find the key in the object - return false; - } - - ptr = &ptr->operator[](reference_token); - break; - } - - case detail::value_t::array: - { - if (JSON_HEDLEY_UNLIKELY(reference_token == "-")) - { - // "-" always fails the range check - return false; - } - if (JSON_HEDLEY_UNLIKELY(reference_token.size() == 1 && !("0" <= reference_token && reference_token <= "9"))) - { - // invalid char - return false; - } - if (JSON_HEDLEY_UNLIKELY(reference_token.size() > 1)) - { - if (JSON_HEDLEY_UNLIKELY(!('1' <= reference_token[0] && reference_token[0] <= '9'))) - { - // first char should be between '1' and '9' - return false; - } - for (std::size_t i = 1; i < reference_token.size(); i++) - { - if (JSON_HEDLEY_UNLIKELY(!('0' <= reference_token[i] && reference_token[i] <= '9'))) - { - // other char should be between '0' and '9' - return false; - } - } - } - - const auto idx = array_index(reference_token); - if (idx >= ptr->size()) - { - // index out of range - return false; - } - - ptr = &ptr->operator[](idx); - break; - } - - default: - { - // we do not expect primitive values if there is still a - // reference token to process - return false; - } - } - } - - // no reference token left means we found a primitive value - return true; - } - - /*! - @brief split the string input to reference tokens - - @note This function is only called by the json_pointer constructor. - All exceptions below are documented there. - - @throw parse_error.107 if the pointer is not empty or begins with '/' - @throw parse_error.108 if character '~' is not followed by '0' or '1' - */ - static std::vector split(const std::string& reference_string) - { - std::vector result; - - // special case: empty reference string -> no reference tokens - if (reference_string.empty()) - { - return result; - } - - // check if nonempty reference string begins with slash - if (JSON_HEDLEY_UNLIKELY(reference_string[0] != '/')) - { - JSON_THROW(detail::parse_error::create(107, 1, - "JSON pointer must be empty or begin with '/' - was: '" + - reference_string + "'")); - } - - // extract the reference tokens: - // - slash: position of the last read slash (or end of string) - // - start: position after the previous slash - for ( - // search for the first slash after the first character - std::size_t slash = reference_string.find_first_of('/', 1), - // set the beginning of the first reference token - start = 1; - // we can stop if start == 0 (if slash == std::string::npos) - start != 0; - // set the beginning of the next reference token - // (will eventually be 0 if slash == std::string::npos) - start = (slash == std::string::npos) ? 0 : slash + 1, - // find next slash - slash = reference_string.find_first_of('/', start)) - { - // use the text between the beginning of the reference token - // (start) and the last slash (slash). - auto reference_token = reference_string.substr(start, slash - start); - - // check reference tokens are properly escaped - for (std::size_t pos = reference_token.find_first_of('~'); - pos != std::string::npos; - pos = reference_token.find_first_of('~', pos + 1)) - { - JSON_ASSERT(reference_token[pos] == '~'); - - // ~ must be followed by 0 or 1 - if (JSON_HEDLEY_UNLIKELY(pos == reference_token.size() - 1 || - (reference_token[pos + 1] != '0' && - reference_token[pos + 1] != '1'))) - { - JSON_THROW(detail::parse_error::create(108, 0, "escape character '~' must be followed with '0' or '1'")); - } - } - - // finally, store the reference token - unescape(reference_token); - result.push_back(reference_token); - } - - return result; - } - - /*! - @brief replace all occurrences of a substring by another string - - @param[in,out] s the string to manipulate; changed so that all - occurrences of @a f are replaced with @a t - @param[in] f the substring to replace with @a t - @param[in] t the string to replace @a f - - @pre The search string @a f must not be empty. **This precondition is - enforced with an assertion.** - - @since version 2.0.0 - */ - static void replace_substring(std::string& s, const std::string& f, - const std::string& t) - { - JSON_ASSERT(!f.empty()); - for (auto pos = s.find(f); // find first occurrence of f - pos != std::string::npos; // make sure f was found - s.replace(pos, f.size(), t), // replace with t, and - pos = s.find(f, pos + t.size())) // find next occurrence of f - {} - } - - /// escape "~" to "~0" and "/" to "~1" - static std::string escape(std::string s) - { - replace_substring(s, "~", "~0"); - replace_substring(s, "/", "~1"); - return s; - } - - /// unescape "~1" to tilde and "~0" to slash (order is important!) - static void unescape(std::string& s) - { - replace_substring(s, "~1", "/"); - replace_substring(s, "~0", "~"); - } - - /*! - @param[in] reference_string the reference string to the current value - @param[in] value the value to consider - @param[in,out] result the result object to insert values to - - @note Empty objects or arrays are flattened to `null`. - */ - static void flatten(const std::string& reference_string, - const BasicJsonType& value, - BasicJsonType& result) - { - switch (value.type()) - { - case detail::value_t::array: - { - if (value.m_value.array->empty()) - { - // flatten empty array as null - result[reference_string] = nullptr; - } - else - { - // iterate array and use index as reference string - for (std::size_t i = 0; i < value.m_value.array->size(); ++i) - { - flatten(reference_string + "/" + std::to_string(i), - value.m_value.array->operator[](i), result); - } - } - break; - } - - case detail::value_t::object: - { - if (value.m_value.object->empty()) - { - // flatten empty object as null - result[reference_string] = nullptr; - } - else - { - // iterate object and use keys as reference string - for (const auto& element : *value.m_value.object) - { - flatten(reference_string + "/" + escape(element.first), element.second, result); - } - } - break; - } - - default: - { - // add primitive value with its reference string - result[reference_string] = value; - break; - } - } - } - - /*! - @param[in] value flattened JSON - - @return unflattened JSON - - @throw parse_error.109 if array index is not a number - @throw type_error.314 if value is not an object - @throw type_error.315 if object values are not primitive - @throw type_error.313 if value cannot be unflattened - */ - static BasicJsonType - unflatten(const BasicJsonType& value) - { - if (JSON_HEDLEY_UNLIKELY(!value.is_object())) - { - JSON_THROW(detail::type_error::create(314, "only objects can be unflattened")); - } - - BasicJsonType result; - - // iterate the JSON object values - for (const auto& element : *value.m_value.object) - { - if (JSON_HEDLEY_UNLIKELY(!element.second.is_primitive())) - { - JSON_THROW(detail::type_error::create(315, "values in object must be primitive")); - } - - // assign value to reference pointed to by JSON pointer; Note that if - // the JSON pointer is "" (i.e., points to the whole value), function - // get_and_create returns a reference to result itself. An assignment - // will then create a primitive value. - json_pointer(element.first).get_and_create(result) = element.second; - } - - return result; - } - - /*! - @brief compares two JSON pointers for equality - - @param[in] lhs JSON pointer to compare - @param[in] rhs JSON pointer to compare - @return whether @a lhs is equal to @a rhs - - @complexity Linear in the length of the JSON pointer - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - */ - friend bool operator==(json_pointer const& lhs, - json_pointer const& rhs) noexcept - { - return lhs.reference_tokens == rhs.reference_tokens; - } - - /*! - @brief compares two JSON pointers for inequality - - @param[in] lhs JSON pointer to compare - @param[in] rhs JSON pointer to compare - @return whether @a lhs is not equal @a rhs - - @complexity Linear in the length of the JSON pointer - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - */ - friend bool operator!=(json_pointer const& lhs, - json_pointer const& rhs) noexcept - { - return !(lhs == rhs); - } - - /// the reference tokens - std::vector reference_tokens; -}; -} // namespace nlohmann - -// #include - - -#include -#include - -// #include - - -namespace nlohmann -{ -namespace detail -{ -template -class json_ref -{ - public: - using value_type = BasicJsonType; - - json_ref(value_type&& value) - : owned_value(std::move(value)) - , value_ref(&owned_value) - , is_rvalue(true) - {} - - json_ref(const value_type& value) - : value_ref(const_cast(&value)) - , is_rvalue(false) - {} - - json_ref(std::initializer_list init) - : owned_value(init) - , value_ref(&owned_value) - , is_rvalue(true) - {} - - template < - class... Args, - enable_if_t::value, int> = 0 > - json_ref(Args && ... args) - : owned_value(std::forward(args)...) - , value_ref(&owned_value) - , is_rvalue(true) - {} - - // class should be movable only - json_ref(json_ref&&) = default; - json_ref(const json_ref&) = delete; - json_ref& operator=(const json_ref&) = delete; - json_ref& operator=(json_ref&&) = delete; - ~json_ref() = default; - - value_type moved_or_copied() const - { - if (is_rvalue) - { - return std::move(*value_ref); - } - return *value_ref; - } - - value_type const& operator*() const - { - return *static_cast(value_ref); - } - - value_type const* operator->() const - { - return static_cast(value_ref); - } - - private: - mutable value_type owned_value = nullptr; - value_type* value_ref = nullptr; - const bool is_rvalue = true; -}; -} // namespace detail -} // namespace nlohmann - -// #include - -// #include - -// #include - -// #include - - -#include // reverse -#include // array -#include // uint8_t, uint16_t, uint32_t, uint64_t -#include // memcpy -#include // numeric_limits -#include // string -#include // isnan, isinf - -// #include - -// #include - -// #include - - -#include // copy -#include // size_t -#include // streamsize -#include // back_inserter -#include // shared_ptr, make_shared -#include // basic_ostream -#include // basic_string -#include // vector -// #include - - -namespace nlohmann -{ -namespace detail -{ -/// abstract output adapter interface -template struct output_adapter_protocol -{ - virtual void write_character(CharType c) = 0; - virtual void write_characters(const CharType* s, std::size_t length) = 0; - virtual ~output_adapter_protocol() = default; -}; - -/// a type to simplify interfaces -template -using output_adapter_t = std::shared_ptr>; - -/// output adapter for byte vectors -template -class output_vector_adapter : public output_adapter_protocol -{ - public: - explicit output_vector_adapter(std::vector& vec) noexcept - : v(vec) - {} - - void write_character(CharType c) override - { - v.push_back(c); - } - - JSON_HEDLEY_NON_NULL(2) - void write_characters(const CharType* s, std::size_t length) override - { - std::copy(s, s + length, std::back_inserter(v)); - } - - private: - std::vector& v; -}; - -/// output adapter for output streams -template -class output_stream_adapter : public output_adapter_protocol -{ - public: - explicit output_stream_adapter(std::basic_ostream& s) noexcept - : stream(s) - {} - - void write_character(CharType c) override - { - stream.put(c); - } - - JSON_HEDLEY_NON_NULL(2) - void write_characters(const CharType* s, std::size_t length) override - { - stream.write(s, static_cast(length)); - } - - private: - std::basic_ostream& stream; -}; - -/// output adapter for basic_string -template> -class output_string_adapter : public output_adapter_protocol -{ - public: - explicit output_string_adapter(StringType& s) noexcept - : str(s) - {} - - void write_character(CharType c) override - { - str.push_back(c); - } - - JSON_HEDLEY_NON_NULL(2) - void write_characters(const CharType* s, std::size_t length) override - { - str.append(s, length); - } - - private: - StringType& str; -}; - -template> -class output_adapter -{ - public: - output_adapter(std::vector& vec) - : oa(std::make_shared>(vec)) {} - - output_adapter(std::basic_ostream& s) - : oa(std::make_shared>(s)) {} - - output_adapter(StringType& s) - : oa(std::make_shared>(s)) {} - - operator output_adapter_t() - { - return oa; - } - - private: - output_adapter_t oa = nullptr; -}; -} // namespace detail -} // namespace nlohmann - - -namespace nlohmann -{ -namespace detail -{ -/////////////////// -// binary writer // -/////////////////// - -/*! -@brief serialization to CBOR and MessagePack values -*/ -template -class binary_writer -{ - using string_t = typename BasicJsonType::string_t; - using binary_t = typename BasicJsonType::binary_t; - using number_float_t = typename BasicJsonType::number_float_t; - - public: - /*! - @brief create a binary writer - - @param[in] adapter output adapter to write to - */ - explicit binary_writer(output_adapter_t adapter) : oa(adapter) - { - JSON_ASSERT(oa); - } - - /*! - @param[in] j JSON value to serialize - @pre j.type() == value_t::object - */ - void write_bson(const BasicJsonType& j) - { - switch (j.type()) - { - case value_t::object: - { - write_bson_object(*j.m_value.object); - break; - } - - default: - { - JSON_THROW(type_error::create(317, "to serialize to BSON, top-level type must be object, but is " + std::string(j.type_name()))); - } - } - } - - /*! - @param[in] j JSON value to serialize - */ - void write_cbor(const BasicJsonType& j) - { - switch (j.type()) - { - case value_t::null: - { - oa->write_character(to_char_type(0xF6)); - break; - } - - case value_t::boolean: - { - oa->write_character(j.m_value.boolean - ? to_char_type(0xF5) - : to_char_type(0xF4)); - break; - } - - case value_t::number_integer: - { - if (j.m_value.number_integer >= 0) - { - // CBOR does not differentiate between positive signed - // integers and unsigned integers. Therefore, we used the - // code from the value_t::number_unsigned case here. - if (j.m_value.number_integer <= 0x17) - { - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_integer <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x18)); - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_integer <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x19)); - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_integer <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x1A)); - write_number(static_cast(j.m_value.number_integer)); - } - else - { - oa->write_character(to_char_type(0x1B)); - write_number(static_cast(j.m_value.number_integer)); - } - } - else - { - // The conversions below encode the sign in the first - // byte, and the value is converted to a positive number. - const auto positive_number = -1 - j.m_value.number_integer; - if (j.m_value.number_integer >= -24) - { - write_number(static_cast(0x20 + positive_number)); - } - else if (positive_number <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x38)); - write_number(static_cast(positive_number)); - } - else if (positive_number <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x39)); - write_number(static_cast(positive_number)); - } - else if (positive_number <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x3A)); - write_number(static_cast(positive_number)); - } - else - { - oa->write_character(to_char_type(0x3B)); - write_number(static_cast(positive_number)); - } - } - break; - } - - case value_t::number_unsigned: - { - if (j.m_value.number_unsigned <= 0x17) - { - write_number(static_cast(j.m_value.number_unsigned)); - } - else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x18)); - write_number(static_cast(j.m_value.number_unsigned)); - } - else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x19)); - write_number(static_cast(j.m_value.number_unsigned)); - } - else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x1A)); - write_number(static_cast(j.m_value.number_unsigned)); - } - else - { - oa->write_character(to_char_type(0x1B)); - write_number(static_cast(j.m_value.number_unsigned)); - } - break; - } - - case value_t::number_float: - { - if (std::isnan(j.m_value.number_float)) - { - // NaN is 0xf97e00 in CBOR - oa->write_character(to_char_type(0xF9)); - oa->write_character(to_char_type(0x7E)); - oa->write_character(to_char_type(0x00)); - } - else if (std::isinf(j.m_value.number_float)) - { - // Infinity is 0xf97c00, -Infinity is 0xf9fc00 - oa->write_character(to_char_type(0xf9)); - oa->write_character(j.m_value.number_float > 0 ? to_char_type(0x7C) : to_char_type(0xFC)); - oa->write_character(to_char_type(0x00)); - } - else - { - write_compact_float(j.m_value.number_float, detail::input_format_t::cbor); - } - break; - } - - case value_t::string: - { - // step 1: write control byte and the string length - const auto N = j.m_value.string->size(); - if (N <= 0x17) - { - write_number(static_cast(0x60 + N)); - } - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x78)); - write_number(static_cast(N)); - } - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x79)); - write_number(static_cast(N)); - } - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x7A)); - write_number(static_cast(N)); - } - // LCOV_EXCL_START - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x7B)); - write_number(static_cast(N)); - } - // LCOV_EXCL_STOP - - // step 2: write the string - oa->write_characters( - reinterpret_cast(j.m_value.string->c_str()), - j.m_value.string->size()); - break; - } - - case value_t::array: - { - // step 1: write control byte and the array size - const auto N = j.m_value.array->size(); - if (N <= 0x17) - { - write_number(static_cast(0x80 + N)); - } - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x98)); - write_number(static_cast(N)); - } - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x99)); - write_number(static_cast(N)); - } - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x9A)); - write_number(static_cast(N)); - } - // LCOV_EXCL_START - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x9B)); - write_number(static_cast(N)); - } - // LCOV_EXCL_STOP - - // step 2: write each element - for (const auto& el : *j.m_value.array) - { - write_cbor(el); - } - break; - } - - case value_t::binary: - { - if (j.m_value.binary->has_subtype()) - { - write_number(static_cast(0xd8)); - write_number(j.m_value.binary->subtype()); - } - - // step 1: write control byte and the binary array size - const auto N = j.m_value.binary->size(); - if (N <= 0x17) - { - write_number(static_cast(0x40 + N)); - } - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x58)); - write_number(static_cast(N)); - } - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x59)); - write_number(static_cast(N)); - } - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x5A)); - write_number(static_cast(N)); - } - // LCOV_EXCL_START - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0x5B)); - write_number(static_cast(N)); - } - // LCOV_EXCL_STOP - - // step 2: write each element - oa->write_characters( - reinterpret_cast(j.m_value.binary->data()), - N); - - break; - } - - case value_t::object: - { - // step 1: write control byte and the object size - const auto N = j.m_value.object->size(); - if (N <= 0x17) - { - write_number(static_cast(0xA0 + N)); - } - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0xB8)); - write_number(static_cast(N)); - } - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0xB9)); - write_number(static_cast(N)); - } - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0xBA)); - write_number(static_cast(N)); - } - // LCOV_EXCL_START - else if (N <= (std::numeric_limits::max)()) - { - oa->write_character(to_char_type(0xBB)); - write_number(static_cast(N)); - } - // LCOV_EXCL_STOP - - // step 2: write each element - for (const auto& el : *j.m_value.object) - { - write_cbor(el.first); - write_cbor(el.second); - } - break; - } - - default: - break; - } - } - - /*! - @param[in] j JSON value to serialize - */ - void write_msgpack(const BasicJsonType& j) - { - switch (j.type()) - { - case value_t::null: // nil - { - oa->write_character(to_char_type(0xC0)); - break; - } - - case value_t::boolean: // true and false - { - oa->write_character(j.m_value.boolean - ? to_char_type(0xC3) - : to_char_type(0xC2)); - break; - } - - case value_t::number_integer: - { - if (j.m_value.number_integer >= 0) - { - // MessagePack does not differentiate between positive - // signed integers and unsigned integers. Therefore, we used - // the code from the value_t::number_unsigned case here. - if (j.m_value.number_unsigned < 128) - { - // positive fixnum - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) - { - // uint 8 - oa->write_character(to_char_type(0xCC)); - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) - { - // uint 16 - oa->write_character(to_char_type(0xCD)); - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) - { - // uint 32 - oa->write_character(to_char_type(0xCE)); - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) - { - // uint 64 - oa->write_character(to_char_type(0xCF)); - write_number(static_cast(j.m_value.number_integer)); - } - } - else - { - if (j.m_value.number_integer >= -32) - { - // negative fixnum - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_integer >= (std::numeric_limits::min)() && - j.m_value.number_integer <= (std::numeric_limits::max)()) - { - // int 8 - oa->write_character(to_char_type(0xD0)); - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_integer >= (std::numeric_limits::min)() && - j.m_value.number_integer <= (std::numeric_limits::max)()) - { - // int 16 - oa->write_character(to_char_type(0xD1)); - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_integer >= (std::numeric_limits::min)() && - j.m_value.number_integer <= (std::numeric_limits::max)()) - { - // int 32 - oa->write_character(to_char_type(0xD2)); - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_integer >= (std::numeric_limits::min)() && - j.m_value.number_integer <= (std::numeric_limits::max)()) - { - // int 64 - oa->write_character(to_char_type(0xD3)); - write_number(static_cast(j.m_value.number_integer)); - } - } - break; - } - - case value_t::number_unsigned: - { - if (j.m_value.number_unsigned < 128) - { - // positive fixnum - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) - { - // uint 8 - oa->write_character(to_char_type(0xCC)); - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) - { - // uint 16 - oa->write_character(to_char_type(0xCD)); - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) - { - // uint 32 - oa->write_character(to_char_type(0xCE)); - write_number(static_cast(j.m_value.number_integer)); - } - else if (j.m_value.number_unsigned <= (std::numeric_limits::max)()) - { - // uint 64 - oa->write_character(to_char_type(0xCF)); - write_number(static_cast(j.m_value.number_integer)); - } - break; - } - - case value_t::number_float: - { - write_compact_float(j.m_value.number_float, detail::input_format_t::msgpack); - break; - } - - case value_t::string: - { - // step 1: write control byte and the string length - const auto N = j.m_value.string->size(); - if (N <= 31) - { - // fixstr - write_number(static_cast(0xA0 | N)); - } - else if (N <= (std::numeric_limits::max)()) - { - // str 8 - oa->write_character(to_char_type(0xD9)); - write_number(static_cast(N)); - } - else if (N <= (std::numeric_limits::max)()) - { - // str 16 - oa->write_character(to_char_type(0xDA)); - write_number(static_cast(N)); - } - else if (N <= (std::numeric_limits::max)()) - { - // str 32 - oa->write_character(to_char_type(0xDB)); - write_number(static_cast(N)); - } - - // step 2: write the string - oa->write_characters( - reinterpret_cast(j.m_value.string->c_str()), - j.m_value.string->size()); - break; - } - - case value_t::array: - { - // step 1: write control byte and the array size - const auto N = j.m_value.array->size(); - if (N <= 15) - { - // fixarray - write_number(static_cast(0x90 | N)); - } - else if (N <= (std::numeric_limits::max)()) - { - // array 16 - oa->write_character(to_char_type(0xDC)); - write_number(static_cast(N)); - } - else if (N <= (std::numeric_limits::max)()) - { - // array 32 - oa->write_character(to_char_type(0xDD)); - write_number(static_cast(N)); - } - - // step 2: write each element - for (const auto& el : *j.m_value.array) - { - write_msgpack(el); - } - break; - } - - case value_t::binary: - { - // step 0: determine if the binary type has a set subtype to - // determine whether or not to use the ext or fixext types - const bool use_ext = j.m_value.binary->has_subtype(); - - // step 1: write control byte and the byte string length - const auto N = j.m_value.binary->size(); - if (N <= (std::numeric_limits::max)()) - { - std::uint8_t output_type{}; - bool fixed = true; - if (use_ext) - { - switch (N) - { - case 1: - output_type = 0xD4; // fixext 1 - break; - case 2: - output_type = 0xD5; // fixext 2 - break; - case 4: - output_type = 0xD6; // fixext 4 - break; - case 8: - output_type = 0xD7; // fixext 8 - break; - case 16: - output_type = 0xD8; // fixext 16 - break; - default: - output_type = 0xC7; // ext 8 - fixed = false; - break; - } - - } - else - { - output_type = 0xC4; // bin 8 - fixed = false; - } - - oa->write_character(to_char_type(output_type)); - if (!fixed) - { - write_number(static_cast(N)); - } - } - else if (N <= (std::numeric_limits::max)()) - { - std::uint8_t output_type = use_ext - ? 0xC8 // ext 16 - : 0xC5; // bin 16 - - oa->write_character(to_char_type(output_type)); - write_number(static_cast(N)); - } - else if (N <= (std::numeric_limits::max)()) - { - std::uint8_t output_type = use_ext - ? 0xC9 // ext 32 - : 0xC6; // bin 32 - - oa->write_character(to_char_type(output_type)); - write_number(static_cast(N)); - } - - // step 1.5: if this is an ext type, write the subtype - if (use_ext) - { - write_number(static_cast(j.m_value.binary->subtype())); - } - - // step 2: write the byte string - oa->write_characters( - reinterpret_cast(j.m_value.binary->data()), - N); - - break; - } - - case value_t::object: - { - // step 1: write control byte and the object size - const auto N = j.m_value.object->size(); - if (N <= 15) - { - // fixmap - write_number(static_cast(0x80 | (N & 0xF))); - } - else if (N <= (std::numeric_limits::max)()) - { - // map 16 - oa->write_character(to_char_type(0xDE)); - write_number(static_cast(N)); - } - else if (N <= (std::numeric_limits::max)()) - { - // map 32 - oa->write_character(to_char_type(0xDF)); - write_number(static_cast(N)); - } - - // step 2: write each element - for (const auto& el : *j.m_value.object) - { - write_msgpack(el.first); - write_msgpack(el.second); - } - break; - } - - default: - break; - } - } - - /*! - @param[in] j JSON value to serialize - @param[in] use_count whether to use '#' prefixes (optimized format) - @param[in] use_type whether to use '$' prefixes (optimized format) - @param[in] add_prefix whether prefixes need to be used for this value - */ - void write_ubjson(const BasicJsonType& j, const bool use_count, - const bool use_type, const bool add_prefix = true) - { - switch (j.type()) - { - case value_t::null: - { - if (add_prefix) - { - oa->write_character(to_char_type('Z')); - } - break; - } - - case value_t::boolean: - { - if (add_prefix) - { - oa->write_character(j.m_value.boolean - ? to_char_type('T') - : to_char_type('F')); - } - break; - } - - case value_t::number_integer: - { - write_number_with_ubjson_prefix(j.m_value.number_integer, add_prefix); - break; - } - - case value_t::number_unsigned: - { - write_number_with_ubjson_prefix(j.m_value.number_unsigned, add_prefix); - break; - } - - case value_t::number_float: - { - write_number_with_ubjson_prefix(j.m_value.number_float, add_prefix); - break; - } - - case value_t::string: - { - if (add_prefix) - { - oa->write_character(to_char_type('S')); - } - write_number_with_ubjson_prefix(j.m_value.string->size(), true); - oa->write_characters( - reinterpret_cast(j.m_value.string->c_str()), - j.m_value.string->size()); - break; - } - - case value_t::array: - { - if (add_prefix) - { - oa->write_character(to_char_type('[')); - } - - bool prefix_required = true; - if (use_type && !j.m_value.array->empty()) - { - JSON_ASSERT(use_count); - const CharType first_prefix = ubjson_prefix(j.front()); - const bool same_prefix = std::all_of(j.begin() + 1, j.end(), - [this, first_prefix](const BasicJsonType & v) - { - return ubjson_prefix(v) == first_prefix; - }); - - if (same_prefix) - { - prefix_required = false; - oa->write_character(to_char_type('$')); - oa->write_character(first_prefix); - } - } - - if (use_count) - { - oa->write_character(to_char_type('#')); - write_number_with_ubjson_prefix(j.m_value.array->size(), true); - } - - for (const auto& el : *j.m_value.array) - { - write_ubjson(el, use_count, use_type, prefix_required); - } - - if (!use_count) - { - oa->write_character(to_char_type(']')); - } - - break; - } - - case value_t::binary: - { - if (add_prefix) - { - oa->write_character(to_char_type('[')); - } - - if (use_type && !j.m_value.binary->empty()) - { - JSON_ASSERT(use_count); - oa->write_character(to_char_type('$')); - oa->write_character('U'); - } - - if (use_count) - { - oa->write_character(to_char_type('#')); - write_number_with_ubjson_prefix(j.m_value.binary->size(), true); - } - - if (use_type) - { - oa->write_characters( - reinterpret_cast(j.m_value.binary->data()), - j.m_value.binary->size()); - } - else - { - for (size_t i = 0; i < j.m_value.binary->size(); ++i) - { - oa->write_character(to_char_type('U')); - oa->write_character(j.m_value.binary->data()[i]); - } - } - - if (!use_count) - { - oa->write_character(to_char_type(']')); - } - - break; - } - - case value_t::object: - { - if (add_prefix) - { - oa->write_character(to_char_type('{')); - } - - bool prefix_required = true; - if (use_type && !j.m_value.object->empty()) - { - JSON_ASSERT(use_count); - const CharType first_prefix = ubjson_prefix(j.front()); - const bool same_prefix = std::all_of(j.begin(), j.end(), - [this, first_prefix](const BasicJsonType & v) - { - return ubjson_prefix(v) == first_prefix; - }); - - if (same_prefix) - { - prefix_required = false; - oa->write_character(to_char_type('$')); - oa->write_character(first_prefix); - } - } - - if (use_count) - { - oa->write_character(to_char_type('#')); - write_number_with_ubjson_prefix(j.m_value.object->size(), true); - } - - for (const auto& el : *j.m_value.object) - { - write_number_with_ubjson_prefix(el.first.size(), true); - oa->write_characters( - reinterpret_cast(el.first.c_str()), - el.first.size()); - write_ubjson(el.second, use_count, use_type, prefix_required); - } - - if (!use_count) - { - oa->write_character(to_char_type('}')); - } - - break; - } - - default: - break; - } - } - - private: - ////////// - // BSON // - ////////// - - /*! - @return The size of a BSON document entry header, including the id marker - and the entry name size (and its null-terminator). - */ - static std::size_t calc_bson_entry_header_size(const string_t& name) - { - const auto it = name.find(static_cast(0)); - if (JSON_HEDLEY_UNLIKELY(it != BasicJsonType::string_t::npos)) - { - JSON_THROW(out_of_range::create(409, - "BSON key cannot contain code point U+0000 (at byte " + std::to_string(it) + ")")); - } - - return /*id*/ 1ul + name.size() + /*zero-terminator*/1u; - } - - /*! - @brief Writes the given @a element_type and @a name to the output adapter - */ - void write_bson_entry_header(const string_t& name, - const std::uint8_t element_type) - { - oa->write_character(to_char_type(element_type)); // boolean - oa->write_characters( - reinterpret_cast(name.c_str()), - name.size() + 1u); - } - - /*! - @brief Writes a BSON element with key @a name and boolean value @a value - */ - void write_bson_boolean(const string_t& name, - const bool value) - { - write_bson_entry_header(name, 0x08); - oa->write_character(value ? to_char_type(0x01) : to_char_type(0x00)); - } - - /*! - @brief Writes a BSON element with key @a name and double value @a value - */ - void write_bson_double(const string_t& name, - const double value) - { - write_bson_entry_header(name, 0x01); - write_number(value); - } - - /*! - @return The size of the BSON-encoded string in @a value - */ - static std::size_t calc_bson_string_size(const string_t& value) - { - return sizeof(std::int32_t) + value.size() + 1ul; - } - - /*! - @brief Writes a BSON element with key @a name and string value @a value - */ - void write_bson_string(const string_t& name, - const string_t& value) - { - write_bson_entry_header(name, 0x02); - - write_number(static_cast(value.size() + 1ul)); - oa->write_characters( - reinterpret_cast(value.c_str()), - value.size() + 1); - } - - /*! - @brief Writes a BSON element with key @a name and null value - */ - void write_bson_null(const string_t& name) - { - write_bson_entry_header(name, 0x0A); - } - - /*! - @return The size of the BSON-encoded integer @a value - */ - static std::size_t calc_bson_integer_size(const std::int64_t value) - { - return (std::numeric_limits::min)() <= value && value <= (std::numeric_limits::max)() - ? sizeof(std::int32_t) - : sizeof(std::int64_t); - } - - /*! - @brief Writes a BSON element with key @a name and integer @a value - */ - void write_bson_integer(const string_t& name, - const std::int64_t value) - { - if ((std::numeric_limits::min)() <= value && value <= (std::numeric_limits::max)()) - { - write_bson_entry_header(name, 0x10); // int32 - write_number(static_cast(value)); - } - else - { - write_bson_entry_header(name, 0x12); // int64 - write_number(static_cast(value)); - } - } - - /*! - @return The size of the BSON-encoded unsigned integer in @a j - */ - static constexpr std::size_t calc_bson_unsigned_size(const std::uint64_t value) noexcept - { - return (value <= static_cast((std::numeric_limits::max)())) - ? sizeof(std::int32_t) - : sizeof(std::int64_t); - } - - /*! - @brief Writes a BSON element with key @a name and unsigned @a value - */ - void write_bson_unsigned(const string_t& name, - const std::uint64_t value) - { - if (value <= static_cast((std::numeric_limits::max)())) - { - write_bson_entry_header(name, 0x10 /* int32 */); - write_number(static_cast(value)); - } - else if (value <= static_cast((std::numeric_limits::max)())) - { - write_bson_entry_header(name, 0x12 /* int64 */); - write_number(static_cast(value)); - } - else - { - JSON_THROW(out_of_range::create(407, "integer number " + std::to_string(value) + " cannot be represented by BSON as it does not fit int64")); - } - } - - /*! - @brief Writes a BSON element with key @a name and object @a value - */ - void write_bson_object_entry(const string_t& name, - const typename BasicJsonType::object_t& value) - { - write_bson_entry_header(name, 0x03); // object - write_bson_object(value); - } - - /*! - @return The size of the BSON-encoded array @a value - */ - static std::size_t calc_bson_array_size(const typename BasicJsonType::array_t& value) - { - std::size_t array_index = 0ul; - - const std::size_t embedded_document_size = std::accumulate(std::begin(value), std::end(value), std::size_t(0), [&array_index](std::size_t result, const typename BasicJsonType::array_t::value_type & el) - { - return result + calc_bson_element_size(std::to_string(array_index++), el); - }); - - return sizeof(std::int32_t) + embedded_document_size + 1ul; - } - - /*! - @return The size of the BSON-encoded binary array @a value - */ - static std::size_t calc_bson_binary_size(const typename BasicJsonType::binary_t& value) - { - return sizeof(std::int32_t) + value.size() + 1ul; - } - - /*! - @brief Writes a BSON element with key @a name and array @a value - */ - void write_bson_array(const string_t& name, - const typename BasicJsonType::array_t& value) - { - write_bson_entry_header(name, 0x04); // array - write_number(static_cast(calc_bson_array_size(value))); - - std::size_t array_index = 0ul; - - for (const auto& el : value) - { - write_bson_element(std::to_string(array_index++), el); - } - - oa->write_character(to_char_type(0x00)); - } - - /*! - @brief Writes a BSON element with key @a name and binary value @a value - */ - void write_bson_binary(const string_t& name, - const binary_t& value) - { - write_bson_entry_header(name, 0x05); - - write_number(static_cast(value.size())); - write_number(value.has_subtype() ? value.subtype() : std::uint8_t(0x00)); - - oa->write_characters(reinterpret_cast(value.data()), value.size()); - } - - /*! - @brief Calculates the size necessary to serialize the JSON value @a j with its @a name - @return The calculated size for the BSON document entry for @a j with the given @a name. - */ - static std::size_t calc_bson_element_size(const string_t& name, - const BasicJsonType& j) - { - const auto header_size = calc_bson_entry_header_size(name); - switch (j.type()) - { - case value_t::object: - return header_size + calc_bson_object_size(*j.m_value.object); - - case value_t::array: - return header_size + calc_bson_array_size(*j.m_value.array); - - case value_t::binary: - return header_size + calc_bson_binary_size(*j.m_value.binary); - - case value_t::boolean: - return header_size + 1ul; - - case value_t::number_float: - return header_size + 8ul; - - case value_t::number_integer: - return header_size + calc_bson_integer_size(j.m_value.number_integer); - - case value_t::number_unsigned: - return header_size + calc_bson_unsigned_size(j.m_value.number_unsigned); - - case value_t::string: - return header_size + calc_bson_string_size(*j.m_value.string); - - case value_t::null: - return header_size + 0ul; - - // LCOV_EXCL_START - default: - JSON_ASSERT(false); - return 0ul; - // LCOV_EXCL_STOP - } - } - - /*! - @brief Serializes the JSON value @a j to BSON and associates it with the - key @a name. - @param name The name to associate with the JSON entity @a j within the - current BSON document - @return The size of the BSON entry - */ - void write_bson_element(const string_t& name, - const BasicJsonType& j) - { - switch (j.type()) - { - case value_t::object: - return write_bson_object_entry(name, *j.m_value.object); - - case value_t::array: - return write_bson_array(name, *j.m_value.array); - - case value_t::binary: - return write_bson_binary(name, *j.m_value.binary); - - case value_t::boolean: - return write_bson_boolean(name, j.m_value.boolean); - - case value_t::number_float: - return write_bson_double(name, j.m_value.number_float); - - case value_t::number_integer: - return write_bson_integer(name, j.m_value.number_integer); - - case value_t::number_unsigned: - return write_bson_unsigned(name, j.m_value.number_unsigned); - - case value_t::string: - return write_bson_string(name, *j.m_value.string); - - case value_t::null: - return write_bson_null(name); - - // LCOV_EXCL_START - default: - JSON_ASSERT(false); - return; - // LCOV_EXCL_STOP - } - } - - /*! - @brief Calculates the size of the BSON serialization of the given - JSON-object @a j. - @param[in] j JSON value to serialize - @pre j.type() == value_t::object - */ - static std::size_t calc_bson_object_size(const typename BasicJsonType::object_t& value) - { - std::size_t document_size = std::accumulate(value.begin(), value.end(), std::size_t(0), - [](size_t result, const typename BasicJsonType::object_t::value_type & el) - { - return result += calc_bson_element_size(el.first, el.second); - }); - - return sizeof(std::int32_t) + document_size + 1ul; - } - - /*! - @param[in] j JSON value to serialize - @pre j.type() == value_t::object - */ - void write_bson_object(const typename BasicJsonType::object_t& value) - { - write_number(static_cast(calc_bson_object_size(value))); - - for (const auto& el : value) - { - write_bson_element(el.first, el.second); - } - - oa->write_character(to_char_type(0x00)); - } - - ////////// - // CBOR // - ////////// - - static constexpr CharType get_cbor_float_prefix(float /*unused*/) - { - return to_char_type(0xFA); // Single-Precision Float - } - - static constexpr CharType get_cbor_float_prefix(double /*unused*/) - { - return to_char_type(0xFB); // Double-Precision Float - } - - ///////////// - // MsgPack // - ///////////// - - static constexpr CharType get_msgpack_float_prefix(float /*unused*/) - { - return to_char_type(0xCA); // float 32 - } - - static constexpr CharType get_msgpack_float_prefix(double /*unused*/) - { - return to_char_type(0xCB); // float 64 - } - - //////////// - // UBJSON // - //////////// - - // UBJSON: write number (floating point) - template::value, int>::type = 0> - void write_number_with_ubjson_prefix(const NumberType n, - const bool add_prefix) - { - if (add_prefix) - { - oa->write_character(get_ubjson_float_prefix(n)); - } - write_number(n); - } - - // UBJSON: write number (unsigned integer) - template::value, int>::type = 0> - void write_number_with_ubjson_prefix(const NumberType n, - const bool add_prefix) - { - if (n <= static_cast((std::numeric_limits::max)())) - { - if (add_prefix) - { - oa->write_character(to_char_type('i')); // int8 - } - write_number(static_cast(n)); - } - else if (n <= (std::numeric_limits::max)()) - { - if (add_prefix) - { - oa->write_character(to_char_type('U')); // uint8 - } - write_number(static_cast(n)); - } - else if (n <= static_cast((std::numeric_limits::max)())) - { - if (add_prefix) - { - oa->write_character(to_char_type('I')); // int16 - } - write_number(static_cast(n)); - } - else if (n <= static_cast((std::numeric_limits::max)())) - { - if (add_prefix) - { - oa->write_character(to_char_type('l')); // int32 - } - write_number(static_cast(n)); - } - else if (n <= static_cast((std::numeric_limits::max)())) - { - if (add_prefix) - { - oa->write_character(to_char_type('L')); // int64 - } - write_number(static_cast(n)); - } - else - { - if (add_prefix) - { - oa->write_character(to_char_type('H')); // high-precision number - } - - const auto number = BasicJsonType(n).dump(); - write_number_with_ubjson_prefix(number.size(), true); - for (std::size_t i = 0; i < number.size(); ++i) - { - oa->write_character(to_char_type(static_cast(number[i]))); - } - } - } - - // UBJSON: write number (signed integer) - template < typename NumberType, typename std::enable_if < - std::is_signed::value&& - !std::is_floating_point::value, int >::type = 0 > - void write_number_with_ubjson_prefix(const NumberType n, - const bool add_prefix) - { - if ((std::numeric_limits::min)() <= n && n <= (std::numeric_limits::max)()) - { - if (add_prefix) - { - oa->write_character(to_char_type('i')); // int8 - } - write_number(static_cast(n)); - } - else if (static_cast((std::numeric_limits::min)()) <= n && n <= static_cast((std::numeric_limits::max)())) - { - if (add_prefix) - { - oa->write_character(to_char_type('U')); // uint8 - } - write_number(static_cast(n)); - } - else if ((std::numeric_limits::min)() <= n && n <= (std::numeric_limits::max)()) - { - if (add_prefix) - { - oa->write_character(to_char_type('I')); // int16 - } - write_number(static_cast(n)); - } - else if ((std::numeric_limits::min)() <= n && n <= (std::numeric_limits::max)()) - { - if (add_prefix) - { - oa->write_character(to_char_type('l')); // int32 - } - write_number(static_cast(n)); - } - else if ((std::numeric_limits::min)() <= n && n <= (std::numeric_limits::max)()) - { - if (add_prefix) - { - oa->write_character(to_char_type('L')); // int64 - } - write_number(static_cast(n)); - } - // LCOV_EXCL_START - else - { - if (add_prefix) - { - oa->write_character(to_char_type('H')); // high-precision number - } - - const auto number = BasicJsonType(n).dump(); - write_number_with_ubjson_prefix(number.size(), true); - for (std::size_t i = 0; i < number.size(); ++i) - { - oa->write_character(to_char_type(static_cast(number[i]))); - } - } - // LCOV_EXCL_STOP - } - - /*! - @brief determine the type prefix of container values - */ - CharType ubjson_prefix(const BasicJsonType& j) const noexcept - { - switch (j.type()) - { - case value_t::null: - return 'Z'; - - case value_t::boolean: - return j.m_value.boolean ? 'T' : 'F'; - - case value_t::number_integer: - { - if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) - { - return 'i'; - } - if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) - { - return 'U'; - } - if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) - { - return 'I'; - } - if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) - { - return 'l'; - } - if ((std::numeric_limits::min)() <= j.m_value.number_integer && j.m_value.number_integer <= (std::numeric_limits::max)()) - { - return 'L'; - } - // anything else is treated as high-precision number - return 'H'; // LCOV_EXCL_LINE - } - - case value_t::number_unsigned: - { - if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) - { - return 'i'; - } - if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) - { - return 'U'; - } - if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) - { - return 'I'; - } - if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) - { - return 'l'; - } - if (j.m_value.number_unsigned <= static_cast((std::numeric_limits::max)())) - { - return 'L'; - } - // anything else is treated as high-precision number - return 'H'; // LCOV_EXCL_LINE - } - - case value_t::number_float: - return get_ubjson_float_prefix(j.m_value.number_float); - - case value_t::string: - return 'S'; - - case value_t::array: // fallthrough - case value_t::binary: - return '['; - - case value_t::object: - return '{'; - - default: // discarded values - return 'N'; - } - } - - static constexpr CharType get_ubjson_float_prefix(float /*unused*/) - { - return 'd'; // float 32 - } - - static constexpr CharType get_ubjson_float_prefix(double /*unused*/) - { - return 'D'; // float 64 - } - - /////////////////////// - // Utility functions // - /////////////////////// - - /* - @brief write a number to output input - @param[in] n number of type @a NumberType - @tparam NumberType the type of the number - @tparam OutputIsLittleEndian Set to true if output data is - required to be little endian - - @note This function needs to respect the system's endianess, because bytes - in CBOR, MessagePack, and UBJSON are stored in network order (big - endian) and therefore need reordering on little endian systems. - */ - template - void write_number(const NumberType n) - { - // step 1: write number to array of length NumberType - std::array vec; - std::memcpy(vec.data(), &n, sizeof(NumberType)); - - // step 2: write array to output (with possible reordering) - if (is_little_endian != OutputIsLittleEndian) - { - // reverse byte order prior to conversion if necessary - std::reverse(vec.begin(), vec.end()); - } - - oa->write_characters(vec.data(), sizeof(NumberType)); - } - - void write_compact_float(const number_float_t n, detail::input_format_t format) - { - if (static_cast(n) >= static_cast(std::numeric_limits::lowest()) && - static_cast(n) <= static_cast((std::numeric_limits::max)()) && - static_cast(static_cast(n)) == static_cast(n)) - { - oa->write_character(format == detail::input_format_t::cbor - ? get_cbor_float_prefix(static_cast(n)) - : get_msgpack_float_prefix(static_cast(n))); - write_number(static_cast(n)); - } - else - { - oa->write_character(format == detail::input_format_t::cbor - ? get_cbor_float_prefix(n) - : get_msgpack_float_prefix(n)); - write_number(n); - } - } - - public: - // The following to_char_type functions are implement the conversion - // between uint8_t and CharType. In case CharType is not unsigned, - // such a conversion is required to allow values greater than 128. - // See for a discussion. - template < typename C = CharType, - enable_if_t < std::is_signed::value && std::is_signed::value > * = nullptr > - static constexpr CharType to_char_type(std::uint8_t x) noexcept - { - return *reinterpret_cast(&x); - } - - template < typename C = CharType, - enable_if_t < std::is_signed::value && std::is_unsigned::value > * = nullptr > - static CharType to_char_type(std::uint8_t x) noexcept - { - static_assert(sizeof(std::uint8_t) == sizeof(CharType), "size of CharType must be equal to std::uint8_t"); - static_assert(std::is_trivial::value, "CharType must be trivial"); - CharType result; - std::memcpy(&result, &x, sizeof(x)); - return result; - } - - template::value>* = nullptr> - static constexpr CharType to_char_type(std::uint8_t x) noexcept - { - return x; - } - - template < typename InputCharType, typename C = CharType, - enable_if_t < - std::is_signed::value && - std::is_signed::value && - std::is_same::type>::value - > * = nullptr > - static constexpr CharType to_char_type(InputCharType x) noexcept - { - return x; - } - - private: - /// whether we can assume little endianess - const bool is_little_endian = little_endianess(); - - /// the output - output_adapter_t oa = nullptr; -}; -} // namespace detail -} // namespace nlohmann - -// #include - -// #include - - -#include // reverse, remove, fill, find, none_of -#include // array -#include // localeconv, lconv -#include // labs, isfinite, isnan, signbit -#include // size_t, ptrdiff_t -#include // uint8_t -#include // snprintf -#include // numeric_limits -#include // string, char_traits -#include // is_same -#include // move - -// #include - - -#include // array -#include // signbit, isfinite -#include // intN_t, uintN_t -#include // memcpy, memmove -#include // numeric_limits -#include // conditional - -// #include - - -namespace nlohmann -{ -namespace detail -{ - -/*! -@brief implements the Grisu2 algorithm for binary to decimal floating-point -conversion. - -This implementation is a slightly modified version of the reference -implementation which may be obtained from -http://florian.loitsch.com/publications (bench.tar.gz). - -The code is distributed under the MIT license, Copyright (c) 2009 Florian Loitsch. - -For a detailed description of the algorithm see: - -[1] Loitsch, "Printing Floating-Point Numbers Quickly and Accurately with - Integers", Proceedings of the ACM SIGPLAN 2010 Conference on Programming - Language Design and Implementation, PLDI 2010 -[2] Burger, Dybvig, "Printing Floating-Point Numbers Quickly and Accurately", - Proceedings of the ACM SIGPLAN 1996 Conference on Programming Language - Design and Implementation, PLDI 1996 -*/ -namespace dtoa_impl -{ - -template -Target reinterpret_bits(const Source source) -{ - static_assert(sizeof(Target) == sizeof(Source), "size mismatch"); - - Target target; - std::memcpy(&target, &source, sizeof(Source)); - return target; -} - -struct diyfp // f * 2^e -{ - static constexpr int kPrecision = 64; // = q - - std::uint64_t f = 0; - int e = 0; - - constexpr diyfp(std::uint64_t f_, int e_) noexcept : f(f_), e(e_) {} - - /*! - @brief returns x - y - @pre x.e == y.e and x.f >= y.f - */ - static diyfp sub(const diyfp& x, const diyfp& y) noexcept - { - JSON_ASSERT(x.e == y.e); - JSON_ASSERT(x.f >= y.f); - - return {x.f - y.f, x.e}; - } - - /*! - @brief returns x * y - @note The result is rounded. (Only the upper q bits are returned.) - */ - static diyfp mul(const diyfp& x, const diyfp& y) noexcept - { - static_assert(kPrecision == 64, "internal error"); - - // Computes: - // f = round((x.f * y.f) / 2^q) - // e = x.e + y.e + q - - // Emulate the 64-bit * 64-bit multiplication: - // - // p = u * v - // = (u_lo + 2^32 u_hi) (v_lo + 2^32 v_hi) - // = (u_lo v_lo ) + 2^32 ((u_lo v_hi ) + (u_hi v_lo )) + 2^64 (u_hi v_hi ) - // = (p0 ) + 2^32 ((p1 ) + (p2 )) + 2^64 (p3 ) - // = (p0_lo + 2^32 p0_hi) + 2^32 ((p1_lo + 2^32 p1_hi) + (p2_lo + 2^32 p2_hi)) + 2^64 (p3 ) - // = (p0_lo ) + 2^32 (p0_hi + p1_lo + p2_lo ) + 2^64 (p1_hi + p2_hi + p3) - // = (p0_lo ) + 2^32 (Q ) + 2^64 (H ) - // = (p0_lo ) + 2^32 (Q_lo + 2^32 Q_hi ) + 2^64 (H ) - // - // (Since Q might be larger than 2^32 - 1) - // - // = (p0_lo + 2^32 Q_lo) + 2^64 (Q_hi + H) - // - // (Q_hi + H does not overflow a 64-bit int) - // - // = p_lo + 2^64 p_hi - - const std::uint64_t u_lo = x.f & 0xFFFFFFFFu; - const std::uint64_t u_hi = x.f >> 32u; - const std::uint64_t v_lo = y.f & 0xFFFFFFFFu; - const std::uint64_t v_hi = y.f >> 32u; - - const std::uint64_t p0 = u_lo * v_lo; - const std::uint64_t p1 = u_lo * v_hi; - const std::uint64_t p2 = u_hi * v_lo; - const std::uint64_t p3 = u_hi * v_hi; - - const std::uint64_t p0_hi = p0 >> 32u; - const std::uint64_t p1_lo = p1 & 0xFFFFFFFFu; - const std::uint64_t p1_hi = p1 >> 32u; - const std::uint64_t p2_lo = p2 & 0xFFFFFFFFu; - const std::uint64_t p2_hi = p2 >> 32u; - - std::uint64_t Q = p0_hi + p1_lo + p2_lo; - - // The full product might now be computed as - // - // p_hi = p3 + p2_hi + p1_hi + (Q >> 32) - // p_lo = p0_lo + (Q << 32) - // - // But in this particular case here, the full p_lo is not required. - // Effectively we only need to add the highest bit in p_lo to p_hi (and - // Q_hi + 1 does not overflow). - - Q += std::uint64_t{1} << (64u - 32u - 1u); // round, ties up - - const std::uint64_t h = p3 + p2_hi + p1_hi + (Q >> 32u); - - return {h, x.e + y.e + 64}; - } - - /*! - @brief normalize x such that the significand is >= 2^(q-1) - @pre x.f != 0 - */ - static diyfp normalize(diyfp x) noexcept - { - JSON_ASSERT(x.f != 0); - - while ((x.f >> 63u) == 0) - { - x.f <<= 1u; - x.e--; - } - - return x; - } - - /*! - @brief normalize x such that the result has the exponent E - @pre e >= x.e and the upper e - x.e bits of x.f must be zero. - */ - static diyfp normalize_to(const diyfp& x, const int target_exponent) noexcept - { - const int delta = x.e - target_exponent; - - JSON_ASSERT(delta >= 0); - JSON_ASSERT(((x.f << delta) >> delta) == x.f); - - return {x.f << delta, target_exponent}; - } -}; - -struct boundaries -{ - diyfp w; - diyfp minus; - diyfp plus; -}; - -/*! -Compute the (normalized) diyfp representing the input number 'value' and its -boundaries. - -@pre value must be finite and positive -*/ -template -boundaries compute_boundaries(FloatType value) -{ - JSON_ASSERT(std::isfinite(value)); - JSON_ASSERT(value > 0); - - // Convert the IEEE representation into a diyfp. - // - // If v is denormal: - // value = 0.F * 2^(1 - bias) = ( F) * 2^(1 - bias - (p-1)) - // If v is normalized: - // value = 1.F * 2^(E - bias) = (2^(p-1) + F) * 2^(E - bias - (p-1)) - - static_assert(std::numeric_limits::is_iec559, - "internal error: dtoa_short requires an IEEE-754 floating-point implementation"); - - constexpr int kPrecision = std::numeric_limits::digits; // = p (includes the hidden bit) - constexpr int kBias = std::numeric_limits::max_exponent - 1 + (kPrecision - 1); - constexpr int kMinExp = 1 - kBias; - constexpr std::uint64_t kHiddenBit = std::uint64_t{1} << (kPrecision - 1); // = 2^(p-1) - - using bits_type = typename std::conditional::type; - - const std::uint64_t bits = reinterpret_bits(value); - const std::uint64_t E = bits >> (kPrecision - 1); - const std::uint64_t F = bits & (kHiddenBit - 1); - - const bool is_denormal = E == 0; - const diyfp v = is_denormal - ? diyfp(F, kMinExp) - : diyfp(F + kHiddenBit, static_cast(E) - kBias); - - // Compute the boundaries m- and m+ of the floating-point value - // v = f * 2^e. - // - // Determine v- and v+, the floating-point predecessor and successor if v, - // respectively. - // - // v- = v - 2^e if f != 2^(p-1) or e == e_min (A) - // = v - 2^(e-1) if f == 2^(p-1) and e > e_min (B) - // - // v+ = v + 2^e - // - // Let m- = (v- + v) / 2 and m+ = (v + v+) / 2. All real numbers _strictly_ - // between m- and m+ round to v, regardless of how the input rounding - // algorithm breaks ties. - // - // ---+-------------+-------------+-------------+-------------+--- (A) - // v- m- v m+ v+ - // - // -----------------+------+------+-------------+-------------+--- (B) - // v- m- v m+ v+ - - const bool lower_boundary_is_closer = F == 0 && E > 1; - const diyfp m_plus = diyfp(2 * v.f + 1, v.e - 1); - const diyfp m_minus = lower_boundary_is_closer - ? diyfp(4 * v.f - 1, v.e - 2) // (B) - : diyfp(2 * v.f - 1, v.e - 1); // (A) - - // Determine the normalized w+ = m+. - const diyfp w_plus = diyfp::normalize(m_plus); - - // Determine w- = m- such that e_(w-) = e_(w+). - const diyfp w_minus = diyfp::normalize_to(m_minus, w_plus.e); - - return {diyfp::normalize(v), w_minus, w_plus}; -} - -// Given normalized diyfp w, Grisu needs to find a (normalized) cached -// power-of-ten c, such that the exponent of the product c * w = f * 2^e lies -// within a certain range [alpha, gamma] (Definition 3.2 from [1]) -// -// alpha <= e = e_c + e_w + q <= gamma -// -// or -// -// f_c * f_w * 2^alpha <= f_c 2^(e_c) * f_w 2^(e_w) * 2^q -// <= f_c * f_w * 2^gamma -// -// Since c and w are normalized, i.e. 2^(q-1) <= f < 2^q, this implies -// -// 2^(q-1) * 2^(q-1) * 2^alpha <= c * w * 2^q < 2^q * 2^q * 2^gamma -// -// or -// -// 2^(q - 2 + alpha) <= c * w < 2^(q + gamma) -// -// The choice of (alpha,gamma) determines the size of the table and the form of -// the digit generation procedure. Using (alpha,gamma)=(-60,-32) works out well -// in practice: -// -// The idea is to cut the number c * w = f * 2^e into two parts, which can be -// processed independently: An integral part p1, and a fractional part p2: -// -// f * 2^e = ( (f div 2^-e) * 2^-e + (f mod 2^-e) ) * 2^e -// = (f div 2^-e) + (f mod 2^-e) * 2^e -// = p1 + p2 * 2^e -// -// The conversion of p1 into decimal form requires a series of divisions and -// modulos by (a power of) 10. These operations are faster for 32-bit than for -// 64-bit integers, so p1 should ideally fit into a 32-bit integer. This can be -// achieved by choosing -// -// -e >= 32 or e <= -32 := gamma -// -// In order to convert the fractional part -// -// p2 * 2^e = p2 / 2^-e = d[-1] / 10^1 + d[-2] / 10^2 + ... -// -// into decimal form, the fraction is repeatedly multiplied by 10 and the digits -// d[-i] are extracted in order: -// -// (10 * p2) div 2^-e = d[-1] -// (10 * p2) mod 2^-e = d[-2] / 10^1 + ... -// -// The multiplication by 10 must not overflow. It is sufficient to choose -// -// 10 * p2 < 16 * p2 = 2^4 * p2 <= 2^64. -// -// Since p2 = f mod 2^-e < 2^-e, -// -// -e <= 60 or e >= -60 := alpha - -constexpr int kAlpha = -60; -constexpr int kGamma = -32; - -struct cached_power // c = f * 2^e ~= 10^k -{ - std::uint64_t f; - int e; - int k; -}; - -/*! -For a normalized diyfp w = f * 2^e, this function returns a (normalized) cached -power-of-ten c = f_c * 2^e_c, such that the exponent of the product w * c -satisfies (Definition 3.2 from [1]) - - alpha <= e_c + e + q <= gamma. -*/ -inline cached_power get_cached_power_for_binary_exponent(int e) -{ - // Now - // - // alpha <= e_c + e + q <= gamma (1) - // ==> f_c * 2^alpha <= c * 2^e * 2^q - // - // and since the c's are normalized, 2^(q-1) <= f_c, - // - // ==> 2^(q - 1 + alpha) <= c * 2^(e + q) - // ==> 2^(alpha - e - 1) <= c - // - // If c were an exact power of ten, i.e. c = 10^k, one may determine k as - // - // k = ceil( log_10( 2^(alpha - e - 1) ) ) - // = ceil( (alpha - e - 1) * log_10(2) ) - // - // From the paper: - // "In theory the result of the procedure could be wrong since c is rounded, - // and the computation itself is approximated [...]. In practice, however, - // this simple function is sufficient." - // - // For IEEE double precision floating-point numbers converted into - // normalized diyfp's w = f * 2^e, with q = 64, - // - // e >= -1022 (min IEEE exponent) - // -52 (p - 1) - // -52 (p - 1, possibly normalize denormal IEEE numbers) - // -11 (normalize the diyfp) - // = -1137 - // - // and - // - // e <= +1023 (max IEEE exponent) - // -52 (p - 1) - // -11 (normalize the diyfp) - // = 960 - // - // This binary exponent range [-1137,960] results in a decimal exponent - // range [-307,324]. One does not need to store a cached power for each - // k in this range. For each such k it suffices to find a cached power - // such that the exponent of the product lies in [alpha,gamma]. - // This implies that the difference of the decimal exponents of adjacent - // table entries must be less than or equal to - // - // floor( (gamma - alpha) * log_10(2) ) = 8. - // - // (A smaller distance gamma-alpha would require a larger table.) - - // NB: - // Actually this function returns c, such that -60 <= e_c + e + 64 <= -34. - - constexpr int kCachedPowersMinDecExp = -300; - constexpr int kCachedPowersDecStep = 8; - - static constexpr std::array kCachedPowers = - { - { - { 0xAB70FE17C79AC6CA, -1060, -300 }, - { 0xFF77B1FCBEBCDC4F, -1034, -292 }, - { 0xBE5691EF416BD60C, -1007, -284 }, - { 0x8DD01FAD907FFC3C, -980, -276 }, - { 0xD3515C2831559A83, -954, -268 }, - { 0x9D71AC8FADA6C9B5, -927, -260 }, - { 0xEA9C227723EE8BCB, -901, -252 }, - { 0xAECC49914078536D, -874, -244 }, - { 0x823C12795DB6CE57, -847, -236 }, - { 0xC21094364DFB5637, -821, -228 }, - { 0x9096EA6F3848984F, -794, -220 }, - { 0xD77485CB25823AC7, -768, -212 }, - { 0xA086CFCD97BF97F4, -741, -204 }, - { 0xEF340A98172AACE5, -715, -196 }, - { 0xB23867FB2A35B28E, -688, -188 }, - { 0x84C8D4DFD2C63F3B, -661, -180 }, - { 0xC5DD44271AD3CDBA, -635, -172 }, - { 0x936B9FCEBB25C996, -608, -164 }, - { 0xDBAC6C247D62A584, -582, -156 }, - { 0xA3AB66580D5FDAF6, -555, -148 }, - { 0xF3E2F893DEC3F126, -529, -140 }, - { 0xB5B5ADA8AAFF80B8, -502, -132 }, - { 0x87625F056C7C4A8B, -475, -124 }, - { 0xC9BCFF6034C13053, -449, -116 }, - { 0x964E858C91BA2655, -422, -108 }, - { 0xDFF9772470297EBD, -396, -100 }, - { 0xA6DFBD9FB8E5B88F, -369, -92 }, - { 0xF8A95FCF88747D94, -343, -84 }, - { 0xB94470938FA89BCF, -316, -76 }, - { 0x8A08F0F8BF0F156B, -289, -68 }, - { 0xCDB02555653131B6, -263, -60 }, - { 0x993FE2C6D07B7FAC, -236, -52 }, - { 0xE45C10C42A2B3B06, -210, -44 }, - { 0xAA242499697392D3, -183, -36 }, - { 0xFD87B5F28300CA0E, -157, -28 }, - { 0xBCE5086492111AEB, -130, -20 }, - { 0x8CBCCC096F5088CC, -103, -12 }, - { 0xD1B71758E219652C, -77, -4 }, - { 0x9C40000000000000, -50, 4 }, - { 0xE8D4A51000000000, -24, 12 }, - { 0xAD78EBC5AC620000, 3, 20 }, - { 0x813F3978F8940984, 30, 28 }, - { 0xC097CE7BC90715B3, 56, 36 }, - { 0x8F7E32CE7BEA5C70, 83, 44 }, - { 0xD5D238A4ABE98068, 109, 52 }, - { 0x9F4F2726179A2245, 136, 60 }, - { 0xED63A231D4C4FB27, 162, 68 }, - { 0xB0DE65388CC8ADA8, 189, 76 }, - { 0x83C7088E1AAB65DB, 216, 84 }, - { 0xC45D1DF942711D9A, 242, 92 }, - { 0x924D692CA61BE758, 269, 100 }, - { 0xDA01EE641A708DEA, 295, 108 }, - { 0xA26DA3999AEF774A, 322, 116 }, - { 0xF209787BB47D6B85, 348, 124 }, - { 0xB454E4A179DD1877, 375, 132 }, - { 0x865B86925B9BC5C2, 402, 140 }, - { 0xC83553C5C8965D3D, 428, 148 }, - { 0x952AB45CFA97A0B3, 455, 156 }, - { 0xDE469FBD99A05FE3, 481, 164 }, - { 0xA59BC234DB398C25, 508, 172 }, - { 0xF6C69A72A3989F5C, 534, 180 }, - { 0xB7DCBF5354E9BECE, 561, 188 }, - { 0x88FCF317F22241E2, 588, 196 }, - { 0xCC20CE9BD35C78A5, 614, 204 }, - { 0x98165AF37B2153DF, 641, 212 }, - { 0xE2A0B5DC971F303A, 667, 220 }, - { 0xA8D9D1535CE3B396, 694, 228 }, - { 0xFB9B7CD9A4A7443C, 720, 236 }, - { 0xBB764C4CA7A44410, 747, 244 }, - { 0x8BAB8EEFB6409C1A, 774, 252 }, - { 0xD01FEF10A657842C, 800, 260 }, - { 0x9B10A4E5E9913129, 827, 268 }, - { 0xE7109BFBA19C0C9D, 853, 276 }, - { 0xAC2820D9623BF429, 880, 284 }, - { 0x80444B5E7AA7CF85, 907, 292 }, - { 0xBF21E44003ACDD2D, 933, 300 }, - { 0x8E679C2F5E44FF8F, 960, 308 }, - { 0xD433179D9C8CB841, 986, 316 }, - { 0x9E19DB92B4E31BA9, 1013, 324 }, - } - }; - - // This computation gives exactly the same results for k as - // k = ceil((kAlpha - e - 1) * 0.30102999566398114) - // for |e| <= 1500, but doesn't require floating-point operations. - // NB: log_10(2) ~= 78913 / 2^18 - JSON_ASSERT(e >= -1500); - JSON_ASSERT(e <= 1500); - const int f = kAlpha - e - 1; - const int k = (f * 78913) / (1 << 18) + static_cast(f > 0); - - const int index = (-kCachedPowersMinDecExp + k + (kCachedPowersDecStep - 1)) / kCachedPowersDecStep; - JSON_ASSERT(index >= 0); - JSON_ASSERT(static_cast(index) < kCachedPowers.size()); - - const cached_power cached = kCachedPowers[static_cast(index)]; - JSON_ASSERT(kAlpha <= cached.e + e + 64); - JSON_ASSERT(kGamma >= cached.e + e + 64); - - return cached; -} - -/*! -For n != 0, returns k, such that pow10 := 10^(k-1) <= n < 10^k. -For n == 0, returns 1 and sets pow10 := 1. -*/ -inline int find_largest_pow10(const std::uint32_t n, std::uint32_t& pow10) -{ - // LCOV_EXCL_START - if (n >= 1000000000) - { - pow10 = 1000000000; - return 10; - } - // LCOV_EXCL_STOP - else if (n >= 100000000) - { - pow10 = 100000000; - return 9; - } - else if (n >= 10000000) - { - pow10 = 10000000; - return 8; - } - else if (n >= 1000000) - { - pow10 = 1000000; - return 7; - } - else if (n >= 100000) - { - pow10 = 100000; - return 6; - } - else if (n >= 10000) - { - pow10 = 10000; - return 5; - } - else if (n >= 1000) - { - pow10 = 1000; - return 4; - } - else if (n >= 100) - { - pow10 = 100; - return 3; - } - else if (n >= 10) - { - pow10 = 10; - return 2; - } - else - { - pow10 = 1; - return 1; - } -} - -inline void grisu2_round(char* buf, int len, std::uint64_t dist, std::uint64_t delta, - std::uint64_t rest, std::uint64_t ten_k) -{ - JSON_ASSERT(len >= 1); - JSON_ASSERT(dist <= delta); - JSON_ASSERT(rest <= delta); - JSON_ASSERT(ten_k > 0); - - // <--------------------------- delta ----> - // <---- dist ---------> - // --------------[------------------+-------------------]-------------- - // M- w M+ - // - // ten_k - // <------> - // <---- rest ----> - // --------------[------------------+----+--------------]-------------- - // w V - // = buf * 10^k - // - // ten_k represents a unit-in-the-last-place in the decimal representation - // stored in buf. - // Decrement buf by ten_k while this takes buf closer to w. - - // The tests are written in this order to avoid overflow in unsigned - // integer arithmetic. - - while (rest < dist - && delta - rest >= ten_k - && (rest + ten_k < dist || dist - rest > rest + ten_k - dist)) - { - JSON_ASSERT(buf[len - 1] != '0'); - buf[len - 1]--; - rest += ten_k; - } -} - -/*! -Generates V = buffer * 10^decimal_exponent, such that M- <= V <= M+. -M- and M+ must be normalized and share the same exponent -60 <= e <= -32. -*/ -inline void grisu2_digit_gen(char* buffer, int& length, int& decimal_exponent, - diyfp M_minus, diyfp w, diyfp M_plus) -{ - static_assert(kAlpha >= -60, "internal error"); - static_assert(kGamma <= -32, "internal error"); - - // Generates the digits (and the exponent) of a decimal floating-point - // number V = buffer * 10^decimal_exponent in the range [M-, M+]. The diyfp's - // w, M- and M+ share the same exponent e, which satisfies alpha <= e <= gamma. - // - // <--------------------------- delta ----> - // <---- dist ---------> - // --------------[------------------+-------------------]-------------- - // M- w M+ - // - // Grisu2 generates the digits of M+ from left to right and stops as soon as - // V is in [M-,M+]. - - JSON_ASSERT(M_plus.e >= kAlpha); - JSON_ASSERT(M_plus.e <= kGamma); - - std::uint64_t delta = diyfp::sub(M_plus, M_minus).f; // (significand of (M+ - M-), implicit exponent is e) - std::uint64_t dist = diyfp::sub(M_plus, w ).f; // (significand of (M+ - w ), implicit exponent is e) - - // Split M+ = f * 2^e into two parts p1 and p2 (note: e < 0): - // - // M+ = f * 2^e - // = ((f div 2^-e) * 2^-e + (f mod 2^-e)) * 2^e - // = ((p1 ) * 2^-e + (p2 )) * 2^e - // = p1 + p2 * 2^e - - const diyfp one(std::uint64_t{1} << -M_plus.e, M_plus.e); - - auto p1 = static_cast(M_plus.f >> -one.e); // p1 = f div 2^-e (Since -e >= 32, p1 fits into a 32-bit int.) - std::uint64_t p2 = M_plus.f & (one.f - 1); // p2 = f mod 2^-e - - // 1) - // - // Generate the digits of the integral part p1 = d[n-1]...d[1]d[0] - - JSON_ASSERT(p1 > 0); - - std::uint32_t pow10; - const int k = find_largest_pow10(p1, pow10); - - // 10^(k-1) <= p1 < 10^k, pow10 = 10^(k-1) - // - // p1 = (p1 div 10^(k-1)) * 10^(k-1) + (p1 mod 10^(k-1)) - // = (d[k-1] ) * 10^(k-1) + (p1 mod 10^(k-1)) - // - // M+ = p1 + p2 * 2^e - // = d[k-1] * 10^(k-1) + (p1 mod 10^(k-1)) + p2 * 2^e - // = d[k-1] * 10^(k-1) + ((p1 mod 10^(k-1)) * 2^-e + p2) * 2^e - // = d[k-1] * 10^(k-1) + ( rest) * 2^e - // - // Now generate the digits d[n] of p1 from left to right (n = k-1,...,0) - // - // p1 = d[k-1]...d[n] * 10^n + d[n-1]...d[0] - // - // but stop as soon as - // - // rest * 2^e = (d[n-1]...d[0] * 2^-e + p2) * 2^e <= delta * 2^e - - int n = k; - while (n > 0) - { - // Invariants: - // M+ = buffer * 10^n + (p1 + p2 * 2^e) (buffer = 0 for n = k) - // pow10 = 10^(n-1) <= p1 < 10^n - // - const std::uint32_t d = p1 / pow10; // d = p1 div 10^(n-1) - const std::uint32_t r = p1 % pow10; // r = p1 mod 10^(n-1) - // - // M+ = buffer * 10^n + (d * 10^(n-1) + r) + p2 * 2^e - // = (buffer * 10 + d) * 10^(n-1) + (r + p2 * 2^e) - // - JSON_ASSERT(d <= 9); - buffer[length++] = static_cast('0' + d); // buffer := buffer * 10 + d - // - // M+ = buffer * 10^(n-1) + (r + p2 * 2^e) - // - p1 = r; - n--; - // - // M+ = buffer * 10^n + (p1 + p2 * 2^e) - // pow10 = 10^n - // - - // Now check if enough digits have been generated. - // Compute - // - // p1 + p2 * 2^e = (p1 * 2^-e + p2) * 2^e = rest * 2^e - // - // Note: - // Since rest and delta share the same exponent e, it suffices to - // compare the significands. - const std::uint64_t rest = (std::uint64_t{p1} << -one.e) + p2; - if (rest <= delta) - { - // V = buffer * 10^n, with M- <= V <= M+. - - decimal_exponent += n; - - // We may now just stop. But instead look if the buffer could be - // decremented to bring V closer to w. - // - // pow10 = 10^n is now 1 ulp in the decimal representation V. - // The rounding procedure works with diyfp's with an implicit - // exponent of e. - // - // 10^n = (10^n * 2^-e) * 2^e = ulp * 2^e - // - const std::uint64_t ten_n = std::uint64_t{pow10} << -one.e; - grisu2_round(buffer, length, dist, delta, rest, ten_n); - - return; - } - - pow10 /= 10; - // - // pow10 = 10^(n-1) <= p1 < 10^n - // Invariants restored. - } - - // 2) - // - // The digits of the integral part have been generated: - // - // M+ = d[k-1]...d[1]d[0] + p2 * 2^e - // = buffer + p2 * 2^e - // - // Now generate the digits of the fractional part p2 * 2^e. - // - // Note: - // No decimal point is generated: the exponent is adjusted instead. - // - // p2 actually represents the fraction - // - // p2 * 2^e - // = p2 / 2^-e - // = d[-1] / 10^1 + d[-2] / 10^2 + ... - // - // Now generate the digits d[-m] of p1 from left to right (m = 1,2,...) - // - // p2 * 2^e = d[-1]d[-2]...d[-m] * 10^-m - // + 10^-m * (d[-m-1] / 10^1 + d[-m-2] / 10^2 + ...) - // - // using - // - // 10^m * p2 = ((10^m * p2) div 2^-e) * 2^-e + ((10^m * p2) mod 2^-e) - // = ( d) * 2^-e + ( r) - // - // or - // 10^m * p2 * 2^e = d + r * 2^e - // - // i.e. - // - // M+ = buffer + p2 * 2^e - // = buffer + 10^-m * (d + r * 2^e) - // = (buffer * 10^m + d) * 10^-m + 10^-m * r * 2^e - // - // and stop as soon as 10^-m * r * 2^e <= delta * 2^e - - JSON_ASSERT(p2 > delta); - - int m = 0; - for (;;) - { - // Invariant: - // M+ = buffer * 10^-m + 10^-m * (d[-m-1] / 10 + d[-m-2] / 10^2 + ...) * 2^e - // = buffer * 10^-m + 10^-m * (p2 ) * 2^e - // = buffer * 10^-m + 10^-m * (1/10 * (10 * p2) ) * 2^e - // = buffer * 10^-m + 10^-m * (1/10 * ((10*p2 div 2^-e) * 2^-e + (10*p2 mod 2^-e)) * 2^e - // - JSON_ASSERT(p2 <= (std::numeric_limits::max)() / 10); - p2 *= 10; - const std::uint64_t d = p2 >> -one.e; // d = (10 * p2) div 2^-e - const std::uint64_t r = p2 & (one.f - 1); // r = (10 * p2) mod 2^-e - // - // M+ = buffer * 10^-m + 10^-m * (1/10 * (d * 2^-e + r) * 2^e - // = buffer * 10^-m + 10^-m * (1/10 * (d + r * 2^e)) - // = (buffer * 10 + d) * 10^(-m-1) + 10^(-m-1) * r * 2^e - // - JSON_ASSERT(d <= 9); - buffer[length++] = static_cast('0' + d); // buffer := buffer * 10 + d - // - // M+ = buffer * 10^(-m-1) + 10^(-m-1) * r * 2^e - // - p2 = r; - m++; - // - // M+ = buffer * 10^-m + 10^-m * p2 * 2^e - // Invariant restored. - - // Check if enough digits have been generated. - // - // 10^-m * p2 * 2^e <= delta * 2^e - // p2 * 2^e <= 10^m * delta * 2^e - // p2 <= 10^m * delta - delta *= 10; - dist *= 10; - if (p2 <= delta) - { - break; - } - } - - // V = buffer * 10^-m, with M- <= V <= M+. - - decimal_exponent -= m; - - // 1 ulp in the decimal representation is now 10^-m. - // Since delta and dist are now scaled by 10^m, we need to do the - // same with ulp in order to keep the units in sync. - // - // 10^m * 10^-m = 1 = 2^-e * 2^e = ten_m * 2^e - // - const std::uint64_t ten_m = one.f; - grisu2_round(buffer, length, dist, delta, p2, ten_m); - - // By construction this algorithm generates the shortest possible decimal - // number (Loitsch, Theorem 6.2) which rounds back to w. - // For an input number of precision p, at least - // - // N = 1 + ceil(p * log_10(2)) - // - // decimal digits are sufficient to identify all binary floating-point - // numbers (Matula, "In-and-Out conversions"). - // This implies that the algorithm does not produce more than N decimal - // digits. - // - // N = 17 for p = 53 (IEEE double precision) - // N = 9 for p = 24 (IEEE single precision) -} - -/*! -v = buf * 10^decimal_exponent -len is the length of the buffer (number of decimal digits) -The buffer must be large enough, i.e. >= max_digits10. -*/ -JSON_HEDLEY_NON_NULL(1) -inline void grisu2(char* buf, int& len, int& decimal_exponent, - diyfp m_minus, diyfp v, diyfp m_plus) -{ - JSON_ASSERT(m_plus.e == m_minus.e); - JSON_ASSERT(m_plus.e == v.e); - - // --------(-----------------------+-----------------------)-------- (A) - // m- v m+ - // - // --------------------(-----------+-----------------------)-------- (B) - // m- v m+ - // - // First scale v (and m- and m+) such that the exponent is in the range - // [alpha, gamma]. - - const cached_power cached = get_cached_power_for_binary_exponent(m_plus.e); - - const diyfp c_minus_k(cached.f, cached.e); // = c ~= 10^-k - - // The exponent of the products is = v.e + c_minus_k.e + q and is in the range [alpha,gamma] - const diyfp w = diyfp::mul(v, c_minus_k); - const diyfp w_minus = diyfp::mul(m_minus, c_minus_k); - const diyfp w_plus = diyfp::mul(m_plus, c_minus_k); - - // ----(---+---)---------------(---+---)---------------(---+---)---- - // w- w w+ - // = c*m- = c*v = c*m+ - // - // diyfp::mul rounds its result and c_minus_k is approximated too. w, w- and - // w+ are now off by a small amount. - // In fact: - // - // w - v * 10^k < 1 ulp - // - // To account for this inaccuracy, add resp. subtract 1 ulp. - // - // --------+---[---------------(---+---)---------------]---+-------- - // w- M- w M+ w+ - // - // Now any number in [M-, M+] (bounds included) will round to w when input, - // regardless of how the input rounding algorithm breaks ties. - // - // And digit_gen generates the shortest possible such number in [M-, M+]. - // Note that this does not mean that Grisu2 always generates the shortest - // possible number in the interval (m-, m+). - const diyfp M_minus(w_minus.f + 1, w_minus.e); - const diyfp M_plus (w_plus.f - 1, w_plus.e ); - - decimal_exponent = -cached.k; // = -(-k) = k - - grisu2_digit_gen(buf, len, decimal_exponent, M_minus, w, M_plus); -} - -/*! -v = buf * 10^decimal_exponent -len is the length of the buffer (number of decimal digits) -The buffer must be large enough, i.e. >= max_digits10. -*/ -template -JSON_HEDLEY_NON_NULL(1) -void grisu2(char* buf, int& len, int& decimal_exponent, FloatType value) -{ - static_assert(diyfp::kPrecision >= std::numeric_limits::digits + 3, - "internal error: not enough precision"); - - JSON_ASSERT(std::isfinite(value)); - JSON_ASSERT(value > 0); - - // If the neighbors (and boundaries) of 'value' are always computed for double-precision - // numbers, all float's can be recovered using strtod (and strtof). However, the resulting - // decimal representations are not exactly "short". - // - // The documentation for 'std::to_chars' (https://en.cppreference.com/w/cpp/utility/to_chars) - // says "value is converted to a string as if by std::sprintf in the default ("C") locale" - // and since sprintf promotes float's to double's, I think this is exactly what 'std::to_chars' - // does. - // On the other hand, the documentation for 'std::to_chars' requires that "parsing the - // representation using the corresponding std::from_chars function recovers value exactly". That - // indicates that single precision floating-point numbers should be recovered using - // 'std::strtof'. - // - // NB: If the neighbors are computed for single-precision numbers, there is a single float - // (7.0385307e-26f) which can't be recovered using strtod. The resulting double precision - // value is off by 1 ulp. -#if 0 - const boundaries w = compute_boundaries(static_cast(value)); -#else - const boundaries w = compute_boundaries(value); -#endif - - grisu2(buf, len, decimal_exponent, w.minus, w.w, w.plus); -} - -/*! -@brief appends a decimal representation of e to buf -@return a pointer to the element following the exponent. -@pre -1000 < e < 1000 -*/ -JSON_HEDLEY_NON_NULL(1) -JSON_HEDLEY_RETURNS_NON_NULL -inline char* append_exponent(char* buf, int e) -{ - JSON_ASSERT(e > -1000); - JSON_ASSERT(e < 1000); - - if (e < 0) - { - e = -e; - *buf++ = '-'; - } - else - { - *buf++ = '+'; - } - - auto k = static_cast(e); - if (k < 10) - { - // Always print at least two digits in the exponent. - // This is for compatibility with printf("%g"). - *buf++ = '0'; - *buf++ = static_cast('0' + k); - } - else if (k < 100) - { - *buf++ = static_cast('0' + k / 10); - k %= 10; - *buf++ = static_cast('0' + k); - } - else - { - *buf++ = static_cast('0' + k / 100); - k %= 100; - *buf++ = static_cast('0' + k / 10); - k %= 10; - *buf++ = static_cast('0' + k); - } - - return buf; -} - -/*! -@brief prettify v = buf * 10^decimal_exponent - -If v is in the range [10^min_exp, 10^max_exp) it will be printed in fixed-point -notation. Otherwise it will be printed in exponential notation. - -@pre min_exp < 0 -@pre max_exp > 0 -*/ -JSON_HEDLEY_NON_NULL(1) -JSON_HEDLEY_RETURNS_NON_NULL -inline char* format_buffer(char* buf, int len, int decimal_exponent, - int min_exp, int max_exp) -{ - JSON_ASSERT(min_exp < 0); - JSON_ASSERT(max_exp > 0); - - const int k = len; - const int n = len + decimal_exponent; - - // v = buf * 10^(n-k) - // k is the length of the buffer (number of decimal digits) - // n is the position of the decimal point relative to the start of the buffer. - - if (k <= n && n <= max_exp) - { - // digits[000] - // len <= max_exp + 2 - - std::memset(buf + k, '0', static_cast(n) - static_cast(k)); - // Make it look like a floating-point number (#362, #378) - buf[n + 0] = '.'; - buf[n + 1] = '0'; - return buf + (static_cast(n) + 2); - } - - if (0 < n && n <= max_exp) - { - // dig.its - // len <= max_digits10 + 1 - - JSON_ASSERT(k > n); - - std::memmove(buf + (static_cast(n) + 1), buf + n, static_cast(k) - static_cast(n)); - buf[n] = '.'; - return buf + (static_cast(k) + 1U); - } - - if (min_exp < n && n <= 0) - { - // 0.[000]digits - // len <= 2 + (-min_exp - 1) + max_digits10 - - std::memmove(buf + (2 + static_cast(-n)), buf, static_cast(k)); - buf[0] = '0'; - buf[1] = '.'; - std::memset(buf + 2, '0', static_cast(-n)); - return buf + (2U + static_cast(-n) + static_cast(k)); - } - - if (k == 1) - { - // dE+123 - // len <= 1 + 5 - - buf += 1; - } - else - { - // d.igitsE+123 - // len <= max_digits10 + 1 + 5 - - std::memmove(buf + 2, buf + 1, static_cast(k) - 1); - buf[1] = '.'; - buf += 1 + static_cast(k); - } - - *buf++ = 'e'; - return append_exponent(buf, n - 1); -} - -} // namespace dtoa_impl - -/*! -@brief generates a decimal representation of the floating-point number value in [first, last). - -The format of the resulting decimal representation is similar to printf's %g -format. Returns an iterator pointing past-the-end of the decimal representation. - -@note The input number must be finite, i.e. NaN's and Inf's are not supported. -@note The buffer must be large enough. -@note The result is NOT null-terminated. -*/ -template -JSON_HEDLEY_NON_NULL(1, 2) -JSON_HEDLEY_RETURNS_NON_NULL -char* to_chars(char* first, const char* last, FloatType value) -{ - static_cast(last); // maybe unused - fix warning - JSON_ASSERT(std::isfinite(value)); - - // Use signbit(value) instead of (value < 0) since signbit works for -0. - if (std::signbit(value)) - { - value = -value; - *first++ = '-'; - } - - if (value == 0) // +-0 - { - *first++ = '0'; - // Make it look like a floating-point number (#362, #378) - *first++ = '.'; - *first++ = '0'; - return first; - } - - JSON_ASSERT(last - first >= std::numeric_limits::max_digits10); - - // Compute v = buffer * 10^decimal_exponent. - // The decimal digits are stored in the buffer, which needs to be interpreted - // as an unsigned decimal integer. - // len is the length of the buffer, i.e. the number of decimal digits. - int len = 0; - int decimal_exponent = 0; - dtoa_impl::grisu2(first, len, decimal_exponent, value); - - JSON_ASSERT(len <= std::numeric_limits::max_digits10); - - // Format the buffer like printf("%.*g", prec, value) - constexpr int kMinExp = -4; - // Use digits10 here to increase compatibility with version 2. - constexpr int kMaxExp = std::numeric_limits::digits10; - - JSON_ASSERT(last - first >= kMaxExp + 2); - JSON_ASSERT(last - first >= 2 + (-kMinExp - 1) + std::numeric_limits::max_digits10); - JSON_ASSERT(last - first >= std::numeric_limits::max_digits10 + 6); - - return dtoa_impl::format_buffer(first, len, decimal_exponent, kMinExp, kMaxExp); -} - -} // namespace detail -} // namespace nlohmann - -// #include - -// #include - -// #include - -// #include - -// #include - -// #include - - -namespace nlohmann -{ -namespace detail -{ -/////////////////// -// serialization // -/////////////////// - -/// how to treat decoding errors -enum class error_handler_t -{ - strict, ///< throw a type_error exception in case of invalid UTF-8 - replace, ///< replace invalid UTF-8 sequences with U+FFFD - ignore ///< ignore invalid UTF-8 sequences -}; - -template -class serializer -{ - using string_t = typename BasicJsonType::string_t; - using number_float_t = typename BasicJsonType::number_float_t; - using number_integer_t = typename BasicJsonType::number_integer_t; - using number_unsigned_t = typename BasicJsonType::number_unsigned_t; - using binary_char_t = typename BasicJsonType::binary_t::value_type; - static constexpr std::uint8_t UTF8_ACCEPT = 0; - static constexpr std::uint8_t UTF8_REJECT = 1; - - public: - /*! - @param[in] s output stream to serialize to - @param[in] ichar indentation character to use - @param[in] error_handler_ how to react on decoding errors - */ - serializer(output_adapter_t s, const char ichar, - error_handler_t error_handler_ = error_handler_t::strict) - : o(std::move(s)) - , loc(std::localeconv()) - , thousands_sep(loc->thousands_sep == nullptr ? '\0' : std::char_traits::to_char_type(* (loc->thousands_sep))) - , decimal_point(loc->decimal_point == nullptr ? '\0' : std::char_traits::to_char_type(* (loc->decimal_point))) - , indent_char(ichar) - , indent_string(512, indent_char) - , error_handler(error_handler_) - {} - - // delete because of pointer members - serializer(const serializer&) = delete; - serializer& operator=(const serializer&) = delete; - serializer(serializer&&) = delete; - serializer& operator=(serializer&&) = delete; - ~serializer() = default; - - /*! - @brief internal implementation of the serialization function - - This function is called by the public member function dump and organizes - the serialization internally. The indentation level is propagated as - additional parameter. In case of arrays and objects, the function is - called recursively. - - - strings and object keys are escaped using `escape_string()` - - integer numbers are converted implicitly via `operator<<` - - floating-point numbers are converted to a string using `"%g"` format - - binary values are serialized as objects containing the subtype and the - byte array - - @param[in] val value to serialize - @param[in] pretty_print whether the output shall be pretty-printed - @param[in] ensure_ascii If @a ensure_ascii is true, all non-ASCII characters - in the output are escaped with `\uXXXX` sequences, and the result consists - of ASCII characters only. - @param[in] indent_step the indent level - @param[in] current_indent the current indent level (only used internally) - */ - void dump(const BasicJsonType& val, - const bool pretty_print, - const bool ensure_ascii, - const unsigned int indent_step, - const unsigned int current_indent = 0) - { - switch (val.m_type) - { - case value_t::object: - { - if (val.m_value.object->empty()) - { - o->write_characters("{}", 2); - return; - } - - if (pretty_print) - { - o->write_characters("{\n", 2); - - // variable to hold indentation for recursive calls - const auto new_indent = current_indent + indent_step; - if (JSON_HEDLEY_UNLIKELY(indent_string.size() < new_indent)) - { - indent_string.resize(indent_string.size() * 2, ' '); - } - - // first n-1 elements - auto i = val.m_value.object->cbegin(); - for (std::size_t cnt = 0; cnt < val.m_value.object->size() - 1; ++cnt, ++i) - { - o->write_characters(indent_string.c_str(), new_indent); - o->write_character('\"'); - dump_escaped(i->first, ensure_ascii); - o->write_characters("\": ", 3); - dump(i->second, true, ensure_ascii, indent_step, new_indent); - o->write_characters(",\n", 2); - } - - // last element - JSON_ASSERT(i != val.m_value.object->cend()); - JSON_ASSERT(std::next(i) == val.m_value.object->cend()); - o->write_characters(indent_string.c_str(), new_indent); - o->write_character('\"'); - dump_escaped(i->first, ensure_ascii); - o->write_characters("\": ", 3); - dump(i->second, true, ensure_ascii, indent_step, new_indent); - - o->write_character('\n'); - o->write_characters(indent_string.c_str(), current_indent); - o->write_character('}'); - } - else - { - o->write_character('{'); - - // first n-1 elements - auto i = val.m_value.object->cbegin(); - for (std::size_t cnt = 0; cnt < val.m_value.object->size() - 1; ++cnt, ++i) - { - o->write_character('\"'); - dump_escaped(i->first, ensure_ascii); - o->write_characters("\":", 2); - dump(i->second, false, ensure_ascii, indent_step, current_indent); - o->write_character(','); - } - - // last element - JSON_ASSERT(i != val.m_value.object->cend()); - JSON_ASSERT(std::next(i) == val.m_value.object->cend()); - o->write_character('\"'); - dump_escaped(i->first, ensure_ascii); - o->write_characters("\":", 2); - dump(i->second, false, ensure_ascii, indent_step, current_indent); - - o->write_character('}'); - } - - return; - } - - case value_t::array: - { - if (val.m_value.array->empty()) - { - o->write_characters("[]", 2); - return; - } - - if (pretty_print) - { - o->write_characters("[\n", 2); - - // variable to hold indentation for recursive calls - const auto new_indent = current_indent + indent_step; - if (JSON_HEDLEY_UNLIKELY(indent_string.size() < new_indent)) - { - indent_string.resize(indent_string.size() * 2, ' '); - } - - // first n-1 elements - for (auto i = val.m_value.array->cbegin(); - i != val.m_value.array->cend() - 1; ++i) - { - o->write_characters(indent_string.c_str(), new_indent); - dump(*i, true, ensure_ascii, indent_step, new_indent); - o->write_characters(",\n", 2); - } - - // last element - JSON_ASSERT(!val.m_value.array->empty()); - o->write_characters(indent_string.c_str(), new_indent); - dump(val.m_value.array->back(), true, ensure_ascii, indent_step, new_indent); - - o->write_character('\n'); - o->write_characters(indent_string.c_str(), current_indent); - o->write_character(']'); - } - else - { - o->write_character('['); - - // first n-1 elements - for (auto i = val.m_value.array->cbegin(); - i != val.m_value.array->cend() - 1; ++i) - { - dump(*i, false, ensure_ascii, indent_step, current_indent); - o->write_character(','); - } - - // last element - JSON_ASSERT(!val.m_value.array->empty()); - dump(val.m_value.array->back(), false, ensure_ascii, indent_step, current_indent); - - o->write_character(']'); - } - - return; - } - - case value_t::string: - { - o->write_character('\"'); - dump_escaped(*val.m_value.string, ensure_ascii); - o->write_character('\"'); - return; - } - - case value_t::binary: - { - if (pretty_print) - { - o->write_characters("{\n", 2); - - // variable to hold indentation for recursive calls - const auto new_indent = current_indent + indent_step; - if (JSON_HEDLEY_UNLIKELY(indent_string.size() < new_indent)) - { - indent_string.resize(indent_string.size() * 2, ' '); - } - - o->write_characters(indent_string.c_str(), new_indent); - - o->write_characters("\"bytes\": [", 10); - - if (!val.m_value.binary->empty()) - { - for (auto i = val.m_value.binary->cbegin(); - i != val.m_value.binary->cend() - 1; ++i) - { - dump_integer(*i); - o->write_characters(", ", 2); - } - dump_integer(val.m_value.binary->back()); - } - - o->write_characters("],\n", 3); - o->write_characters(indent_string.c_str(), new_indent); - - o->write_characters("\"subtype\": ", 11); - if (val.m_value.binary->has_subtype()) - { - dump_integer(val.m_value.binary->subtype()); - } - else - { - o->write_characters("null", 4); - } - o->write_character('\n'); - o->write_characters(indent_string.c_str(), current_indent); - o->write_character('}'); - } - else - { - o->write_characters("{\"bytes\":[", 10); - - if (!val.m_value.binary->empty()) - { - for (auto i = val.m_value.binary->cbegin(); - i != val.m_value.binary->cend() - 1; ++i) - { - dump_integer(*i); - o->write_character(','); - } - dump_integer(val.m_value.binary->back()); - } - - o->write_characters("],\"subtype\":", 12); - if (val.m_value.binary->has_subtype()) - { - dump_integer(val.m_value.binary->subtype()); - o->write_character('}'); - } - else - { - o->write_characters("null}", 5); - } - } - return; - } - - case value_t::boolean: - { - if (val.m_value.boolean) - { - o->write_characters("true", 4); - } - else - { - o->write_characters("false", 5); - } - return; - } - - case value_t::number_integer: - { - dump_integer(val.m_value.number_integer); - return; - } - - case value_t::number_unsigned: - { - dump_integer(val.m_value.number_unsigned); - return; - } - - case value_t::number_float: - { - dump_float(val.m_value.number_float); - return; - } - - case value_t::discarded: - { - o->write_characters("", 11); - return; - } - - case value_t::null: - { - o->write_characters("null", 4); - return; - } - - default: // LCOV_EXCL_LINE - JSON_ASSERT(false); // LCOV_EXCL_LINE - } - } - - private: - /*! - @brief dump escaped string - - Escape a string by replacing certain special characters by a sequence of an - escape character (backslash) and another character and other control - characters by a sequence of "\u" followed by a four-digit hex - representation. The escaped string is written to output stream @a o. - - @param[in] s the string to escape - @param[in] ensure_ascii whether to escape non-ASCII characters with - \uXXXX sequences - - @complexity Linear in the length of string @a s. - */ - void dump_escaped(const string_t& s, const bool ensure_ascii) - { - std::uint32_t codepoint; - std::uint8_t state = UTF8_ACCEPT; - std::size_t bytes = 0; // number of bytes written to string_buffer - - // number of bytes written at the point of the last valid byte - std::size_t bytes_after_last_accept = 0; - std::size_t undumped_chars = 0; - - for (std::size_t i = 0; i < s.size(); ++i) - { - const auto byte = static_cast(s[i]); - - switch (decode(state, codepoint, byte)) - { - case UTF8_ACCEPT: // decode found a new code point - { - switch (codepoint) - { - case 0x08: // backspace - { - string_buffer[bytes++] = '\\'; - string_buffer[bytes++] = 'b'; - break; - } - - case 0x09: // horizontal tab - { - string_buffer[bytes++] = '\\'; - string_buffer[bytes++] = 't'; - break; - } - - case 0x0A: // newline - { - string_buffer[bytes++] = '\\'; - string_buffer[bytes++] = 'n'; - break; - } - - case 0x0C: // formfeed - { - string_buffer[bytes++] = '\\'; - string_buffer[bytes++] = 'f'; - break; - } - - case 0x0D: // carriage return - { - string_buffer[bytes++] = '\\'; - string_buffer[bytes++] = 'r'; - break; - } - - case 0x22: // quotation mark - { - string_buffer[bytes++] = '\\'; - string_buffer[bytes++] = '\"'; - break; - } - - case 0x5C: // reverse solidus - { - string_buffer[bytes++] = '\\'; - string_buffer[bytes++] = '\\'; - break; - } - - default: - { - // escape control characters (0x00..0x1F) or, if - // ensure_ascii parameter is used, non-ASCII characters - if ((codepoint <= 0x1F) || (ensure_ascii && (codepoint >= 0x7F))) - { - if (codepoint <= 0xFFFF) - { - (std::snprintf)(string_buffer.data() + bytes, 7, "\\u%04x", - static_cast(codepoint)); - bytes += 6; - } - else - { - (std::snprintf)(string_buffer.data() + bytes, 13, "\\u%04x\\u%04x", - static_cast(0xD7C0u + (codepoint >> 10u)), - static_cast(0xDC00u + (codepoint & 0x3FFu))); - bytes += 12; - } - } - else - { - // copy byte to buffer (all previous bytes - // been copied have in default case above) - string_buffer[bytes++] = s[i]; - } - break; - } - } - - // write buffer and reset index; there must be 13 bytes - // left, as this is the maximal number of bytes to be - // written ("\uxxxx\uxxxx\0") for one code point - if (string_buffer.size() - bytes < 13) - { - o->write_characters(string_buffer.data(), bytes); - bytes = 0; - } - - // remember the byte position of this accept - bytes_after_last_accept = bytes; - undumped_chars = 0; - break; - } - - case UTF8_REJECT: // decode found invalid UTF-8 byte - { - switch (error_handler) - { - case error_handler_t::strict: - { - std::string sn(3, '\0'); - (std::snprintf)(&sn[0], sn.size(), "%.2X", byte); - JSON_THROW(type_error::create(316, "invalid UTF-8 byte at index " + std::to_string(i) + ": 0x" + sn)); - } - - case error_handler_t::ignore: - case error_handler_t::replace: - { - // in case we saw this character the first time, we - // would like to read it again, because the byte - // may be OK for itself, but just not OK for the - // previous sequence - if (undumped_chars > 0) - { - --i; - } - - // reset length buffer to the last accepted index; - // thus removing/ignoring the invalid characters - bytes = bytes_after_last_accept; - - if (error_handler == error_handler_t::replace) - { - // add a replacement character - if (ensure_ascii) - { - string_buffer[bytes++] = '\\'; - string_buffer[bytes++] = 'u'; - string_buffer[bytes++] = 'f'; - string_buffer[bytes++] = 'f'; - string_buffer[bytes++] = 'f'; - string_buffer[bytes++] = 'd'; - } - else - { - string_buffer[bytes++] = detail::binary_writer::to_char_type('\xEF'); - string_buffer[bytes++] = detail::binary_writer::to_char_type('\xBF'); - string_buffer[bytes++] = detail::binary_writer::to_char_type('\xBD'); - } - - // write buffer and reset index; there must be 13 bytes - // left, as this is the maximal number of bytes to be - // written ("\uxxxx\uxxxx\0") for one code point - if (string_buffer.size() - bytes < 13) - { - o->write_characters(string_buffer.data(), bytes); - bytes = 0; - } - - bytes_after_last_accept = bytes; - } - - undumped_chars = 0; - - // continue processing the string - state = UTF8_ACCEPT; - break; - } - - default: // LCOV_EXCL_LINE - JSON_ASSERT(false); // LCOV_EXCL_LINE - } - break; - } - - default: // decode found yet incomplete multi-byte code point - { - if (!ensure_ascii) - { - // code point will not be escaped - copy byte to buffer - string_buffer[bytes++] = s[i]; - } - ++undumped_chars; - break; - } - } - } - - // we finished processing the string - if (JSON_HEDLEY_LIKELY(state == UTF8_ACCEPT)) - { - // write buffer - if (bytes > 0) - { - o->write_characters(string_buffer.data(), bytes); - } - } - else - { - // we finish reading, but do not accept: string was incomplete - switch (error_handler) - { - case error_handler_t::strict: - { - std::string sn(3, '\0'); - (std::snprintf)(&sn[0], sn.size(), "%.2X", static_cast(s.back())); - JSON_THROW(type_error::create(316, "incomplete UTF-8 string; last byte: 0x" + sn)); - } - - case error_handler_t::ignore: - { - // write all accepted bytes - o->write_characters(string_buffer.data(), bytes_after_last_accept); - break; - } - - case error_handler_t::replace: - { - // write all accepted bytes - o->write_characters(string_buffer.data(), bytes_after_last_accept); - // add a replacement character - if (ensure_ascii) - { - o->write_characters("\\ufffd", 6); - } - else - { - o->write_characters("\xEF\xBF\xBD", 3); - } - break; - } - - default: // LCOV_EXCL_LINE - JSON_ASSERT(false); // LCOV_EXCL_LINE - } - } - } - - /*! - @brief count digits - - Count the number of decimal (base 10) digits for an input unsigned integer. - - @param[in] x unsigned integer number to count its digits - @return number of decimal digits - */ - inline unsigned int count_digits(number_unsigned_t x) noexcept - { - unsigned int n_digits = 1; - for (;;) - { - if (x < 10) - { - return n_digits; - } - if (x < 100) - { - return n_digits + 1; - } - if (x < 1000) - { - return n_digits + 2; - } - if (x < 10000) - { - return n_digits + 3; - } - x = x / 10000u; - n_digits += 4; - } - } - - /*! - @brief dump an integer - - Dump a given integer to output stream @a o. Works internally with - @a number_buffer. - - @param[in] x integer number (signed or unsigned) to dump - @tparam NumberType either @a number_integer_t or @a number_unsigned_t - */ - template < typename NumberType, detail::enable_if_t < - std::is_same::value || - std::is_same::value || - std::is_same::value, - int > = 0 > - void dump_integer(NumberType x) - { - static constexpr std::array, 100> digits_to_99 - { - { - {{'0', '0'}}, {{'0', '1'}}, {{'0', '2'}}, {{'0', '3'}}, {{'0', '4'}}, {{'0', '5'}}, {{'0', '6'}}, {{'0', '7'}}, {{'0', '8'}}, {{'0', '9'}}, - {{'1', '0'}}, {{'1', '1'}}, {{'1', '2'}}, {{'1', '3'}}, {{'1', '4'}}, {{'1', '5'}}, {{'1', '6'}}, {{'1', '7'}}, {{'1', '8'}}, {{'1', '9'}}, - {{'2', '0'}}, {{'2', '1'}}, {{'2', '2'}}, {{'2', '3'}}, {{'2', '4'}}, {{'2', '5'}}, {{'2', '6'}}, {{'2', '7'}}, {{'2', '8'}}, {{'2', '9'}}, - {{'3', '0'}}, {{'3', '1'}}, {{'3', '2'}}, {{'3', '3'}}, {{'3', '4'}}, {{'3', '5'}}, {{'3', '6'}}, {{'3', '7'}}, {{'3', '8'}}, {{'3', '9'}}, - {{'4', '0'}}, {{'4', '1'}}, {{'4', '2'}}, {{'4', '3'}}, {{'4', '4'}}, {{'4', '5'}}, {{'4', '6'}}, {{'4', '7'}}, {{'4', '8'}}, {{'4', '9'}}, - {{'5', '0'}}, {{'5', '1'}}, {{'5', '2'}}, {{'5', '3'}}, {{'5', '4'}}, {{'5', '5'}}, {{'5', '6'}}, {{'5', '7'}}, {{'5', '8'}}, {{'5', '9'}}, - {{'6', '0'}}, {{'6', '1'}}, {{'6', '2'}}, {{'6', '3'}}, {{'6', '4'}}, {{'6', '5'}}, {{'6', '6'}}, {{'6', '7'}}, {{'6', '8'}}, {{'6', '9'}}, - {{'7', '0'}}, {{'7', '1'}}, {{'7', '2'}}, {{'7', '3'}}, {{'7', '4'}}, {{'7', '5'}}, {{'7', '6'}}, {{'7', '7'}}, {{'7', '8'}}, {{'7', '9'}}, - {{'8', '0'}}, {{'8', '1'}}, {{'8', '2'}}, {{'8', '3'}}, {{'8', '4'}}, {{'8', '5'}}, {{'8', '6'}}, {{'8', '7'}}, {{'8', '8'}}, {{'8', '9'}}, - {{'9', '0'}}, {{'9', '1'}}, {{'9', '2'}}, {{'9', '3'}}, {{'9', '4'}}, {{'9', '5'}}, {{'9', '6'}}, {{'9', '7'}}, {{'9', '8'}}, {{'9', '9'}}, - } - }; - - // special case for "0" - if (x == 0) - { - o->write_character('0'); - return; - } - - // use a pointer to fill the buffer - auto buffer_ptr = number_buffer.begin(); - - const bool is_negative = std::is_same::value && !(x >= 0); // see issue #755 - number_unsigned_t abs_value; - - unsigned int n_chars; - - if (is_negative) - { - *buffer_ptr = '-'; - abs_value = remove_sign(static_cast(x)); - - // account one more byte for the minus sign - n_chars = 1 + count_digits(abs_value); - } - else - { - abs_value = static_cast(x); - n_chars = count_digits(abs_value); - } - - // spare 1 byte for '\0' - JSON_ASSERT(n_chars < number_buffer.size() - 1); - - // jump to the end to generate the string from backward - // so we later avoid reversing the result - buffer_ptr += n_chars; - - // Fast int2ascii implementation inspired by "Fastware" talk by Andrei Alexandrescu - // See: https://www.youtube.com/watch?v=o4-CwDo2zpg - while (abs_value >= 100) - { - const auto digits_index = static_cast((abs_value % 100)); - abs_value /= 100; - *(--buffer_ptr) = digits_to_99[digits_index][1]; - *(--buffer_ptr) = digits_to_99[digits_index][0]; - } - - if (abs_value >= 10) - { - const auto digits_index = static_cast(abs_value); - *(--buffer_ptr) = digits_to_99[digits_index][1]; - *(--buffer_ptr) = digits_to_99[digits_index][0]; - } - else - { - *(--buffer_ptr) = static_cast('0' + abs_value); - } - - o->write_characters(number_buffer.data(), n_chars); - } - - /*! - @brief dump a floating-point number - - Dump a given floating-point number to output stream @a o. Works internally - with @a number_buffer. - - @param[in] x floating-point number to dump - */ - void dump_float(number_float_t x) - { - // NaN / inf - if (!std::isfinite(x)) - { - o->write_characters("null", 4); - return; - } - - // If number_float_t is an IEEE-754 single or double precision number, - // use the Grisu2 algorithm to produce short numbers which are - // guaranteed to round-trip, using strtof and strtod, resp. - // - // NB: The test below works if == . - static constexpr bool is_ieee_single_or_double - = (std::numeric_limits::is_iec559 && std::numeric_limits::digits == 24 && std::numeric_limits::max_exponent == 128) || - (std::numeric_limits::is_iec559 && std::numeric_limits::digits == 53 && std::numeric_limits::max_exponent == 1024); - - dump_float(x, std::integral_constant()); - } - - void dump_float(number_float_t x, std::true_type /*is_ieee_single_or_double*/) - { - char* begin = number_buffer.data(); - char* end = ::nlohmann::detail::to_chars(begin, begin + number_buffer.size(), x); - - o->write_characters(begin, static_cast(end - begin)); - } - - void dump_float(number_float_t x, std::false_type /*is_ieee_single_or_double*/) - { - // get number of digits for a float -> text -> float round-trip - static constexpr auto d = std::numeric_limits::max_digits10; - - // the actual conversion - std::ptrdiff_t len = (std::snprintf)(number_buffer.data(), number_buffer.size(), "%.*g", d, x); - - // negative value indicates an error - JSON_ASSERT(len > 0); - // check if buffer was large enough - JSON_ASSERT(static_cast(len) < number_buffer.size()); - - // erase thousands separator - if (thousands_sep != '\0') - { - const auto end = std::remove(number_buffer.begin(), - number_buffer.begin() + len, thousands_sep); - std::fill(end, number_buffer.end(), '\0'); - JSON_ASSERT((end - number_buffer.begin()) <= len); - len = (end - number_buffer.begin()); - } - - // convert decimal point to '.' - if (decimal_point != '\0' && decimal_point != '.') - { - const auto dec_pos = std::find(number_buffer.begin(), number_buffer.end(), decimal_point); - if (dec_pos != number_buffer.end()) - { - *dec_pos = '.'; - } - } - - o->write_characters(number_buffer.data(), static_cast(len)); - - // determine if need to append ".0" - const bool value_is_int_like = - std::none_of(number_buffer.begin(), number_buffer.begin() + len + 1, - [](char c) - { - return c == '.' || c == 'e'; - }); - - if (value_is_int_like) - { - o->write_characters(".0", 2); - } - } - - /*! - @brief check whether a string is UTF-8 encoded - - The function checks each byte of a string whether it is UTF-8 encoded. The - result of the check is stored in the @a state parameter. The function must - be called initially with state 0 (accept). State 1 means the string must - be rejected, because the current byte is not allowed. If the string is - completely processed, but the state is non-zero, the string ended - prematurely; that is, the last byte indicated more bytes should have - followed. - - @param[in,out] state the state of the decoding - @param[in,out] codep codepoint (valid only if resulting state is UTF8_ACCEPT) - @param[in] byte next byte to decode - @return new state - - @note The function has been edited: a std::array is used. - - @copyright Copyright (c) 2008-2009 Bjoern Hoehrmann - @sa http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ - */ - static std::uint8_t decode(std::uint8_t& state, std::uint32_t& codep, const std::uint8_t byte) noexcept - { - static const std::array utf8d = - { - { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 00..1F - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 20..3F - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 40..5F - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 60..7F - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, // 80..9F - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // A0..BF - 8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // C0..DF - 0xA, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x4, 0x3, 0x3, // E0..EF - 0xB, 0x6, 0x6, 0x6, 0x5, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, // F0..FF - 0x0, 0x1, 0x2, 0x3, 0x5, 0x8, 0x7, 0x1, 0x1, 0x1, 0x4, 0x6, 0x1, 0x1, 0x1, 0x1, // s0..s0 - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, // s1..s2 - 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, // s3..s4 - 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, // s5..s6 - 1, 3, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 // s7..s8 - } - }; - - const std::uint8_t type = utf8d[byte]; - - codep = (state != UTF8_ACCEPT) - ? (byte & 0x3fu) | (codep << 6u) - : (0xFFu >> type) & (byte); - - std::size_t index = 256u + static_cast(state) * 16u + static_cast(type); - JSON_ASSERT(index < 400); - state = utf8d[index]; - return state; - } - - /* - * Overload to make the compiler happy while it is instantiating - * dump_integer for number_unsigned_t. - * Must never be called. - */ - number_unsigned_t remove_sign(number_unsigned_t x) - { - JSON_ASSERT(false); // LCOV_EXCL_LINE - return x; // LCOV_EXCL_LINE - } - - /* - * Helper function for dump_integer - * - * This function takes a negative signed integer and returns its absolute - * value as unsigned integer. The plus/minus shuffling is necessary as we can - * not directly remove the sign of an arbitrary signed integer as the - * absolute values of INT_MIN and INT_MAX are usually not the same. See - * #1708 for details. - */ - inline number_unsigned_t remove_sign(number_integer_t x) noexcept - { - JSON_ASSERT(x < 0 && x < (std::numeric_limits::max)()); - return static_cast(-(x + 1)) + 1; - } - - private: - /// the output of the serializer - output_adapter_t o = nullptr; - - /// a (hopefully) large enough character buffer - std::array number_buffer{{}}; - - /// the locale - const std::lconv* loc = nullptr; - /// the locale's thousand separator character - const char thousands_sep = '\0'; - /// the locale's decimal point character - const char decimal_point = '\0'; - - /// string buffer - std::array string_buffer{{}}; - - /// the indentation character - const char indent_char; - /// the indentation string - string_t indent_string; - - /// error_handler how to react on decoding errors - const error_handler_t error_handler; -}; -} // namespace detail -} // namespace nlohmann - -// #include - -// #include - -// #include - - -#include // less -#include // allocator -#include // pair -#include // vector - -namespace nlohmann -{ - -/// ordered_map: a minimal map-like container that preserves insertion order -/// for use within nlohmann::basic_json -template , - class Allocator = std::allocator>> - struct ordered_map : std::vector, Allocator> -{ - using key_type = Key; - using mapped_type = T; - using Container = std::vector, Allocator>; - using typename Container::iterator; - using typename Container::const_iterator; - using typename Container::size_type; - using typename Container::value_type; - - // Explicit constructors instead of `using Container::Container` - // otherwise older compilers choke on it (GCC <= 5.5, xcode <= 9.4) - ordered_map(const Allocator& alloc = Allocator()) : Container{alloc} {} - template - ordered_map(It first, It last, const Allocator& alloc = Allocator()) - : Container{first, last, alloc} {} - ordered_map(std::initializer_list init, const Allocator& alloc = Allocator() ) - : Container{init, alloc} {} - - std::pair emplace(const key_type& key, T&& t) - { - for (auto it = this->begin(); it != this->end(); ++it) - { - if (it->first == key) - { - return {it, false}; - } - } - Container::emplace_back(key, t); - return {--this->end(), true}; - } - - T& operator[](const Key& key) - { - return emplace(key, T{}).first->second; - } - - const T& operator[](const Key& key) const - { - return at(key); - } - - T& at(const Key& key) - { - for (auto it = this->begin(); it != this->end(); ++it) - { - if (it->first == key) - { - return it->second; - } - } - - throw std::out_of_range("key not found"); - } - - const T& at(const Key& key) const - { - for (auto it = this->begin(); it != this->end(); ++it) - { - if (it->first == key) - { - return it->second; - } - } - - throw std::out_of_range("key not found"); - } - - size_type erase(const Key& key) - { - for (auto it = this->begin(); it != this->end(); ++it) - { - if (it->first == key) - { - // Since we cannot move const Keys, re-construct them in place - for (auto next = it; ++next != this->end(); ++it) - { - it->~value_type(); // Destroy but keep allocation - new (&*it) value_type{std::move(*next)}; - } - Container::pop_back(); - return 1; - } - } - return 0; - } - - iterator erase(iterator pos) - { - auto it = pos; - - // Since we cannot move const Keys, re-construct them in place - for (auto next = it; ++next != this->end(); ++it) - { - it->~value_type(); // Destroy but keep allocation - new (&*it) value_type{std::move(*next)}; - } - Container::pop_back(); - return pos; - } - - size_type count(const Key& key) const - { - for (auto it = this->begin(); it != this->end(); ++it) - { - if (it->first == key) - { - return 1; - } - } - return 0; - } - - iterator find(const Key& key) - { - for (auto it = this->begin(); it != this->end(); ++it) - { - if (it->first == key) - { - return it; - } - } - return Container::end(); - } - - const_iterator find(const Key& key) const - { - for (auto it = this->begin(); it != this->end(); ++it) - { - if (it->first == key) - { - return it; - } - } - return Container::end(); - } - - std::pair insert( value_type&& value ) - { - return emplace(value.first, std::move(value.second)); - } - - std::pair insert( const value_type& value ) - { - for (auto it = this->begin(); it != this->end(); ++it) - { - if (it->first == value.first) - { - return {it, false}; - } - } - Container::push_back(value); - return {--this->end(), true}; - } -}; - -} // namespace nlohmann - - -/*! -@brief namespace for Niels Lohmann -@see https://github.com/nlohmann -@since version 1.0.0 -*/ -namespace nlohmann -{ - -/*! -@brief a class to store JSON values - -@tparam ObjectType type for JSON objects (`std::map` by default; will be used -in @ref object_t) -@tparam ArrayType type for JSON arrays (`std::vector` by default; will be used -in @ref array_t) -@tparam StringType type for JSON strings and object keys (`std::string` by -default; will be used in @ref string_t) -@tparam BooleanType type for JSON booleans (`bool` by default; will be used -in @ref boolean_t) -@tparam NumberIntegerType type for JSON integer numbers (`int64_t` by -default; will be used in @ref number_integer_t) -@tparam NumberUnsignedType type for JSON unsigned integer numbers (@c -`uint64_t` by default; will be used in @ref number_unsigned_t) -@tparam NumberFloatType type for JSON floating-point numbers (`double` by -default; will be used in @ref number_float_t) -@tparam BinaryType type for packed binary data for compatibility with binary -serialization formats (`std::vector` by default; will be used in -@ref binary_t) -@tparam AllocatorType type of the allocator to use (`std::allocator` by -default) -@tparam JSONSerializer the serializer to resolve internal calls to `to_json()` -and `from_json()` (@ref adl_serializer by default) - -@requirement The class satisfies the following concept requirements: -- Basic - - [DefaultConstructible](https://en.cppreference.com/w/cpp/named_req/DefaultConstructible): - JSON values can be default constructed. The result will be a JSON null - value. - - [MoveConstructible](https://en.cppreference.com/w/cpp/named_req/MoveConstructible): - A JSON value can be constructed from an rvalue argument. - - [CopyConstructible](https://en.cppreference.com/w/cpp/named_req/CopyConstructible): - A JSON value can be copy-constructed from an lvalue expression. - - [MoveAssignable](https://en.cppreference.com/w/cpp/named_req/MoveAssignable): - A JSON value van be assigned from an rvalue argument. - - [CopyAssignable](https://en.cppreference.com/w/cpp/named_req/CopyAssignable): - A JSON value can be copy-assigned from an lvalue expression. - - [Destructible](https://en.cppreference.com/w/cpp/named_req/Destructible): - JSON values can be destructed. -- Layout - - [StandardLayoutType](https://en.cppreference.com/w/cpp/named_req/StandardLayoutType): - JSON values have - [standard layout](https://en.cppreference.com/w/cpp/language/data_members#Standard_layout): - All non-static data members are private and standard layout types, the - class has no virtual functions or (virtual) base classes. -- Library-wide - - [EqualityComparable](https://en.cppreference.com/w/cpp/named_req/EqualityComparable): - JSON values can be compared with `==`, see @ref - operator==(const_reference,const_reference). - - [LessThanComparable](https://en.cppreference.com/w/cpp/named_req/LessThanComparable): - JSON values can be compared with `<`, see @ref - operator<(const_reference,const_reference). - - [Swappable](https://en.cppreference.com/w/cpp/named_req/Swappable): - Any JSON lvalue or rvalue of can be swapped with any lvalue or rvalue of - other compatible types, using unqualified function call @ref swap(). - - [NullablePointer](https://en.cppreference.com/w/cpp/named_req/NullablePointer): - JSON values can be compared against `std::nullptr_t` objects which are used - to model the `null` value. -- Container - - [Container](https://en.cppreference.com/w/cpp/named_req/Container): - JSON values can be used like STL containers and provide iterator access. - - [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer); - JSON values can be used like STL containers and provide reverse iterator - access. - -@invariant The member variables @a m_value and @a m_type have the following -relationship: -- If `m_type == value_t::object`, then `m_value.object != nullptr`. -- If `m_type == value_t::array`, then `m_value.array != nullptr`. -- If `m_type == value_t::string`, then `m_value.string != nullptr`. -The invariants are checked by member function assert_invariant(). - -@internal -@note ObjectType trick from https://stackoverflow.com/a/9860911 -@endinternal - -@see [RFC 7159: The JavaScript Object Notation (JSON) Data Interchange -Format](http://rfc7159.net/rfc7159) - -@since version 1.0.0 - -@nosubgrouping -*/ -NLOHMANN_BASIC_JSON_TPL_DECLARATION -class basic_json -{ - private: - template friend struct detail::external_constructor; - friend ::nlohmann::json_pointer; - - template - friend class ::nlohmann::detail::parser; - friend ::nlohmann::detail::serializer; - template - friend class ::nlohmann::detail::iter_impl; - template - friend class ::nlohmann::detail::binary_writer; - template - friend class ::nlohmann::detail::binary_reader; - template - friend class ::nlohmann::detail::json_sax_dom_parser; - template - friend class ::nlohmann::detail::json_sax_dom_callback_parser; - - /// workaround type for MSVC - using basic_json_t = NLOHMANN_BASIC_JSON_TPL; - - // convenience aliases for types residing in namespace detail; - using lexer = ::nlohmann::detail::lexer_base; - - template - static ::nlohmann::detail::parser parser( - InputAdapterType adapter, - detail::parser_callback_tcb = nullptr, - const bool allow_exceptions = true, - const bool ignore_comments = false - ) - { - return ::nlohmann::detail::parser(std::move(adapter), - std::move(cb), allow_exceptions, ignore_comments); - } - - using primitive_iterator_t = ::nlohmann::detail::primitive_iterator_t; - template - using internal_iterator = ::nlohmann::detail::internal_iterator; - template - using iter_impl = ::nlohmann::detail::iter_impl; - template - using iteration_proxy = ::nlohmann::detail::iteration_proxy; - template using json_reverse_iterator = ::nlohmann::detail::json_reverse_iterator; - - template - using output_adapter_t = ::nlohmann::detail::output_adapter_t; - - template - using binary_reader = ::nlohmann::detail::binary_reader; - template using binary_writer = ::nlohmann::detail::binary_writer; - - using serializer = ::nlohmann::detail::serializer; - - public: - using value_t = detail::value_t; - /// JSON Pointer, see @ref nlohmann::json_pointer - using json_pointer = ::nlohmann::json_pointer; - template - using json_serializer = JSONSerializer; - /// how to treat decoding errors - using error_handler_t = detail::error_handler_t; - /// how to treat CBOR tags - using cbor_tag_handler_t = detail::cbor_tag_handler_t; - /// helper type for initializer lists of basic_json values - using initializer_list_t = std::initializer_list>; - - using input_format_t = detail::input_format_t; - /// SAX interface type, see @ref nlohmann::json_sax - using json_sax_t = json_sax; - - //////////////// - // exceptions // - //////////////// - - /// @name exceptions - /// Classes to implement user-defined exceptions. - /// @{ - - /// @copydoc detail::exception - using exception = detail::exception; - /// @copydoc detail::parse_error - using parse_error = detail::parse_error; - /// @copydoc detail::invalid_iterator - using invalid_iterator = detail::invalid_iterator; - /// @copydoc detail::type_error - using type_error = detail::type_error; - /// @copydoc detail::out_of_range - using out_of_range = detail::out_of_range; - /// @copydoc detail::other_error - using other_error = detail::other_error; - - /// @} - - - ///////////////////// - // container types // - ///////////////////// - - /// @name container types - /// The canonic container types to use @ref basic_json like any other STL - /// container. - /// @{ - - /// the type of elements in a basic_json container - using value_type = basic_json; - - /// the type of an element reference - using reference = value_type&; - /// the type of an element const reference - using const_reference = const value_type&; - - /// a type to represent differences between iterators - using difference_type = std::ptrdiff_t; - /// a type to represent container sizes - using size_type = std::size_t; - - /// the allocator type - using allocator_type = AllocatorType; - - /// the type of an element pointer - using pointer = typename std::allocator_traits::pointer; - /// the type of an element const pointer - using const_pointer = typename std::allocator_traits::const_pointer; - - /// an iterator for a basic_json container - using iterator = iter_impl; - /// a const iterator for a basic_json container - using const_iterator = iter_impl; - /// a reverse iterator for a basic_json container - using reverse_iterator = json_reverse_iterator; - /// a const reverse iterator for a basic_json container - using const_reverse_iterator = json_reverse_iterator; - - /// @} - - - /*! - @brief returns the allocator associated with the container - */ - static allocator_type get_allocator() - { - return allocator_type(); - } - - /*! - @brief returns version information on the library - - This function returns a JSON object with information about the library, - including the version number and information on the platform and compiler. - - @return JSON object holding version information - key | description - ----------- | --------------- - `compiler` | Information on the used compiler. It is an object with the following keys: `c++` (the used C++ standard), `family` (the compiler family; possible values are `clang`, `icc`, `gcc`, `ilecpp`, `msvc`, `pgcpp`, `sunpro`, and `unknown`), and `version` (the compiler version). - `copyright` | The copyright line for the library as string. - `name` | The name of the library as string. - `platform` | The used platform as string. Possible values are `win32`, `linux`, `apple`, `unix`, and `unknown`. - `url` | The URL of the project as string. - `version` | The version of the library. It is an object with the following keys: `major`, `minor`, and `patch` as defined by [Semantic Versioning](http://semver.org), and `string` (the version string). - - @liveexample{The following code shows an example output of the `meta()` - function.,meta} - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes to any JSON value. - - @complexity Constant. - - @since 2.1.0 - */ - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json meta() - { - basic_json result; - - result["copyright"] = "(C) 2013-2020 Niels Lohmann"; - result["name"] = "JSON for Modern C++"; - result["url"] = "https://github.com/nlohmann/json"; - result["version"]["string"] = - std::to_string(NLOHMANN_JSON_VERSION_MAJOR) + "." + - std::to_string(NLOHMANN_JSON_VERSION_MINOR) + "." + - std::to_string(NLOHMANN_JSON_VERSION_PATCH); - result["version"]["major"] = NLOHMANN_JSON_VERSION_MAJOR; - result["version"]["minor"] = NLOHMANN_JSON_VERSION_MINOR; - result["version"]["patch"] = NLOHMANN_JSON_VERSION_PATCH; - -#ifdef _WIN32 - result["platform"] = "win32"; -#elif defined __linux__ - result["platform"] = "linux"; -#elif defined __APPLE__ - result["platform"] = "apple"; -#elif defined __unix__ - result["platform"] = "unix"; -#else - result["platform"] = "unknown"; -#endif - -#if defined(__ICC) || defined(__INTEL_COMPILER) - result["compiler"] = {{"family", "icc"}, {"version", __INTEL_COMPILER}}; -#elif defined(__clang__) - result["compiler"] = {{"family", "clang"}, {"version", __clang_version__}}; -#elif defined(__GNUC__) || defined(__GNUG__) - result["compiler"] = {{"family", "gcc"}, {"version", std::to_string(__GNUC__) + "." + std::to_string(__GNUC_MINOR__) + "." + std::to_string(__GNUC_PATCHLEVEL__)}}; -#elif defined(__HP_cc) || defined(__HP_aCC) - result["compiler"] = "hp" -#elif defined(__IBMCPP__) - result["compiler"] = {{"family", "ilecpp"}, {"version", __IBMCPP__}}; -#elif defined(_MSC_VER) - result["compiler"] = {{"family", "msvc"}, {"version", _MSC_VER}}; -#elif defined(__PGI) - result["compiler"] = {{"family", "pgcpp"}, {"version", __PGI}}; -#elif defined(__SUNPRO_CC) - result["compiler"] = {{"family", "sunpro"}, {"version", __SUNPRO_CC}}; -#else - result["compiler"] = {{"family", "unknown"}, {"version", "unknown"}}; -#endif - -#ifdef __cplusplus - result["compiler"]["c++"] = std::to_string(__cplusplus); -#else - result["compiler"]["c++"] = "unknown"; -#endif - return result; - } - - - /////////////////////////// - // JSON value data types // - /////////////////////////// - - /// @name JSON value data types - /// The data types to store a JSON value. These types are derived from - /// the template arguments passed to class @ref basic_json. - /// @{ - -#if defined(JSON_HAS_CPP_14) - // Use transparent comparator if possible, combined with perfect forwarding - // on find() and count() calls prevents unnecessary string construction. - using object_comparator_t = std::less<>; -#else - using object_comparator_t = std::less; -#endif - - /*! - @brief a type for an object - - [RFC 7159](http://rfc7159.net/rfc7159) describes JSON objects as follows: - > An object is an unordered collection of zero or more name/value pairs, - > where a name is a string and a value is a string, number, boolean, null, - > object, or array. - - To store objects in C++, a type is defined by the template parameters - described below. - - @tparam ObjectType the container to store objects (e.g., `std::map` or - `std::unordered_map`) - @tparam StringType the type of the keys or names (e.g., `std::string`). - The comparison function `std::less` is used to order elements - inside the container. - @tparam AllocatorType the allocator to use for objects (e.g., - `std::allocator`) - - #### Default type - - With the default values for @a ObjectType (`std::map`), @a StringType - (`std::string`), and @a AllocatorType (`std::allocator`), the default - value for @a object_t is: - - @code {.cpp} - std::map< - std::string, // key_type - basic_json, // value_type - std::less, // key_compare - std::allocator> // allocator_type - > - @endcode - - #### Behavior - - The choice of @a object_t influences the behavior of the JSON class. With - the default type, objects have the following behavior: - - - When all names are unique, objects will be interoperable in the sense - that all software implementations receiving that object will agree on - the name-value mappings. - - When the names within an object are not unique, it is unspecified which - one of the values for a given key will be chosen. For instance, - `{"key": 2, "key": 1}` could be equal to either `{"key": 1}` or - `{"key": 2}`. - - Internally, name/value pairs are stored in lexicographical order of the - names. Objects will also be serialized (see @ref dump) in this order. - For instance, `{"b": 1, "a": 2}` and `{"a": 2, "b": 1}` will be stored - and serialized as `{"a": 2, "b": 1}`. - - When comparing objects, the order of the name/value pairs is irrelevant. - This makes objects interoperable in the sense that they will not be - affected by these differences. For instance, `{"b": 1, "a": 2}` and - `{"a": 2, "b": 1}` will be treated as equal. - - #### Limits - - [RFC 7159](http://rfc7159.net/rfc7159) specifies: - > An implementation may set limits on the maximum depth of nesting. - - In this class, the object's limit of nesting is not explicitly constrained. - However, a maximum depth of nesting may be introduced by the compiler or - runtime environment. A theoretical limit can be queried by calling the - @ref max_size function of a JSON object. - - #### Storage - - Objects are stored as pointers in a @ref basic_json type. That is, for any - access to object values, a pointer of type `object_t*` must be - dereferenced. - - @sa @ref array_t -- type for an array value - - @since version 1.0.0 - - @note The order name/value pairs are added to the object is *not* - preserved by the library. Therefore, iterating an object may return - name/value pairs in a different order than they were originally stored. In - fact, keys will be traversed in alphabetical order as `std::map` with - `std::less` is used by default. Please note this behavior conforms to [RFC - 7159](http://rfc7159.net/rfc7159), because any order implements the - specified "unordered" nature of JSON objects. - */ - using object_t = ObjectType>>; - - /*! - @brief a type for an array - - [RFC 7159](http://rfc7159.net/rfc7159) describes JSON arrays as follows: - > An array is an ordered sequence of zero or more values. - - To store objects in C++, a type is defined by the template parameters - explained below. - - @tparam ArrayType container type to store arrays (e.g., `std::vector` or - `std::list`) - @tparam AllocatorType allocator to use for arrays (e.g., `std::allocator`) - - #### Default type - - With the default values for @a ArrayType (`std::vector`) and @a - AllocatorType (`std::allocator`), the default value for @a array_t is: - - @code {.cpp} - std::vector< - basic_json, // value_type - std::allocator // allocator_type - > - @endcode - - #### Limits - - [RFC 7159](http://rfc7159.net/rfc7159) specifies: - > An implementation may set limits on the maximum depth of nesting. - - In this class, the array's limit of nesting is not explicitly constrained. - However, a maximum depth of nesting may be introduced by the compiler or - runtime environment. A theoretical limit can be queried by calling the - @ref max_size function of a JSON array. - - #### Storage - - Arrays are stored as pointers in a @ref basic_json type. That is, for any - access to array values, a pointer of type `array_t*` must be dereferenced. - - @sa @ref object_t -- type for an object value - - @since version 1.0.0 - */ - using array_t = ArrayType>; - - /*! - @brief a type for a string - - [RFC 7159](http://rfc7159.net/rfc7159) describes JSON strings as follows: - > A string is a sequence of zero or more Unicode characters. - - To store objects in C++, a type is defined by the template parameter - described below. Unicode values are split by the JSON class into - byte-sized characters during deserialization. - - @tparam StringType the container to store strings (e.g., `std::string`). - Note this container is used for keys/names in objects, see @ref object_t. - - #### Default type - - With the default values for @a StringType (`std::string`), the default - value for @a string_t is: - - @code {.cpp} - std::string - @endcode - - #### Encoding - - Strings are stored in UTF-8 encoding. Therefore, functions like - `std::string::size()` or `std::string::length()` return the number of - bytes in the string rather than the number of characters or glyphs. - - #### String comparison - - [RFC 7159](http://rfc7159.net/rfc7159) states: - > Software implementations are typically required to test names of object - > members for equality. Implementations that transform the textual - > representation into sequences of Unicode code units and then perform the - > comparison numerically, code unit by code unit, are interoperable in the - > sense that implementations will agree in all cases on equality or - > inequality of two strings. For example, implementations that compare - > strings with escaped characters unconverted may incorrectly find that - > `"a\\b"` and `"a\u005Cb"` are not equal. - - This implementation is interoperable as it does compare strings code unit - by code unit. - - #### Storage - - String values are stored as pointers in a @ref basic_json type. That is, - for any access to string values, a pointer of type `string_t*` must be - dereferenced. - - @since version 1.0.0 - */ - using string_t = StringType; - - /*! - @brief a type for a boolean - - [RFC 7159](http://rfc7159.net/rfc7159) implicitly describes a boolean as a - type which differentiates the two literals `true` and `false`. - - To store objects in C++, a type is defined by the template parameter @a - BooleanType which chooses the type to use. - - #### Default type - - With the default values for @a BooleanType (`bool`), the default value for - @a boolean_t is: - - @code {.cpp} - bool - @endcode - - #### Storage - - Boolean values are stored directly inside a @ref basic_json type. - - @since version 1.0.0 - */ - using boolean_t = BooleanType; - - /*! - @brief a type for a number (integer) - - [RFC 7159](http://rfc7159.net/rfc7159) describes numbers as follows: - > The representation of numbers is similar to that used in most - > programming languages. A number is represented in base 10 using decimal - > digits. It contains an integer component that may be prefixed with an - > optional minus sign, which may be followed by a fraction part and/or an - > exponent part. Leading zeros are not allowed. (...) Numeric values that - > cannot be represented in the grammar below (such as Infinity and NaN) - > are not permitted. - - This description includes both integer and floating-point numbers. - However, C++ allows more precise storage if it is known whether the number - is a signed integer, an unsigned integer or a floating-point number. - Therefore, three different types, @ref number_integer_t, @ref - number_unsigned_t and @ref number_float_t are used. - - To store integer numbers in C++, a type is defined by the template - parameter @a NumberIntegerType which chooses the type to use. - - #### Default type - - With the default values for @a NumberIntegerType (`int64_t`), the default - value for @a number_integer_t is: - - @code {.cpp} - int64_t - @endcode - - #### Default behavior - - - The restrictions about leading zeros is not enforced in C++. Instead, - leading zeros in integer literals lead to an interpretation as octal - number. Internally, the value will be stored as decimal number. For - instance, the C++ integer literal `010` will be serialized to `8`. - During deserialization, leading zeros yield an error. - - Not-a-number (NaN) values will be serialized to `null`. - - #### Limits - - [RFC 7159](http://rfc7159.net/rfc7159) specifies: - > An implementation may set limits on the range and precision of numbers. - - When the default type is used, the maximal integer number that can be - stored is `9223372036854775807` (INT64_MAX) and the minimal integer number - that can be stored is `-9223372036854775808` (INT64_MIN). Integer numbers - that are out of range will yield over/underflow when used in a - constructor. During deserialization, too large or small integer numbers - will be automatically be stored as @ref number_unsigned_t or @ref - number_float_t. - - [RFC 7159](http://rfc7159.net/rfc7159) further states: - > Note that when such software is used, numbers that are integers and are - > in the range \f$[-2^{53}+1, 2^{53}-1]\f$ are interoperable in the sense - > that implementations will agree exactly on their numeric values. - - As this range is a subrange of the exactly supported range [INT64_MIN, - INT64_MAX], this class's integer type is interoperable. - - #### Storage - - Integer number values are stored directly inside a @ref basic_json type. - - @sa @ref number_float_t -- type for number values (floating-point) - - @sa @ref number_unsigned_t -- type for number values (unsigned integer) - - @since version 1.0.0 - */ - using number_integer_t = NumberIntegerType; - - /*! - @brief a type for a number (unsigned) - - [RFC 7159](http://rfc7159.net/rfc7159) describes numbers as follows: - > The representation of numbers is similar to that used in most - > programming languages. A number is represented in base 10 using decimal - > digits. It contains an integer component that may be prefixed with an - > optional minus sign, which may be followed by a fraction part and/or an - > exponent part. Leading zeros are not allowed. (...) Numeric values that - > cannot be represented in the grammar below (such as Infinity and NaN) - > are not permitted. - - This description includes both integer and floating-point numbers. - However, C++ allows more precise storage if it is known whether the number - is a signed integer, an unsigned integer or a floating-point number. - Therefore, three different types, @ref number_integer_t, @ref - number_unsigned_t and @ref number_float_t are used. - - To store unsigned integer numbers in C++, a type is defined by the - template parameter @a NumberUnsignedType which chooses the type to use. - - #### Default type - - With the default values for @a NumberUnsignedType (`uint64_t`), the - default value for @a number_unsigned_t is: - - @code {.cpp} - uint64_t - @endcode - - #### Default behavior - - - The restrictions about leading zeros is not enforced in C++. Instead, - leading zeros in integer literals lead to an interpretation as octal - number. Internally, the value will be stored as decimal number. For - instance, the C++ integer literal `010` will be serialized to `8`. - During deserialization, leading zeros yield an error. - - Not-a-number (NaN) values will be serialized to `null`. - - #### Limits - - [RFC 7159](http://rfc7159.net/rfc7159) specifies: - > An implementation may set limits on the range and precision of numbers. - - When the default type is used, the maximal integer number that can be - stored is `18446744073709551615` (UINT64_MAX) and the minimal integer - number that can be stored is `0`. Integer numbers that are out of range - will yield over/underflow when used in a constructor. During - deserialization, too large or small integer numbers will be automatically - be stored as @ref number_integer_t or @ref number_float_t. - - [RFC 7159](http://rfc7159.net/rfc7159) further states: - > Note that when such software is used, numbers that are integers and are - > in the range \f$[-2^{53}+1, 2^{53}-1]\f$ are interoperable in the sense - > that implementations will agree exactly on their numeric values. - - As this range is a subrange (when considered in conjunction with the - number_integer_t type) of the exactly supported range [0, UINT64_MAX], - this class's integer type is interoperable. - - #### Storage - - Integer number values are stored directly inside a @ref basic_json type. - - @sa @ref number_float_t -- type for number values (floating-point) - @sa @ref number_integer_t -- type for number values (integer) - - @since version 2.0.0 - */ - using number_unsigned_t = NumberUnsignedType; - - /*! - @brief a type for a number (floating-point) - - [RFC 7159](http://rfc7159.net/rfc7159) describes numbers as follows: - > The representation of numbers is similar to that used in most - > programming languages. A number is represented in base 10 using decimal - > digits. It contains an integer component that may be prefixed with an - > optional minus sign, which may be followed by a fraction part and/or an - > exponent part. Leading zeros are not allowed. (...) Numeric values that - > cannot be represented in the grammar below (such as Infinity and NaN) - > are not permitted. - - This description includes both integer and floating-point numbers. - However, C++ allows more precise storage if it is known whether the number - is a signed integer, an unsigned integer or a floating-point number. - Therefore, three different types, @ref number_integer_t, @ref - number_unsigned_t and @ref number_float_t are used. - - To store floating-point numbers in C++, a type is defined by the template - parameter @a NumberFloatType which chooses the type to use. - - #### Default type - - With the default values for @a NumberFloatType (`double`), the default - value for @a number_float_t is: - - @code {.cpp} - double - @endcode - - #### Default behavior - - - The restrictions about leading zeros is not enforced in C++. Instead, - leading zeros in floating-point literals will be ignored. Internally, - the value will be stored as decimal number. For instance, the C++ - floating-point literal `01.2` will be serialized to `1.2`. During - deserialization, leading zeros yield an error. - - Not-a-number (NaN) values will be serialized to `null`. - - #### Limits - - [RFC 7159](http://rfc7159.net/rfc7159) states: - > This specification allows implementations to set limits on the range and - > precision of numbers accepted. Since software that implements IEEE - > 754-2008 binary64 (double precision) numbers is generally available and - > widely used, good interoperability can be achieved by implementations - > that expect no more precision or range than these provide, in the sense - > that implementations will approximate JSON numbers within the expected - > precision. - - This implementation does exactly follow this approach, as it uses double - precision floating-point numbers. Note values smaller than - `-1.79769313486232e+308` and values greater than `1.79769313486232e+308` - will be stored as NaN internally and be serialized to `null`. - - #### Storage - - Floating-point number values are stored directly inside a @ref basic_json - type. - - @sa @ref number_integer_t -- type for number values (integer) - - @sa @ref number_unsigned_t -- type for number values (unsigned integer) - - @since version 1.0.0 - */ - using number_float_t = NumberFloatType; - - /*! - @brief a type for a packed binary type - - This type is a type designed to carry binary data that appears in various - serialized formats, such as CBOR's Major Type 2, MessagePack's bin, and - BSON's generic binary subtype. This type is NOT a part of standard JSON and - exists solely for compatibility with these binary types. As such, it is - simply defined as an ordered sequence of zero or more byte values. - - Additionally, as an implementation detail, the subtype of the binary data is - carried around as a `std::uint8_t`, which is compatible with both of the - binary data formats that use binary subtyping, (though the specific - numbering is incompatible with each other, and it is up to the user to - translate between them). - - [CBOR's RFC 7049](https://tools.ietf.org/html/rfc7049) describes this type - as: - > Major type 2: a byte string. The string's length in bytes is represented - > following the rules for positive integers (major type 0). - - [MessagePack's documentation on the bin type - family](https://github.com/msgpack/msgpack/blob/master/spec.md#bin-format-family) - describes this type as: - > Bin format family stores an byte array in 2, 3, or 5 bytes of extra bytes - > in addition to the size of the byte array. - - [BSON's specifications](http://bsonspec.org/spec.html) describe several - binary types; however, this type is intended to represent the generic binary - type which has the description: - > Generic binary subtype - This is the most commonly used binary subtype and - > should be the 'default' for drivers and tools. - - None of these impose any limitations on the internal representation other - than the basic unit of storage be some type of array whose parts are - decomposable into bytes. - - The default representation of this binary format is a - `std::vector`, which is a very common way to represent a byte - array in modern C++. - - #### Default type - - The default values for @a BinaryType is `std::vector` - - #### Storage - - Binary Arrays are stored as pointers in a @ref basic_json type. That is, - for any access to array values, a pointer of the type `binary_t*` must be - dereferenced. - - #### Notes on subtypes - - - CBOR - - Binary values are represented as byte strings. No subtypes are - supported and will be ignored when CBOR is written. - - MessagePack - - If a subtype is given and the binary array contains exactly 1, 2, 4, 8, - or 16 elements, the fixext family (fixext1, fixext2, fixext4, fixext8) - is used. For other sizes, the ext family (ext8, ext16, ext32) is used. - The subtype is then added as singed 8-bit integer. - - If no subtype is given, the bin family (bin8, bin16, bin32) is used. - - BSON - - If a subtype is given, it is used and added as unsigned 8-bit integer. - - If no subtype is given, the generic binary subtype 0x00 is used. - - @sa @ref binary -- create a binary array - - @since version 3.8.0 - */ - using binary_t = nlohmann::byte_container_with_subtype; - /// @} - - private: - - /// helper for exception-safe object creation - template - JSON_HEDLEY_RETURNS_NON_NULL - static T* create(Args&& ... args) - { - AllocatorType alloc; - using AllocatorTraits = std::allocator_traits>; - - auto deleter = [&](T * object) - { - AllocatorTraits::deallocate(alloc, object, 1); - }; - std::unique_ptr object(AllocatorTraits::allocate(alloc, 1), deleter); - AllocatorTraits::construct(alloc, object.get(), std::forward(args)...); - JSON_ASSERT(object != nullptr); - return object.release(); - } - - //////////////////////// - // JSON value storage // - //////////////////////// - - /*! - @brief a JSON value - - The actual storage for a JSON value of the @ref basic_json class. This - union combines the different storage types for the JSON value types - defined in @ref value_t. - - JSON type | value_t type | used type - --------- | --------------- | ------------------------ - object | object | pointer to @ref object_t - array | array | pointer to @ref array_t - string | string | pointer to @ref string_t - boolean | boolean | @ref boolean_t - number | number_integer | @ref number_integer_t - number | number_unsigned | @ref number_unsigned_t - number | number_float | @ref number_float_t - binary | binary | pointer to @ref binary_t - null | null | *no value is stored* - - @note Variable-length types (objects, arrays, and strings) are stored as - pointers. The size of the union should not exceed 64 bits if the default - value types are used. - - @since version 1.0.0 - */ - union json_value - { - /// object (stored with pointer to save storage) - object_t* object; - /// array (stored with pointer to save storage) - array_t* array; - /// string (stored with pointer to save storage) - string_t* string; - /// binary (stored with pointer to save storage) - binary_t* binary; - /// boolean - boolean_t boolean; - /// number (integer) - number_integer_t number_integer; - /// number (unsigned integer) - number_unsigned_t number_unsigned; - /// number (floating-point) - number_float_t number_float; - - /// default constructor (for null values) - json_value() = default; - /// constructor for booleans - json_value(boolean_t v) noexcept : boolean(v) {} - /// constructor for numbers (integer) - json_value(number_integer_t v) noexcept : number_integer(v) {} - /// constructor for numbers (unsigned) - json_value(number_unsigned_t v) noexcept : number_unsigned(v) {} - /// constructor for numbers (floating-point) - json_value(number_float_t v) noexcept : number_float(v) {} - /// constructor for empty values of a given type - json_value(value_t t) - { - switch (t) - { - case value_t::object: - { - object = create(); - break; - } - - case value_t::array: - { - array = create(); - break; - } - - case value_t::string: - { - string = create(""); - break; - } - - case value_t::binary: - { - binary = create(); - break; - } - - case value_t::boolean: - { - boolean = boolean_t(false); - break; - } - - case value_t::number_integer: - { - number_integer = number_integer_t(0); - break; - } - - case value_t::number_unsigned: - { - number_unsigned = number_unsigned_t(0); - break; - } - - case value_t::number_float: - { - number_float = number_float_t(0.0); - break; - } - - case value_t::null: - { - object = nullptr; // silence warning, see #821 - break; - } - - default: - { - object = nullptr; // silence warning, see #821 - if (JSON_HEDLEY_UNLIKELY(t == value_t::null)) - { - JSON_THROW(other_error::create(500, "961c151d2e87f2686a955a9be24d316f1362bf21 3.9.1")); // LCOV_EXCL_LINE - } - break; - } - } - } - - /// constructor for strings - json_value(const string_t& value) - { - string = create(value); - } - - /// constructor for rvalue strings - json_value(string_t&& value) - { - string = create(std::move(value)); - } - - /// constructor for objects - json_value(const object_t& value) - { - object = create(value); - } - - /// constructor for rvalue objects - json_value(object_t&& value) - { - object = create(std::move(value)); - } - - /// constructor for arrays - json_value(const array_t& value) - { - array = create(value); - } - - /// constructor for rvalue arrays - json_value(array_t&& value) - { - array = create(std::move(value)); - } - - /// constructor for binary arrays - json_value(const typename binary_t::container_type& value) - { - binary = create(value); - } - - /// constructor for rvalue binary arrays - json_value(typename binary_t::container_type&& value) - { - binary = create(std::move(value)); - } - - /// constructor for binary arrays (internal type) - json_value(const binary_t& value) - { - binary = create(value); - } - - /// constructor for rvalue binary arrays (internal type) - json_value(binary_t&& value) - { - binary = create(std::move(value)); - } - - void destroy(value_t t) noexcept - { - // flatten the current json_value to a heap-allocated stack - std::vector stack; - - // move the top-level items to stack - if (t == value_t::array) - { - stack.reserve(array->size()); - std::move(array->begin(), array->end(), std::back_inserter(stack)); - } - else if (t == value_t::object) - { - stack.reserve(object->size()); - for (auto&& it : *object) - { - stack.push_back(std::move(it.second)); - } - } - - while (!stack.empty()) - { - // move the last item to local variable to be processed - basic_json current_item(std::move(stack.back())); - stack.pop_back(); - - // if current_item is array/object, move - // its children to the stack to be processed later - if (current_item.is_array()) - { - std::move(current_item.m_value.array->begin(), current_item.m_value.array->end(), - std::back_inserter(stack)); - - current_item.m_value.array->clear(); - } - else if (current_item.is_object()) - { - for (auto&& it : *current_item.m_value.object) - { - stack.push_back(std::move(it.second)); - } - - current_item.m_value.object->clear(); - } - - // it's now safe that current_item get destructed - // since it doesn't have any children - } - - switch (t) - { - case value_t::object: - { - AllocatorType alloc; - std::allocator_traits::destroy(alloc, object); - std::allocator_traits::deallocate(alloc, object, 1); - break; - } - - case value_t::array: - { - AllocatorType alloc; - std::allocator_traits::destroy(alloc, array); - std::allocator_traits::deallocate(alloc, array, 1); - break; - } - - case value_t::string: - { - AllocatorType alloc; - std::allocator_traits::destroy(alloc, string); - std::allocator_traits::deallocate(alloc, string, 1); - break; - } - - case value_t::binary: - { - AllocatorType alloc; - std::allocator_traits::destroy(alloc, binary); - std::allocator_traits::deallocate(alloc, binary, 1); - break; - } - - default: - { - break; - } - } - } - }; - - /*! - @brief checks the class invariants - - This function asserts the class invariants. It needs to be called at the - end of every constructor to make sure that created objects respect the - invariant. Furthermore, it has to be called each time the type of a JSON - value is changed, because the invariant expresses a relationship between - @a m_type and @a m_value. - */ - void assert_invariant() const noexcept - { - JSON_ASSERT(m_type != value_t::object || m_value.object != nullptr); - JSON_ASSERT(m_type != value_t::array || m_value.array != nullptr); - JSON_ASSERT(m_type != value_t::string || m_value.string != nullptr); - JSON_ASSERT(m_type != value_t::binary || m_value.binary != nullptr); - } - - public: - ////////////////////////// - // JSON parser callback // - ////////////////////////// - - /*! - @brief parser event types - - The parser callback distinguishes the following events: - - `object_start`: the parser read `{` and started to process a JSON object - - `key`: the parser read a key of a value in an object - - `object_end`: the parser read `}` and finished processing a JSON object - - `array_start`: the parser read `[` and started to process a JSON array - - `array_end`: the parser read `]` and finished processing a JSON array - - `value`: the parser finished reading a JSON value - - @image html callback_events.png "Example when certain parse events are triggered" - - @sa @ref parser_callback_t for more information and examples - */ - using parse_event_t = detail::parse_event_t; - - /*! - @brief per-element parser callback type - - With a parser callback function, the result of parsing a JSON text can be - influenced. When passed to @ref parse, it is called on certain events - (passed as @ref parse_event_t via parameter @a event) with a set recursion - depth @a depth and context JSON value @a parsed. The return value of the - callback function is a boolean indicating whether the element that emitted - the callback shall be kept or not. - - We distinguish six scenarios (determined by the event type) in which the - callback function can be called. The following table describes the values - of the parameters @a depth, @a event, and @a parsed. - - parameter @a event | description | parameter @a depth | parameter @a parsed - ------------------ | ----------- | ------------------ | ------------------- - parse_event_t::object_start | the parser read `{` and started to process a JSON object | depth of the parent of the JSON object | a JSON value with type discarded - parse_event_t::key | the parser read a key of a value in an object | depth of the currently parsed JSON object | a JSON string containing the key - parse_event_t::object_end | the parser read `}` and finished processing a JSON object | depth of the parent of the JSON object | the parsed JSON object - parse_event_t::array_start | the parser read `[` and started to process a JSON array | depth of the parent of the JSON array | a JSON value with type discarded - parse_event_t::array_end | the parser read `]` and finished processing a JSON array | depth of the parent of the JSON array | the parsed JSON array - parse_event_t::value | the parser finished reading a JSON value | depth of the value | the parsed JSON value - - @image html callback_events.png "Example when certain parse events are triggered" - - Discarding a value (i.e., returning `false`) has different effects - depending on the context in which function was called: - - - Discarded values in structured types are skipped. That is, the parser - will behave as if the discarded value was never read. - - In case a value outside a structured type is skipped, it is replaced - with `null`. This case happens if the top-level element is skipped. - - @param[in] depth the depth of the recursion during parsing - - @param[in] event an event of type parse_event_t indicating the context in - the callback function has been called - - @param[in,out] parsed the current intermediate parse result; note that - writing to this value has no effect for parse_event_t::key events - - @return Whether the JSON value which called the function during parsing - should be kept (`true`) or not (`false`). In the latter case, it is either - skipped completely or replaced by an empty discarded object. - - @sa @ref parse for examples - - @since version 1.0.0 - */ - using parser_callback_t = detail::parser_callback_t; - - ////////////////// - // constructors // - ////////////////// - - /// @name constructors and destructors - /// Constructors of class @ref basic_json, copy/move constructor, copy - /// assignment, static functions creating objects, and the destructor. - /// @{ - - /*! - @brief create an empty value with a given type - - Create an empty JSON value with a given type. The value will be default - initialized with an empty value which depends on the type: - - Value type | initial value - ----------- | ------------- - null | `null` - boolean | `false` - string | `""` - number | `0` - object | `{}` - array | `[]` - binary | empty array - - @param[in] v the type of the value to create - - @complexity Constant. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes to any JSON value. - - @liveexample{The following code shows the constructor for different @ref - value_t values,basic_json__value_t} - - @sa @ref clear() -- restores the postcondition of this constructor - - @since version 1.0.0 - */ - basic_json(const value_t v) - : m_type(v), m_value(v) - { - assert_invariant(); - } - - /*! - @brief create a null object - - Create a `null` JSON value. It either takes a null pointer as parameter - (explicitly creating `null`) or no parameter (implicitly creating `null`). - The passed null pointer itself is not read -- it is only used to choose - the right constructor. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this constructor never throws - exceptions. - - @liveexample{The following code shows the constructor with and without a - null pointer parameter.,basic_json__nullptr_t} - - @since version 1.0.0 - */ - basic_json(std::nullptr_t = nullptr) noexcept - : basic_json(value_t::null) - { - assert_invariant(); - } - - /*! - @brief create a JSON value - - This is a "catch all" constructor for all compatible JSON types; that is, - types for which a `to_json()` method exists. The constructor forwards the - parameter @a val to that method (to `json_serializer::to_json` method - with `U = uncvref_t`, to be exact). - - Template type @a CompatibleType includes, but is not limited to, the - following types: - - **arrays**: @ref array_t and all kinds of compatible containers such as - `std::vector`, `std::deque`, `std::list`, `std::forward_list`, - `std::array`, `std::valarray`, `std::set`, `std::unordered_set`, - `std::multiset`, and `std::unordered_multiset` with a `value_type` from - which a @ref basic_json value can be constructed. - - **objects**: @ref object_t and all kinds of compatible associative - containers such as `std::map`, `std::unordered_map`, `std::multimap`, - and `std::unordered_multimap` with a `key_type` compatible to - @ref string_t and a `value_type` from which a @ref basic_json value can - be constructed. - - **strings**: @ref string_t, string literals, and all compatible string - containers can be used. - - **numbers**: @ref number_integer_t, @ref number_unsigned_t, - @ref number_float_t, and all convertible number types such as `int`, - `size_t`, `int64_t`, `float` or `double` can be used. - - **boolean**: @ref boolean_t / `bool` can be used. - - **binary**: @ref binary_t / `std::vector` may be used, - unfortunately because string literals cannot be distinguished from binary - character arrays by the C++ type system, all types compatible with `const - char*` will be directed to the string constructor instead. This is both - for backwards compatibility, and due to the fact that a binary type is not - a standard JSON type. - - See the examples below. - - @tparam CompatibleType a type such that: - - @a CompatibleType is not derived from `std::istream`, - - @a CompatibleType is not @ref basic_json (to avoid hijacking copy/move - constructors), - - @a CompatibleType is not a different @ref basic_json type (i.e. with different template arguments) - - @a CompatibleType is not a @ref basic_json nested type (e.g., - @ref json_pointer, @ref iterator, etc ...) - - @ref @ref json_serializer has a - `to_json(basic_json_t&, CompatibleType&&)` method - - @tparam U = `uncvref_t` - - @param[in] val the value to be forwarded to the respective constructor - - @complexity Usually linear in the size of the passed @a val, also - depending on the implementation of the called `to_json()` - method. - - @exceptionsafety Depends on the called constructor. For types directly - supported by the library (i.e., all types for which no `to_json()` function - was provided), strong guarantee holds: if an exception is thrown, there are - no changes to any JSON value. - - @liveexample{The following code shows the constructor with several - compatible types.,basic_json__CompatibleType} - - @since version 2.1.0 - */ - template < typename CompatibleType, - typename U = detail::uncvref_t, - detail::enable_if_t < - !detail::is_basic_json::value && detail::is_compatible_type::value, int > = 0 > - basic_json(CompatibleType && val) noexcept(noexcept( - JSONSerializer::to_json(std::declval(), - std::forward(val)))) - { - JSONSerializer::to_json(*this, std::forward(val)); - assert_invariant(); - } - - /*! - @brief create a JSON value from an existing one - - This is a constructor for existing @ref basic_json types. - It does not hijack copy/move constructors, since the parameter has different - template arguments than the current ones. - - The constructor tries to convert the internal @ref m_value of the parameter. - - @tparam BasicJsonType a type such that: - - @a BasicJsonType is a @ref basic_json type. - - @a BasicJsonType has different template arguments than @ref basic_json_t. - - @param[in] val the @ref basic_json value to be converted. - - @complexity Usually linear in the size of the passed @a val, also - depending on the implementation of the called `to_json()` - method. - - @exceptionsafety Depends on the called constructor. For types directly - supported by the library (i.e., all types for which no `to_json()` function - was provided), strong guarantee holds: if an exception is thrown, there are - no changes to any JSON value. - - @since version 3.2.0 - */ - template < typename BasicJsonType, - detail::enable_if_t < - detail::is_basic_json::value&& !std::is_same::value, int > = 0 > - basic_json(const BasicJsonType& val) - { - using other_boolean_t = typename BasicJsonType::boolean_t; - using other_number_float_t = typename BasicJsonType::number_float_t; - using other_number_integer_t = typename BasicJsonType::number_integer_t; - using other_number_unsigned_t = typename BasicJsonType::number_unsigned_t; - using other_string_t = typename BasicJsonType::string_t; - using other_object_t = typename BasicJsonType::object_t; - using other_array_t = typename BasicJsonType::array_t; - using other_binary_t = typename BasicJsonType::binary_t; - - switch (val.type()) - { - case value_t::boolean: - JSONSerializer::to_json(*this, val.template get()); - break; - case value_t::number_float: - JSONSerializer::to_json(*this, val.template get()); - break; - case value_t::number_integer: - JSONSerializer::to_json(*this, val.template get()); - break; - case value_t::number_unsigned: - JSONSerializer::to_json(*this, val.template get()); - break; - case value_t::string: - JSONSerializer::to_json(*this, val.template get_ref()); - break; - case value_t::object: - JSONSerializer::to_json(*this, val.template get_ref()); - break; - case value_t::array: - JSONSerializer::to_json(*this, val.template get_ref()); - break; - case value_t::binary: - JSONSerializer::to_json(*this, val.template get_ref()); - break; - case value_t::null: - *this = nullptr; - break; - case value_t::discarded: - m_type = value_t::discarded; - break; - default: // LCOV_EXCL_LINE - JSON_ASSERT(false); // LCOV_EXCL_LINE - } - assert_invariant(); - } - - /*! - @brief create a container (array or object) from an initializer list - - Creates a JSON value of type array or object from the passed initializer - list @a init. In case @a type_deduction is `true` (default), the type of - the JSON value to be created is deducted from the initializer list @a init - according to the following rules: - - 1. If the list is empty, an empty JSON object value `{}` is created. - 2. If the list consists of pairs whose first element is a string, a JSON - object value is created where the first elements of the pairs are - treated as keys and the second elements are as values. - 3. In all other cases, an array is created. - - The rules aim to create the best fit between a C++ initializer list and - JSON values. The rationale is as follows: - - 1. The empty initializer list is written as `{}` which is exactly an empty - JSON object. - 2. C++ has no way of describing mapped types other than to list a list of - pairs. As JSON requires that keys must be of type string, rule 2 is the - weakest constraint one can pose on initializer lists to interpret them - as an object. - 3. In all other cases, the initializer list could not be interpreted as - JSON object type, so interpreting it as JSON array type is safe. - - With the rules described above, the following JSON values cannot be - expressed by an initializer list: - - - the empty array (`[]`): use @ref array(initializer_list_t) - with an empty initializer list in this case - - arrays whose elements satisfy rule 2: use @ref - array(initializer_list_t) with the same initializer list - in this case - - @note When used without parentheses around an empty initializer list, @ref - basic_json() is called instead of this function, yielding the JSON null - value. - - @param[in] init initializer list with JSON values - - @param[in] type_deduction internal parameter; when set to `true`, the type - of the JSON value is deducted from the initializer list @a init; when set - to `false`, the type provided via @a manual_type is forced. This mode is - used by the functions @ref array(initializer_list_t) and - @ref object(initializer_list_t). - - @param[in] manual_type internal parameter; when @a type_deduction is set - to `false`, the created JSON value will use the provided type (only @ref - value_t::array and @ref value_t::object are valid); when @a type_deduction - is set to `true`, this parameter has no effect - - @throw type_error.301 if @a type_deduction is `false`, @a manual_type is - `value_t::object`, but @a init contains an element which is not a pair - whose first element is a string. In this case, the constructor could not - create an object. If @a type_deduction would have be `true`, an array - would have been created. See @ref object(initializer_list_t) - for an example. - - @complexity Linear in the size of the initializer list @a init. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes to any JSON value. - - @liveexample{The example below shows how JSON values are created from - initializer lists.,basic_json__list_init_t} - - @sa @ref array(initializer_list_t) -- create a JSON array - value from an initializer list - @sa @ref object(initializer_list_t) -- create a JSON object - value from an initializer list - - @since version 1.0.0 - */ - basic_json(initializer_list_t init, - bool type_deduction = true, - value_t manual_type = value_t::array) - { - // check if each element is an array with two elements whose first - // element is a string - bool is_an_object = std::all_of(init.begin(), init.end(), - [](const detail::json_ref& element_ref) - { - return element_ref->is_array() && element_ref->size() == 2 && (*element_ref)[0].is_string(); - }); - - // adjust type if type deduction is not wanted - if (!type_deduction) - { - // if array is wanted, do not create an object though possible - if (manual_type == value_t::array) - { - is_an_object = false; - } - - // if object is wanted but impossible, throw an exception - if (JSON_HEDLEY_UNLIKELY(manual_type == value_t::object && !is_an_object)) - { - JSON_THROW(type_error::create(301, "cannot create object from initializer list")); - } - } - - if (is_an_object) - { - // the initializer list is a list of pairs -> create object - m_type = value_t::object; - m_value = value_t::object; - - std::for_each(init.begin(), init.end(), [this](const detail::json_ref& element_ref) - { - auto element = element_ref.moved_or_copied(); - m_value.object->emplace( - std::move(*((*element.m_value.array)[0].m_value.string)), - std::move((*element.m_value.array)[1])); - }); - } - else - { - // the initializer list describes an array -> create array - m_type = value_t::array; - m_value.array = create(init.begin(), init.end()); - } - - assert_invariant(); - } - - /*! - @brief explicitly create a binary array (without subtype) - - Creates a JSON binary array value from a given binary container. Binary - values are part of various binary formats, such as CBOR, MessagePack, and - BSON. This constructor is used to create a value for serialization to those - formats. - - @note Note, this function exists because of the difficulty in correctly - specifying the correct template overload in the standard value ctor, as both - JSON arrays and JSON binary arrays are backed with some form of a - `std::vector`. Because JSON binary arrays are a non-standard extension it - was decided that it would be best to prevent automatic initialization of a - binary array type, for backwards compatibility and so it does not happen on - accident. - - @param[in] init container containing bytes to use as binary type - - @return JSON binary array value - - @complexity Linear in the size of @a init. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes to any JSON value. - - @since version 3.8.0 - */ - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json binary(const typename binary_t::container_type& init) - { - auto res = basic_json(); - res.m_type = value_t::binary; - res.m_value = init; - return res; - } - - /*! - @brief explicitly create a binary array (with subtype) - - Creates a JSON binary array value from a given binary container. Binary - values are part of various binary formats, such as CBOR, MessagePack, and - BSON. This constructor is used to create a value for serialization to those - formats. - - @note Note, this function exists because of the difficulty in correctly - specifying the correct template overload in the standard value ctor, as both - JSON arrays and JSON binary arrays are backed with some form of a - `std::vector`. Because JSON binary arrays are a non-standard extension it - was decided that it would be best to prevent automatic initialization of a - binary array type, for backwards compatibility and so it does not happen on - accident. - - @param[in] init container containing bytes to use as binary type - @param[in] subtype subtype to use in MessagePack and BSON - - @return JSON binary array value - - @complexity Linear in the size of @a init. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes to any JSON value. - - @since version 3.8.0 - */ - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json binary(const typename binary_t::container_type& init, std::uint8_t subtype) - { - auto res = basic_json(); - res.m_type = value_t::binary; - res.m_value = binary_t(init, subtype); - return res; - } - - /// @copydoc binary(const typename binary_t::container_type&) - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json binary(typename binary_t::container_type&& init) - { - auto res = basic_json(); - res.m_type = value_t::binary; - res.m_value = std::move(init); - return res; - } - - /// @copydoc binary(const typename binary_t::container_type&, std::uint8_t) - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json binary(typename binary_t::container_type&& init, std::uint8_t subtype) - { - auto res = basic_json(); - res.m_type = value_t::binary; - res.m_value = binary_t(std::move(init), subtype); - return res; - } - - /*! - @brief explicitly create an array from an initializer list - - Creates a JSON array value from a given initializer list. That is, given a - list of values `a, b, c`, creates the JSON value `[a, b, c]`. If the - initializer list is empty, the empty array `[]` is created. - - @note This function is only needed to express two edge cases that cannot - be realized with the initializer list constructor (@ref - basic_json(initializer_list_t, bool, value_t)). These cases - are: - 1. creating an array whose elements are all pairs whose first element is a - string -- in this case, the initializer list constructor would create an - object, taking the first elements as keys - 2. creating an empty array -- passing the empty initializer list to the - initializer list constructor yields an empty object - - @param[in] init initializer list with JSON values to create an array from - (optional) - - @return JSON array value - - @complexity Linear in the size of @a init. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes to any JSON value. - - @liveexample{The following code shows an example for the `array` - function.,array} - - @sa @ref basic_json(initializer_list_t, bool, value_t) -- - create a JSON value from an initializer list - @sa @ref object(initializer_list_t) -- create a JSON object - value from an initializer list - - @since version 1.0.0 - */ - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json array(initializer_list_t init = {}) - { - return basic_json(init, false, value_t::array); - } - - /*! - @brief explicitly create an object from an initializer list - - Creates a JSON object value from a given initializer list. The initializer - lists elements must be pairs, and their first elements must be strings. If - the initializer list is empty, the empty object `{}` is created. - - @note This function is only added for symmetry reasons. In contrast to the - related function @ref array(initializer_list_t), there are - no cases which can only be expressed by this function. That is, any - initializer list @a init can also be passed to the initializer list - constructor @ref basic_json(initializer_list_t, bool, value_t). - - @param[in] init initializer list to create an object from (optional) - - @return JSON object value - - @throw type_error.301 if @a init is not a list of pairs whose first - elements are strings. In this case, no object can be created. When such a - value is passed to @ref basic_json(initializer_list_t, bool, value_t), - an array would have been created from the passed initializer list @a init. - See example below. - - @complexity Linear in the size of @a init. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes to any JSON value. - - @liveexample{The following code shows an example for the `object` - function.,object} - - @sa @ref basic_json(initializer_list_t, bool, value_t) -- - create a JSON value from an initializer list - @sa @ref array(initializer_list_t) -- create a JSON array - value from an initializer list - - @since version 1.0.0 - */ - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json object(initializer_list_t init = {}) - { - return basic_json(init, false, value_t::object); - } - - /*! - @brief construct an array with count copies of given value - - Constructs a JSON array value by creating @a cnt copies of a passed value. - In case @a cnt is `0`, an empty array is created. - - @param[in] cnt the number of JSON copies of @a val to create - @param[in] val the JSON value to copy - - @post `std::distance(begin(),end()) == cnt` holds. - - @complexity Linear in @a cnt. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes to any JSON value. - - @liveexample{The following code shows examples for the @ref - basic_json(size_type\, const basic_json&) - constructor.,basic_json__size_type_basic_json} - - @since version 1.0.0 - */ - basic_json(size_type cnt, const basic_json& val) - : m_type(value_t::array) - { - m_value.array = create(cnt, val); - assert_invariant(); - } - - /*! - @brief construct a JSON container given an iterator range - - Constructs the JSON value with the contents of the range `[first, last)`. - The semantics depends on the different types a JSON value can have: - - In case of a null type, invalid_iterator.206 is thrown. - - In case of other primitive types (number, boolean, or string), @a first - must be `begin()` and @a last must be `end()`. In this case, the value is - copied. Otherwise, invalid_iterator.204 is thrown. - - In case of structured types (array, object), the constructor behaves as - similar versions for `std::vector` or `std::map`; that is, a JSON array - or object is constructed from the values in the range. - - @tparam InputIT an input iterator type (@ref iterator or @ref - const_iterator) - - @param[in] first begin of the range to copy from (included) - @param[in] last end of the range to copy from (excluded) - - @pre Iterators @a first and @a last must be initialized. **This - precondition is enforced with an assertion (see warning).** If - assertions are switched off, a violation of this precondition yields - undefined behavior. - - @pre Range `[first, last)` is valid. Usually, this precondition cannot be - checked efficiently. Only certain edge cases are detected; see the - description of the exceptions below. A violation of this precondition - yields undefined behavior. - - @warning A precondition is enforced with a runtime assertion that will - result in calling `std::abort` if this precondition is not met. - Assertions can be disabled by defining `NDEBUG` at compile time. - See https://en.cppreference.com/w/cpp/error/assert for more - information. - - @throw invalid_iterator.201 if iterators @a first and @a last are not - compatible (i.e., do not belong to the same JSON value). In this case, - the range `[first, last)` is undefined. - @throw invalid_iterator.204 if iterators @a first and @a last belong to a - primitive type (number, boolean, or string), but @a first does not point - to the first element any more. In this case, the range `[first, last)` is - undefined. See example code below. - @throw invalid_iterator.206 if iterators @a first and @a last belong to a - null value. In this case, the range `[first, last)` is undefined. - - @complexity Linear in distance between @a first and @a last. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes to any JSON value. - - @liveexample{The example below shows several ways to create JSON values by - specifying a subrange with iterators.,basic_json__InputIt_InputIt} - - @since version 1.0.0 - */ - template < class InputIT, typename std::enable_if < - std::is_same::value || - std::is_same::value, int >::type = 0 > - basic_json(InputIT first, InputIT last) - { - JSON_ASSERT(first.m_object != nullptr); - JSON_ASSERT(last.m_object != nullptr); - - // make sure iterator fits the current value - if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object)) - { - JSON_THROW(invalid_iterator::create(201, "iterators are not compatible")); - } - - // copy type from first iterator - m_type = first.m_object->m_type; - - // check if iterator range is complete for primitive values - switch (m_type) - { - case value_t::boolean: - case value_t::number_float: - case value_t::number_integer: - case value_t::number_unsigned: - case value_t::string: - { - if (JSON_HEDLEY_UNLIKELY(!first.m_it.primitive_iterator.is_begin() - || !last.m_it.primitive_iterator.is_end())) - { - JSON_THROW(invalid_iterator::create(204, "iterators out of range")); - } - break; - } - - default: - break; - } - - switch (m_type) - { - case value_t::number_integer: - { - m_value.number_integer = first.m_object->m_value.number_integer; - break; - } - - case value_t::number_unsigned: - { - m_value.number_unsigned = first.m_object->m_value.number_unsigned; - break; - } - - case value_t::number_float: - { - m_value.number_float = first.m_object->m_value.number_float; - break; - } - - case value_t::boolean: - { - m_value.boolean = first.m_object->m_value.boolean; - break; - } - - case value_t::string: - { - m_value = *first.m_object->m_value.string; - break; - } - - case value_t::object: - { - m_value.object = create(first.m_it.object_iterator, - last.m_it.object_iterator); - break; - } - - case value_t::array: - { - m_value.array = create(first.m_it.array_iterator, - last.m_it.array_iterator); - break; - } - - case value_t::binary: - { - m_value = *first.m_object->m_value.binary; - break; - } - - default: - JSON_THROW(invalid_iterator::create(206, "cannot construct with iterators from " + - std::string(first.m_object->type_name()))); - } - - assert_invariant(); - } - - - /////////////////////////////////////// - // other constructors and destructor // - /////////////////////////////////////// - - template, - std::is_same>::value, int> = 0 > - basic_json(const JsonRef& ref) : basic_json(ref.moved_or_copied()) {} - - /*! - @brief copy constructor - - Creates a copy of a given JSON value. - - @param[in] other the JSON value to copy - - @post `*this == other` - - @complexity Linear in the size of @a other. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes to any JSON value. - - @requirement This function helps `basic_json` satisfying the - [Container](https://en.cppreference.com/w/cpp/named_req/Container) - requirements: - - The complexity is linear. - - As postcondition, it holds: `other == basic_json(other)`. - - @liveexample{The following code shows an example for the copy - constructor.,basic_json__basic_json} - - @since version 1.0.0 - */ - basic_json(const basic_json& other) - : m_type(other.m_type) - { - // check of passed value is valid - other.assert_invariant(); - - switch (m_type) - { - case value_t::object: - { - m_value = *other.m_value.object; - break; - } - - case value_t::array: - { - m_value = *other.m_value.array; - break; - } - - case value_t::string: - { - m_value = *other.m_value.string; - break; - } - - case value_t::boolean: - { - m_value = other.m_value.boolean; - break; - } - - case value_t::number_integer: - { - m_value = other.m_value.number_integer; - break; - } - - case value_t::number_unsigned: - { - m_value = other.m_value.number_unsigned; - break; - } - - case value_t::number_float: - { - m_value = other.m_value.number_float; - break; - } - - case value_t::binary: - { - m_value = *other.m_value.binary; - break; - } - - default: - break; - } - - assert_invariant(); - } - - /*! - @brief move constructor - - Move constructor. Constructs a JSON value with the contents of the given - value @a other using move semantics. It "steals" the resources from @a - other and leaves it as JSON null value. - - @param[in,out] other value to move to this object - - @post `*this` has the same value as @a other before the call. - @post @a other is a JSON null value. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this constructor never throws - exceptions. - - @requirement This function helps `basic_json` satisfying the - [MoveConstructible](https://en.cppreference.com/w/cpp/named_req/MoveConstructible) - requirements. - - @liveexample{The code below shows the move constructor explicitly called - via std::move.,basic_json__moveconstructor} - - @since version 1.0.0 - */ - basic_json(basic_json&& other) noexcept - : m_type(std::move(other.m_type)), - m_value(std::move(other.m_value)) - { - // check that passed value is valid - other.assert_invariant(); - - // invalidate payload - other.m_type = value_t::null; - other.m_value = {}; - - assert_invariant(); - } - - /*! - @brief copy assignment - - Copy assignment operator. Copies a JSON value via the "copy and swap" - strategy: It is expressed in terms of the copy constructor, destructor, - and the `swap()` member function. - - @param[in] other value to copy from - - @complexity Linear. - - @requirement This function helps `basic_json` satisfying the - [Container](https://en.cppreference.com/w/cpp/named_req/Container) - requirements: - - The complexity is linear. - - @liveexample{The code below shows and example for the copy assignment. It - creates a copy of value `a` which is then swapped with `b`. Finally\, the - copy of `a` (which is the null value after the swap) is - destroyed.,basic_json__copyassignment} - - @since version 1.0.0 - */ - basic_json& operator=(basic_json other) noexcept ( - std::is_nothrow_move_constructible::value&& - std::is_nothrow_move_assignable::value&& - std::is_nothrow_move_constructible::value&& - std::is_nothrow_move_assignable::value - ) - { - // check that passed value is valid - other.assert_invariant(); - - using std::swap; - swap(m_type, other.m_type); - swap(m_value, other.m_value); - - assert_invariant(); - return *this; - } - - /*! - @brief destructor - - Destroys the JSON value and frees all allocated memory. - - @complexity Linear. - - @requirement This function helps `basic_json` satisfying the - [Container](https://en.cppreference.com/w/cpp/named_req/Container) - requirements: - - The complexity is linear. - - All stored elements are destroyed and all memory is freed. - - @since version 1.0.0 - */ - ~basic_json() noexcept - { - assert_invariant(); - m_value.destroy(m_type); - } - - /// @} - - public: - /////////////////////// - // object inspection // - /////////////////////// - - /// @name object inspection - /// Functions to inspect the type of a JSON value. - /// @{ - - /*! - @brief serialization - - Serialization function for JSON values. The function tries to mimic - Python's `json.dumps()` function, and currently supports its @a indent - and @a ensure_ascii parameters. - - @param[in] indent If indent is nonnegative, then array elements and object - members will be pretty-printed with that indent level. An indent level of - `0` will only insert newlines. `-1` (the default) selects the most compact - representation. - @param[in] indent_char The character to use for indentation if @a indent is - greater than `0`. The default is ` ` (space). - @param[in] ensure_ascii If @a ensure_ascii is true, all non-ASCII characters - in the output are escaped with `\uXXXX` sequences, and the result consists - of ASCII characters only. - @param[in] error_handler how to react on decoding errors; there are three - possible values: `strict` (throws and exception in case a decoding error - occurs; default), `replace` (replace invalid UTF-8 sequences with U+FFFD), - and `ignore` (ignore invalid UTF-8 sequences during serialization; all - bytes are copied to the output unchanged). - - @return string containing the serialization of the JSON value - - @throw type_error.316 if a string stored inside the JSON value is not - UTF-8 encoded and @a error_handler is set to strict - - @note Binary values are serialized as object containing two keys: - - "bytes": an array of bytes as integers - - "subtype": the subtype as integer or "null" if the binary has no subtype - - @complexity Linear. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes in the JSON value. - - @liveexample{The following example shows the effect of different @a indent\, - @a indent_char\, and @a ensure_ascii parameters to the result of the - serialization.,dump} - - @see https://docs.python.org/2/library/json.html#json.dump - - @since version 1.0.0; indentation character @a indent_char, option - @a ensure_ascii and exceptions added in version 3.0.0; error - handlers added in version 3.4.0; serialization of binary values added - in version 3.8.0. - */ - string_t dump(const int indent = -1, - const char indent_char = ' ', - const bool ensure_ascii = false, - const error_handler_t error_handler = error_handler_t::strict) const - { - string_t result; - serializer s(detail::output_adapter(result), indent_char, error_handler); - - if (indent >= 0) - { - s.dump(*this, true, ensure_ascii, static_cast(indent)); - } - else - { - s.dump(*this, false, ensure_ascii, 0); - } - - return result; - } - - /*! - @brief return the type of the JSON value (explicit) - - Return the type of the JSON value as a value from the @ref value_t - enumeration. - - @return the type of the JSON value - Value type | return value - ------------------------- | ------------------------- - null | value_t::null - boolean | value_t::boolean - string | value_t::string - number (integer) | value_t::number_integer - number (unsigned integer) | value_t::number_unsigned - number (floating-point) | value_t::number_float - object | value_t::object - array | value_t::array - binary | value_t::binary - discarded | value_t::discarded - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `type()` for all JSON - types.,type} - - @sa @ref operator value_t() -- return the type of the JSON value (implicit) - @sa @ref type_name() -- return the type as string - - @since version 1.0.0 - */ - constexpr value_t type() const noexcept - { - return m_type; - } - - /*! - @brief return whether type is primitive - - This function returns true if and only if the JSON type is primitive - (string, number, boolean, or null). - - @return `true` if type is primitive (string, number, boolean, or null), - `false` otherwise. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `is_primitive()` for all JSON - types.,is_primitive} - - @sa @ref is_structured() -- returns whether JSON value is structured - @sa @ref is_null() -- returns whether JSON value is `null` - @sa @ref is_string() -- returns whether JSON value is a string - @sa @ref is_boolean() -- returns whether JSON value is a boolean - @sa @ref is_number() -- returns whether JSON value is a number - @sa @ref is_binary() -- returns whether JSON value is a binary array - - @since version 1.0.0 - */ - constexpr bool is_primitive() const noexcept - { - return is_null() || is_string() || is_boolean() || is_number() || is_binary(); - } - - /*! - @brief return whether type is structured - - This function returns true if and only if the JSON type is structured - (array or object). - - @return `true` if type is structured (array or object), `false` otherwise. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `is_structured()` for all JSON - types.,is_structured} - - @sa @ref is_primitive() -- returns whether value is primitive - @sa @ref is_array() -- returns whether value is an array - @sa @ref is_object() -- returns whether value is an object - - @since version 1.0.0 - */ - constexpr bool is_structured() const noexcept - { - return is_array() || is_object(); - } - - /*! - @brief return whether value is null - - This function returns true if and only if the JSON value is null. - - @return `true` if type is null, `false` otherwise. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `is_null()` for all JSON - types.,is_null} - - @since version 1.0.0 - */ - constexpr bool is_null() const noexcept - { - return m_type == value_t::null; - } - - /*! - @brief return whether value is a boolean - - This function returns true if and only if the JSON value is a boolean. - - @return `true` if type is boolean, `false` otherwise. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `is_boolean()` for all JSON - types.,is_boolean} - - @since version 1.0.0 - */ - constexpr bool is_boolean() const noexcept - { - return m_type == value_t::boolean; - } - - /*! - @brief return whether value is a number - - This function returns true if and only if the JSON value is a number. This - includes both integer (signed and unsigned) and floating-point values. - - @return `true` if type is number (regardless whether integer, unsigned - integer or floating-type), `false` otherwise. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `is_number()` for all JSON - types.,is_number} - - @sa @ref is_number_integer() -- check if value is an integer or unsigned - integer number - @sa @ref is_number_unsigned() -- check if value is an unsigned integer - number - @sa @ref is_number_float() -- check if value is a floating-point number - - @since version 1.0.0 - */ - constexpr bool is_number() const noexcept - { - return is_number_integer() || is_number_float(); - } - - /*! - @brief return whether value is an integer number - - This function returns true if and only if the JSON value is a signed or - unsigned integer number. This excludes floating-point values. - - @return `true` if type is an integer or unsigned integer number, `false` - otherwise. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `is_number_integer()` for all - JSON types.,is_number_integer} - - @sa @ref is_number() -- check if value is a number - @sa @ref is_number_unsigned() -- check if value is an unsigned integer - number - @sa @ref is_number_float() -- check if value is a floating-point number - - @since version 1.0.0 - */ - constexpr bool is_number_integer() const noexcept - { - return m_type == value_t::number_integer || m_type == value_t::number_unsigned; - } - - /*! - @brief return whether value is an unsigned integer number - - This function returns true if and only if the JSON value is an unsigned - integer number. This excludes floating-point and signed integer values. - - @return `true` if type is an unsigned integer number, `false` otherwise. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `is_number_unsigned()` for all - JSON types.,is_number_unsigned} - - @sa @ref is_number() -- check if value is a number - @sa @ref is_number_integer() -- check if value is an integer or unsigned - integer number - @sa @ref is_number_float() -- check if value is a floating-point number - - @since version 2.0.0 - */ - constexpr bool is_number_unsigned() const noexcept - { - return m_type == value_t::number_unsigned; - } - - /*! - @brief return whether value is a floating-point number - - This function returns true if and only if the JSON value is a - floating-point number. This excludes signed and unsigned integer values. - - @return `true` if type is a floating-point number, `false` otherwise. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `is_number_float()` for all - JSON types.,is_number_float} - - @sa @ref is_number() -- check if value is number - @sa @ref is_number_integer() -- check if value is an integer number - @sa @ref is_number_unsigned() -- check if value is an unsigned integer - number - - @since version 1.0.0 - */ - constexpr bool is_number_float() const noexcept - { - return m_type == value_t::number_float; - } - - /*! - @brief return whether value is an object - - This function returns true if and only if the JSON value is an object. - - @return `true` if type is object, `false` otherwise. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `is_object()` for all JSON - types.,is_object} - - @since version 1.0.0 - */ - constexpr bool is_object() const noexcept - { - return m_type == value_t::object; - } - - /*! - @brief return whether value is an array - - This function returns true if and only if the JSON value is an array. - - @return `true` if type is array, `false` otherwise. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `is_array()` for all JSON - types.,is_array} - - @since version 1.0.0 - */ - constexpr bool is_array() const noexcept - { - return m_type == value_t::array; - } - - /*! - @brief return whether value is a string - - This function returns true if and only if the JSON value is a string. - - @return `true` if type is string, `false` otherwise. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `is_string()` for all JSON - types.,is_string} - - @since version 1.0.0 - */ - constexpr bool is_string() const noexcept - { - return m_type == value_t::string; - } - - /*! - @brief return whether value is a binary array - - This function returns true if and only if the JSON value is a binary array. - - @return `true` if type is binary array, `false` otherwise. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `is_binary()` for all JSON - types.,is_binary} - - @since version 3.8.0 - */ - constexpr bool is_binary() const noexcept - { - return m_type == value_t::binary; - } - - /*! - @brief return whether value is discarded - - This function returns true if and only if the JSON value was discarded - during parsing with a callback function (see @ref parser_callback_t). - - @note This function will always be `false` for JSON values after parsing. - That is, discarded values can only occur during parsing, but will be - removed when inside a structured value or replaced by null in other cases. - - @return `true` if type is discarded, `false` otherwise. - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies `is_discarded()` for all JSON - types.,is_discarded} - - @since version 1.0.0 - */ - constexpr bool is_discarded() const noexcept - { - return m_type == value_t::discarded; - } - - /*! - @brief return the type of the JSON value (implicit) - - Implicitly return the type of the JSON value as a value from the @ref - value_t enumeration. - - @return the type of the JSON value - - @complexity Constant. - - @exceptionsafety No-throw guarantee: this member function never throws - exceptions. - - @liveexample{The following code exemplifies the @ref value_t operator for - all JSON types.,operator__value_t} - - @sa @ref type() -- return the type of the JSON value (explicit) - @sa @ref type_name() -- return the type as string - - @since version 1.0.0 - */ - constexpr operator value_t() const noexcept - { - return m_type; - } - - /// @} - - private: - ////////////////// - // value access // - ////////////////// - - /// get a boolean (explicit) - boolean_t get_impl(boolean_t* /*unused*/) const - { - if (JSON_HEDLEY_LIKELY(is_boolean())) - { - return m_value.boolean; - } - - JSON_THROW(type_error::create(302, "type must be boolean, but is " + std::string(type_name()))); - } - - /// get a pointer to the value (object) - object_t* get_impl_ptr(object_t* /*unused*/) noexcept - { - return is_object() ? m_value.object : nullptr; - } - - /// get a pointer to the value (object) - constexpr const object_t* get_impl_ptr(const object_t* /*unused*/) const noexcept - { - return is_object() ? m_value.object : nullptr; - } - - /// get a pointer to the value (array) - array_t* get_impl_ptr(array_t* /*unused*/) noexcept - { - return is_array() ? m_value.array : nullptr; - } - - /// get a pointer to the value (array) - constexpr const array_t* get_impl_ptr(const array_t* /*unused*/) const noexcept - { - return is_array() ? m_value.array : nullptr; - } - - /// get a pointer to the value (string) - string_t* get_impl_ptr(string_t* /*unused*/) noexcept - { - return is_string() ? m_value.string : nullptr; - } - - /// get a pointer to the value (string) - constexpr const string_t* get_impl_ptr(const string_t* /*unused*/) const noexcept - { - return is_string() ? m_value.string : nullptr; - } - - /// get a pointer to the value (boolean) - boolean_t* get_impl_ptr(boolean_t* /*unused*/) noexcept - { - return is_boolean() ? &m_value.boolean : nullptr; - } - - /// get a pointer to the value (boolean) - constexpr const boolean_t* get_impl_ptr(const boolean_t* /*unused*/) const noexcept - { - return is_boolean() ? &m_value.boolean : nullptr; - } - - /// get a pointer to the value (integer number) - number_integer_t* get_impl_ptr(number_integer_t* /*unused*/) noexcept - { - return is_number_integer() ? &m_value.number_integer : nullptr; - } - - /// get a pointer to the value (integer number) - constexpr const number_integer_t* get_impl_ptr(const number_integer_t* /*unused*/) const noexcept - { - return is_number_integer() ? &m_value.number_integer : nullptr; - } - - /// get a pointer to the value (unsigned number) - number_unsigned_t* get_impl_ptr(number_unsigned_t* /*unused*/) noexcept - { - return is_number_unsigned() ? &m_value.number_unsigned : nullptr; - } - - /// get a pointer to the value (unsigned number) - constexpr const number_unsigned_t* get_impl_ptr(const number_unsigned_t* /*unused*/) const noexcept - { - return is_number_unsigned() ? &m_value.number_unsigned : nullptr; - } - - /// get a pointer to the value (floating-point number) - number_float_t* get_impl_ptr(number_float_t* /*unused*/) noexcept - { - return is_number_float() ? &m_value.number_float : nullptr; - } - - /// get a pointer to the value (floating-point number) - constexpr const number_float_t* get_impl_ptr(const number_float_t* /*unused*/) const noexcept - { - return is_number_float() ? &m_value.number_float : nullptr; - } - - /// get a pointer to the value (binary) - binary_t* get_impl_ptr(binary_t* /*unused*/) noexcept - { - return is_binary() ? m_value.binary : nullptr; - } - - /// get a pointer to the value (binary) - constexpr const binary_t* get_impl_ptr(const binary_t* /*unused*/) const noexcept - { - return is_binary() ? m_value.binary : nullptr; - } - - /*! - @brief helper function to implement get_ref() - - This function helps to implement get_ref() without code duplication for - const and non-const overloads - - @tparam ThisType will be deduced as `basic_json` or `const basic_json` - - @throw type_error.303 if ReferenceType does not match underlying value - type of the current JSON - */ - template - static ReferenceType get_ref_impl(ThisType& obj) - { - // delegate the call to get_ptr<>() - auto ptr = obj.template get_ptr::type>(); - - if (JSON_HEDLEY_LIKELY(ptr != nullptr)) - { - return *ptr; - } - - JSON_THROW(type_error::create(303, "incompatible ReferenceType for get_ref, actual type is " + std::string(obj.type_name()))); - } - - public: - /// @name value access - /// Direct access to the stored value of a JSON value. - /// @{ - - /*! - @brief get special-case overload - - This overloads avoids a lot of template boilerplate, it can be seen as the - identity method - - @tparam BasicJsonType == @ref basic_json - - @return a copy of *this - - @complexity Constant. - - @since version 2.1.0 - */ - template::type, basic_json_t>::value, - int> = 0> - basic_json get() const - { - return *this; - } - - /*! - @brief get special-case overload - - This overloads converts the current @ref basic_json in a different - @ref basic_json type - - @tparam BasicJsonType == @ref basic_json - - @return a copy of *this, converted into @tparam BasicJsonType - - @complexity Depending on the implementation of the called `from_json()` - method. - - @since version 3.2.0 - */ - template < typename BasicJsonType, detail::enable_if_t < - !std::is_same::value&& - detail::is_basic_json::value, int > = 0 > - BasicJsonType get() const - { - return *this; - } - - /*! - @brief get a value (explicit) - - Explicit type conversion between the JSON value and a compatible value - which is [CopyConstructible](https://en.cppreference.com/w/cpp/named_req/CopyConstructible) - and [DefaultConstructible](https://en.cppreference.com/w/cpp/named_req/DefaultConstructible). - The value is converted by calling the @ref json_serializer - `from_json()` method. - - The function is equivalent to executing - @code {.cpp} - ValueType ret; - JSONSerializer::from_json(*this, ret); - return ret; - @endcode - - This overloads is chosen if: - - @a ValueType is not @ref basic_json, - - @ref json_serializer has a `from_json()` method of the form - `void from_json(const basic_json&, ValueType&)`, and - - @ref json_serializer does not have a `from_json()` method of - the form `ValueType from_json(const basic_json&)` - - @tparam ValueTypeCV the provided value type - @tparam ValueType the returned value type - - @return copy of the JSON value, converted to @a ValueType - - @throw what @ref json_serializer `from_json()` method throws - - @liveexample{The example below shows several conversions from JSON values - to other types. There a few things to note: (1) Floating-point numbers can - be converted to integers\, (2) A JSON array can be converted to a standard - `std::vector`\, (3) A JSON object can be converted to C++ - associative containers such as `std::unordered_map`.,get__ValueType_const} - - @since version 2.1.0 - */ - template < typename ValueTypeCV, typename ValueType = detail::uncvref_t, - detail::enable_if_t < - !detail::is_basic_json::value && - detail::has_from_json::value && - !detail::has_non_default_from_json::value, - int > = 0 > - ValueType get() const noexcept(noexcept( - JSONSerializer::from_json(std::declval(), std::declval()))) - { - // we cannot static_assert on ValueTypeCV being non-const, because - // there is support for get(), which is why we - // still need the uncvref - static_assert(!std::is_reference::value, - "get() cannot be used with reference types, you might want to use get_ref()"); - static_assert(std::is_default_constructible::value, - "types must be DefaultConstructible when used with get()"); - - ValueType ret; - JSONSerializer::from_json(*this, ret); - return ret; - } - - /*! - @brief get a value (explicit); special case - - Explicit type conversion between the JSON value and a compatible value - which is **not** [CopyConstructible](https://en.cppreference.com/w/cpp/named_req/CopyConstructible) - and **not** [DefaultConstructible](https://en.cppreference.com/w/cpp/named_req/DefaultConstructible). - The value is converted by calling the @ref json_serializer - `from_json()` method. - - The function is equivalent to executing - @code {.cpp} - return JSONSerializer::from_json(*this); - @endcode - - This overloads is chosen if: - - @a ValueType is not @ref basic_json and - - @ref json_serializer has a `from_json()` method of the form - `ValueType from_json(const basic_json&)` - - @note If @ref json_serializer has both overloads of - `from_json()`, this one is chosen. - - @tparam ValueTypeCV the provided value type - @tparam ValueType the returned value type - - @return copy of the JSON value, converted to @a ValueType - - @throw what @ref json_serializer `from_json()` method throws - - @since version 2.1.0 - */ - template < typename ValueTypeCV, typename ValueType = detail::uncvref_t, - detail::enable_if_t < !std::is_same::value && - detail::has_non_default_from_json::value, - int > = 0 > - ValueType get() const noexcept(noexcept( - JSONSerializer::from_json(std::declval()))) - { - static_assert(!std::is_reference::value, - "get() cannot be used with reference types, you might want to use get_ref()"); - return JSONSerializer::from_json(*this); - } - - /*! - @brief get a value (explicit) - - Explicit type conversion between the JSON value and a compatible value. - The value is filled into the input parameter by calling the @ref json_serializer - `from_json()` method. - - The function is equivalent to executing - @code {.cpp} - ValueType v; - JSONSerializer::from_json(*this, v); - @endcode - - This overloads is chosen if: - - @a ValueType is not @ref basic_json, - - @ref json_serializer has a `from_json()` method of the form - `void from_json(const basic_json&, ValueType&)`, and - - @tparam ValueType the input parameter type. - - @return the input parameter, allowing chaining calls. - - @throw what @ref json_serializer `from_json()` method throws - - @liveexample{The example below shows several conversions from JSON values - to other types. There a few things to note: (1) Floating-point numbers can - be converted to integers\, (2) A JSON array can be converted to a standard - `std::vector`\, (3) A JSON object can be converted to C++ - associative containers such as `std::unordered_map`.,get_to} - - @since version 3.3.0 - */ - template < typename ValueType, - detail::enable_if_t < - !detail::is_basic_json::value&& - detail::has_from_json::value, - int > = 0 > - ValueType & get_to(ValueType& v) const noexcept(noexcept( - JSONSerializer::from_json(std::declval(), v))) - { - JSONSerializer::from_json(*this, v); - return v; - } - - // specialization to allow to call get_to with a basic_json value - // see https://github.com/nlohmann/json/issues/2175 - template::value, - int> = 0> - ValueType & get_to(ValueType& v) const - { - v = *this; - return v; - } - - template < - typename T, std::size_t N, - typename Array = T (&)[N], - detail::enable_if_t < - detail::has_from_json::value, int > = 0 > - Array get_to(T (&v)[N]) const - noexcept(noexcept(JSONSerializer::from_json( - std::declval(), v))) - { - JSONSerializer::from_json(*this, v); - return v; - } - - - /*! - @brief get a pointer value (implicit) - - Implicit pointer access to the internally stored JSON value. No copies are - made. - - @warning Writing data to the pointee of the result yields an undefined - state. - - @tparam PointerType pointer type; must be a pointer to @ref array_t, @ref - object_t, @ref string_t, @ref boolean_t, @ref number_integer_t, - @ref number_unsigned_t, or @ref number_float_t. Enforced by a static - assertion. - - @return pointer to the internally stored JSON value if the requested - pointer type @a PointerType fits to the JSON value; `nullptr` otherwise - - @complexity Constant. - - @liveexample{The example below shows how pointers to internal values of a - JSON value can be requested. Note that no type conversions are made and a - `nullptr` is returned if the value and the requested pointer type does not - match.,get_ptr} - - @since version 1.0.0 - */ - template::value, int>::type = 0> - auto get_ptr() noexcept -> decltype(std::declval().get_impl_ptr(std::declval())) - { - // delegate the call to get_impl_ptr<>() - return get_impl_ptr(static_cast(nullptr)); - } - - /*! - @brief get a pointer value (implicit) - @copydoc get_ptr() - */ - template < typename PointerType, typename std::enable_if < - std::is_pointer::value&& - std::is_const::type>::value, int >::type = 0 > - constexpr auto get_ptr() const noexcept -> decltype(std::declval().get_impl_ptr(std::declval())) - { - // delegate the call to get_impl_ptr<>() const - return get_impl_ptr(static_cast(nullptr)); - } - - /*! - @brief get a pointer value (explicit) - - Explicit pointer access to the internally stored JSON value. No copies are - made. - - @warning The pointer becomes invalid if the underlying JSON object - changes. - - @tparam PointerType pointer type; must be a pointer to @ref array_t, @ref - object_t, @ref string_t, @ref boolean_t, @ref number_integer_t, - @ref number_unsigned_t, or @ref number_float_t. - - @return pointer to the internally stored JSON value if the requested - pointer type @a PointerType fits to the JSON value; `nullptr` otherwise - - @complexity Constant. - - @liveexample{The example below shows how pointers to internal values of a - JSON value can be requested. Note that no type conversions are made and a - `nullptr` is returned if the value and the requested pointer type does not - match.,get__PointerType} - - @sa @ref get_ptr() for explicit pointer-member access - - @since version 1.0.0 - */ - template::value, int>::type = 0> - auto get() noexcept -> decltype(std::declval().template get_ptr()) - { - // delegate the call to get_ptr - return get_ptr(); - } - - /*! - @brief get a pointer value (explicit) - @copydoc get() - */ - template::value, int>::type = 0> - constexpr auto get() const noexcept -> decltype(std::declval().template get_ptr()) - { - // delegate the call to get_ptr - return get_ptr(); - } - - /*! - @brief get a reference value (implicit) - - Implicit reference access to the internally stored JSON value. No copies - are made. - - @warning Writing data to the referee of the result yields an undefined - state. - - @tparam ReferenceType reference type; must be a reference to @ref array_t, - @ref object_t, @ref string_t, @ref boolean_t, @ref number_integer_t, or - @ref number_float_t. Enforced by static assertion. - - @return reference to the internally stored JSON value if the requested - reference type @a ReferenceType fits to the JSON value; throws - type_error.303 otherwise - - @throw type_error.303 in case passed type @a ReferenceType is incompatible - with the stored JSON value; see example below - - @complexity Constant. - - @liveexample{The example shows several calls to `get_ref()`.,get_ref} - - @since version 1.1.0 - */ - template::value, int>::type = 0> - ReferenceType get_ref() - { - // delegate call to get_ref_impl - return get_ref_impl(*this); - } - - /*! - @brief get a reference value (implicit) - @copydoc get_ref() - */ - template < typename ReferenceType, typename std::enable_if < - std::is_reference::value&& - std::is_const::type>::value, int >::type = 0 > - ReferenceType get_ref() const - { - // delegate call to get_ref_impl - return get_ref_impl(*this); - } - - /*! - @brief get a value (implicit) - - Implicit type conversion between the JSON value and a compatible value. - The call is realized by calling @ref get() const. - - @tparam ValueType non-pointer type compatible to the JSON value, for - instance `int` for JSON integer numbers, `bool` for JSON booleans, or - `std::vector` types for JSON arrays. The character type of @ref string_t - as well as an initializer list of this type is excluded to avoid - ambiguities as these types implicitly convert to `std::string`. - - @return copy of the JSON value, converted to type @a ValueType - - @throw type_error.302 in case passed type @a ValueType is incompatible - to the JSON value type (e.g., the JSON value is of type boolean, but a - string is requested); see example below - - @complexity Linear in the size of the JSON value. - - @liveexample{The example below shows several conversions from JSON values - to other types. There a few things to note: (1) Floating-point numbers can - be converted to integers\, (2) A JSON array can be converted to a standard - `std::vector`\, (3) A JSON object can be converted to C++ - associative containers such as `std::unordered_map`.,operator__ValueType} - - @since version 1.0.0 - */ - template < typename ValueType, typename std::enable_if < - !std::is_pointer::value&& - !std::is_same>::value&& - !std::is_same::value&& - !detail::is_basic_json::value - && !std::is_same>::value -#if defined(JSON_HAS_CPP_17) && (defined(__GNUC__) || (defined(_MSC_VER) && _MSC_VER >= 1910 && _MSC_VER <= 1914)) - && !std::is_same::value -#endif - && detail::is_detected::value - , int >::type = 0 > - JSON_EXPLICIT operator ValueType() const - { - // delegate the call to get<>() const - return get(); - } - - /*! - @return reference to the binary value - - @throw type_error.302 if the value is not binary - - @sa @ref is_binary() to check if the value is binary - - @since version 3.8.0 - */ - binary_t& get_binary() - { - if (!is_binary()) - { - JSON_THROW(type_error::create(302, "type must be binary, but is " + std::string(type_name()))); - } - - return *get_ptr(); - } - - /// @copydoc get_binary() - const binary_t& get_binary() const - { - if (!is_binary()) - { - JSON_THROW(type_error::create(302, "type must be binary, but is " + std::string(type_name()))); - } - - return *get_ptr(); - } - - /// @} - - - //////////////////// - // element access // - //////////////////// - - /// @name element access - /// Access to the JSON value. - /// @{ - - /*! - @brief access specified array element with bounds checking - - Returns a reference to the element at specified location @a idx, with - bounds checking. - - @param[in] idx index of the element to access - - @return reference to the element at index @a idx - - @throw type_error.304 if the JSON value is not an array; in this case, - calling `at` with an index makes no sense. See example below. - @throw out_of_range.401 if the index @a idx is out of range of the array; - that is, `idx >= size()`. See example below. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes in the JSON value. - - @complexity Constant. - - @since version 1.0.0 - - @liveexample{The example below shows how array elements can be read and - written using `at()`. It also demonstrates the different exceptions that - can be thrown.,at__size_type} - */ - reference at(size_type idx) - { - // at only works for arrays - if (JSON_HEDLEY_LIKELY(is_array())) - { - JSON_TRY - { - return m_value.array->at(idx); - } - JSON_CATCH (std::out_of_range&) - { - // create better exception explanation - JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range")); - } - } - else - { - JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()))); - } - } - - /*! - @brief access specified array element with bounds checking - - Returns a const reference to the element at specified location @a idx, - with bounds checking. - - @param[in] idx index of the element to access - - @return const reference to the element at index @a idx - - @throw type_error.304 if the JSON value is not an array; in this case, - calling `at` with an index makes no sense. See example below. - @throw out_of_range.401 if the index @a idx is out of range of the array; - that is, `idx >= size()`. See example below. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes in the JSON value. - - @complexity Constant. - - @since version 1.0.0 - - @liveexample{The example below shows how array elements can be read using - `at()`. It also demonstrates the different exceptions that can be thrown., - at__size_type_const} - */ - const_reference at(size_type idx) const - { - // at only works for arrays - if (JSON_HEDLEY_LIKELY(is_array())) - { - JSON_TRY - { - return m_value.array->at(idx); - } - JSON_CATCH (std::out_of_range&) - { - // create better exception explanation - JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range")); - } - } - else - { - JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()))); - } - } - - /*! - @brief access specified object element with bounds checking - - Returns a reference to the element at with specified key @a key, with - bounds checking. - - @param[in] key key of the element to access - - @return reference to the element at key @a key - - @throw type_error.304 if the JSON value is not an object; in this case, - calling `at` with a key makes no sense. See example below. - @throw out_of_range.403 if the key @a key is is not stored in the object; - that is, `find(key) == end()`. See example below. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes in the JSON value. - - @complexity Logarithmic in the size of the container. - - @sa @ref operator[](const typename object_t::key_type&) for unchecked - access by reference - @sa @ref value() for access by value with a default value - - @since version 1.0.0 - - @liveexample{The example below shows how object elements can be read and - written using `at()`. It also demonstrates the different exceptions that - can be thrown.,at__object_t_key_type} - */ - reference at(const typename object_t::key_type& key) - { - // at only works for objects - if (JSON_HEDLEY_LIKELY(is_object())) - { - JSON_TRY - { - return m_value.object->at(key); - } - JSON_CATCH (std::out_of_range&) - { - // create better exception explanation - JSON_THROW(out_of_range::create(403, "key '" + key + "' not found")); - } - } - else - { - JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()))); - } - } - - /*! - @brief access specified object element with bounds checking - - Returns a const reference to the element at with specified key @a key, - with bounds checking. - - @param[in] key key of the element to access - - @return const reference to the element at key @a key - - @throw type_error.304 if the JSON value is not an object; in this case, - calling `at` with a key makes no sense. See example below. - @throw out_of_range.403 if the key @a key is is not stored in the object; - that is, `find(key) == end()`. See example below. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes in the JSON value. - - @complexity Logarithmic in the size of the container. - - @sa @ref operator[](const typename object_t::key_type&) for unchecked - access by reference - @sa @ref value() for access by value with a default value - - @since version 1.0.0 - - @liveexample{The example below shows how object elements can be read using - `at()`. It also demonstrates the different exceptions that can be thrown., - at__object_t_key_type_const} - */ - const_reference at(const typename object_t::key_type& key) const - { - // at only works for objects - if (JSON_HEDLEY_LIKELY(is_object())) - { - JSON_TRY - { - return m_value.object->at(key); - } - JSON_CATCH (std::out_of_range&) - { - // create better exception explanation - JSON_THROW(out_of_range::create(403, "key '" + key + "' not found")); - } - } - else - { - JSON_THROW(type_error::create(304, "cannot use at() with " + std::string(type_name()))); - } - } - - /*! - @brief access specified array element - - Returns a reference to the element at specified location @a idx. - - @note If @a idx is beyond the range of the array (i.e., `idx >= size()`), - then the array is silently filled up with `null` values to make `idx` a - valid reference to the last stored element. - - @param[in] idx index of the element to access - - @return reference to the element at index @a idx - - @throw type_error.305 if the JSON value is not an array or null; in that - cases, using the [] operator with an index makes no sense. - - @complexity Constant if @a idx is in the range of the array. Otherwise - linear in `idx - size()`. - - @liveexample{The example below shows how array elements can be read and - written using `[]` operator. Note the addition of `null` - values.,operatorarray__size_type} - - @since version 1.0.0 - */ - reference operator[](size_type idx) - { - // implicitly convert null value to an empty array - if (is_null()) - { - m_type = value_t::array; - m_value.array = create(); - assert_invariant(); - } - - // operator[] only works for arrays - if (JSON_HEDLEY_LIKELY(is_array())) - { - // fill up array with null values if given idx is outside range - if (idx >= m_value.array->size()) - { - m_value.array->insert(m_value.array->end(), - idx - m_value.array->size() + 1, - basic_json()); - } - - return m_value.array->operator[](idx); - } - - JSON_THROW(type_error::create(305, "cannot use operator[] with a numeric argument with " + std::string(type_name()))); - } - - /*! - @brief access specified array element - - Returns a const reference to the element at specified location @a idx. - - @param[in] idx index of the element to access - - @return const reference to the element at index @a idx - - @throw type_error.305 if the JSON value is not an array; in that case, - using the [] operator with an index makes no sense. - - @complexity Constant. - - @liveexample{The example below shows how array elements can be read using - the `[]` operator.,operatorarray__size_type_const} - - @since version 1.0.0 - */ - const_reference operator[](size_type idx) const - { - // const operator[] only works for arrays - if (JSON_HEDLEY_LIKELY(is_array())) - { - return m_value.array->operator[](idx); - } - - JSON_THROW(type_error::create(305, "cannot use operator[] with a numeric argument with " + std::string(type_name()))); - } - - /*! - @brief access specified object element - - Returns a reference to the element at with specified key @a key. - - @note If @a key is not found in the object, then it is silently added to - the object and filled with a `null` value to make `key` a valid reference. - In case the value was `null` before, it is converted to an object. - - @param[in] key key of the element to access - - @return reference to the element at key @a key - - @throw type_error.305 if the JSON value is not an object or null; in that - cases, using the [] operator with a key makes no sense. - - @complexity Logarithmic in the size of the container. - - @liveexample{The example below shows how object elements can be read and - written using the `[]` operator.,operatorarray__key_type} - - @sa @ref at(const typename object_t::key_type&) for access by reference - with range checking - @sa @ref value() for access by value with a default value - - @since version 1.0.0 - */ - reference operator[](const typename object_t::key_type& key) - { - // implicitly convert null value to an empty object - if (is_null()) - { - m_type = value_t::object; - m_value.object = create(); - assert_invariant(); - } - - // operator[] only works for objects - if (JSON_HEDLEY_LIKELY(is_object())) - { - return m_value.object->operator[](key); - } - - JSON_THROW(type_error::create(305, "cannot use operator[] with a string argument with " + std::string(type_name()))); - } - - /*! - @brief read-only access specified object element - - Returns a const reference to the element at with specified key @a key. No - bounds checking is performed. - - @warning If the element with key @a key does not exist, the behavior is - undefined. - - @param[in] key key of the element to access - - @return const reference to the element at key @a key - - @pre The element with key @a key must exist. **This precondition is - enforced with an assertion.** - - @throw type_error.305 if the JSON value is not an object; in that case, - using the [] operator with a key makes no sense. - - @complexity Logarithmic in the size of the container. - - @liveexample{The example below shows how object elements can be read using - the `[]` operator.,operatorarray__key_type_const} - - @sa @ref at(const typename object_t::key_type&) for access by reference - with range checking - @sa @ref value() for access by value with a default value - - @since version 1.0.0 - */ - const_reference operator[](const typename object_t::key_type& key) const - { - // const operator[] only works for objects - if (JSON_HEDLEY_LIKELY(is_object())) - { - JSON_ASSERT(m_value.object->find(key) != m_value.object->end()); - return m_value.object->find(key)->second; - } - - JSON_THROW(type_error::create(305, "cannot use operator[] with a string argument with " + std::string(type_name()))); - } - - /*! - @brief access specified object element - - Returns a reference to the element at with specified key @a key. - - @note If @a key is not found in the object, then it is silently added to - the object and filled with a `null` value to make `key` a valid reference. - In case the value was `null` before, it is converted to an object. - - @param[in] key key of the element to access - - @return reference to the element at key @a key - - @throw type_error.305 if the JSON value is not an object or null; in that - cases, using the [] operator with a key makes no sense. - - @complexity Logarithmic in the size of the container. - - @liveexample{The example below shows how object elements can be read and - written using the `[]` operator.,operatorarray__key_type} - - @sa @ref at(const typename object_t::key_type&) for access by reference - with range checking - @sa @ref value() for access by value with a default value - - @since version 1.1.0 - */ - template - JSON_HEDLEY_NON_NULL(2) - reference operator[](T* key) - { - // implicitly convert null to object - if (is_null()) - { - m_type = value_t::object; - m_value = value_t::object; - assert_invariant(); - } - - // at only works for objects - if (JSON_HEDLEY_LIKELY(is_object())) - { - return m_value.object->operator[](key); - } - - JSON_THROW(type_error::create(305, "cannot use operator[] with a string argument with " + std::string(type_name()))); - } - - /*! - @brief read-only access specified object element - - Returns a const reference to the element at with specified key @a key. No - bounds checking is performed. - - @warning If the element with key @a key does not exist, the behavior is - undefined. - - @param[in] key key of the element to access - - @return const reference to the element at key @a key - - @pre The element with key @a key must exist. **This precondition is - enforced with an assertion.** - - @throw type_error.305 if the JSON value is not an object; in that case, - using the [] operator with a key makes no sense. - - @complexity Logarithmic in the size of the container. - - @liveexample{The example below shows how object elements can be read using - the `[]` operator.,operatorarray__key_type_const} - - @sa @ref at(const typename object_t::key_type&) for access by reference - with range checking - @sa @ref value() for access by value with a default value - - @since version 1.1.0 - */ - template - JSON_HEDLEY_NON_NULL(2) - const_reference operator[](T* key) const - { - // at only works for objects - if (JSON_HEDLEY_LIKELY(is_object())) - { - JSON_ASSERT(m_value.object->find(key) != m_value.object->end()); - return m_value.object->find(key)->second; - } - - JSON_THROW(type_error::create(305, "cannot use operator[] with a string argument with " + std::string(type_name()))); - } - - /*! - @brief access specified object element with default value - - Returns either a copy of an object's element at the specified key @a key - or a given default value if no element with key @a key exists. - - The function is basically equivalent to executing - @code {.cpp} - try { - return at(key); - } catch(out_of_range) { - return default_value; - } - @endcode - - @note Unlike @ref at(const typename object_t::key_type&), this function - does not throw if the given key @a key was not found. - - @note Unlike @ref operator[](const typename object_t::key_type& key), this - function does not implicitly add an element to the position defined by @a - key. This function is furthermore also applicable to const objects. - - @param[in] key key of the element to access - @param[in] default_value the value to return if @a key is not found - - @tparam ValueType type compatible to JSON values, for instance `int` for - JSON integer numbers, `bool` for JSON booleans, or `std::vector` types for - JSON arrays. Note the type of the expected value at @a key and the default - value @a default_value must be compatible. - - @return copy of the element at key @a key or @a default_value if @a key - is not found - - @throw type_error.302 if @a default_value does not match the type of the - value at @a key - @throw type_error.306 if the JSON value is not an object; in that case, - using `value()` with a key makes no sense. - - @complexity Logarithmic in the size of the container. - - @liveexample{The example below shows how object elements can be queried - with a default value.,basic_json__value} - - @sa @ref at(const typename object_t::key_type&) for access by reference - with range checking - @sa @ref operator[](const typename object_t::key_type&) for unchecked - access by reference - - @since version 1.0.0 - */ - // using std::is_convertible in a std::enable_if will fail when using explicit conversions - template < class ValueType, typename std::enable_if < - detail::is_getable::value - && !std::is_same::value, int >::type = 0 > - ValueType value(const typename object_t::key_type& key, const ValueType& default_value) const - { - // at only works for objects - if (JSON_HEDLEY_LIKELY(is_object())) - { - // if key is found, return value and given default value otherwise - const auto it = find(key); - if (it != end()) - { - return it->template get(); - } - - return default_value; - } - - JSON_THROW(type_error::create(306, "cannot use value() with " + std::string(type_name()))); - } - - /*! - @brief overload for a default value of type const char* - @copydoc basic_json::value(const typename object_t::key_type&, const ValueType&) const - */ - string_t value(const typename object_t::key_type& key, const char* default_value) const - { - return value(key, string_t(default_value)); - } - - /*! - @brief access specified object element via JSON Pointer with default value - - Returns either a copy of an object's element at the specified key @a key - or a given default value if no element with key @a key exists. - - The function is basically equivalent to executing - @code {.cpp} - try { - return at(ptr); - } catch(out_of_range) { - return default_value; - } - @endcode - - @note Unlike @ref at(const json_pointer&), this function does not throw - if the given key @a key was not found. - - @param[in] ptr a JSON pointer to the element to access - @param[in] default_value the value to return if @a ptr found no value - - @tparam ValueType type compatible to JSON values, for instance `int` for - JSON integer numbers, `bool` for JSON booleans, or `std::vector` types for - JSON arrays. Note the type of the expected value at @a key and the default - value @a default_value must be compatible. - - @return copy of the element at key @a key or @a default_value if @a key - is not found - - @throw type_error.302 if @a default_value does not match the type of the - value at @a ptr - @throw type_error.306 if the JSON value is not an object; in that case, - using `value()` with a key makes no sense. - - @complexity Logarithmic in the size of the container. - - @liveexample{The example below shows how object elements can be queried - with a default value.,basic_json__value_ptr} - - @sa @ref operator[](const json_pointer&) for unchecked access by reference - - @since version 2.0.2 - */ - template::value, int>::type = 0> - ValueType value(const json_pointer& ptr, const ValueType& default_value) const - { - // at only works for objects - if (JSON_HEDLEY_LIKELY(is_object())) - { - // if pointer resolves a value, return it or use default value - JSON_TRY - { - return ptr.get_checked(this).template get(); - } - JSON_INTERNAL_CATCH (out_of_range&) - { - return default_value; - } - } - - JSON_THROW(type_error::create(306, "cannot use value() with " + std::string(type_name()))); - } - - /*! - @brief overload for a default value of type const char* - @copydoc basic_json::value(const json_pointer&, ValueType) const - */ - JSON_HEDLEY_NON_NULL(3) - string_t value(const json_pointer& ptr, const char* default_value) const - { - return value(ptr, string_t(default_value)); - } - - /*! - @brief access the first element - - Returns a reference to the first element in the container. For a JSON - container `c`, the expression `c.front()` is equivalent to `*c.begin()`. - - @return In case of a structured type (array or object), a reference to the - first element is returned. In case of number, string, boolean, or binary - values, a reference to the value is returned. - - @complexity Constant. - - @pre The JSON value must not be `null` (would throw `std::out_of_range`) - or an empty array or object (undefined behavior, **guarded by - assertions**). - @post The JSON value remains unchanged. - - @throw invalid_iterator.214 when called on `null` value - - @liveexample{The following code shows an example for `front()`.,front} - - @sa @ref back() -- access the last element - - @since version 1.0.0 - */ - reference front() - { - return *begin(); - } - - /*! - @copydoc basic_json::front() - */ - const_reference front() const - { - return *cbegin(); - } - - /*! - @brief access the last element - - Returns a reference to the last element in the container. For a JSON - container `c`, the expression `c.back()` is equivalent to - @code {.cpp} - auto tmp = c.end(); - --tmp; - return *tmp; - @endcode - - @return In case of a structured type (array or object), a reference to the - last element is returned. In case of number, string, boolean, or binary - values, a reference to the value is returned. - - @complexity Constant. - - @pre The JSON value must not be `null` (would throw `std::out_of_range`) - or an empty array or object (undefined behavior, **guarded by - assertions**). - @post The JSON value remains unchanged. - - @throw invalid_iterator.214 when called on a `null` value. See example - below. - - @liveexample{The following code shows an example for `back()`.,back} - - @sa @ref front() -- access the first element - - @since version 1.0.0 - */ - reference back() - { - auto tmp = end(); - --tmp; - return *tmp; - } - - /*! - @copydoc basic_json::back() - */ - const_reference back() const - { - auto tmp = cend(); - --tmp; - return *tmp; - } - - /*! - @brief remove element given an iterator - - Removes the element specified by iterator @a pos. The iterator @a pos must - be valid and dereferenceable. Thus the `end()` iterator (which is valid, - but is not dereferenceable) cannot be used as a value for @a pos. - - If called on a primitive type other than `null`, the resulting JSON value - will be `null`. - - @param[in] pos iterator to the element to remove - @return Iterator following the last removed element. If the iterator @a - pos refers to the last element, the `end()` iterator is returned. - - @tparam IteratorType an @ref iterator or @ref const_iterator - - @post Invalidates iterators and references at or after the point of the - erase, including the `end()` iterator. - - @throw type_error.307 if called on a `null` value; example: `"cannot use - erase() with null"` - @throw invalid_iterator.202 if called on an iterator which does not belong - to the current JSON value; example: `"iterator does not fit current - value"` - @throw invalid_iterator.205 if called on a primitive type with invalid - iterator (i.e., any iterator which is not `begin()`); example: `"iterator - out of range"` - - @complexity The complexity depends on the type: - - objects: amortized constant - - arrays: linear in distance between @a pos and the end of the container - - strings and binary: linear in the length of the member - - other types: constant - - @liveexample{The example shows the result of `erase()` for different JSON - types.,erase__IteratorType} - - @sa @ref erase(IteratorType, IteratorType) -- removes the elements in - the given range - @sa @ref erase(const typename object_t::key_type&) -- removes the element - from an object at the given key - @sa @ref erase(const size_type) -- removes the element from an array at - the given index - - @since version 1.0.0 - */ - template < class IteratorType, typename std::enable_if < - std::is_same::value || - std::is_same::value, int >::type - = 0 > - IteratorType erase(IteratorType pos) - { - // make sure iterator fits the current value - if (JSON_HEDLEY_UNLIKELY(this != pos.m_object)) - { - JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); - } - - IteratorType result = end(); - - switch (m_type) - { - case value_t::boolean: - case value_t::number_float: - case value_t::number_integer: - case value_t::number_unsigned: - case value_t::string: - case value_t::binary: - { - if (JSON_HEDLEY_UNLIKELY(!pos.m_it.primitive_iterator.is_begin())) - { - JSON_THROW(invalid_iterator::create(205, "iterator out of range")); - } - - if (is_string()) - { - AllocatorType alloc; - std::allocator_traits::destroy(alloc, m_value.string); - std::allocator_traits::deallocate(alloc, m_value.string, 1); - m_value.string = nullptr; - } - else if (is_binary()) - { - AllocatorType alloc; - std::allocator_traits::destroy(alloc, m_value.binary); - std::allocator_traits::deallocate(alloc, m_value.binary, 1); - m_value.binary = nullptr; - } - - m_type = value_t::null; - assert_invariant(); - break; - } - - case value_t::object: - { - result.m_it.object_iterator = m_value.object->erase(pos.m_it.object_iterator); - break; - } - - case value_t::array: - { - result.m_it.array_iterator = m_value.array->erase(pos.m_it.array_iterator); - break; - } - - default: - JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()))); - } - - return result; - } - - /*! - @brief remove elements given an iterator range - - Removes the element specified by the range `[first; last)`. The iterator - @a first does not need to be dereferenceable if `first == last`: erasing - an empty range is a no-op. - - If called on a primitive type other than `null`, the resulting JSON value - will be `null`. - - @param[in] first iterator to the beginning of the range to remove - @param[in] last iterator past the end of the range to remove - @return Iterator following the last removed element. If the iterator @a - second refers to the last element, the `end()` iterator is returned. - - @tparam IteratorType an @ref iterator or @ref const_iterator - - @post Invalidates iterators and references at or after the point of the - erase, including the `end()` iterator. - - @throw type_error.307 if called on a `null` value; example: `"cannot use - erase() with null"` - @throw invalid_iterator.203 if called on iterators which does not belong - to the current JSON value; example: `"iterators do not fit current value"` - @throw invalid_iterator.204 if called on a primitive type with invalid - iterators (i.e., if `first != begin()` and `last != end()`); example: - `"iterators out of range"` - - @complexity The complexity depends on the type: - - objects: `log(size()) + std::distance(first, last)` - - arrays: linear in the distance between @a first and @a last, plus linear - in the distance between @a last and end of the container - - strings and binary: linear in the length of the member - - other types: constant - - @liveexample{The example shows the result of `erase()` for different JSON - types.,erase__IteratorType_IteratorType} - - @sa @ref erase(IteratorType) -- removes the element at a given position - @sa @ref erase(const typename object_t::key_type&) -- removes the element - from an object at the given key - @sa @ref erase(const size_type) -- removes the element from an array at - the given index - - @since version 1.0.0 - */ - template < class IteratorType, typename std::enable_if < - std::is_same::value || - std::is_same::value, int >::type - = 0 > - IteratorType erase(IteratorType first, IteratorType last) - { - // make sure iterator fits the current value - if (JSON_HEDLEY_UNLIKELY(this != first.m_object || this != last.m_object)) - { - JSON_THROW(invalid_iterator::create(203, "iterators do not fit current value")); - } - - IteratorType result = end(); - - switch (m_type) - { - case value_t::boolean: - case value_t::number_float: - case value_t::number_integer: - case value_t::number_unsigned: - case value_t::string: - case value_t::binary: - { - if (JSON_HEDLEY_LIKELY(!first.m_it.primitive_iterator.is_begin() - || !last.m_it.primitive_iterator.is_end())) - { - JSON_THROW(invalid_iterator::create(204, "iterators out of range")); - } - - if (is_string()) - { - AllocatorType alloc; - std::allocator_traits::destroy(alloc, m_value.string); - std::allocator_traits::deallocate(alloc, m_value.string, 1); - m_value.string = nullptr; - } - else if (is_binary()) - { - AllocatorType alloc; - std::allocator_traits::destroy(alloc, m_value.binary); - std::allocator_traits::deallocate(alloc, m_value.binary, 1); - m_value.binary = nullptr; - } - - m_type = value_t::null; - assert_invariant(); - break; - } - - case value_t::object: - { - result.m_it.object_iterator = m_value.object->erase(first.m_it.object_iterator, - last.m_it.object_iterator); - break; - } - - case value_t::array: - { - result.m_it.array_iterator = m_value.array->erase(first.m_it.array_iterator, - last.m_it.array_iterator); - break; - } - - default: - JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()))); - } - - return result; - } - - /*! - @brief remove element from a JSON object given a key - - Removes elements from a JSON object with the key value @a key. - - @param[in] key value of the elements to remove - - @return Number of elements removed. If @a ObjectType is the default - `std::map` type, the return value will always be `0` (@a key was not - found) or `1` (@a key was found). - - @post References and iterators to the erased elements are invalidated. - Other references and iterators are not affected. - - @throw type_error.307 when called on a type other than JSON object; - example: `"cannot use erase() with null"` - - @complexity `log(size()) + count(key)` - - @liveexample{The example shows the effect of `erase()`.,erase__key_type} - - @sa @ref erase(IteratorType) -- removes the element at a given position - @sa @ref erase(IteratorType, IteratorType) -- removes the elements in - the given range - @sa @ref erase(const size_type) -- removes the element from an array at - the given index - - @since version 1.0.0 - */ - size_type erase(const typename object_t::key_type& key) - { - // this erase only works for objects - if (JSON_HEDLEY_LIKELY(is_object())) - { - return m_value.object->erase(key); - } - - JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()))); - } - - /*! - @brief remove element from a JSON array given an index - - Removes element from a JSON array at the index @a idx. - - @param[in] idx index of the element to remove - - @throw type_error.307 when called on a type other than JSON object; - example: `"cannot use erase() with null"` - @throw out_of_range.401 when `idx >= size()`; example: `"array index 17 - is out of range"` - - @complexity Linear in distance between @a idx and the end of the container. - - @liveexample{The example shows the effect of `erase()`.,erase__size_type} - - @sa @ref erase(IteratorType) -- removes the element at a given position - @sa @ref erase(IteratorType, IteratorType) -- removes the elements in - the given range - @sa @ref erase(const typename object_t::key_type&) -- removes the element - from an object at the given key - - @since version 1.0.0 - */ - void erase(const size_type idx) - { - // this erase only works for arrays - if (JSON_HEDLEY_LIKELY(is_array())) - { - if (JSON_HEDLEY_UNLIKELY(idx >= size())) - { - JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range")); - } - - m_value.array->erase(m_value.array->begin() + static_cast(idx)); - } - else - { - JSON_THROW(type_error::create(307, "cannot use erase() with " + std::string(type_name()))); - } - } - - /// @} - - - //////////// - // lookup // - //////////// - - /// @name lookup - /// @{ - - /*! - @brief find an element in a JSON object - - Finds an element in a JSON object with key equivalent to @a key. If the - element is not found or the JSON value is not an object, end() is - returned. - - @note This method always returns @ref end() when executed on a JSON type - that is not an object. - - @param[in] key key value of the element to search for. - - @return Iterator to an element with key equivalent to @a key. If no such - element is found or the JSON value is not an object, past-the-end (see - @ref end()) iterator is returned. - - @complexity Logarithmic in the size of the JSON object. - - @liveexample{The example shows how `find()` is used.,find__key_type} - - @sa @ref contains(KeyT&&) const -- checks whether a key exists - - @since version 1.0.0 - */ - template - iterator find(KeyT&& key) - { - auto result = end(); - - if (is_object()) - { - result.m_it.object_iterator = m_value.object->find(std::forward(key)); - } - - return result; - } - - /*! - @brief find an element in a JSON object - @copydoc find(KeyT&&) - */ - template - const_iterator find(KeyT&& key) const - { - auto result = cend(); - - if (is_object()) - { - result.m_it.object_iterator = m_value.object->find(std::forward(key)); - } - - return result; - } - - /*! - @brief returns the number of occurrences of a key in a JSON object - - Returns the number of elements with key @a key. If ObjectType is the - default `std::map` type, the return value will always be `0` (@a key was - not found) or `1` (@a key was found). - - @note This method always returns `0` when executed on a JSON type that is - not an object. - - @param[in] key key value of the element to count - - @return Number of elements with key @a key. If the JSON value is not an - object, the return value will be `0`. - - @complexity Logarithmic in the size of the JSON object. - - @liveexample{The example shows how `count()` is used.,count} - - @since version 1.0.0 - */ - template - size_type count(KeyT&& key) const - { - // return 0 for all nonobject types - return is_object() ? m_value.object->count(std::forward(key)) : 0; - } - - /*! - @brief check the existence of an element in a JSON object - - Check whether an element exists in a JSON object with key equivalent to - @a key. If the element is not found or the JSON value is not an object, - false is returned. - - @note This method always returns false when executed on a JSON type - that is not an object. - - @param[in] key key value to check its existence. - - @return true if an element with specified @a key exists. If no such - element with such key is found or the JSON value is not an object, - false is returned. - - @complexity Logarithmic in the size of the JSON object. - - @liveexample{The following code shows an example for `contains()`.,contains} - - @sa @ref find(KeyT&&) -- returns an iterator to an object element - @sa @ref contains(const json_pointer&) const -- checks the existence for a JSON pointer - - @since version 3.6.0 - */ - template < typename KeyT, typename std::enable_if < - !std::is_same::type, json_pointer>::value, int >::type = 0 > - bool contains(KeyT && key) const - { - return is_object() && m_value.object->find(std::forward(key)) != m_value.object->end(); - } - - /*! - @brief check the existence of an element in a JSON object given a JSON pointer - - Check whether the given JSON pointer @a ptr can be resolved in the current - JSON value. - - @note This method can be executed on any JSON value type. - - @param[in] ptr JSON pointer to check its existence. - - @return true if the JSON pointer can be resolved to a stored value, false - otherwise. - - @post If `j.contains(ptr)` returns true, it is safe to call `j[ptr]`. - - @throw parse_error.106 if an array index begins with '0' - @throw parse_error.109 if an array index was not a number - - @complexity Logarithmic in the size of the JSON object. - - @liveexample{The following code shows an example for `contains()`.,contains_json_pointer} - - @sa @ref contains(KeyT &&) const -- checks the existence of a key - - @since version 3.7.0 - */ - bool contains(const json_pointer& ptr) const - { - return ptr.contains(this); - } - - /// @} - - - /////////////// - // iterators // - /////////////// - - /// @name iterators - /// @{ - - /*! - @brief returns an iterator to the first element - - Returns an iterator to the first element. - - @image html range-begin-end.svg "Illustration from cppreference.com" - - @return iterator to the first element - - @complexity Constant. - - @requirement This function helps `basic_json` satisfying the - [Container](https://en.cppreference.com/w/cpp/named_req/Container) - requirements: - - The complexity is constant. - - @liveexample{The following code shows an example for `begin()`.,begin} - - @sa @ref cbegin() -- returns a const iterator to the beginning - @sa @ref end() -- returns an iterator to the end - @sa @ref cend() -- returns a const iterator to the end - - @since version 1.0.0 - */ - iterator begin() noexcept - { - iterator result(this); - result.set_begin(); - return result; - } - - /*! - @copydoc basic_json::cbegin() - */ - const_iterator begin() const noexcept - { - return cbegin(); - } - - /*! - @brief returns a const iterator to the first element - - Returns a const iterator to the first element. - - @image html range-begin-end.svg "Illustration from cppreference.com" - - @return const iterator to the first element - - @complexity Constant. - - @requirement This function helps `basic_json` satisfying the - [Container](https://en.cppreference.com/w/cpp/named_req/Container) - requirements: - - The complexity is constant. - - Has the semantics of `const_cast(*this).begin()`. - - @liveexample{The following code shows an example for `cbegin()`.,cbegin} - - @sa @ref begin() -- returns an iterator to the beginning - @sa @ref end() -- returns an iterator to the end - @sa @ref cend() -- returns a const iterator to the end - - @since version 1.0.0 - */ - const_iterator cbegin() const noexcept - { - const_iterator result(this); - result.set_begin(); - return result; - } - - /*! - @brief returns an iterator to one past the last element - - Returns an iterator to one past the last element. - - @image html range-begin-end.svg "Illustration from cppreference.com" - - @return iterator one past the last element - - @complexity Constant. - - @requirement This function helps `basic_json` satisfying the - [Container](https://en.cppreference.com/w/cpp/named_req/Container) - requirements: - - The complexity is constant. - - @liveexample{The following code shows an example for `end()`.,end} - - @sa @ref cend() -- returns a const iterator to the end - @sa @ref begin() -- returns an iterator to the beginning - @sa @ref cbegin() -- returns a const iterator to the beginning - - @since version 1.0.0 - */ - iterator end() noexcept - { - iterator result(this); - result.set_end(); - return result; - } - - /*! - @copydoc basic_json::cend() - */ - const_iterator end() const noexcept - { - return cend(); - } - - /*! - @brief returns a const iterator to one past the last element - - Returns a const iterator to one past the last element. - - @image html range-begin-end.svg "Illustration from cppreference.com" - - @return const iterator one past the last element - - @complexity Constant. - - @requirement This function helps `basic_json` satisfying the - [Container](https://en.cppreference.com/w/cpp/named_req/Container) - requirements: - - The complexity is constant. - - Has the semantics of `const_cast(*this).end()`. - - @liveexample{The following code shows an example for `cend()`.,cend} - - @sa @ref end() -- returns an iterator to the end - @sa @ref begin() -- returns an iterator to the beginning - @sa @ref cbegin() -- returns a const iterator to the beginning - - @since version 1.0.0 - */ - const_iterator cend() const noexcept - { - const_iterator result(this); - result.set_end(); - return result; - } - - /*! - @brief returns an iterator to the reverse-beginning - - Returns an iterator to the reverse-beginning; that is, the last element. - - @image html range-rbegin-rend.svg "Illustration from cppreference.com" - - @complexity Constant. - - @requirement This function helps `basic_json` satisfying the - [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer) - requirements: - - The complexity is constant. - - Has the semantics of `reverse_iterator(end())`. - - @liveexample{The following code shows an example for `rbegin()`.,rbegin} - - @sa @ref crbegin() -- returns a const reverse iterator to the beginning - @sa @ref rend() -- returns a reverse iterator to the end - @sa @ref crend() -- returns a const reverse iterator to the end - - @since version 1.0.0 - */ - reverse_iterator rbegin() noexcept - { - return reverse_iterator(end()); - } - - /*! - @copydoc basic_json::crbegin() - */ - const_reverse_iterator rbegin() const noexcept - { - return crbegin(); - } - - /*! - @brief returns an iterator to the reverse-end - - Returns an iterator to the reverse-end; that is, one before the first - element. - - @image html range-rbegin-rend.svg "Illustration from cppreference.com" - - @complexity Constant. - - @requirement This function helps `basic_json` satisfying the - [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer) - requirements: - - The complexity is constant. - - Has the semantics of `reverse_iterator(begin())`. - - @liveexample{The following code shows an example for `rend()`.,rend} - - @sa @ref crend() -- returns a const reverse iterator to the end - @sa @ref rbegin() -- returns a reverse iterator to the beginning - @sa @ref crbegin() -- returns a const reverse iterator to the beginning - - @since version 1.0.0 - */ - reverse_iterator rend() noexcept - { - return reverse_iterator(begin()); - } - - /*! - @copydoc basic_json::crend() - */ - const_reverse_iterator rend() const noexcept - { - return crend(); - } - - /*! - @brief returns a const reverse iterator to the last element - - Returns a const iterator to the reverse-beginning; that is, the last - element. - - @image html range-rbegin-rend.svg "Illustration from cppreference.com" - - @complexity Constant. - - @requirement This function helps `basic_json` satisfying the - [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer) - requirements: - - The complexity is constant. - - Has the semantics of `const_cast(*this).rbegin()`. - - @liveexample{The following code shows an example for `crbegin()`.,crbegin} - - @sa @ref rbegin() -- returns a reverse iterator to the beginning - @sa @ref rend() -- returns a reverse iterator to the end - @sa @ref crend() -- returns a const reverse iterator to the end - - @since version 1.0.0 - */ - const_reverse_iterator crbegin() const noexcept - { - return const_reverse_iterator(cend()); - } - - /*! - @brief returns a const reverse iterator to one before the first - - Returns a const reverse iterator to the reverse-end; that is, one before - the first element. - - @image html range-rbegin-rend.svg "Illustration from cppreference.com" - - @complexity Constant. - - @requirement This function helps `basic_json` satisfying the - [ReversibleContainer](https://en.cppreference.com/w/cpp/named_req/ReversibleContainer) - requirements: - - The complexity is constant. - - Has the semantics of `const_cast(*this).rend()`. - - @liveexample{The following code shows an example for `crend()`.,crend} - - @sa @ref rend() -- returns a reverse iterator to the end - @sa @ref rbegin() -- returns a reverse iterator to the beginning - @sa @ref crbegin() -- returns a const reverse iterator to the beginning - - @since version 1.0.0 - */ - const_reverse_iterator crend() const noexcept - { - return const_reverse_iterator(cbegin()); - } - - public: - /*! - @brief wrapper to access iterator member functions in range-based for - - This function allows to access @ref iterator::key() and @ref - iterator::value() during range-based for loops. In these loops, a - reference to the JSON values is returned, so there is no access to the - underlying iterator. - - For loop without iterator_wrapper: - - @code{cpp} - for (auto it = j_object.begin(); it != j_object.end(); ++it) - { - std::cout << "key: " << it.key() << ", value:" << it.value() << '\n'; - } - @endcode - - Range-based for loop without iterator proxy: - - @code{cpp} - for (auto it : j_object) - { - // "it" is of type json::reference and has no key() member - std::cout << "value: " << it << '\n'; - } - @endcode - - Range-based for loop with iterator proxy: - - @code{cpp} - for (auto it : json::iterator_wrapper(j_object)) - { - std::cout << "key: " << it.key() << ", value:" << it.value() << '\n'; - } - @endcode - - @note When iterating over an array, `key()` will return the index of the - element as string (see example). - - @param[in] ref reference to a JSON value - @return iteration proxy object wrapping @a ref with an interface to use in - range-based for loops - - @liveexample{The following code shows how the wrapper is used,iterator_wrapper} - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes in the JSON value. - - @complexity Constant. - - @note The name of this function is not yet final and may change in the - future. - - @deprecated This stream operator is deprecated and will be removed in - future 4.0.0 of the library. Please use @ref items() instead; - that is, replace `json::iterator_wrapper(j)` with `j.items()`. - */ - JSON_HEDLEY_DEPRECATED_FOR(3.1.0, items()) - static iteration_proxy iterator_wrapper(reference ref) noexcept - { - return ref.items(); - } - - /*! - @copydoc iterator_wrapper(reference) - */ - JSON_HEDLEY_DEPRECATED_FOR(3.1.0, items()) - static iteration_proxy iterator_wrapper(const_reference ref) noexcept - { - return ref.items(); - } - - /*! - @brief helper to access iterator member functions in range-based for - - This function allows to access @ref iterator::key() and @ref - iterator::value() during range-based for loops. In these loops, a - reference to the JSON values is returned, so there is no access to the - underlying iterator. - - For loop without `items()` function: - - @code{cpp} - for (auto it = j_object.begin(); it != j_object.end(); ++it) - { - std::cout << "key: " << it.key() << ", value:" << it.value() << '\n'; - } - @endcode - - Range-based for loop without `items()` function: - - @code{cpp} - for (auto it : j_object) - { - // "it" is of type json::reference and has no key() member - std::cout << "value: " << it << '\n'; - } - @endcode - - Range-based for loop with `items()` function: - - @code{cpp} - for (auto& el : j_object.items()) - { - std::cout << "key: " << el.key() << ", value:" << el.value() << '\n'; - } - @endcode - - The `items()` function also allows to use - [structured bindings](https://en.cppreference.com/w/cpp/language/structured_binding) - (C++17): - - @code{cpp} - for (auto& [key, val] : j_object.items()) - { - std::cout << "key: " << key << ", value:" << val << '\n'; - } - @endcode - - @note When iterating over an array, `key()` will return the index of the - element as string (see example). For primitive types (e.g., numbers), - `key()` returns an empty string. - - @warning Using `items()` on temporary objects is dangerous. Make sure the - object's lifetime exeeds the iteration. See - for more - information. - - @return iteration proxy object wrapping @a ref with an interface to use in - range-based for loops - - @liveexample{The following code shows how the function is used.,items} - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes in the JSON value. - - @complexity Constant. - - @since version 3.1.0, structured bindings support since 3.5.0. - */ - iteration_proxy items() noexcept - { - return iteration_proxy(*this); - } - - /*! - @copydoc items() - */ - iteration_proxy items() const noexcept - { - return iteration_proxy(*this); - } - - /// @} - - - ////////////// - // capacity // - ////////////// - - /// @name capacity - /// @{ - - /*! - @brief checks whether the container is empty. - - Checks if a JSON value has no elements (i.e. whether its @ref size is `0`). - - @return The return value depends on the different types and is - defined as follows: - Value type | return value - ----------- | ------------- - null | `true` - boolean | `false` - string | `false` - number | `false` - binary | `false` - object | result of function `object_t::empty()` - array | result of function `array_t::empty()` - - @liveexample{The following code uses `empty()` to check if a JSON - object contains any elements.,empty} - - @complexity Constant, as long as @ref array_t and @ref object_t satisfy - the Container concept; that is, their `empty()` functions have constant - complexity. - - @iterators No changes. - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - - @note This function does not return whether a string stored as JSON value - is empty - it returns whether the JSON container itself is empty which is - false in the case of a string. - - @requirement This function helps `basic_json` satisfying the - [Container](https://en.cppreference.com/w/cpp/named_req/Container) - requirements: - - The complexity is constant. - - Has the semantics of `begin() == end()`. - - @sa @ref size() -- returns the number of elements - - @since version 1.0.0 - */ - bool empty() const noexcept - { - switch (m_type) - { - case value_t::null: - { - // null values are empty - return true; - } - - case value_t::array: - { - // delegate call to array_t::empty() - return m_value.array->empty(); - } - - case value_t::object: - { - // delegate call to object_t::empty() - return m_value.object->empty(); - } - - default: - { - // all other types are nonempty - return false; - } - } - } - - /*! - @brief returns the number of elements - - Returns the number of elements in a JSON value. - - @return The return value depends on the different types and is - defined as follows: - Value type | return value - ----------- | ------------- - null | `0` - boolean | `1` - string | `1` - number | `1` - binary | `1` - object | result of function object_t::size() - array | result of function array_t::size() - - @liveexample{The following code calls `size()` on the different value - types.,size} - - @complexity Constant, as long as @ref array_t and @ref object_t satisfy - the Container concept; that is, their size() functions have constant - complexity. - - @iterators No changes. - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - - @note This function does not return the length of a string stored as JSON - value - it returns the number of elements in the JSON value which is 1 in - the case of a string. - - @requirement This function helps `basic_json` satisfying the - [Container](https://en.cppreference.com/w/cpp/named_req/Container) - requirements: - - The complexity is constant. - - Has the semantics of `std::distance(begin(), end())`. - - @sa @ref empty() -- checks whether the container is empty - @sa @ref max_size() -- returns the maximal number of elements - - @since version 1.0.0 - */ - size_type size() const noexcept - { - switch (m_type) - { - case value_t::null: - { - // null values are empty - return 0; - } - - case value_t::array: - { - // delegate call to array_t::size() - return m_value.array->size(); - } - - case value_t::object: - { - // delegate call to object_t::size() - return m_value.object->size(); - } - - default: - { - // all other types have size 1 - return 1; - } - } - } - - /*! - @brief returns the maximum possible number of elements - - Returns the maximum number of elements a JSON value is able to hold due to - system or library implementation limitations, i.e. `std::distance(begin(), - end())` for the JSON value. - - @return The return value depends on the different types and is - defined as follows: - Value type | return value - ----------- | ------------- - null | `0` (same as `size()`) - boolean | `1` (same as `size()`) - string | `1` (same as `size()`) - number | `1` (same as `size()`) - binary | `1` (same as `size()`) - object | result of function `object_t::max_size()` - array | result of function `array_t::max_size()` - - @liveexample{The following code calls `max_size()` on the different value - types. Note the output is implementation specific.,max_size} - - @complexity Constant, as long as @ref array_t and @ref object_t satisfy - the Container concept; that is, their `max_size()` functions have constant - complexity. - - @iterators No changes. - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - - @requirement This function helps `basic_json` satisfying the - [Container](https://en.cppreference.com/w/cpp/named_req/Container) - requirements: - - The complexity is constant. - - Has the semantics of returning `b.size()` where `b` is the largest - possible JSON value. - - @sa @ref size() -- returns the number of elements - - @since version 1.0.0 - */ - size_type max_size() const noexcept - { - switch (m_type) - { - case value_t::array: - { - // delegate call to array_t::max_size() - return m_value.array->max_size(); - } - - case value_t::object: - { - // delegate call to object_t::max_size() - return m_value.object->max_size(); - } - - default: - { - // all other types have max_size() == size() - return size(); - } - } - } - - /// @} - - - /////////////// - // modifiers // - /////////////// - - /// @name modifiers - /// @{ - - /*! - @brief clears the contents - - Clears the content of a JSON value and resets it to the default value as - if @ref basic_json(value_t) would have been called with the current value - type from @ref type(): - - Value type | initial value - ----------- | ------------- - null | `null` - boolean | `false` - string | `""` - number | `0` - binary | An empty byte vector - object | `{}` - array | `[]` - - @post Has the same effect as calling - @code {.cpp} - *this = basic_json(type()); - @endcode - - @liveexample{The example below shows the effect of `clear()` to different - JSON types.,clear} - - @complexity Linear in the size of the JSON value. - - @iterators All iterators, pointers and references related to this container - are invalidated. - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - - @sa @ref basic_json(value_t) -- constructor that creates an object with the - same value than calling `clear()` - - @since version 1.0.0 - */ - void clear() noexcept - { - switch (m_type) - { - case value_t::number_integer: - { - m_value.number_integer = 0; - break; - } - - case value_t::number_unsigned: - { - m_value.number_unsigned = 0; - break; - } - - case value_t::number_float: - { - m_value.number_float = 0.0; - break; - } - - case value_t::boolean: - { - m_value.boolean = false; - break; - } - - case value_t::string: - { - m_value.string->clear(); - break; - } - - case value_t::binary: - { - m_value.binary->clear(); - break; - } - - case value_t::array: - { - m_value.array->clear(); - break; - } - - case value_t::object: - { - m_value.object->clear(); - break; - } - - default: - break; - } - } - - /*! - @brief add an object to an array - - Appends the given element @a val to the end of the JSON value. If the - function is called on a JSON null value, an empty array is created before - appending @a val. - - @param[in] val the value to add to the JSON array - - @throw type_error.308 when called on a type other than JSON array or - null; example: `"cannot use push_back() with number"` - - @complexity Amortized constant. - - @liveexample{The example shows how `push_back()` and `+=` can be used to - add elements to a JSON array. Note how the `null` value was silently - converted to a JSON array.,push_back} - - @since version 1.0.0 - */ - void push_back(basic_json&& val) - { - // push_back only works for null objects or arrays - if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_array()))) - { - JSON_THROW(type_error::create(308, "cannot use push_back() with " + std::string(type_name()))); - } - - // transform null object into an array - if (is_null()) - { - m_type = value_t::array; - m_value = value_t::array; - assert_invariant(); - } - - // add element to array (move semantics) - m_value.array->push_back(std::move(val)); - // if val is moved from, basic_json move constructor marks it null so we do not call the destructor - } - - /*! - @brief add an object to an array - @copydoc push_back(basic_json&&) - */ - reference operator+=(basic_json&& val) - { - push_back(std::move(val)); - return *this; - } - - /*! - @brief add an object to an array - @copydoc push_back(basic_json&&) - */ - void push_back(const basic_json& val) - { - // push_back only works for null objects or arrays - if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_array()))) - { - JSON_THROW(type_error::create(308, "cannot use push_back() with " + std::string(type_name()))); - } - - // transform null object into an array - if (is_null()) - { - m_type = value_t::array; - m_value = value_t::array; - assert_invariant(); - } - - // add element to array - m_value.array->push_back(val); - } - - /*! - @brief add an object to an array - @copydoc push_back(basic_json&&) - */ - reference operator+=(const basic_json& val) - { - push_back(val); - return *this; - } - - /*! - @brief add an object to an object - - Inserts the given element @a val to the JSON object. If the function is - called on a JSON null value, an empty object is created before inserting - @a val. - - @param[in] val the value to add to the JSON object - - @throw type_error.308 when called on a type other than JSON object or - null; example: `"cannot use push_back() with number"` - - @complexity Logarithmic in the size of the container, O(log(`size()`)). - - @liveexample{The example shows how `push_back()` and `+=` can be used to - add elements to a JSON object. Note how the `null` value was silently - converted to a JSON object.,push_back__object_t__value} - - @since version 1.0.0 - */ - void push_back(const typename object_t::value_type& val) - { - // push_back only works for null objects or objects - if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_object()))) - { - JSON_THROW(type_error::create(308, "cannot use push_back() with " + std::string(type_name()))); - } - - // transform null object into an object - if (is_null()) - { - m_type = value_t::object; - m_value = value_t::object; - assert_invariant(); - } - - // add element to array - m_value.object->insert(val); - } - - /*! - @brief add an object to an object - @copydoc push_back(const typename object_t::value_type&) - */ - reference operator+=(const typename object_t::value_type& val) - { - push_back(val); - return *this; - } - - /*! - @brief add an object to an object - - This function allows to use `push_back` with an initializer list. In case - - 1. the current value is an object, - 2. the initializer list @a init contains only two elements, and - 3. the first element of @a init is a string, - - @a init is converted into an object element and added using - @ref push_back(const typename object_t::value_type&). Otherwise, @a init - is converted to a JSON value and added using @ref push_back(basic_json&&). - - @param[in] init an initializer list - - @complexity Linear in the size of the initializer list @a init. - - @note This function is required to resolve an ambiguous overload error, - because pairs like `{"key", "value"}` can be both interpreted as - `object_t::value_type` or `std::initializer_list`, see - https://github.com/nlohmann/json/issues/235 for more information. - - @liveexample{The example shows how initializer lists are treated as - objects when possible.,push_back__initializer_list} - */ - void push_back(initializer_list_t init) - { - if (is_object() && init.size() == 2 && (*init.begin())->is_string()) - { - basic_json&& key = init.begin()->moved_or_copied(); - push_back(typename object_t::value_type( - std::move(key.get_ref()), (init.begin() + 1)->moved_or_copied())); - } - else - { - push_back(basic_json(init)); - } - } - - /*! - @brief add an object to an object - @copydoc push_back(initializer_list_t) - */ - reference operator+=(initializer_list_t init) - { - push_back(init); - return *this; - } - - /*! - @brief add an object to an array - - Creates a JSON value from the passed parameters @a args to the end of the - JSON value. If the function is called on a JSON null value, an empty array - is created before appending the value created from @a args. - - @param[in] args arguments to forward to a constructor of @ref basic_json - @tparam Args compatible types to create a @ref basic_json object - - @return reference to the inserted element - - @throw type_error.311 when called on a type other than JSON array or - null; example: `"cannot use emplace_back() with number"` - - @complexity Amortized constant. - - @liveexample{The example shows how `push_back()` can be used to add - elements to a JSON array. Note how the `null` value was silently converted - to a JSON array.,emplace_back} - - @since version 2.0.8, returns reference since 3.7.0 - */ - template - reference emplace_back(Args&& ... args) - { - // emplace_back only works for null objects or arrays - if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_array()))) - { - JSON_THROW(type_error::create(311, "cannot use emplace_back() with " + std::string(type_name()))); - } - - // transform null object into an array - if (is_null()) - { - m_type = value_t::array; - m_value = value_t::array; - assert_invariant(); - } - - // add element to array (perfect forwarding) -#ifdef JSON_HAS_CPP_17 - return m_value.array->emplace_back(std::forward(args)...); -#else - m_value.array->emplace_back(std::forward(args)...); - return m_value.array->back(); -#endif - } - - /*! - @brief add an object to an object if key does not exist - - Inserts a new element into a JSON object constructed in-place with the - given @a args if there is no element with the key in the container. If the - function is called on a JSON null value, an empty object is created before - appending the value created from @a args. - - @param[in] args arguments to forward to a constructor of @ref basic_json - @tparam Args compatible types to create a @ref basic_json object - - @return a pair consisting of an iterator to the inserted element, or the - already-existing element if no insertion happened, and a bool - denoting whether the insertion took place. - - @throw type_error.311 when called on a type other than JSON object or - null; example: `"cannot use emplace() with number"` - - @complexity Logarithmic in the size of the container, O(log(`size()`)). - - @liveexample{The example shows how `emplace()` can be used to add elements - to a JSON object. Note how the `null` value was silently converted to a - JSON object. Further note how no value is added if there was already one - value stored with the same key.,emplace} - - @since version 2.0.8 - */ - template - std::pair emplace(Args&& ... args) - { - // emplace only works for null objects or arrays - if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_object()))) - { - JSON_THROW(type_error::create(311, "cannot use emplace() with " + std::string(type_name()))); - } - - // transform null object into an object - if (is_null()) - { - m_type = value_t::object; - m_value = value_t::object; - assert_invariant(); - } - - // add element to array (perfect forwarding) - auto res = m_value.object->emplace(std::forward(args)...); - // create result iterator and set iterator to the result of emplace - auto it = begin(); - it.m_it.object_iterator = res.first; - - // return pair of iterator and boolean - return {it, res.second}; - } - - /// Helper for insertion of an iterator - /// @note: This uses std::distance to support GCC 4.8, - /// see https://github.com/nlohmann/json/pull/1257 - template - iterator insert_iterator(const_iterator pos, Args&& ... args) - { - iterator result(this); - JSON_ASSERT(m_value.array != nullptr); - - auto insert_pos = std::distance(m_value.array->begin(), pos.m_it.array_iterator); - m_value.array->insert(pos.m_it.array_iterator, std::forward(args)...); - result.m_it.array_iterator = m_value.array->begin() + insert_pos; - - // This could have been written as: - // result.m_it.array_iterator = m_value.array->insert(pos.m_it.array_iterator, cnt, val); - // but the return value of insert is missing in GCC 4.8, so it is written this way instead. - - return result; - } - - /*! - @brief inserts element - - Inserts element @a val before iterator @a pos. - - @param[in] pos iterator before which the content will be inserted; may be - the end() iterator - @param[in] val element to insert - @return iterator pointing to the inserted @a val. - - @throw type_error.309 if called on JSON values other than arrays; - example: `"cannot use insert() with string"` - @throw invalid_iterator.202 if @a pos is not an iterator of *this; - example: `"iterator does not fit current value"` - - @complexity Constant plus linear in the distance between @a pos and end of - the container. - - @liveexample{The example shows how `insert()` is used.,insert} - - @since version 1.0.0 - */ - iterator insert(const_iterator pos, const basic_json& val) - { - // insert only works for arrays - if (JSON_HEDLEY_LIKELY(is_array())) - { - // check if iterator pos fits to this JSON value - if (JSON_HEDLEY_UNLIKELY(pos.m_object != this)) - { - JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); - } - - // insert to array and return iterator - return insert_iterator(pos, val); - } - - JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); - } - - /*! - @brief inserts element - @copydoc insert(const_iterator, const basic_json&) - */ - iterator insert(const_iterator pos, basic_json&& val) - { - return insert(pos, val); - } - - /*! - @brief inserts elements - - Inserts @a cnt copies of @a val before iterator @a pos. - - @param[in] pos iterator before which the content will be inserted; may be - the end() iterator - @param[in] cnt number of copies of @a val to insert - @param[in] val element to insert - @return iterator pointing to the first element inserted, or @a pos if - `cnt==0` - - @throw type_error.309 if called on JSON values other than arrays; example: - `"cannot use insert() with string"` - @throw invalid_iterator.202 if @a pos is not an iterator of *this; - example: `"iterator does not fit current value"` - - @complexity Linear in @a cnt plus linear in the distance between @a pos - and end of the container. - - @liveexample{The example shows how `insert()` is used.,insert__count} - - @since version 1.0.0 - */ - iterator insert(const_iterator pos, size_type cnt, const basic_json& val) - { - // insert only works for arrays - if (JSON_HEDLEY_LIKELY(is_array())) - { - // check if iterator pos fits to this JSON value - if (JSON_HEDLEY_UNLIKELY(pos.m_object != this)) - { - JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); - } - - // insert to array and return iterator - return insert_iterator(pos, cnt, val); - } - - JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); - } - - /*! - @brief inserts elements - - Inserts elements from range `[first, last)` before iterator @a pos. - - @param[in] pos iterator before which the content will be inserted; may be - the end() iterator - @param[in] first begin of the range of elements to insert - @param[in] last end of the range of elements to insert - - @throw type_error.309 if called on JSON values other than arrays; example: - `"cannot use insert() with string"` - @throw invalid_iterator.202 if @a pos is not an iterator of *this; - example: `"iterator does not fit current value"` - @throw invalid_iterator.210 if @a first and @a last do not belong to the - same JSON value; example: `"iterators do not fit"` - @throw invalid_iterator.211 if @a first or @a last are iterators into - container for which insert is called; example: `"passed iterators may not - belong to container"` - - @return iterator pointing to the first element inserted, or @a pos if - `first==last` - - @complexity Linear in `std::distance(first, last)` plus linear in the - distance between @a pos and end of the container. - - @liveexample{The example shows how `insert()` is used.,insert__range} - - @since version 1.0.0 - */ - iterator insert(const_iterator pos, const_iterator first, const_iterator last) - { - // insert only works for arrays - if (JSON_HEDLEY_UNLIKELY(!is_array())) - { - JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); - } - - // check if iterator pos fits to this JSON value - if (JSON_HEDLEY_UNLIKELY(pos.m_object != this)) - { - JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); - } - - // check if range iterators belong to the same JSON object - if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object)) - { - JSON_THROW(invalid_iterator::create(210, "iterators do not fit")); - } - - if (JSON_HEDLEY_UNLIKELY(first.m_object == this)) - { - JSON_THROW(invalid_iterator::create(211, "passed iterators may not belong to container")); - } - - // insert to array and return iterator - return insert_iterator(pos, first.m_it.array_iterator, last.m_it.array_iterator); - } - - /*! - @brief inserts elements - - Inserts elements from initializer list @a ilist before iterator @a pos. - - @param[in] pos iterator before which the content will be inserted; may be - the end() iterator - @param[in] ilist initializer list to insert the values from - - @throw type_error.309 if called on JSON values other than arrays; example: - `"cannot use insert() with string"` - @throw invalid_iterator.202 if @a pos is not an iterator of *this; - example: `"iterator does not fit current value"` - - @return iterator pointing to the first element inserted, or @a pos if - `ilist` is empty - - @complexity Linear in `ilist.size()` plus linear in the distance between - @a pos and end of the container. - - @liveexample{The example shows how `insert()` is used.,insert__ilist} - - @since version 1.0.0 - */ - iterator insert(const_iterator pos, initializer_list_t ilist) - { - // insert only works for arrays - if (JSON_HEDLEY_UNLIKELY(!is_array())) - { - JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); - } - - // check if iterator pos fits to this JSON value - if (JSON_HEDLEY_UNLIKELY(pos.m_object != this)) - { - JSON_THROW(invalid_iterator::create(202, "iterator does not fit current value")); - } - - // insert to array and return iterator - return insert_iterator(pos, ilist.begin(), ilist.end()); - } - - /*! - @brief inserts elements - - Inserts elements from range `[first, last)`. - - @param[in] first begin of the range of elements to insert - @param[in] last end of the range of elements to insert - - @throw type_error.309 if called on JSON values other than objects; example: - `"cannot use insert() with string"` - @throw invalid_iterator.202 if iterator @a first or @a last does does not - point to an object; example: `"iterators first and last must point to - objects"` - @throw invalid_iterator.210 if @a first and @a last do not belong to the - same JSON value; example: `"iterators do not fit"` - - @complexity Logarithmic: `O(N*log(size() + N))`, where `N` is the number - of elements to insert. - - @liveexample{The example shows how `insert()` is used.,insert__range_object} - - @since version 3.0.0 - */ - void insert(const_iterator first, const_iterator last) - { - // insert only works for objects - if (JSON_HEDLEY_UNLIKELY(!is_object())) - { - JSON_THROW(type_error::create(309, "cannot use insert() with " + std::string(type_name()))); - } - - // check if range iterators belong to the same JSON object - if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object)) - { - JSON_THROW(invalid_iterator::create(210, "iterators do not fit")); - } - - // passed iterators must belong to objects - if (JSON_HEDLEY_UNLIKELY(!first.m_object->is_object())) - { - JSON_THROW(invalid_iterator::create(202, "iterators first and last must point to objects")); - } - - m_value.object->insert(first.m_it.object_iterator, last.m_it.object_iterator); - } - - /*! - @brief updates a JSON object from another object, overwriting existing keys - - Inserts all values from JSON object @a j and overwrites existing keys. - - @param[in] j JSON object to read values from - - @throw type_error.312 if called on JSON values other than objects; example: - `"cannot use update() with string"` - - @complexity O(N*log(size() + N)), where N is the number of elements to - insert. - - @liveexample{The example shows how `update()` is used.,update} - - @sa https://docs.python.org/3.6/library/stdtypes.html#dict.update - - @since version 3.0.0 - */ - void update(const_reference j) - { - // implicitly convert null value to an empty object - if (is_null()) - { - m_type = value_t::object; - m_value.object = create(); - assert_invariant(); - } - - if (JSON_HEDLEY_UNLIKELY(!is_object())) - { - JSON_THROW(type_error::create(312, "cannot use update() with " + std::string(type_name()))); - } - if (JSON_HEDLEY_UNLIKELY(!j.is_object())) - { - JSON_THROW(type_error::create(312, "cannot use update() with " + std::string(j.type_name()))); - } - - for (auto it = j.cbegin(); it != j.cend(); ++it) - { - m_value.object->operator[](it.key()) = it.value(); - } - } - - /*! - @brief updates a JSON object from another object, overwriting existing keys - - Inserts all values from from range `[first, last)` and overwrites existing - keys. - - @param[in] first begin of the range of elements to insert - @param[in] last end of the range of elements to insert - - @throw type_error.312 if called on JSON values other than objects; example: - `"cannot use update() with string"` - @throw invalid_iterator.202 if iterator @a first or @a last does does not - point to an object; example: `"iterators first and last must point to - objects"` - @throw invalid_iterator.210 if @a first and @a last do not belong to the - same JSON value; example: `"iterators do not fit"` - - @complexity O(N*log(size() + N)), where N is the number of elements to - insert. - - @liveexample{The example shows how `update()` is used__range.,update} - - @sa https://docs.python.org/3.6/library/stdtypes.html#dict.update - - @since version 3.0.0 - */ - void update(const_iterator first, const_iterator last) - { - // implicitly convert null value to an empty object - if (is_null()) - { - m_type = value_t::object; - m_value.object = create(); - assert_invariant(); - } - - if (JSON_HEDLEY_UNLIKELY(!is_object())) - { - JSON_THROW(type_error::create(312, "cannot use update() with " + std::string(type_name()))); - } - - // check if range iterators belong to the same JSON object - if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object)) - { - JSON_THROW(invalid_iterator::create(210, "iterators do not fit")); - } - - // passed iterators must belong to objects - if (JSON_HEDLEY_UNLIKELY(!first.m_object->is_object() - || !last.m_object->is_object())) - { - JSON_THROW(invalid_iterator::create(202, "iterators first and last must point to objects")); - } - - for (auto it = first; it != last; ++it) - { - m_value.object->operator[](it.key()) = it.value(); - } - } - - /*! - @brief exchanges the values - - Exchanges the contents of the JSON value with those of @a other. Does not - invoke any move, copy, or swap operations on individual elements. All - iterators and references remain valid. The past-the-end iterator is - invalidated. - - @param[in,out] other JSON value to exchange the contents with - - @complexity Constant. - - @liveexample{The example below shows how JSON values can be swapped with - `swap()`.,swap__reference} - - @since version 1.0.0 - */ - void swap(reference other) noexcept ( - std::is_nothrow_move_constructible::value&& - std::is_nothrow_move_assignable::value&& - std::is_nothrow_move_constructible::value&& - std::is_nothrow_move_assignable::value - ) - { - std::swap(m_type, other.m_type); - std::swap(m_value, other.m_value); - assert_invariant(); - } - - /*! - @brief exchanges the values - - Exchanges the contents of the JSON value from @a left with those of @a right. Does not - invoke any move, copy, or swap operations on individual elements. All - iterators and references remain valid. The past-the-end iterator is - invalidated. implemented as a friend function callable via ADL. - - @param[in,out] left JSON value to exchange the contents with - @param[in,out] right JSON value to exchange the contents with - - @complexity Constant. - - @liveexample{The example below shows how JSON values can be swapped with - `swap()`.,swap__reference} - - @since version 1.0.0 - */ - friend void swap(reference left, reference right) noexcept ( - std::is_nothrow_move_constructible::value&& - std::is_nothrow_move_assignable::value&& - std::is_nothrow_move_constructible::value&& - std::is_nothrow_move_assignable::value - ) - { - left.swap(right); - } - - /*! - @brief exchanges the values - - Exchanges the contents of a JSON array with those of @a other. Does not - invoke any move, copy, or swap operations on individual elements. All - iterators and references remain valid. The past-the-end iterator is - invalidated. - - @param[in,out] other array to exchange the contents with - - @throw type_error.310 when JSON value is not an array; example: `"cannot - use swap() with string"` - - @complexity Constant. - - @liveexample{The example below shows how arrays can be swapped with - `swap()`.,swap__array_t} - - @since version 1.0.0 - */ - void swap(array_t& other) - { - // swap only works for arrays - if (JSON_HEDLEY_LIKELY(is_array())) - { - std::swap(*(m_value.array), other); - } - else - { - JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()))); - } - } - - /*! - @brief exchanges the values - - Exchanges the contents of a JSON object with those of @a other. Does not - invoke any move, copy, or swap operations on individual elements. All - iterators and references remain valid. The past-the-end iterator is - invalidated. - - @param[in,out] other object to exchange the contents with - - @throw type_error.310 when JSON value is not an object; example: - `"cannot use swap() with string"` - - @complexity Constant. - - @liveexample{The example below shows how objects can be swapped with - `swap()`.,swap__object_t} - - @since version 1.0.0 - */ - void swap(object_t& other) - { - // swap only works for objects - if (JSON_HEDLEY_LIKELY(is_object())) - { - std::swap(*(m_value.object), other); - } - else - { - JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()))); - } - } - - /*! - @brief exchanges the values - - Exchanges the contents of a JSON string with those of @a other. Does not - invoke any move, copy, or swap operations on individual elements. All - iterators and references remain valid. The past-the-end iterator is - invalidated. - - @param[in,out] other string to exchange the contents with - - @throw type_error.310 when JSON value is not a string; example: `"cannot - use swap() with boolean"` - - @complexity Constant. - - @liveexample{The example below shows how strings can be swapped with - `swap()`.,swap__string_t} - - @since version 1.0.0 - */ - void swap(string_t& other) - { - // swap only works for strings - if (JSON_HEDLEY_LIKELY(is_string())) - { - std::swap(*(m_value.string), other); - } - else - { - JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()))); - } - } - - /*! - @brief exchanges the values - - Exchanges the contents of a JSON string with those of @a other. Does not - invoke any move, copy, or swap operations on individual elements. All - iterators and references remain valid. The past-the-end iterator is - invalidated. - - @param[in,out] other binary to exchange the contents with - - @throw type_error.310 when JSON value is not a string; example: `"cannot - use swap() with boolean"` - - @complexity Constant. - - @liveexample{The example below shows how strings can be swapped with - `swap()`.,swap__binary_t} - - @since version 3.8.0 - */ - void swap(binary_t& other) - { - // swap only works for strings - if (JSON_HEDLEY_LIKELY(is_binary())) - { - std::swap(*(m_value.binary), other); - } - else - { - JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()))); - } - } - - /// @copydoc swap(binary_t) - void swap(typename binary_t::container_type& other) - { - // swap only works for strings - if (JSON_HEDLEY_LIKELY(is_binary())) - { - std::swap(*(m_value.binary), other); - } - else - { - JSON_THROW(type_error::create(310, "cannot use swap() with " + std::string(type_name()))); - } - } - - /// @} - - public: - ////////////////////////////////////////// - // lexicographical comparison operators // - ////////////////////////////////////////// - - /// @name lexicographical comparison operators - /// @{ - - /*! - @brief comparison: equal - - Compares two JSON values for equality according to the following rules: - - Two JSON values are equal if (1) they are from the same type and (2) - their stored values are the same according to their respective - `operator==`. - - Integer and floating-point numbers are automatically converted before - comparison. Note that two NaN values are always treated as unequal. - - Two JSON null values are equal. - - @note Floating-point inside JSON values numbers are compared with - `json::number_float_t::operator==` which is `double::operator==` by - default. To compare floating-point while respecting an epsilon, an alternative - [comparison function](https://github.com/mariokonrad/marnav/blob/master/include/marnav/math/floatingpoint.hpp#L34-#L39) - could be used, for instance - @code {.cpp} - template::value, T>::type> - inline bool is_same(T a, T b, T epsilon = std::numeric_limits::epsilon()) noexcept - { - return std::abs(a - b) <= epsilon; - } - @endcode - Or you can self-defined operator equal function like this: - @code {.cpp} - bool my_equal(const_reference lhs, const_reference rhs) { - const auto lhs_type lhs.type(); - const auto rhs_type rhs.type(); - if (lhs_type == rhs_type) { - switch(lhs_type) - // self_defined case - case value_t::number_float: - return std::abs(lhs - rhs) <= std::numeric_limits::epsilon(); - // other cases remain the same with the original - ... - } - ... - } - @endcode - - @note NaN values never compare equal to themselves or to other NaN values. - - @param[in] lhs first JSON value to consider - @param[in] rhs second JSON value to consider - @return whether the values @a lhs and @a rhs are equal - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - - @complexity Linear. - - @liveexample{The example demonstrates comparing several JSON - types.,operator__equal} - - @since version 1.0.0 - */ - friend bool operator==(const_reference lhs, const_reference rhs) noexcept - { - const auto lhs_type = lhs.type(); - const auto rhs_type = rhs.type(); - - if (lhs_type == rhs_type) - { - switch (lhs_type) - { - case value_t::array: - return *lhs.m_value.array == *rhs.m_value.array; - - case value_t::object: - return *lhs.m_value.object == *rhs.m_value.object; - - case value_t::null: - return true; - - case value_t::string: - return *lhs.m_value.string == *rhs.m_value.string; - - case value_t::boolean: - return lhs.m_value.boolean == rhs.m_value.boolean; - - case value_t::number_integer: - return lhs.m_value.number_integer == rhs.m_value.number_integer; - - case value_t::number_unsigned: - return lhs.m_value.number_unsigned == rhs.m_value.number_unsigned; - - case value_t::number_float: - return lhs.m_value.number_float == rhs.m_value.number_float; - - case value_t::binary: - return *lhs.m_value.binary == *rhs.m_value.binary; - - default: - return false; - } - } - else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_float) - { - return static_cast(lhs.m_value.number_integer) == rhs.m_value.number_float; - } - else if (lhs_type == value_t::number_float && rhs_type == value_t::number_integer) - { - return lhs.m_value.number_float == static_cast(rhs.m_value.number_integer); - } - else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_float) - { - return static_cast(lhs.m_value.number_unsigned) == rhs.m_value.number_float; - } - else if (lhs_type == value_t::number_float && rhs_type == value_t::number_unsigned) - { - return lhs.m_value.number_float == static_cast(rhs.m_value.number_unsigned); - } - else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_integer) - { - return static_cast(lhs.m_value.number_unsigned) == rhs.m_value.number_integer; - } - else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_unsigned) - { - return lhs.m_value.number_integer == static_cast(rhs.m_value.number_unsigned); - } - - return false; - } - - /*! - @brief comparison: equal - @copydoc operator==(const_reference, const_reference) - */ - template::value, int>::type = 0> - friend bool operator==(const_reference lhs, const ScalarType rhs) noexcept - { - return lhs == basic_json(rhs); - } - - /*! - @brief comparison: equal - @copydoc operator==(const_reference, const_reference) - */ - template::value, int>::type = 0> - friend bool operator==(const ScalarType lhs, const_reference rhs) noexcept - { - return basic_json(lhs) == rhs; - } - - /*! - @brief comparison: not equal - - Compares two JSON values for inequality by calculating `not (lhs == rhs)`. - - @param[in] lhs first JSON value to consider - @param[in] rhs second JSON value to consider - @return whether the values @a lhs and @a rhs are not equal - - @complexity Linear. - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - - @liveexample{The example demonstrates comparing several JSON - types.,operator__notequal} - - @since version 1.0.0 - */ - friend bool operator!=(const_reference lhs, const_reference rhs) noexcept - { - return !(lhs == rhs); - } - - /*! - @brief comparison: not equal - @copydoc operator!=(const_reference, const_reference) - */ - template::value, int>::type = 0> - friend bool operator!=(const_reference lhs, const ScalarType rhs) noexcept - { - return lhs != basic_json(rhs); - } - - /*! - @brief comparison: not equal - @copydoc operator!=(const_reference, const_reference) - */ - template::value, int>::type = 0> - friend bool operator!=(const ScalarType lhs, const_reference rhs) noexcept - { - return basic_json(lhs) != rhs; - } - - /*! - @brief comparison: less than - - Compares whether one JSON value @a lhs is less than another JSON value @a - rhs according to the following rules: - - If @a lhs and @a rhs have the same type, the values are compared using - the default `<` operator. - - Integer and floating-point numbers are automatically converted before - comparison - - In case @a lhs and @a rhs have different types, the values are ignored - and the order of the types is considered, see - @ref operator<(const value_t, const value_t). - - @param[in] lhs first JSON value to consider - @param[in] rhs second JSON value to consider - @return whether @a lhs is less than @a rhs - - @complexity Linear. - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - - @liveexample{The example demonstrates comparing several JSON - types.,operator__less} - - @since version 1.0.0 - */ - friend bool operator<(const_reference lhs, const_reference rhs) noexcept - { - const auto lhs_type = lhs.type(); - const auto rhs_type = rhs.type(); - - if (lhs_type == rhs_type) - { - switch (lhs_type) - { - case value_t::array: - // note parentheses are necessary, see - // https://github.com/nlohmann/json/issues/1530 - return (*lhs.m_value.array) < (*rhs.m_value.array); - - case value_t::object: - return (*lhs.m_value.object) < (*rhs.m_value.object); - - case value_t::null: - return false; - - case value_t::string: - return (*lhs.m_value.string) < (*rhs.m_value.string); - - case value_t::boolean: - return (lhs.m_value.boolean) < (rhs.m_value.boolean); - - case value_t::number_integer: - return (lhs.m_value.number_integer) < (rhs.m_value.number_integer); - - case value_t::number_unsigned: - return (lhs.m_value.number_unsigned) < (rhs.m_value.number_unsigned); - - case value_t::number_float: - return (lhs.m_value.number_float) < (rhs.m_value.number_float); - - case value_t::binary: - return (*lhs.m_value.binary) < (*rhs.m_value.binary); - - default: - return false; - } - } - else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_float) - { - return static_cast(lhs.m_value.number_integer) < rhs.m_value.number_float; - } - else if (lhs_type == value_t::number_float && rhs_type == value_t::number_integer) - { - return lhs.m_value.number_float < static_cast(rhs.m_value.number_integer); - } - else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_float) - { - return static_cast(lhs.m_value.number_unsigned) < rhs.m_value.number_float; - } - else if (lhs_type == value_t::number_float && rhs_type == value_t::number_unsigned) - { - return lhs.m_value.number_float < static_cast(rhs.m_value.number_unsigned); - } - else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_unsigned) - { - return lhs.m_value.number_integer < static_cast(rhs.m_value.number_unsigned); - } - else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_integer) - { - return static_cast(lhs.m_value.number_unsigned) < rhs.m_value.number_integer; - } - - // We only reach this line if we cannot compare values. In that case, - // we compare types. Note we have to call the operator explicitly, - // because MSVC has problems otherwise. - return operator<(lhs_type, rhs_type); - } - - /*! - @brief comparison: less than - @copydoc operator<(const_reference, const_reference) - */ - template::value, int>::type = 0> - friend bool operator<(const_reference lhs, const ScalarType rhs) noexcept - { - return lhs < basic_json(rhs); - } - - /*! - @brief comparison: less than - @copydoc operator<(const_reference, const_reference) - */ - template::value, int>::type = 0> - friend bool operator<(const ScalarType lhs, const_reference rhs) noexcept - { - return basic_json(lhs) < rhs; - } - - /*! - @brief comparison: less than or equal - - Compares whether one JSON value @a lhs is less than or equal to another - JSON value by calculating `not (rhs < lhs)`. - - @param[in] lhs first JSON value to consider - @param[in] rhs second JSON value to consider - @return whether @a lhs is less than or equal to @a rhs - - @complexity Linear. - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - - @liveexample{The example demonstrates comparing several JSON - types.,operator__greater} - - @since version 1.0.0 - */ - friend bool operator<=(const_reference lhs, const_reference rhs) noexcept - { - return !(rhs < lhs); - } - - /*! - @brief comparison: less than or equal - @copydoc operator<=(const_reference, const_reference) - */ - template::value, int>::type = 0> - friend bool operator<=(const_reference lhs, const ScalarType rhs) noexcept - { - return lhs <= basic_json(rhs); - } - - /*! - @brief comparison: less than or equal - @copydoc operator<=(const_reference, const_reference) - */ - template::value, int>::type = 0> - friend bool operator<=(const ScalarType lhs, const_reference rhs) noexcept - { - return basic_json(lhs) <= rhs; - } - - /*! - @brief comparison: greater than - - Compares whether one JSON value @a lhs is greater than another - JSON value by calculating `not (lhs <= rhs)`. - - @param[in] lhs first JSON value to consider - @param[in] rhs second JSON value to consider - @return whether @a lhs is greater than to @a rhs - - @complexity Linear. - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - - @liveexample{The example demonstrates comparing several JSON - types.,operator__lessequal} - - @since version 1.0.0 - */ - friend bool operator>(const_reference lhs, const_reference rhs) noexcept - { - return !(lhs <= rhs); - } - - /*! - @brief comparison: greater than - @copydoc operator>(const_reference, const_reference) - */ - template::value, int>::type = 0> - friend bool operator>(const_reference lhs, const ScalarType rhs) noexcept - { - return lhs > basic_json(rhs); - } - - /*! - @brief comparison: greater than - @copydoc operator>(const_reference, const_reference) - */ - template::value, int>::type = 0> - friend bool operator>(const ScalarType lhs, const_reference rhs) noexcept - { - return basic_json(lhs) > rhs; - } - - /*! - @brief comparison: greater than or equal - - Compares whether one JSON value @a lhs is greater than or equal to another - JSON value by calculating `not (lhs < rhs)`. - - @param[in] lhs first JSON value to consider - @param[in] rhs second JSON value to consider - @return whether @a lhs is greater than or equal to @a rhs - - @complexity Linear. - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - - @liveexample{The example demonstrates comparing several JSON - types.,operator__greaterequal} - - @since version 1.0.0 - */ - friend bool operator>=(const_reference lhs, const_reference rhs) noexcept - { - return !(lhs < rhs); - } - - /*! - @brief comparison: greater than or equal - @copydoc operator>=(const_reference, const_reference) - */ - template::value, int>::type = 0> - friend bool operator>=(const_reference lhs, const ScalarType rhs) noexcept - { - return lhs >= basic_json(rhs); - } - - /*! - @brief comparison: greater than or equal - @copydoc operator>=(const_reference, const_reference) - */ - template::value, int>::type = 0> - friend bool operator>=(const ScalarType lhs, const_reference rhs) noexcept - { - return basic_json(lhs) >= rhs; - } - - /// @} - - /////////////////// - // serialization // - /////////////////// - - /// @name serialization - /// @{ - - /*! - @brief serialize to stream - - Serialize the given JSON value @a j to the output stream @a o. The JSON - value will be serialized using the @ref dump member function. - - - The indentation of the output can be controlled with the member variable - `width` of the output stream @a o. For instance, using the manipulator - `std::setw(4)` on @a o sets the indentation level to `4` and the - serialization result is the same as calling `dump(4)`. - - - The indentation character can be controlled with the member variable - `fill` of the output stream @a o. For instance, the manipulator - `std::setfill('\\t')` sets indentation to use a tab character rather than - the default space character. - - @param[in,out] o stream to serialize to - @param[in] j JSON value to serialize - - @return the stream @a o - - @throw type_error.316 if a string stored inside the JSON value is not - UTF-8 encoded - - @complexity Linear. - - @liveexample{The example below shows the serialization with different - parameters to `width` to adjust the indentation level.,operator_serialize} - - @since version 1.0.0; indentation character added in version 3.0.0 - */ - friend std::ostream& operator<<(std::ostream& o, const basic_json& j) - { - // read width member and use it as indentation parameter if nonzero - const bool pretty_print = o.width() > 0; - const auto indentation = pretty_print ? o.width() : 0; - - // reset width to 0 for subsequent calls to this stream - o.width(0); - - // do the actual serialization - serializer s(detail::output_adapter(o), o.fill()); - s.dump(j, pretty_print, false, static_cast(indentation)); - return o; - } - - /*! - @brief serialize to stream - @deprecated This stream operator is deprecated and will be removed in - future 4.0.0 of the library. Please use - @ref operator<<(std::ostream&, const basic_json&) - instead; that is, replace calls like `j >> o;` with `o << j;`. - @since version 1.0.0; deprecated since version 3.0.0 - */ - JSON_HEDLEY_DEPRECATED_FOR(3.0.0, operator<<(std::ostream&, const basic_json&)) - friend std::ostream& operator>>(const basic_json& j, std::ostream& o) - { - return o << j; - } - - /// @} - - - ///////////////////// - // deserialization // - ///////////////////// - - /// @name deserialization - /// @{ - - /*! - @brief deserialize from a compatible input - - @tparam InputType A compatible input, for instance - - an std::istream object - - a FILE pointer - - a C-style array of characters - - a pointer to a null-terminated string of single byte characters - - an object obj for which begin(obj) and end(obj) produces a valid pair of - iterators. - - @param[in] i input to read from - @param[in] cb a parser callback function of type @ref parser_callback_t - which is used to control the deserialization by filtering unwanted values - (optional) - @param[in] allow_exceptions whether to throw exceptions in case of a - parse error (optional, true by default) - @param[in] ignore_comments whether comments should be ignored and treated - like whitespace (true) or yield a parse error (true); (optional, false by - default) - - @return deserialized JSON value; in case of a parse error and - @a allow_exceptions set to `false`, the return value will be - value_t::discarded. - - @throw parse_error.101 if a parse error occurs; example: `""unexpected end - of input; expected string literal""` - @throw parse_error.102 if to_unicode fails or surrogate error - @throw parse_error.103 if to_unicode fails - - @complexity Linear in the length of the input. The parser is a predictive - LL(1) parser. The complexity can be higher if the parser callback function - @a cb or reading from the input @a i has a super-linear complexity. - - @note A UTF-8 byte order mark is silently ignored. - - @liveexample{The example below demonstrates the `parse()` function reading - from an array.,parse__array__parser_callback_t} - - @liveexample{The example below demonstrates the `parse()` function with - and without callback function.,parse__string__parser_callback_t} - - @liveexample{The example below demonstrates the `parse()` function with - and without callback function.,parse__istream__parser_callback_t} - - @liveexample{The example below demonstrates the `parse()` function reading - from a contiguous container.,parse__contiguouscontainer__parser_callback_t} - - @since version 2.0.3 (contiguous containers); version 3.9.0 allowed to - ignore comments. - */ - template - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json parse(InputType&& i, - const parser_callback_t cb = nullptr, - const bool allow_exceptions = true, - const bool ignore_comments = false) - { - basic_json result; - parser(detail::input_adapter(std::forward(i)), cb, allow_exceptions, ignore_comments).parse(true, result); - return result; - } - - /*! - @brief deserialize from a pair of character iterators - - The value_type of the iterator must be an integral type with size of 1, 2 or - 4 bytes, which will be interpreted respectively as UTF-8, UTF-16 and UTF-32. - - @param[in] first iterator to start of character range - @param[in] last iterator to end of character range - @param[in] cb a parser callback function of type @ref parser_callback_t - which is used to control the deserialization by filtering unwanted values - (optional) - @param[in] allow_exceptions whether to throw exceptions in case of a - parse error (optional, true by default) - @param[in] ignore_comments whether comments should be ignored and treated - like whitespace (true) or yield a parse error (true); (optional, false by - default) - - @return deserialized JSON value; in case of a parse error and - @a allow_exceptions set to `false`, the return value will be - value_t::discarded. - - @throw parse_error.101 if a parse error occurs; example: `""unexpected end - of input; expected string literal""` - @throw parse_error.102 if to_unicode fails or surrogate error - @throw parse_error.103 if to_unicode fails - */ - template - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json parse(IteratorType first, - IteratorType last, - const parser_callback_t cb = nullptr, - const bool allow_exceptions = true, - const bool ignore_comments = false) - { - basic_json result; - parser(detail::input_adapter(std::move(first), std::move(last)), cb, allow_exceptions, ignore_comments).parse(true, result); - return result; - } - - JSON_HEDLEY_WARN_UNUSED_RESULT - JSON_HEDLEY_DEPRECATED_FOR(3.8.0, parse(ptr, ptr + len)) - static basic_json parse(detail::span_input_adapter&& i, - const parser_callback_t cb = nullptr, - const bool allow_exceptions = true, - const bool ignore_comments = false) - { - basic_json result; - parser(i.get(), cb, allow_exceptions, ignore_comments).parse(true, result); - return result; - } - - /*! - @brief check if the input is valid JSON - - Unlike the @ref parse(InputType&&, const parser_callback_t,const bool) - function, this function neither throws an exception in case of invalid JSON - input (i.e., a parse error) nor creates diagnostic information. - - @tparam InputType A compatible input, for instance - - an std::istream object - - a FILE pointer - - a C-style array of characters - - a pointer to a null-terminated string of single byte characters - - an object obj for which begin(obj) and end(obj) produces a valid pair of - iterators. - - @param[in] i input to read from - @param[in] ignore_comments whether comments should be ignored and treated - like whitespace (true) or yield a parse error (true); (optional, false by - default) - - @return Whether the input read from @a i is valid JSON. - - @complexity Linear in the length of the input. The parser is a predictive - LL(1) parser. - - @note A UTF-8 byte order mark is silently ignored. - - @liveexample{The example below demonstrates the `accept()` function reading - from a string.,accept__string} - */ - template - static bool accept(InputType&& i, - const bool ignore_comments = false) - { - return parser(detail::input_adapter(std::forward(i)), nullptr, false, ignore_comments).accept(true); - } - - template - static bool accept(IteratorType first, IteratorType last, - const bool ignore_comments = false) - { - return parser(detail::input_adapter(std::move(first), std::move(last)), nullptr, false, ignore_comments).accept(true); - } - - JSON_HEDLEY_WARN_UNUSED_RESULT - JSON_HEDLEY_DEPRECATED_FOR(3.8.0, accept(ptr, ptr + len)) - static bool accept(detail::span_input_adapter&& i, - const bool ignore_comments = false) - { - return parser(i.get(), nullptr, false, ignore_comments).accept(true); - } - - /*! - @brief generate SAX events - - The SAX event lister must follow the interface of @ref json_sax. - - This function reads from a compatible input. Examples are: - - an std::istream object - - a FILE pointer - - a C-style array of characters - - a pointer to a null-terminated string of single byte characters - - an object obj for which begin(obj) and end(obj) produces a valid pair of - iterators. - - @param[in] i input to read from - @param[in,out] sax SAX event listener - @param[in] format the format to parse (JSON, CBOR, MessagePack, or UBJSON) - @param[in] strict whether the input has to be consumed completely - @param[in] ignore_comments whether comments should be ignored and treated - like whitespace (true) or yield a parse error (true); (optional, false by - default); only applies to the JSON file format. - - @return return value of the last processed SAX event - - @throw parse_error.101 if a parse error occurs; example: `""unexpected end - of input; expected string literal""` - @throw parse_error.102 if to_unicode fails or surrogate error - @throw parse_error.103 if to_unicode fails - - @complexity Linear in the length of the input. The parser is a predictive - LL(1) parser. The complexity can be higher if the SAX consumer @a sax has - a super-linear complexity. - - @note A UTF-8 byte order mark is silently ignored. - - @liveexample{The example below demonstrates the `sax_parse()` function - reading from string and processing the events with a user-defined SAX - event consumer.,sax_parse} - - @since version 3.2.0 - */ - template - JSON_HEDLEY_NON_NULL(2) - static bool sax_parse(InputType&& i, SAX* sax, - input_format_t format = input_format_t::json, - const bool strict = true, - const bool ignore_comments = false) - { - auto ia = detail::input_adapter(std::forward(i)); - return format == input_format_t::json - ? parser(std::move(ia), nullptr, true, ignore_comments).sax_parse(sax, strict) - : detail::binary_reader(std::move(ia)).sax_parse(format, sax, strict); - } - - template - JSON_HEDLEY_NON_NULL(3) - static bool sax_parse(IteratorType first, IteratorType last, SAX* sax, - input_format_t format = input_format_t::json, - const bool strict = true, - const bool ignore_comments = false) - { - auto ia = detail::input_adapter(std::move(first), std::move(last)); - return format == input_format_t::json - ? parser(std::move(ia), nullptr, true, ignore_comments).sax_parse(sax, strict) - : detail::binary_reader(std::move(ia)).sax_parse(format, sax, strict); - } - - template - JSON_HEDLEY_DEPRECATED_FOR(3.8.0, sax_parse(ptr, ptr + len, ...)) - JSON_HEDLEY_NON_NULL(2) - static bool sax_parse(detail::span_input_adapter&& i, SAX* sax, - input_format_t format = input_format_t::json, - const bool strict = true, - const bool ignore_comments = false) - { - auto ia = i.get(); - return format == input_format_t::json - ? parser(std::move(ia), nullptr, true, ignore_comments).sax_parse(sax, strict) - : detail::binary_reader(std::move(ia)).sax_parse(format, sax, strict); - } - - /*! - @brief deserialize from stream - @deprecated This stream operator is deprecated and will be removed in - version 4.0.0 of the library. Please use - @ref operator>>(std::istream&, basic_json&) - instead; that is, replace calls like `j << i;` with `i >> j;`. - @since version 1.0.0; deprecated since version 3.0.0 - */ - JSON_HEDLEY_DEPRECATED_FOR(3.0.0, operator>>(std::istream&, basic_json&)) - friend std::istream& operator<<(basic_json& j, std::istream& i) - { - return operator>>(i, j); - } - - /*! - @brief deserialize from stream - - Deserializes an input stream to a JSON value. - - @param[in,out] i input stream to read a serialized JSON value from - @param[in,out] j JSON value to write the deserialized input to - - @throw parse_error.101 in case of an unexpected token - @throw parse_error.102 if to_unicode fails or surrogate error - @throw parse_error.103 if to_unicode fails - - @complexity Linear in the length of the input. The parser is a predictive - LL(1) parser. - - @note A UTF-8 byte order mark is silently ignored. - - @liveexample{The example below shows how a JSON value is constructed by - reading a serialization from a stream.,operator_deserialize} - - @sa parse(std::istream&, const parser_callback_t) for a variant with a - parser callback function to filter values while parsing - - @since version 1.0.0 - */ - friend std::istream& operator>>(std::istream& i, basic_json& j) - { - parser(detail::input_adapter(i)).parse(false, j); - return i; - } - - /// @} - - /////////////////////////// - // convenience functions // - /////////////////////////// - - /*! - @brief return the type as string - - Returns the type name as string to be used in error messages - usually to - indicate that a function was called on a wrong JSON type. - - @return a string representation of a the @a m_type member: - Value type | return value - ----------- | ------------- - null | `"null"` - boolean | `"boolean"` - string | `"string"` - number | `"number"` (for all number types) - object | `"object"` - array | `"array"` - binary | `"binary"` - discarded | `"discarded"` - - @exceptionsafety No-throw guarantee: this function never throws exceptions. - - @complexity Constant. - - @liveexample{The following code exemplifies `type_name()` for all JSON - types.,type_name} - - @sa @ref type() -- return the type of the JSON value - @sa @ref operator value_t() -- return the type of the JSON value (implicit) - - @since version 1.0.0, public since 2.1.0, `const char*` and `noexcept` - since 3.0.0 - */ - JSON_HEDLEY_RETURNS_NON_NULL - const char* type_name() const noexcept - { - { - switch (m_type) - { - case value_t::null: - return "null"; - case value_t::object: - return "object"; - case value_t::array: - return "array"; - case value_t::string: - return "string"; - case value_t::boolean: - return "boolean"; - case value_t::binary: - return "binary"; - case value_t::discarded: - return "discarded"; - default: - return "number"; - } - } - } - - - private: - ////////////////////// - // member variables // - ////////////////////// - - /// the type of the current element - value_t m_type = value_t::null; - - /// the value of the current element - json_value m_value = {}; - - ////////////////////////////////////////// - // binary serialization/deserialization // - ////////////////////////////////////////// - - /// @name binary serialization/deserialization support - /// @{ - - public: - /*! - @brief create a CBOR serialization of a given JSON value - - Serializes a given JSON value @a j to a byte vector using the CBOR (Concise - Binary Object Representation) serialization format. CBOR is a binary - serialization format which aims to be more compact than JSON itself, yet - more efficient to parse. - - The library uses the following mapping from JSON values types to - CBOR types according to the CBOR specification (RFC 7049): - - JSON value type | value/range | CBOR type | first byte - --------------- | ------------------------------------------ | ---------------------------------- | --------------- - null | `null` | Null | 0xF6 - boolean | `true` | True | 0xF5 - boolean | `false` | False | 0xF4 - number_integer | -9223372036854775808..-2147483649 | Negative integer (8 bytes follow) | 0x3B - number_integer | -2147483648..-32769 | Negative integer (4 bytes follow) | 0x3A - number_integer | -32768..-129 | Negative integer (2 bytes follow) | 0x39 - number_integer | -128..-25 | Negative integer (1 byte follow) | 0x38 - number_integer | -24..-1 | Negative integer | 0x20..0x37 - number_integer | 0..23 | Integer | 0x00..0x17 - number_integer | 24..255 | Unsigned integer (1 byte follow) | 0x18 - number_integer | 256..65535 | Unsigned integer (2 bytes follow) | 0x19 - number_integer | 65536..4294967295 | Unsigned integer (4 bytes follow) | 0x1A - number_integer | 4294967296..18446744073709551615 | Unsigned integer (8 bytes follow) | 0x1B - number_unsigned | 0..23 | Integer | 0x00..0x17 - number_unsigned | 24..255 | Unsigned integer (1 byte follow) | 0x18 - number_unsigned | 256..65535 | Unsigned integer (2 bytes follow) | 0x19 - number_unsigned | 65536..4294967295 | Unsigned integer (4 bytes follow) | 0x1A - number_unsigned | 4294967296..18446744073709551615 | Unsigned integer (8 bytes follow) | 0x1B - number_float | *any value representable by a float* | Single-Precision Float | 0xFA - number_float | *any value NOT representable by a float* | Double-Precision Float | 0xFB - string | *length*: 0..23 | UTF-8 string | 0x60..0x77 - string | *length*: 23..255 | UTF-8 string (1 byte follow) | 0x78 - string | *length*: 256..65535 | UTF-8 string (2 bytes follow) | 0x79 - string | *length*: 65536..4294967295 | UTF-8 string (4 bytes follow) | 0x7A - string | *length*: 4294967296..18446744073709551615 | UTF-8 string (8 bytes follow) | 0x7B - array | *size*: 0..23 | array | 0x80..0x97 - array | *size*: 23..255 | array (1 byte follow) | 0x98 - array | *size*: 256..65535 | array (2 bytes follow) | 0x99 - array | *size*: 65536..4294967295 | array (4 bytes follow) | 0x9A - array | *size*: 4294967296..18446744073709551615 | array (8 bytes follow) | 0x9B - object | *size*: 0..23 | map | 0xA0..0xB7 - object | *size*: 23..255 | map (1 byte follow) | 0xB8 - object | *size*: 256..65535 | map (2 bytes follow) | 0xB9 - object | *size*: 65536..4294967295 | map (4 bytes follow) | 0xBA - object | *size*: 4294967296..18446744073709551615 | map (8 bytes follow) | 0xBB - binary | *size*: 0..23 | byte string | 0x40..0x57 - binary | *size*: 23..255 | byte string (1 byte follow) | 0x58 - binary | *size*: 256..65535 | byte string (2 bytes follow) | 0x59 - binary | *size*: 65536..4294967295 | byte string (4 bytes follow) | 0x5A - binary | *size*: 4294967296..18446744073709551615 | byte string (8 bytes follow) | 0x5B - - @note The mapping is **complete** in the sense that any JSON value type - can be converted to a CBOR value. - - @note If NaN or Infinity are stored inside a JSON number, they are - serialized properly. This behavior differs from the @ref dump() - function which serializes NaN or Infinity to `null`. - - @note The following CBOR types are not used in the conversion: - - UTF-8 strings terminated by "break" (0x7F) - - arrays terminated by "break" (0x9F) - - maps terminated by "break" (0xBF) - - byte strings terminated by "break" (0x5F) - - date/time (0xC0..0xC1) - - bignum (0xC2..0xC3) - - decimal fraction (0xC4) - - bigfloat (0xC5) - - expected conversions (0xD5..0xD7) - - simple values (0xE0..0xF3, 0xF8) - - undefined (0xF7) - - half-precision floats (0xF9) - - break (0xFF) - - @param[in] j JSON value to serialize - @return CBOR serialization as byte vector - - @complexity Linear in the size of the JSON value @a j. - - @liveexample{The example shows the serialization of a JSON value to a byte - vector in CBOR format.,to_cbor} - - @sa http://cbor.io - @sa @ref from_cbor(detail::input_adapter&&, const bool, const bool, const cbor_tag_handler_t) for the - analogous deserialization - @sa @ref to_msgpack(const basic_json&) for the related MessagePack format - @sa @ref to_ubjson(const basic_json&, const bool, const bool) for the - related UBJSON format - - @since version 2.0.9; compact representation of floating-point numbers - since version 3.8.0 - */ - static std::vector to_cbor(const basic_json& j) - { - std::vector result; - to_cbor(j, result); - return result; - } - - static void to_cbor(const basic_json& j, detail::output_adapter o) - { - binary_writer(o).write_cbor(j); - } - - static void to_cbor(const basic_json& j, detail::output_adapter o) - { - binary_writer(o).write_cbor(j); - } - - /*! - @brief create a MessagePack serialization of a given JSON value - - Serializes a given JSON value @a j to a byte vector using the MessagePack - serialization format. MessagePack is a binary serialization format which - aims to be more compact than JSON itself, yet more efficient to parse. - - The library uses the following mapping from JSON values types to - MessagePack types according to the MessagePack specification: - - JSON value type | value/range | MessagePack type | first byte - --------------- | --------------------------------- | ---------------- | ---------- - null | `null` | nil | 0xC0 - boolean | `true` | true | 0xC3 - boolean | `false` | false | 0xC2 - number_integer | -9223372036854775808..-2147483649 | int64 | 0xD3 - number_integer | -2147483648..-32769 | int32 | 0xD2 - number_integer | -32768..-129 | int16 | 0xD1 - number_integer | -128..-33 | int8 | 0xD0 - number_integer | -32..-1 | negative fixint | 0xE0..0xFF - number_integer | 0..127 | positive fixint | 0x00..0x7F - number_integer | 128..255 | uint 8 | 0xCC - number_integer | 256..65535 | uint 16 | 0xCD - number_integer | 65536..4294967295 | uint 32 | 0xCE - number_integer | 4294967296..18446744073709551615 | uint 64 | 0xCF - number_unsigned | 0..127 | positive fixint | 0x00..0x7F - number_unsigned | 128..255 | uint 8 | 0xCC - number_unsigned | 256..65535 | uint 16 | 0xCD - number_unsigned | 65536..4294967295 | uint 32 | 0xCE - number_unsigned | 4294967296..18446744073709551615 | uint 64 | 0xCF - number_float | *any value representable by a float* | float 32 | 0xCA - number_float | *any value NOT representable by a float* | float 64 | 0xCB - string | *length*: 0..31 | fixstr | 0xA0..0xBF - string | *length*: 32..255 | str 8 | 0xD9 - string | *length*: 256..65535 | str 16 | 0xDA - string | *length*: 65536..4294967295 | str 32 | 0xDB - array | *size*: 0..15 | fixarray | 0x90..0x9F - array | *size*: 16..65535 | array 16 | 0xDC - array | *size*: 65536..4294967295 | array 32 | 0xDD - object | *size*: 0..15 | fix map | 0x80..0x8F - object | *size*: 16..65535 | map 16 | 0xDE - object | *size*: 65536..4294967295 | map 32 | 0xDF - binary | *size*: 0..255 | bin 8 | 0xC4 - binary | *size*: 256..65535 | bin 16 | 0xC5 - binary | *size*: 65536..4294967295 | bin 32 | 0xC6 - - @note The mapping is **complete** in the sense that any JSON value type - can be converted to a MessagePack value. - - @note The following values can **not** be converted to a MessagePack value: - - strings with more than 4294967295 bytes - - byte strings with more than 4294967295 bytes - - arrays with more than 4294967295 elements - - objects with more than 4294967295 elements - - @note Any MessagePack output created @ref to_msgpack can be successfully - parsed by @ref from_msgpack. - - @note If NaN or Infinity are stored inside a JSON number, they are - serialized properly. This behavior differs from the @ref dump() - function which serializes NaN or Infinity to `null`. - - @param[in] j JSON value to serialize - @return MessagePack serialization as byte vector - - @complexity Linear in the size of the JSON value @a j. - - @liveexample{The example shows the serialization of a JSON value to a byte - vector in MessagePack format.,to_msgpack} - - @sa http://msgpack.org - @sa @ref from_msgpack for the analogous deserialization - @sa @ref to_cbor(const basic_json& for the related CBOR format - @sa @ref to_ubjson(const basic_json&, const bool, const bool) for the - related UBJSON format - - @since version 2.0.9 - */ - static std::vector to_msgpack(const basic_json& j) - { - std::vector result; - to_msgpack(j, result); - return result; - } - - static void to_msgpack(const basic_json& j, detail::output_adapter o) - { - binary_writer(o).write_msgpack(j); - } - - static void to_msgpack(const basic_json& j, detail::output_adapter o) - { - binary_writer(o).write_msgpack(j); - } - - /*! - @brief create a UBJSON serialization of a given JSON value - - Serializes a given JSON value @a j to a byte vector using the UBJSON - (Universal Binary JSON) serialization format. UBJSON aims to be more compact - than JSON itself, yet more efficient to parse. - - The library uses the following mapping from JSON values types to - UBJSON types according to the UBJSON specification: - - JSON value type | value/range | UBJSON type | marker - --------------- | --------------------------------- | ----------- | ------ - null | `null` | null | `Z` - boolean | `true` | true | `T` - boolean | `false` | false | `F` - number_integer | -9223372036854775808..-2147483649 | int64 | `L` - number_integer | -2147483648..-32769 | int32 | `l` - number_integer | -32768..-129 | int16 | `I` - number_integer | -128..127 | int8 | `i` - number_integer | 128..255 | uint8 | `U` - number_integer | 256..32767 | int16 | `I` - number_integer | 32768..2147483647 | int32 | `l` - number_integer | 2147483648..9223372036854775807 | int64 | `L` - number_unsigned | 0..127 | int8 | `i` - number_unsigned | 128..255 | uint8 | `U` - number_unsigned | 256..32767 | int16 | `I` - number_unsigned | 32768..2147483647 | int32 | `l` - number_unsigned | 2147483648..9223372036854775807 | int64 | `L` - number_unsigned | 2147483649..18446744073709551615 | high-precision | `H` - number_float | *any value* | float64 | `D` - string | *with shortest length indicator* | string | `S` - array | *see notes on optimized format* | array | `[` - object | *see notes on optimized format* | map | `{` - - @note The mapping is **complete** in the sense that any JSON value type - can be converted to a UBJSON value. - - @note The following values can **not** be converted to a UBJSON value: - - strings with more than 9223372036854775807 bytes (theoretical) - - @note The following markers are not used in the conversion: - - `Z`: no-op values are not created. - - `C`: single-byte strings are serialized with `S` markers. - - @note Any UBJSON output created @ref to_ubjson can be successfully parsed - by @ref from_ubjson. - - @note If NaN or Infinity are stored inside a JSON number, they are - serialized properly. This behavior differs from the @ref dump() - function which serializes NaN or Infinity to `null`. - - @note The optimized formats for containers are supported: Parameter - @a use_size adds size information to the beginning of a container and - removes the closing marker. Parameter @a use_type further checks - whether all elements of a container have the same type and adds the - type marker to the beginning of the container. The @a use_type - parameter must only be used together with @a use_size = true. Note - that @a use_size = true alone may result in larger representations - - the benefit of this parameter is that the receiving side is - immediately informed on the number of elements of the container. - - @note If the JSON data contains the binary type, the value stored is a list - of integers, as suggested by the UBJSON documentation. In particular, - this means that serialization and the deserialization of a JSON - containing binary values into UBJSON and back will result in a - different JSON object. - - @param[in] j JSON value to serialize - @param[in] use_size whether to add size annotations to container types - @param[in] use_type whether to add type annotations to container types - (must be combined with @a use_size = true) - @return UBJSON serialization as byte vector - - @complexity Linear in the size of the JSON value @a j. - - @liveexample{The example shows the serialization of a JSON value to a byte - vector in UBJSON format.,to_ubjson} - - @sa http://ubjson.org - @sa @ref from_ubjson(detail::input_adapter&&, const bool, const bool) for the - analogous deserialization - @sa @ref to_cbor(const basic_json& for the related CBOR format - @sa @ref to_msgpack(const basic_json&) for the related MessagePack format - - @since version 3.1.0 - */ - static std::vector to_ubjson(const basic_json& j, - const bool use_size = false, - const bool use_type = false) - { - std::vector result; - to_ubjson(j, result, use_size, use_type); - return result; - } - - static void to_ubjson(const basic_json& j, detail::output_adapter o, - const bool use_size = false, const bool use_type = false) - { - binary_writer(o).write_ubjson(j, use_size, use_type); - } - - static void to_ubjson(const basic_json& j, detail::output_adapter o, - const bool use_size = false, const bool use_type = false) - { - binary_writer(o).write_ubjson(j, use_size, use_type); - } - - - /*! - @brief Serializes the given JSON object `j` to BSON and returns a vector - containing the corresponding BSON-representation. - - BSON (Binary JSON) is a binary format in which zero or more ordered key/value pairs are - stored as a single entity (a so-called document). - - The library uses the following mapping from JSON values types to BSON types: - - JSON value type | value/range | BSON type | marker - --------------- | --------------------------------- | ----------- | ------ - null | `null` | null | 0x0A - boolean | `true`, `false` | boolean | 0x08 - number_integer | -9223372036854775808..-2147483649 | int64 | 0x12 - number_integer | -2147483648..2147483647 | int32 | 0x10 - number_integer | 2147483648..9223372036854775807 | int64 | 0x12 - number_unsigned | 0..2147483647 | int32 | 0x10 - number_unsigned | 2147483648..9223372036854775807 | int64 | 0x12 - number_unsigned | 9223372036854775808..18446744073709551615| -- | -- - number_float | *any value* | double | 0x01 - string | *any value* | string | 0x02 - array | *any value* | document | 0x04 - object | *any value* | document | 0x03 - binary | *any value* | binary | 0x05 - - @warning The mapping is **incomplete**, since only JSON-objects (and things - contained therein) can be serialized to BSON. - Also, integers larger than 9223372036854775807 cannot be serialized to BSON, - and the keys may not contain U+0000, since they are serialized a - zero-terminated c-strings. - - @throw out_of_range.407 if `j.is_number_unsigned() && j.get() > 9223372036854775807` - @throw out_of_range.409 if a key in `j` contains a NULL (U+0000) - @throw type_error.317 if `!j.is_object()` - - @pre The input `j` is required to be an object: `j.is_object() == true`. - - @note Any BSON output created via @ref to_bson can be successfully parsed - by @ref from_bson. - - @param[in] j JSON value to serialize - @return BSON serialization as byte vector - - @complexity Linear in the size of the JSON value @a j. - - @liveexample{The example shows the serialization of a JSON value to a byte - vector in BSON format.,to_bson} - - @sa http://bsonspec.org/spec.html - @sa @ref from_bson(detail::input_adapter&&, const bool strict) for the - analogous deserialization - @sa @ref to_ubjson(const basic_json&, const bool, const bool) for the - related UBJSON format - @sa @ref to_cbor(const basic_json&) for the related CBOR format - @sa @ref to_msgpack(const basic_json&) for the related MessagePack format - */ - static std::vector to_bson(const basic_json& j) - { - std::vector result; - to_bson(j, result); - return result; - } - - /*! - @brief Serializes the given JSON object `j` to BSON and forwards the - corresponding BSON-representation to the given output_adapter `o`. - @param j The JSON object to convert to BSON. - @param o The output adapter that receives the binary BSON representation. - @pre The input `j` shall be an object: `j.is_object() == true` - @sa @ref to_bson(const basic_json&) - */ - static void to_bson(const basic_json& j, detail::output_adapter o) - { - binary_writer(o).write_bson(j); - } - - /*! - @copydoc to_bson(const basic_json&, detail::output_adapter) - */ - static void to_bson(const basic_json& j, detail::output_adapter o) - { - binary_writer(o).write_bson(j); - } - - - /*! - @brief create a JSON value from an input in CBOR format - - Deserializes a given input @a i to a JSON value using the CBOR (Concise - Binary Object Representation) serialization format. - - The library maps CBOR types to JSON value types as follows: - - CBOR type | JSON value type | first byte - ---------------------- | --------------- | ---------- - Integer | number_unsigned | 0x00..0x17 - Unsigned integer | number_unsigned | 0x18 - Unsigned integer | number_unsigned | 0x19 - Unsigned integer | number_unsigned | 0x1A - Unsigned integer | number_unsigned | 0x1B - Negative integer | number_integer | 0x20..0x37 - Negative integer | number_integer | 0x38 - Negative integer | number_integer | 0x39 - Negative integer | number_integer | 0x3A - Negative integer | number_integer | 0x3B - Byte string | binary | 0x40..0x57 - Byte string | binary | 0x58 - Byte string | binary | 0x59 - Byte string | binary | 0x5A - Byte string | binary | 0x5B - UTF-8 string | string | 0x60..0x77 - UTF-8 string | string | 0x78 - UTF-8 string | string | 0x79 - UTF-8 string | string | 0x7A - UTF-8 string | string | 0x7B - UTF-8 string | string | 0x7F - array | array | 0x80..0x97 - array | array | 0x98 - array | array | 0x99 - array | array | 0x9A - array | array | 0x9B - array | array | 0x9F - map | object | 0xA0..0xB7 - map | object | 0xB8 - map | object | 0xB9 - map | object | 0xBA - map | object | 0xBB - map | object | 0xBF - False | `false` | 0xF4 - True | `true` | 0xF5 - Null | `null` | 0xF6 - Half-Precision Float | number_float | 0xF9 - Single-Precision Float | number_float | 0xFA - Double-Precision Float | number_float | 0xFB - - @warning The mapping is **incomplete** in the sense that not all CBOR - types can be converted to a JSON value. The following CBOR types - are not supported and will yield parse errors (parse_error.112): - - date/time (0xC0..0xC1) - - bignum (0xC2..0xC3) - - decimal fraction (0xC4) - - bigfloat (0xC5) - - expected conversions (0xD5..0xD7) - - simple values (0xE0..0xF3, 0xF8) - - undefined (0xF7) - - @warning CBOR allows map keys of any type, whereas JSON only allows - strings as keys in object values. Therefore, CBOR maps with keys - other than UTF-8 strings are rejected (parse_error.113). - - @note Any CBOR output created @ref to_cbor can be successfully parsed by - @ref from_cbor. - - @param[in] i an input in CBOR format convertible to an input adapter - @param[in] strict whether to expect the input to be consumed until EOF - (true by default) - @param[in] allow_exceptions whether to throw exceptions in case of a - parse error (optional, true by default) - @param[in] tag_handler how to treat CBOR tags (optional, error by default) - - @return deserialized JSON value; in case of a parse error and - @a allow_exceptions set to `false`, the return value will be - value_t::discarded. - - @throw parse_error.110 if the given input ends prematurely or the end of - file was not reached when @a strict was set to true - @throw parse_error.112 if unsupported features from CBOR were - used in the given input @a v or if the input is not valid CBOR - @throw parse_error.113 if a string was expected as map key, but not found - - @complexity Linear in the size of the input @a i. - - @liveexample{The example shows the deserialization of a byte vector in CBOR - format to a JSON value.,from_cbor} - - @sa http://cbor.io - @sa @ref to_cbor(const basic_json&) for the analogous serialization - @sa @ref from_msgpack(detail::input_adapter&&, const bool, const bool) for the - related MessagePack format - @sa @ref from_ubjson(detail::input_adapter&&, const bool, const bool) for the - related UBJSON format - - @since version 2.0.9; parameter @a start_index since 2.1.1; changed to - consume input adapters, removed start_index parameter, and added - @a strict parameter since 3.0.0; added @a allow_exceptions parameter - since 3.2.0; added @a tag_handler parameter since 3.9.0. - */ - template - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json from_cbor(InputType&& i, - const bool strict = true, - const bool allow_exceptions = true, - const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) - { - basic_json result; - detail::json_sax_dom_parser sdp(result, allow_exceptions); - auto ia = detail::input_adapter(std::forward(i)); - const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::cbor, &sdp, strict, tag_handler); - return res ? result : basic_json(value_t::discarded); - } - - /*! - @copydoc from_cbor(detail::input_adapter&&, const bool, const bool, const cbor_tag_handler_t) - */ - template - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json from_cbor(IteratorType first, IteratorType last, - const bool strict = true, - const bool allow_exceptions = true, - const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) - { - basic_json result; - detail::json_sax_dom_parser sdp(result, allow_exceptions); - auto ia = detail::input_adapter(std::move(first), std::move(last)); - const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::cbor, &sdp, strict, tag_handler); - return res ? result : basic_json(value_t::discarded); - } - - template - JSON_HEDLEY_WARN_UNUSED_RESULT - JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_cbor(ptr, ptr + len)) - static basic_json from_cbor(const T* ptr, std::size_t len, - const bool strict = true, - const bool allow_exceptions = true, - const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) - { - return from_cbor(ptr, ptr + len, strict, allow_exceptions, tag_handler); - } - - - JSON_HEDLEY_WARN_UNUSED_RESULT - JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_cbor(ptr, ptr + len)) - static basic_json from_cbor(detail::span_input_adapter&& i, - const bool strict = true, - const bool allow_exceptions = true, - const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error) - { - basic_json result; - detail::json_sax_dom_parser sdp(result, allow_exceptions); - auto ia = i.get(); - const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::cbor, &sdp, strict, tag_handler); - return res ? result : basic_json(value_t::discarded); - } - - /*! - @brief create a JSON value from an input in MessagePack format - - Deserializes a given input @a i to a JSON value using the MessagePack - serialization format. - - The library maps MessagePack types to JSON value types as follows: - - MessagePack type | JSON value type | first byte - ---------------- | --------------- | ---------- - positive fixint | number_unsigned | 0x00..0x7F - fixmap | object | 0x80..0x8F - fixarray | array | 0x90..0x9F - fixstr | string | 0xA0..0xBF - nil | `null` | 0xC0 - false | `false` | 0xC2 - true | `true` | 0xC3 - float 32 | number_float | 0xCA - float 64 | number_float | 0xCB - uint 8 | number_unsigned | 0xCC - uint 16 | number_unsigned | 0xCD - uint 32 | number_unsigned | 0xCE - uint 64 | number_unsigned | 0xCF - int 8 | number_integer | 0xD0 - int 16 | number_integer | 0xD1 - int 32 | number_integer | 0xD2 - int 64 | number_integer | 0xD3 - str 8 | string | 0xD9 - str 16 | string | 0xDA - str 32 | string | 0xDB - array 16 | array | 0xDC - array 32 | array | 0xDD - map 16 | object | 0xDE - map 32 | object | 0xDF - bin 8 | binary | 0xC4 - bin 16 | binary | 0xC5 - bin 32 | binary | 0xC6 - ext 8 | binary | 0xC7 - ext 16 | binary | 0xC8 - ext 32 | binary | 0xC9 - fixext 1 | binary | 0xD4 - fixext 2 | binary | 0xD5 - fixext 4 | binary | 0xD6 - fixext 8 | binary | 0xD7 - fixext 16 | binary | 0xD8 - negative fixint | number_integer | 0xE0-0xFF - - @note Any MessagePack output created @ref to_msgpack can be successfully - parsed by @ref from_msgpack. - - @param[in] i an input in MessagePack format convertible to an input - adapter - @param[in] strict whether to expect the input to be consumed until EOF - (true by default) - @param[in] allow_exceptions whether to throw exceptions in case of a - parse error (optional, true by default) - - @return deserialized JSON value; in case of a parse error and - @a allow_exceptions set to `false`, the return value will be - value_t::discarded. - - @throw parse_error.110 if the given input ends prematurely or the end of - file was not reached when @a strict was set to true - @throw parse_error.112 if unsupported features from MessagePack were - used in the given input @a i or if the input is not valid MessagePack - @throw parse_error.113 if a string was expected as map key, but not found - - @complexity Linear in the size of the input @a i. - - @liveexample{The example shows the deserialization of a byte vector in - MessagePack format to a JSON value.,from_msgpack} - - @sa http://msgpack.org - @sa @ref to_msgpack(const basic_json&) for the analogous serialization - @sa @ref from_cbor(detail::input_adapter&&, const bool, const bool, const cbor_tag_handler_t) for the - related CBOR format - @sa @ref from_ubjson(detail::input_adapter&&, const bool, const bool) for - the related UBJSON format - @sa @ref from_bson(detail::input_adapter&&, const bool, const bool) for - the related BSON format - - @since version 2.0.9; parameter @a start_index since 2.1.1; changed to - consume input adapters, removed start_index parameter, and added - @a strict parameter since 3.0.0; added @a allow_exceptions parameter - since 3.2.0 - */ - template - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json from_msgpack(InputType&& i, - const bool strict = true, - const bool allow_exceptions = true) - { - basic_json result; - detail::json_sax_dom_parser sdp(result, allow_exceptions); - auto ia = detail::input_adapter(std::forward(i)); - const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::msgpack, &sdp, strict); - return res ? result : basic_json(value_t::discarded); - } - - /*! - @copydoc from_msgpack(detail::input_adapter&&, const bool, const bool) - */ - template - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json from_msgpack(IteratorType first, IteratorType last, - const bool strict = true, - const bool allow_exceptions = true) - { - basic_json result; - detail::json_sax_dom_parser sdp(result, allow_exceptions); - auto ia = detail::input_adapter(std::move(first), std::move(last)); - const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::msgpack, &sdp, strict); - return res ? result : basic_json(value_t::discarded); - } - - - template - JSON_HEDLEY_WARN_UNUSED_RESULT - JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_msgpack(ptr, ptr + len)) - static basic_json from_msgpack(const T* ptr, std::size_t len, - const bool strict = true, - const bool allow_exceptions = true) - { - return from_msgpack(ptr, ptr + len, strict, allow_exceptions); - } - - JSON_HEDLEY_WARN_UNUSED_RESULT - JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_msgpack(ptr, ptr + len)) - static basic_json from_msgpack(detail::span_input_adapter&& i, - const bool strict = true, - const bool allow_exceptions = true) - { - basic_json result; - detail::json_sax_dom_parser sdp(result, allow_exceptions); - auto ia = i.get(); - const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::msgpack, &sdp, strict); - return res ? result : basic_json(value_t::discarded); - } - - - /*! - @brief create a JSON value from an input in UBJSON format - - Deserializes a given input @a i to a JSON value using the UBJSON (Universal - Binary JSON) serialization format. - - The library maps UBJSON types to JSON value types as follows: - - UBJSON type | JSON value type | marker - ----------- | --------------------------------------- | ------ - no-op | *no value, next value is read* | `N` - null | `null` | `Z` - false | `false` | `F` - true | `true` | `T` - float32 | number_float | `d` - float64 | number_float | `D` - uint8 | number_unsigned | `U` - int8 | number_integer | `i` - int16 | number_integer | `I` - int32 | number_integer | `l` - int64 | number_integer | `L` - high-precision number | number_integer, number_unsigned, or number_float - depends on number string | 'H' - string | string | `S` - char | string | `C` - array | array (optimized values are supported) | `[` - object | object (optimized values are supported) | `{` - - @note The mapping is **complete** in the sense that any UBJSON value can - be converted to a JSON value. - - @param[in] i an input in UBJSON format convertible to an input adapter - @param[in] strict whether to expect the input to be consumed until EOF - (true by default) - @param[in] allow_exceptions whether to throw exceptions in case of a - parse error (optional, true by default) - - @return deserialized JSON value; in case of a parse error and - @a allow_exceptions set to `false`, the return value will be - value_t::discarded. - - @throw parse_error.110 if the given input ends prematurely or the end of - file was not reached when @a strict was set to true - @throw parse_error.112 if a parse error occurs - @throw parse_error.113 if a string could not be parsed successfully - - @complexity Linear in the size of the input @a i. - - @liveexample{The example shows the deserialization of a byte vector in - UBJSON format to a JSON value.,from_ubjson} - - @sa http://ubjson.org - @sa @ref to_ubjson(const basic_json&, const bool, const bool) for the - analogous serialization - @sa @ref from_cbor(detail::input_adapter&&, const bool, const bool, const cbor_tag_handler_t) for the - related CBOR format - @sa @ref from_msgpack(detail::input_adapter&&, const bool, const bool) for - the related MessagePack format - @sa @ref from_bson(detail::input_adapter&&, const bool, const bool) for - the related BSON format - - @since version 3.1.0; added @a allow_exceptions parameter since 3.2.0 - */ - template - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json from_ubjson(InputType&& i, - const bool strict = true, - const bool allow_exceptions = true) - { - basic_json result; - detail::json_sax_dom_parser sdp(result, allow_exceptions); - auto ia = detail::input_adapter(std::forward(i)); - const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::ubjson, &sdp, strict); - return res ? result : basic_json(value_t::discarded); - } - - /*! - @copydoc from_ubjson(detail::input_adapter&&, const bool, const bool) - */ - template - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json from_ubjson(IteratorType first, IteratorType last, - const bool strict = true, - const bool allow_exceptions = true) - { - basic_json result; - detail::json_sax_dom_parser sdp(result, allow_exceptions); - auto ia = detail::input_adapter(std::move(first), std::move(last)); - const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::ubjson, &sdp, strict); - return res ? result : basic_json(value_t::discarded); - } - - template - JSON_HEDLEY_WARN_UNUSED_RESULT - JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_ubjson(ptr, ptr + len)) - static basic_json from_ubjson(const T* ptr, std::size_t len, - const bool strict = true, - const bool allow_exceptions = true) - { - return from_ubjson(ptr, ptr + len, strict, allow_exceptions); - } - - JSON_HEDLEY_WARN_UNUSED_RESULT - JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_ubjson(ptr, ptr + len)) - static basic_json from_ubjson(detail::span_input_adapter&& i, - const bool strict = true, - const bool allow_exceptions = true) - { - basic_json result; - detail::json_sax_dom_parser sdp(result, allow_exceptions); - auto ia = i.get(); - const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::ubjson, &sdp, strict); - return res ? result : basic_json(value_t::discarded); - } - - - /*! - @brief Create a JSON value from an input in BSON format - - Deserializes a given input @a i to a JSON value using the BSON (Binary JSON) - serialization format. - - The library maps BSON record types to JSON value types as follows: - - BSON type | BSON marker byte | JSON value type - --------------- | ---------------- | --------------------------- - double | 0x01 | number_float - string | 0x02 | string - document | 0x03 | object - array | 0x04 | array - binary | 0x05 | still unsupported - undefined | 0x06 | still unsupported - ObjectId | 0x07 | still unsupported - boolean | 0x08 | boolean - UTC Date-Time | 0x09 | still unsupported - null | 0x0A | null - Regular Expr. | 0x0B | still unsupported - DB Pointer | 0x0C | still unsupported - JavaScript Code | 0x0D | still unsupported - Symbol | 0x0E | still unsupported - JavaScript Code | 0x0F | still unsupported - int32 | 0x10 | number_integer - Timestamp | 0x11 | still unsupported - 128-bit decimal float | 0x13 | still unsupported - Max Key | 0x7F | still unsupported - Min Key | 0xFF | still unsupported - - @warning The mapping is **incomplete**. The unsupported mappings - are indicated in the table above. - - @param[in] i an input in BSON format convertible to an input adapter - @param[in] strict whether to expect the input to be consumed until EOF - (true by default) - @param[in] allow_exceptions whether to throw exceptions in case of a - parse error (optional, true by default) - - @return deserialized JSON value; in case of a parse error and - @a allow_exceptions set to `false`, the return value will be - value_t::discarded. - - @throw parse_error.114 if an unsupported BSON record type is encountered - - @complexity Linear in the size of the input @a i. - - @liveexample{The example shows the deserialization of a byte vector in - BSON format to a JSON value.,from_bson} - - @sa http://bsonspec.org/spec.html - @sa @ref to_bson(const basic_json&) for the analogous serialization - @sa @ref from_cbor(detail::input_adapter&&, const bool, const bool, const cbor_tag_handler_t) for the - related CBOR format - @sa @ref from_msgpack(detail::input_adapter&&, const bool, const bool) for - the related MessagePack format - @sa @ref from_ubjson(detail::input_adapter&&, const bool, const bool) for the - related UBJSON format - */ - template - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json from_bson(InputType&& i, - const bool strict = true, - const bool allow_exceptions = true) - { - basic_json result; - detail::json_sax_dom_parser sdp(result, allow_exceptions); - auto ia = detail::input_adapter(std::forward(i)); - const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::bson, &sdp, strict); - return res ? result : basic_json(value_t::discarded); - } - - /*! - @copydoc from_bson(detail::input_adapter&&, const bool, const bool) - */ - template - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json from_bson(IteratorType first, IteratorType last, - const bool strict = true, - const bool allow_exceptions = true) - { - basic_json result; - detail::json_sax_dom_parser sdp(result, allow_exceptions); - auto ia = detail::input_adapter(std::move(first), std::move(last)); - const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::bson, &sdp, strict); - return res ? result : basic_json(value_t::discarded); - } - - template - JSON_HEDLEY_WARN_UNUSED_RESULT - JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_bson(ptr, ptr + len)) - static basic_json from_bson(const T* ptr, std::size_t len, - const bool strict = true, - const bool allow_exceptions = true) - { - return from_bson(ptr, ptr + len, strict, allow_exceptions); - } - - JSON_HEDLEY_WARN_UNUSED_RESULT - JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_bson(ptr, ptr + len)) - static basic_json from_bson(detail::span_input_adapter&& i, - const bool strict = true, - const bool allow_exceptions = true) - { - basic_json result; - detail::json_sax_dom_parser sdp(result, allow_exceptions); - auto ia = i.get(); - const bool res = binary_reader(std::move(ia)).sax_parse(input_format_t::bson, &sdp, strict); - return res ? result : basic_json(value_t::discarded); - } - /// @} - - ////////////////////////// - // JSON Pointer support // - ////////////////////////// - - /// @name JSON Pointer functions - /// @{ - - /*! - @brief access specified element via JSON Pointer - - Uses a JSON pointer to retrieve a reference to the respective JSON value. - No bound checking is performed. Similar to @ref operator[](const typename - object_t::key_type&), `null` values are created in arrays and objects if - necessary. - - In particular: - - If the JSON pointer points to an object key that does not exist, it - is created an filled with a `null` value before a reference to it - is returned. - - If the JSON pointer points to an array index that does not exist, it - is created an filled with a `null` value before a reference to it - is returned. All indices between the current maximum and the given - index are also filled with `null`. - - The special value `-` is treated as a synonym for the index past the - end. - - @param[in] ptr a JSON pointer - - @return reference to the element pointed to by @a ptr - - @complexity Constant. - - @throw parse_error.106 if an array index begins with '0' - @throw parse_error.109 if an array index was not a number - @throw out_of_range.404 if the JSON pointer can not be resolved - - @liveexample{The behavior is shown in the example.,operatorjson_pointer} - - @since version 2.0.0 - */ - reference operator[](const json_pointer& ptr) - { - return ptr.get_unchecked(this); - } - - /*! - @brief access specified element via JSON Pointer - - Uses a JSON pointer to retrieve a reference to the respective JSON value. - No bound checking is performed. The function does not change the JSON - value; no `null` values are created. In particular, the special value - `-` yields an exception. - - @param[in] ptr JSON pointer to the desired element - - @return const reference to the element pointed to by @a ptr - - @complexity Constant. - - @throw parse_error.106 if an array index begins with '0' - @throw parse_error.109 if an array index was not a number - @throw out_of_range.402 if the array index '-' is used - @throw out_of_range.404 if the JSON pointer can not be resolved - - @liveexample{The behavior is shown in the example.,operatorjson_pointer_const} - - @since version 2.0.0 - */ - const_reference operator[](const json_pointer& ptr) const - { - return ptr.get_unchecked(this); - } - - /*! - @brief access specified element via JSON Pointer - - Returns a reference to the element at with specified JSON pointer @a ptr, - with bounds checking. - - @param[in] ptr JSON pointer to the desired element - - @return reference to the element pointed to by @a ptr - - @throw parse_error.106 if an array index in the passed JSON pointer @a ptr - begins with '0'. See example below. - - @throw parse_error.109 if an array index in the passed JSON pointer @a ptr - is not a number. See example below. - - @throw out_of_range.401 if an array index in the passed JSON pointer @a ptr - is out of range. See example below. - - @throw out_of_range.402 if the array index '-' is used in the passed JSON - pointer @a ptr. As `at` provides checked access (and no elements are - implicitly inserted), the index '-' is always invalid. See example below. - - @throw out_of_range.403 if the JSON pointer describes a key of an object - which cannot be found. See example below. - - @throw out_of_range.404 if the JSON pointer @a ptr can not be resolved. - See example below. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes in the JSON value. - - @complexity Constant. - - @since version 2.0.0 - - @liveexample{The behavior is shown in the example.,at_json_pointer} - */ - reference at(const json_pointer& ptr) - { - return ptr.get_checked(this); - } - - /*! - @brief access specified element via JSON Pointer - - Returns a const reference to the element at with specified JSON pointer @a - ptr, with bounds checking. - - @param[in] ptr JSON pointer to the desired element - - @return reference to the element pointed to by @a ptr - - @throw parse_error.106 if an array index in the passed JSON pointer @a ptr - begins with '0'. See example below. - - @throw parse_error.109 if an array index in the passed JSON pointer @a ptr - is not a number. See example below. - - @throw out_of_range.401 if an array index in the passed JSON pointer @a ptr - is out of range. See example below. - - @throw out_of_range.402 if the array index '-' is used in the passed JSON - pointer @a ptr. As `at` provides checked access (and no elements are - implicitly inserted), the index '-' is always invalid. See example below. - - @throw out_of_range.403 if the JSON pointer describes a key of an object - which cannot be found. See example below. - - @throw out_of_range.404 if the JSON pointer @a ptr can not be resolved. - See example below. - - @exceptionsafety Strong guarantee: if an exception is thrown, there are no - changes in the JSON value. - - @complexity Constant. - - @since version 2.0.0 - - @liveexample{The behavior is shown in the example.,at_json_pointer_const} - */ - const_reference at(const json_pointer& ptr) const - { - return ptr.get_checked(this); - } - - /*! - @brief return flattened JSON value - - The function creates a JSON object whose keys are JSON pointers (see [RFC - 6901](https://tools.ietf.org/html/rfc6901)) and whose values are all - primitive. The original JSON value can be restored using the @ref - unflatten() function. - - @return an object that maps JSON pointers to primitive values - - @note Empty objects and arrays are flattened to `null` and will not be - reconstructed correctly by the @ref unflatten() function. - - @complexity Linear in the size the JSON value. - - @liveexample{The following code shows how a JSON object is flattened to an - object whose keys consist of JSON pointers.,flatten} - - @sa @ref unflatten() for the reverse function - - @since version 2.0.0 - */ - basic_json flatten() const - { - basic_json result(value_t::object); - json_pointer::flatten("", *this, result); - return result; - } - - /*! - @brief unflatten a previously flattened JSON value - - The function restores the arbitrary nesting of a JSON value that has been - flattened before using the @ref flatten() function. The JSON value must - meet certain constraints: - 1. The value must be an object. - 2. The keys must be JSON pointers (see - [RFC 6901](https://tools.ietf.org/html/rfc6901)) - 3. The mapped values must be primitive JSON types. - - @return the original JSON from a flattened version - - @note Empty objects and arrays are flattened by @ref flatten() to `null` - values and can not unflattened to their original type. Apart from - this example, for a JSON value `j`, the following is always true: - `j == j.flatten().unflatten()`. - - @complexity Linear in the size the JSON value. - - @throw type_error.314 if value is not an object - @throw type_error.315 if object values are not primitive - - @liveexample{The following code shows how a flattened JSON object is - unflattened into the original nested JSON object.,unflatten} - - @sa @ref flatten() for the reverse function - - @since version 2.0.0 - */ - basic_json unflatten() const - { - return json_pointer::unflatten(*this); - } - - /// @} - - ////////////////////////// - // JSON Patch functions // - ////////////////////////// - - /// @name JSON Patch functions - /// @{ - - /*! - @brief applies a JSON patch - - [JSON Patch](http://jsonpatch.com) defines a JSON document structure for - expressing a sequence of operations to apply to a JSON) document. With - this function, a JSON Patch is applied to the current JSON value by - executing all operations from the patch. - - @param[in] json_patch JSON patch document - @return patched document - - @note The application of a patch is atomic: Either all operations succeed - and the patched document is returned or an exception is thrown. In - any case, the original value is not changed: the patch is applied - to a copy of the value. - - @throw parse_error.104 if the JSON patch does not consist of an array of - objects - - @throw parse_error.105 if the JSON patch is malformed (e.g., mandatory - attributes are missing); example: `"operation add must have member path"` - - @throw out_of_range.401 if an array index is out of range. - - @throw out_of_range.403 if a JSON pointer inside the patch could not be - resolved successfully in the current JSON value; example: `"key baz not - found"` - - @throw out_of_range.405 if JSON pointer has no parent ("add", "remove", - "move") - - @throw other_error.501 if "test" operation was unsuccessful - - @complexity Linear in the size of the JSON value and the length of the - JSON patch. As usually only a fraction of the JSON value is affected by - the patch, the complexity can usually be neglected. - - @liveexample{The following code shows how a JSON patch is applied to a - value.,patch} - - @sa @ref diff -- create a JSON patch by comparing two JSON values - - @sa [RFC 6902 (JSON Patch)](https://tools.ietf.org/html/rfc6902) - @sa [RFC 6901 (JSON Pointer)](https://tools.ietf.org/html/rfc6901) - - @since version 2.0.0 - */ - basic_json patch(const basic_json& json_patch) const - { - // make a working copy to apply the patch to - basic_json result = *this; - - // the valid JSON Patch operations - enum class patch_operations {add, remove, replace, move, copy, test, invalid}; - - const auto get_op = [](const std::string & op) - { - if (op == "add") - { - return patch_operations::add; - } - if (op == "remove") - { - return patch_operations::remove; - } - if (op == "replace") - { - return patch_operations::replace; - } - if (op == "move") - { - return patch_operations::move; - } - if (op == "copy") - { - return patch_operations::copy; - } - if (op == "test") - { - return patch_operations::test; - } - - return patch_operations::invalid; - }; - - // wrapper for "add" operation; add value at ptr - const auto operation_add = [&result](json_pointer & ptr, basic_json val) - { - // adding to the root of the target document means replacing it - if (ptr.empty()) - { - result = val; - return; - } - - // make sure the top element of the pointer exists - json_pointer top_pointer = ptr.top(); - if (top_pointer != ptr) - { - result.at(top_pointer); - } - - // get reference to parent of JSON pointer ptr - const auto last_path = ptr.back(); - ptr.pop_back(); - basic_json& parent = result[ptr]; - - switch (parent.m_type) - { - case value_t::null: - case value_t::object: - { - // use operator[] to add value - parent[last_path] = val; - break; - } - - case value_t::array: - { - if (last_path == "-") - { - // special case: append to back - parent.push_back(val); - } - else - { - const auto idx = json_pointer::array_index(last_path); - if (JSON_HEDLEY_UNLIKELY(idx > parent.size())) - { - // avoid undefined behavior - JSON_THROW(out_of_range::create(401, "array index " + std::to_string(idx) + " is out of range")); - } - - // default case: insert add offset - parent.insert(parent.begin() + static_cast(idx), val); - } - break; - } - - // if there exists a parent it cannot be primitive - default: // LCOV_EXCL_LINE - JSON_ASSERT(false); // LCOV_EXCL_LINE - } - }; - - // wrapper for "remove" operation; remove value at ptr - const auto operation_remove = [&result](json_pointer & ptr) - { - // get reference to parent of JSON pointer ptr - const auto last_path = ptr.back(); - ptr.pop_back(); - basic_json& parent = result.at(ptr); - - // remove child - if (parent.is_object()) - { - // perform range check - auto it = parent.find(last_path); - if (JSON_HEDLEY_LIKELY(it != parent.end())) - { - parent.erase(it); - } - else - { - JSON_THROW(out_of_range::create(403, "key '" + last_path + "' not found")); - } - } - else if (parent.is_array()) - { - // note erase performs range check - parent.erase(json_pointer::array_index(last_path)); - } - }; - - // type check: top level value must be an array - if (JSON_HEDLEY_UNLIKELY(!json_patch.is_array())) - { - JSON_THROW(parse_error::create(104, 0, "JSON patch must be an array of objects")); - } - - // iterate and apply the operations - for (const auto& val : json_patch) - { - // wrapper to get a value for an operation - const auto get_value = [&val](const std::string & op, - const std::string & member, - bool string_type) -> basic_json & - { - // find value - auto it = val.m_value.object->find(member); - - // context-sensitive error message - const auto error_msg = (op == "op") ? "operation" : "operation '" + op + "'"; - - // check if desired value is present - if (JSON_HEDLEY_UNLIKELY(it == val.m_value.object->end())) - { - JSON_THROW(parse_error::create(105, 0, error_msg + " must have member '" + member + "'")); - } - - // check if result is of type string - if (JSON_HEDLEY_UNLIKELY(string_type && !it->second.is_string())) - { - JSON_THROW(parse_error::create(105, 0, error_msg + " must have string member '" + member + "'")); - } - - // no error: return value - return it->second; - }; - - // type check: every element of the array must be an object - if (JSON_HEDLEY_UNLIKELY(!val.is_object())) - { - JSON_THROW(parse_error::create(104, 0, "JSON patch must be an array of objects")); - } - - // collect mandatory members - const auto op = get_value("op", "op", true).template get(); - const auto path = get_value(op, "path", true).template get(); - json_pointer ptr(path); - - switch (get_op(op)) - { - case patch_operations::add: - { - operation_add(ptr, get_value("add", "value", false)); - break; - } - - case patch_operations::remove: - { - operation_remove(ptr); - break; - } - - case patch_operations::replace: - { - // the "path" location must exist - use at() - result.at(ptr) = get_value("replace", "value", false); - break; - } - - case patch_operations::move: - { - const auto from_path = get_value("move", "from", true).template get(); - json_pointer from_ptr(from_path); - - // the "from" location must exist - use at() - basic_json v = result.at(from_ptr); - - // The move operation is functionally identical to a - // "remove" operation on the "from" location, followed - // immediately by an "add" operation at the target - // location with the value that was just removed. - operation_remove(from_ptr); - operation_add(ptr, v); - break; - } - - case patch_operations::copy: - { - const auto from_path = get_value("copy", "from", true).template get(); - const json_pointer from_ptr(from_path); - - // the "from" location must exist - use at() - basic_json v = result.at(from_ptr); - - // The copy is functionally identical to an "add" - // operation at the target location using the value - // specified in the "from" member. - operation_add(ptr, v); - break; - } - - case patch_operations::test: - { - bool success = false; - JSON_TRY - { - // check if "value" matches the one at "path" - // the "path" location must exist - use at() - success = (result.at(ptr) == get_value("test", "value", false)); - } - JSON_INTERNAL_CATCH (out_of_range&) - { - // ignore out of range errors: success remains false - } - - // throw an exception if test fails - if (JSON_HEDLEY_UNLIKELY(!success)) - { - JSON_THROW(other_error::create(501, "unsuccessful: " + val.dump())); - } - - break; - } - - default: - { - // op must be "add", "remove", "replace", "move", "copy", or - // "test" - JSON_THROW(parse_error::create(105, 0, "operation value '" + op + "' is invalid")); - } - } - } - - return result; - } - - /*! - @brief creates a diff as a JSON patch - - Creates a [JSON Patch](http://jsonpatch.com) so that value @a source can - be changed into the value @a target by calling @ref patch function. - - @invariant For two JSON values @a source and @a target, the following code - yields always `true`: - @code {.cpp} - source.patch(diff(source, target)) == target; - @endcode - - @note Currently, only `remove`, `add`, and `replace` operations are - generated. - - @param[in] source JSON value to compare from - @param[in] target JSON value to compare against - @param[in] path helper value to create JSON pointers - - @return a JSON patch to convert the @a source to @a target - - @complexity Linear in the lengths of @a source and @a target. - - @liveexample{The following code shows how a JSON patch is created as a - diff for two JSON values.,diff} - - @sa @ref patch -- apply a JSON patch - @sa @ref merge_patch -- apply a JSON Merge Patch - - @sa [RFC 6902 (JSON Patch)](https://tools.ietf.org/html/rfc6902) - - @since version 2.0.0 - */ - JSON_HEDLEY_WARN_UNUSED_RESULT - static basic_json diff(const basic_json& source, const basic_json& target, - const std::string& path = "") - { - // the patch - basic_json result(value_t::array); - - // if the values are the same, return empty patch - if (source == target) - { - return result; - } - - if (source.type() != target.type()) - { - // different types: replace value - result.push_back( - { - {"op", "replace"}, {"path", path}, {"value", target} - }); - return result; - } - - switch (source.type()) - { - case value_t::array: - { - // first pass: traverse common elements - std::size_t i = 0; - while (i < source.size() && i < target.size()) - { - // recursive call to compare array values at index i - auto temp_diff = diff(source[i], target[i], path + "/" + std::to_string(i)); - result.insert(result.end(), temp_diff.begin(), temp_diff.end()); - ++i; - } - - // i now reached the end of at least one array - // in a second pass, traverse the remaining elements - - // remove my remaining elements - const auto end_index = static_cast(result.size()); - while (i < source.size()) - { - // add operations in reverse order to avoid invalid - // indices - result.insert(result.begin() + end_index, object( - { - {"op", "remove"}, - {"path", path + "/" + std::to_string(i)} - })); - ++i; - } - - // add other remaining elements - while (i < target.size()) - { - result.push_back( - { - {"op", "add"}, - {"path", path + "/-"}, - {"value", target[i]} - }); - ++i; - } - - break; - } - - case value_t::object: - { - // first pass: traverse this object's elements - for (auto it = source.cbegin(); it != source.cend(); ++it) - { - // escape the key name to be used in a JSON patch - const auto key = json_pointer::escape(it.key()); - - if (target.find(it.key()) != target.end()) - { - // recursive call to compare object values at key it - auto temp_diff = diff(it.value(), target[it.key()], path + "/" + key); - result.insert(result.end(), temp_diff.begin(), temp_diff.end()); - } - else - { - // found a key that is not in o -> remove it - result.push_back(object( - { - {"op", "remove"}, {"path", path + "/" + key} - })); - } - } - - // second pass: traverse other object's elements - for (auto it = target.cbegin(); it != target.cend(); ++it) - { - if (source.find(it.key()) == source.end()) - { - // found a key that is not in this -> add it - const auto key = json_pointer::escape(it.key()); - result.push_back( - { - {"op", "add"}, {"path", path + "/" + key}, - {"value", it.value()} - }); - } - } - - break; - } - - default: - { - // both primitive type: replace value - result.push_back( - { - {"op", "replace"}, {"path", path}, {"value", target} - }); - break; - } - } - - return result; - } - - /// @} - - //////////////////////////////// - // JSON Merge Patch functions // - //////////////////////////////// - - /// @name JSON Merge Patch functions - /// @{ - - /*! - @brief applies a JSON Merge Patch - - The merge patch format is primarily intended for use with the HTTP PATCH - method as a means of describing a set of modifications to a target - resource's content. This function applies a merge patch to the current - JSON value. - - The function implements the following algorithm from Section 2 of - [RFC 7396 (JSON Merge Patch)](https://tools.ietf.org/html/rfc7396): - - ``` - define MergePatch(Target, Patch): - if Patch is an Object: - if Target is not an Object: - Target = {} // Ignore the contents and set it to an empty Object - for each Name/Value pair in Patch: - if Value is null: - if Name exists in Target: - remove the Name/Value pair from Target - else: - Target[Name] = MergePatch(Target[Name], Value) - return Target - else: - return Patch - ``` - - Thereby, `Target` is the current object; that is, the patch is applied to - the current value. - - @param[in] apply_patch the patch to apply - - @complexity Linear in the lengths of @a patch. - - @liveexample{The following code shows how a JSON Merge Patch is applied to - a JSON document.,merge_patch} - - @sa @ref patch -- apply a JSON patch - @sa [RFC 7396 (JSON Merge Patch)](https://tools.ietf.org/html/rfc7396) - - @since version 3.0.0 - */ - void merge_patch(const basic_json& apply_patch) - { - if (apply_patch.is_object()) - { - if (!is_object()) - { - *this = object(); - } - for (auto it = apply_patch.begin(); it != apply_patch.end(); ++it) - { - if (it.value().is_null()) - { - erase(it.key()); - } - else - { - operator[](it.key()).merge_patch(it.value()); - } - } - } - else - { - *this = apply_patch; - } - } - - /// @} -}; - -/*! -@brief user-defined to_string function for JSON values - -This function implements a user-defined to_string for JSON objects. - -@param[in] j a JSON object -@return a std::string object -*/ - -NLOHMANN_BASIC_JSON_TPL_DECLARATION -std::string to_string(const NLOHMANN_BASIC_JSON_TPL& j) -{ - return j.dump(); -} -} // namespace nlohmann - -/////////////////////// -// nonmember support // -/////////////////////// - -// specialization of std::swap, and std::hash -namespace std -{ - -/// hash value for JSON objects -template<> -struct hash -{ - /*! - @brief return a hash value for a JSON object - - @since version 1.0.0 - */ - std::size_t operator()(const nlohmann::json& j) const - { - return nlohmann::detail::hash(j); - } -}; - -/// specialization for std::less -/// @note: do not remove the space after '<', -/// see https://github.com/nlohmann/json/pull/679 -template<> -struct less<::nlohmann::detail::value_t> -{ - /*! - @brief compare two value_t enum values - @since version 3.0.0 - */ - bool operator()(nlohmann::detail::value_t lhs, - nlohmann::detail::value_t rhs) const noexcept - { - return nlohmann::detail::operator<(lhs, rhs); - } -}; - -// C++20 prohibit function specialization in the std namespace. -#ifndef JSON_HAS_CPP_20 - -/*! -@brief exchanges the values of two JSON objects - -@since version 1.0.0 -*/ -template<> -inline void swap(nlohmann::json& j1, nlohmann::json& j2) noexcept( - is_nothrow_move_constructible::value&& - is_nothrow_move_assignable::value - ) -{ - j1.swap(j2); -} - -#endif - -} // namespace std - -/*! -@brief user-defined string literal for JSON values - -This operator implements a user-defined string literal for JSON objects. It -can be used by adding `"_json"` to a string literal and returns a JSON object -if no parse error occurred. - -@param[in] s a string representation of a JSON object -@param[in] n the length of string @a s -@return a JSON object - -@since version 1.0.0 -*/ -JSON_HEDLEY_NON_NULL(1) -inline nlohmann::json operator "" _json(const char* s, std::size_t n) -{ - return nlohmann::json::parse(s, s + n); -} - -/*! -@brief user-defined string literal for JSON pointer - -This operator implements a user-defined string literal for JSON Pointers. It -can be used by adding `"_json_pointer"` to a string literal and returns a JSON pointer -object if no parse error occurred. - -@param[in] s a string representation of a JSON Pointer -@param[in] n the length of string @a s -@return a JSON pointer object - -@since version 2.0.0 -*/ -JSON_HEDLEY_NON_NULL(1) -inline nlohmann::json::json_pointer operator "" _json_pointer(const char* s, std::size_t n) -{ - return nlohmann::json::json_pointer(std::string(s, n)); -} - -// #include - - -// restore GCC/clang diagnostic settings -#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) - #pragma GCC diagnostic pop -#endif -#if defined(__clang__) - #pragma GCC diagnostic pop -#endif - -// clean up -#undef JSON_ASSERT -#undef JSON_INTERNAL_CATCH -#undef JSON_CATCH -#undef JSON_THROW -#undef JSON_TRY -#undef JSON_HAS_CPP_14 -#undef JSON_HAS_CPP_17 -#undef NLOHMANN_BASIC_JSON_TPL_DECLARATION -#undef NLOHMANN_BASIC_JSON_TPL -#undef JSON_EXPLICIT - -// #include -#undef JSON_HEDLEY_ALWAYS_INLINE -#undef JSON_HEDLEY_ARM_VERSION -#undef JSON_HEDLEY_ARM_VERSION_CHECK -#undef JSON_HEDLEY_ARRAY_PARAM -#undef JSON_HEDLEY_ASSUME -#undef JSON_HEDLEY_BEGIN_C_DECLS -#undef JSON_HEDLEY_CLANG_HAS_ATTRIBUTE -#undef JSON_HEDLEY_CLANG_HAS_BUILTIN -#undef JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE -#undef JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE -#undef JSON_HEDLEY_CLANG_HAS_EXTENSION -#undef JSON_HEDLEY_CLANG_HAS_FEATURE -#undef JSON_HEDLEY_CLANG_HAS_WARNING -#undef JSON_HEDLEY_COMPCERT_VERSION -#undef JSON_HEDLEY_COMPCERT_VERSION_CHECK -#undef JSON_HEDLEY_CONCAT -#undef JSON_HEDLEY_CONCAT3 -#undef JSON_HEDLEY_CONCAT3_EX -#undef JSON_HEDLEY_CONCAT_EX -#undef JSON_HEDLEY_CONST -#undef JSON_HEDLEY_CONSTEXPR -#undef JSON_HEDLEY_CONST_CAST -#undef JSON_HEDLEY_CPP_CAST -#undef JSON_HEDLEY_CRAY_VERSION -#undef JSON_HEDLEY_CRAY_VERSION_CHECK -#undef JSON_HEDLEY_C_DECL -#undef JSON_HEDLEY_DEPRECATED -#undef JSON_HEDLEY_DEPRECATED_FOR -#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL -#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ -#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED -#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES -#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS -#undef JSON_HEDLEY_DIAGNOSTIC_POP -#undef JSON_HEDLEY_DIAGNOSTIC_PUSH -#undef JSON_HEDLEY_DMC_VERSION -#undef JSON_HEDLEY_DMC_VERSION_CHECK -#undef JSON_HEDLEY_EMPTY_BASES -#undef JSON_HEDLEY_EMSCRIPTEN_VERSION -#undef JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK -#undef JSON_HEDLEY_END_C_DECLS -#undef JSON_HEDLEY_FLAGS -#undef JSON_HEDLEY_FLAGS_CAST -#undef JSON_HEDLEY_GCC_HAS_ATTRIBUTE -#undef JSON_HEDLEY_GCC_HAS_BUILTIN -#undef JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE -#undef JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE -#undef JSON_HEDLEY_GCC_HAS_EXTENSION -#undef JSON_HEDLEY_GCC_HAS_FEATURE -#undef JSON_HEDLEY_GCC_HAS_WARNING -#undef JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK -#undef JSON_HEDLEY_GCC_VERSION -#undef JSON_HEDLEY_GCC_VERSION_CHECK -#undef JSON_HEDLEY_GNUC_HAS_ATTRIBUTE -#undef JSON_HEDLEY_GNUC_HAS_BUILTIN -#undef JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE -#undef JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE -#undef JSON_HEDLEY_GNUC_HAS_EXTENSION -#undef JSON_HEDLEY_GNUC_HAS_FEATURE -#undef JSON_HEDLEY_GNUC_HAS_WARNING -#undef JSON_HEDLEY_GNUC_VERSION -#undef JSON_HEDLEY_GNUC_VERSION_CHECK -#undef JSON_HEDLEY_HAS_ATTRIBUTE -#undef JSON_HEDLEY_HAS_BUILTIN -#undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE -#undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS -#undef JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE -#undef JSON_HEDLEY_HAS_EXTENSION -#undef JSON_HEDLEY_HAS_FEATURE -#undef JSON_HEDLEY_HAS_WARNING -#undef JSON_HEDLEY_IAR_VERSION -#undef JSON_HEDLEY_IAR_VERSION_CHECK -#undef JSON_HEDLEY_IBM_VERSION -#undef JSON_HEDLEY_IBM_VERSION_CHECK -#undef JSON_HEDLEY_IMPORT -#undef JSON_HEDLEY_INLINE -#undef JSON_HEDLEY_INTEL_VERSION -#undef JSON_HEDLEY_INTEL_VERSION_CHECK -#undef JSON_HEDLEY_IS_CONSTANT -#undef JSON_HEDLEY_IS_CONSTEXPR_ -#undef JSON_HEDLEY_LIKELY -#undef JSON_HEDLEY_MALLOC -#undef JSON_HEDLEY_MESSAGE -#undef JSON_HEDLEY_MSVC_VERSION -#undef JSON_HEDLEY_MSVC_VERSION_CHECK -#undef JSON_HEDLEY_NEVER_INLINE -#undef JSON_HEDLEY_NON_NULL -#undef JSON_HEDLEY_NO_ESCAPE -#undef JSON_HEDLEY_NO_RETURN -#undef JSON_HEDLEY_NO_THROW -#undef JSON_HEDLEY_NULL -#undef JSON_HEDLEY_PELLES_VERSION -#undef JSON_HEDLEY_PELLES_VERSION_CHECK -#undef JSON_HEDLEY_PGI_VERSION -#undef JSON_HEDLEY_PGI_VERSION_CHECK -#undef JSON_HEDLEY_PREDICT -#undef JSON_HEDLEY_PRINTF_FORMAT -#undef JSON_HEDLEY_PRIVATE -#undef JSON_HEDLEY_PUBLIC -#undef JSON_HEDLEY_PURE -#undef JSON_HEDLEY_REINTERPRET_CAST -#undef JSON_HEDLEY_REQUIRE -#undef JSON_HEDLEY_REQUIRE_CONSTEXPR -#undef JSON_HEDLEY_REQUIRE_MSG -#undef JSON_HEDLEY_RESTRICT -#undef JSON_HEDLEY_RETURNS_NON_NULL -#undef JSON_HEDLEY_SENTINEL -#undef JSON_HEDLEY_STATIC_ASSERT -#undef JSON_HEDLEY_STATIC_CAST -#undef JSON_HEDLEY_STRINGIFY -#undef JSON_HEDLEY_STRINGIFY_EX -#undef JSON_HEDLEY_SUNPRO_VERSION -#undef JSON_HEDLEY_SUNPRO_VERSION_CHECK -#undef JSON_HEDLEY_TINYC_VERSION -#undef JSON_HEDLEY_TINYC_VERSION_CHECK -#undef JSON_HEDLEY_TI_ARMCL_VERSION -#undef JSON_HEDLEY_TI_ARMCL_VERSION_CHECK -#undef JSON_HEDLEY_TI_CL2000_VERSION -#undef JSON_HEDLEY_TI_CL2000_VERSION_CHECK -#undef JSON_HEDLEY_TI_CL430_VERSION -#undef JSON_HEDLEY_TI_CL430_VERSION_CHECK -#undef JSON_HEDLEY_TI_CL6X_VERSION -#undef JSON_HEDLEY_TI_CL6X_VERSION_CHECK -#undef JSON_HEDLEY_TI_CL7X_VERSION -#undef JSON_HEDLEY_TI_CL7X_VERSION_CHECK -#undef JSON_HEDLEY_TI_CLPRU_VERSION -#undef JSON_HEDLEY_TI_CLPRU_VERSION_CHECK -#undef JSON_HEDLEY_TI_VERSION -#undef JSON_HEDLEY_TI_VERSION_CHECK -#undef JSON_HEDLEY_UNAVAILABLE -#undef JSON_HEDLEY_UNLIKELY -#undef JSON_HEDLEY_UNPREDICTABLE -#undef JSON_HEDLEY_UNREACHABLE -#undef JSON_HEDLEY_UNREACHABLE_RETURN -#undef JSON_HEDLEY_VERSION -#undef JSON_HEDLEY_VERSION_DECODE_MAJOR -#undef JSON_HEDLEY_VERSION_DECODE_MINOR -#undef JSON_HEDLEY_VERSION_DECODE_REVISION -#undef JSON_HEDLEY_VERSION_ENCODE -#undef JSON_HEDLEY_WARNING -#undef JSON_HEDLEY_WARN_UNUSED_RESULT -#undef JSON_HEDLEY_WARN_UNUSED_RESULT_MSG -#undef JSON_HEDLEY_FALL_THROUGH - - - -#endif // INCLUDE_NLOHMANN_JSON_HPP_ diff --git a/internal/core/thirdparty/profiler/CMakeLists.txt b/internal/core/thirdparty/profiler/CMakeLists.txt deleted file mode 100644 index afbca20891ca8..0000000000000 --- a/internal/core/thirdparty/profiler/CMakeLists.txt +++ /dev/null @@ -1,105 +0,0 @@ -#------------------------------------------------------------------------------- -# Copyright (C) 2019-2020 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing permissions and limitations under the License. -#------------------------------------------------------------------------------- - -if ( DEFINED ENV{KNOWHERE_LIBUNWIND_URL} ) - set( LIBUNWIND_SOURCE_URL "$ENV{KNOWHERE_LIBUNWIND_URL}" ) -else () - set( LIBUNWIND_SOURCE_URL - "https://github.com/libunwind/libunwind/releases/download/v${LIBUNWIND_VERSION}/libunwind-${LIBUNWIND_VERSION}.tar.gz" ) -endif () - -if ( DEFINED ENV{KNOWHERE_GPERFTOOLS_URL} ) - set( GPERFTOOLS_SOURCE_URL "$ENV{KNOWHERE_GPERFTOOLS_URL}" ) -else () - set( GPERFTOOLS_SOURCE_URL - "https://github.com/gperftools/gperftools/releases/download/gperftools-${GPERFTOOLS_VERSION}/gperftools-${GPERFTOOLS_VERSION}.tar.gz" ) -endif () - -# ---------------------------------------------------------------------- -# libunwind - -macro( build_libunwind ) - message( STATUS "Building libunwind-${LIBUNWIND_VERSION} from source" ) - - set( LIBUNWIND_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/libunwind) - ExternalProject_Add( - libunwind_ep - DOWNLOAD_DIR ${THIRDPARTY_DOWNLOAD_PATH} - INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/libunwind - URL ${LIBUNWIND_SOURCE_URL} - URL_MD5 f625b6a98ac1976116c71708a73dc44a - CONFIGURE_COMMAND /configure - "--prefix=" - "--quiet" - "--disable-tests" - "cc=${EP_C_COMPILER}" - "cxx=${EP_CXX_COMPILER}" - BUILD_COMMAND ${MAKE} ${MAKE_BUILD_ARGS} - INSTALL_COMMAND ${MAKE} install - ${EP_LOG_OPTIONS} ) - - ExternalProject_Get_Property( libunwind_ep INSTALL_DIR ) - file(MAKE_DIRECTORY ${INSTALL_DIR}/include) - add_library( libunwind SHARED IMPORTED ) - set_target_properties( - libunwind PROPERTIES - IMPORTED_GLOBAL TRUE - IMPORTED_LOCATION "${INSTALL_DIR}/lib/libunwind.so" - INTERFACE_INCLUDE_DIRECTORIES "${INSTALL_DIR}/include" ) - - add_dependencies( libunwind libunwind_ep ) -endmacro() - - -# ---------------------------------------------------------------------- -# gperftools - -macro( build_gperftools ) - message( STATUS "Building gperftools-${GPERFTOOLS_VERSION} from source" ) - - ExternalProject_Add( - gperftools_ep - DEPENDS libunwind_ep - DOWNLOAD_DIR ${CMAKE_BINARY_DIR}/3rdparty_download/download - INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/gperftools - URL ${GPERFTOOLS_SOURCE_URL} - URL_MD5 cb21f2ebe71bbc8d5ad101b310be980a - CONFIGURE_COMMAND /configure - "--prefix=" - "--quiet" - "cc=${EP_C_COMPILER}" - "cxx=${EP_CXX_COMPILER}" - BUILD_COMMAND ${MAKE} ${MAKE_BUILD_ARGS} - INSTALL_COMMAND ${MAKE} install - ${EP_LOG_OPTIONS} ) - - ExternalProject_Get_Property( gperftools_ep INSTALL_DIR ) - file(MAKE_DIRECTORY ${INSTALL_DIR}/include) - # libprofiler.so - add_library( gperftools SHARED IMPORTED ) - set_target_properties( gperftools - PROPERTIES - IMPORTED_GLOBAL TRUE - IMPORTED_LOCATION "${INSTALL_DIR}/lib/libtcmalloc_and_profiler.so" - INTERFACE_INCLUDE_DIRECTORIES "${INSTALL_DIR}/include" - INTERFACE_LINK_LIBRARIES libunwind - ) - add_dependencies( gperftools gperftools_ep ) -endmacro() - - -build_libunwind() -build_gperftools() -get_target_property( GPERFTOOLS_LIB gperftools LOCATION ) -install(FILES ${GPERFTOOLS_LIB} DESTINATION ${CMAKE_INSTALL_PREFIX}) - diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 6179b38b61c27..4d5971b9b9f4a 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -25,6 +25,7 @@ set(MILVUS_TEST_FILES test_concurrent_vector.cpp test_c_api.cpp test_expr.cpp + test_float16.cpp test_growing.cpp test_growing_index.cpp test_indexing.cpp @@ -54,11 +55,14 @@ set(MILVUS_TEST_FILES test_offset_ordered_map.cpp test_offset_ordered_array.cpp test_always_true_expr.cpp - test_plan_proto.cpp) + test_plan_proto.cpp + test_chunk_cache.cpp + ) if ( BUILD_DISK_ANN STREQUAL "ON" ) set(MILVUS_TEST_FILES ${MILVUS_TEST_FILES} + #need update aws-sdk-cpp, see more from https://github.com/aws/aws-sdk-cpp/issues/1757 #test_minio_chunk_manager.cpp ) endif() @@ -68,7 +72,17 @@ if (LINUX OR APPLE) ${MILVUS_TEST_FILES} test_scalar_index_creator.cpp test_string_index.cpp - ) + test_array.cpp test_array_expr.cpp) +endif() + +if (DEFINED AZURE_BUILD_DIR) + set(MILVUS_TEST_FILES + ${MILVUS_TEST_FILES} + test_azure_chunk_manager.cpp + #need update aws-sdk-cpp, see more from https://github.com/aws/aws-sdk-cpp/issues/2119 + #test_remote_chunk_manager.cpp + ) + include_directories("${AZURE_BUILD_DIR}/vcpkg_installed/${VCPKG_TARGET_TRIPLET}/include") endif() if (LINUX) diff --git a/internal/core/unittest/bench/CMakeLists.txt b/internal/core/unittest/bench/CMakeLists.txt index b59c27926dc89..6147e3733205d 100644 --- a/internal/core/unittest/bench/CMakeLists.txt +++ b/internal/core/unittest/bench/CMakeLists.txt @@ -12,7 +12,7 @@ include_directories(${CMAKE_HOME_DIRECTORY}/src) include_directories(${CMAKE_HOME_DIRECTORY}/unittest) -set(bench_srcs +set(bench_srcs bench_naive.cpp bench_search.cpp ) @@ -38,7 +38,6 @@ target_link_libraries(indexbuilder_bench milvus_log pthread knowhere - milvus_utils ) target_link_libraries(indexbuilder_bench benchmark_main) diff --git a/internal/core/unittest/bench/bench_indexbuilder.cpp b/internal/core/unittest/bench/bench_indexbuilder.cpp index e85c927cbf456..1b5c91a48786e 100644 --- a/internal/core/unittest/bench/bench_indexbuilder.cpp +++ b/internal/core/unittest/bench/bench_indexbuilder.cpp @@ -60,6 +60,8 @@ IndexBuilder_build(benchmark::State& state) { const auto& param = index_params.params(i); config[param.key()] = param.value(); } + config[milvus::index::INDEX_ENGINE_VERSION] = + std::to_string(knowhere::Version::GetCurrentVersion().VersionNumber()); auto is_binary = state.range(2); auto dataset = GenDataset(NB, metric_type, is_binary); @@ -68,7 +70,7 @@ IndexBuilder_build(benchmark::State& state) { for (auto _ : state) { auto index = std::make_unique( - milvus::DataType::VECTOR_FLOAT, config, nullptr); + milvus::DataType::VECTOR_FLOAT, config); index->Build(xb_dataset); } } @@ -102,7 +104,7 @@ IndexBuilder_build_and_codec(benchmark::State& state) { for (auto _ : state) { auto index = std::make_unique( - milvus::DataType::VECTOR_FLOAT, config, nullptr); + milvus::DataType::VECTOR_FLOAT, config); index->Build(xb_dataset); index->Serialize(); diff --git a/internal/core/unittest/test_array.cpp b/internal/core/unittest/test_array.cpp new file mode 100644 index 0000000000000..de30aaad91f1b --- /dev/null +++ b/internal/core/unittest/test_array.cpp @@ -0,0 +1,121 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include + +#include "common/Array.h" + +TEST(Array, TestConstructArray) { + using namespace milvus; + + int N = 10; + milvus::proto::schema::ScalarField field_int_data; + milvus::proto::plan::Array field_int_array; + field_int_array.set_same_type(true); + for (int i = 0; i < N; i++) { + field_int_data.mutable_int_data()->add_data(i); + field_int_array.mutable_array()->Add()->set_int64_val(i); + } + auto int_array = Array(field_int_data); + ASSERT_EQ(N, int_array.length()); + ASSERT_EQ(N * sizeof(int), int_array.byte_size()); + for (int i = 0; i < N; ++i) { + ASSERT_EQ(int_array.get_data(i), i); + } + ASSERT_TRUE(int_array.is_same_array(field_int_array)); + + milvus::proto::schema::ScalarField field_long_data; + milvus::proto::plan::Array field_long_array; + field_long_array.set_same_type(true); + for (int i = 0; i < N; i++) { + field_long_data.mutable_long_data()->add_data(i); + field_long_array.mutable_array()->Add()->set_int64_val(i); + } + auto long_array = Array(field_long_data); + ASSERT_EQ(N, long_array.length()); + ASSERT_EQ(N * sizeof(int64_t), long_array.byte_size()); + for (int i = 0; i < N; ++i) { + ASSERT_EQ(long_array.get_data(i), i); + } + ASSERT_TRUE(long_array.is_same_array(field_int_array)); + + milvus::proto::schema::ScalarField field_string_data; + milvus::proto::plan::Array field_string_array; + field_string_array.set_same_type(true); + for (int i = 0; i < N; i++) { + field_string_data.mutable_string_data()->add_data(std::to_string(i)); + proto::plan::GenericValue string_val; + string_val.set_string_val(std::to_string(i)); + field_string_array.mutable_array()->Add()->CopyFrom(string_val); + } + auto string_array = Array(field_string_data); + ASSERT_EQ(N, string_array.length()); + // ASSERT_EQ(N, string_array.size()); + for (int i = 0; i < N; ++i) { + ASSERT_EQ(string_array.get_data(i), + std::to_string(i)); + } + ASSERT_TRUE(string_array.is_same_array(field_string_array)); + + milvus::proto::schema::ScalarField field_bool_data; + milvus::proto::plan::Array field_bool_array; + field_bool_array.set_same_type(true); + for (int i = 0; i < N; i++) { + field_bool_data.mutable_bool_data()->add_data(bool(i)); + field_bool_array.mutable_array()->Add()->set_bool_val(bool(i)); + } + auto bool_array = Array(field_bool_data); + ASSERT_EQ(N, bool_array.length()); + ASSERT_EQ(N * sizeof(bool), bool_array.byte_size()); + for (int i = 0; i < N; ++i) { + ASSERT_EQ(bool_array.get_data(i), bool(i)); + } + ASSERT_TRUE(bool_array.is_same_array(field_bool_array)); + + milvus::proto::schema::ScalarField field_float_data; + milvus::proto::plan::Array field_float_array; + field_float_array.set_same_type(true); + for (int i = 0; i < N; i++) { + field_float_data.mutable_float_data()->add_data(float(i) * 0.1); + field_float_array.mutable_array()->Add()->set_float_val(float(i * 0.1)); + } + auto float_array = Array(field_float_data); + ASSERT_EQ(N, float_array.length()); + ASSERT_EQ(N * sizeof(float), float_array.byte_size()); + for (int i = 0; i < N; ++i) { + ASSERT_DOUBLE_EQ(float_array.get_data(i), float(i * 0.1)); + } + ASSERT_TRUE(float_array.is_same_array(field_float_array)); + + milvus::proto::schema::ScalarField field_double_data; + milvus::proto::plan::Array field_double_array; + field_double_array.set_same_type(true); + for (int i = 0; i < N; i++) { + field_double_data.mutable_double_data()->add_data(double(i) * 0.1); + field_double_array.mutable_array()->Add()->set_float_val( + double(i * 0.1)); + } + auto double_array = Array(field_double_data); + ASSERT_EQ(N, double_array.length()); + ASSERT_EQ(N * sizeof(double), double_array.byte_size()); + for (int i = 0; i < N; ++i) { + ASSERT_DOUBLE_EQ(double_array.get_data(i), double(i * 0.1)); + } + ASSERT_TRUE(double_array.is_same_array(field_double_array)); + + milvus::proto::schema::ScalarField field_empty_data; + milvus::proto::plan::Array field_empty_array; + auto empty_array = Array(field_empty_data); + ASSERT_EQ(0, empty_array.length()); + ASSERT_EQ(0, empty_array.byte_size()); + ASSERT_TRUE(empty_array.is_same_array(field_empty_array)); +} diff --git a/internal/core/unittest/test_array_expr.cpp b/internal/core/unittest/test_array_expr.cpp new file mode 100644 index 0000000000000..0b03e1d180983 --- /dev/null +++ b/internal/core/unittest/test_array_expr.cpp @@ -0,0 +1,1652 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include +#include +#include +#include + +#include "common/Types.h" +#include "pb/plan.pb.h" +#include "query/Expr.h" +#include "query/ExprImpl.h" +#include "query/Plan.h" +#include "query/PlanNode.h" +#include "query/generated/ExecExprVisitor.h" +#include "segcore/SegmentGrowingImpl.h" +#include "simdjson/padded_string.h" +#include "test_utils/DataGen.h" +#include "index/IndexFactory.h" + +TEST(Expr, TestArrayRange) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + std::vector>> + testcases = { + {R"(binary_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + lower_inclusive: false, + upper_inclusive: false, + lower_value: < + int64_val: 1 + > + upper_value: < + int64_val: 10000 + > + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return 1 < val && val < 10000; + }}, + {R"(binary_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + lower_inclusive: true, + upper_inclusive: false, + lower_value: < + int64_val: 1 + > + upper_value: < + int64_val: 10000 + > + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return 1 <= val && val < 10000; + }}, + {R"(binary_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + lower_inclusive: false, + upper_inclusive: true, + lower_value: < + int64_val: 1 + > + upper_value: < + int64_val: 10000 + > + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return 1 < val && val <= 10000; + }}, + {R"(binary_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + lower_inclusive: true, + upper_inclusive: true, + lower_value: < + int64_val: 1 + > + upper_value: < + int64_val: 10000 + > + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return 1 <= val && val <= 10000; + }}, + {R"(binary_range_expr: < + column_info: < + field_id: 104 + data_type: Array + nested_path:"0" + element_type:VarChar + > + lower_inclusive: true, + upper_inclusive: true, + lower_value: < + string_val: "aaa" + > + upper_value: < + string_val: "zzz" + > + >)", + "string", + [](milvus::Array& array) { + auto val = array.get_data(0); + return "aaa" <= val && val <= "zzz"; + }}, + {R"(binary_range_expr: < + column_info: < + field_id: 105 + data_type: Array + nested_path:"0" + element_type:Float + > + lower_inclusive: true, + upper_inclusive: true, + lower_value: < + float_val: 1.1 + > + upper_value: < + float_val: 2048.12 + > + >)", + "float", + [](milvus::Array& array) { + auto val = array.get_data(0); + return 1.1 <= val && val <= 2048.12; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + op: GreaterEqual, + value: < + int64_val: 10000 + > + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val >= 10000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + op: GreaterThan, + value: < + int64_val: 2000 + > + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val > 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + op: LessEqual, + value: < + int64_val: 2000 + > + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val <= 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + op: LessThan, + value: < + int64_val: 2000 + > + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val < 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + op: Equal, + value: < + int64_val: 2000 + > + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val == 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + op: NotEqual, + value: < + int64_val: 2000 + > + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val != 2000; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 103 + data_type: Array + nested_path:"0" + element_type:Bool + > + op: Equal, + value: < + bool_val: false + > + >)", + "bool", + [](milvus::Array& array) { + auto val = array.get_data(0); + return !val; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 104 + data_type: Array + nested_path:"0" + element_type:VarChar + > + op: Equal, + value: < + string_val: "abc" + > + >)", + "string", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val == "abc"; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 105 + data_type: Array + nested_path:"0" + element_type:Float + > + op: Equal, + value: < + float_val: 2.2 + > + >)", + "float", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val == 2.2; + }}, + + }; + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto long_array_fid = + schema->AddDebugField("long_array", DataType::ARRAY, DataType::INT64); + auto bool_array_fid = + schema->AddDebugField("bool_array", DataType::ARRAY, DataType::BOOL); + auto string_array_fid = schema->AddDebugField( + "string_array", DataType::ARRAY, DataType::VARCHAR); + auto float_array_fid = + schema->AddDebugField("double_array", DataType::ARRAY, DataType::FLOAT); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::map> array_cols; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_long_array_col = raw_data.get_col(long_array_fid); + auto new_bool_array_col = raw_data.get_col(bool_array_fid); + auto new_string_array_col = + raw_data.get_col(string_array_fid); + auto new_float_array_col = + raw_data.get_col(float_array_fid); + + array_cols["long"].insert(array_cols["long"].end(), + new_long_array_col.begin(), + new_long_array_col.end()); + array_cols["bool"].insert(array_cols["bool"].end(), + new_bool_array_col.begin(), + new_bool_array_col.end()); + array_cols["string"].insert(array_cols["string"].end(), + new_string_array_col.begin(), + new_string_array_col.end()); + array_cols["float"].insert(array_cols["float"].end(), + new_float_array_col.begin(), + new_float_array_col.end()); + + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + ExecExprVisitor visitor( + *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + for (auto [clause, array_type, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 4, clause); + auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Array(array_cols[array_type][i]); + auto ref = ref_func(array); + ASSERT_EQ(ans, ref); + } + } +} + +TEST(Expr, TestArrayEqual) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + std::vector< + std::tuple)>>> + testcases = { + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Array + element_type:Int64 + > + op:Equal + value:< + array_val: array: array: + same_type:true + element_type:Int64 + >> + >)", + [](std::vector v) { + if (v.size() != 3) { + return false; + } + for (int i = 0; i < 3; ++i) { + if (v[i] != i + 1) { + return false; + } + } + return true; + }}, + {R"(unary_range_expr: < + column_info: < + field_id: 102 + data_type: Array + element_type:Int64 + > + op:NotEqual + value: array: array: + same_type:true + element_type:Int64 + >> + >)", + [](std::vector v) { + if (v.size() != 3) { + return true; + } + for (int i = 0; i < 3; ++i) { + if (v[i] != i + 1) { + return true; + } + } + return false; + }}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto long_array_fid = + schema->AddDebugField("long_array", DataType::ARRAY, DataType::INT64); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector long_array_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter, 0, 1, 3); + auto new_long_array_col = raw_data.get_col(long_array_fid); + long_array_col.insert(long_array_col.end(), + new_long_array_col.begin(), + new_long_array_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + ExecExprVisitor visitor( + *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + for (auto [clause, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 4, clause); + auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Array(long_array_col[i]); + std::vector array_values(array.length()); + for (int j = 0; j < array.length(); ++j) { + array_values.push_back(array.get_data(j)); + } + auto ref = ref_func(array_values); + ASSERT_EQ(ans, ref); + } + } +} + +TEST(Expr, PraseArrayContainsExpr) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + + std::vector raw_plans{ + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:Array + element_type:Int64 + > + elements: + op:Contains + elements_same_type:true + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:Array + element_type:Int64 + > + elements: elements: elements: + op:ContainsAll + elements_same_type:true + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:101 + data_type:Array + element_type:Int64 + > + elements: elements: elements: + op:ContainsAny + elements_same_type:true + > + > + query_info:< + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)", + }; + + for (auto& raw_plan : raw_plans) { + auto plan_str = translate_text_plan_to_binary_plan(raw_plan); + auto schema = std::make_shared(); + schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + schema->AddField( + FieldName("array"), FieldId(101), DataType::ARRAY, DataType::INT64); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + } +} + +template +struct ArrayTestcase { + std::vector term; + std::vector nested_path; +}; + +TEST(Expr, TestArrayContains) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto int_array_fid = + schema->AddDebugField("int_array", DataType::ARRAY, DataType::INT8); + auto long_array_fid = + schema->AddDebugField("long_array", DataType::ARRAY, DataType::INT64); + auto bool_array_fid = + schema->AddDebugField("bool_array", DataType::ARRAY, DataType::BOOL); + auto float_array_fid = + schema->AddDebugField("float_array", DataType::ARRAY, DataType::FLOAT); + auto double_array_fid = schema->AddDebugField( + "double_array", DataType::ARRAY, DataType::DOUBLE); + auto string_array_fid = schema->AddDebugField( + "string_array", DataType::ARRAY, DataType::VARCHAR); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::map> array_cols; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_int_array_col = raw_data.get_col(int_array_fid); + auto new_long_array_col = raw_data.get_col(long_array_fid); + auto new_bool_array_col = raw_data.get_col(bool_array_fid); + auto new_float_array_col = + raw_data.get_col(float_array_fid); + auto new_double_array_col = + raw_data.get_col(double_array_fid); + auto new_string_array_col = + raw_data.get_col(string_array_fid); + + array_cols["int"].insert(array_cols["int"].end(), + new_int_array_col.begin(), + new_int_array_col.end()); + array_cols["long"].insert(array_cols["long"].end(), + new_long_array_col.begin(), + new_long_array_col.end()); + array_cols["bool"].insert(array_cols["bool"].end(), + new_bool_array_col.begin(), + new_bool_array_col.end()); + array_cols["float"].insert(array_cols["float"].end(), + new_float_array_col.begin(), + new_float_array_col.end()); + array_cols["double"].insert(array_cols["double"].end(), + new_double_array_col.begin(), + new_double_array_col.end()); + array_cols["string"].insert(array_cols["string"].end(), + new_string_array_col.begin(), + new_string_array_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + ExecExprVisitor visitor( + *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + + std::vector> bool_testcases{{{true, true}, {}}, + {{false, false}, {}}}; + + for (auto testcase : bool_testcases) { + auto check = [&](const std::vector& values) { + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) != + values.end()) { + return true; + } + } + return false; + }; + RetrievePlanNode plan; + plan.predicate_ = std::make_unique>( + ColumnInfo(bool_array_fid, DataType::ARRAY), + testcase.term, + true, + proto::plan::JSONContainsExpr_JSONOp_Contains, + proto::plan::GenericValue::ValCase::kBoolVal); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Array(array_cols["bool"][i]); + std::vector res; + for (int j = 0; j < array.length(); ++j) { + res.push_back(array.get_data(j)); + } + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> double_testcases{ + {{1.123, 10.34}, {"double"}}, + {{10.34, 100.234}, {"double"}}, + {{100.234, 1000.4546}, {"double"}}, + {{1000.4546, 1.123}, {"double"}}, + {{1000.4546, 10.34}, {"double"}}, + {{1.123, 100.234}, {"double"}}, + }; + + for (auto testcase : double_testcases) { + auto check = [&](const std::vector& values) { + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) != + values.end()) { + return true; + } + } + return false; + }; + RetrievePlanNode plan; + plan.predicate_ = std::make_unique>( + ColumnInfo(double_array_fid, DataType::ARRAY), + testcase.term, + true, + proto::plan::JSONContainsExpr_JSONOp_Contains, + proto::plan::GenericValue::ValCase::kFloatVal); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Array(array_cols["double"][i]); + std::vector res; + for (int j = 0; j < array.length(); ++j) { + res.push_back(array.get_data(j)); + } + ASSERT_EQ(ans, check(res)); + } + } + + for (auto testcase : double_testcases) { + auto check = [&](const std::vector& values) { + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) != + values.end()) { + return true; + } + } + return false; + }; + RetrievePlanNode plan; + plan.predicate_ = std::make_unique>( + ColumnInfo(float_array_fid, DataType::ARRAY), + testcase.term, + true, + proto::plan::JSONContainsExpr_JSONOp_Contains, + proto::plan::GenericValue::ValCase::kFloatVal); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Array(array_cols["float"][i]); + std::vector res; + for (int j = 0; j < array.length(); ++j) { + res.push_back(array.get_data(j)); + } + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> testcases{ + {{1, 10}, {"int"}}, + {{10, 100}, {"int"}}, + {{100, 1000}, {"int"}}, + {{1000, 10}, {"int"}}, + {{2, 4, 6, 8, 10}, {"int"}}, + {{1, 2, 3, 4, 5}, {"int"}}, + }; + + for (auto testcase : testcases) { + auto check = [&](const std::vector& values) { + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; + }; + RetrievePlanNode plan; + plan.predicate_ = std::make_unique>( + ColumnInfo(int_array_fid, DataType::ARRAY), + testcase.term, + true, + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + proto::plan::GenericValue::ValCase::kInt64Val); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Array(array_cols["int"][i]); + std::vector res; + for (int j = 0; j < array.length(); ++j) { + res.push_back(array.get_data(j)); + } + ASSERT_EQ(ans, check(res)); + } + } + + for (auto testcase : testcases) { + auto check = [&](const std::vector& values) { + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; + }; + RetrievePlanNode plan; + plan.predicate_ = std::make_unique>( + ColumnInfo(long_array_fid, DataType::ARRAY), + testcase.term, + true, + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + proto::plan::GenericValue::ValCase::kInt64Val); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Array(array_cols["long"][i]); + std::vector res; + for (int j = 0; j < array.length(); ++j) { + res.push_back(array.get_data(j)); + } + ASSERT_EQ(ans, check(res)); + } + } + + std::vector> testcases_string = { + {{"1sads", "10dsf"}, {"string"}}, + {{"10dsf", "100"}, {"string"}}, + {{"100", "10dsf", "1sads"}, {"string"}}, + {{"100ddfdsssdfdsfsd0", "100"}, {"string"}}, + }; + + for (auto testcase : testcases_string) { + auto check = [&](const std::vector& values) { + for (auto const& e : testcase.term) { + if (std::find(values.begin(), values.end(), e) == + values.end()) { + return false; + } + } + return true; + }; + RetrievePlanNode plan; + plan.predicate_ = std::make_unique>( + ColumnInfo(string_array_fid, DataType::ARRAY), + testcase.term, + true, + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + proto::plan::GenericValue::ValCase::kStringVal); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Array(array_cols["string"][i]); + std::vector res; + for (int j = 0; j < array.length(); ++j) { + res.push_back(array.get_data(j)); + } + ASSERT_EQ(ans, check(res)); + } + } +} + +TEST(Expr, TestArrayBinaryArith) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto int_array_fid = + schema->AddDebugField("int_array", DataType::ARRAY, DataType::INT8); + auto long_array_fid = + schema->AddDebugField("long_array", DataType::ARRAY, DataType::INT64); + auto float_array_fid = + schema->AddDebugField("float_array", DataType::ARRAY, DataType::FLOAT); + auto double_array_fid = schema->AddDebugField( + "double_array", DataType::ARRAY, DataType::DOUBLE); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::map> array_cols; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_int_array_col = raw_data.get_col(int_array_fid); + auto new_long_array_col = raw_data.get_col(long_array_fid); + auto new_float_array_col = + raw_data.get_col(float_array_fid); + auto new_double_array_col = + raw_data.get_col(double_array_fid); + + array_cols["int"].insert(array_cols["int"].end(), + new_int_array_col.begin(), + new_int_array_col.end()); + array_cols["long"].insert(array_cols["long"].end(), + new_long_array_col.begin(), + new_long_array_col.end()); + array_cols["float"].insert(array_cols["float"].end(), + new_float_array_col.begin(), + new_float_array_col.end()); + array_cols["double"].insert(array_cols["double"].end(), + new_double_array_col.begin(), + new_double_array_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + ExecExprVisitor visitor( + *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + + std::vector>> + testcases = { + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Array + nested_path:"0" + element_type:Int8 + > + arith_op:Add + right_operand: + op:Equal + value: + >)", + "int", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val + 2 == 5; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Array + nested_path:"0" + element_type:Int8 + > + arith_op:Add + right_operand: + op:NotEqual + value: + >)", + "int", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val + 2 != 5; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Sub + right_operand: + op:Equal + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val - 1 == 144; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Sub + right_operand: + op:NotEqual + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val - 1 != 144; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Array + nested_path:"0" + element_type:Float + > + arith_op:Add + right_operand: + op:Equal + value: + >)", + "float", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val + 2.2 == 133.2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 103 + data_type: Array + nested_path:"0" + element_type:Float + > + arith_op:Add + right_operand: + op:NotEqual + value: + >)", + "float", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val + 2.2 != 133.2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Array + nested_path:"0" + element_type:Double + > + arith_op:Sub + right_operand: + op:Equal + value: + >)", + "double", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val - 11.1 == 125.7; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 104 + data_type: Array + nested_path:"0" + element_type:Double + > + arith_op:Sub + right_operand: + op:NotEqual + value: + >)", + "double", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val - 11.1 != 125.7; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Mul + right_operand: + op:Equal + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val * 2 == 8; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Mul + right_operand: + op:NotEqual + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val * 2 != 20; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Div + right_operand: + op:Equal + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val / 2 == 8; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Div + right_operand: + op:NotEqual + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val / 2 != 20; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Mod + right_operand: + op:Equal + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val % 3 == 0; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Int64 + > + arith_op:Mod + right_operand: + op:NotEqual + value: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val % 3 != 2; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Array + nested_path:"0" + element_type:Int8 + > + arith_op:ArrayLength + op:Equal + value: + >)", + "int", + [](milvus::Array& array) { + return array.length() == 10; + }}, + {R"(binary_arith_op_eval_range_expr: < + column_info: < + field_id: 101 + data_type: Array + nested_path:"0" + element_type:Int8 + > + arith_op:ArrayLength + op:NotEqual + value: + >)", + "int", + [](milvus::Array& array) { + return array.length() != 8; + }}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + for (auto [clause, array_type, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 4, clause); + auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Array(array_cols[array_type][i]); + auto ref = ref_func(array); + ASSERT_EQ(ans, ref); + } + } +} + +template +struct UnaryRangeTestcase { + milvus::OpType op_type; + T value; + std::vector nested_path; + std::function check_func; +}; + +TEST(Expr, TestArrayStringMatch) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto string_array_fid = schema->AddDebugField( + "string_array", DataType::ARRAY, DataType::VARCHAR); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::map> array_cols; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_string_array_col = + raw_data.get_col(string_array_fid); + array_cols["string"].insert(array_cols["string"].end(), + new_string_array_col.begin(), + new_string_array_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + ExecExprVisitor visitor( + *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + + std::vector> prefix_testcases{ + {OpType::PrefixMatch, + "abc", + {"0"}, + [](milvus::Array& array) { + return PrefixMatch(array.get_data(0), "abc"); + }}, + {OpType::PrefixMatch, + "def", + {"1"}, + [](milvus::Array& array) { + return PrefixMatch(array.get_data(1), "def"); + }}, + }; + //vector_anns: op:PrefixMatch value: > > query_info:<> placeholder_tag:"$0" > + for (auto& testcase : prefix_testcases) { + RetrievePlanNode plan; + plan.predicate_ = std::make_unique>( + ColumnInfo(string_array_fid, DataType::ARRAY, testcase.nested_path), + testcase.op_type, + testcase.value, + proto::plan::GenericValue::ValCase::kStringVal); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Array(array_cols["string"][i]); + ASSERT_EQ(ans, testcase.check_func(array)); + } + } +} + +TEST(Expr, TestArrayInTerm) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto long_array_fid = + schema->AddDebugField("long_array", DataType::ARRAY, DataType::INT64); + auto bool_array_fid = + schema->AddDebugField("bool_array", DataType::ARRAY, DataType::BOOL); + auto float_array_fid = + schema->AddDebugField("float_array", DataType::ARRAY, DataType::FLOAT); + auto string_array_fid = schema->AddDebugField( + "string_array", DataType::ARRAY, DataType::VARCHAR); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::map> array_cols; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_long_array_col = raw_data.get_col(long_array_fid); + auto new_bool_array_col = raw_data.get_col(bool_array_fid); + auto new_float_array_col = + raw_data.get_col(float_array_fid); + auto new_string_array_col = + raw_data.get_col(string_array_fid); + array_cols["long"].insert(array_cols["long"].end(), + new_long_array_col.begin(), + new_long_array_col.end()); + array_cols["bool"].insert(array_cols["bool"].end(), + new_bool_array_col.begin(), + new_bool_array_col.end()); + array_cols["float"].insert(array_cols["float"].end(), + new_float_array_col.begin(), + new_float_array_col.end()); + array_cols["string"].insert(array_cols["string"].end(), + new_string_array_col.begin(), + new_string_array_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + ExecExprVisitor visitor( + *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + + std::vector>> + testcases = { + {R"(term_expr: < + column_info: < + field_id: 101 + data_type: Array + nested_path:"0" + element_type:Int64 + > + values: values: values: + >)", + "long", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val == 1 || val ==2 || val == 3; + }}, + {R"(term_expr: < + column_info: < + field_id: 101 + data_type: Array + nested_path:"0" + element_type:Int64 + > + >)", + "long", + [](milvus::Array& array) { + return false; + }}, + {R"(term_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Bool + > + values: values: + >)", + "bool", + [](milvus::Array& array) { + auto val = array.get_data(0); + return !val; + }}, + {R"(term_expr: < + column_info: < + field_id: 102 + data_type: Array + nested_path:"0" + element_type:Bool + > + >)", + "bool", + [](milvus::Array& array) { + return false; + }}, + {R"(term_expr: < + column_info: < + field_id: 103 + data_type: Array + nested_path:"0" + element_type:Float + > + values: values: + >)", + "float", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val == 1.23 || val == 124.31; + }}, + {R"(term_expr: < + column_info: < + field_id: 103 + data_type: Array + nested_path:"0" + element_type:Float + > + >)", + "float", + [](milvus::Array& array) { + return false; + }}, + {R"(term_expr: < + column_info: < + field_id: 104 + data_type: Array + nested_path:"0" + element_type:VarChar + > + values: values: + >)", + "string", + [](milvus::Array& array) { + auto val = array.get_data(0); + return val == "abc" || val == "idhgf1s"; + }}, + {R"(term_expr: < + column_info: < + field_id: 104 + data_type: Array + nested_path:"0" + element_type:VarChar + > + >)", + "string", + [](milvus::Array& array) { + return false; + }}, + }; + + std::string raw_plan_tmp = R"(vector_anns: < + field_id: 100 + predicates: < + @@@@ + > + query_info: < + topk: 10 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + for (auto [clause, array_type, ref_func] : testcases) { + auto loc = raw_plan_tmp.find("@@@@"); + auto raw_plan = raw_plan_tmp; + raw_plan.replace(loc, 4, clause); + auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Array(array_cols[array_type][i]); + ASSERT_EQ(ans, ref_func(array)); + } + } +} + +TEST(Expr, TestTermInArray) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto long_array_fid = + schema->AddDebugField("long_array", DataType::ARRAY, DataType::INT64); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::map> array_cols; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGen(schema, N, iter); + auto new_long_array_col = raw_data.get_col(long_array_fid); + array_cols["long"].insert(array_cols["long"].end(), + new_long_array_col.begin(), + new_long_array_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + ExecExprVisitor visitor( + *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + + struct TermTestCases { + std::vector values; + std::vector nested_path; + std::function check_func; + }; + std::vector testcases = { + {{100}, + {}, + [](milvus::Array& array) { + for (int i = 0; i < array.length(); ++i) { + auto val = array.get_data(i); + if (val == 100) { + return true; + } + } + return false; + }}, + {{1024}, + {}, + [](milvus::Array& array) { + for (int i = 0; i < array.length(); ++i) { + auto val = array.get_data(i); + if (val == 1024) { + return true; + } + } + return false; + }}, + }; + + for (auto& testcase : testcases) { + RetrievePlanNode plan; + plan.predicate_ = std::make_unique>( + ColumnInfo(long_array_fid, DataType::ARRAY, testcase.nested_path), + testcase.values, + proto::plan::GenericValue::ValCase::kInt64Val, + true); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto array = milvus::Array(array_cols["long"][i]); + ASSERT_EQ(ans, testcase.check_func(array)); + } + } +} diff --git a/internal/core/unittest/test_azure_chunk_manager.cpp b/internal/core/unittest/test_azure_chunk_manager.cpp new file mode 100644 index 0000000000000..71e9a78d82ab1 --- /dev/null +++ b/internal/core/unittest/test_azure_chunk_manager.cpp @@ -0,0 +1,288 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include + +#include "common/EasyAssert.h" +#include "storage/AzureChunkManager.h" +#include "storage/Util.h" + +using namespace std; +using namespace milvus; +using namespace milvus::storage; + +StorageConfig +get_default_storage_config() { + auto endpoint = "core.windows.net"; + auto accessKey = "devstoreaccount1"; + auto accessValue = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="; + auto rootPath = "files"; + auto useSSL = false; + auto useIam = false; + auto iamEndPoint = ""; + auto bucketName = "a-bucket"; + + return StorageConfig{endpoint, + bucketName, + accessKey, + accessValue, + rootPath, + "remote", + "azure", + iamEndPoint, + "error", + "", + useSSL, + useIam}; +} + +class AzureChunkManagerTest : public testing::Test { + public: + AzureChunkManagerTest() { + } + ~AzureChunkManagerTest() { + } + + virtual void + SetUp() { + configs_ = get_default_storage_config(); + chunk_manager_ = make_unique(configs_); + chunk_manager_ptr_ = CreateChunkManager(configs_); + } + + protected: + AzureChunkManagerPtr chunk_manager_; + ChunkManagerPtr chunk_manager_ptr_; + StorageConfig configs_; +}; + +TEST_F(AzureChunkManagerTest, BasicFunctions) { + EXPECT_TRUE(chunk_manager_->GetName() == "AzureChunkManager"); + EXPECT_TRUE(chunk_manager_ptr_->GetName() == "AzureChunkManager"); + EXPECT_TRUE(chunk_manager_->GetRootPath() == "files"); + + + string path = "test"; + uint8_t readdata[20] = {0}; + try { + chunk_manager_->Read(path, 0, readdata, sizeof(readdata)); + } catch (SegcoreError& e) { + EXPECT_TRUE(string(e.what()).find("Read") != string::npos); + } + try { + chunk_manager_->Write(path, 0, readdata, sizeof(readdata)); + } catch (SegcoreError& e) { + EXPECT_TRUE(string(e.what()).find("Write") != string::npos); + } +} + +TEST_F(AzureChunkManagerTest, BucketPositive) { + string testBucketName = "test-bucket"; + bool exist = chunk_manager_->BucketExists(testBucketName); + EXPECT_EQ(exist, false); + chunk_manager_->CreateBucket(testBucketName); + exist = chunk_manager_->BucketExists(testBucketName); + EXPECT_EQ(exist, true); + vector buckets = chunk_manager_->ListBuckets(); + EXPECT_EQ(buckets[0], testBucketName); + chunk_manager_->DeleteBucket(testBucketName); +} + +TEST_F(AzureChunkManagerTest, BucketNegtive) { + string testBucketName = "test-bucket-ng"; + try { + chunk_manager_->DeleteBucket(testBucketName); + } catch (SegcoreError& e) { + EXPECT_TRUE(string(e.what()).find("not") != string::npos); + } + + // create already exist bucket + chunk_manager_->CreateBucket(testBucketName); + try { + chunk_manager_->CreateBucket(testBucketName); + } catch (SegcoreError& e) { + EXPECT_TRUE(string(e.what()).find("exists") != string::npos); + } + chunk_manager_->DeleteBucket(testBucketName); +} + +TEST_F(AzureChunkManagerTest, ObjectExist) { + string testBucketName = configs_.bucket_name; + string objPath = "1/3"; + if (!chunk_manager_->BucketExists(testBucketName)) { + chunk_manager_->CreateBucket(testBucketName); + } + + bool exist = chunk_manager_->Exist(objPath); + EXPECT_EQ(exist, false); + chunk_manager_->DeleteBucket(testBucketName); +} + +TEST_F(AzureChunkManagerTest, WritePositive) { + string testBucketName = configs_.bucket_name; + EXPECT_EQ(chunk_manager_->GetBucketName(), testBucketName); + + if (!chunk_manager_->BucketExists(testBucketName)) { + chunk_manager_->CreateBucket(testBucketName); + } + auto has_bucket = chunk_manager_->BucketExists(testBucketName); + uint8_t data[5] = {0x17, 0x32, 0x45, 0x34, 0x23}; + string path = "1"; + chunk_manager_->Write(path, data, sizeof(data)); + + bool exist = chunk_manager_->Exist(path); + EXPECT_EQ(exist, true); + + auto size = chunk_manager_->Size(path); + EXPECT_EQ(size, 5); + + int datasize = 10000; + uint8_t* bigdata = new uint8_t[datasize]; + srand((unsigned)time(NULL)); + for (int i = 0; i < datasize; ++i) { + bigdata[i] = rand() % 256; + } + chunk_manager_->Write(path, bigdata, datasize); + size = chunk_manager_->Size(path); + EXPECT_EQ(size, datasize); + delete[] bigdata; + + chunk_manager_->Remove(path); + chunk_manager_->DeleteBucket(testBucketName); +} + +TEST_F(AzureChunkManagerTest, ReadPositive) { + string testBucketName = configs_.bucket_name; + EXPECT_EQ(chunk_manager_->GetBucketName(), testBucketName); + + if (!chunk_manager_->BucketExists(testBucketName)) { + chunk_manager_->CreateBucket(testBucketName); + } + uint8_t data[5] = {0x17, 0x32, 0x45, 0x34, 0x23}; + string path = "1/4/6"; + chunk_manager_->Write(path, data, sizeof(data)); + bool exist = chunk_manager_->Exist(path); + EXPECT_EQ(exist, true); + auto size = chunk_manager_->Size(path); + EXPECT_EQ(size, sizeof(data)); + + uint8_t readdata[20] = {0}; + size = chunk_manager_->Read(path, readdata, sizeof(data)); + EXPECT_EQ(size, sizeof(data)); + EXPECT_EQ(readdata[0], 0x17); + EXPECT_EQ(readdata[1], 0x32); + EXPECT_EQ(readdata[2], 0x45); + EXPECT_EQ(readdata[3], 0x34); + EXPECT_EQ(readdata[4], 0x23); + + size = chunk_manager_->Read(path, readdata, 3); + EXPECT_EQ(size, 3); + EXPECT_EQ(readdata[0], 0x17); + EXPECT_EQ(readdata[1], 0x32); + EXPECT_EQ(readdata[2], 0x45); + + uint8_t dataWithNULL[] = {0x17, 0x32, 0x00, 0x34, 0x23}; + chunk_manager_->Write(path, dataWithNULL, sizeof(dataWithNULL)); + exist = chunk_manager_->Exist(path); + EXPECT_EQ(exist, true); + size = chunk_manager_->Size(path); + EXPECT_EQ(size, sizeof(dataWithNULL)); + size = chunk_manager_->Read(path, readdata, sizeof(dataWithNULL)); + EXPECT_EQ(size, sizeof(dataWithNULL)); + EXPECT_EQ(readdata[0], 0x17); + EXPECT_EQ(readdata[1], 0x32); + EXPECT_EQ(readdata[2], 0x00); + EXPECT_EQ(readdata[3], 0x34); + EXPECT_EQ(readdata[4], 0x23); + + chunk_manager_->Remove(path); + + try { + chunk_manager_->Read(path, readdata, sizeof(dataWithNULL)); + } catch (SegcoreError& e) { + EXPECT_TRUE(string(e.what()).find("exists") != string::npos); + } + + chunk_manager_->DeleteBucket(testBucketName); +} + +TEST_F(AzureChunkManagerTest, RemovePositive) { + string testBucketName = configs_.bucket_name; + EXPECT_EQ(chunk_manager_->GetBucketName(), testBucketName); + + if (!chunk_manager_->BucketExists(testBucketName)) { + chunk_manager_->CreateBucket(testBucketName); + } + uint8_t data[5] = {0x17, 0x32, 0x45, 0x34, 0x23}; + string path = "1/7/8"; + chunk_manager_->Write(path, data, sizeof(data)); + + bool exist = chunk_manager_->Exist(path); + EXPECT_EQ(exist, true); + + chunk_manager_->Remove(path); + + exist = chunk_manager_->Exist(path); + EXPECT_EQ(exist, false); + + try { + chunk_manager_->Remove(path); + } catch (SegcoreError& e) { + EXPECT_TRUE(string(e.what()).find("not") != string::npos); + } + + try { + chunk_manager_->Size(path); + } catch (SegcoreError& e) { + EXPECT_TRUE(string(e.what()).find("not") != string::npos); + } + + chunk_manager_->DeleteBucket(testBucketName); +} + +TEST_F(AzureChunkManagerTest, ListWithPrefixPositive) { + string testBucketName = configs_.bucket_name; + EXPECT_EQ(chunk_manager_->GetBucketName(), testBucketName); + + if (!chunk_manager_->BucketExists(testBucketName)) { + chunk_manager_->CreateBucket(testBucketName); + } + + string path1 = "1/7/8"; + string path2 = "1/7/4"; + string path3 = "1/4/8"; + uint8_t data[5] = {0x17, 0x32, 0x45, 0x34, 0x23}; + chunk_manager_->Write(path1, data, sizeof(data)); + chunk_manager_->Write(path2, data, sizeof(data)); + chunk_manager_->Write(path3, data, sizeof(data)); + + vector objs = chunk_manager_->ListWithPrefix("1/7"); + EXPECT_EQ(objs.size(), 2); + sort(objs.begin(), objs.end()); + EXPECT_EQ(objs[0], "1/7/4"); + EXPECT_EQ(objs[1], "1/7/8"); + + objs = chunk_manager_->ListWithPrefix("//1/7"); + EXPECT_EQ(objs.size(), 0); + + objs = chunk_manager_->ListWithPrefix("1"); + EXPECT_EQ(objs.size(), 3); + sort(objs.begin(), objs.end()); + EXPECT_EQ(objs[0], "1/4/8"); + EXPECT_EQ(objs[1], "1/7/4"); + + chunk_manager_->Remove(path1); + chunk_manager_->Remove(path2); + chunk_manager_->Remove(path3); + chunk_manager_->DeleteBucket(testBucketName); +} diff --git a/internal/core/unittest/test_bf.cpp b/internal/core/unittest/test_bf.cpp index e512c678a8aa5..f0e64b087b4c3 100644 --- a/internal/core/unittest/test_bf.cpp +++ b/internal/core/unittest/test_bf.cpp @@ -56,7 +56,7 @@ Distances(const float* base, } return res; } else { - PanicInfo("invalid metric type"); + PanicInfo(MetricTypeInvalid, "invalid metric type"); } } @@ -85,7 +85,7 @@ Ref(const float* base, } else if (milvus::IsMetricType(metric, knowhere::metric::IP)) { std::reverse(res.begin(), res.end()); } else { - PanicInfo("invalid metric type"); + PanicInfo(MetricTypeInvalid, "invalid metric type"); } return GetOffsets(res, topk); } diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index bb8e806781fb4..26df258bbdfe7 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -19,8 +20,10 @@ #include #include +#include "boost/container/vector.hpp" #include "common/LoadInfo.h" #include "common/Types.h" +#include "common/type_c.h" #include "index/IndexFactory.h" #include "knowhere/comp/index_param.h" #include "pb/plan.pb.h" @@ -28,6 +31,7 @@ #include "segcore/Collection.h" #include "segcore/Reduce.h" #include "segcore/reduce_c.h" +#include "segcore/segment_c.h" #include "test_utils/DataGen.h" #include "test_utils/PbHelper.h" #include "test_utils/indexbuilder_test_utils.h" @@ -245,9 +249,12 @@ generate_index(void* raw_data, IndexType index_type, int64_t dim, int64_t N) { - CreateIndexInfo create_index_info{field_type, index_type, metric_type}; + auto engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); + CreateIndexInfo create_index_info{ + field_type, index_type, metric_type, engine_version}; auto indexing = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, nullptr); + create_index_info, milvus::storage::FileManagerContext()); auto database = knowhere::GenDataSet(N, dim, raw_data); auto build_config = generate_build_conf(index_type, metric_type); @@ -311,7 +318,7 @@ TEST(CApiTest, CPlan) { milvus::proto::plan::PlanNode plan_node; auto vector_anns = plan_node.mutable_vector_anns(); - vector_anns->set_is_binary(true); + vector_anns->set_vector_type(milvus::proto::plan::VectorType::BinaryVector); vector_anns->set_placeholder_tag("$0"); vector_anns->set_field_id(100); auto query_info = vector_anns->mutable_query_info(); @@ -950,7 +957,7 @@ TEST(CApiTest, SearchTest) { milvus::proto::plan::PlanNode plan_node; auto vector_anns = plan_node.mutable_vector_anns(); - vector_anns->set_is_binary(false); + vector_anns->set_vector_type(milvus::proto::plan::VectorType::FloatVector); vector_anns->set_placeholder_tag("$0"); vector_anns->set_field_id(100); auto query_info = vector_anns->mutable_query_info(); @@ -1151,7 +1158,7 @@ TEST(CApiTest, GetDeletedCountTest) { // TODO: assert(deleted_count == len(delete_row_ids)) auto deleted_count = GetDeletedCount(segment); - ASSERT_EQ(deleted_count, delete_row_ids.size()); + ASSERT_EQ(deleted_count, 0); DeleteCollection(collection); DeleteSegment(segment); @@ -1275,7 +1282,7 @@ TEST(CApiTest, ReudceNullResult) { milvus::proto::plan::PlanNode plan_node; auto vector_anns = plan_node.mutable_vector_anns(); - vector_anns->set_is_binary(false); + vector_anns->set_vector_type(milvus::proto::plan::VectorType::FloatVector); vector_anns->set_placeholder_tag("$0"); vector_anns->set_field_id(100); auto query_info = vector_anns->mutable_query_info(); @@ -1359,7 +1366,7 @@ TEST(CApiTest, ReduceRemoveDuplicates) { milvus::proto::plan::PlanNode plan_node; auto vector_anns = plan_node.mutable_vector_anns(); - vector_anns->set_is_binary(false); + vector_anns->set_vector_type(milvus::proto::plan::VectorType::FloatVector); vector_anns->set_placeholder_tag("$0"); vector_anns->set_field_id(100); auto query_info = vector_anns->mutable_query_info(); @@ -1593,7 +1600,8 @@ TEST(CApiTest, LoadIndexInfo) { auto N = 1024 * 10; auto [raw_data, timestamps, uids] = generate_data(N); auto indexing = knowhere::IndexFactory::Instance().Create( - knowhere::IndexEnum::INDEX_FAISS_IVFSQ8); + knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, + knowhere::Version::GetCurrentVersion().VersionNumber()); auto conf = knowhere::Json{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}, {knowhere::meta::DIM, DIM}, @@ -1626,6 +1634,9 @@ TEST(CApiTest, LoadIndexInfo) { status = AppendFieldInfo( c_load_index_info, 0, 0, 0, 0, CDataType::FloatVector, ""); ASSERT_EQ(status.error_code, Success); + AppendIndexEngineVersionToLoadInfo( + c_load_index_info, + knowhere::Version::GetCurrentVersion().VersionNumber()); status = AppendIndex(c_load_index_info, c_binary_set); ASSERT_EQ(status.error_code, Success); DeleteLoadIndexInfo(c_load_index_info); @@ -1639,7 +1650,8 @@ TEST(CApiTest, LoadIndexSearch) { auto num_query = 100; auto [raw_data, timestamps, uids] = generate_data(N); auto indexing = knowhere::IndexFactory::Instance().Create( - knowhere::IndexEnum::INDEX_FAISS_IVFSQ8); + knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, + knowhere::Version::GetCurrentVersion().VersionNumber()); auto conf = knowhere::Json{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}, {knowhere::meta::DIM, DIM}, @@ -1663,7 +1675,9 @@ TEST(CApiTest, LoadIndexSearch) { auto& index_params = load_index_info.index_params; index_params["index_type"] = knowhere::IndexEnum::INDEX_FAISS_IVFSQ8; load_index_info.index = std::make_unique( - index_params["index_type"], knowhere::metric::L2); + index_params["index_type"], + knowhere::metric::L2, + knowhere::Version::GetCurrentVersion().VersionNumber()); load_index_info.index->Load(binary_set); // search @@ -1703,7 +1717,7 @@ TEST(CApiTest, Indexing_Without_Predicate) { milvus::proto::plan::PlanNode plan_node; auto vector_anns = plan_node.mutable_vector_anns(); - vector_anns->set_is_binary(false); + vector_anns->set_vector_type(milvus::proto::plan::VectorType::FloatVector); vector_anns->set_placeholder_tag("$0"); vector_anns->set_field_id(100); auto query_info = vector_anns->mutable_query_info(); @@ -1781,8 +1795,10 @@ TEST(CApiTest, Indexing_Without_Predicate) { c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + AppendIndexEngineVersionToLoadInfo( + c_load_index_info, + knowhere::Version::GetCurrentVersion().VersionNumber()); AppendIndex(c_load_index_info, (CBinarySet)&binary_set); - // load index for vec field, load raw data for scalar field auto sealed_segment = SealedCreator(schema, dataset); sealed_segment->DropFieldData(FieldId(100)); @@ -1921,6 +1937,9 @@ TEST(CApiTest, Indexing_Expr_Without_Predicate) { c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + AppendIndexEngineVersionToLoadInfo( + c_load_index_info, + knowhere::Version::GetCurrentVersion().VersionNumber()); AppendIndex(c_load_index_info, (CBinarySet)&binary_set); // load index for vec field, load raw data for scalar field @@ -2090,6 +2109,9 @@ TEST(CApiTest, Indexing_With_float_Predicate_Range) { c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + AppendIndexEngineVersionToLoadInfo( + c_load_index_info, + knowhere::Version::GetCurrentVersion().VersionNumber()); AppendIndex(c_load_index_info, (CBinarySet)&binary_set); // load index for vec field, load raw data for scalar field @@ -2261,6 +2283,9 @@ TEST(CApiTest, Indexing_Expr_With_float_Predicate_Range) { c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + AppendIndexEngineVersionToLoadInfo( + c_load_index_info, + knowhere::Version::GetCurrentVersion().VersionNumber()); AppendIndex(c_load_index_info, (CBinarySet)&binary_set); // load index for vec field, load raw data for scalar field @@ -2424,6 +2449,9 @@ TEST(CApiTest, Indexing_With_float_Predicate_Term) { c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + AppendIndexEngineVersionToLoadInfo( + c_load_index_info, + knowhere::Version::GetCurrentVersion().VersionNumber()); AppendIndex(c_load_index_info, (CBinarySet)&binary_set); // load index for vec field, load raw data for scalar field @@ -2588,6 +2616,9 @@ TEST(CApiTest, Indexing_Expr_With_float_Predicate_Term) { c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + AppendIndexEngineVersionToLoadInfo( + c_load_index_info, + knowhere::Version::GetCurrentVersion().VersionNumber()); AppendIndex(c_load_index_info, (CBinarySet)&binary_set); // load index for vec field, load raw data for scalar field @@ -2758,6 +2789,9 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Range) { c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( c_load_index_info, 0, 0, 0, 100, CDataType::BinaryVector, ""); + AppendIndexEngineVersionToLoadInfo( + c_load_index_info, + knowhere::Version::GetCurrentVersion().VersionNumber()); AppendIndex(c_load_index_info, (CBinarySet)&binary_set); // load index for vec field, load raw data for scalar field @@ -2928,6 +2962,9 @@ TEST(CApiTest, Indexing_Expr_With_binary_Predicate_Range) { c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( c_load_index_info, 0, 0, 0, 100, CDataType::BinaryVector, ""); + AppendIndexEngineVersionToLoadInfo( + c_load_index_info, + knowhere::Version::GetCurrentVersion().VersionNumber()); AppendIndex(c_load_index_info, (CBinarySet)&binary_set); // load index for vec field, load raw data for scalar field @@ -3092,6 +3129,9 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Term) { c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( c_load_index_info, 0, 0, 0, 100, CDataType::BinaryVector, ""); + AppendIndexEngineVersionToLoadInfo( + c_load_index_info, + knowhere::Version::GetCurrentVersion().VersionNumber()); AppendIndex(c_load_index_info, (CBinarySet)&binary_set); // load index for vec field, load raw data for scalar field @@ -3279,6 +3319,9 @@ TEST(CApiTest, Indexing_Expr_With_binary_Predicate_Term) { c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( c_load_index_info, 0, 0, 0, 100, CDataType::BinaryVector, ""); + AppendIndexEngineVersionToLoadInfo( + c_load_index_info, + knowhere::Version::GetCurrentVersion().VersionNumber()); AppendIndex(c_load_index_info, (CBinarySet)&binary_set); // load index for vec field, load raw data for scalar field @@ -3449,6 +3492,9 @@ TEST(CApiTest, SealedSegment_search_float_Predicate_Range) { c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + AppendIndexEngineVersionToLoadInfo( + c_load_index_info, + knowhere::Version::GetCurrentVersion().VersionNumber()); AppendIndex(c_load_index_info, (CBinarySet)&binary_set); auto query_dataset = knowhere::GenDataSet(num_queries, DIM, query_ptr); @@ -3673,6 +3719,9 @@ TEST(CApiTest, SealedSegment_search_float_With_Expr_Predicate_Range) { c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str()); AppendFieldInfo( c_load_index_info, 0, 0, 0, 100, CDataType::FloatVector, ""); + AppendIndexEngineVersionToLoadInfo( + c_load_index_info, + knowhere::Version::GetCurrentVersion().VersionNumber()); AppendIndex(c_load_index_info, (CBinarySet)&binary_set); // load vec index @@ -3933,7 +3982,7 @@ TEST(CApiTest, RetriveScalarFieldFromSealedSegmentWithIndex) { break; } default: { - PanicInfo("not supported type"); + PanicInfo(DataTypeInvalid, "not supported type"); } } } diff --git a/internal/core/unittest/test_chunk_cache.cpp b/internal/core/unittest/test_chunk_cache.cpp new file mode 100644 index 0000000000000..5386b745f00a8 --- /dev/null +++ b/internal/core/unittest/test_chunk_cache.cpp @@ -0,0 +1,167 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#include "fmt/format.h" +#include "common/Schema.h" +#include "test_utils/DataGen.h" +#include "test_utils/storage_test_utils.h" +#include "storage/ChunkCache.h" +#include "storage/LocalChunkManagerSingleton.h" + +#define DEFAULT_READ_AHEAD_POLICY "willneed" + +TEST(ChunkCacheTest, Read) { + auto N = 10000; + auto dim = 128; + auto metric_type = knowhere::metric::L2; + + auto mmap_dir = "/tmp/test_chunk_cache/mmap"; + auto local_storage_path = "/tmp/test_chunk_cache/local"; + auto file_name = std::string("chunk_cache_test/insert_log/1/101/1000000"); + + milvus::storage::LocalChunkManagerSingleton::GetInstance().Init( + local_storage_path); + + auto schema = std::make_shared(); + auto fake_id = schema->AddDebugField( + "fakevec", milvus::DataType::VECTOR_FLOAT, dim, metric_type); + auto i64_fid = schema->AddDebugField("counter", milvus::DataType::INT64); + schema->set_primary_field_id(i64_fid); + + auto dataset = milvus::segcore::DataGen(schema, N); + + auto field_data_meta = + milvus::storage::FieldDataMeta{1, 2, 3, fake_id.get()}; + auto field_meta = milvus::FieldMeta(milvus::FieldName("facevec"), + fake_id, + milvus::DataType::VECTOR_FLOAT, + dim, + metric_type); + + auto lcm = milvus::storage::LocalChunkManagerSingleton::GetInstance() + .GetChunkManager(); + auto data = dataset.get_col(fake_id); + auto data_slices = std::vector{(uint8_t*)data.data()}; + auto slice_sizes = std::vector{static_cast(N)}; + auto slice_names = std::vector{file_name}; + PutFieldData(lcm.get(), + data_slices, + slice_sizes, + slice_names, + field_data_meta, + field_meta); + + auto cc = std::make_shared( + mmap_dir, DEFAULT_READ_AHEAD_POLICY, lcm); + const auto& column = cc->Read(file_name); + Assert(column->ByteSize() == dim * N * 4); + + cc->Prefetch(file_name); + auto actual = (float*)column->Data(); + for (auto i = 0; i < N; i++) { + AssertInfo(data[i] == actual[i], + fmt::format("expect {}, actual {}", data[i], actual[i])); + } + + cc->Remove(file_name); + lcm->Remove(file_name); + std::filesystem::remove_all(mmap_dir); + + auto exist = lcm->Exist(file_name); + Assert(!exist); + exist = std::filesystem::exists(mmap_dir); + Assert(!exist); +} + +TEST(ChunkCacheTest, TestMultithreads) { + auto N = 1000; + auto dim = 128; + auto metric_type = knowhere::metric::L2; + + auto mmap_dir = "/tmp/test_chunk_cache/mmap"; + auto local_storage_path = "/tmp/test_chunk_cache/local"; + auto file_name = std::string("chunk_cache_test/insert_log/2/101/1000000"); + + milvus::storage::LocalChunkManagerSingleton::GetInstance().Init( + local_storage_path); + + auto schema = std::make_shared(); + auto fake_id = schema->AddDebugField( + "fakevec", milvus::DataType::VECTOR_FLOAT, dim, metric_type); + auto i64_fid = schema->AddDebugField("counter", milvus::DataType::INT64); + schema->set_primary_field_id(i64_fid); + + auto dataset = milvus::segcore::DataGen(schema, N); + + auto field_data_meta = + milvus::storage::FieldDataMeta{1, 2, 3, fake_id.get()}; + auto field_meta = milvus::FieldMeta(milvus::FieldName("facevec"), + fake_id, + milvus::DataType::VECTOR_FLOAT, + dim, + metric_type); + + auto lcm = milvus::storage::LocalChunkManagerSingleton::GetInstance() + .GetChunkManager(); + auto data = dataset.get_col(fake_id); + auto data_slices = std::vector{(uint8_t*)data.data()}; + auto slice_sizes = std::vector{static_cast(N)}; + auto slice_names = std::vector{file_name}; + PutFieldData(lcm.get(), + data_slices, + slice_sizes, + slice_names, + field_data_meta, + field_meta); + + auto cc = std::make_shared( + mmap_dir, DEFAULT_READ_AHEAD_POLICY, lcm); + + constexpr int threads = 16; + std::vector total_counts(threads); + auto executor = [&](int thread_id) { + const auto& column = cc->Read(file_name); + Assert(column->ByteSize() == dim * N * 4); + + cc->Prefetch(file_name); + auto actual = (float*)column->Data(); + for (auto i = 0; i < N; i++) { + AssertInfo(data[i] == actual[i], + fmt::format("expect {}, actual {}", data[i], actual[i])); + } + }; + std::vector pool; + for (int i = 0; i < threads; ++i) { + pool.emplace_back(executor, i); + } + for (auto& thread : pool) { + thread.join(); + } + + cc->Remove(file_name); + lcm->Remove(file_name); + std::filesystem::remove_all(mmap_dir); + + auto exist = lcm->Exist(file_name); + Assert(!exist); + exist = std::filesystem::exists(mmap_dir); + Assert(!exist); +} diff --git a/internal/core/unittest/test_data_codec.cpp b/internal/core/unittest/test_data_codec.cpp index f8939e72db92c..e8075e7ce4302 100644 --- a/internal/core/unittest/test_data_codec.cpp +++ b/internal/core/unittest/test_data_codec.cpp @@ -21,7 +21,7 @@ #include "storage/IndexData.h" #include "storage/Util.h" #include "common/Consts.h" -#include "utils/Json.h" +#include "common/Json.h" using namespace milvus; @@ -302,6 +302,36 @@ TEST(storage, InsertDataBinaryVector) { ASSERT_EQ(data, new_data); } +TEST(storage, InsertDataFloat16Vector) { + std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; + int DIM = 2; + auto field_data = milvus::storage::CreateFieldData( + storage::DataType::VECTOR_FLOAT16, DIM); + field_data->FillFieldData(data.data(), data.size() / DIM); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::VECTOR_FLOAT16); + ASSERT_EQ(new_payload->get_num_rows(), data.size() / DIM); + std::vector new_data(data.size()); + memcpy(new_data.data(), + new_payload->Data(), + new_payload->get_num_rows() * sizeof(float16) * DIM); + ASSERT_EQ(data, new_data); +} + TEST(storage, IndexData) { std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; auto field_data = milvus::storage::CreateFieldData(storage::DataType::INT8); @@ -329,3 +359,40 @@ TEST(storage, IndexData) { memcpy(new_data.data(), new_field_data->Data(), new_field_data->Size()); ASSERT_EQ(data, new_data); } + +TEST(storage, InsertDataStringArray) { + milvus::proto::schema::ScalarField field_string_data; + field_string_data.mutable_string_data()->add_data("test_array1"); + field_string_data.mutable_string_data()->add_data("test_array2"); + field_string_data.mutable_string_data()->add_data("test_array3"); + field_string_data.mutable_string_data()->add_data("test_array4"); + field_string_data.mutable_string_data()->add_data("test_array5"); + auto string_array = Array(field_string_data); + FixedVector data = {string_array}; + auto field_data = + milvus::storage::CreateFieldData(storage::DataType::ARRAY); + field_data->FillFieldData(data.data(), data.size()); + + storage::InsertData insert_data(field_data); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), storage::DataType::ARRAY); + ASSERT_EQ(new_payload->get_num_rows(), data.size()); + FixedVector new_data(data.size()); + for (int i = 0; i < data.size(); ++i) { + new_data[i] = *static_cast(new_payload->RawValue(i)); + ASSERT_EQ(new_payload->Size(i), data[i].byte_size()); + ASSERT_TRUE(data[i].operator==(new_data[i])); + } +} diff --git a/internal/core/unittest/test_disk_file_manager_test.cpp b/internal/core/unittest/test_disk_file_manager_test.cpp index 00ac8f74b5c29..310dec776caea 100644 --- a/internal/core/unittest/test_disk_file_manager_test.cpp +++ b/internal/core/unittest/test_disk_file_manager_test.cpp @@ -61,8 +61,8 @@ TEST_F(DiskAnnFileManagerTest, AddFilePositiveParallel) { IndexMeta index_meta = {3, 100, 1000, 1, "index"}; int64_t slice_size = milvus::FILE_SLICE_SIZE; - auto diskAnnFileManager = - std::make_shared(filed_data_meta, index_meta, cm_); + auto diskAnnFileManager = std::make_shared( + storage::FileManagerContext(filed_data_meta, index_meta, cm_)); auto ok = diskAnnFileManager->AddFile(indexFilePath); EXPECT_EQ(ok, true); @@ -101,7 +101,7 @@ TEST_F(DiskAnnFileManagerTest, AddFilePositiveParallel) { int test_worker(string s) { std::cout << s << std::endl; - sleep(4); + std::this_thread::sleep_for(std::chrono::seconds(4)); std::cout << s << std::endl; return 1; } @@ -163,7 +163,7 @@ TEST_F(DiskAnnFileManagerTest, TestThreadPool) { int test_exception(string s) { if (s == "test_id60") { - throw std::runtime_error("run time error"); + throw SegcoreError(ErrorCode::UnexpectedError, "run time error"); } return 1; } diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 1998d32afdd81..eacc3970e6590 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -486,7 +486,7 @@ TEST(Expr, TestRange) { auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector age_col; - int num_iters = 100; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_age_col = raw_data.get_col(i64_fid); @@ -553,7 +553,7 @@ TEST(Expr, TestBinaryRangeJSON) { auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; - int num_iters = 100; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); @@ -641,7 +641,7 @@ TEST(Expr, TestExistsJson) { auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; - int num_iters = 100; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); @@ -706,7 +706,7 @@ TEST(Expr, TestUnaryRangeJson) { auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; - int num_iters = 100; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); @@ -762,7 +762,7 @@ TEST(Expr, TestUnaryRangeJson) { break; } default: { - PanicInfo("unsupported range node"); + PanicInfo(Unsupported, "unsupported range node"); } } @@ -796,6 +796,54 @@ TEST(Expr, TestUnaryRangeJson) { } } } + + struct TestArrayCase { + proto::plan::Array val; + std::vector nested_path; + }; + + proto::plan::Array arr; + arr.set_same_type(true); + proto::plan::GenericValue int_val1; + int_val1.set_int64_val(int64_t(1)); + arr.add_array()->CopyFrom(int_val1); + + proto::plan::GenericValue int_val2; + int_val2.set_int64_val(int64_t(2)); + arr.add_array()->CopyFrom(int_val2); + + proto::plan::GenericValue int_val3; + int_val3.set_int64_val(int64_t(3)); + arr.add_array()->CopyFrom(int_val3); + + std::vector array_cases = {{arr, {"array"}}}; + + for (const auto& testcase : array_cases) { + auto check = [&](OpType op) { + if (testcase.nested_path[0] == "array" && op == OpType::Equal) { + return true; + } + return false; + }; + for (auto& op : ops) { + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + plan.predicate_ = + std::make_unique>( + ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), + op, + testcase.val, + proto::plan::GenericValue::ValCase::kArrayVal); + auto final = visitor.call_child(*plan.predicate_.value()); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + auto ref = check(op); + ASSERT_EQ(ans, ref); + } + } + } } TEST(Expr, TestTermJson) { @@ -1065,7 +1113,7 @@ TEST(Expr, TestCompare) { int N = 1000; std::vector age1_col; std::vector age2_col; - int num_iters = 100; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_age1_col = raw_data.get_col(i32_fid); @@ -1371,7 +1419,7 @@ TEST(Expr, TestMultiLogicalExprsOptimization) { schema->set_primary_field_id(str1_fid); auto seg = CreateSealedSegment(schema); - size_t N = 1000000; + size_t N = 10000; auto raw_data = DataGen(schema, N); auto fields = schema->get_fields(); for (auto field_data : raw_data.raw_->fields_data()) { @@ -1470,7 +1518,7 @@ TEST(Expr, TestExprs) { schema->set_primary_field_id(str1_fid); auto seg = CreateSealedSegment(schema); - int N = 1000000; + int N = 10000; auto raw_data = DataGen(schema, N); // load field data @@ -1998,7 +2046,7 @@ TEST(Expr, TestBinaryArithOpEvalRange) { std::vector age64_col; std::vector age_float_col; std::vector age_double_col; - int num_iters = 100; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); @@ -2125,7 +2173,7 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSON) { auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; - int num_iters = 100; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); @@ -2213,7 +2261,7 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSONFloat) { auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; - int num_iters = 100; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); @@ -2261,6 +2309,46 @@ TEST(Expr, TestBinaryArithOpEvalRangeJSONFloat) { ASSERT_EQ(ans, ref) << testcase.value << " " << val; } } + + std::vector array_testcases{ + {0, 3, OpType::Equal, {"array"}}, + {0, 5, OpType::NotEqual, {"array"}}, + }; + + for (auto testcase : array_testcases) { + auto check = [&](int64_t value) { + if (testcase.op == OpType::Equal) { + return value == testcase.value; + } + return value != testcase.value; + }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + plan.predicate_ = + std::make_unique>( + ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), + proto::plan::GenericValue::ValCase::kInt64Val, + ArithOpType::ArrayLength, + testcase.right_operand, + testcase.op, + testcase.value); + auto final = visitor.call_child(*plan.predicate_.value()); + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + + auto json = milvus::Json(simdjson::padded_string(json_col[i])); + int64_t array_length = 0; + auto doc = json.doc(); + auto array = doc.at_pointer(pointer).get_array(); + if (!array.error()) { + array_length = array.count_elements(); + } + auto ref = check(array_length); + ASSERT_EQ(ans, ref) << testcase.value << " " << array_length; + } + } } TEST(Expr, TestBinaryArithOpEvalRangeWithScalarSortIndex) { @@ -2683,7 +2771,7 @@ TEST(Expr, TestUnaryRangeWithJSON) { auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; - int num_iters = 100; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); @@ -2861,7 +2949,7 @@ TEST(Expr, TestTermWithJSON) { auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; - int num_iters = 100; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); @@ -3006,7 +3094,7 @@ TEST(Expr, TestExistsWithJSON) { auto seg = CreateGrowingSegment(schema, empty_index_meta); int N = 1000; std::vector json_col; - int num_iters = 100; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGen(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); @@ -3116,6 +3204,7 @@ template struct Testcase { std::vector term; std::vector nested_path; + bool res; }; TEST(Expr, TestTermInFieldJson) { @@ -3129,9 +3218,9 @@ TEST(Expr, TestTermInFieldJson) { schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); - int N = 10000; + int N = 1000; std::vector json_col; - int num_iters = 2; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); @@ -3465,9 +3554,9 @@ TEST(Expr, TestJsonContainsAny) { schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); - int N = 10000; + int N = 1000; std::vector json_col; - int num_iters = 2; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); @@ -3658,9 +3747,9 @@ TEST(Expr, TestJsonContainsAll) { schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); - int N = 10000; + int N = 1000; std::vector json_col; - int num_iters = 2; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); @@ -3875,9 +3964,9 @@ TEST(Expr, TestJsonContainsArray) { schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); - int N = 10000; + int N = 1000; std::vector json_col; - int num_iters = 2; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); @@ -4017,12 +4106,11 @@ TEST(Expr, TestJsonContainsArray) { proto::plan::GenericValue int_val22; int_val22.set_int64_val(int64_t(4)); sub_arr2.add_array()->CopyFrom(int_val22); - std::vector> diff_testcases2{{{sub_arr1, sub_arr2}, {"array2"}}}; + std::vector> diff_testcases2{ + {{sub_arr1, sub_arr2}, {"array2"}}}; for (auto& testcase : diff_testcases2) { - auto check = [&](const std::vector& values, int i) { - return true; - }; + auto check = [&]() { return true; }; RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); plan.predicate_ = @@ -4043,8 +4131,7 @@ TEST(Expr, TestJsonContainsArray) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - std::vector res; - ASSERT_EQ(ans, check(res, i)); + ASSERT_EQ(ans, check()); } } @@ -4096,11 +4183,12 @@ TEST(Expr, TestJsonContainsArray) { proto::plan::GenericValue int_val42; int_val42.set_int64_val(int64_t(8)); sub_arr4.add_array()->CopyFrom(int_val42); - std::vector> diff_testcases3{{{sub_arr3, sub_arr4}, {"array2"}}}; + std::vector> diff_testcases3{ + {{sub_arr3, sub_arr4}, {"array2"}}}; - for (auto& testcase : diff_testcases2) { + for (auto& testcase : diff_testcases3) { auto check = [&](const std::vector& values, int i) { - return true; + return false; }; RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); @@ -4127,9 +4215,9 @@ TEST(Expr, TestJsonContainsArray) { } } - for (auto& testcase : diff_testcases2) { + for (auto& testcase : diff_testcases3) { auto check = [&](const std::vector& values, int i) { - return true; + return false; }; RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); @@ -4157,6 +4245,141 @@ TEST(Expr, TestJsonContainsArray) { } } +milvus::proto::plan::GenericValue +generatedArrayWithFourDiffType(int64_t int_val, + double float_val, + bool bool_val, + std::string string_val) { + using namespace milvus; + + proto::plan::GenericValue value; + proto::plan::Array diff_type_array; + diff_type_array.set_same_type(false); + proto::plan::GenericValue int_value; + int_value.set_int64_val(int_val); + diff_type_array.add_array()->CopyFrom(int_value); + + proto::plan::GenericValue float_value; + float_value.set_float_val(float_val); + diff_type_array.add_array()->CopyFrom(float_value); + + proto::plan::GenericValue bool_value; + bool_value.set_bool_val(bool_val); + diff_type_array.add_array()->CopyFrom(bool_value); + + proto::plan::GenericValue string_value; + string_value.set_string_val(string_val); + diff_type_array.add_array()->CopyFrom(string_value); + + value.mutable_array_val()->CopyFrom(diff_type_array); + return value; +} + +TEST(Expr, TestJsonContainsDiffTypeArray) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + + auto schema = std::make_shared(); + auto i64_fid = schema->AddDebugField("id", DataType::INT64); + auto json_fid = schema->AddDebugField("json", DataType::JSON); + schema->set_primary_field_id(i64_fid); + + auto seg = CreateGrowingSegment(schema, empty_index_meta); + int N = 1000; + std::vector json_col; + int num_iters = 1; + for (int iter = 0; iter < num_iters; ++iter) { + auto raw_data = DataGenForJsonArray(schema, N, iter); + auto new_json_col = raw_data.get_col(json_fid); + + json_col.insert( + json_col.end(), new_json_col.begin(), new_json_col.end()); + seg->PreInsert(N); + seg->Insert(iter * N, + N, + raw_data.row_ids_.data(), + raw_data.timestamps_.data(), + raw_data.raw_); + } + + auto seg_promote = dynamic_cast(seg.get()); + ExecExprVisitor visitor( + *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); + + proto::plan::GenericValue int_value; + int_value.set_int64_val(1); + auto diff_type_array1 = + generatedArrayWithFourDiffType(1, 2.2, false, "abc"); + auto diff_type_array2 = + generatedArrayWithFourDiffType(1, 2.2, false, "def"); + auto diff_type_array3 = generatedArrayWithFourDiffType(1, 2.2, true, "abc"); + auto diff_type_array4 = + generatedArrayWithFourDiffType(1, 3.3, false, "abc"); + auto diff_type_array5 = + generatedArrayWithFourDiffType(2, 2.2, false, "abc"); + + std::vector> diff_testcases{ + {{diff_type_array1, int_value}, {"array3"}, true}, + {{diff_type_array2, int_value}, {"array3"}, false}, + {{diff_type_array3, int_value}, {"array3"}, false}, + {{diff_type_array4, int_value}, {"array3"}, false}, + {{diff_type_array5, int_value}, {"array3"}, false}, + }; + + for (auto& testcase : diff_testcases) { + auto check = [&]() { return testcase.res; }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + plan.predicate_ = + std::make_unique>( + ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), + testcase.term, + false, + proto::plan::JSONContainsExpr_JSONOp_ContainsAny, + proto::plan::GenericValue::ValCase::kArrayVal); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + ASSERT_EQ(ans, check()); + } + } + + for (auto& testcase : diff_testcases) { + auto check = [&]() { return false; }; + RetrievePlanNode plan; + auto pointer = milvus::Json::pointer(testcase.nested_path); + plan.predicate_ = + std::make_unique>( + ColumnInfo(json_fid, DataType::JSON, testcase.nested_path), + testcase.term, + false, + proto::plan::JSONContainsExpr_JSONOp_ContainsAll, + proto::plan::GenericValue::ValCase::kArrayVal); + auto start = std::chrono::steady_clock::now(); + auto final = visitor.call_child(*plan.predicate_.value()); + std::cout << "cost" + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + EXPECT_EQ(final.size(), N * num_iters); + + for (int i = 0; i < N * num_iters; ++i) { + auto ans = final[i]; + ASSERT_EQ(ans, check()); + } + } +} + TEST(Expr, TestJsonContainsDiffType) { using namespace milvus; using namespace milvus::query; @@ -4168,9 +4391,9 @@ TEST(Expr, TestJsonContainsDiffType) { schema->set_primary_field_id(i64_fid); auto seg = CreateGrowingSegment(schema, empty_index_meta); - int N = 10000; + int N = 1000; std::vector json_col; - int num_iters = 2; + int num_iters = 1; for (int iter = 0; iter < num_iters; ++iter) { auto raw_data = DataGenForJsonArray(schema, N, iter); auto new_json_col = raw_data.get_col(json_fid); @@ -4189,32 +4412,34 @@ TEST(Expr, TestJsonContainsDiffType) { ExecExprVisitor visitor( *seg_promote, seg_promote->get_row_count(), MAX_TIMESTAMP); - proto::plan::GenericValue v; - auto a = v.mutable_array_val(); proto::plan::GenericValue int_val; - int_val.set_int64_val(int64_t(1)); - a->add_array()->CopyFrom(int_val); + int_val.set_int64_val(int64_t(3)); proto::plan::GenericValue bool_val; bool_val.set_bool_val(bool(false)); - a->add_array()->CopyFrom(bool_val); proto::plan::GenericValue float_val; float_val.set_float_val(double(100.34)); - a->add_array()->CopyFrom(float_val); proto::plan::GenericValue string_val; string_val.set_string_val("10dsf"); - a->add_array()->CopyFrom(string_val); - // a->set_same_type(false); - // v.set_allocated_array_val(a); + + proto::plan::GenericValue string_val2; + string_val2.set_string_val("abc"); + proto::plan::GenericValue bool_val2; + bool_val2.set_bool_val(bool(true)); + proto::plan::GenericValue float_val2; + float_val2.set_float_val(double(2.2)); + proto::plan::GenericValue int_val2; + int_val2.set_int64_val(int64_t(1)); std::vector> diff_testcases{ - {{v}, {"string"}}}; + {{int_val, bool_val, float_val, string_val}, + {"diff_type_array"}, + false}, + {{string_val2, bool_val2, float_val2, int_val2}, + {"diff_type_array"}, + true}, + }; for (auto& testcase : diff_testcases) { - auto check = [&](const std::vector& values) { - return std::find(values.begin(), - values.end(), - std::string_view("10dsf")) != values.end(); - }; RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); plan.predicate_ = @@ -4235,21 +4460,11 @@ TEST(Expr, TestJsonContainsDiffType) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - - ASSERT_EQ(ans, check(res)); + ASSERT_EQ(ans, testcase.res); } } for (auto& testcase : diff_testcases) { - auto check = [&](const std::vector& values) { - return false; - }; RetrievePlanNode plan; auto pointer = milvus::Json::pointer(testcase.nested_path); plan.predicate_ = @@ -4270,14 +4485,7 @@ TEST(Expr, TestJsonContainsDiffType) { for (int i = 0; i < N * num_iters; ++i) { auto ans = final[i]; - auto array = milvus::Json(simdjson::padded_string(json_col[i])) - .array_at(pointer); - std::vector res; - for (const auto& element : array) { - res.push_back(element.template get()); - } - - ASSERT_EQ(ans, check(res)); + ASSERT_EQ(ans, testcase.res); } } } diff --git a/internal/core/unittest/test_float16.cpp b/internal/core/unittest/test_float16.cpp new file mode 100644 index 0000000000000..d83c433bb3cbb --- /dev/null +++ b/internal/core/unittest/test_float16.cpp @@ -0,0 +1,419 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include + +#include "common/LoadInfo.h" +#include "common/Types.h" +#include "index/IndexFactory.h" +#include "knowhere/comp/index_param.h" +#include "query/ExprImpl.h" +#include "segcore/Reduce.h" +#include "segcore/reduce_c.h" +#include "test_utils/PbHelper.h" +#include "test_utils/indexbuilder_test_utils.h" + +#include "pb/schema.pb.h" +#include "pb/plan.pb.h" +#include "query/Expr.h" +#include "query/Plan.h" +#include "query/Utils.h" +#include "query/PlanImpl.h" +#include "query/PlanNode.h" +#include "query/PlanProto.h" +#include "query/SearchBruteForce.h" +#include "query/generated/ExecPlanNodeVisitor.h" +#include "query/generated/PlanNodeVisitor.h" +#include "query/generated/ExecExprVisitor.h" +#include "query/generated/ExprVisitor.h" +#include "query/generated/ShowPlanNodeVisitor.h" +#include "segcore/Collection.h" +#include "segcore/SegmentSealed.h" +#include "segcore/SegmentGrowing.h" +#include "segcore/SegmentGrowingImpl.h" +#include "test_utils/AssertUtils.h" +#include "test_utils/DataGen.h" + +using namespace milvus::segcore; +using namespace milvus; +using namespace milvus::index; +using namespace knowhere; +using milvus::index::VectorIndex; +using milvus::segcore::LoadIndexInfo; + +const int64_t ROW_COUNT = 100 * 1000; + +TEST(Float16, Insert) { + using namespace milvus; + using namespace milvus::query; + using namespace milvus::segcore; + int64_t N = ROW_COUNT; + constexpr int64_t size_per_chunk = 32 * 1024; + auto schema = std::make_shared(); + auto float16_vec_fid = schema->AddDebugField( + "float16vec", DataType::VECTOR_FLOAT16, 32, knowhere::metric::L2); + auto i64_fid = schema->AddDebugField("counter", DataType::INT64); + schema->set_primary_field_id(i64_fid); + + auto dataset = DataGen(schema, N); + // auto seg_conf = SegcoreConfig::default_config(); + auto segment = CreateGrowingSegment(schema, empty_index_meta); + segment->PreInsert(N); + segment->Insert(0, + N, + dataset.row_ids_.data(), + dataset.timestamps_.data(), + dataset.raw_); + auto float16_ptr = dataset.get_col(float16_vec_fid); + SegmentInternalInterface& interface = *segment; + auto num_chunk = interface.num_chunk(); + ASSERT_EQ(num_chunk, upper_div(N, size_per_chunk)); + auto row_count = interface.get_row_count(); + ASSERT_EQ(N, row_count); + for (auto chunk_id = 0; chunk_id < num_chunk; ++chunk_id) { + auto float16_span = interface.chunk_data( + float16_vec_fid, chunk_id); + auto begin = chunk_id * size_per_chunk; + auto end = std::min((chunk_id + 1) * size_per_chunk, N); + auto size_of_chunk = end - begin; + for (int i = 0; i < size_of_chunk; ++i) { + // std::cout << float16_span.data()[i] << " " << float16_ptr[i + begin * 32] << std::endl; + ASSERT_EQ(float16_span.data()[i], float16_ptr[i + begin * 32]); + } + } +} + +TEST(Float16, ShowExecutor) { + using namespace milvus::query; + using namespace milvus::segcore; + using namespace milvus; + auto metric_type = knowhere::metric::L2; + auto node = std::make_unique(); + auto schema = std::make_shared(); + auto field_id = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT16, 16, metric_type); + int64_t num_queries = 100L; + auto raw_data = DataGen(schema, num_queries); + auto& info = node->search_info_; + info.metric_type_ = metric_type; + info.topk_ = 20; + info.field_id_ = field_id; + node->predicate_ = std::nullopt; + ShowPlanNodeVisitor show_visitor; + PlanNodePtr base(node.release()); + auto res = show_visitor.call_child(*base); + auto dup = res; + std::cout << dup.dump(4); +} + +TEST(Float16, ExecWithoutPredicateFlat) { + using namespace milvus::query; + using namespace milvus::segcore; + using namespace milvus; + auto schema = std::make_shared(); + auto vec_fid = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT16, 32, knowhere::metric::L2); + schema->AddDebugField("age", DataType::FLOAT); + auto i64_fid = schema->AddDebugField("counter", DataType::INT64); + schema->set_primary_field_id(i64_fid); + const char* raw_plan = R"(vector_anns: < + field_id: 100 + query_info: < + topk: 5 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto plan_str = translate_text_plan_to_binary_plan(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + int64_t N = ROW_COUNT; + auto dataset = DataGen(schema, N); + auto segment = CreateGrowingSegment(schema, empty_index_meta); + segment->PreInsert(N); + segment->Insert(0, + N, + dataset.row_ids_.data(), + dataset.timestamps_.data(), + dataset.raw_); + auto vec_ptr = dataset.get_col(vec_fid); + + auto num_queries = 5; + auto ph_group_raw = CreateFloat16PlaceholderGroup(num_queries, 32, 1024); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + + auto sr = segment->Search(plan.get(), ph_group.get()); + int topk = 5; + + query::Json json = SearchResultToJson(*sr); + std::cout << json.dump(2); +} + +TEST(Float16, GetVector) { + auto metricType = knowhere::metric::L2; + auto schema = std::make_shared(); + auto pk = schema->AddDebugField("pk", DataType::INT64); + auto random = schema->AddDebugField("random", DataType::DOUBLE); + auto vec = schema->AddDebugField( + "embeddings", DataType::VECTOR_FLOAT16, 128, metricType); + schema->set_primary_field_id(pk); + std::map index_params = { + {"index_type", "IVF_FLAT"}, + {"metric_type", metricType}, + {"nlist", "128"}}; + std::map type_params = {{"dim", "128"}}; + FieldIndexMeta fieldIndexMeta( + vec, std::move(index_params), std::move(type_params)); + auto& config = SegcoreConfig::default_config(); + config.set_chunk_rows(1024); + config.set_enable_growing_segment_index(true); + std::map filedMap = {{vec, fieldIndexMeta}}; + IndexMetaPtr metaPtr = + std::make_shared(100000, std::move(filedMap)); + auto segment_growing = CreateGrowingSegment(schema, metaPtr); + auto segment = dynamic_cast(segment_growing.get()); + + int64_t per_batch = 5000; + int64_t n_batch = 20; + int64_t dim = 128; + for (int64_t i = 0; i < n_batch; i++) { + auto dataset = DataGen(schema, per_batch); + auto fakevec = dataset.get_col(vec); + auto offset = segment->PreInsert(per_batch); + segment->Insert(offset, + per_batch, + dataset.row_ids_.data(), + dataset.timestamps_.data(), + dataset.raw_); + auto num_inserted = (i + 1) * per_batch; + auto ids_ds = GenRandomIds(num_inserted); + auto result = + segment->bulk_subscript(vec, ids_ds->GetIds(), num_inserted); + + auto vector = result.get()->mutable_vectors()->float16_vector(); + EXPECT_TRUE(vector.size() == num_inserted * dim * sizeof(float16)); + // EXPECT_TRUE(vector.size() == num_inserted * dim); + // for (size_t i = 0; i < num_inserted; ++i) { + // auto id = ids_ds->GetIds()[i]; + // for (size_t j = 0; j < 128; ++j) { + // EXPECT_TRUE(vector[i * dim + j] == + // fakevec[(id % per_batch) * dim + j]); + // } + // } + } +} + +std::string +generate_collection_schema(std::string metric_type, int dim, bool is_fp16) { + namespace schema = milvus::proto::schema; + schema::CollectionSchema collection_schema; + collection_schema.set_name("collection_test"); + + auto vec_field_schema = collection_schema.add_fields(); + vec_field_schema->set_name("fakevec"); + vec_field_schema->set_fieldid(100); + if (is_fp16) { + vec_field_schema->set_data_type(schema::DataType::Float16Vector); + } else { + vec_field_schema->set_data_type(schema::DataType::FloatVector); + } + auto metric_type_param = vec_field_schema->add_index_params(); + metric_type_param->set_key("metric_type"); + metric_type_param->set_value(metric_type); + auto dim_param = vec_field_schema->add_type_params(); + dim_param->set_key("dim"); + dim_param->set_value(std::to_string(dim)); + + auto other_field_schema = collection_schema.add_fields(); + other_field_schema->set_name("counter"); + other_field_schema->set_fieldid(101); + other_field_schema->set_data_type(schema::DataType::Int64); + other_field_schema->set_is_primary_key(true); + + auto other_field_schema2 = collection_schema.add_fields(); + other_field_schema2->set_name("doubleField"); + other_field_schema2->set_fieldid(102); + other_field_schema2->set_data_type(schema::DataType::Double); + + std::string schema_string; + auto marshal = google::protobuf::TextFormat::PrintToString( + collection_schema, &schema_string); + assert(marshal); + return schema_string; +} + +CCollection +NewCollection(const char* schema_proto_blob) { + auto proto = std::string(schema_proto_blob); + auto collection = std::make_unique(proto); + return (void*)collection.release(); +} + +TEST(Float16, CApiCPlan) { + std::string schema_string = + generate_collection_schema(knowhere::metric::L2, 16, true); + auto collection = NewCollection(schema_string.c_str()); + + // const char* dsl_string = R"( + // { + // "bool": { + // "vector": { + // "fakevec": { + // "metric_type": "L2", + // "params": { + // "nprobe": 10 + // }, + // "query": "$0", + // "topk": 10, + // "round_decimal": 3 + // } + // } + // } + // })"; + + milvus::proto::plan::PlanNode plan_node; + auto vector_anns = plan_node.mutable_vector_anns(); + vector_anns->set_vector_type( + milvus::proto::plan::VectorType::Float16Vector); + vector_anns->set_placeholder_tag("$0"); + vector_anns->set_field_id(100); + auto query_info = vector_anns->mutable_query_info(); + query_info->set_topk(10); + query_info->set_round_decimal(3); + query_info->set_metric_type("L2"); + query_info->set_search_params(R"({"nprobe": 10})"); + auto plan_str = plan_node.SerializeAsString(); + + void* plan = nullptr; + auto status = CreateSearchPlanByExpr( + collection, plan_str.data(), plan_str.size(), &plan); + ASSERT_EQ(status.error_code, Success); + + int64_t field_id = -1; + status = GetFieldID(plan, &field_id); + ASSERT_EQ(status.error_code, Success); + + auto col = static_cast(collection); + for (auto& [target_field_id, field_meta] : + col->get_schema()->get_fields()) { + if (field_meta.is_vector()) { + ASSERT_EQ(field_id, target_field_id.get()); + } + } + ASSERT_NE(field_id, -1); + + DeleteSearchPlan(plan); + DeleteCollection(collection); +} + +TEST(Float16, RetrieveEmpty) { + auto schema = std::make_shared(); + auto fid_64 = schema->AddDebugField("i64", DataType::INT64); + auto DIM = 16; + auto fid_vec = schema->AddDebugField( + "vector_64", DataType::VECTOR_FLOAT16, DIM, knowhere::metric::L2); + schema->set_primary_field_id(fid_64); + + int64_t N = 100; + int64_t req_size = 10; + auto choose = [=](int i) { return i * 3 % N; }; + + auto segment = CreateSealedSegment(schema); + + auto plan = std::make_unique(*schema); + std::vector values; + for (int i = 0; i < req_size; ++i) { + values.emplace_back(choose(i)); + } + auto term_expr = std::make_unique>( + milvus::query::ColumnInfo( + fid_64, DataType::INT64, std::vector()), + values, + proto::plan::GenericValue::kInt64Val); + plan->plan_node_ = std::make_unique(); + plan->plan_node_->predicate_ = std::move(term_expr); + std::vector target_offsets{fid_64, fid_vec}; + plan->field_ids_ = target_offsets; + + auto retrieve_results = + segment->Retrieve(plan.get(), 100, DEFAULT_MAX_OUTPUT_SIZE); + + Assert(retrieve_results->fields_data_size() == target_offsets.size()); + auto field0 = retrieve_results->fields_data(0); + auto field1 = retrieve_results->fields_data(1); + Assert(field0.has_scalars()); + auto field0_data = field0.scalars().long_data(); + Assert(field0_data.data_size() == 0); + Assert(field1.vectors().float16_vector().size() == 0); +} + +TEST(Float16, ExecWithPredicate) { + using namespace milvus::query; + using namespace milvus::segcore; + auto schema = std::make_shared(); + schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT16, 16, knowhere::metric::L2); + schema->AddDebugField("age", DataType::FLOAT); + auto i64_fid = schema->AddDebugField("counter", DataType::INT64); + schema->set_primary_field_id(i64_fid); + const char* raw_plan = R"(vector_anns: < + field_id: 100 + predicates: < + binary_range_expr: < + column_info: < + field_id: 101 + data_type: Float + > + lower_inclusive: true, + upper_inclusive: false, + lower_value: < + float_val: -1 + > + upper_value: < + float_val: 1 + > + > + > + query_info: < + topk: 5 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + int64_t N = ROW_COUNT; + auto dataset = DataGen(schema, N); + auto segment = CreateGrowingSegment(schema, empty_index_meta); + segment->PreInsert(N); + segment->Insert(0, + N, + dataset.row_ids_.data(), + dataset.timestamps_.data(), + dataset.raw_); + + auto plan_str = translate_text_plan_to_binary_plan(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto num_queries = 5; + auto ph_group_raw = CreateFloat16PlaceholderGroup(num_queries, 16, 1024); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + + auto sr = segment->Search(plan.get(), ph_group.get()); + int topk = 5; + + query::Json json = SearchResultToJson(*sr); + std::cout << json.dump(2); +} \ No newline at end of file diff --git a/internal/core/unittest/test_growing.cpp b/internal/core/unittest/test_growing.cpp index 12ee1e6b44596..34272c72f948f 100644 --- a/internal/core/unittest/test_growing.cpp +++ b/internal/core/unittest/test_growing.cpp @@ -31,10 +31,18 @@ TEST(Growing, DeleteCount) { int64_t c = 10; auto offset = 0; + auto dataset = DataGen(schema, c); + auto pks = dataset.get_col(pk); + segment->Insert(offset, + c, + dataset.row_ids_.data(), + dataset.timestamps_.data(), + dataset.raw_); + Timestamp begin_ts = 100; auto tss = GenTss(c, begin_ts); - auto pks = GenPKs(c, 0); - auto status = segment->Delete(offset, c, pks.get(), tss.data()); + auto del_pks = GenPKs(pks.begin(), pks.end()); + auto status = segment->Delete(offset, c, del_pks.get(), tss.data()); ASSERT_TRUE(status.ok()); auto cnt = segment->get_deleted_count(); @@ -100,6 +108,18 @@ TEST(Growing, FillData) { auto double_field = schema->AddDebugField("double", DataType::DOUBLE); auto varchar_field = schema->AddDebugField("varchar", DataType::VARCHAR); auto json_field = schema->AddDebugField("json", DataType::JSON); + auto int_array_field = + schema->AddDebugField("int_array", DataType::ARRAY, DataType::INT8); + auto long_array_field = + schema->AddDebugField("long_array", DataType::ARRAY, DataType::INT64); + auto bool_array_field = + schema->AddDebugField("bool_array", DataType::ARRAY, DataType::BOOL); + auto string_array_field = schema->AddDebugField( + "string_array", DataType::ARRAY, DataType::VARCHAR); + auto double_array_field = schema->AddDebugField( + "double_array", DataType::ARRAY, DataType::DOUBLE); + auto float_array_field = + schema->AddDebugField("float_array", DataType::ARRAY, DataType::FLOAT); auto vec = schema->AddDebugField( "embeddings", DataType::VECTOR_FLOAT, 128, metric_type); schema->set_primary_field_id(int64_field); @@ -133,6 +153,15 @@ TEST(Growing, FillData) { auto double_values = dataset.get_col(double_field); auto varchar_values = dataset.get_col(varchar_field); auto json_values = dataset.get_col(json_field); + auto int_array_values = dataset.get_col(int_array_field); + auto long_array_values = dataset.get_col(long_array_field); + auto bool_array_values = dataset.get_col(bool_array_field); + auto string_array_values = + dataset.get_col(string_array_field); + auto double_array_values = + dataset.get_col(double_array_field); + auto float_array_values = + dataset.get_col(float_array_field); auto vector_values = dataset.get_col(vec); auto offset = segment->PreInsert(per_batch); @@ -159,13 +188,26 @@ TEST(Growing, FillData) { varchar_field, ids_ds->GetIds(), num_inserted); auto json_result = segment->bulk_subscript(json_field, ids_ds->GetIds(), num_inserted); + auto int_array_result = segment->bulk_subscript( + int_array_field, ids_ds->GetIds(), num_inserted); + auto long_array_result = segment->bulk_subscript( + long_array_field, ids_ds->GetIds(), num_inserted); + auto bool_array_result = segment->bulk_subscript( + bool_array_field, ids_ds->GetIds(), num_inserted); + auto string_array_result = segment->bulk_subscript( + string_array_field, ids_ds->GetIds(), num_inserted); + auto double_array_result = segment->bulk_subscript( + double_array_field, ids_ds->GetIds(), num_inserted); + auto float_array_result = segment->bulk_subscript( + float_array_field, ids_ds->GetIds(), num_inserted); auto vec_result = segment->bulk_subscript(vec, ids_ds->GetIds(), num_inserted); EXPECT_EQ(int8_result->scalars().int_data().data_size(), num_inserted); EXPECT_EQ(int16_result->scalars().int_data().data_size(), num_inserted); EXPECT_EQ(int32_result->scalars().int_data().data_size(), num_inserted); - EXPECT_EQ(int64_result->scalars().long_data().data_size(), num_inserted); + EXPECT_EQ(int64_result->scalars().long_data().data_size(), + num_inserted); EXPECT_EQ(float_result->scalars().float_data().data_size(), num_inserted); EXPECT_EQ(double_result->scalars().double_data().data_size(), @@ -175,5 +217,17 @@ TEST(Growing, FillData) { EXPECT_EQ(json_result->scalars().json_data().data_size(), num_inserted); EXPECT_EQ(vec_result->vectors().float_vector().data_size(), num_inserted * dim); + EXPECT_EQ(int_array_result->scalars().array_data().data_size(), + num_inserted); + EXPECT_EQ(long_array_result->scalars().array_data().data_size(), + num_inserted); + EXPECT_EQ(bool_array_result->scalars().array_data().data_size(), + num_inserted); + EXPECT_EQ(string_array_result->scalars().array_data().data_size(), + num_inserted); + EXPECT_EQ(double_array_result->scalars().array_data().data_size(), + num_inserted); + EXPECT_EQ(float_array_result->scalars().array_data().data_size(), + num_inserted); } } diff --git a/internal/core/unittest/test_growing_index.cpp b/internal/core/unittest/test_growing_index.cpp index 634dbb06d2037..fa612fc31c4be 100644 --- a/internal/core/unittest/test_growing_index.cpp +++ b/internal/core/unittest/test_growing_index.cpp @@ -46,7 +46,7 @@ TEST(GrowingIndex, Correctness) { milvus::proto::plan::PlanNode plan_node; auto vector_anns = plan_node.mutable_vector_anns(); - vector_anns->set_is_binary(false); + vector_anns->set_vector_type(milvus::proto::plan::VectorType::FloatVector); vector_anns->set_placeholder_tag("$0"); vector_anns->set_field_id(102); auto query_info = vector_anns->mutable_query_info(); @@ -58,7 +58,8 @@ TEST(GrowingIndex, Correctness) { milvus::proto::plan::PlanNode range_query_plan_node; auto vector_range_querys = range_query_plan_node.mutable_vector_anns(); - vector_range_querys->set_is_binary(false); + vector_range_querys->set_vector_type( + milvus::proto::plan::VectorType::FloatVector); vector_range_querys->set_placeholder_tag("$0"); vector_range_querys->set_field_id(102); auto range_query_info = vector_range_querys->mutable_query_info(); diff --git a/internal/core/unittest/test_index_c_api.cpp b/internal/core/unittest/test_index_c_api.cpp index 155b23e951008..54c528b3a542f 100644 --- a/internal/core/unittest/test_index_c_api.cpp +++ b/internal/core/unittest/test_index_c_api.cpp @@ -48,34 +48,34 @@ TEST(FloatVecIndex, All) { { status = CreateIndex( dtype, type_params_str.c_str(), index_params_str.c_str(), &index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = BuildFloatVecIndex(index, NB * DIM, xb_data.data()); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = SerializeIndexToBinarySet(index, &binary_set); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), ©_index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = LoadIndexFromBinarySet(copy_index, binary_set); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = DeleteIndex(index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = DeleteIndex(copy_index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { DeleteBinarySet(binary_set); } } @@ -107,34 +107,34 @@ TEST(BinaryVecIndex, All) { { status = CreateIndex( dtype, type_params_str.c_str(), index_params_str.c_str(), &index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = BuildBinaryVecIndex(index, NB * DIM / 8, xb_data.data()); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = SerializeIndexToBinarySet(index, &binary_set); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), ©_index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = LoadIndexFromBinarySet(copy_index, binary_set); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = DeleteIndex(index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = DeleteIndex(copy_index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { DeleteBinarySet(binary_set); } } @@ -166,35 +166,35 @@ TEST(CBoolIndexTest, All) { type_params_str.c_str(), index_params_str.c_str(), &index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = BuildScalarIndex( index, half_ds->GetRows(), half_ds->GetTensor()); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = SerializeIndexToBinarySet(index, &binary_set); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), ©_index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = LoadIndexFromBinarySet(copy_index, binary_set); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = DeleteIndex(index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = DeleteIndex(copy_index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { DeleteBinarySet(binary_set); } } @@ -224,34 +224,34 @@ TEST(CInt64IndexTest, All) { type_params_str.c_str(), index_params_str.c_str(), &index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = BuildScalarIndex(index, arr.size(), arr.data()); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = SerializeIndexToBinarySet(index, &binary_set); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), ©_index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = LoadIndexFromBinarySet(copy_index, binary_set); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = DeleteIndex(index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = DeleteIndex(copy_index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { DeleteBinarySet(binary_set); } } @@ -283,35 +283,35 @@ TEST(CStringIndexTest, All) { type_params_str.c_str(), index_params_str.c_str(), &index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = BuildScalarIndex( index, (str_ds->GetRows()), (str_ds->GetTensor())); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = SerializeIndexToBinarySet(index, &binary_set); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = CreateIndex(dtype, type_params_str.c_str(), index_params_str.c_str(), ©_index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = LoadIndexFromBinarySet(copy_index, binary_set); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = DeleteIndex(index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { status = DeleteIndex(copy_index); - ASSERT_EQ(Success, status.error_code); + ASSERT_EQ(milvus::Success, status.error_code); } { DeleteBinarySet(binary_set); } } diff --git a/internal/core/unittest/test_index_wrapper.cpp b/internal/core/unittest/test_index_wrapper.cpp index 2d44206c612a6..f466becaab168 100644 --- a/internal/core/unittest/test_index_wrapper.cpp +++ b/internal/core/unittest/test_index_wrapper.cpp @@ -70,10 +70,9 @@ class IndexWrapperTest : public ::testing::TestWithParam { is_binary = is_binary_map[index_type]; if (is_binary) { - vec_field_data_type = DataType::VECTOR_FLOAT; - ; - } else { vec_field_data_type = DataType::VECTOR_BINARY; + } else { + vec_field_data_type = DataType::VECTOR_FLOAT; } auto dataset = GenDataset(NB, metric_type, is_binary); @@ -129,8 +128,16 @@ INSTANTIATE_TEST_CASE_P( std::pair(knowhere::IndexEnum::INDEX_HNSW, knowhere::metric::L2))); TEST_P(IndexWrapperTest, BuildAndQuery) { + milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100}; + milvus::storage::IndexMeta index_meta{3, 100, 1000, 1}; + auto chunk_manager = milvus::storage::CreateChunkManager(storage_config_); + + storage::FileManagerContext file_manager_context( + field_data_meta, index_meta, chunk_manager); + config[milvus::index::INDEX_ENGINE_VERSION] = + std::to_string(knowhere::Version::GetCurrentVersion().VersionNumber()); auto index = milvus::indexbuilder::IndexFactory::GetInstance().CreateIndex( - vec_field_data_type, config, nullptr); + vec_field_data_type, config, file_manager_context); auto dataset = GenDataset(NB, metric_type, is_binary); knowhere::DataSetPtr xb_dataset; @@ -146,12 +153,18 @@ TEST_P(IndexWrapperTest, BuildAndQuery) { ASSERT_NO_THROW(index->Build(xb_dataset)); auto binary_set = index->Serialize(); + std::vector index_files; + for (auto& binary : binary_set.binary_map_) { + index_files.emplace_back(binary.first); + } + config["index_files"] = index_files; auto copy_index = milvus::indexbuilder::IndexFactory::GetInstance().CreateIndex( - vec_field_data_type, config, nullptr); + vec_field_data_type, config, file_manager_context); auto vec_index = static_cast(copy_index.get()); ASSERT_EQ(vec_index->dim(), DIM); + ASSERT_NO_THROW(vec_index->Load(binary_set)); milvus::SearchInfo search_info; diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index 59a38322c2dbe..cb2c46961918d 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -16,7 +16,9 @@ #include #include +#include "common/EasyAssert.h" #include "knowhere/comp/index_param.h" +#include "nlohmann/json.hpp" #include "query/SearchBruteForce.h" #include "segcore/Reduce.h" #include "index/IndexFactory.h" @@ -55,7 +57,7 @@ generate_data(int N) { } } // namespace -Status +SegcoreError merge_into(int64_t queries, int64_t topk, float* distances, @@ -90,7 +92,7 @@ merge_into(int64_t queries, std::copy_n(buf_dis.data(), topk, src2_dis); std::copy_n(buf_uids.data(), topk, src2_uids); } - return Status::OK(); + return SegcoreError::success(); } /* @@ -173,7 +175,7 @@ TEST(Indexing, BinaryBruteForce) { auto json = SearchResultToJson(sr); std::cout << json.dump(2); #ifdef __linux__ - auto ref = json::parse(R"( + auto ref = nlohmann::json::parse(R"( [ [ [ "1024->0.000000", "48942->0.642000", "18494->0.644000", "68225->0.644000", "93557->0.644000" ], @@ -190,7 +192,7 @@ TEST(Indexing, BinaryBruteForce) { ] )"); #else // for mac - auto ref = json::parse(R"( + auto ref = nlohmann::json::parse(R"( [ [ [ "1024->0.000000", "59169->0.645000", "98548->0.646000", "3356->0.646000", "90373->0.647000" ], @@ -221,8 +223,10 @@ TEST(Indexing, Naive) { create_index_info.field_type = DataType::VECTOR_FLOAT; create_index_info.metric_type = knowhere::metric::L2; create_index_info.index_type = knowhere::IndexEnum::INDEX_FAISS_IVFPQ; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); auto index = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, nullptr); + create_index_info, milvus::storage::FileManagerContext()); auto build_conf = knowhere::Json{ {knowhere::meta::METRIC_TYPE, knowhere::metric::L2}, @@ -292,8 +296,14 @@ class IndexTest : public ::testing::TestWithParam { auto param = GetParam(); index_type = param.first; metric_type = param.second; - NB = 10000; - if (index_type == knowhere::IndexEnum::INDEX_HNSW) { + NB = 3000; + + // try to reduce the test time, + // but the large dataset is needed for the case below. + auto test_name = std::string( + testing::UnitTest::GetInstance()->current_test_info()->name()); + if (test_name == "Mmap" && + index_type == knowhere::IndexEnum::INDEX_HNSW) { NB = 270000; } build_conf = generate_build_conf(index_type, metric_type); @@ -350,7 +360,7 @@ class IndexTest : public ::testing::TestWithParam { std::vector xb_bin_data; knowhere::DataSetPtr xq_dataset; int64_t query_offset = 100; - int64_t NB = 10000; + int64_t NB = 3000; StorageConfig storage_config_; }; @@ -378,15 +388,17 @@ TEST_P(IndexTest, BuildAndQuery) { create_index_info.index_type = index_type; create_index_info.metric_type = metric_type; create_index_info.field_type = vec_field_data_type; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); index::IndexBasePtr index; milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100}; milvus::storage::IndexMeta index_meta{3, 100, 1000, 1}; auto chunk_manager = milvus::storage::CreateChunkManager(storage_config_); - auto file_manager = milvus::storage::CreateFileManager( - index_type, field_data_meta, index_meta, chunk_manager); + milvus::storage::FileManagerContext file_manager_context( + field_data_meta, index_meta, chunk_manager); index = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, file_manager); + create_index_info, file_manager_context); ASSERT_NO_THROW(index->BuildWithDataset(xb_dataset, build_conf)); milvus::index::IndexBasePtr new_index; @@ -396,13 +408,14 @@ TEST_P(IndexTest, BuildAndQuery) { index.reset(); new_index = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, file_manager); + create_index_info, file_manager_context); vec_index = dynamic_cast(new_index.get()); std::vector index_files; for (auto& binary : binary_set.binary_map_) { index_files.emplace_back(binary.first); } + load_conf = generate_load_conf(index_type, metric_type, 0); load_conf["index_files"] = index_files; ASSERT_NO_THROW(vec_index->Load(load_conf)); EXPECT_EQ(vec_index->Count(), NB); @@ -429,15 +442,17 @@ TEST_P(IndexTest, Mmap) { create_index_info.index_type = index_type; create_index_info.metric_type = metric_type; create_index_info.field_type = vec_field_data_type; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); index::IndexBasePtr index; milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100}; milvus::storage::IndexMeta index_meta{3, 100, 1000, 1}; auto chunk_manager = milvus::storage::CreateChunkManager(storage_config_); - auto file_manager = milvus::storage::CreateFileManager( - index_type, field_data_meta, index_meta, chunk_manager); + milvus::storage::FileManagerContext file_manager_context( + field_data_meta, index_meta, chunk_manager); index = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, file_manager); + create_index_info, file_manager_context); ASSERT_NO_THROW(index->BuildWithDataset(xb_dataset, build_conf)); milvus::index::IndexBasePtr new_index; @@ -447,7 +462,7 @@ TEST_P(IndexTest, Mmap) { index.reset(); new_index = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, file_manager); + create_index_info, file_manager_context); if (!new_index->IsMmapSupported()) { return; } @@ -457,6 +472,7 @@ TEST_P(IndexTest, Mmap) { for (auto& binary : binary_set.binary_map_) { index_files.emplace_back(binary.first); } + load_conf = generate_load_conf(index_type, metric_type, 0); load_conf["index_files"] = index_files; load_conf["mmap_filepath"] = "mmap/test_index_mmap_" + index_type; vec_index->Load(load_conf); @@ -484,39 +500,39 @@ TEST_P(IndexTest, GetVector) { create_index_info.index_type = index_type; create_index_info.metric_type = metric_type; create_index_info.field_type = vec_field_data_type; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); index::IndexBasePtr index; milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100}; milvus::storage::IndexMeta index_meta{3, 100, 1000, 1}; auto chunk_manager = milvus::storage::CreateChunkManager(storage_config_); - auto file_manager = milvus::storage::CreateFileManager( - index_type, field_data_meta, index_meta, chunk_manager); + milvus::storage::FileManagerContext file_manager_context( + field_data_meta, index_meta, chunk_manager); index = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, file_manager); + create_index_info, file_manager_context); ASSERT_NO_THROW(index->BuildWithDataset(xb_dataset, build_conf)); milvus::index::IndexBasePtr new_index; milvus::index::VectorIndex* vec_index = nullptr; - if (index_type == knowhere::IndexEnum::INDEX_DISKANN) { - // TODO ::diskann.query need load first, ugly - auto binary_set = index->Serialize(milvus::Config{}); - index.reset(); - - new_index = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, file_manager); - - vec_index = dynamic_cast(new_index.get()); + auto binary_set = index->Upload(); + index.reset(); + std::vector index_files; + for (auto& binary : binary_set.binary_map_) { + index_files.emplace_back(binary.first); + } + new_index = milvus::index::IndexFactory::GetInstance().CreateIndex( + create_index_info, file_manager_context); + load_conf = generate_load_conf(index_type, metric_type, 0); + load_conf["index_files"] = index_files; - std::vector index_files; - for (auto& binary : binary_set.binary_map_) { - index_files.emplace_back(binary.first); - } - load_conf["index_files"] = index_files; + vec_index = dynamic_cast(new_index.get()); + if (index_type == knowhere::IndexEnum::INDEX_DISKANN) { vec_index->Load(binary_set, load_conf); EXPECT_EQ(vec_index->Count(), NB); } else { - vec_index = dynamic_cast(index.get()); + vec_index->Load(load_conf); } EXPECT_EQ(vec_index->GetDim(), DIM); EXPECT_EQ(vec_index->Count(), NB); @@ -561,6 +577,8 @@ TEST(Indexing, SearchDiskAnnWithInvalidParam) { create_index_info.index_type = index_type; create_index_info.metric_type = metric_type; create_index_info.field_type = milvus::DataType::VECTOR_FLOAT; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); int64_t collection_id = 1; int64_t partition_id = 2; @@ -575,10 +593,10 @@ TEST(Indexing, SearchDiskAnnWithInvalidParam) { milvus::storage::IndexMeta index_meta{ segment_id, field_id, build_id, index_version}; auto chunk_manager = storage::CreateChunkManager(storage_config); - auto file_manager = milvus::storage::CreateFileManager( - index_type, field_data_meta, index_meta, chunk_manager); + milvus::storage::FileManagerContext file_manager_context( + field_data_meta, index_meta, chunk_manager); auto index = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, file_manager); + create_index_info, file_manager_context); auto build_conf = Config{ {knowhere::meta::METRIC_TYPE, metric_type}, @@ -603,7 +621,7 @@ TEST(Indexing, SearchDiskAnnWithInvalidParam) { index.reset(); auto new_index = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, file_manager); + create_index_info, file_manager_context); auto vec_index = dynamic_cast(new_index.get()); std::vector index_files; for (auto& binary : binary_set.binary_map_) { diff --git a/internal/core/unittest/test_minio_chunk_manager.cpp b/internal/core/unittest/test_minio_chunk_manager.cpp index 9417311b4c136..64a94c5a949e4 100644 --- a/internal/core/unittest/test_minio_chunk_manager.cpp +++ b/internal/core/unittest/test_minio_chunk_manager.cpp @@ -19,7 +19,6 @@ using namespace std; using namespace milvus; using namespace milvus::storage; -using namespace boost::filesystem; class MinioChunkManagerTest : public testing::Test { public: @@ -30,7 +29,7 @@ class MinioChunkManagerTest : public testing::Test { virtual void SetUp() { - configs_ = get_default_remote_storage_config(); + configs_ = StorageConfig{}; chunk_manager_ = std::make_unique(configs_); } @@ -39,68 +38,50 @@ class MinioChunkManagerTest : public testing::Test { StorageConfig configs_; }; -StorageConfig -get_google_cloud_storage_config() { - auto endpoint = "storage.googleapis.com:443"; - auto accessKey = ""; - auto accessValue = ""; - auto rootPath = "files"; - auto useSSL = true; - auto useIam = true; - auto iamEndPoint = ""; - auto bucketName = "gcp-zilliz-infra-test"; - - return StorageConfig{endpoint, - bucketName, - accessKey, - accessValue, - rootPath, - "minio", - iamEndPoint, - "error", - "", - useSSL, - useIam}; -} - -StorageConfig -get_aliyun_cloud_storage_config() { - auto endpoint = "oss-cn-shanghai.aliyuncs.com:443"; - auto accessKey = ""; - auto accessValue = ""; - auto rootPath = "files"; - auto useSSL = true; - auto useIam = true; - auto iamEndPoint = ""; - auto bucketName = "vdc-infra-poc"; - - return StorageConfig{endpoint, - bucketName, - accessKey, - accessValue, - rootPath, - "minio", - iamEndPoint, - useSSL, - useIam}; -} - -class AliyunChunkManagerTest : public testing::Test { - public: - AliyunChunkManagerTest() { - } - ~AliyunChunkManagerTest() { - } - - virtual void - SetUp() { - chunk_manager_ = std::make_unique( - get_aliyun_cloud_storage_config()); - } - - protected: - MinioChunkManagerPtr chunk_manager_; -}; +//StorageConfig +//get_aliyun_cloud_storage_config() { +// auto endpoint = "oss-cn-shanghai.aliyuncs.com:443"; +// auto accessKey = ""; +// auto accessValue = ""; +// auto rootPath = "files"; +// auto useSSL = true; +// auto useIam = true; +// auto iamEndPoint = ""; +// auto bucketName = "vdc-infra-poc"; +// auto cloudProvider = "aliyun"; +// auto logLevel = "error"; +// auto region = ""; +// +// return StorageConfig{endpoint, +// bucketName, +// accessKey, +// accessValue, +// rootPath, +// "minio", +// cloudProvider, +// iamEndPoint, +// logLevel, +// region, +// useSSL, +// useIam}; +//} + +//class AliyunChunkManagerTest : public testing::Test { +// public: +// AliyunChunkManagerTest() { +// } +// ~AliyunChunkManagerTest() { +// } +// +// virtual void +// SetUp() { +// chunk_manager_ = std::make_unique( +// get_aliyun_cloud_storage_config()); +// } +// +// protected: +// MinioChunkManagerPtr chunk_manager_; +//}; TEST_F(MinioChunkManagerTest, BucketPositive) { string testBucketName = "test-bucket"; @@ -120,12 +101,8 @@ TEST_F(MinioChunkManagerTest, BucketNegtive) { // create already exist bucket chunk_manager_->CreateBucket(testBucketName); - try { - chunk_manager_->CreateBucket(testBucketName); - } catch (S3ErrorException& e) { - EXPECT_TRUE(std::string(e.what()).find("BucketAlreadyOwnedByYou") != - string::npos); - } + bool created = chunk_manager_->CreateBucket(testBucketName); + EXPECT_EQ(created, false); chunk_manager_->DeleteBucket(testBucketName); } @@ -147,9 +124,9 @@ TEST_F(MinioChunkManagerTest, WritePositive) { chunk_manager_->SetBucketName(testBucketName); EXPECT_EQ(chunk_manager_->GetBucketName(), testBucketName); - // if (!chunk_manager_->BucketExists(testBucketName)) { - // chunk_manager_->CreateBucket(testBucketName); - // } + if (!chunk_manager_->BucketExists(testBucketName)) { + chunk_manager_->CreateBucket(testBucketName); + } auto has_bucket = chunk_manager_->BucketExists(testBucketName); uint8_t data[5] = {0x17, 0x32, 0x45, 0x34, 0x23}; string path = "1"; @@ -225,6 +202,30 @@ TEST_F(MinioChunkManagerTest, ReadPositive) { chunk_manager_->DeleteBucket(testBucketName); } +TEST_F(MinioChunkManagerTest, ReadNotExist) { + string testBucketName = configs_.bucket_name; + chunk_manager_->SetBucketName(testBucketName); + EXPECT_EQ(chunk_manager_->GetBucketName(), testBucketName); + + if (!chunk_manager_->BucketExists(testBucketName)) { + chunk_manager_->CreateBucket(testBucketName); + } + string path = "1/5/8"; + uint8_t readdata[20] = {0}; + + EXPECT_THROW( + try { + chunk_manager_->Read(path, readdata, sizeof(readdata)); + } catch (SegcoreError& e) { + EXPECT_TRUE(std::string(e.what()).find("exist") != string::npos); + throw e; + }, + SegcoreError); + + chunk_manager_->Remove(path); + chunk_manager_->DeleteBucket(testBucketName); +} + TEST_F(MinioChunkManagerTest, RemovePositive) { string testBucketName = "test-remove"; chunk_manager_->SetBucketName(testBucketName); @@ -240,7 +241,12 @@ TEST_F(MinioChunkManagerTest, RemovePositive) { bool exist = chunk_manager_->Exist(path); EXPECT_EQ(exist, true); - chunk_manager_->Remove(path); + bool deleted = chunk_manager_->Remove(path); + EXPECT_EQ(deleted, true); + + // test double deleted + deleted = chunk_manager_->Remove(path); + EXPECT_EQ(deleted, false); exist = chunk_manager_->Exist(path); EXPECT_EQ(exist, false); @@ -286,45 +292,45 @@ TEST_F(MinioChunkManagerTest, ListWithPrefixPositive) { chunk_manager_->DeleteBucket(testBucketName); } -TEST_F(AliyunChunkManagerTest, ReadPositive) { - string testBucketName = "vdc-infra-poc"; - chunk_manager_->SetBucketName(testBucketName); - EXPECT_EQ(chunk_manager_->GetBucketName(), testBucketName); - - uint8_t data[5] = {0x17, 0x32, 0x45, 0x34, 0x23}; - string path = "1/4/6"; - chunk_manager_->Write(path, data, sizeof(data)); - bool exist = chunk_manager_->Exist(path); - EXPECT_EQ(exist, true); - auto size = chunk_manager_->Size(path); - EXPECT_EQ(size, 5); - - uint8_t readdata[20] = {0}; - size = chunk_manager_->Read(path, readdata, 20); - EXPECT_EQ(readdata[0], 0x17); - EXPECT_EQ(readdata[1], 0x32); - EXPECT_EQ(readdata[2], 0x45); - EXPECT_EQ(readdata[3], 0x34); - EXPECT_EQ(readdata[4], 0x23); - - size = chunk_manager_->Read(path, readdata, 3); - EXPECT_EQ(size, 3); - EXPECT_EQ(readdata[0], 0x17); - EXPECT_EQ(readdata[1], 0x32); - EXPECT_EQ(readdata[2], 0x45); - - uint8_t dataWithNULL[] = {0x17, 0x32, 0x00, 0x34, 0x23}; - chunk_manager_->Write(path, dataWithNULL, sizeof(dataWithNULL)); - exist = chunk_manager_->Exist(path); - EXPECT_EQ(exist, true); - size = chunk_manager_->Size(path); - EXPECT_EQ(size, 5); - size = chunk_manager_->Read(path, readdata, 20); - EXPECT_EQ(readdata[0], 0x17); - EXPECT_EQ(readdata[1], 0x32); - EXPECT_EQ(readdata[2], 0x00); - EXPECT_EQ(readdata[3], 0x34); - EXPECT_EQ(readdata[4], 0x23); - - chunk_manager_->Remove(path); -} +//TEST_F(AliyunChunkManagerTest, ReadPositive) { +// string testBucketName = "vdc-infra-poc"; +// chunk_manager_->SetBucketName(testBucketName); +// EXPECT_EQ(chunk_manager_->GetBucketName(), testBucketName); +// +// uint8_t data[5] = {0x17, 0x32, 0x45, 0x34, 0x23}; +// string path = "1/4/6"; +// chunk_manager_->Write(path, data, sizeof(data)); +// bool exist = chunk_manager_->Exist(path); +// EXPECT_EQ(exist, true); +// auto size = chunk_manager_->Size(path); +// EXPECT_EQ(size, 5); +// +// uint8_t readdata[20] = {0}; +// size = chunk_manager_->Read(path, readdata, 20); +// EXPECT_EQ(readdata[0], 0x17); +// EXPECT_EQ(readdata[1], 0x32); +// EXPECT_EQ(readdata[2], 0x45); +// EXPECT_EQ(readdata[3], 0x34); +// EXPECT_EQ(readdata[4], 0x23); +// +// size = chunk_manager_->Read(path, readdata, 3); +// EXPECT_EQ(size, 3); +// EXPECT_EQ(readdata[0], 0x17); +// EXPECT_EQ(readdata[1], 0x32); +// EXPECT_EQ(readdata[2], 0x45); +// +// uint8_t dataWithNULL[] = {0x17, 0x32, 0x00, 0x34, 0x23}; +// chunk_manager_->Write(path, dataWithNULL, sizeof(dataWithNULL)); +// exist = chunk_manager_->Exist(path); +// EXPECT_EQ(exist, true); +// size = chunk_manager_->Size(path); +// EXPECT_EQ(size, 5); +// size = chunk_manager_->Read(path, readdata, 20); +// EXPECT_EQ(readdata[0], 0x17); +// EXPECT_EQ(readdata[1], 0x32); +// EXPECT_EQ(readdata[2], 0x00); +// EXPECT_EQ(readdata[3], 0x34); +// EXPECT_EQ(readdata[4], 0x23); +// +// chunk_manager_->Remove(path); +//} diff --git a/internal/core/unittest/test_parquet_c.cpp b/internal/core/unittest/test_parquet_c.cpp index 1084f04dbf400..552277e7f4f27 100644 --- a/internal/core/unittest/test_parquet_c.cpp +++ b/internal/core/unittest/test_parquet_c.cpp @@ -17,11 +17,13 @@ #include #include +#include "common/EasyAssert.h" #include "storage/parquet_c.h" #include "storage/PayloadReader.h" #include "storage/PayloadWriter.h" namespace wrapper = milvus::storage; +using ErrorCode = milvus::ErrorCode; static void WriteToFile(CBuffer cb) { diff --git a/internal/core/unittest/test_query.cpp b/internal/core/unittest/test_query.cpp index 7b8798f5126f1..de53f6299447c 100644 --- a/internal/core/unittest/test_query.cpp +++ b/internal/core/unittest/test_query.cpp @@ -22,6 +22,7 @@ #include "test_utils/AssertUtils.h" #include "test_utils/DataGen.h" +using json = nlohmann::json; using namespace milvus; using namespace milvus::query; using namespace milvus::segcore; @@ -31,9 +32,6 @@ const int64_t ROW_COUNT = 100 * 1000; } TEST(Query, ShowExecutor) { - using namespace milvus::query; - using namespace milvus::segcore; - using namespace milvus; auto metric_type = knowhere::metric::L2; auto node = std::make_unique(); auto schema = std::make_shared(); diff --git a/internal/core/unittest/test_remote_chunk_manager.cpp b/internal/core/unittest/test_remote_chunk_manager.cpp new file mode 100644 index 0000000000000..77f3ec9498b10 --- /dev/null +++ b/internal/core/unittest/test_remote_chunk_manager.cpp @@ -0,0 +1,277 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include +#include + +#include "storage/MinioChunkManager.h" +#include "storage/Util.h" + +using namespace std; +using namespace milvus; +using namespace milvus::storage; + +const string +get_default_bucket_name() { + return "a-bucket"; +} + +StorageConfig +get_default_remote_storage_config() { + StorageConfig storage_config; + storage_config.storage_type = "remote"; + storage_config.address = "localhost:9000"; + char const* tmp = getenv("MINIO_ADDRESS"); + if (tmp != NULL) { + storage_config.address = string(tmp); + } + storage_config.bucket_name = get_default_bucket_name(); + storage_config.access_key_id = "minioadmin"; + storage_config.access_key_value = "minioadmin"; + storage_config.root_path = "files"; + storage_config.storage_type = "remote"; + storage_config.cloud_provider = ""; + storage_config.useSSL = false; + storage_config.useIAM = false; + return storage_config; +} + +class RemoteChunkManagerTest : public testing::Test { + public: + RemoteChunkManagerTest() { + } + ~RemoteChunkManagerTest() { + } + + virtual void + SetUp() { + configs_ = get_default_remote_storage_config(); + aws_chunk_manager_ = make_unique(configs_); + chunk_manager_ptr_ = CreateChunkManager(configs_); + } + + protected: + std::unique_ptr aws_chunk_manager_; + ChunkManagerPtr chunk_manager_ptr_; + StorageConfig configs_; +}; + +TEST_F(RemoteChunkManagerTest, BasicFunctions) { + EXPECT_TRUE(aws_chunk_manager_->GetName() == "AwsChunkManager"); + EXPECT_TRUE(chunk_manager_ptr_->GetName() == "MinioChunkManager"); + + ChunkManagerPtr the_chunk_manager_; + configs_.cloud_provider = "aws"; + the_chunk_manager_ = CreateChunkManager(configs_); + EXPECT_TRUE(the_chunk_manager_->GetName() == "AwsChunkManager"); + + configs_.cloud_provider = "gcp"; + the_chunk_manager_ = CreateChunkManager(configs_); + EXPECT_TRUE(the_chunk_manager_->GetName() == "GcpChunkManager"); + + configs_.cloud_provider = "aliyun"; + the_chunk_manager_ = CreateChunkManager(configs_); + EXPECT_TRUE(the_chunk_manager_->GetName() == "AliyunChunkManager"); + +#ifdef AZURE_BUILD_DIR + configs_.cloud_provider = "azure"; + the_chunk_manager_ = CreateChunkManager(configs_); + EXPECT_TRUE(the_chunk_manager_->GetName() == "AzureChunkManager"); +#endif + + configs_.cloud_provider = ""; +} + +TEST_F(RemoteChunkManagerTest, BucketPositive) { + string testBucketName = get_default_bucket_name(); + aws_chunk_manager_->SetBucketName(testBucketName); + bool exist = aws_chunk_manager_->BucketExists(testBucketName); + EXPECT_EQ(exist, false); + aws_chunk_manager_->CreateBucket(testBucketName); + exist = aws_chunk_manager_->BucketExists(testBucketName); + EXPECT_EQ(exist, true); + aws_chunk_manager_->DeleteBucket(testBucketName); +} + +TEST_F(RemoteChunkManagerTest, BucketNegtive) { + string testBucketName = get_default_bucket_name(); + aws_chunk_manager_->SetBucketName(testBucketName); + aws_chunk_manager_->DeleteBucket(testBucketName); + + // create already exist bucket + aws_chunk_manager_->CreateBucket(testBucketName); + try { + aws_chunk_manager_->CreateBucket(testBucketName); + } catch (SegcoreError& e) { + EXPECT_TRUE(std::string(e.what()).find("exists") != + string::npos); + } + aws_chunk_manager_->DeleteBucket(testBucketName); +} + +TEST_F(RemoteChunkManagerTest, ObjectExist) { + string testBucketName = get_default_bucket_name(); + string objPath = "1/3"; + aws_chunk_manager_->SetBucketName(testBucketName); + if (!aws_chunk_manager_->BucketExists(testBucketName)) { + aws_chunk_manager_->CreateBucket(testBucketName); + } + + bool exist = aws_chunk_manager_->Exist(objPath); + EXPECT_EQ(exist, false); + exist = chunk_manager_ptr_->Exist(objPath); + EXPECT_EQ(exist, false); + aws_chunk_manager_->DeleteBucket(testBucketName); +} + +TEST_F(RemoteChunkManagerTest, WritePositive) { + string testBucketName = get_default_bucket_name(); + aws_chunk_manager_->SetBucketName(testBucketName); + EXPECT_EQ(aws_chunk_manager_->GetBucketName(), testBucketName); + + if (!aws_chunk_manager_->BucketExists(testBucketName)) { + aws_chunk_manager_->CreateBucket(testBucketName); + } + uint8_t data[5] = {0x17, 0x32, 0x45, 0x34, 0x23}; + string path = "1"; + aws_chunk_manager_->Write(path, data, sizeof(data)); + + bool exist = aws_chunk_manager_->Exist(path); + EXPECT_EQ(exist, true); + + auto size = aws_chunk_manager_->Size(path); + EXPECT_EQ(size, 5); + + int datasize = 10000; + uint8_t* bigdata = new uint8_t[datasize]; + srand((unsigned)time(NULL)); + for (int i = 0; i < datasize; ++i) { + bigdata[i] = rand() % 256; + } + aws_chunk_manager_->Write(path, bigdata, datasize); + size = aws_chunk_manager_->Size(path); + EXPECT_EQ(size, datasize); + delete[] bigdata; + + aws_chunk_manager_->Remove(path); + aws_chunk_manager_->DeleteBucket(testBucketName); +} + +TEST_F(RemoteChunkManagerTest, ReadPositive) { + string testBucketName = get_default_bucket_name(); + aws_chunk_manager_->SetBucketName(testBucketName); + EXPECT_EQ(aws_chunk_manager_->GetBucketName(), testBucketName); + + if (!aws_chunk_manager_->BucketExists(testBucketName)) { + aws_chunk_manager_->CreateBucket(testBucketName); + } + uint8_t data[5] = {0x17, 0x32, 0x45, 0x34, 0x23}; + string path = "1/4/6"; + aws_chunk_manager_->Write(path, data, sizeof(data)); + bool exist = aws_chunk_manager_->Exist(path); + EXPECT_EQ(exist, true); + auto size = aws_chunk_manager_->Size(path); + EXPECT_EQ(size, sizeof(data)); + + uint8_t readdata[20] = {0}; + size = aws_chunk_manager_->Read(path, readdata, sizeof(data)); + EXPECT_EQ(size, sizeof(data)); + EXPECT_EQ(readdata[0], 0x17); + EXPECT_EQ(readdata[1], 0x32); + EXPECT_EQ(readdata[2], 0x45); + EXPECT_EQ(readdata[3], 0x34); + EXPECT_EQ(readdata[4], 0x23); + + size = aws_chunk_manager_->Read(path, readdata, 3); + EXPECT_EQ(size, 3); + EXPECT_EQ(readdata[0], 0x17); + EXPECT_EQ(readdata[1], 0x32); + EXPECT_EQ(readdata[2], 0x45); + + uint8_t dataWithNULL[] = {0x17, 0x32, 0x00, 0x34, 0x23}; + aws_chunk_manager_->Write(path, dataWithNULL, sizeof(dataWithNULL)); + exist = aws_chunk_manager_->Exist(path); + EXPECT_EQ(exist, true); + size = aws_chunk_manager_->Size(path); + EXPECT_EQ(size, sizeof(dataWithNULL)); + size = aws_chunk_manager_->Read(path, readdata, sizeof(dataWithNULL)); + EXPECT_EQ(size, sizeof(dataWithNULL)); + EXPECT_EQ(readdata[0], 0x17); + EXPECT_EQ(readdata[1], 0x32); + EXPECT_EQ(readdata[2], 0x00); + EXPECT_EQ(readdata[3], 0x34); + EXPECT_EQ(readdata[4], 0x23); + + aws_chunk_manager_->Remove(path); + aws_chunk_manager_->DeleteBucket(testBucketName); +} + +TEST_F(RemoteChunkManagerTest, RemovePositive) { + string testBucketName = get_default_bucket_name(); + aws_chunk_manager_->SetBucketName(testBucketName); + EXPECT_EQ(aws_chunk_manager_->GetBucketName(), testBucketName); + + if (!aws_chunk_manager_->BucketExists(testBucketName)) { + aws_chunk_manager_->CreateBucket(testBucketName); + } + uint8_t data[5] = {0x17, 0x32, 0x45, 0x34, 0x23}; + string path = "1/7/8"; + aws_chunk_manager_->Write(path, data, sizeof(data)); + + bool exist = aws_chunk_manager_->Exist(path); + EXPECT_EQ(exist, true); + + aws_chunk_manager_->Remove(path); + + exist = aws_chunk_manager_->Exist(path); + EXPECT_EQ(exist, false); + + aws_chunk_manager_->DeleteBucket(testBucketName); +} + +TEST_F(RemoteChunkManagerTest, ListWithPrefixPositive) { + string testBucketName = get_default_bucket_name(); + aws_chunk_manager_->SetBucketName(testBucketName); + EXPECT_EQ(aws_chunk_manager_->GetBucketName(), testBucketName); + + if (!aws_chunk_manager_->BucketExists(testBucketName)) { + aws_chunk_manager_->CreateBucket(testBucketName); + } + + string path1 = "1/7/8"; + string path2 = "1/7/4"; + string path3 = "1/4/8"; + uint8_t data[5] = {0x17, 0x32, 0x45, 0x34, 0x23}; + aws_chunk_manager_->Write(path1, data, sizeof(data)); + aws_chunk_manager_->Write(path2, data, sizeof(data)); + aws_chunk_manager_->Write(path3, data, sizeof(data)); + + vector objs = aws_chunk_manager_->ListWithPrefix("1/7"); + EXPECT_EQ(objs.size(), 2); + std::sort(objs.begin(), objs.end()); + EXPECT_EQ(objs[0], "1/7/4"); + EXPECT_EQ(objs[1], "1/7/8"); + + objs = aws_chunk_manager_->ListWithPrefix("//1/7"); + EXPECT_EQ(objs.size(), 2); + + objs = aws_chunk_manager_->ListWithPrefix("1"); + EXPECT_EQ(objs.size(), 3); + std::sort(objs.begin(), objs.end()); + EXPECT_EQ(objs[0], "1/4/8"); + EXPECT_EQ(objs[1], "1/7/4"); + + aws_chunk_manager_->Remove(path1); + aws_chunk_manager_->Remove(path2); + aws_chunk_manager_->Remove(path3); + aws_chunk_manager_->DeleteBucket(testBucketName); +} diff --git a/internal/core/unittest/test_scalar_index_creator.cpp b/internal/core/unittest/test_scalar_index_creator.cpp index 987c6ce87dd00..e23e1630181f3 100644 --- a/internal/core/unittest/test_scalar_index_creator.cpp +++ b/internal/core/unittest/test_scalar_index_creator.cpp @@ -112,7 +112,9 @@ TYPED_TEST_P(TypedScalarIndexCreatorTest, Constructor) { } auto creator = milvus::indexbuilder::CreateScalarIndex( - milvus::DataType(dtype), config, nullptr); + milvus::DataType(dtype), + config, + milvus::storage::FileManagerContext()); } } @@ -133,12 +135,16 @@ TYPED_TEST_P(TypedScalarIndexCreatorTest, Codec) { config[iter->first] = iter->second; } auto creator = milvus::indexbuilder::CreateScalarIndex( - milvus::DataType(dtype), config, nullptr); + milvus::DataType(dtype), + config, + milvus::storage::FileManagerContext()); auto arr = GenArr(nb); build_index(creator, arr); auto binary_set = creator->Serialize(); auto copy_creator = milvus::indexbuilder::CreateScalarIndex( - milvus::DataType(dtype), config, nullptr); + milvus::DataType(dtype), + config, + milvus::storage::FileManagerContext()); copy_creator->Load(binary_set); } } diff --git a/internal/core/unittest/test_sealed.cpp b/internal/core/unittest/test_sealed.cpp index 42a882ded0d71..f4b507d9553e2 100644 --- a/internal/core/unittest/test_sealed.cpp +++ b/internal/core/unittest/test_sealed.cpp @@ -15,7 +15,14 @@ #include "common/Types.h" #include "segcore/SegmentSealedImpl.h" #include "test_utils/DataGen.h" +#include "test_utils/storage_test_utils.h" #include "index/IndexFactory.h" +#include "storage/Util.h" +#include "knowhere/version.h" +#include "storage/ChunkCacheSingleton.h" +#include "storage/RemoteChunkManagerSingleton.h" +#include "storage/MinioChunkManager.h" +#include "test_utils/indexbuilder_test_utils.h" using namespace milvus; using namespace milvus::query; @@ -82,9 +89,11 @@ TEST(Sealed, without_predicate) { create_index_info.field_type = DataType::VECTOR_FLOAT; create_index_info.metric_type = knowhere::metric::L2; create_index_info.index_type = knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); auto indexing = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, nullptr); + create_index_info, milvus::storage::FileManagerContext()); auto build_conf = knowhere::Json{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}, @@ -195,8 +204,10 @@ TEST(Sealed, with_predicate) { create_index_info.field_type = DataType::VECTOR_FLOAT; create_index_info.metric_type = knowhere::metric::L2; create_index_info.index_type = knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); auto indexing = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, nullptr); + create_index_info, milvus::storage::FileManagerContext()); auto build_conf = knowhere::Json{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}, @@ -299,8 +310,10 @@ TEST(Sealed, with_predicate_filter_all) { create_index_info.field_type = DataType::VECTOR_FLOAT; create_index_info.metric_type = knowhere::metric::L2; create_index_info.index_type = knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); auto ivf_indexing = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, nullptr); + create_index_info, milvus::storage::FileManagerContext()); auto ivf_build_conf = knowhere::Json{{knowhere::meta::DIM, std::to_string(dim)}, @@ -337,8 +350,10 @@ TEST(Sealed, with_predicate_filter_all) { create_index_info.field_type = DataType::VECTOR_FLOAT; create_index_info.metric_type = knowhere::metric::L2; create_index_info.index_type = knowhere::IndexEnum::INDEX_HNSW; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); auto hnsw_indexing = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, nullptr); + create_index_info, milvus::storage::FileManagerContext()); hnsw_indexing->BuildWithDataset(database, hnsw_conf); auto hnsw_vec_index = @@ -376,6 +391,7 @@ TEST(Sealed, LoadFieldData) { schema->AddDebugField("int16", DataType::INT16); schema->AddDebugField("float", DataType::FLOAT); schema->AddDebugField("json", DataType::JSON); + schema->AddDebugField("array", DataType::ARRAY, DataType::INT64); schema->set_primary_field_id(counter_id); auto dataset = DataGen(schema, N); @@ -500,6 +516,7 @@ TEST(Sealed, LoadFieldDataMmap) { schema->AddDebugField("int16", DataType::INT16); schema->AddDebugField("float", DataType::FLOAT); schema->AddDebugField("json", DataType::JSON); + schema->AddDebugField("array", DataType::ARRAY, DataType::INT64); schema->set_primary_field_id(counter_id); auto dataset = DataGen(schema, N); @@ -1053,7 +1070,7 @@ TEST(Sealed, DeleteCount) { ASSERT_TRUE(status.ok()); auto cnt = segment->get_deleted_count(); - ASSERT_EQ(cnt, c); + ASSERT_EQ(cnt, 0); } TEST(Sealed, RealCount) { @@ -1146,3 +1163,221 @@ TEST(Sealed, GetVector) { } } } + +TEST(Sealed, GetVectorFromChunkCache) { + // skip test due to mem leak from AWS::InitSDK + return; + + auto dim = 16; + auto topK = 5; + auto N = ROW_COUNT; + auto metric_type = knowhere::metric::L2; + auto index_type = knowhere::IndexEnum::INDEX_FAISS_IVFPQ; + + auto mmap_dir = "/tmp/mmap"; + auto file_name = std::string( + "sealed_test_get_vector_from_chunk_cache/insert_log/1/101/1000000"); + + auto sc = milvus::storage::StorageConfig{}; + milvus::storage::RemoteChunkManagerSingleton::GetInstance().Init(sc); + auto mcm = std::make_unique(sc); + mcm->CreateBucket(sc.bucket_name); + milvus::storage::ChunkCacheSingleton::GetInstance().Init(mmap_dir, + "willneed"); + + auto schema = std::make_shared(); + auto fakevec_id = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, dim, metric_type); + auto counter_id = schema->AddDebugField("counter", DataType::INT64); + auto double_id = schema->AddDebugField("double", DataType::DOUBLE); + auto nothing_id = schema->AddDebugField("nothing", DataType::INT32); + auto str_id = schema->AddDebugField("str", DataType::VARCHAR); + schema->AddDebugField("int8", DataType::INT8); + schema->AddDebugField("int16", DataType::INT16); + schema->AddDebugField("float", DataType::FLOAT); + schema->set_primary_field_id(counter_id); + + auto dataset = DataGen(schema, N); + auto field_data_meta = + milvus::storage::FieldDataMeta{1, 2, 3, fakevec_id.get()}; + auto field_meta = milvus::FieldMeta(milvus::FieldName("facevec"), + fakevec_id, + milvus::DataType::VECTOR_FLOAT, + dim, + metric_type); + + auto rcm = milvus::storage::RemoteChunkManagerSingleton::GetInstance() + .GetRemoteChunkManager(); + auto data = dataset.get_col(fakevec_id); + auto data_slices = std::vector{(uint8_t*)data.data()}; + auto slice_sizes = std::vector{static_cast(N)}; + auto slice_names = std::vector{file_name}; + PutFieldData(rcm.get(), + data_slices, + slice_sizes, + slice_names, + field_data_meta, + field_meta); + + auto fakevec = dataset.get_col(fakevec_id); + auto conf = generate_build_conf(index_type, metric_type); + auto ds = knowhere::GenDataSet(N, dim, fakevec.data()); + auto indexing = std::make_unique( + index_type, + metric_type, + knowhere::Version::GetCurrentVersion().VersionNumber()); + indexing->BuildWithDataset(ds, conf); + auto segment_sealed = CreateSealedSegment(schema); + + LoadIndexInfo vec_info; + vec_info.field_id = fakevec_id.get(); + vec_info.index = std::move(indexing); + vec_info.index_params["metric_type"] = knowhere::metric::L2; + segment_sealed->LoadIndex(vec_info); + + auto field_binlog_info = + FieldBinlogInfo{fakevec_id.get(), + N, + std::vector{N}, + std::vector{file_name}}; + segment_sealed->AddFieldDataInfoForSealed(LoadFieldDataInfo{ + std::map{ + {fakevec_id.get(), field_binlog_info}}, + mmap_dir, + }); + + auto segment = dynamic_cast(segment_sealed.get()); + auto has = segment->HasRawData(vec_info.field_id); + EXPECT_FALSE(has); + + auto ids_ds = GenRandomIds(N); + auto result = + segment->get_vector(fakevec_id, ids_ds->GetIds(), ids_ds->GetRows()); + + auto vector = result.get()->mutable_vectors()->float_vector().data(); + EXPECT_TRUE(vector.size() == fakevec.size()); + for (size_t i = 0; i < N; ++i) { + auto id = ids_ds->GetIds()[i]; + for (size_t j = 0; j < dim; ++j) { + auto expect = fakevec[id * dim + j]; + auto actual = vector[i * dim + j]; + AssertInfo(expect == actual, + fmt::format("expect {}, actual {}", expect, actual)); + } + } + + rcm->Remove(file_name); + std::filesystem::remove_all(mmap_dir); + auto exist = rcm->Exist(file_name); + Assert(!exist); + exist = std::filesystem::exists(mmap_dir); + Assert(!exist); +} + +TEST(Sealed, LoadArrayFieldData) { + auto dim = 16; + auto topK = 5; + auto N = 10; + auto metric_type = knowhere::metric::L2; + auto schema = std::make_shared(); + auto fakevec_id = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, dim, metric_type); + auto counter_id = schema->AddDebugField("counter", DataType::INT64); + auto array_id = + schema->AddDebugField("array", DataType::ARRAY, DataType::INT64); + schema->set_primary_field_id(counter_id); + + auto dataset = DataGen(schema, N); + auto fakevec = dataset.get_col(fakevec_id); + auto segment = CreateSealedSegment(schema); + + const char* raw_plan = R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:102 + data_type:Array + element_type:Int64 + > + elements: + op:Contains + elements_same_type:true + > + > + query_info:< + topk: 5 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)"; + + auto plan_str = translate_text_plan_to_binary_plan(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto num_queries = 5; + auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + + SealedLoadFieldData(dataset, *segment); + segment->Search(plan.get(), ph_group.get()); + + auto ids_ds = GenRandomIds(N); + auto s = dynamic_cast(segment.get()); + auto int64_result = s->bulk_subscript(array_id, ids_ds->GetIds(), N); + auto result_count = int64_result->scalars().array_data().data().size(); + ASSERT_EQ(result_count, N); +} + +TEST(Sealed, LoadArrayFieldDataWithMMap) { + auto dim = 16; + auto topK = 5; + auto N = ROW_COUNT; + auto metric_type = knowhere::metric::L2; + auto schema = std::make_shared(); + auto fakevec_id = schema->AddDebugField( + "fakevec", DataType::VECTOR_FLOAT, dim, metric_type); + auto counter_id = schema->AddDebugField("counter", DataType::INT64); + auto array_id = + schema->AddDebugField("array", DataType::ARRAY, DataType::INT64); + schema->set_primary_field_id(counter_id); + + auto dataset = DataGen(schema, N); + auto fakevec = dataset.get_col(fakevec_id); + auto segment = CreateSealedSegment(schema); + + const char* raw_plan = R"(vector_anns:< + field_id:100 + predicates:< + json_contains_expr:< + column_info:< + field_id:102 + data_type:Array + element_type:Int64 + > + elements: + op:Contains + elements_same_type:true + > + > + query_info:< + topk: 5 + round_decimal: 3 + metric_type: "L2" + search_params: "{\"nprobe\": 10}" + > placeholder_tag:"$0" + >)"; + + auto plan_str = translate_text_plan_to_binary_plan(raw_plan); + auto plan = + CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + auto num_queries = 5; + auto ph_group_raw = CreatePlaceholderGroup(num_queries, 16, 1024); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + + SealedLoadFieldData(dataset, *segment, {}, true); + segment->Search(plan.get(), ph_group.get()); +} diff --git a/internal/core/unittest/test_string_expr.cpp b/internal/core/unittest/test_string_expr.cpp index c2dd48a7b450f..32aa555d5f8d8 100644 --- a/internal/core/unittest/test_string_expr.cpp +++ b/internal/core/unittest/test_string_expr.cpp @@ -73,12 +73,12 @@ GenQueryInfo(int64_t topk, auto GenAnns(proto::plan::Expr* predicate, - bool is_binary, + proto::plan::VectorType vectorType, int64_t field_id, std::string placeholder_tag = "$0") { auto query_info = GenQueryInfo(10, "L2", "{\"nprobe\": 10}", -1); auto anns = new proto::plan::VectorANNS(); - anns->set_is_binary(is_binary); + anns->set_vector_type(vectorType); anns->set_field_id(field_id); anns->set_allocated_predicates(predicate); anns->set_allocated_query_info(query_info); @@ -177,10 +177,16 @@ GenTermPlan(const FieldMeta& fvec_meta, auto expr = GenExpr().release(); expr->set_allocated_term_expr(term_expr); - auto anns = GenAnns(expr, - fvec_meta.get_data_type() == DataType::VECTOR_BINARY, - fvec_meta.get_id().get(), - "$0"); + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + + auto anns = GenAnns(expr, vector_type, fvec_meta.get_id().get(), "$0"); auto plan_node = GenPlanNode(); plan_node->set_allocated_vector_anns(anns); @@ -215,10 +221,16 @@ GenAlwaysTrueExpr(const FieldMeta& fvec_meta, const FieldMeta& str_meta) { auto GenAlwaysFalsePlan(const FieldMeta& fvec_meta, const FieldMeta& str_meta) { auto always_false_expr = GenAlwaysFalseExpr(fvec_meta, str_meta); - auto anns = GenAnns(always_false_expr, - fvec_meta.get_data_type() == DataType::VECTOR_BINARY, - fvec_meta.get_id().get(), - "$0"); + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + auto anns = + GenAnns(always_false_expr, vector_type, fvec_meta.get_id().get(), "$0"); auto plan_node = GenPlanNode(); plan_node->set_allocated_vector_anns(anns); @@ -228,10 +240,16 @@ GenAlwaysFalsePlan(const FieldMeta& fvec_meta, const FieldMeta& str_meta) { auto GenAlwaysTruePlan(const FieldMeta& fvec_meta, const FieldMeta& str_meta) { auto always_true_expr = GenAlwaysTrueExpr(fvec_meta, str_meta); - auto anns = GenAnns(always_true_expr, - fvec_meta.get_data_type() == DataType::VECTOR_BINARY, - fvec_meta.get_id().get(), - "$0"); + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + auto anns = + GenAnns(always_true_expr, vector_type, fvec_meta.get_id().get(), "$0"); auto plan_node = GenPlanNode(); plan_node->set_allocated_vector_anns(anns); @@ -353,11 +371,15 @@ TEST(StringExpr, Compare) { auto expr = GenExpr().release(); expr->set_allocated_compare_expr(compare_expr); - auto anns = - GenAnns(expr, - fvec_meta.get_data_type() == DataType::VECTOR_BINARY, - fvec_meta.get_id().get(), - "$0"); + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + auto anns = GenAnns(expr, vector_type, fvec_meta.get_id().get(), "$0"); auto plan_node = std::make_unique(); plan_node->set_allocated_vector_anns(anns); @@ -456,11 +478,15 @@ TEST(StringExpr, UnaryRange) { auto expr = GenExpr().release(); expr->set_allocated_unary_range_expr(unary_range_expr); - auto anns = - GenAnns(expr, - fvec_meta.get_data_type() == DataType::VECTOR_BINARY, - fvec_meta.get_id().get(), - "$0"); + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + auto anns = GenAnns(expr, vector_type, fvec_meta.get_id().get(), "$0"); auto plan_node = std::make_unique(); plan_node->set_allocated_vector_anns(anns); @@ -551,11 +577,15 @@ TEST(StringExpr, BinaryRange) { auto expr = GenExpr().release(); expr->set_allocated_binary_range_expr(binary_range_expr); - auto anns = - GenAnns(expr, - fvec_meta.get_data_type() == DataType::VECTOR_BINARY, - fvec_meta.get_id().get(), - "$0"); + proto::plan::VectorType vector_type; + if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT) { + vector_type = proto::plan::VectorType::FloatVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_BINARY) { + vector_type = proto::plan::VectorType::BinaryVector; + } else if (fvec_meta.get_data_type() == DataType::VECTOR_FLOAT16) { + vector_type = proto::plan::VectorType::Float16Vector; + } + auto anns = GenAnns(expr, vector_type, fvec_meta.get_id().get(), "$0"); auto plan_node = std::make_unique(); plan_node->set_allocated_vector_anns(anns); diff --git a/internal/core/unittest/test_tracer.cpp b/internal/core/unittest/test_tracer.cpp index a9d568bac9d5f..7393b671e69c5 100644 --- a/internal/core/unittest/test_tracer.cpp +++ b/internal/core/unittest/test_tracer.cpp @@ -15,7 +15,7 @@ #include #include "common/Tracer.h" -#include "exceptions/EasyAssert.h" +#include "common/EasyAssert.h" using namespace milvus; using namespace milvus::tracer; diff --git a/internal/core/unittest/test_utils.cpp b/internal/core/unittest/test_utils.cpp index cf35a89a73901..60a31a20e13d7 100644 --- a/internal/core/unittest/test_utils.cpp +++ b/internal/core/unittest/test_utils.cpp @@ -15,11 +15,11 @@ #include #include +#include "common/EasyAssert.h" #include "common/Utils.h" #include "query/Utils.h" #include "test_utils/DataGen.h" #include "common/Types.h" -#include "index/Exception.h" TEST(Util, StringMatch) { using namespace milvus; @@ -138,29 +138,51 @@ TEST(Util, upper_bound) { ASSERT_EQ(10, upper_bound(timestamps, 0, data.size(), 10)); } +// A simple wrapper that removes a temporary file. +struct TmpFileWrapper { + int fd = -1; + std::string filename; + + TmpFileWrapper(const std::string& _filename) : filename{_filename} { + fd = open( + filename.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR | S_IXUSR); + } + TmpFileWrapper(const TmpFileWrapper&) = delete; + TmpFileWrapper(TmpFileWrapper&&) = delete; + TmpFileWrapper& operator =(const TmpFileWrapper&) = delete; + TmpFileWrapper& operator =(TmpFileWrapper&&) = delete; + ~TmpFileWrapper() { + if (fd != -1) { + close(fd); + remove(filename.c_str()); + } + } +}; + TEST(Util, read_from_fd) { auto uuid = boost::uuids::random_generator()(); auto uuid_string = boost::uuids::to_string(uuid); auto file = std::string("/tmp/") + uuid_string; - auto fd = open( - file.c_str(), O_RDWR | O_CREAT | O_EXCL, S_IRUSR | S_IWUSR | S_IXUSR); - ASSERT_NE(fd, -1); + auto tmp_file = TmpFileWrapper(file); + ASSERT_NE(tmp_file.fd, -1); + size_t data_size = 100 * 1024 * 1024; // 100M auto index_data = std::shared_ptr(new uint8_t[data_size]); auto max_loop = size_t(INT_MAX) / data_size + 1; // insert data > 2G for (int i = 0; i < max_loop; ++i) { - auto size_write = write(fd, index_data.get(), data_size); + auto size_write = write(tmp_file.fd, index_data.get(), data_size); ASSERT_GE(size_write, 0); } auto read_buf = std::shared_ptr(new uint8_t[data_size * max_loop]); EXPECT_NO_THROW(milvus::index::ReadDataFromFD( - fd, read_buf.get(), data_size * max_loop)); + tmp_file.fd, read_buf.get(), data_size * max_loop)); // On Linux, read() (and similar system calls) will transfer at most 0x7ffff000 (2,147,479,552) bytes once EXPECT_THROW(milvus::index::ReadDataFromFD( - fd, read_buf.get(), data_size * max_loop, INT_MAX), - milvus::index::UnistdException); + tmp_file.fd, read_buf.get(), data_size * max_loop, INT_MAX), + milvus::SegcoreError); } + diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index df2c3bdd84a89..0bf70e8b4556c 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -19,10 +19,11 @@ #include #include "Constants.h" +#include "common/EasyAssert.h" #include "common/Schema.h" #include "index/ScalarIndexSort.h" #include "index/StringIndexSort.h" -#include "index/VectorMemNMIndex.h" +#include "index/VectorMemIndex.h" #include "query/SearchOnIndex.h" #include "segcore/SegmentGrowingImpl.h" #include "segcore/SegmentSealedImpl.h" @@ -105,65 +106,92 @@ struct GeneratedData { auto src_data = reinterpret_cast( target_field_data.vectors().binary_vector().data()); std::copy_n(src_data, len, ret.data()); + } else if (field_meta.get_data_type() == + DataType::VECTOR_FLOAT16) { + // int len = raw_->num_rows() * field_meta.get_dim() * sizeof(float16); + int len = raw_->num_rows() * field_meta.get_dim(); + ret.resize(len); + auto src_data = reinterpret_cast( + target_field_data.vectors().float16_vector().data()); + std::copy_n(src_data, len, ret.data()); } else { - PanicInfo("unsupported"); + PanicInfo(Unsupported, "unsupported"); } return std::move(ret); } - switch (field_meta.get_data_type()) { - case DataType::BOOL: { - auto src_data = reinterpret_cast( - target_field_data.scalars().bool_data().data().data()); - std::copy_n(src_data, raw_->num_rows(), ret.data()); - break; - } - case DataType::INT8: - case DataType::INT16: - case DataType::INT32: { - auto src_data = reinterpret_cast( - target_field_data.scalars().int_data().data().data()); - std::copy_n(src_data, raw_->num_rows(), ret.data()); - break; - } - case DataType::INT64: { - auto src_data = reinterpret_cast( - target_field_data.scalars().long_data().data().data()); - std::copy_n(src_data, raw_->num_rows(), ret.data()); - break; - } - case DataType::FLOAT: { - auto src_data = reinterpret_cast( - target_field_data.scalars().float_data().data().data()); - std::copy_n(src_data, raw_->num_rows(), ret.data()); - break; - } - case DataType::DOUBLE: { - auto src_data = - reinterpret_cast(target_field_data.scalars() - .double_data() - .data() - .data()); - std::copy_n(src_data, raw_->num_rows(), ret.data()); - break; - } - case DataType::VARCHAR: { - auto ret_data = reinterpret_cast(ret.data()); - auto src_data = - target_field_data.scalars().string_data().data(); - std::copy(src_data.begin(), src_data.end(), ret_data); - - break; - } - case DataType::JSON: { - auto ret_data = reinterpret_cast(ret.data()); - auto src_data = - target_field_data.scalars().json_data().data(); - std::copy(src_data.begin(), src_data.end(), ret_data); - break; - } - default: { - PanicInfo("unsupported"); + if constexpr (std::is_same_v) { + auto ret_data = reinterpret_cast(ret.data()); + auto src_data = target_field_data.scalars().array_data().data(); + std::copy(src_data.begin(), src_data.end(), ret_data); + } else { + switch (field_meta.get_data_type()) { + case DataType::BOOL: { + auto src_data = reinterpret_cast( + target_field_data.scalars() + .bool_data() + .data() + .data()); + std::copy_n(src_data, raw_->num_rows(), ret.data()); + break; + } + case DataType::INT8: + case DataType::INT16: + case DataType::INT32: { + auto src_data = reinterpret_cast( + target_field_data.scalars() + .int_data() + .data() + .data()); + std::copy_n(src_data, raw_->num_rows(), ret.data()); + break; + } + case DataType::INT64: { + auto src_data = reinterpret_cast( + target_field_data.scalars() + .long_data() + .data() + .data()); + std::copy_n(src_data, raw_->num_rows(), ret.data()); + break; + } + case DataType::FLOAT: { + auto src_data = reinterpret_cast( + target_field_data.scalars() + .float_data() + .data() + .data()); + std::copy_n(src_data, raw_->num_rows(), ret.data()); + break; + } + case DataType::DOUBLE: { + auto src_data = reinterpret_cast( + target_field_data.scalars() + .double_data() + .data() + .data()); + std::copy_n(src_data, raw_->num_rows(), ret.data()); + break; + } + case DataType::VARCHAR: { + auto ret_data = + reinterpret_cast(ret.data()); + auto src_data = + target_field_data.scalars().string_data().data(); + std::copy(src_data.begin(), src_data.end(), ret_data); + break; + } + case DataType::JSON: { + auto ret_data = + reinterpret_cast(ret.data()); + auto src_data = + target_field_data.scalars().json_data().data(); + std::copy(src_data.begin(), src_data.end(), ret_data); + break; + } + default: { + PanicInfo(Unsupported, "unsupported"); + } } } } @@ -172,13 +200,13 @@ struct GeneratedData { std::unique_ptr get_col(FieldId field_id) const { - for (auto target_field_data : raw_->fields_data()) { + for (const auto& target_field_data : raw_->fields_data()) { if (field_id.get() == target_field_data.field_id()) { return std::make_unique(target_field_data); } } - PanicInfo("field id not find"); + PanicInfo(FieldIDInvalid, "field id not find"); } private: @@ -188,7 +216,8 @@ struct GeneratedData { int64_t N, uint64_t seed, uint64_t ts_offset, - int repeat_count); + int repeat_count, + int array_len); friend GeneratedData DataGenForJsonArray(SchemaPtr schema, int64_t N, @@ -203,7 +232,8 @@ DataGen(SchemaPtr schema, int64_t N, uint64_t seed = 42, uint64_t ts_offset = 0, - int repeat_count = 1) { + int repeat_count = 1, + int array_len = 10) { using std::vector; std::default_random_engine er(seed); std::normal_distribution<> distr(0, 1); @@ -259,6 +289,15 @@ DataGen(SchemaPtr schema, insert_cols(data, N, field_meta); break; } + case DataType::VECTOR_FLOAT16: { + auto dim = field_meta.get_dim(); + vector final(dim * N); + for (auto& x : final) { + x = float16(distr(er) + offset); + } + insert_cols(final, N, field_meta); + break; + } case DataType::BOOL: { FixedVector data(N); for (int i = 0; i < N; ++i) { @@ -329,18 +368,102 @@ DataGen(SchemaPtr schema, case DataType::JSON: { vector data(N); for (int i = 0; i < N / repeat_count; i++) { - auto str = R"({"int":)" + std::to_string(er()) + - R"(,"double":)" + - std::to_string(static_cast(er())) + - R"(,"string":")" + std::to_string(er()) + - R"(","bool": true)" + "}"; + auto str = + R"({"int":)" + std::to_string(er()) + R"(,"double":)" + + std::to_string(static_cast(er())) + + R"(,"string":")" + std::to_string(er()) + + R"(","bool": true)" + R"(, "array": [1,2,3])" + "}"; data[i] = str; } insert_cols(data, N, field_meta); break; } + case DataType::ARRAY: { + vector data(N); + switch (field_meta.get_element_type()) { + case DataType::BOOL: { + for (int i = 0; i < N / repeat_count; i++) { + milvus::proto::schema::ScalarField field_data; + + for (int j = 0; j < array_len; j++) { + field_data.mutable_bool_data()->add_data( + static_cast(er())); + } + data[i] = field_data; + } + break; + } + case DataType::INT8: + case DataType::INT16: + case DataType::INT32: { + for (int i = 0; i < N / repeat_count; i++) { + milvus::proto::schema::ScalarField field_data; + + for (int j = 0; j < array_len; j++) { + field_data.mutable_int_data()->add_data( + static_cast(er())); + } + data[i] = field_data; + } + break; + } + case DataType::INT64: { + for (int i = 0; i < N / repeat_count; i++) { + milvus::proto::schema::ScalarField field_data; + for (int j = 0; j < array_len; j++) { + field_data.mutable_long_data()->add_data( + static_cast(er())); + } + data[i] = field_data; + } + break; + } + case DataType::STRING: + case DataType::VARCHAR: { + for (int i = 0; i < N / repeat_count; i++) { + milvus::proto::schema::ScalarField field_data; + + for (int j = 0; j < array_len; j++) { + field_data.mutable_string_data()->add_data( + std::to_string(er())); + } + data[i] = field_data; + } + break; + } + case DataType::FLOAT: { + for (int i = 0; i < N / repeat_count; i++) { + milvus::proto::schema::ScalarField field_data; + + for (int j = 0; j < array_len; j++) { + field_data.mutable_float_data()->add_data( + static_cast(er())); + } + data[i] = field_data; + } + break; + } + case DataType::DOUBLE: { + for (int i = 0; i < N / repeat_count; i++) { + milvus::proto::schema::ScalarField field_data; + + for (int j = 0; j < array_len; j++) { + field_data.mutable_double_data()->add_data( + static_cast(er())); + } + data[i] = field_data; + } + break; + } + default: { + throw std::runtime_error("unsupported data type"); + } + } + insert_cols(data, N, field_meta); + break; + } default: { - throw std::runtime_error("unimplemented"); + throw SegcoreError(ErrorCode::NotImplemented, "unimplemented"); } } ++offset; @@ -414,22 +537,24 @@ DataGenForJsonArray(SchemaPtr schema, std::to_string(static_cast(er()))); stringVec.push_back("\"" + std::to_string(er()) + "\""); boolVec.push_back(i % 2 == 0 ? "true" : "false"); - arrayVec.push_back(fmt::format("[{}, {}, {}]", i, i+1, i+2)); + arrayVec.push_back( + fmt::format("[{}, {}, {}]", i, i + 1, i + 2)); } - auto str = R"({"int":[)" + join(intVec, ",") + - R"(],"double":[)" + join(doubleVec, ",") + - R"(],"string":[)" + join(stringVec, ",") + - R"(],"bool": [)" + join(boolVec, ",") + - R"(],"array": [)" + join(arrayVec, ",") + - R"(],"array2": [[1,2], [3,4]])" + "}"; - //std::cout << str << std::endl; + auto str = + R"({"int":[)" + join(intVec, ",") + R"(],"double":[)" + + join(doubleVec, ",") + R"(],"string":[)" + + join(stringVec, ",") + R"(],"bool": [)" + + join(boolVec, ",") + R"(],"array": [)" + + join(arrayVec, ",") + R"(],"array2": [[1,2], [3,4]])" + + R"(,"array3": [[1,2.2,false,"abc"]])" + + R"(,"diff_type_array": [1,2.2,true,"abc"])" + "}"; data[i] = str; } insert_cols(data, N, field_meta); break; } default: { - throw std::runtime_error("unimplemented"); + throw SegcoreError(ErrorCode::NotImplemented, "unimplemented"); } } } @@ -549,6 +674,27 @@ CreateBinaryPlaceholderGroupFromBlob(int64_t num_queries, return raw_group; } +inline auto +CreateFloat16PlaceholderGroup(int64_t num_queries, + int64_t dim, + int64_t seed = 42) { + namespace ser = milvus::proto::common; + ser::PlaceholderGroup raw_group; + auto value = raw_group.add_placeholders(); + value->set_tag("$0"); + value->set_type(ser::PlaceholderType::Float16Vector); + std::normal_distribution dis(0, 1); + std::default_random_engine e(seed); + for (int i = 0; i < num_queries; ++i) { + std::vector vec; + for (int d = 0; d < dim; ++d) { + vec.push_back(float16(dis(e))); + } + value->add_values(vec.data(), vec.size() * sizeof(float16)); + } + return raw_group; +} + inline auto SearchResultToVector(const SearchResult& sr) { int64_t num_queries = sr.total_nq_; @@ -564,7 +710,7 @@ SearchResultToVector(const SearchResult& sr) { return result; } -inline json +inline nlohmann::json SearchResultToJson(const SearchResult& sr) { int64_t num_queries = sr.total_nq_; int64_t topk = sr.unity_topK_; @@ -578,7 +724,7 @@ SearchResultToJson(const SearchResult& sr) { } results.emplace_back(std::move(result)); } - return json{results}; + return nlohmann::json{results}; }; inline storage::FieldDataPtr @@ -611,7 +757,7 @@ CreateFieldDataFromDataArray(ssize_t raw_count, break; } default: { - PanicInfo("unsupported"); + PanicInfo(Unsupported, "unsupported"); } } } else { @@ -672,8 +818,17 @@ CreateFieldDataFromDataArray(ssize_t raw_count, createFieldData(data_raw.data(), DataType::JSON, dim); break; } + case DataType::ARRAY: { + auto src_data = data->scalars().array_data().data(); + std::vector data_raw(src_data.size()); + for (int i = 0; i < src_data.size(); i++) { + data_raw[i] = Array(src_data.at(i)); + } + createFieldData(data_raw.data(), DataType::ARRAY, dim); + break; + } default: { - PanicInfo("unsupported"); + PanicInfo(Unsupported, "unsupported"); } } } @@ -747,16 +902,36 @@ SealedCreator(SchemaPtr schema, const GeneratedData& dataset) { inline std::unique_ptr GenVecIndexing(int64_t N, int64_t dim, const float* vec) { - // {knowhere::IndexParams::nprobe, 10}, auto conf = knowhere::Json{{knowhere::meta::METRIC_TYPE, knowhere::metric::L2}, {knowhere::meta::DIM, std::to_string(dim)}, {knowhere::indexparam::NLIST, "1024"}, {knowhere::meta::DEVICE_ID, 0}}; auto database = knowhere::GenDataSet(N, dim, vec); - auto indexing = std::make_unique( - knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, knowhere::metric::L2); + milvus::storage::FieldDataMeta field_data_meta{1, 2, 3, 100}; + milvus::storage::IndexMeta index_meta{3, 100, 1000, 1}; + milvus::storage::StorageConfig storage_config; + storage_config.storage_type = "local"; + storage_config.root_path = TestRemotePath; + auto chunk_manager = milvus::storage::CreateChunkManager(storage_config); + milvus::storage::FileManagerContext file_manager_context( + field_data_meta, index_meta, chunk_manager); + auto indexing = std::make_unique( + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, + knowhere::metric::L2, + knowhere::Version::GetCurrentVersion().VersionNumber(), + file_manager_context); indexing->BuildWithDataset(database, conf); + auto binary_set = indexing->Upload(); + + std::vector index_files; + for (auto& binary : binary_set.binary_map_) { + index_files.emplace_back(binary.first); + } + conf["index_files"] = index_files; + // we need a load stage to use index as the producation does + // knowhere would do some data preparation in this stage + indexing->Load(conf); return indexing; } diff --git a/internal/core/unittest/test_utils/indexbuilder_test_utils.h b/internal/core/unittest/test_utils/indexbuilder_test_utils.h index 9f9d9a092bda3..174cb0fbdc94e 100644 --- a/internal/core/unittest/test_utils/indexbuilder_test_utils.h +++ b/internal/core/unittest/test_utils/indexbuilder_test_utils.h @@ -114,7 +114,10 @@ generate_load_conf(const milvus::IndexType& index_type, std::to_string(0.0002)}, }; } - return knowhere::Json(); + return knowhere::Json{ + {knowhere::meta::METRIC_TYPE, metric_type}, + {knowhere::meta::DIM, std::to_string(DIM)}, + }; } std::vector diff --git a/internal/core/unittest/test_utils/storage_test_utils.h b/internal/core/unittest/test_utils/storage_test_utils.h index 589d52290475e..02a69c0b98934 100644 --- a/internal/core/unittest/test_utils/storage_test_utils.h +++ b/internal/core/unittest/test_utils/storage_test_utils.h @@ -22,6 +22,7 @@ #include "common/LoadInfo.h" #include "storage/Types.h" #include "storage/InsertData.h" +#include "storage/ThreadPools.h" using milvus::DataType; using milvus::FieldId; @@ -68,6 +69,7 @@ PrepareInsertBinlog(int64_t collection_id, field_id, FieldBinlogInfo{field_id, static_cast(row_count), + std::vector{int64_t(row_count)}, std::vector{file}}); }; @@ -98,4 +100,40 @@ PrepareInsertBinlog(int64_t collection_id, return load_info; } +std::map +PutFieldData(milvus::storage::ChunkManager* remote_chunk_manager, + const std::vector& buffers, + const std::vector& element_counts, + const std::vector& object_keys, + FieldDataMeta& field_data_meta, + milvus::FieldMeta& field_meta) { + auto& pool = + milvus::ThreadPools::GetThreadPool(milvus::ThreadPoolPriority::MIDDLE); + std::vector>> futures; + AssertInfo(buffers.size() == element_counts.size(), + "inconsistent size of data slices with slice sizes!"); + AssertInfo(buffers.size() == object_keys.size(), + "inconsistent size of data slices with slice names!"); + + for (int64_t i = 0; i < buffers.size(); ++i) { + futures.push_back( + pool.Submit(milvus::storage::EncodeAndUploadFieldSlice, + remote_chunk_manager, + const_cast(buffers[i]), + element_counts[i], + field_data_meta, + field_meta, + object_keys[i])); + } + + std::map remote_paths_to_size; + for (auto& future : futures) { + auto res = future.get(); + remote_paths_to_size[res.first] = res.second; + } + + milvus::storage::ReleaseArrowUnused(); + return remote_paths_to_size; +} + } // namespace diff --git a/internal/datacoord/allocator.go b/internal/datacoord/allocator.go index 6f6e58c3ff257..d4c687bd72984 100644 --- a/internal/datacoord/allocator.go +++ b/internal/datacoord/allocator.go @@ -37,13 +37,13 @@ var _ allocator = (*rootCoordAllocator)(nil) // rootCoordAllocator use RootCoord as allocator type rootCoordAllocator struct { - types.RootCoord + types.RootCoordClient } // newRootCoordAllocator gets an allocator from RootCoord -func newRootCoordAllocator(rootCoordClient types.RootCoord) allocator { +func newRootCoordAllocator(rootCoordClient types.RootCoordClient) allocator { return &rootCoordAllocator{ - RootCoord: rootCoordClient, + RootCoordClient: rootCoordClient, } } diff --git a/internal/datacoord/allocator_test.go b/internal/datacoord/allocator_test.go index 4b0a137e85e54..923c2449282b9 100644 --- a/internal/datacoord/allocator_test.go +++ b/internal/datacoord/allocator_test.go @@ -20,13 +20,14 @@ import ( "context" "testing" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func TestAllocator_Basic(t *testing.T) { paramtable.Init() - ms := newMockRootCoordService() + ms := newMockRootCoordClient() allocator := newRootCoordAllocator(ms) ctx := context.Background() @@ -41,7 +42,7 @@ func TestAllocator_Basic(t *testing.T) { }) t.Run("Test Unhealthy Root", func(t *testing.T) { - ms := newMockRootCoordService() + ms := newMockRootCoordClient() allocator := newRootCoordAllocator(ms) err := ms.Stop() assert.NoError(t, err) diff --git a/internal/datacoord/build_index_policy.go b/internal/datacoord/build_index_policy.go index 0446e50431698..9397378f9c927 100644 --- a/internal/datacoord/build_index_policy.go +++ b/internal/datacoord/build_index_policy.go @@ -24,5 +24,4 @@ func defaultBuildIndexPolicy(buildIDs []UniqueID) { sort.Slice(buildIDs, func(i, j int) bool { return buildIDs[i] < buildIDs[j] }) - } diff --git a/internal/datacoord/channel_checker.go b/internal/datacoord/channel_checker.go index 6ef7dd41763b7..9ab1555b72cf1 100644 --- a/internal/datacoord/channel_checker.go +++ b/internal/datacoord/channel_checker.go @@ -41,8 +41,8 @@ type channelStateTimer struct { etcdWatcher clientv3.WatchChan timeoutWatcher chan *ackEvent - //Modifies afterwards must guarantee that runningTimerCount is updated synchronized with runningTimers - //in order to keep consistency + // Modifies afterwards must guarantee that runningTimerCount is updated synchronized with runningTimers + // in order to keep consistency runningTimerCount atomic.Int32 } @@ -185,7 +185,6 @@ func parseWatchInfo(key string, data []byte) (*datapb.ChannelWatchInfo, error) { watchInfo := datapb.ChannelWatchInfo{} if err := proto.Unmarshal(data, &watchInfo); err != nil { return nil, fmt.Errorf("invalid event data: fail to parse ChannelWatchInfo, key: %s, err: %v", key, err) - } if watchInfo.Vchan == nil { diff --git a/internal/datacoord/channel_checker_test.go b/internal/datacoord/channel_checker_test.go index a15cd4bd5f780..5ed5e900f2851 100644 --- a/internal/datacoord/channel_checker_test.go +++ b/internal/datacoord/channel_checker_test.go @@ -21,11 +21,11 @@ import ( "testing" "time" - "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/golang/protobuf/proto" + "github.com/milvus-io/milvus/internal/proto/datapb" ) func TestChannelStateTimer(t *testing.T) { @@ -242,6 +242,5 @@ func TestChannelStateTimer_parses(t *testing.T) { for _, test := range tests { assert.Equal(t, test.outAckType, getAckType(test.inState)) } - }) } diff --git a/internal/datacoord/channel_manager.go b/internal/datacoord/channel_manager.go index daa078e5bc97c..0d77425c08b97 100644 --- a/internal/datacoord/channel_manager.go +++ b/internal/datacoord/channel_manager.go @@ -431,7 +431,8 @@ func (c *ChannelManager) unsubAttempt(ncInfo *NodeChannelInfo) { } // Watch tries to add the channel to cluster. Watch is a no op if the channel already exists. -func (c *ChannelManager) Watch(ch *channel) error { +func (c *ChannelManager) Watch(ctx context.Context, ch *channel) error { + log := log.Ctx(ctx) c.mu.Lock() defer c.mu.Unlock() @@ -467,7 +468,7 @@ func (c *ChannelManager) fillChannelWatchInfo(op *ChannelOp) { // fillChannelWatchInfoWithState updates the channel op by filling in channel watch info. func (c *ChannelManager) fillChannelWatchInfoWithState(op *ChannelOp, state datapb.ChannelWatchState) []string { - var channelsWithTimer = []string{} + channelsWithTimer := []string{} startTs := time.Now().Unix() checkInterval := Params.DataCoordCfg.WatchTimeoutInterval.GetAsDuration(time.Second) for _, ch := range op.Channels { @@ -490,14 +491,27 @@ func (c *ChannelManager) fillChannelWatchInfoWithState(op *ChannelOp, state data return channelsWithTimer } -// GetChannels gets channels info of registered nodes. -func (c *ChannelManager) GetChannels() []*NodeChannelInfo { +// GetAssignedChannels gets channels info of registered nodes. +func (c *ChannelManager) GetAssignedChannels() []*NodeChannelInfo { c.mu.RLock() defer c.mu.RUnlock() return c.store.GetNodesChannels() } +func (c *ChannelManager) GetChannelsByCollectionID(collectionID UniqueID) []*channel { + channels := make([]*channel, 0) + for _, nodeChannels := range c.store.GetChannels() { + for _, channelInfo := range nodeChannels.Channels { + if collectionID == channelInfo.CollectionID { + channels = append(channels, channelInfo) + } + } + } + log.Info("get channel", zap.Any("collection", collectionID), zap.Any("channel", channels)) + return channels +} + // GetBufferChannels gets buffer channels. func (c *ChannelManager) GetBufferChannels() *NodeChannelInfo { c.mu.RLock() @@ -607,7 +621,7 @@ type ackEvent struct { } func (c *ChannelManager) updateWithTimer(updates ChannelOpSet, state datapb.ChannelWatchState) error { - var channelsWithTimer = []string{} + channelsWithTimer := []string{} for _, op := range updates { if op.Type == Add { channelsWithTimer = append(channelsWithTimer, c.fillChannelWatchInfoWithState(op, state)...) @@ -765,17 +779,25 @@ func (c *ChannelManager) Release(nodeID UniqueID, channelName string) error { // Reassign reassigns a channel to another DataNode. func (c *ChannelManager) Reassign(originNodeID UniqueID, channelName string) error { - c.mu.Lock() - defer c.mu.Unlock() - + c.mu.RLock() ch := c.getChannelByNodeAndName(originNodeID, channelName) if ch == nil { + c.mu.RUnlock() return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", originNodeID, channelName) } + c.mu.RUnlock() reallocates := &NodeChannelInfo{originNodeID, []*channel{ch}} + isDropped := c.isMarkedDrop(channelName, ch.CollectionID) - if c.isMarkedDrop(channelName, ch.CollectionID) { + c.mu.Lock() + defer c.mu.Unlock() + ch = c.getChannelByNodeAndName(originNodeID, channelName) + if ch == nil { + return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", originNodeID, channelName) + } + + if isDropped { if err := c.remove(originNodeID, ch); err != nil { return fmt.Errorf("failed to remove watch info: %v,%s", ch, err.Error()) } @@ -804,13 +826,13 @@ func (c *ChannelManager) Reassign(originNodeID UniqueID, channelName string) err // CleanupAndReassign tries to clean up datanode's subscription, and then reassigns the channel to another DataNode. func (c *ChannelManager) CleanupAndReassign(nodeID UniqueID, channelName string) error { - c.mu.Lock() - defer c.mu.Unlock() - + c.mu.RLock() chToCleanUp := c.getChannelByNodeAndName(nodeID, channelName) if chToCleanUp == nil { + c.mu.RUnlock() return fmt.Errorf("failed to find matching channel: %s and node: %d", channelName, nodeID) } + c.mu.RUnlock() if c.msgstreamFactory == nil { log.Warn("msgstream factory is not set, unable to clean up topics") @@ -821,8 +843,16 @@ func (c *ChannelManager) CleanupAndReassign(nodeID UniqueID, channelName string) } reallocates := &NodeChannelInfo{nodeID, []*channel{chToCleanUp}} + isDropped := c.isMarkedDrop(channelName, chToCleanUp.CollectionID) + + c.mu.Lock() + defer c.mu.Unlock() + chToCleanUp = c.getChannelByNodeAndName(nodeID, channelName) + if chToCleanUp == nil { + return fmt.Errorf("failed to find matching channel: %s and node: %d", channelName, nodeID) + } - if c.isMarkedDrop(channelName, chToCleanUp.CollectionID) { + if isDropped { if err := c.remove(nodeID, chToCleanUp); err != nil { return fmt.Errorf("failed to remove watch info: %v,%s", chToCleanUp, err.Error()) } @@ -870,7 +900,7 @@ func (c *ChannelManager) getChannelByNodeAndName(nodeID UniqueID, channelName st } func (c *ChannelManager) getNodeIDByChannelName(chName string) (bool, UniqueID) { - for _, nodeChannel := range c.GetChannels() { + for _, nodeChannel := range c.GetAssignedChannels() { for _, ch := range nodeChannel.Channels { if ch.Name == chName { return true, nodeChannel.NodeID diff --git a/internal/datacoord/channel_manager_test.go b/internal/datacoord/channel_manager_test.go index 799308b3b2a60..14b14d3a421dd 100644 --- a/internal/datacoord/channel_manager_test.go +++ b/internal/datacoord/channel_manager_test.go @@ -26,6 +26,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/milvus-io/milvus/internal/kv" @@ -120,7 +121,7 @@ func TestChannelManager_StateTransfer(t *testing.T) { }() chManager.AddNode(nodeID) - chManager.Watch(&channel{Name: cName, CollectionID: collectionID}) + chManager.Watch(ctx, &channel{Name: cName, CollectionID: collectionID}) key := path.Join(prefix, strconv.FormatInt(nodeID, 10), cName) waitAndStore(t, watchkv, key, datapb.ChannelWatchState_ToWatch, datapb.ChannelWatchState_WatchSuccess) @@ -150,7 +151,7 @@ func TestChannelManager_StateTransfer(t *testing.T) { }() chManager.AddNode(nodeID) - chManager.Watch(&channel{Name: cName, CollectionID: collectionID}) + chManager.Watch(ctx, &channel{Name: cName, CollectionID: collectionID}) key := path.Join(prefix, strconv.FormatInt(nodeID, 10), cName) waitAndStore(t, watchkv, key, datapb.ChannelWatchState_ToWatch, datapb.ChannelWatchState_WatchFailure) @@ -181,7 +182,7 @@ func TestChannelManager_StateTransfer(t *testing.T) { }() chManager.AddNode(nodeID) - chManager.Watch(&channel{Name: cName, CollectionID: collectionID}) + chManager.Watch(ctx, &channel{Name: cName, CollectionID: collectionID}) // simulating timeout behavior of startOne, cuz 20s is a long wait e := &ackEvent{ @@ -203,7 +204,7 @@ func TestChannelManager_StateTransfer(t *testing.T) { }) t.Run("ToRelease-ReleaseSuccess-Reassign-ToWatch-2-DN", func(t *testing.T) { - var oldNode = UniqueID(120) + oldNode := UniqueID(120) cName := channelNamePrefix + "ToRelease-ReleaseSuccess-Reassign-ToWatch-2-DN" watchkv.RemoveWithPrefix("") @@ -289,7 +290,7 @@ func TestChannelManager_StateTransfer(t *testing.T) { }) t.Run("ToRelease-ReleaseFail-CleanUpAndDelete-Reassign-ToWatch-2-DN", func(t *testing.T) { - var oldNode = UniqueID(121) + oldNode := UniqueID(121) cName := channelNamePrefix + "ToRelease-ReleaseFail-CleanUpAndDelete-Reassign-ToWatch-2-DN" watchkv.RemoveWithPrefix("") @@ -417,7 +418,7 @@ func TestChannelManager(t *testing.T) { assert.False(t, chManager.Match(nodeToAdd, channel1)) assert.False(t, chManager.Match(nodeToAdd, channel2)) - err = chManager.Watch(&channel{Name: "channel-3", CollectionID: collectionID}) + err = chManager.Watch(context.TODO(), &channel{Name: "channel-3", CollectionID: collectionID}) assert.NoError(t, err) assert.True(t, chManager.Match(nodeToAdd, "channel-3")) @@ -459,7 +460,7 @@ func TestChannelManager(t *testing.T) { assert.True(t, chManager.Match(nodeID, channel1)) assert.True(t, chManager.Match(nodeID, channel2)) - err = chManager.Watch(&channel{Name: "channel-3", CollectionID: collectionID}) + err = chManager.Watch(context.TODO(), &channel{Name: "channel-3", CollectionID: collectionID}) assert.NoError(t, err) waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, nodeID, "channel-3", collectionID) @@ -478,13 +479,13 @@ func TestChannelManager(t *testing.T) { chManager, err := NewChannelManager(watchkv, newMockHandler()) require.NoError(t, err) - err = chManager.Watch(&channel{Name: bufferCh, CollectionID: collectionID}) + err = chManager.Watch(context.TODO(), &channel{Name: bufferCh, CollectionID: collectionID}) assert.NoError(t, err) waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, bufferID, bufferCh, collectionID) chManager.store.Add(nodeID) - err = chManager.Watch(&channel{Name: chanToAdd, CollectionID: collectionID}) + err = chManager.Watch(context.TODO(), &channel{Name: chanToAdd, CollectionID: collectionID}) assert.NoError(t, err) waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, nodeID, chanToAdd, collectionID) @@ -520,7 +521,7 @@ func TestChannelManager(t *testing.T) { t.Run("test Reassign", func(t *testing.T) { defer watchkv.RemoveWithPrefix("") - var collectionID = UniqueID(5) + collectionID := UniqueID(5) tests := []struct { nodeID UniqueID @@ -568,12 +569,117 @@ func TestChannelManager(t *testing.T) { waitAndCheckState(t, watchkv, datapb.ChannelWatchState_ToWatch, remainTest.nodeID, remainTest.chName, collectionID) }) + t.Run("test Reassign with get channel fail", func(t *testing.T) { + chManager, err := NewChannelManager(watchkv, newMockHandler()) + require.NoError(t, err) + + err = chManager.Reassign(1, "not-exists-channelName") + assert.Error(t, err) + }) + + t.Run("test Reassign with dropped channel", func(t *testing.T) { + collectionID := UniqueID(5) + handler := NewNMockHandler(t) + handler.EXPECT(). + CheckShouldDropChannel(mock.Anything, mock.Anything). + Return(true) + handler.EXPECT().FinishDropChannel(mock.Anything).Return(nil) + chManager, err := NewChannelManager(watchkv, handler) + require.NoError(t, err) + + chManager.store.Add(1) + ops := getOpsWithWatchInfo(1, &channel{Name: "chan", CollectionID: collectionID}) + err = chManager.store.Update(ops) + require.NoError(t, err) + + assert.Equal(t, 1, chManager.store.GetNodeChannelCount(1)) + err = chManager.Reassign(1, "chan") + assert.NoError(t, err) + assert.Equal(t, 0, chManager.store.GetNodeChannelCount(1)) + }) + + t.Run("test Reassign-channel not found", func(t *testing.T) { + var chManager *ChannelManager + var err error + handler := NewNMockHandler(t) + handler.EXPECT(). + CheckShouldDropChannel(mock.Anything, mock.Anything). + Run(func(channel string, collectionID int64) { + channels, err := chManager.store.Delete(1) + assert.NoError(t, err) + assert.Equal(t, 1, len(channels)) + }).Return(true).Once() + + chManager, err = NewChannelManager(watchkv, handler) + require.NoError(t, err) + + chManager.store.Add(1) + ops := getOpsWithWatchInfo(1, &channel{Name: "chan", CollectionID: 1}) + err = chManager.store.Update(ops) + require.NoError(t, err) + + assert.Equal(t, 1, chManager.store.GetNodeChannelCount(1)) + err = chManager.Reassign(1, "chan") + assert.Error(t, err) + }) + + t.Run("test CleanupAndReassign-channel not found", func(t *testing.T) { + var chManager *ChannelManager + var err error + handler := NewNMockHandler(t) + handler.EXPECT(). + CheckShouldDropChannel(mock.Anything, mock.Anything). + Run(func(channel string, collectionID int64) { + channels, err := chManager.store.Delete(1) + assert.NoError(t, err) + assert.Equal(t, 1, len(channels)) + }).Return(true).Once() + + chManager, err = NewChannelManager(watchkv, handler) + require.NoError(t, err) + + chManager.store.Add(1) + ops := getOpsWithWatchInfo(1, &channel{Name: "chan", CollectionID: 1}) + err = chManager.store.Update(ops) + require.NoError(t, err) + + assert.Equal(t, 1, chManager.store.GetNodeChannelCount(1)) + err = chManager.CleanupAndReassign(1, "chan") + assert.Error(t, err) + }) + + t.Run("test CleanupAndReassign with get channel fail", func(t *testing.T) { + chManager, err := NewChannelManager(watchkv, newMockHandler()) + require.NoError(t, err) + + err = chManager.CleanupAndReassign(1, "not-exists-channelName") + assert.Error(t, err) + }) + + t.Run("test CleanupAndReassign with dropped channel", func(t *testing.T) { + handler := NewNMockHandler(t) + handler.EXPECT(). + CheckShouldDropChannel(mock.Anything, mock.Anything). + Return(true) + handler.EXPECT().FinishDropChannel(mock.Anything).Return(nil) + chManager, err := NewChannelManager(watchkv, handler) + require.NoError(t, err) + + chManager.store.Add(1) + ops := getOpsWithWatchInfo(1, &channel{Name: "chan", CollectionID: 1}) + err = chManager.store.Update(ops) + require.NoError(t, err) + + assert.Equal(t, 1, chManager.store.GetNodeChannelCount(1)) + err = chManager.CleanupAndReassign(1, "chan") + assert.NoError(t, err) + assert.Equal(t, 0, chManager.store.GetNodeChannelCount(1)) + }) + t.Run("test DeleteNode", func(t *testing.T) { defer watchkv.RemoveWithPrefix("") - var ( - collectionID = UniqueID(999) - ) + collectionID := UniqueID(999) chManager, err := NewChannelManager(watchkv, newMockHandler(), withStateChecker()) require.NoError(t, err) chManager.store = &ChannelStore{ @@ -581,7 +687,8 @@ func TestChannelManager(t *testing.T) { channelsInfo: map[int64]*NodeChannelInfo{ 1: {1, []*channel{ {Name: "channel-1", CollectionID: collectionID}, - {Name: "channel-2", CollectionID: collectionID}}}, + {Name: "channel-2", CollectionID: collectionID}, + }}, bufferID: {bufferID, []*channel{}}, }, } @@ -596,7 +703,7 @@ func TestChannelManager(t *testing.T) { t.Run("test CleanupAndReassign", func(t *testing.T) { defer watchkv.RemoveWithPrefix("") - var collectionID = UniqueID(6) + collectionID := UniqueID(6) tests := []struct { nodeID UniqueID @@ -745,7 +852,7 @@ func TestChannelManager(t *testing.T) { ) cName := channelNamePrefix + "TestBgChecker" - //1. set up channel_manager + // 1. set up channel_manager ctx, cancel := context.WithCancel(context.TODO()) defer cancel() chManager, err := NewChannelManager(watchkv, newMockHandler(), withBgChecker()) @@ -753,20 +860,20 @@ func TestChannelManager(t *testing.T) { assert.NotNil(t, chManager.bgChecker) chManager.Startup(ctx, []int64{nodeID}) - //2. test isSilent function running correctly + // 2. test isSilent function running correctly Params.Save(Params.DataCoordCfg.ChannelBalanceSilentDuration.Key, "3") assert.False(t, chManager.isSilent()) assert.False(t, chManager.stateTimer.hasRunningTimers()) - //3. watch one channel - chManager.Watch(&channel{Name: cName, CollectionID: collectionID}) + // 3. watch one channel + chManager.Watch(ctx, &channel{Name: cName, CollectionID: collectionID}) assert.False(t, chManager.isSilent()) assert.True(t, chManager.stateTimer.hasRunningTimers()) key := path.Join(prefix, strconv.FormatInt(nodeID, 10), cName) waitAndStore(t, watchkv, key, datapb.ChannelWatchState_ToWatch, datapb.ChannelWatchState_WatchSuccess) waitAndCheckState(t, watchkv, datapb.ChannelWatchState_WatchSuccess, nodeID, cName, collectionID) - //4. wait for duration and check silent again + // 4. wait for duration and check silent again time.Sleep(Params.DataCoordCfg.ChannelBalanceSilentDuration.GetAsDuration(time.Second)) chManager.stateTimer.removeTimers([]string{cName}) assert.True(t, chManager.isSilent()) @@ -839,7 +946,8 @@ func TestChannelManager_Reload(t *testing.T) { chManager.store = &ChannelStore{ store: watchkv, channelsInfo: map[int64]*NodeChannelInfo{ - nodeID: {nodeID, []*channel{{Name: channelName, CollectionID: collectionID}}}}, + nodeID: {nodeID, []*channel{{Name: channelName, CollectionID: collectionID}}}, + }, } data, err := proto.Marshal(getWatchInfoWithState(datapb.ChannelWatchState_WatchFailure, collectionID, channelName)) @@ -861,7 +969,8 @@ func TestChannelManager_Reload(t *testing.T) { chManager.store = &ChannelStore{ store: watchkv, channelsInfo: map[int64]*NodeChannelInfo{ - nodeID: {nodeID, []*channel{{Name: channelName, CollectionID: collectionID}}}}, + nodeID: {nodeID, []*channel{{Name: channelName, CollectionID: collectionID}}}, + }, } require.NoError(t, err) @@ -902,7 +1011,6 @@ func TestChannelManager_Reload(t *testing.T) { v, err := watchkv.Load(path.Join(prefix, strconv.FormatInt(nodeID, 10), channelName)) assert.Error(t, err) assert.Empty(t, v) - }) }) @@ -958,9 +1066,7 @@ func TestChannelManager_BalanceBehaviour(t *testing.T) { t.Run("one node with three channels add a new node", func(t *testing.T) { defer watchkv.RemoveWithPrefix("") - var ( - collectionID = UniqueID(999) - ) + collectionID := UniqueID(999) chManager, err := NewChannelManager(watchkv, newMockHandler(), withStateChecker()) require.NoError(t, err) @@ -976,12 +1082,12 @@ func TestChannelManager_BalanceBehaviour(t *testing.T) { 1: {1, []*channel{ {Name: "channel-1", CollectionID: collectionID}, {Name: "channel-2", CollectionID: collectionID}, - {Name: "channel-3", CollectionID: collectionID}}}}, + {Name: "channel-3", CollectionID: collectionID}, + }}, + }, } - var ( - channelBalanced string - ) + var channelBalanced string chManager.AddNode(2) channelBalanced = "channel-1" @@ -1018,7 +1124,7 @@ func TestChannelManager_BalanceBehaviour(t *testing.T) { assert.True(t, chManager.Match(2, "channel-1")) chManager.AddNode(3) - chManager.Watch(&channel{Name: "channel-4", CollectionID: collectionID}) + chManager.Watch(ctx, &channel{Name: "channel-4", CollectionID: collectionID}) key = path.Join(prefix, "3", "channel-4") waitAndStore(t, watchkv, key, datapb.ChannelWatchState_ToWatch, datapb.ChannelWatchState_WatchSuccess) @@ -1047,7 +1153,6 @@ func TestChannelManager_BalanceBehaviour(t *testing.T) { assert.True(t, chManager.Match(1, "channel-1")) assert.True(t, chManager.Match(1, "channel-4")) }) - } func TestChannelManager_RemoveChannel(t *testing.T) { @@ -1153,6 +1258,5 @@ func TestChannelManager_HelperFunc(t *testing.T) { assert.ElementsMatch(t, test.expectedOut, nodes) }) } - }) } diff --git a/internal/datacoord/channel_store_test.go b/internal/datacoord/channel_store_test.go index ace7e086f888c..22f545fb40131 100644 --- a/internal/datacoord/channel_store_test.go +++ b/internal/datacoord/channel_store_test.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/internal/kv/mocks" + "github.com/milvus-io/milvus/internal/kv/predicates" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/testutils" @@ -73,7 +74,7 @@ func genChannelOperations(from, to int64, num int) ChannelOpSet { func TestChannelStore_Update(t *testing.T) { txnKv := mocks.NewTxnKV(t) - txnKv.EXPECT().MultiSaveAndRemove(mock.Anything, mock.Anything).Run(func(saves map[string]string, removals []string) { + txnKv.EXPECT().MultiSaveAndRemove(mock.Anything, mock.Anything).Run(func(saves map[string]string, removals []string, preds ...predicates.Predicate) { assert.False(t, len(saves)+len(removals) > 128, "too many operations") }).Return(nil) diff --git a/internal/datacoord/cluster.go b/internal/datacoord/cluster.go index 9e137134bae63..dae6698065ac4 100644 --- a/internal/datacoord/cluster.go +++ b/internal/datacoord/cluster.go @@ -20,13 +20,14 @@ import ( "context" "fmt" + "github.com/samber/lo" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/samber/lo" - "go.uber.org/zap" ) // Cluster provides interfaces to interact with datanode cluster @@ -70,14 +71,15 @@ func (c *Cluster) UnRegister(node *NodeInfo) error { } // Watch tries to add a channel in datanode cluster -func (c *Cluster) Watch(ch string, collectionID UniqueID) error { - return c.channelManager.Watch(&channel{Name: ch, CollectionID: collectionID}) +func (c *Cluster) Watch(ctx context.Context, ch string, collectionID UniqueID) error { + return c.channelManager.Watch(ctx, &channel{Name: ch, CollectionID: collectionID}) } // Flush sends flush requests to dataNodes specified // which also according to channels where segments are assigned to. func (c *Cluster) Flush(ctx context.Context, nodeID int64, channel string, - segments []*datapb.SegmentInfo) error { + segments []*datapb.SegmentInfo, +) error { if !c.channelManager.Match(nodeID, channel) { log.Warn("node is not matched with channel", zap.String("channel", channel), @@ -100,26 +102,39 @@ func (c *Cluster) Flush(ctx context.Context, nodeID int64, channel string, ), CollectionID: ch.CollectionID, SegmentIDs: lo.Map(segments, getSegmentID), + ChannelName: channel, } c.sessionManager.Flush(ctx, nodeID, req) return nil } -// Import sends import requests to DataNodes whose ID==nodeID. -func (c *Cluster) Import(ctx context.Context, nodeID int64, it *datapb.ImportTaskRequest) { - c.sessionManager.Import(ctx, nodeID, it) -} +func (c *Cluster) FlushChannels(ctx context.Context, nodeID int64, flushTs Timestamp, channels []string) error { + if len(channels) == 0 { + return nil + } -// ReCollectSegmentStats triggers a ReCollectSegmentStats call from session manager. -func (c *Cluster) ReCollectSegmentStats(ctx context.Context) error { - for _, node := range c.sessionManager.getLiveNodeIDs() { - err := c.sessionManager.ReCollectSegmentStats(ctx, node) - if err != nil { - return err + for _, channel := range channels { + if !c.channelManager.Match(nodeID, channel) { + return fmt.Errorf("channel %s is not watched on node %d", channel, nodeID) } } - return nil + + req := &datapb.FlushChannelsRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithTargetID(nodeID), + ), + FlushTs: flushTs, + Channels: channels, + } + + return c.sessionManager.FlushChannels(ctx, nodeID, req) +} + +// Import sends import requests to DataNodes whose ID==nodeID. +func (c *Cluster) Import(ctx context.Context, nodeID int64, it *datapb.ImportTaskRequest) { + c.sessionManager.Import(ctx, nodeID, it) } // GetSessions returns all sessions diff --git a/internal/datacoord/cluster_test.go b/internal/datacoord/cluster_test.go index e9dd11d929981..2df57c6081719 100644 --- a/internal/datacoord/cluster_test.go +++ b/internal/datacoord/cluster_test.go @@ -132,7 +132,7 @@ func (suite *ClusterSuite) TestCreate() { err = cluster.Startup(ctx, []*NodeInfo{{NodeID: 1, Address: "localhost:9999"}}) suite.NoError(err) - channels := channelManager.GetChannels() + channels := channelManager.GetAssignedChannels() suite.EqualValues([]*NodeChannelInfo{{1, []*channel{{Name: "channel1", CollectionID: 1}}}}, channels) }) @@ -181,7 +181,7 @@ func (suite *ClusterSuite) TestCreate() { suite.EqualValues(1, len(sessions)) suite.EqualValues(2, sessions[0].info.NodeID) suite.EqualValues(addr, sessions[0].info.Address) - channels := channelManager2.GetChannels() + channels := channelManager2.GetAssignedChannels() suite.EqualValues(1, len(channels)) suite.EqualValues(2, channels[0].NodeID) }) @@ -235,7 +235,7 @@ func (suite *ClusterSuite) TestRegister() { sessionManager := NewSessionManager() channelManager, err := NewChannelManager(kv, newMockHandler()) suite.NoError(err) - err = channelManager.Watch(&channel{ + err = channelManager.Watch(context.TODO(), &channel{ Name: "ch1", CollectionID: 0, }) @@ -253,7 +253,7 @@ func (suite *ClusterSuite) TestRegister() { suite.NoError(err) bufferChannels := channelManager.GetBufferChannels() suite.Empty(bufferChannels.Channels) - nodeChannels := channelManager.GetChannels() + nodeChannels := channelManager.GetAssignedChannels() suite.EqualValues(1, len(nodeChannels)) suite.EqualValues(1, nodeChannels[0].NodeID) suite.EqualValues("ch1", nodeChannels[0].Channels[0].Name) @@ -287,7 +287,7 @@ func (suite *ClusterSuite) TestRegister() { suite.NoError(err) restartCluster := NewCluster(sessionManager2, channelManager2) defer restartCluster.Close() - channels := channelManager2.GetChannels() + channels := channelManager2.GetAssignedChannels() suite.Empty(channels) suite.MetricsEqual(metrics.DataCoordNumDataNodes, 1) @@ -347,12 +347,12 @@ func (suite *ClusterSuite) TestUnregister() { nodes := []*NodeInfo{nodeInfo1, nodeInfo2} err = cluster.Startup(ctx, nodes) suite.NoError(err) - err = cluster.Watch("ch1", 1) + err = cluster.Watch(ctx, "ch1", 1) suite.NoError(err) err = cluster.UnRegister(nodeInfo1) suite.NoError(err) - channels := channelManager.GetChannels() + channels := channelManager.GetAssignedChannels() suite.EqualValues(1, len(channels)) suite.EqualValues(2, channels[0].NodeID) suite.EqualValues(1, len(channels[0].Channels)) @@ -367,7 +367,7 @@ func (suite *ClusterSuite) TestUnregister() { ctx, cancel := context.WithCancel(context.TODO()) defer cancel() - var mockSessionCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { + mockSessionCreator := func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { return newMockDataNodeClient(1, nil) } sessionManager := NewSessionManager(withSessionCreator(mockSessionCreator)) @@ -382,11 +382,11 @@ func (suite *ClusterSuite) TestUnregister() { } err = cluster.Startup(ctx, []*NodeInfo{nodeInfo}) suite.NoError(err) - err = cluster.Watch("ch_1", 1) + err = cluster.Watch(ctx, "ch_1", 1) suite.NoError(err) err = cluster.UnRegister(nodeInfo) suite.NoError(err) - channels := channelManager.GetChannels() + channels := channelManager.GetAssignedChannels() suite.Empty(channels) channel := channelManager.GetBufferChannels() suite.NotNil(channel) @@ -414,7 +414,7 @@ func TestWatchIfNeeded(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) defer cancel() - var mockSessionCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { + mockSessionCreator := func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { return newMockDataNodeClient(1, nil) } sessionManager := NewSessionManager(withSessionCreator(mockSessionCreator)) @@ -431,9 +431,9 @@ func TestWatchIfNeeded(t *testing.T) { err = cluster.Startup(ctx, []*NodeInfo{info}) assert.NoError(t, err) - err = cluster.Watch("ch1", 1) + err = cluster.Watch(ctx, "ch1", 1) assert.NoError(t, err) - channels := channelManager.GetChannels() + channels := channelManager.GetAssignedChannels() assert.EqualValues(t, 1, len(channels)) assert.EqualValues(t, "ch1", channels[0].Channels[0].Name) }) @@ -441,16 +441,18 @@ func TestWatchIfNeeded(t *testing.T) { t.Run("watch channel to empty cluster", func(t *testing.T) { defer kv.RemoveWithPrefix("") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() sessionManager := NewSessionManager() channelManager, err := NewChannelManager(kv, newMockHandler()) assert.NoError(t, err) cluster := NewCluster(sessionManager, channelManager) defer cluster.Close() - err = cluster.Watch("ch1", 1) + err = cluster.Watch(ctx, "ch1", 1) assert.NoError(t, err) - channels := channelManager.GetChannels() + channels := channelManager.GetAssignedChannels() assert.Empty(t, channels) channel := channelManager.GetBufferChannels() assert.NotNil(t, channel) @@ -499,7 +501,7 @@ func TestConsistentHashPolicy(t *testing.T) { channels := []string{"ch1", "ch2", "ch3"} for _, c := range channels { - err = cluster.Watch(c, 1) + err = cluster.Watch(context.TODO(), c, 1) assert.NoError(t, err) idstr, err := hash.Get(c) assert.NoError(t, err) @@ -563,7 +565,7 @@ func TestCluster_Flush(t *testing.T) { err = cluster.Startup(ctx, nodes) assert.NoError(t, err) - err = cluster.Watch("chan-1", 1) + err = cluster.Watch(context.Background(), "chan-1", 1) assert.NoError(t, err) // flush empty should impact nothing @@ -584,7 +586,7 @@ func TestCluster_Flush(t *testing.T) { assert.Error(t, err) }) - //TODO add a method to verify datanode has flush request after client injection is available + // TODO add a method to verify datanode has flush request after client injection is available } func TestCluster_Import(t *testing.T) { @@ -610,7 +612,7 @@ func TestCluster_Import(t *testing.T) { err = cluster.Startup(ctx, nodes) assert.NoError(t, err) - err = cluster.Watch("chan-1", 1) + err = cluster.Watch(ctx, "chan-1", 1) assert.NoError(t, err) assert.NotPanics(t, func() { @@ -618,66 +620,3 @@ func TestCluster_Import(t *testing.T) { }) time.Sleep(500 * time.Millisecond) } - -func TestCluster_ReCollectSegmentStats(t *testing.T) { - kv := getWatchKV(t) - defer func() { - kv.RemoveWithPrefix("") - kv.Close() - }() - - t.Run("recollect succeed", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - var mockSessionCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { - return newMockDataNodeClient(1, nil) - } - sessionManager := NewSessionManager(withSessionCreator(mockSessionCreator)) - channelManager, err := NewChannelManager(kv, newMockHandler()) - assert.NoError(t, err) - cluster := NewCluster(sessionManager, channelManager) - defer cluster.Close() - addr := "localhost:8080" - info := &NodeInfo{ - Address: addr, - NodeID: 1, - } - nodes := []*NodeInfo{info} - err = cluster.Startup(ctx, nodes) - assert.NoError(t, err) - - err = cluster.Watch("chan-1", 1) - assert.NoError(t, err) - - assert.NotPanics(t, func() { - cluster.ReCollectSegmentStats(ctx) - }) - time.Sleep(500 * time.Millisecond) - }) - - t.Run("recollect failed", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - sessionManager := NewSessionManager() - channelManager, err := NewChannelManager(kv, newMockHandler()) - assert.NoError(t, err) - cluster := NewCluster(sessionManager, channelManager) - defer cluster.Close() - addr := "localhost:8080" - info := &NodeInfo{ - Address: addr, - NodeID: 1, - } - nodes := []*NodeInfo{info} - err = cluster.Startup(ctx, nodes) - assert.NoError(t, err) - - err = cluster.Watch("chan-1", 1) - assert.NoError(t, err) - - assert.NotPanics(t, func() { - cluster.ReCollectSegmentStats(ctx) - }) - time.Sleep(500 * time.Millisecond) - }) -} diff --git a/internal/datacoord/compaction.go b/internal/datacoord/compaction.go index 221fa9bfe4ff5..bece81e5bd11d 100644 --- a/internal/datacoord/compaction.go +++ b/internal/datacoord/compaction.go @@ -23,13 +23,13 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) // TODO this num should be determined by resources of datanode, for now, we set to a fixed value for simple @@ -102,12 +102,13 @@ type compactionPlanHandler struct { quit chan struct{} wg sync.WaitGroup flushCh chan UniqueID - //segRefer *SegmentReferenceManager + // segRefer *SegmentReferenceManager parallelCh map[int64]chan struct{} } func newCompactionPlanHandler(sessions *SessionManager, cm *ChannelManager, meta *meta, - allocator allocator, flush chan UniqueID) *compactionPlanHandler { + allocator allocator, flush chan UniqueID, +) *compactionPlanHandler { return &compactionPlanHandler{ plans: make(map[int64]*compactionTask), chManager: cm, @@ -115,7 +116,7 @@ func newCompactionPlanHandler(sessions *SessionManager, cm *ChannelManager, meta sessions: sessions, allocator: allocator, flushCh: flush, - //segRefer: segRefer, + // segRefer: segRefer, parallelCh: make(map[int64]chan struct{}), } } @@ -263,7 +264,7 @@ func (c *compactionPlanHandler) handleMergeCompactionResult(plan *datapb.Compact return err } - var nodeID = c.plans[plan.GetPlanID()].dataNodeID + nodeID := c.plans[plan.GetPlanID()].dataNodeID req := &datapb.SyncSegmentsRequest{ PlanID: plan.PlanID, CompactedTo: newSegment.GetID(), diff --git a/internal/datacoord/compaction_test.go b/internal/datacoord/compaction_test.go index f4cbc3a326825..3f2e7f6b21449 100644 --- a/internal/datacoord/compaction_test.go +++ b/internal/datacoord/compaction_test.go @@ -23,22 +23,22 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.uber.org/zap" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + mockkv "github.com/milvus-io/milvus/internal/kv/mocks" "github.com/milvus-io/milvus/internal/metastore/kv/datacoord" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" - - mockkv "github.com/milvus-io/milvus/internal/kv/mocks" ) func Test_compactionPlanHandler_execCompactionPlan(t *testing.T) { @@ -182,7 +182,6 @@ func Test_compactionPlanHandler_execCompactionPlan(t *testing.T) { assert.Equal(t, tt.args.signal, task.triggerInfo) assert.Equal(t, 1, c.executingTaskNum) } else { - assert.Eventually(t, func() bool { c.mu.RLock() @@ -198,8 +197,7 @@ func Test_compactionPlanHandler_execCompactionPlan(t *testing.T) { } func Test_compactionPlanHandler_execWithParallels(t *testing.T) { - - mockDataNode := &mocks.MockDataNode{} + mockDataNode := &mocks.MockDataNodeClient{} paramtable.Get().Save(Params.DataCoordCfg.CompactionCheckIntervalInSeconds.Key, "1") defer paramtable.Get().Reset(Params.DataCoordCfg.CompactionCheckIntervalInSeconds.Key) c := &compactionPlanHandler{ @@ -235,11 +233,12 @@ func Test_compactionPlanHandler_execWithParallels(t *testing.T) { var mut sync.RWMutex called := 0 - mockDataNode.EXPECT().Compaction(mock.Anything, mock.Anything).Run(func(ctx context.Context, req *datapb.CompactionPlan) { - mut.Lock() - defer mut.Unlock() - called++ - }).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil).Times(3) + mockDataNode.EXPECT().Compaction(mock.Anything, mock.Anything, mock.Anything). + Run(func(ctx context.Context, req *datapb.CompactionPlan, opts ...grpc.CallOption) { + mut.Lock() + defer mut.Unlock() + called++ + }).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil).Times(3) go func() { c.execCompactionPlan(signal, plan1) c.execCompactionPlan(signal, plan2) @@ -286,8 +285,10 @@ func getDeltaLogPath(rootPath string, segmentID typeutil.UniqueID) string { } func TestCompactionPlanHandler_handleMergeCompactionResult(t *testing.T) { - mockDataNode := &mocks.MockDataNode{} - call := mockDataNode.EXPECT().SyncSegments(mock.Anything, mock.Anything).Run(func(ctx context.Context, req *datapb.SyncSegmentsRequest) {}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) + mockDataNode := &mocks.MockDataNodeClient{} + call := mockDataNode.EXPECT().SyncSegments(mock.Anything, mock.Anything, mock.Anything). + Run(func(ctx context.Context, req *datapb.SyncSegmentsRequest, opts ...grpc.CallOption) {}). + Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) dataNodeID := UniqueID(111) @@ -330,7 +331,8 @@ func TestCompactionPlanHandler_handleMergeCompactionResult(t *testing.T) { data map[int64]*Session }{ data: map[int64]*Session{ - dataNodeID: {client: mockDataNode}}, + dataNodeID: {client: mockDataNode}, + }, }, } @@ -419,7 +421,8 @@ func TestCompactionPlanHandler_handleMergeCompactionResult(t *testing.T) { require.True(t, has) call.Unset() - call = mockDataNode.EXPECT().SyncSegments(mock.Anything, mock.Anything).Run(func(ctx context.Context, req *datapb.SyncSegmentsRequest) {}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil) + mockDataNode.EXPECT().SyncSegments(mock.Anything, mock.Anything, mock.Anything). + Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil) err = c.handleMergeCompactionResult(plan, compactionResult2) assert.Error(t, err) } @@ -441,8 +444,10 @@ func TestCompactionPlanHandler_completeCompaction(t *testing.T) { }) t.Run("test complete merge compaction task", func(t *testing.T) { - mockDataNode := &mocks.MockDataNode{} - mockDataNode.EXPECT().SyncSegments(mock.Anything, mock.Anything).Run(func(ctx context.Context, req *datapb.SyncSegmentsRequest) {}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) + mockDataNode := &mocks.MockDataNodeClient{} + mockDataNode.EXPECT().SyncSegments(mock.Anything, mock.Anything, mock.Anything). + Run(func(ctx context.Context, req *datapb.SyncSegmentsRequest, opts ...grpc.CallOption) {}). + Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) dataNodeID := UniqueID(111) @@ -485,7 +490,8 @@ func TestCompactionPlanHandler_completeCompaction(t *testing.T) { data map[int64]*Session }{ data: map[int64]*Session{ - dataNodeID: {client: mockDataNode}}, + dataNodeID: {client: mockDataNode}, + }, }, } @@ -533,8 +539,10 @@ func TestCompactionPlanHandler_completeCompaction(t *testing.T) { }) t.Run("test empty result merge compaction task", func(t *testing.T) { - mockDataNode := &mocks.MockDataNode{} - mockDataNode.EXPECT().SyncSegments(mock.Anything, mock.Anything).Run(func(ctx context.Context, req *datapb.SyncSegmentsRequest) {}).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) + mockDataNode := &mocks.MockDataNodeClient{} + mockDataNode.EXPECT().SyncSegments(mock.Anything, mock.Anything, mock.Anything). + Run(func(ctx context.Context, req *datapb.SyncSegmentsRequest, opts ...grpc.CallOption) {}). + Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) dataNodeID := UniqueID(111) @@ -577,7 +585,8 @@ func TestCompactionPlanHandler_completeCompaction(t *testing.T) { data map[int64]*Session }{ data: map[int64]*Session{ - dataNodeID: {client: mockDataNode}}, + dataNodeID: {client: mockDataNode}, + }, }, } @@ -600,8 +609,8 @@ func TestCompactionPlanHandler_completeCompaction(t *testing.T) { }, } - meta.AddSegment(NewSegmentInfo(seg1)) - meta.AddSegment(NewSegmentInfo(seg2)) + meta.AddSegment(context.TODO(), NewSegmentInfo(seg1)) + meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) segments := meta.GetAllSegmentsUnsafe() assert.Equal(t, len(segments), 2) diff --git a/internal/datacoord/compaction_trigger.go b/internal/datacoord/compaction_trigger.go index 7cdf2d08c98af..781cbdba9d27b 100644 --- a/internal/datacoord/compaction_trigger.go +++ b/internal/datacoord/compaction_trigger.go @@ -23,7 +23,6 @@ import ( "sync" "time" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/samber/lo" "go.uber.org/zap" @@ -32,6 +31,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/logutil" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) type compactTime struct { @@ -72,8 +72,8 @@ type compactionTrigger struct { forceMu sync.Mutex quit chan struct{} wg sync.WaitGroup - //segRefer *SegmentReferenceManager - //indexCoord types.IndexCoord + // segRefer *SegmentReferenceManager + // indexCoord types.IndexCoord estimateNonDiskSegmentPolicy calUpperLimitPolicy estimateDiskSegmentPolicy calUpperLimitPolicy // A sloopy hack, so we can test with different segment row count without worrying that @@ -85,8 +85,8 @@ func newCompactionTrigger( meta *meta, compactionHandler compactionPlanContext, allocator allocator, - //segRefer *SegmentReferenceManager, - //indexCoord types.IndexCoord, + // segRefer *SegmentReferenceManager, + // indexCoord types.IndexCoord, handler Handler, ) *compactionTrigger { return &compactionTrigger{ @@ -94,8 +94,8 @@ func newCompactionTrigger( allocator: allocator, signals: make(chan *compactionSignal, 100), compactionHandler: compactionHandler, - //segRefer: segRefer, - //indexCoord: indexCoord, + // segRefer: segRefer, + // indexCoord: indexCoord, estimateDiskSegmentPolicy: calBySchemaPolicyWithDiskIndex, estimateNonDiskSegmentPolicy: calBySchemaPolicy, handler: handler, @@ -211,7 +211,6 @@ func (t *compactionTrigger) getCompactTime(ts Timestamp, coll *collectionInfo) ( // triggerCompaction trigger a compaction if any compaction condition satisfy. func (t *compactionTrigger) triggerCompaction() error { - id, err := t.allocSignalID() if err != nil { return err @@ -585,7 +584,7 @@ func (t *compactionTrigger) generatePlans(segments []*SegmentInfo, force bool, i } // greedy pick from large segment to small, the goal is to fill each segment to reach 512M // we must ensure all prioritized candidates is in a plan - //TODO the compaction selection policy should consider if compaction workload is high + // TODO the compaction selection policy should consider if compaction workload is high for len(prioritizedCandidates) > 0 { var bucket []*SegmentInfo // pop out the first element diff --git a/internal/datacoord/compaction_trigger_test.go b/internal/datacoord/compaction_trigger_test.go index fe21e152507af..c00f4605141b7 100644 --- a/internal/datacoord/compaction_trigger_test.go +++ b/internal/datacoord/compaction_trigger_test.go @@ -22,8 +22,6 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -33,6 +31,7 @@ import ( "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) type spyCompactionHandler struct { @@ -466,9 +465,7 @@ func Test_compactionTrigger_force(t *testing.T) { }) t.Run(tt.name+" with DiskANN index", func(t *testing.T) { - segmentIDs := make([]int64, 0) for _, segment := range tt.fields.meta.segments.GetSegments() { - segmentIDs = append(segmentIDs, segment.GetID()) // Collection 1000 means it has DiskANN index segment.CollectionID = 1000 } @@ -493,7 +490,7 @@ func Test_compactionTrigger_force(t *testing.T) { }) t.Run(tt.name+" with allocate ts error", func(t *testing.T) { - //indexCood := newMockIndexCoord() + // indexCood := newMockIndexCoord() tr := &compactionTrigger{ meta: tt.fields.meta, handler: newMockHandlerWithMeta(tt.fields.meta), @@ -623,7 +620,6 @@ func Test_compactionTrigger_force_maxSegmentLimit(t *testing.T) { collectionID int64 compactTime *compactTime } - paramtable.Init() vecFieldID := int64(201) segmentInfos := &SegmentsInfo{ segments: make(map[UniqueID]*SegmentInfo), @@ -830,7 +826,6 @@ func Test_compactionTrigger_noplan(t *testing.T) { collectionID int64 compactTime *compactTime } - paramtable.Init() Params.DataCoordCfg.MinSegmentToMerge.DefaultValue = "4" vecFieldID := int64(201) tests := []struct { @@ -931,7 +926,6 @@ func Test_compactionTrigger_noplan(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tr := &compactionTrigger{ meta: tt.fields.meta, handler: newMockHandlerWithMeta(tt.fields.meta), @@ -972,7 +966,6 @@ func Test_compactionTrigger_PrioritizedCandi(t *testing.T) { collectionID int64 compactTime *compactTime } - paramtable.Init() vecFieldID := int64(201) genSeg := func(segID, numRows int64) *datapb.SegmentInfo { @@ -1155,7 +1148,6 @@ func Test_compactionTrigger_SmallCandi(t *testing.T) { collectionID int64 compactTime *compactTime } - paramtable.Init() vecFieldID := int64(201) genSeg := func(segID, numRows int64) *datapb.SegmentInfo { @@ -1338,7 +1330,6 @@ func Test_compactionTrigger_SqueezeNonPlannedSegs(t *testing.T) { collectionID int64 compactTime *compactTime } - paramtable.Init() vecFieldID := int64(201) genSeg := func(segID, numRows int64) *datapb.SegmentInfo { @@ -1517,7 +1508,6 @@ func Test_compactionTrigger_noplan_random_size(t *testing.T) { collectionID int64 compactTime *compactTime } - paramtable.Init() segmentInfos := &SegmentsInfo{ segments: make(map[UniqueID]*SegmentInfo), @@ -1668,7 +1658,6 @@ func Test_compactionTrigger_noplan_random_size(t *testing.T) { } for _, plan := range plans { - size := int64(0) for _, log := range plan.SegmentBinlogs { size += log.FieldBinlogs[0].GetBinlogs()[0].LogSize @@ -1689,8 +1678,6 @@ func Test_compactionTrigger_noplan_random_size(t *testing.T) { // Test shouldDoSingleCompaction func Test_compactionTrigger_shouldDoSingleCompaction(t *testing.T) { - paramtable.Init() - trigger := newCompactionTrigger(&meta{}, &compactionPlanHandler{}, newMockAllocator(), newMockHandler()) // Test too many deltalogs. @@ -1719,7 +1706,7 @@ func Test_compactionTrigger_shouldDoSingleCompaction(t *testing.T) { couldDo := trigger.ShouldDoSingleCompaction(info, false, &compactTime{}) assert.True(t, couldDo) - //Test too many stats log + // Test too many stats log info = &SegmentInfo{ SegmentInfo: &datapb.SegmentInfo{ ID: 1, @@ -1747,12 +1734,12 @@ func Test_compactionTrigger_shouldDoSingleCompaction(t *testing.T) { couldDo = trigger.ShouldDoSingleCompaction(info, true, &compactTime{}) assert.False(t, couldDo) - //Test too many stats log but compacted + // Test too many stats log but compacted info.CompactionFrom = []int64{0, 1} couldDo = trigger.ShouldDoSingleCompaction(info, false, &compactTime{}) assert.False(t, couldDo) - //Test expire triggered compaction + // Test expire triggered compaction var binlogs2 []*datapb.FieldBinlog for i := UniqueID(0); i < 100; i++ { binlogs2 = append(binlogs2, &datapb.FieldBinlog{ @@ -1985,40 +1972,41 @@ func (s *CompactionTriggerSuite) SetupTest() { s.indexID = 300 s.vecFieldID = 400 s.channel = "dml_0_100v0" - s.meta = &meta{segments: &SegmentsInfo{ - map[int64]*SegmentInfo{ - 1: { - SegmentInfo: s.genSeg(1, 60), - lastFlushTime: time.Now().Add(-100 * time.Minute), - segmentIndexes: s.genSegIndex(1, indexID, 60), - }, - 2: { - SegmentInfo: s.genSeg(2, 60), - lastFlushTime: time.Now(), - segmentIndexes: s.genSegIndex(2, indexID, 60), - }, - 3: { - SegmentInfo: s.genSeg(3, 60), - lastFlushTime: time.Now(), - segmentIndexes: s.genSegIndex(3, indexID, 60), - }, - 4: { - SegmentInfo: s.genSeg(4, 60), - lastFlushTime: time.Now(), - segmentIndexes: s.genSegIndex(4, indexID, 60), - }, - 5: { - SegmentInfo: s.genSeg(5, 26), - lastFlushTime: time.Now(), - segmentIndexes: s.genSegIndex(5, indexID, 26), - }, - 6: { - SegmentInfo: s.genSeg(6, 26), - lastFlushTime: time.Now(), - segmentIndexes: s.genSegIndex(6, indexID, 26), + s.meta = &meta{ + segments: &SegmentsInfo{ + map[int64]*SegmentInfo{ + 1: { + SegmentInfo: s.genSeg(1, 60), + lastFlushTime: time.Now().Add(-100 * time.Minute), + segmentIndexes: s.genSegIndex(1, indexID, 60), + }, + 2: { + SegmentInfo: s.genSeg(2, 60), + lastFlushTime: time.Now(), + segmentIndexes: s.genSegIndex(2, indexID, 60), + }, + 3: { + SegmentInfo: s.genSeg(3, 60), + lastFlushTime: time.Now(), + segmentIndexes: s.genSegIndex(3, indexID, 60), + }, + 4: { + SegmentInfo: s.genSeg(4, 60), + lastFlushTime: time.Now(), + segmentIndexes: s.genSegIndex(4, indexID, 60), + }, + 5: { + SegmentInfo: s.genSeg(5, 26), + lastFlushTime: time.Now(), + segmentIndexes: s.genSegIndex(5, indexID, 26), + }, + 6: { + SegmentInfo: s.genSeg(6, 26), + lastFlushTime: time.Now(), + segmentIndexes: s.genSegIndex(6, indexID, 26), + }, }, }, - }, collections: map[int64]*collectionInfo{ s.collectionID: { ID: s.collectionID, @@ -2072,7 +2060,7 @@ func (s *CompactionTriggerSuite) TestHandleSignal() { defer s.SetupTest() tr := s.tr s.compactionHandler.EXPECT().isFull().Return(false) - //s.allocator.EXPECT().allocTimestamp(mock.Anything).Return(10000, nil) + // s.allocator.EXPECT().allocTimestamp(mock.Anything).Return(10000, nil) s.handler.EXPECT().GetCollection(mock.Anything, int64(100)).Return(nil, errors.New("mocked")) tr.handleSignal(&compactionSignal{ segmentID: 1, @@ -2089,7 +2077,7 @@ func (s *CompactionTriggerSuite) TestHandleSignal() { defer s.SetupTest() tr := s.tr s.compactionHandler.EXPECT().isFull().Return(false) - //s.allocator.EXPECT().allocTimestamp(mock.Anything).Return(10000, nil) + // s.allocator.EXPECT().allocTimestamp(mock.Anything).Return(10000, nil) s.handler.EXPECT().GetCollection(mock.Anything, int64(100)).Return(&collectionInfo{ Properties: map[string]string{ common.CollectionAutoCompactionKey: "bad_value", diff --git a/internal/datacoord/coordinator_broker.go b/internal/datacoord/coordinator_broker.go index 91252d2ec4ae4..3a3e603b532ad 100644 --- a/internal/datacoord/coordinator_broker.go +++ b/internal/datacoord/coordinator_broker.go @@ -20,6 +20,8 @@ import ( "time" "github.com/cockroachdb/errors" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/types" @@ -27,7 +29,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" - "go.uber.org/zap" ) const ( @@ -43,10 +44,10 @@ type Broker interface { } type CoordinatorBroker struct { - rootCoord types.RootCoord + rootCoord types.RootCoordClient } -func NewCoordinatorBroker(rootCoord types.RootCoord) *CoordinatorBroker { +func NewCoordinatorBroker(rootCoord types.RootCoordClient) *CoordinatorBroker { return &CoordinatorBroker{ rootCoord: rootCoord, } diff --git a/internal/datacoord/errors.go b/internal/datacoord/errors.go index 05497cb514c99..9127a28f8aaac 100644 --- a/internal/datacoord/errors.go +++ b/internal/datacoord/errors.go @@ -22,16 +22,12 @@ import ( "github.com/cockroachdb/errors" ) -// errNilKvClient stands for a nil kv client is detected when initialized -var errNilKvClient = errors.New("kv client not initialized") - -// serverNotServingErrMsg used for Status Reason when DataCoord is not healthy -const serverNotServingErrMsg = "DataCoord is not serving" - // errors for VerifyResponse -var errNilResponse = errors.New("response is nil") -var errNilStatusResponse = errors.New("response has nil status") -var errUnknownResponseType = errors.New("unknown response type") +var ( + errNilResponse = errors.New("response is nil") + errNilStatusResponse = errors.New("response has nil status") + errUnknownResponseType = errors.New("unknown response type") +) func msgDataCoordIsUnhealthy(coordID UniqueID) string { return fmt.Sprintf("DataCoord %d is not ready", coordID) diff --git a/internal/datacoord/errors_test.go b/internal/datacoord/errors_test.go index 581df1f412fcb..5e1d722b4490c 100644 --- a/internal/datacoord/errors_test.go +++ b/internal/datacoord/errors_test.go @@ -19,9 +19,10 @@ package datacoord import ( "testing" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" ) func TestMsgDataCoordIsUnhealthy(t *testing.T) { diff --git a/internal/datacoord/garbage_collector.go b/internal/datacoord/garbage_collector.go index 49a5d762b4762..cdec36bda3dff 100644 --- a/internal/datacoord/garbage_collector.go +++ b/internal/datacoord/garbage_collector.go @@ -18,6 +18,7 @@ package datacoord import ( "context" + "fmt" "path" "sort" "strings" @@ -33,17 +34,12 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/metautil" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) -const ( - //TODO silverxia change to configuration - insertLogPrefix = `insert_log` - statsLogPrefix = `stats_log` - deltaLogPrefix = `delta_log` -) - // GcOption garbage collection options type GcOption struct { cli storage.ChunkManager // client @@ -129,35 +125,43 @@ func (gc *garbageCollector) scan() { total = 0 valid = 0 missing = 0 - - segmentMap = typeutil.NewUniqueSet() - filesMap = typeutil.NewSet[string]() ) - segments := gc.meta.GetAllSegmentsUnsafe() - for _, segment := range segments { - segmentMap.Insert(segment.GetID()) - for _, log := range getLogs(segment) { - filesMap.Insert(log.GetLogPath()) + getMetaMap := func() (typeutil.UniqueSet, typeutil.Set[string]) { + segmentMap := typeutil.NewUniqueSet() + filesMap := typeutil.NewSet[string]() + segments := gc.meta.GetAllSegmentsUnsafe() + for _, segment := range segments { + segmentMap.Insert(segment.GetID()) + for _, log := range getLogs(segment) { + filesMap.Insert(log.GetLogPath()) + } } + return segmentMap, filesMap } // walk only data cluster related prefixes prefixes := make([]string, 0, 3) - prefixes = append(prefixes, path.Join(gc.option.cli.RootPath(), insertLogPrefix)) - prefixes = append(prefixes, path.Join(gc.option.cli.RootPath(), statsLogPrefix)) - prefixes = append(prefixes, path.Join(gc.option.cli.RootPath(), deltaLogPrefix)) + prefixes = append(prefixes, path.Join(gc.option.cli.RootPath(), common.SegmentInsertLogPath)) + prefixes = append(prefixes, path.Join(gc.option.cli.RootPath(), common.SegmentStatslogPath)) + prefixes = append(prefixes, path.Join(gc.option.cli.RootPath(), common.SegmentDeltaLogPath)) + labels := []string{metrics.InsertFileLabel, metrics.StatFileLabel, metrics.DeleteFileLabel} var removedKeys []string - for _, prefix := range prefixes { + for idx, prefix := range prefixes { startTs := time.Now() infoKeys, modTimes, err := gc.option.cli.ListWithPrefix(ctx, prefix, true) if err != nil { log.Error("failed to list files with prefix", zap.String("prefix", prefix), - zap.String("error", err.Error()), + zap.Error(err), ) } - log.Info("gc scan finish list object", zap.String("prefix", prefix), zap.Duration("time spent", time.Since(startTs)), zap.Int("keys", len(infoKeys))) + cost := time.Since(startTs) + segmentMap, filesMap := getMetaMap() + metrics.GarbageCollectorListLatency. + WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), labels[idx]). + Observe(float64(cost.Milliseconds())) + log.Info("gc scan finish list object", zap.String("prefix", prefix), zap.Duration("time spent", cost), zap.Int("keys", len(infoKeys))) for i, infoKey := range infoKeys { total++ _, has := filesMap[infoKey] @@ -175,7 +179,7 @@ func (gc *garbageCollector) scan() { continue } - if strings.Contains(prefix, statsLogPrefix) && + if strings.Contains(prefix, common.SegmentInsertLogPath) && segmentMap.Contain(segmentID) { valid++ continue @@ -195,6 +199,7 @@ func (gc *garbageCollector) scan() { } } } + metrics.GarbageCollectorRunCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Add(1) log.Info("scan file to do garbage collection", zap.Int("total", total), zap.Int("valid", valid), @@ -205,7 +210,8 @@ func (gc *garbageCollector) scan() { func (gc *garbageCollector) checkDroppedSegmentGC(segment *SegmentInfo, childSegment *SegmentInfo, indexSet typeutil.UniqueSet, - cpTimestamp Timestamp) bool { + cpTimestamp Timestamp, +) bool { log := log.With(zap.Int64("segmentID", segment.ID)) isCompacted := childSegment != nil || segment.GetCompacted() @@ -250,7 +256,7 @@ func (gc *garbageCollector) clearEtcd() { if segment.GetState() == commonpb.SegmentState_Dropped { drops[segment.GetID()] = segment channels.Insert(segment.GetInsertChannel()) - //continue + // continue // A(indexed), B(indexed) -> C(no indexed), D(no indexed) -> E(no indexed), A, B can not be GC } for _, from := range segment.GetCompactionFrom() { diff --git a/internal/datacoord/garbage_collector_test.go b/internal/datacoord/garbage_collector_test.go index 97847bb1909b6..86367f2d13514 100644 --- a/internal/datacoord/garbage_collector_test.go +++ b/internal/datacoord/garbage_collector_test.go @@ -45,6 +45,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -52,7 +53,7 @@ import ( func Test_garbageCollector_basic(t *testing.T) { bucketName := `datacoord-ut` + strings.ToLower(funcutil.RandomString(8)) rootPath := `gc` + funcutil.RandomString(8) - //TODO change to Params + // TODO change to Params cli, _, _, _, _, err := initUtOSSEnv(bucketName, rootPath, 0) require.NoError(t, err) @@ -91,7 +92,6 @@ func Test_garbageCollector_basic(t *testing.T) { gc.close() }) }) - } func validateMinioPrefixElements(t *testing.T, cli *minio.Client, bucketName string, prefix string, elements []string) { @@ -105,7 +105,7 @@ func validateMinioPrefixElements(t *testing.T, cli *minio.Client, bucketName str func Test_garbageCollector_scan(t *testing.T) { bucketName := `datacoord-ut` + strings.ToLower(funcutil.RandomString(8)) rootPath := `gc` + funcutil.RandomString(8) - //TODO change to Params + // TODO change to Params cli, inserts, stats, delta, others, err := initUtOSSEnv(bucketName, rootPath, 4) require.NoError(t, err) @@ -122,9 +122,9 @@ func Test_garbageCollector_scan(t *testing.T) { }) gc.scan() - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, insertLogPrefix), inserts) - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, statsLogPrefix), stats) - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, deltaLogPrefix), delta) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentInsertLogPath), inserts) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentStatslogPath), stats) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentDeltaLogPath), delta) validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, `indexes`), others) gc.close() }) @@ -139,9 +139,9 @@ func Test_garbageCollector_scan(t *testing.T) { }) gc.scan() - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, insertLogPrefix), inserts) - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, statsLogPrefix), stats) - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, deltaLogPrefix), delta) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentInsertLogPath), inserts) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentStatslogPath), stats) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentDeltaLogPath), delta) validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, `indexes`), others) gc.close() @@ -152,7 +152,7 @@ func Test_garbageCollector_scan(t *testing.T) { segment.Binlogs = []*datapb.FieldBinlog{getFieldBinlogPaths(0, inserts[0])} segment.Statslogs = []*datapb.FieldBinlog{getFieldBinlogPaths(0, stats[0])} segment.Deltalogs = []*datapb.FieldBinlog{getFieldBinlogPaths(0, delta[0])} - err = meta.AddSegment(segment) + err = meta.AddSegment(context.TODO(), segment) require.NoError(t, err) gc := newGarbageCollector(meta, newMockHandler(), GcOption{ @@ -164,9 +164,9 @@ func Test_garbageCollector_scan(t *testing.T) { }) gc.start() gc.scan() - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, insertLogPrefix), inserts) - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, statsLogPrefix), stats) - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, deltaLogPrefix), delta) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentInsertLogPath), inserts) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentStatslogPath), stats) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentDeltaLogPath), delta) validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, `indexes`), others) gc.close() @@ -180,7 +180,7 @@ func Test_garbageCollector_scan(t *testing.T) { segment.Statslogs = []*datapb.FieldBinlog{getFieldBinlogPaths(0, stats[0])} segment.Deltalogs = []*datapb.FieldBinlog{getFieldBinlogPaths(0, delta[0])} - err = meta.AddSegment(segment) + err = meta.AddSegment(context.TODO(), segment) require.NoError(t, err) gc := newGarbageCollector(meta, newMockHandler(), GcOption{ @@ -191,9 +191,9 @@ func Test_garbageCollector_scan(t *testing.T) { dropTolerance: 0, }) gc.clearEtcd() - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, insertLogPrefix), inserts[1:]) - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, statsLogPrefix), stats[1:]) - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, deltaLogPrefix), delta[1:]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentInsertLogPath), inserts[1:]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentStatslogPath), stats[1:]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentDeltaLogPath), delta[1:]) validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, `indexes`), others) gc.close() @@ -211,9 +211,9 @@ func Test_garbageCollector_scan(t *testing.T) { gc.clearEtcd() // bad path shall remains since datacoord cannot determine file is garbage or not if path is not valid - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, insertLogPrefix), inserts[1:2]) - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, statsLogPrefix), stats[1:2]) - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, deltaLogPrefix), delta[1:2]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentInsertLogPath), inserts[1:2]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentStatslogPath), stats[1:2]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentDeltaLogPath), delta[1:2]) validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, `indexes`), others) gc.close() @@ -231,9 +231,9 @@ func Test_garbageCollector_scan(t *testing.T) { gc.scan() // bad path shall remains since datacoord cannot determine file is garbage or not if path is not valid - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, insertLogPrefix), inserts[1:2]) - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, statsLogPrefix), stats[1:2]) - validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, deltaLogPrefix), delta[1:2]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentInsertLogPath), inserts[1:2]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentStatslogPath), stats[1:2]) + validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, common.SegmentDeltaLogPath), delta[1:2]) validateMinioPrefixElements(t, cli.Client, bucketName, path.Join(rootPath, `indexes`), others) gc.close() @@ -280,14 +280,14 @@ func initUtOSSEnv(bucket, root string, n int) (mcm *storage.MinioChunkManager, i token = path.Join(strconv.Itoa(1+i), strconv.Itoa(10+i), strconv.Itoa(100+i), funcutil.RandomString(8), funcutil.RandomString(8)) } // insert - filePath := path.Join(root, insertLogPrefix, token) + filePath := path.Join(root, common.SegmentInsertLogPath, token) info, err := cli.PutObject(context.TODO(), bucket, filePath, reader, int64(len(content)), minio.PutObjectOptions{}) if err != nil { return nil, nil, nil, nil, nil, err } inserts = append(inserts, info.Key) // stats - filePath = path.Join(root, statsLogPrefix, token) + filePath = path.Join(root, common.SegmentStatslogPath, token) info, err = cli.PutObject(context.TODO(), bucket, filePath, reader, int64(len(content)), minio.PutObjectOptions{}) if err != nil { return nil, nil, nil, nil, nil, err @@ -300,7 +300,7 @@ func initUtOSSEnv(bucket, root string, n int) (mcm *storage.MinioChunkManager, i } else { token = path.Join(strconv.Itoa(1+i), strconv.Itoa(10+i), strconv.Itoa(100+i), funcutil.RandomString(8)) } - filePath = path.Join(root, deltaLogPrefix, token) + filePath = path.Join(root, common.SegmentDeltaLogPath, token) info, err = cli.PutObject(context.TODO(), bucket, filePath, reader, int64(len(content)), minio.PutObjectOptions{}) if err != nil { return nil, nil, nil, nil, nil, err @@ -332,7 +332,7 @@ func createMetaForRecycleUnusedIndexes(catalog metastore.DataCoordCatalog) *meta var ( ctx = context.Background() collID = UniqueID(100) - //partID = UniqueID(200) + // partID = UniqueID(200) fieldID = UniqueID(300) indexID = UniqueID(400) ) @@ -426,7 +426,7 @@ func createMetaForRecycleUnusedSegIndexes(catalog metastore.DataCoordCatalog) *m ctx = context.Background() collID = UniqueID(100) partID = UniqueID(200) - //fieldID = UniqueID(300) + // fieldID = UniqueID(300) indexID = UniqueID(400) segID = UniqueID(500) ) @@ -569,7 +569,7 @@ func createMetaTableForRecycleUnusedIndexFiles(catalog *datacoord.Catalog) *meta ctx = context.Background() collID = UniqueID(100) partID = UniqueID(200) - //fieldID = UniqueID(300) + // fieldID = UniqueID(300) indexID = UniqueID(400) segID = UniqueID(500) buildID = UniqueID(600) diff --git a/internal/datacoord/handler.go b/internal/datacoord/handler.go index ce05ed4b31c90..4db996cda0ace 100644 --- a/internal/datacoord/handler.go +++ b/internal/datacoord/handler.go @@ -20,7 +20,6 @@ import ( "context" "time" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/samber/lo" "go.uber.org/zap" @@ -30,6 +29,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/retry" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -366,7 +366,7 @@ func trimSegmentInfo(info *datapb.SegmentInfo) *datapb.SegmentInfo { // HasCollection returns whether the collection exist from user's perspective. func (h *ServerHandler) HasCollection(ctx context.Context, collectionID UniqueID) (bool, error) { var hasCollection bool - ctx2, cancel := context.WithTimeout(ctx, time.Minute*30) + ctx2, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() if err := retry.Do(ctx2, func() error { has, err := h.s.broker.HasCollection(ctx2, collectionID) @@ -376,9 +376,13 @@ func (h *ServerHandler) HasCollection(ctx context.Context, collectionID UniqueID } hasCollection = has return nil - }, retry.Attempts(500)); err != nil { - log.Ctx(ctx2).Error("datacoord ServerHandler HasCollection finally failed", zap.Int64("collectionID", collectionID)) - log.Panic("datacoord ServerHandler HasCollection finally failed") + }, retry.Attempts(5)); err != nil { + log.Ctx(ctx2).Error("datacoord ServerHandler HasCollection finally failed", + zap.Int64("collectionID", collectionID), + zap.Error(err)) + // A workaround for https://github.com/milvus-io/milvus/issues/26863. The collection may be considered as not + // dropped when any exception happened, but there are chances that finally the collection will be cleaned. + return true, nil } return hasCollection, nil } diff --git a/internal/datacoord/index_builder.go b/internal/datacoord/index_builder.go index fd13493adba58..34bb6d56d9bab 100644 --- a/internal/datacoord/index_builder.go +++ b/internal/datacoord/index_builder.go @@ -23,8 +23,6 @@ import ( "sync" "time" - "github.com/cockroachdb/errors" - "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -34,6 +32,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" ) type indexTaskState int32 @@ -80,24 +79,31 @@ type indexBuilder struct { meta *meta - policy buildIndexPolicy - nodeManager *IndexNodeManager - chunkManager storage.ChunkManager + policy buildIndexPolicy + nodeManager *IndexNodeManager + chunkManager storage.ChunkManager + indexEngineVersionManager *IndexEngineVersionManager } -func newIndexBuilder(ctx context.Context, metaTable *meta, nodeManager *IndexNodeManager, chunkManager storage.ChunkManager) *indexBuilder { +func newIndexBuilder( + ctx context.Context, + metaTable *meta, nodeManager *IndexNodeManager, + chunkManager storage.ChunkManager, + indexEngineVersionManager *IndexEngineVersionManager, +) *indexBuilder { ctx, cancel := context.WithCancel(ctx) ib := &indexBuilder{ - ctx: ctx, - cancel: cancel, - meta: metaTable, - tasks: make(map[int64]indexTaskState), - notifyChan: make(chan struct{}, 1), - scheduleDuration: Params.DataCoordCfg.IndexTaskSchedulerInterval.GetAsDuration(time.Millisecond), - policy: defaultBuildIndexPolicy, - nodeManager: nodeManager, - chunkManager: chunkManager, + ctx: ctx, + cancel: cancel, + meta: metaTable, + tasks: make(map[int64]indexTaskState), + notifyChan: make(chan struct{}, 1), + scheduleDuration: Params.DataCoordCfg.IndexTaskSchedulerInterval.GetAsDuration(time.Millisecond), + policy: defaultBuildIndexPolicy, + nodeManager: nodeManager, + chunkManager: chunkManager, + indexEngineVersionManager: indexEngineVersionManager, } ib.reloadFromKV() return ib @@ -231,7 +237,7 @@ func (ib *indexBuilder) process(buildID UniqueID) bool { } indexParams := ib.meta.GetIndexParams(meta.CollectionID, meta.IndexID) if isFlatIndex(getIndexType(indexParams)) || meta.NumRows < Params.DataCoordCfg.MinSegmentNumRowsToEnableIndex.GetAsInt64() { - log.Ctx(ib.ctx).Debug("segment does not need index really", zap.Int64("buildID", buildID), + log.Ctx(ib.ctx).Info("segment does not need index really", zap.Int64("buildID", buildID), zap.Int64("segmentID", meta.SegmentID), zap.Int64("num rows", meta.NumRows)) if err := ib.meta.FinishTask(&indexpb.IndexTaskInfo{ BuildID: buildID, @@ -280,17 +286,19 @@ func (ib *indexBuilder) process(buildID UniqueID) bool { } } else { storageConfig = &indexpb.StorageConfig{ - Address: Params.MinioCfg.Address.GetValue(), - AccessKeyID: Params.MinioCfg.AccessKeyID.GetValue(), - SecretAccessKey: Params.MinioCfg.SecretAccessKey.GetValue(), - UseSSL: Params.MinioCfg.UseSSL.GetAsBool(), - BucketName: Params.MinioCfg.BucketName.GetValue(), - RootPath: Params.MinioCfg.RootPath.GetValue(), - UseIAM: Params.MinioCfg.UseIAM.GetAsBool(), - IAMEndpoint: Params.MinioCfg.IAMEndpoint.GetValue(), - StorageType: Params.CommonCfg.StorageType.GetValue(), - Region: Params.MinioCfg.Region.GetValue(), - UseVirtualHost: Params.MinioCfg.UseVirtualHost.GetAsBool(), + Address: Params.MinioCfg.Address.GetValue(), + AccessKeyID: Params.MinioCfg.AccessKeyID.GetValue(), + SecretAccessKey: Params.MinioCfg.SecretAccessKey.GetValue(), + UseSSL: Params.MinioCfg.UseSSL.GetAsBool(), + BucketName: Params.MinioCfg.BucketName.GetValue(), + RootPath: Params.MinioCfg.RootPath.GetValue(), + UseIAM: Params.MinioCfg.UseIAM.GetAsBool(), + IAMEndpoint: Params.MinioCfg.IAMEndpoint.GetValue(), + StorageType: Params.CommonCfg.StorageType.GetValue(), + Region: Params.MinioCfg.Region.GetValue(), + UseVirtualHost: Params.MinioCfg.UseVirtualHost.GetAsBool(), + CloudProvider: Params.MinioCfg.CloudProvider.GetValue(), + RequestTimeoutMs: Params.MinioCfg.RequestTimeoutMs.GetAsInt64(), } } @@ -313,25 +321,26 @@ func (ib *indexBuilder) process(buildID UniqueID) bool { } req := &indexpb.CreateJobRequest{ - ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(), - IndexFilePrefix: path.Join(ib.chunkManager.RootPath(), common.SegmentIndexPath), - BuildID: buildID, - DataPaths: binLogs, - IndexVersion: meta.IndexVersion + 1, - StorageConfig: storageConfig, - IndexParams: indexParams, - TypeParams: typeParams, - NumRows: meta.NumRows, - CollectionID: segment.GetCollectionID(), - PartitionID: segment.GetPartitionID(), - SegmentID: segment.GetID(), - FieldID: fieldID, - FieldName: field.Name, - FieldType: field.DataType, - StorePath: fmt.Sprintf("s3://%s:%s@%s/%d?scheme=%s&endpoint_override=%s&allow_bucket_creation=true", Params.MinioCfg.AccessKeyID.GetValue(), Params.MinioCfg.SecretAccessKey.GetValue(), Params.MinioCfg.BucketName.GetValue(), segment.GetID(), scheme, Params.MinioCfg.Address.GetValue()), - StoreVersion: segment.GetStorageVersion(), - IndexStorePath: fmt.Sprintf("s3://%s:%s@%s/index/%d?scheme=%s&endpoint_override=%s&allow_bucket_creation=true", Params.MinioCfg.AccessKeyID.GetValue(), Params.MinioCfg.SecretAccessKey.GetValue(), Params.MinioCfg.BucketName.GetValue(), segment.GetID(), scheme, Params.MinioCfg.Address.GetValue()), - Dim: int64(dim), + ClusterID: Params.CommonCfg.ClusterPrefix.GetValue(), + IndexFilePrefix: path.Join(ib.chunkManager.RootPath(), common.SegmentIndexPath), + BuildID: buildID, + DataPaths: binLogs, + IndexVersion: meta.IndexVersion + 1, + StorageConfig: storageConfig, + IndexParams: indexParams, + TypeParams: typeParams, + NumRows: meta.NumRows, + CollectionID: segment.GetCollectionID(), + PartitionID: segment.GetPartitionID(), + SegmentID: segment.GetID(), + FieldID: fieldID, + FieldName: field.Name, + FieldType: field.DataType, + StorePath: fmt.Sprintf("s3://%s:%s@%s/%d?scheme=%s&endpoint_override=%s&allow_bucket_creation=true", Params.MinioCfg.AccessKeyID.GetValue(), Params.MinioCfg.SecretAccessKey.GetValue(), Params.MinioCfg.BucketName.GetValue(), segment.GetID(), scheme, Params.MinioCfg.Address.GetValue()), + StoreVersion: segment.GetStorageVersion(), + IndexStorePath: fmt.Sprintf("s3://%s:%s@%s/index/%d?scheme=%s&endpoint_override=%s&allow_bucket_creation=true", Params.MinioCfg.AccessKeyID.GetValue(), Params.MinioCfg.SecretAccessKey.GetValue(), Params.MinioCfg.BucketName.GetValue(), segment.GetID(), scheme, Params.MinioCfg.Address.GetValue()), + Dim: int64(dim), + CurrentIndexVersion: ib.indexEngineVersionManager.GetCurrentIndexEngineVersion(), } if err := ib.assignTask(client, req); err != nil { @@ -385,26 +394,26 @@ func (ib *indexBuilder) getTaskState(buildID, nodeID UniqueID) indexTaskState { zap.Error(err)) return indexTaskInProgress } - if response.Status.ErrorCode != commonpb.ErrorCode_Success { + if response.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { log.Ctx(ib.ctx).Warn("IndexCoord get jobs info from IndexNode fail", zap.Int64("nodeID", nodeID), - zap.Int64("buildID", buildID), zap.String("fail reason", response.Status.Reason)) + zap.Int64("buildID", buildID), zap.String("fail reason", response.GetStatus().GetReason())) return indexTaskInProgress } // indexInfos length is always one. - for _, info := range response.IndexInfos { - if info.BuildID == buildID { - if info.State == commonpb.IndexState_Failed || info.State == commonpb.IndexState_Finished { - log.Ctx(ib.ctx).Info("this task has been finished", zap.Int64("buildID", info.BuildID), - zap.String("index state", info.State.String())) + for _, info := range response.GetIndexInfos() { + if info.GetBuildID() == buildID { + if info.GetState() == commonpb.IndexState_Failed || info.GetState() == commonpb.IndexState_Finished { + log.Ctx(ib.ctx).Info("this task has been finished", zap.Int64("buildID", info.GetBuildID()), + zap.String("index state", info.GetState().String())) if err := ib.meta.FinishTask(info); err != nil { - log.Ctx(ib.ctx).Warn("IndexCoord update index state fail", zap.Int64("buildID", info.BuildID), - zap.String("index state", info.State.String()), zap.Error(err)) + log.Ctx(ib.ctx).Warn("IndexCoord update index state fail", zap.Int64("buildID", info.GetBuildID()), + zap.String("index state", info.GetState().String()), zap.Error(err)) return indexTaskInProgress } return indexTaskDone - } else if info.State == commonpb.IndexState_Retry || info.State == commonpb.IndexState_IndexStateNone { - log.Ctx(ib.ctx).Info("this task should be retry", zap.Int64("buildID", buildID), zap.String("fail reason", info.FailReason)) + } else if info.GetState() == commonpb.IndexState_Retry || info.GetState() == commonpb.IndexState_IndexStateNone { + log.Ctx(ib.ctx).Info("this task should be retry", zap.Int64("buildID", buildID), zap.String("fail reason", info.GetFailReason())) return indexTaskRetry } return indexTaskInProgress @@ -434,9 +443,9 @@ func (ib *indexBuilder) dropIndexTask(buildID, nodeID UniqueID) bool { zap.Int64("nodeID", nodeID), zap.Error(err)) return false } - if status.ErrorCode != commonpb.ErrorCode_Success { + if status.GetErrorCode() != commonpb.ErrorCode_Success { log.Ctx(ib.ctx).Warn("IndexCoord notify IndexNode drop the index task fail", zap.Int64("buildID", buildID), - zap.Int64("nodeID", nodeID), zap.String("fail reason", status.Reason)) + zap.Int64("nodeID", nodeID), zap.String("fail reason", status.GetReason())) return false } log.Ctx(ib.ctx).Info("IndexCoord notify IndexNode drop the index task success", @@ -450,19 +459,18 @@ func (ib *indexBuilder) dropIndexTask(buildID, nodeID UniqueID) bool { // assignTask sends the index task to the IndexNode, it has a timeout interval, if the IndexNode doesn't respond within // the interval, it is considered that the task sending failed. -func (ib *indexBuilder) assignTask(builderClient types.IndexNode, req *indexpb.CreateJobRequest) error { +func (ib *indexBuilder) assignTask(builderClient types.IndexNodeClient, req *indexpb.CreateJobRequest) error { ctx, cancel := context.WithTimeout(context.Background(), reqTimeoutInterval) defer cancel() resp, err := builderClient.CreateJob(ctx, req) + if err == nil { + err = merr.Error(resp) + } if err != nil { log.Error("IndexCoord assignmentTasksLoop builderClient.CreateIndex failed", zap.Error(err)) return err } - if resp.ErrorCode != commonpb.ErrorCode_Success { - log.Error("IndexCoord assignmentTasksLoop builderClient.CreateIndex failed", zap.String("Reason", resp.Reason)) - return errors.New(resp.Reason) - } return nil } diff --git a/internal/datacoord/index_builder_test.go b/internal/datacoord/index_builder_test.go index 80c6510c78fe9..5bd1f6bcb4b3f 100644 --- a/internal/datacoord/index_builder_test.go +++ b/internal/datacoord/index_builder_test.go @@ -22,12 +22,11 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/internal/indexnode" "github.com/milvus-io/milvus/internal/metastore" catalogmocks "github.com/milvus-io/milvus/internal/metastore/mocks" "github.com/milvus-io/milvus/internal/metastore/model" @@ -35,7 +34,9 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/types" + mclient "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -615,17 +616,57 @@ func TestIndexBuilder(t *testing.T) { mock.Anything, ).Return(nil) + ic := mocks.NewMockIndexNodeClient(t) + ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything). + Return(&indexpb.GetJobStatsResponse{ + Status: merr.Success(), + TotalJobNum: 1, + EnqueueJobNum: 0, + InProgressJobNum: 1, + TaskSlots: 1, + JobInfos: []*indexpb.JobInfo{ + { + NumRows: 1024, + Dim: 128, + StartTime: 1, + EndTime: 10, + PodID: 1, + }, + }, + }, nil) + ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, in *indexpb.QueryJobsRequest, option ...grpc.CallOption) (*indexpb.QueryJobsResponse, error) { + indexInfos := make([]*indexpb.IndexTaskInfo, 0) + for _, buildID := range in.BuildIDs { + indexInfos = append(indexInfos, &indexpb.IndexTaskInfo{ + BuildID: buildID, + State: commonpb.IndexState_Finished, + IndexFileKeys: []string{"file1", "file2"}, + }) + } + return &indexpb.QueryJobsResponse{ + Status: merr.Success(), + ClusterID: in.ClusterID, + IndexInfos: indexInfos, + }, nil + }) + + ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(merr.Success(), nil) + + ic.EXPECT().DropJobs(mock.Anything, mock.Anything, mock.Anything). + Return(merr.Success(), nil) mt := createMetaTable(catalog) nodeManager := &IndexNodeManager{ ctx: ctx, - nodeClients: map[UniqueID]types.IndexNode{ - 4: indexnode.NewIndexNodeMock(), + nodeClients: map[UniqueID]types.IndexNodeClient{ + 4: ic, }, } chunkManager := &mocks.ChunkManager{} chunkManager.EXPECT().RootPath().Return("root") - ib := newIndexBuilder(ctx, mt, nodeManager, chunkManager) + ib := newIndexBuilder(ctx, mt, nodeManager, chunkManager, newIndexEngineVersionManager()) assert.Equal(t, 6, len(ib.tasks)) assert.Equal(t, indexTaskInit, ib.tasks[buildID]) @@ -696,8 +737,9 @@ func TestIndexBuilder_Error(t *testing.T) { tasks: map[int64]indexTaskState{ buildID: indexTaskInit, }, - meta: createMetaTable(ec), - chunkManager: chunkManager, + meta: createMetaTable(ec), + chunkManager: chunkManager, + indexEngineVersionManager: newIndexEngineVersionManager(), } t.Run("meta not exist", func(t *testing.T) { @@ -719,7 +761,7 @@ func TestIndexBuilder_Error(t *testing.T) { t.Run("peek client fail", func(t *testing.T) { ib.tasks[buildID] = indexTaskInit - ib.nodeManager = &IndexNodeManager{nodeClients: map[UniqueID]types.IndexNode{}} + ib.nodeManager = &IndexNodeManager{nodeClients: map[UniqueID]types.IndexNodeClient{}} ib.process(buildID) state, ok := ib.tasks[buildID] @@ -730,7 +772,7 @@ func TestIndexBuilder_Error(t *testing.T) { t.Run("update version fail", func(t *testing.T) { ib.nodeManager = &IndexNodeManager{ ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNode{1: indexnode.NewIndexNodeMock()}, + nodeClients: map[UniqueID]types.IndexNodeClient{1: &mclient.GrpcIndexNodeClient{Err: nil}}, } ib.process(buildID) @@ -765,23 +807,18 @@ func TestIndexBuilder_Error(t *testing.T) { paramtable.Get().Save(Params.CommonCfg.StorageType.Key, "local") ib.tasks[buildID] = indexTaskInit ib.meta.catalog = sc + + ic := mocks.NewMockIndexNodeClient(t) + ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("error")) + ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything).Return(&indexpb.GetJobStatsResponse{ + Status: merr.Success(), + TaskSlots: 1, + }, nil) + ib.nodeManager = &IndexNodeManager{ ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNode{ - 1: &indexnode.Mock{ - CallCreateJob: func(ctx context.Context, req *indexpb.CreateJobRequest) (*commonpb.Status, error) { - return nil, errors.New("error") - }, - CallGetJobStats: func(ctx context.Context, in *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - TaskSlots: 1, - }, nil - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + 1: ic, }, } ib.process(buildID) @@ -793,26 +830,20 @@ func TestIndexBuilder_Error(t *testing.T) { t.Run("assign task fail", func(t *testing.T) { paramtable.Get().Save(Params.CommonCfg.StorageType.Key, "local") ib.meta.catalog = sc + ic := mocks.NewMockIndexNodeClient(t) + ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything).Return(&commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "mock fail", + }, nil) + ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything).Return(&indexpb.GetJobStatsResponse{ + Status: merr.Success(), + TaskSlots: 1, + }, nil) + ib.nodeManager = &IndexNodeManager{ ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNode{ - 1: &indexnode.Mock{ - CallCreateJob: func(ctx context.Context, req *indexpb.CreateJobRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "mock fail", - }, nil - }, - CallGetJobStats: func(ctx context.Context, in *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - TaskSlots: 1, - }, nil - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + 1: ic, }, } ib.tasks[buildID] = indexTaskInit @@ -826,16 +857,15 @@ func TestIndexBuilder_Error(t *testing.T) { t.Run("drop job error", func(t *testing.T) { ib.meta.buildID2SegmentIndex[buildID].NodeID = nodeID ib.meta.catalog = sc + ic := mocks.NewMockIndexNodeClient(t) + ic.EXPECT().DropJobs(mock.Anything, mock.Anything, mock.Anything).Return(&commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, errors.New("error")) + ib.nodeManager = &IndexNodeManager{ ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNode{ - nodeID: &indexnode.Mock{ - CallDropJobs: func(ctx context.Context, in *indexpb.DropJobsRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, errors.New("error") - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + nodeID: ic, }, } ib.tasks[buildID] = indexTaskDone @@ -856,17 +886,16 @@ func TestIndexBuilder_Error(t *testing.T) { t.Run("drop job fail", func(t *testing.T) { ib.meta.buildID2SegmentIndex[buildID].NodeID = nodeID ib.meta.catalog = sc + ic := mocks.NewMockIndexNodeClient(t) + ic.EXPECT().DropJobs(mock.Anything, mock.Anything, mock.Anything).Return(&commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "mock fail", + }, nil) + ib.nodeManager = &IndexNodeManager{ ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNode{ - nodeID: &indexnode.Mock{ - CallDropJobs: func(ctx context.Context, in *indexpb.DropJobsRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "mock fail", - }, nil - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + nodeID: ic, }, } ib.tasks[buildID] = indexTaskDone @@ -887,14 +916,12 @@ func TestIndexBuilder_Error(t *testing.T) { t.Run("get state error", func(t *testing.T) { ib.meta.buildID2SegmentIndex[buildID].NodeID = nodeID ib.meta.catalog = sc + ic := mocks.NewMockIndexNodeClient(t) + ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("error")) ib.nodeManager = &IndexNodeManager{ ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNode{ - nodeID: &indexnode.Mock{ - CallQueryJobs: func(ctx context.Context, in *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error) { - return nil, errors.New("error") - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + nodeID: ic, }, } @@ -909,19 +936,17 @@ func TestIndexBuilder_Error(t *testing.T) { t.Run("get state fail", func(t *testing.T) { ib.meta.buildID2SegmentIndex[buildID].NodeID = nodeID ib.meta.catalog = sc + ic := mocks.NewMockIndexNodeClient(t) + ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything).Return(&indexpb.QueryJobsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_BuildIndexError, + Reason: "mock fail", + }, + }, nil) ib.nodeManager = &IndexNodeManager{ ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNode{ - nodeID: &indexnode.Mock{ - CallQueryJobs: func(ctx context.Context, in *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error) { - return &indexpb.QueryJobsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_BuildIndexError, - Reason: "mock fail", - }, - }, nil - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + nodeID: ic, }, } @@ -936,28 +961,24 @@ func TestIndexBuilder_Error(t *testing.T) { t.Run("finish task fail", func(t *testing.T) { ib.meta.buildID2SegmentIndex[buildID].NodeID = nodeID ib.meta.catalog = ec + ic := mocks.NewMockIndexNodeClient(t) + ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything).Return(&indexpb.QueryJobsResponse{ + Status: merr.Success(), + IndexInfos: []*indexpb.IndexTaskInfo{ + { + BuildID: buildID, + State: commonpb.IndexState_Finished, + IndexFileKeys: []string{"file1", "file2"}, + SerializedSize: 1024, + FailReason: "", + }, + }, + }, nil) + ib.nodeManager = &IndexNodeManager{ ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNode{ - nodeID: &indexnode.Mock{ - CallQueryJobs: func(ctx context.Context, in *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error) { - return &indexpb.QueryJobsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - IndexInfos: []*indexpb.IndexTaskInfo{ - { - BuildID: buildID, - State: commonpb.IndexState_Finished, - IndexFileKeys: []string{"file1", "file2"}, - SerializedSize: 1024, - FailReason: "", - }, - }, - }, nil - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + nodeID: ic, }, } @@ -972,28 +993,24 @@ func TestIndexBuilder_Error(t *testing.T) { t.Run("task still in progress", func(t *testing.T) { ib.meta.buildID2SegmentIndex[buildID].NodeID = nodeID ib.meta.catalog = ec + ic := mocks.NewMockIndexNodeClient(t) + ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything).Return(&indexpb.QueryJobsResponse{ + Status: merr.Success(), + IndexInfos: []*indexpb.IndexTaskInfo{ + { + BuildID: buildID, + State: commonpb.IndexState_InProgress, + IndexFileKeys: nil, + SerializedSize: 0, + FailReason: "", + }, + }, + }, nil) + ib.nodeManager = &IndexNodeManager{ ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNode{ - nodeID: &indexnode.Mock{ - CallQueryJobs: func(ctx context.Context, in *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error) { - return &indexpb.QueryJobsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - IndexInfos: []*indexpb.IndexTaskInfo{ - { - BuildID: buildID, - State: commonpb.IndexState_InProgress, - IndexFileKeys: nil, - SerializedSize: 0, - FailReason: "", - }, - }, - }, nil - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + nodeID: ic, }, } @@ -1008,20 +1025,15 @@ func TestIndexBuilder_Error(t *testing.T) { t.Run("indexNode has no task", func(t *testing.T) { ib.meta.buildID2SegmentIndex[buildID].NodeID = nodeID ib.meta.catalog = sc + ic := mocks.NewMockIndexNodeClient(t) + ic.EXPECT().QueryJobs(mock.Anything, mock.Anything, mock.Anything).Return(&indexpb.QueryJobsResponse{ + Status: merr.Success(), + IndexInfos: nil, + }, nil) ib.nodeManager = &IndexNodeManager{ ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNode{ - nodeID: &indexnode.Mock{ - CallQueryJobs: func(ctx context.Context, in *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error) { - return &indexpb.QueryJobsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - IndexInfos: nil, - }, nil - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + nodeID: ic, }, } @@ -1038,7 +1050,7 @@ func TestIndexBuilder_Error(t *testing.T) { ib.meta.catalog = sc ib.nodeManager = &IndexNodeManager{ ctx: context.Background(), - nodeClients: map[UniqueID]types.IndexNode{}, + nodeClients: map[UniqueID]types.IndexNodeClient{}, } ib.tasks[buildID] = indexTaskInProgress diff --git a/internal/datacoord/index_engine_version_manager.go b/internal/datacoord/index_engine_version_manager.go new file mode 100644 index 0000000000000..82944705cdc97 --- /dev/null +++ b/internal/datacoord/index_engine_version_manager.go @@ -0,0 +1,95 @@ +package datacoord + +import ( + "math" + "sync" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/log" +) + +type IndexEngineVersionManager struct { + mu sync.Mutex + versions map[int64]sessionutil.IndexEngineVersion +} + +func newIndexEngineVersionManager() *IndexEngineVersionManager { + return &IndexEngineVersionManager{ + versions: map[int64]sessionutil.IndexEngineVersion{}, + } +} + +func (m *IndexEngineVersionManager) Startup(sessions map[string]*sessionutil.Session) { + m.mu.Lock() + defer m.mu.Unlock() + + for _, session := range sessions { + m.addOrUpdate(session) + } +} + +func (m *IndexEngineVersionManager) AddNode(session *sessionutil.Session) { + m.mu.Lock() + defer m.mu.Unlock() + + m.addOrUpdate(session) +} + +func (m *IndexEngineVersionManager) RemoveNode(session *sessionutil.Session) { + m.mu.Lock() + defer m.mu.Unlock() + + delete(m.versions, session.ServerID) +} + +func (m *IndexEngineVersionManager) Update(session *sessionutil.Session) { + m.mu.Lock() + defer m.mu.Unlock() + + m.addOrUpdate(session) +} + +func (m *IndexEngineVersionManager) addOrUpdate(session *sessionutil.Session) { + log.Info("addOrUpdate version", zap.Int64("nodeId", session.ServerID), zap.Int32("minimal", session.IndexEngineVersion.MinimalIndexVersion), zap.Int32("current", session.IndexEngineVersion.CurrentIndexVersion)) + m.versions[session.ServerID] = session.IndexEngineVersion +} + +func (m *IndexEngineVersionManager) GetCurrentIndexEngineVersion() int32 { + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.versions) == 0 { + log.Info("index versions is empty") + return 0 + } + + current := int32(math.MaxInt32) + for _, version := range m.versions { + if version.CurrentIndexVersion < current { + current = version.CurrentIndexVersion + } + } + log.Info("Merged current version", zap.Int32("current", current)) + return current +} + +func (m *IndexEngineVersionManager) GetMinimalIndexEngineVersion() int32 { + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.versions) == 0 { + log.Info("index versions is empty") + return 0 + } + + minimal := int32(0) + for _, version := range m.versions { + if version.MinimalIndexVersion > minimal { + minimal = version.MinimalIndexVersion + } + } + log.Info("Merged minimal version", zap.Int32("minimal", minimal)) + return minimal +} diff --git a/internal/datacoord/index_engine_version_manager_test.go b/internal/datacoord/index_engine_version_manager_test.go new file mode 100644 index 0000000000000..d544b4e2c754f --- /dev/null +++ b/internal/datacoord/index_engine_version_manager_test.go @@ -0,0 +1,58 @@ +package datacoord + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/util/sessionutil" +) + +func Test_IndexEngineVersionManager_GetMergedIndexVersion(t *testing.T) { + m := newIndexEngineVersionManager() + + // empty + assert.Zero(t, m.GetCurrentIndexEngineVersion()) + + // startup + m.Startup(map[string]*sessionutil.Session{ + "1": { + SessionRaw: sessionutil.SessionRaw{ + ServerID: 1, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 20, MinimalIndexVersion: 0}, + }, + }, + }) + assert.Equal(t, int32(20), m.GetCurrentIndexEngineVersion()) + assert.Equal(t, int32(0), m.GetMinimalIndexEngineVersion()) + + // add node + m.AddNode(&sessionutil.Session{ + SessionRaw: sessionutil.SessionRaw{ + ServerID: 2, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 10, MinimalIndexVersion: 5}, + }, + }) + assert.Equal(t, int32(10), m.GetCurrentIndexEngineVersion()) + assert.Equal(t, int32(5), m.GetMinimalIndexEngineVersion()) + + // update + m.Update(&sessionutil.Session{ + SessionRaw: sessionutil.SessionRaw{ + ServerID: 2, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 5, MinimalIndexVersion: 2}, + }, + }) + assert.Equal(t, int32(5), m.GetCurrentIndexEngineVersion()) + assert.Equal(t, int32(2), m.GetMinimalIndexEngineVersion()) + + // remove + m.RemoveNode(&sessionutil.Session{ + SessionRaw: sessionutil.SessionRaw{ + ServerID: 2, + IndexEngineVersion: sessionutil.IndexEngineVersion{CurrentIndexVersion: 5, MinimalIndexVersion: 3}, + }, + }) + assert.Equal(t, int32(20), m.GetCurrentIndexEngineVersion()) + assert.Equal(t, int32(0), m.GetMinimalIndexEngineVersion()) +} diff --git a/internal/datacoord/index_meta.go b/internal/datacoord/index_meta.go index b8cf2cd3f7be3..05d2e57f0b651 100644 --- a/internal/datacoord/index_meta.go +++ b/internal/datacoord/index_meta.go @@ -22,6 +22,7 @@ import ( "strconv" "github.com/golang/protobuf/proto" + "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -30,7 +31,6 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" - "github.com/prometheus/client_golang/prometheus" ) func (m *meta) updateCollectionIndex(index *model.Index) { @@ -117,12 +117,12 @@ func checkParams(fieldIndex *model.Index, req *indexpb.CreateIndexRequest) bool if notEq { return false } - if len(fieldIndex.IndexParams) != len(req.IndexParams) { + if len(fieldIndex.UserIndexParams) != len(req.GetUserIndexParams()) { return false } - for _, param1 := range fieldIndex.IndexParams { + for _, param1 := range fieldIndex.UserIndexParams { exist := false - for _, param2 := range req.IndexParams { + for _, param2 := range req.GetUserIndexParams() { if param2.Key == param1.Key && param2.Value == param1.Value { exist = true } @@ -521,16 +521,17 @@ func (m *meta) FinishTask(taskInfo *indexpb.IndexTaskInfo) error { m.Lock() defer m.Unlock() - segIdx, ok := m.buildID2SegmentIndex[taskInfo.BuildID] + segIdx, ok := m.buildID2SegmentIndex[taskInfo.GetBuildID()] if !ok { - log.Warn("there is no index with buildID", zap.Int64("buildID", taskInfo.BuildID)) + log.Warn("there is no index with buildID", zap.Int64("buildID", taskInfo.GetBuildID())) return nil } updateFunc := func(segIdx *model.SegmentIndex) error { - segIdx.IndexState = taskInfo.State - segIdx.IndexFileKeys = common.CloneStringList(taskInfo.IndexFileKeys) - segIdx.FailReason = taskInfo.FailReason - segIdx.IndexSize = taskInfo.SerializedSize + segIdx.IndexState = taskInfo.GetState() + segIdx.IndexFileKeys = common.CloneStringList(taskInfo.GetIndexFileKeys()) + segIdx.FailReason = taskInfo.GetFailReason() + segIdx.IndexSize = taskInfo.GetSerializedSize() + segIdx.CurrentIndexVersion = taskInfo.GetCurrentIndexVersion() return m.alterSegmentIndexes([]*model.SegmentIndex{segIdx}) } @@ -538,10 +539,12 @@ func (m *meta) FinishTask(taskInfo *indexpb.IndexTaskInfo) error { return err } - log.Info("finish index task success", zap.Int64("buildID", taskInfo.BuildID), - zap.String("state", taskInfo.GetState().String()), zap.String("fail reason", taskInfo.GetFailReason())) + log.Info("finish index task success", zap.Int64("buildID", taskInfo.GetBuildID()), + zap.String("state", taskInfo.GetState().String()), zap.String("fail reason", taskInfo.GetFailReason()), + zap.Int32("current_index_version", taskInfo.GetCurrentIndexVersion()), + ) m.updateIndexTasksMetrics() - metrics.FlushedSegmentFileNum.WithLabelValues(metrics.IndexFileLabel).Observe(float64(len(taskInfo.IndexFileKeys))) + metrics.FlushedSegmentFileNum.WithLabelValues(metrics.IndexFileLabel).Observe(float64(len(taskInfo.GetIndexFileKeys()))) return nil } diff --git a/internal/datacoord/index_meta_test.go b/internal/datacoord/index_meta_test.go index bdd75e2966ddc..2a25c0bbf5b36 100644 --- a/internal/datacoord/index_meta_test.go +++ b/internal/datacoord/index_meta_test.go @@ -28,7 +28,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/internal/kv/mocks" mockkv "github.com/milvus-io/milvus/internal/kv/mocks" "github.com/milvus-io/milvus/internal/metastore/kv/datacoord" catalogmocks "github.com/milvus-io/milvus/internal/metastore/mocks" @@ -41,7 +40,7 @@ import ( func TestMeta_CanCreateIndex(t *testing.T) { var ( collID = UniqueID(1) - //partID = UniqueID(2) + // partID = UniqueID(2) indexID = UniqueID(10) fieldID = UniqueID(100) indexName = "_default_idx" @@ -85,7 +84,7 @@ func TestMeta_CanCreateIndex(t *testing.T) { IndexParams: indexParams, Timestamp: 0, IsAutoIndex: false, - UserIndexParams: nil, + UserIndexParams: indexParams, } t.Run("can create index", func(t *testing.T) { @@ -103,7 +102,7 @@ func TestMeta_CanCreateIndex(t *testing.T) { TypeParams: typeParams, IndexParams: indexParams, IsAutoIndex: false, - UserIndexParams: nil, + UserIndexParams: indexParams, } err = m.CreateIndex(index) @@ -126,17 +125,19 @@ func TestMeta_CanCreateIndex(t *testing.T) { assert.Equal(t, int64(0), tmpIndexID) req.TypeParams = typeParams - req.IndexParams = append(req.IndexParams, &commonpb.KeyValuePair{Key: "metrics_type", Value: "L2"}) + req.UserIndexParams = append(indexParams, &commonpb.KeyValuePair{Key: "metrics_type", Value: "L2"}) tmpIndexID, err = m.CanCreateIndex(req) assert.Error(t, err) assert.Equal(t, int64(0), tmpIndexID) req.IndexParams = []*commonpb.KeyValuePair{{Key: common.IndexTypeKey, Value: "HNSW"}} + req.UserIndexParams = req.IndexParams tmpIndexID, err = m.CanCreateIndex(req) assert.Error(t, err) assert.Equal(t, int64(0), tmpIndexID) req.IndexParams = indexParams + req.UserIndexParams = indexParams req.FieldID++ tmpIndexID, err = m.CanCreateIndex(req) assert.Error(t, err) @@ -162,7 +163,7 @@ func TestMeta_CanCreateIndex(t *testing.T) { func TestMeta_HasSameReq(t *testing.T) { var ( collID = UniqueID(1) - //partID = UniqueID(2) + // partID = UniqueID(2) indexID = UniqueID(10) fieldID = UniqueID(100) indexName = "_default_idx" @@ -199,7 +200,7 @@ func TestMeta_HasSameReq(t *testing.T) { IndexParams: indexParams, Timestamp: 0, IsAutoIndex: false, - UserIndexParams: nil, + UserIndexParams: indexParams, } t.Run("no indexes", func(t *testing.T) { @@ -220,7 +221,7 @@ func TestMeta_HasSameReq(t *testing.T) { TypeParams: typeParams, IndexParams: indexParams, IsAutoIndex: false, - UserIndexParams: nil, + UserIndexParams: indexParams, }, } has, _ := m.HasSameReq(req) @@ -241,6 +242,12 @@ func TestMeta_HasSameReq(t *testing.T) { } func TestMeta_CreateIndex(t *testing.T) { + indexParams := []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "FLAT", + }, + } index := &model.Index{ TenantID: "", CollectionID: 1, @@ -255,14 +262,9 @@ func TestMeta_CreateIndex(t *testing.T) { Value: "128", }, }, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: common.IndexTypeKey, - Value: "FLAT", - }, - }, + IndexParams: indexParams, IsAutoIndex: false, - UserIndexParams: nil, + UserIndexParams: indexParams, } t.Run("success", func(t *testing.T) { @@ -353,7 +355,6 @@ func TestMeta_AddSegmentIndex(t *testing.T) { CreateTime: 12, IndexFileKeys: nil, IndexSize: 0, - WriteHandoff: false, } t.Run("save meta fail", func(t *testing.T) { @@ -371,7 +372,7 @@ func TestMeta_AddSegmentIndex(t *testing.T) { func TestMeta_GetIndexIDByName(t *testing.T) { var ( collID = UniqueID(1) - //partID = UniqueID(2) + // partID = UniqueID(2) indexID = UniqueID(10) fieldID = UniqueID(100) indexName = "_default_idx" @@ -418,14 +419,13 @@ func TestMeta_GetIndexIDByName(t *testing.T) { TypeParams: typeParams, IndexParams: indexParams, IsAutoIndex: false, - UserIndexParams: nil, + UserIndexParams: indexParams, }, } indexID2CreateTS := m.GetIndexIDByName(collID, indexName) assert.Contains(t, indexID2CreateTS, indexID) }) - } func TestMeta_GetSegmentIndexState(t *testing.T) { @@ -493,7 +493,7 @@ func TestMeta_GetSegmentIndexState(t *testing.T) { TypeParams: typeParams, IndexParams: indexParams, IsAutoIndex: false, - UserIndexParams: nil, + UserIndexParams: indexParams, }, } state := m.GetSegmentIndexState(collID, segID) @@ -521,7 +521,6 @@ func TestMeta_GetSegmentIndexState(t *testing.T) { CreateTime: 12, IndexFileKeys: nil, IndexSize: 0, - WriteHandoff: false, }) state := m.GetSegmentIndexState(collID, segID) @@ -544,7 +543,6 @@ func TestMeta_GetSegmentIndexState(t *testing.T) { CreateTime: 12, IndexFileKeys: nil, IndexSize: 0, - WriteHandoff: false, }) state := m.GetSegmentIndexState(collID, segID) @@ -599,7 +597,6 @@ func TestMeta_GetSegmentIndexStateOnField(t *testing.T) { CreateTime: 10, IndexFileKeys: nil, IndexSize: 0, - WriteHandoff: false, }, }, }, @@ -620,7 +617,7 @@ func TestMeta_GetSegmentIndexStateOnField(t *testing.T) { TypeParams: typeParams, IndexParams: indexParams, IsAutoIndex: false, - UserIndexParams: nil, + UserIndexParams: indexParams, }, }, }, @@ -640,7 +637,6 @@ func TestMeta_GetSegmentIndexStateOnField(t *testing.T) { CreateTime: 10, IndexFileKeys: nil, IndexSize: 0, - WriteHandoff: false, }, }, } @@ -732,7 +728,7 @@ func TestMeta_MarkIndexAsDeleted(t *testing.T) { } func TestMeta_GetSegmentIndexes(t *testing.T) { - m := createMetaTable(&datacoord.Catalog{MetaKv: mocks.NewMetaKv(t)}) + m := createMetaTable(&datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)}) t.Run("success", func(t *testing.T) { segIndexes := m.GetSegmentIndexes(segID) @@ -834,6 +830,12 @@ func TestMeta_GetIndexNameByID(t *testing.T) { } func TestMeta_GetTypeParams(t *testing.T) { + indexParams := []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "HNSW", + }, + } m := &meta{ indexes: map[UniqueID]map[UniqueID]*model.Index{ collID: { @@ -851,14 +853,9 @@ func TestMeta_GetTypeParams(t *testing.T) { Value: "128", }, }, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: common.IndexTypeKey, - Value: "HNSW", - }, - }, + IndexParams: indexParams, IsAutoIndex: false, - UserIndexParams: nil, + UserIndexParams: indexParams, }, }, }, @@ -879,6 +876,12 @@ func TestMeta_GetTypeParams(t *testing.T) { } func TestMeta_GetIndexParams(t *testing.T) { + indexParams := []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "HNSW", + }, + } m := &meta{ indexes: map[UniqueID]map[UniqueID]*model.Index{ collID: { @@ -896,14 +899,9 @@ func TestMeta_GetIndexParams(t *testing.T) { Value: "128", }, }, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: common.IndexTypeKey, - Value: "HNSW", - }, - }, + IndexParams: indexParams, IsAutoIndex: false, - UserIndexParams: nil, + UserIndexParams: indexParams, }, }, }, @@ -941,7 +939,6 @@ func TestMeta_GetIndexJob(t *testing.T) { CreateTime: 0, IndexFileKeys: nil, IndexSize: 0, - WriteHandoff: false, }, }, } @@ -1046,7 +1043,6 @@ func updateSegmentIndexMeta(t *testing.T) *meta { CreateTime: 0, IndexFileKeys: nil, IndexSize: 0, - WriteHandoff: false, }, }, }, @@ -1085,7 +1081,6 @@ func updateSegmentIndexMeta(t *testing.T) *meta { CreateTime: 0, IndexFileKeys: nil, IndexSize: 0, - WriteHandoff: false, }, }, } diff --git a/internal/datacoord/index_service.go b/internal/datacoord/index_service.go index f6d85ca764ff9..38abfd2568f36 100644 --- a/internal/datacoord/index_service.go +++ b/internal/datacoord/index_service.go @@ -138,46 +138,39 @@ func (s *Server) createIndexForSegmentLoop(ctx context.Context) { // indexBuilder will find this task and assign it to IndexNode for execution. func (s *Server) CreateIndex(ctx context.Context, req *indexpb.CreateIndexRequest) (*commonpb.Status, error) { log := log.Ctx(ctx).With( - zap.Int64("collectionID", req.CollectionID), + zap.Int64("collectionID", req.GetCollectionID()), ) log.Info("receive CreateIndex request", zap.String("IndexName", req.GetIndexName()), zap.Int64("fieldID", req.GetFieldID()), zap.Any("TypeParams", req.GetTypeParams()), - zap.Any("IndexParams", req.GetIndexParams())) - errResp := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "", - } - if s.isClosed() { - log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID())) - errResp.ErrorCode = commonpb.ErrorCode_DataCoordNA - errResp.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return errResp, nil + zap.Any("IndexParams", req.GetIndexParams()), + ) + + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID()), zap.Error(err)) + return merr.Status(err), nil } metrics.IndexRequestCounter.WithLabelValues(metrics.TotalLabel).Inc() indexID, err := s.meta.CanCreateIndex(req) if err != nil { - log.Error("CreateIndex failed", zap.Error(err)) - errResp.Reason = err.Error() metrics.IndexRequestCounter.WithLabelValues(metrics.FailLabel).Inc() - return errResp, nil + return merr.Status(err), nil } if indexID == 0 { indexID, err = s.allocator.allocID(ctx) if err != nil { log.Warn("failed to alloc indexID", zap.Error(err)) - errResp.Reason = "failed to alloc indexID" metrics.IndexRequestCounter.WithLabelValues(metrics.FailLabel).Inc() - return errResp, nil + return merr.Status(err), nil } if getIndexType(req.GetIndexParams()) == diskAnnIndex && !s.indexNodeManager.ClientSupportDisk() { errMsg := "all IndexNodes do not support disk indexes, please verify" log.Warn(errMsg) - errResp.Reason = errMsg + err = merr.WrapErrIndexNotSupported(diskAnnIndex) metrics.IndexRequestCounter.WithLabelValues(metrics.FailLabel).Inc() - return errResp, nil + return merr.Status(err), nil } } @@ -199,9 +192,8 @@ func (s *Server) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques if err != nil { log.Error("CreateIndex fail", zap.Int64("fieldID", req.GetFieldID()), zap.String("indexName", req.GetIndexName()), zap.Error(err)) - errResp.Reason = err.Error() metrics.IndexRequestCounter.WithLabelValues(metrics.FailLabel).Inc() - return errResp, nil + return merr.Status(err), nil } select { @@ -212,62 +204,47 @@ func (s *Server) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques log.Info("CreateIndex successfully", zap.String("IndexName", req.GetIndexName()), zap.Int64("fieldID", req.GetFieldID()), zap.Int64("IndexID", indexID)) - errResp.ErrorCode = commonpb.ErrorCode_Success metrics.IndexRequestCounter.WithLabelValues(metrics.SuccessLabel).Inc() - return errResp, nil + return merr.Success(), nil } // GetIndexState gets the index state of the index name in the request from Proxy. // Deprecated func (s *Server) GetIndexState(ctx context.Context, req *indexpb.GetIndexStateRequest) (*indexpb.GetIndexStateResponse, error) { log := log.Ctx(ctx).With( - zap.Int64("collectionID", req.CollectionID), + zap.Int64("collectionID", req.GetCollectionID()), + zap.String("indexName", req.GetIndexName()), ) - log.Info("receive GetIndexState request", - zap.String("indexName", req.IndexName)) + log.Info("receive GetIndexState request") - errResp := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "", - } - if s.isClosed() { - log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID())) - errResp.ErrorCode = commonpb.ErrorCode_DataCoordNA - errResp.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID()), zap.Error(err)) return &indexpb.GetIndexStateResponse{ - Status: errResp, + Status: merr.Status(err), }, nil } indexes := s.meta.GetIndexesForCollection(req.GetCollectionID(), req.GetIndexName()) if len(indexes) == 0 { - errResp.ErrorCode = commonpb.ErrorCode_IndexNotExist - errResp.Reason = fmt.Sprintf("there is no index on collection: %d with the index name: %s", req.CollectionID, req.IndexName) - log.Error("GetIndexState fail", - zap.String("indexName", req.IndexName), zap.String("fail reason", errResp.Reason)) + err := merr.WrapErrIndexNotFound(req.GetIndexName()) + log.Warn("GetIndexState fail", zap.Error(err)) return &indexpb.GetIndexStateResponse{ - Status: errResp, + Status: merr.Status(err), }, nil } if len(indexes) > 1 { log.Warn(msgAmbiguousIndexName()) - errResp.ErrorCode = commonpb.ErrorCode_UnexpectedError - errResp.Reason = msgAmbiguousIndexName() + err := merr.WrapErrIndexDuplicate(req.GetIndexName()) return &indexpb.GetIndexStateResponse{ - Status: errResp, + Status: merr.Status(err), }, nil } ret := &indexpb.GetIndexStateResponse{ - Status: merr.Status(nil), + Status: merr.Success(), State: commonpb.IndexState_Finished, } - indexInfo := &indexpb.IndexInfo{ - IndexedRows: 0, - TotalRows: 0, - State: 0, - IndexStateFailReason: "", - } + indexInfo := &indexpb.IndexInfo{} s.completeIndexInfo(indexInfo, indexes[0], s.meta.SelectSegments(func(info *SegmentInfo) bool { return isFlush(info) && info.CollectionID == req.GetCollectionID() }), false, indexes[0].CreateTime) @@ -275,45 +252,40 @@ func (s *Server) GetIndexState(ctx context.Context, req *indexpb.GetIndexStateRe ret.FailReason = indexInfo.IndexStateFailReason log.Info("GetIndexState success", - zap.String("IndexName", req.GetIndexName()), zap.String("state", ret.GetState().String())) + zap.String("state", ret.GetState().String()), + ) return ret, nil } func (s *Server) GetSegmentIndexState(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest) (*indexpb.GetSegmentIndexStateResponse, error) { log := log.Ctx(ctx).With( - zap.Int64("collectionID", req.CollectionID), + zap.Int64("collectionID", req.GetCollectionID()), ) log.Info("receive GetSegmentIndexState", - zap.String("IndexName", req.GetIndexName()), zap.Int64s("fieldID", req.GetSegmentIDs())) - errResp := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "", - } - if s.isClosed() { - log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID())) - errResp.ErrorCode = commonpb.ErrorCode_DataCoordNA - errResp.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) + zap.String("IndexName", req.GetIndexName()), + zap.Int64s("fieldID", req.GetSegmentIDs()), + ) + + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID()), zap.Error(err)) return &indexpb.GetSegmentIndexStateResponse{ - Status: errResp, + Status: merr.Status(err), }, nil } ret := &indexpb.GetSegmentIndexStateResponse{ - Status: merr.Status(nil), + Status: merr.Success(), States: make([]*indexpb.SegmentIndexState, 0), } indexID2CreateTs := s.meta.GetIndexIDByName(req.GetCollectionID(), req.GetIndexName()) if len(indexID2CreateTs) == 0 { - errMsg := fmt.Sprintf("there is no index on collection: %d with the index name: %s", req.CollectionID, req.GetIndexName()) - log.Warn("GetSegmentIndexState fail", zap.String("indexName", req.GetIndexName()), zap.String("fail reason", errMsg)) + err := merr.WrapErrIndexNotFound(req.GetIndexName()) + log.Warn("GetSegmentIndexState fail", zap.String("indexName", req.GetIndexName()), zap.Error(err)) return &indexpb.GetSegmentIndexStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IndexNotExist, - Reason: errMsg, - }, + Status: merr.Status(err), }, nil } - for _, segID := range req.SegmentIDs { + for _, segID := range req.GetSegmentIDs() { state := s.meta.GetSegmentIndexState(req.GetCollectionID(), segID) ret.States = append(ret.States, &indexpb.SegmentIndexState{ SegmentID: segID, @@ -456,44 +428,35 @@ func (s *Server) completeIndexInfo(indexInfo *indexpb.IndexInfo, index *model.In // Deprecated func (s *Server) GetIndexBuildProgress(ctx context.Context, req *indexpb.GetIndexBuildProgressRequest) (*indexpb.GetIndexBuildProgressResponse, error) { log := log.Ctx(ctx).With( - zap.Int64("collectionID", req.CollectionID), + zap.Int64("collectionID", req.GetCollectionID()), ) log.Info("receive GetIndexBuildProgress request", zap.String("indexName", req.GetIndexName())) - errResp := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "", - } - if s.isClosed() { - log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID())) - errResp.ErrorCode = commonpb.ErrorCode_DataCoordNA - errResp.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) + + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID()), zap.Error(err)) return &indexpb.GetIndexBuildProgressResponse{ - Status: errResp, + Status: merr.Status(err), }, nil } indexes := s.meta.GetIndexesForCollection(req.GetCollectionID(), req.GetIndexName()) if len(indexes) == 0 { - errMsg := fmt.Sprintf("there is no index on collection: %d with the index name: %s", req.CollectionID, req.IndexName) - log.Warn("GetIndexBuildProgress fail", zap.String("indexName", req.IndexName), zap.String("fail reason", errMsg)) + err := merr.WrapErrIndexNotFound(req.GetIndexName()) + log.Warn("GetIndexBuildProgress fail", zap.String("indexName", req.IndexName), zap.Error(err)) return &indexpb.GetIndexBuildProgressResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IndexNotExist, - Reason: errMsg, - }, + Status: merr.Status(err), }, nil } if len(indexes) > 1 { log.Warn(msgAmbiguousIndexName()) - errResp.ErrorCode = commonpb.ErrorCode_UnexpectedError - errResp.Reason = msgAmbiguousIndexName() + err := merr.WrapErrIndexDuplicate(req.GetIndexName()) return &indexpb.GetIndexBuildProgressResponse{ - Status: errResp, + Status: merr.Status(err), }, nil } indexInfo := &indexpb.IndexInfo{ - CollectionID: req.CollectionID, + CollectionID: req.GetCollectionID(), IndexID: indexes[0].IndexID, IndexedRows: 0, TotalRows: 0, @@ -506,7 +469,7 @@ func (s *Server) GetIndexBuildProgress(ctx context.Context, req *indexpb.GetInde log.Info("GetIndexBuildProgress success", zap.Int64("collectionID", req.GetCollectionID()), zap.String("indexName", req.GetIndexName())) return &indexpb.GetIndexBuildProgressResponse{ - Status: merr.Status(nil), + Status: merr.Success(), IndexedRows: indexInfo.IndexedRows, TotalRows: indexInfo.TotalRows, PendingIndexRows: indexInfo.PendingIndexRows, @@ -516,32 +479,26 @@ func (s *Server) GetIndexBuildProgress(ctx context.Context, req *indexpb.GetInde // DescribeIndex describe the index info of the collection. func (s *Server) DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error) { log := log.Ctx(ctx).With( - zap.Int64("collectionID", req.CollectionID), + zap.Int64("collectionID", req.GetCollectionID()), + zap.String("indexName", req.GetIndexName()), ) - log.Info("receive DescribeIndex request", zap.String("indexName", req.GetIndexName()), - zap.Uint64("timestamp", req.GetTimestamp())) - errResp := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "", - } - if s.isClosed() { - log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID())) - errResp.ErrorCode = commonpb.ErrorCode_DataCoordNA - errResp.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) + log.Info("receive DescribeIndex request", + zap.Uint64("timestamp", req.GetTimestamp()), + ) + + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID()), zap.Error(err)) return &indexpb.DescribeIndexResponse{ - Status: errResp, + Status: merr.Status(err), }, nil } indexes := s.meta.GetIndexesForCollection(req.GetCollectionID(), req.GetIndexName()) if len(indexes) == 0 { - errMsg := fmt.Sprintf("there is no index on collection: %d with the index name: %s", req.CollectionID, req.IndexName) - log.Warn("DescribeIndex fail", zap.String("indexName", req.IndexName), zap.String("fail reason", errMsg)) + err := merr.WrapErrIndexNotFound(req.GetIndexName()) + log.Warn("DescribeIndex fail", zap.Error(err)) return &indexpb.DescribeIndexResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IndexNotExist, - Reason: fmt.Sprint("index doesn't exist, collectionID ", req.CollectionID), - }, + Status: merr.Status(err), }, nil } @@ -572,9 +529,9 @@ func (s *Server) DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRe s.completeIndexInfo(indexInfo, index, segments, false, createTs) indexInfos = append(indexInfos, indexInfo) } - log.Info("DescribeIndex success", zap.String("indexName", req.GetIndexName())) + log.Info("DescribeIndex success") return &indexpb.DescribeIndexResponse{ - Status: merr.Status(nil), + Status: merr.Success(), IndexInfos: indexInfos, }, nil } @@ -582,27 +539,24 @@ func (s *Server) DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRe // GetIndexStatistics get the statistics of the index. DescribeIndex doesn't contain statistics. func (s *Server) GetIndexStatistics(ctx context.Context, req *indexpb.GetIndexStatisticsRequest) (*indexpb.GetIndexStatisticsResponse, error) { log := log.Ctx(ctx).With( - zap.Int64("collectionID", req.CollectionID), + zap.Int64("collectionID", req.GetCollectionID()), ) log.Info("receive GetIndexStatistics request", zap.String("indexName", req.GetIndexName())) - if s.isClosed() { - log.Warn(msgDataCoordIsUnhealthy(s.serverID())) + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID()), zap.Error(err)) return &indexpb.GetIndexStatisticsResponse{ - Status: s.UnhealthyStatus(), + Status: merr.Status(err), }, nil } indexes := s.meta.GetIndexesForCollection(req.GetCollectionID(), req.GetIndexName()) if len(indexes) == 0 { - errMsg := fmt.Sprintf("there is no index on collection: %d with the index name: %s", req.CollectionID, req.IndexName) + err := merr.WrapErrIndexNotFound(req.GetIndexName()) log.Warn("GetIndexStatistics fail", - zap.String("indexName", req.IndexName), - zap.String("fail reason", errMsg)) + zap.String("indexName", req.GetIndexName()), + zap.Error(err)) return &indexpb.GetIndexStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IndexNotExist, - Reason: fmt.Sprint("index doesn't exist, collectionID ", req.CollectionID), - }, + Status: merr.Status(err), }, nil } @@ -632,7 +586,7 @@ func (s *Server) GetIndexStatistics(ctx context.Context, req *indexpb.GetIndexSt log.Debug("GetIndexStatisticsResponse success", zap.String("indexName", req.GetIndexName())) return &indexpb.GetIndexStatisticsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), IndexInfos: indexInfos, }, nil } @@ -642,82 +596,69 @@ func (s *Server) GetIndexStatistics(ctx context.Context, req *indexpb.GetIndexSt // index tasks. func (s *Server) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error) { log := log.Ctx(ctx).With( - zap.Int64("collectionID", req.CollectionID), + zap.Int64("collectionID", req.GetCollectionID()), ) log.Info("receive DropIndex request", zap.Int64s("partitionIDs", req.GetPartitionIDs()), zap.String("indexName", req.GetIndexName()), zap.Bool("drop all indexes", req.GetDropAll())) - errResp := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "", - } - if s.isClosed() { - log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID())) - errResp.ErrorCode = commonpb.ErrorCode_DataCoordNA - errResp.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return errResp, nil - } - ret := merr.Status(nil) + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID()), zap.Error(err)) + return merr.Status(err), nil + } indexes := s.meta.GetIndexesForCollection(req.GetCollectionID(), req.GetIndexName()) if len(indexes) == 0 { log.Info(fmt.Sprintf("there is no index on collection: %d with the index name: %s", req.CollectionID, req.IndexName)) - return ret, nil + return merr.Success(), nil } if !req.GetDropAll() && len(indexes) > 1 { log.Warn(msgAmbiguousIndexName()) - ret.ErrorCode = commonpb.ErrorCode_UnexpectedError - ret.Reason = msgAmbiguousIndexName() - return ret, nil + err := merr.WrapErrIndexDuplicate(req.GetIndexName()) + return merr.Status(err), nil } indexIDs := make([]UniqueID, 0) for _, index := range indexes { indexIDs = append(indexIDs, index.IndexID) } + // Compatibility logic. To prevent the index on the corresponding segments + // from being dropped at the same time when dropping_partition in version 2.1 if len(req.GetPartitionIDs()) == 0 { // drop collection index - err := s.meta.MarkIndexAsDeleted(req.CollectionID, indexIDs) + err := s.meta.MarkIndexAsDeleted(req.GetCollectionID(), indexIDs) if err != nil { log.Warn("DropIndex fail", zap.String("indexName", req.IndexName), zap.Error(err)) - ret.ErrorCode = commonpb.ErrorCode_UnexpectedError - ret.Reason = err.Error() - return ret, nil + return merr.Status(err), nil } } - log.Debug("DropIndex success", zap.Int64s("partitionIDs", req.PartitionIDs), zap.String("indexName", req.IndexName), - zap.Int64s("indexIDs", indexIDs)) - return ret, nil + log.Debug("DropIndex success", zap.Int64s("partitionIDs", req.GetPartitionIDs()), + zap.String("indexName", req.GetIndexName()), zap.Int64s("indexIDs", indexIDs)) + return merr.Success(), nil } // GetIndexInfos gets the index file paths for segment from DataCoord. func (s *Server) GetIndexInfos(ctx context.Context, req *indexpb.GetIndexInfoRequest) (*indexpb.GetIndexInfoResponse, error) { log := log.Ctx(ctx).With( - zap.Int64("collectionID", req.CollectionID), + zap.Int64("collectionID", req.GetCollectionID()), ) - errResp := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "", - } - if s.isClosed() { - log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID())) - errResp.ErrorCode = commonpb.ErrorCode_DataCoordNA - errResp.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) + + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + log.Warn(msgDataCoordIsUnhealthy(paramtable.GetNodeID()), zap.Error(err)) return &indexpb.GetIndexInfoResponse{ - Status: errResp, + Status: merr.Status(err), }, nil } ret := &indexpb.GetIndexInfoResponse{ - Status: merr.Status(nil), + Status: merr.Success(), SegmentInfo: map[int64]*indexpb.SegmentInfo{}, } - for _, segID := range req.SegmentIDs { + for _, segID := range req.GetSegmentIDs() { segIdxes := s.meta.GetSegmentIndexes(segID) ret.SegmentInfo[segID] = &indexpb.SegmentInfo{ - CollectionID: req.CollectionID, + CollectionID: req.GetCollectionID(), SegmentID: segID, EnableIndex: false, IndexInfos: make([]*indexpb.IndexFilePathInfo, 0), @@ -732,16 +673,17 @@ func (s *Server) GetIndexInfos(ctx context.Context, req *indexpb.GetIndexInfoReq indexParams = append(indexParams, s.meta.GetTypeParams(segIdx.CollectionID, segIdx.IndexID)...) ret.SegmentInfo[segID].IndexInfos = append(ret.SegmentInfo[segID].IndexInfos, &indexpb.IndexFilePathInfo{ - SegmentID: segID, - FieldID: s.meta.GetFieldIDByIndexID(segIdx.CollectionID, segIdx.IndexID), - IndexID: segIdx.IndexID, - BuildID: segIdx.BuildID, - IndexName: s.meta.GetIndexNameByID(segIdx.CollectionID, segIdx.IndexID), - IndexParams: indexParams, - IndexFilePaths: indexFilePaths, - SerializedSize: segIdx.IndexSize, - IndexVersion: segIdx.IndexVersion, - NumRows: segIdx.NumRows, + SegmentID: segID, + FieldID: s.meta.GetFieldIDByIndexID(segIdx.CollectionID, segIdx.IndexID), + IndexID: segIdx.IndexID, + BuildID: segIdx.BuildID, + IndexName: s.meta.GetIndexNameByID(segIdx.CollectionID, segIdx.IndexID), + IndexParams: indexParams, + IndexFilePaths: indexFilePaths, + SerializedSize: segIdx.IndexSize, + IndexVersion: segIdx.IndexVersion, + NumRows: segIdx.NumRows, + CurrentIndexVersion: segIdx.CurrentIndexVersion, }) } } @@ -752,9 +694,3 @@ func (s *Server) GetIndexInfos(ctx context.Context, req *indexpb.GetIndexInfoReq return ret, nil } - -func (s *Server) UnhealthyStatus() *commonpb.Status { - return merr.Status( - merr.WrapErrServiceNotReady( - fmt.Sprintf("datacoord %d is unhealthy", s.serverID()))) -} diff --git a/internal/datacoord/index_service_test.go b/internal/datacoord/index_service_test.go index 5dfc182996002..b75f292326aa0 100644 --- a/internal/datacoord/index_service_test.go +++ b/internal/datacoord/index_service_test.go @@ -27,7 +27,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - "github.com/milvus-io/milvus/internal/kv/mocks" mockkv "github.com/milvus-io/milvus/internal/kv/mocks" "github.com/milvus-io/milvus/internal/metastore/kv/datacoord" catalogmocks "github.com/milvus-io/milvus/internal/metastore/mocks" @@ -37,10 +36,11 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" ) func TestServerId(t *testing.T) { - s := &Server{session: &sessionutil.Session{ServerID: 0}} + s := &Server{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 0}}} assert.Equal(t, int64(0), s.serverID()) } @@ -48,7 +48,7 @@ func TestServer_CreateIndex(t *testing.T) { var ( collID = UniqueID(1) fieldID = UniqueID(10) - //indexID = UniqueID(100) + // indexID = UniqueID(100) indexName = "default_idx" typeParams = []*commonpb.KeyValuePair{ { @@ -100,7 +100,7 @@ func TestServer_CreateIndex(t *testing.T) { s.stateCode.Store(commonpb.StateCode_Abnormal) resp, err := s.CreateIndex(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_DataCoordNA, resp.GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp), merr.ErrServiceNotReady) }) t.Run("index not consistent", func(t *testing.T) { @@ -182,7 +182,7 @@ func TestServer_GetIndexState(t *testing.T) { ) s := &Server{ meta: &meta{ - catalog: &datacoord.Catalog{MetaKv: mocks.NewMetaKv(t)}, + catalog: &datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)}, }, allocator: newMockAllocator(), notifyIndexChan: make(chan UniqueID, 1), @@ -192,18 +192,18 @@ func TestServer_GetIndexState(t *testing.T) { s.stateCode.Store(commonpb.StateCode_Initializing) resp, err := s.GetIndexState(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_DataCoordNA, resp.GetStatus().GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) s.stateCode.Store(commonpb.StateCode_Healthy) - t.Run("index not exist", func(t *testing.T) { + t.Run("index not found", func(t *testing.T) { resp, err := s.GetIndexState(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_IndexNotExist, resp.GetStatus().GetErrorCode()) }) s.meta = &meta{ - catalog: &datacoord.Catalog{MetaKv: mocks.NewMetaKv(t)}, + catalog: &datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)}, indexes: map[UniqueID]map[UniqueID]*model.Index{ collID: { indexID: { @@ -254,7 +254,7 @@ func TestServer_GetIndexState(t *testing.T) { }) s.meta = &meta{ - catalog: &datacoord.Catalog{MetaKv: mocks.NewMetaKv(t)}, + catalog: &datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)}, indexes: map[UniqueID]map[UniqueID]*model.Index{ collID: { indexID: { @@ -372,7 +372,7 @@ func TestServer_GetSegmentIndexState(t *testing.T) { ) s := &Server{ meta: &meta{ - catalog: &datacoord.Catalog{MetaKv: mocks.NewMetaKv(t)}, + catalog: &datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)}, indexes: map[UniqueID]map[UniqueID]*model.Index{}, segments: &SegmentsInfo{map[UniqueID]*SegmentInfo{}}, }, @@ -384,7 +384,7 @@ func TestServer_GetSegmentIndexState(t *testing.T) { s.stateCode.Store(commonpb.StateCode_Abnormal) resp, err := s.GetSegmentIndexState(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_DataCoordNA, resp.GetStatus().GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) t.Run("no indexes", func(t *testing.T) { @@ -507,7 +507,7 @@ func TestServer_GetIndexBuildProgress(t *testing.T) { s := &Server{ meta: &meta{ - catalog: &datacoord.Catalog{MetaKv: mocks.NewMetaKv(t)}, + catalog: &datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)}, indexes: map[UniqueID]map[UniqueID]*model.Index{}, segments: &SegmentsInfo{map[UniqueID]*SegmentInfo{}}, }, @@ -518,7 +518,7 @@ func TestServer_GetIndexBuildProgress(t *testing.T) { s.stateCode.Store(commonpb.StateCode_Initializing) resp, err := s.GetIndexBuildProgress(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_DataCoordNA, resp.GetStatus().GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) t.Run("no indexes", func(t *testing.T) { @@ -706,7 +706,7 @@ func TestServer_DescribeIndex(t *testing.T) { catalog: catalog, indexes: map[UniqueID]map[UniqueID]*model.Index{ collID: { - //finished + // finished indexID: { TenantID: "", CollectionID: collID, @@ -998,7 +998,7 @@ func TestServer_DescribeIndex(t *testing.T) { s.stateCode.Store(commonpb.StateCode_Initializing) resp, err := s.DescribeIndex(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_DataCoordNA, resp.GetStatus().GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) s.stateCode.Store(commonpb.StateCode_Healthy) @@ -1067,7 +1067,7 @@ func TestServer_GetIndexStatistics(t *testing.T) { catalog: catalog, indexes: map[UniqueID]map[UniqueID]*model.Index{ collID: { - //finished + // finished indexID: { TenantID: "", CollectionID: collID, @@ -1347,7 +1347,7 @@ func TestServer_DropIndex(t *testing.T) { catalog: catalog, indexes: map[UniqueID]map[UniqueID]*model.Index{ collID: { - //finished + // finished indexID: { TenantID: "", CollectionID: collID, @@ -1442,7 +1442,7 @@ func TestServer_DropIndex(t *testing.T) { s.stateCode.Store(commonpb.StateCode_Initializing) resp, err := s.DropIndex(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_DataCoordNA, resp.GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp), merr.ErrServiceNotReady) }) s.stateCode.Store(commonpb.StateCode_Healthy) @@ -1539,10 +1539,10 @@ func TestServer_GetIndexInfos(t *testing.T) { s := &Server{ meta: &meta{ - catalog: &datacoord.Catalog{MetaKv: mocks.NewMetaKv(t)}, + catalog: &datacoord.Catalog{MetaKv: mockkv.NewMetaKv(t)}, indexes: map[UniqueID]map[UniqueID]*model.Index{ collID: { - //finished + // finished indexID: { TenantID: "", CollectionID: collID, @@ -1602,7 +1602,7 @@ func TestServer_GetIndexInfos(t *testing.T) { s.stateCode.Store(commonpb.StateCode_Initializing) resp, err := s.GetIndexInfos(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_DataCoordNA, resp.GetStatus().GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) s.stateCode.Store(commonpb.StateCode_Healthy) diff --git a/internal/datacoord/indexnode_manager.go b/internal/datacoord/indexnode_manager.go index ee676c975d170..7a6721f72b766 100644 --- a/internal/datacoord/indexnode_manager.go +++ b/internal/datacoord/indexnode_manager.go @@ -33,7 +33,7 @@ import ( // IndexNodeManager is used to manage the client of IndexNode. type IndexNodeManager struct { - nodeClients map[UniqueID]types.IndexNode + nodeClients map[UniqueID]types.IndexNodeClient stoppingNodes map[UniqueID]struct{} lock sync.RWMutex ctx context.Context @@ -43,7 +43,7 @@ type IndexNodeManager struct { // NewNodeManager is used to create a new IndexNodeManager. func NewNodeManager(ctx context.Context, indexNodeCreator indexNodeCreatorFunc) *IndexNodeManager { return &IndexNodeManager{ - nodeClients: make(map[UniqueID]types.IndexNode), + nodeClients: make(map[UniqueID]types.IndexNodeClient), stoppingNodes: make(map[UniqueID]struct{}), lock: sync.RWMutex{}, ctx: ctx, @@ -52,7 +52,7 @@ func NewNodeManager(ctx context.Context, indexNodeCreator indexNodeCreatorFunc) } // setClient sets IndexNode client to node manager. -func (nm *IndexNodeManager) setClient(nodeID UniqueID, client types.IndexNode) { +func (nm *IndexNodeManager) setClient(nodeID UniqueID, client types.IndexNodeClient) { log.Debug("set IndexNode client", zap.Int64("nodeID", nodeID)) nm.lock.Lock() defer nm.lock.Unlock() @@ -82,7 +82,7 @@ func (nm *IndexNodeManager) StoppingNode(nodeID UniqueID) { func (nm *IndexNodeManager) AddNode(nodeID UniqueID, address string) error { log.Debug("add IndexNode", zap.Int64("nodeID", nodeID), zap.String("node address", address)) var ( - nodeClient types.IndexNode + nodeClient types.IndexNodeClient err error ) @@ -97,7 +97,7 @@ func (nm *IndexNodeManager) AddNode(nodeID UniqueID, address string) error { } // PeekClient peeks the client with the least load. -func (nm *IndexNodeManager) PeekClient(meta *model.SegmentIndex) (UniqueID, types.IndexNode) { +func (nm *IndexNodeManager) PeekClient(meta *model.SegmentIndex) (UniqueID, types.IndexNodeClient) { allClients := nm.GetAllClients() if len(allClients) == 0 { log.Error("there is no IndexNode online") @@ -123,12 +123,12 @@ func (nm *IndexNodeManager) PeekClient(meta *model.SegmentIndex) (UniqueID, type log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), zap.Error(err)) return } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), - zap.String("reason", resp.Status.Reason)) + zap.String("reason", resp.GetStatus().GetReason())) return } - if resp.TaskSlots > 0 { + if resp.GetTaskSlots() > 0 { nodeMutex.Lock() defer nodeMutex.Unlock() log.Info("peek client success", zap.Int64("nodeID", nodeID)) @@ -179,13 +179,13 @@ func (nm *IndexNodeManager) ClientSupportDisk() bool { log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), zap.Error(err)) return } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), - zap.String("reason", resp.Status.Reason)) + zap.String("reason", resp.GetStatus().GetReason())) return } - log.Debug("get job stats success", zap.Int64("nodeID", nodeID), zap.Bool("enable disk", resp.EnableDisk)) - if resp.EnableDisk { + log.Debug("get job stats success", zap.Int64("nodeID", nodeID), zap.Bool("enable disk", resp.GetEnableDisk())) + if resp.GetEnableDisk() { nodeMutex.Lock() defer nodeMutex.Unlock() cancel() @@ -207,11 +207,11 @@ func (nm *IndexNodeManager) ClientSupportDisk() bool { return false } -func (nm *IndexNodeManager) GetAllClients() map[UniqueID]types.IndexNode { +func (nm *IndexNodeManager) GetAllClients() map[UniqueID]types.IndexNodeClient { nm.lock.RLock() defer nm.lock.RUnlock() - allClients := make(map[UniqueID]types.IndexNode, len(nm.nodeClients)) + allClients := make(map[UniqueID]types.IndexNodeClient, len(nm.nodeClients)) for nodeID, client := range nm.nodeClients { if _, ok := nm.stoppingNodes[nodeID]; !ok { allClients[nodeID] = client @@ -221,7 +221,7 @@ func (nm *IndexNodeManager) GetAllClients() map[UniqueID]types.IndexNode { return allClients } -func (nm *IndexNodeManager) GetClientByID(nodeID UniqueID) (types.IndexNode, bool) { +func (nm *IndexNodeManager) GetClientByID(nodeID UniqueID) (types.IndexNodeClient, bool) { nm.lock.RLock() defer nm.lock.RUnlock() @@ -237,7 +237,7 @@ type indexNodeGetMetricsResponse struct { // getMetrics get metrics information of all IndexNode. func (nm *IndexNodeManager) getMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) []indexNodeGetMetricsResponse { - var clients []types.IndexNode + var clients []types.IndexNodeClient nm.lock.RLock() for _, node := range nm.nodeClients { clients = append(clients, node) diff --git a/internal/datacoord/indexnode_manager_test.go b/internal/datacoord/indexnode_manager_test.go index ddb82196bdbc8..6abad8c191106 100644 --- a/internal/datacoord/indexnode_manager_test.go +++ b/internal/datacoord/indexnode_manager_test.go @@ -22,13 +22,14 @@ import ( "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/internal/indexnode" "github.com/milvus-io/milvus/internal/metastore/model" + "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/types" - "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/pkg/util/merr" ) func TestIndexNodeManager_AddNode(t *testing.T) { @@ -49,98 +50,47 @@ func TestIndexNodeManager_AddNode(t *testing.T) { } func TestIndexNodeManager_PeekClient(t *testing.T) { + getMockedGetJobStatsClient := func(resp *indexpb.GetJobStatsResponse, err error) types.IndexNodeClient { + ic := mocks.NewMockIndexNodeClient(t) + ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything).Return(resp, err) + return ic + } + + err := errors.New("error") + t.Run("multiple unavailable IndexNode", func(t *testing.T) { nm := &IndexNodeManager{ ctx: context.TODO(), - nodeClients: map[UniqueID]types.IndexNode{ - 1: &indexnode.Mock{ - CallGetJobStats: func(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, - }, errors.New("error") - }, - }, - 2: &indexnode.Mock{ - CallGetJobStats: func(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, - }, errors.New("error") - }, - }, - 3: &indexnode.Mock{ - CallGetJobStats: func(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, - }, errors.New("error") - }, - }, - 4: &indexnode.Mock{ - CallGetJobStats: func(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, - }, errors.New("error") - }, - }, - 5: &indexnode.Mock{ - CallGetJobStats: func(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "fail reason", - }, - }, nil - }, - }, - 6: &indexnode.Mock{ - CallGetJobStats: func(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "fail reason", - }, - }, nil - }, - }, - 7: &indexnode.Mock{ - CallGetJobStats: func(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "fail reason", - }, - }, nil - }, - }, - 8: &indexnode.Mock{ - CallGetJobStats: func(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - TaskSlots: 1, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - }, nil - }, - }, - 9: &indexnode.Mock{ - CallGetJobStats: func(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - TaskSlots: 10, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - }, nil - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + 1: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ + Status: merr.Status(err), + }, err), + 2: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ + Status: merr.Status(err), + }, err), + 3: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ + Status: merr.Status(err), + }, err), + 4: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ + Status: merr.Status(err), + }, err), + 5: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ + Status: merr.Status(err), + }, nil), + 6: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ + Status: merr.Status(err), + }, nil), + 7: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ + Status: merr.Status(err), + }, nil), + 8: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ + TaskSlots: 1, + Status: merr.Success(), + }, nil), + 9: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ + TaskSlots: 10, + Status: merr.Success(), + }, nil), }, } @@ -151,24 +101,25 @@ func TestIndexNodeManager_PeekClient(t *testing.T) { } func TestIndexNodeManager_ClientSupportDisk(t *testing.T) { + getMockedGetJobStatsClient := func(resp *indexpb.GetJobStatsResponse, err error) types.IndexNodeClient { + ic := mocks.NewMockIndexNodeClient(t) + ic.EXPECT().GetJobStats(mock.Anything, mock.Anything, mock.Anything).Return(resp, err) + return ic + } + + err := errors.New("error") + t.Run("support", func(t *testing.T) { nm := &IndexNodeManager{ ctx: context.Background(), lock: sync.RWMutex{}, - nodeClients: map[UniqueID]types.IndexNode{ - 1: &indexnode.Mock{ - CallGetJobStats: func(ctx context.Context, in *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - TaskSlots: 1, - JobInfos: nil, - EnableDisk: true, - }, nil - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + 1: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ + Status: merr.Success(), + TaskSlots: 1, + JobInfos: nil, + EnableDisk: true, + }, nil), }, } @@ -180,20 +131,13 @@ func TestIndexNodeManager_ClientSupportDisk(t *testing.T) { nm := &IndexNodeManager{ ctx: context.Background(), lock: sync.RWMutex{}, - nodeClients: map[UniqueID]types.IndexNode{ - 1: &indexnode.Mock{ - CallGetJobStats: func(ctx context.Context, in *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - TaskSlots: 1, - JobInfos: nil, - EnableDisk: false, - }, nil - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + 1: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ + Status: merr.Success(), + TaskSlots: 1, + JobInfos: nil, + EnableDisk: false, + }, nil), }, } @@ -205,7 +149,7 @@ func TestIndexNodeManager_ClientSupportDisk(t *testing.T) { nm := &IndexNodeManager{ ctx: context.Background(), lock: sync.RWMutex{}, - nodeClients: map[UniqueID]types.IndexNode{}, + nodeClients: map[UniqueID]types.IndexNodeClient{}, } support := nm.ClientSupportDisk() @@ -216,12 +160,8 @@ func TestIndexNodeManager_ClientSupportDisk(t *testing.T) { nm := &IndexNodeManager{ ctx: context.Background(), lock: sync.RWMutex{}, - nodeClients: map[UniqueID]types.IndexNode{ - 1: &indexnode.Mock{ - CallGetJobStats: func(ctx context.Context, in *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return nil, errors.New("error") - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + 1: getMockedGetJobStatsClient(nil, err), }, } @@ -233,20 +173,13 @@ func TestIndexNodeManager_ClientSupportDisk(t *testing.T) { nm := &IndexNodeManager{ ctx: context.Background(), lock: sync.RWMutex{}, - nodeClients: map[UniqueID]types.IndexNode{ - 1: &indexnode.Mock{ - CallGetJobStats: func(ctx context.Context, in *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "fail reason", - }, - TaskSlots: 0, - JobInfos: nil, - EnableDisk: false, - }, nil - }, - }, + nodeClients: map[UniqueID]types.IndexNodeClient{ + 1: getMockedGetJobStatsClient(&indexpb.GetJobStatsResponse{ + Status: merr.Status(err), + TaskSlots: 0, + JobInfos: nil, + EnableDisk: false, + }, nil), }, } diff --git a/internal/datacoord/meta.go b/internal/datacoord/meta.go index 70ca3ef056ba9..7e078b6aa844f 100644 --- a/internal/datacoord/meta.go +++ b/internal/datacoord/meta.go @@ -24,8 +24,8 @@ import ( "sync" "time" + "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/samber/lo" "go.uber.org/zap" "golang.org/x/exp/maps" @@ -43,6 +43,7 @@ import ( "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/timerecord" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -297,8 +298,9 @@ func (m *meta) GetCollectionBinlogSize() (int64, map[UniqueID]int64) { } // AddSegment records segment info, persisting info into kv store -func (m *meta) AddSegment(segment *SegmentInfo) error { - log.Debug("meta update: adding segment", zap.Int64("segmentID", segment.GetID())) +func (m *meta) AddSegment(ctx context.Context, segment *SegmentInfo) error { + log := log.Ctx(ctx) + log.Info("meta update: adding segment - Start", zap.Int64("segmentID", segment.GetID())) m.Lock() defer m.Unlock() if err := m.catalog.AddSegment(m.ctx, segment.SegmentInfo); err != nil { @@ -627,7 +629,7 @@ func (m *meta) UpdateFlushSegmentsInfo( } // TODO add diff encoding and compression currBinlogs := clonedSegment.GetBinlogs() - var getFieldBinlogs = func(id UniqueID, binlogs []*datapb.FieldBinlog) *datapb.FieldBinlog { + getFieldBinlogs := func(id UniqueID, binlogs []*datapb.FieldBinlog) *datapb.FieldBinlog { for _, binlog := range binlogs { if id == binlog.GetFieldID() { return binlog @@ -668,7 +670,7 @@ func (m *meta) UpdateFlushSegmentsInfo( } clonedSegment.Deltalogs = currDeltaLogs modSegments[segmentID] = clonedSegment - var getClonedSegment = func(segmentID UniqueID) *SegmentInfo { + getClonedSegment := func(segmentID UniqueID) *SegmentInfo { if s, ok := modSegments[segmentID]; ok { return s } @@ -822,7 +824,7 @@ func (m *meta) mergeDropSegment(seg2Drop *SegmentInfo) (*SegmentInfo, *segMetric currBinlogs := clonedSegment.GetBinlogs() - var getFieldBinlogs = func(id UniqueID, binlogs []*datapb.FieldBinlog) *datapb.FieldBinlog { + getFieldBinlogs := func(id UniqueID, binlogs []*datapb.FieldBinlog) *datapb.FieldBinlog { for _, binlog := range binlogs { if id == binlog.GetFieldID() { return binlog @@ -1061,8 +1063,8 @@ func (m *meta) AddAllocation(segmentID UniqueID, allocation *Allocation) error { curSegInfo := m.segments.GetSegment(segmentID) if curSegInfo == nil { // TODO: Error handling. - log.Warn("meta update: add allocation failed - segment not found", zap.Int64("segmentID", segmentID)) - return nil + log.Error("meta update: add allocation failed - segment not found", zap.Int64("segmentID", segmentID)) + return errors.New("meta update: add allocation failed - segment not found") } // As we use global segment lastExpire to guarantee data correctness after restart // there is no need to persist allocation to meta store, only update allocation in-memory meta. @@ -1119,7 +1121,8 @@ func (m *meta) SetSegmentCompacting(segmentID UniqueID, compacting bool) { // - the segment info of compactedTo segment after compaction to add // The compactedTo segment could contain 0 numRows func (m *meta) PrepareCompleteCompactionMutation(plan *datapb.CompactionPlan, - result *datapb.CompactionResult) ([]*SegmentInfo, []*SegmentInfo, *SegmentInfo, *segMetricMutation, error) { + result *datapb.CompactionResult, +) ([]*SegmentInfo, []*SegmentInfo, *SegmentInfo, *segMetricMutation, error) { log.Info("meta update: prepare for complete compaction mutation") compactionLogs := plan.GetSegmentBinlogs() m.Lock() diff --git a/internal/datacoord/meta_test.go b/internal/datacoord/meta_test.go index 6f6b016333586..87b122c3dfced 100644 --- a/internal/datacoord/meta_test.go +++ b/internal/datacoord/meta_test.go @@ -31,6 +31,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/kv" + mockkv "github.com/milvus-io/milvus/internal/kv/mocks" "github.com/milvus-io/milvus/internal/metastore/kv/datacoord" "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/mocks" @@ -39,8 +40,6 @@ import ( "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/testutils" - - mockkv "github.com/milvus-io/milvus/internal/kv/mocks" ) // MetaReloadSuite tests meta reload & meta creation related logic @@ -243,11 +242,11 @@ func TestMeta_Basic(t *testing.T) { segInfo1_1 := buildSegment(collID, partID1, segID1_1, channelName, false) // check AddSegment - err = meta.AddSegment(segInfo0_0) + err = meta.AddSegment(context.TODO(), segInfo0_0) assert.NoError(t, err) - err = meta.AddSegment(segInfo1_0) + err = meta.AddSegment(context.TODO(), segInfo1_0) assert.NoError(t, err) - err = meta.AddSegment(segInfo1_1) + err = meta.AddSegment(context.TODO(), segInfo1_1) assert.NoError(t, err) // check GetSegment @@ -311,7 +310,6 @@ func TestMeta_Basic(t *testing.T) { info1_1 = meta.GetHealthySegment(segID1_1) assert.NotNil(t, info1_1) assert.Equal(t, false, info1_1.GetIsImporting()) - }) t.Run("Test segment with kv fails", func(t *testing.T) { @@ -325,7 +323,7 @@ func TestMeta_Basic(t *testing.T) { meta, err := newMeta(context.TODO(), catalog, nil) assert.NoError(t, err) - err = meta.AddSegment(NewSegmentInfo(&datapb.SegmentInfo{})) + err = meta.AddSegment(context.TODO(), NewSegmentInfo(&datapb.SegmentInfo{})) assert.Error(t, err) metakv2 := mockkv.NewMetaKv(t) @@ -342,7 +340,7 @@ func TestMeta_Basic(t *testing.T) { err = meta.DropSegment(0) assert.NoError(t, err) // nil, since Save error not injected - err = meta.AddSegment(NewSegmentInfo(&datapb.SegmentInfo{})) + err = meta.AddSegment(context.TODO(), NewSegmentInfo(&datapb.SegmentInfo{})) assert.NoError(t, err) // error injected err = meta.DropSegment(0) @@ -351,6 +349,7 @@ func TestMeta_Basic(t *testing.T) { catalog = datacoord.NewCatalog(metakv, "", "") meta, err = newMeta(context.TODO(), catalog, nil) assert.NoError(t, err) + assert.NotNil(t, meta) }) t.Run("Test GetCount", func(t *testing.T) { @@ -366,7 +365,7 @@ func TestMeta_Basic(t *testing.T) { assert.NoError(t, err) segInfo0 := buildSegment(collID, partID0, segID0, channelName, false) segInfo0.NumOfRows = rowCount0 - err = meta.AddSegment(segInfo0) + err = meta.AddSegment(context.TODO(), segInfo0) assert.NoError(t, err) // add seg2 with 300 rows @@ -374,7 +373,7 @@ func TestMeta_Basic(t *testing.T) { assert.NoError(t, err) segInfo1 := buildSegment(collID, partID0, segID1, channelName, false) segInfo1.NumOfRows = rowCount1 - err = meta.AddSegment(segInfo1) + err = meta.AddSegment(context.TODO(), segInfo1) assert.NoError(t, err) // check partition/collection statistics @@ -432,7 +431,7 @@ func TestMeta_Basic(t *testing.T) { assert.NoError(t, err) segInfo0 := buildSegment(collID, partID0, segID0, channelName, false) segInfo0.size.Store(size0) - err = meta.AddSegment(segInfo0) + err = meta.AddSegment(context.TODO(), segInfo0) assert.NoError(t, err) // add seg1 with size1 @@ -440,7 +439,7 @@ func TestMeta_Basic(t *testing.T) { assert.NoError(t, err) segInfo1 := buildSegment(collID, partID0, segID1, channelName, false) segInfo1.size.Store(size1) - err = meta.AddSegment(segInfo1) + err = meta.AddSegment(context.TODO(), segInfo1) assert.NoError(t, err) // check TotalBinlogSize @@ -449,6 +448,16 @@ func TestMeta_Basic(t *testing.T) { assert.Equal(t, int64(size0+size1), collectionBinlogSize[collID]) assert.Equal(t, int64(size0+size1), total) }) + + t.Run("Test AddAllocation", func(t *testing.T) { + meta, _ := newMemoryMeta() + err := meta.AddAllocation(1, &Allocation{ + SegmentID: 1, + NumOfRows: 1, + ExpireTime: 0, + }) + assert.Error(t, err) + }) } func TestGetUnFlushedSegments(t *testing.T) { @@ -460,7 +469,7 @@ func TestGetUnFlushedSegments(t *testing.T) { PartitionID: 0, State: commonpb.SegmentState_Growing, } - err = meta.AddSegment(NewSegmentInfo(s1)) + err = meta.AddSegment(context.TODO(), NewSegmentInfo(s1)) assert.NoError(t, err) s2 := &datapb.SegmentInfo{ ID: 1, @@ -468,7 +477,7 @@ func TestGetUnFlushedSegments(t *testing.T) { PartitionID: 0, State: commonpb.SegmentState_Flushed, } - err = meta.AddSegment(NewSegmentInfo(s2)) + err = meta.AddSegment(context.TODO(), NewSegmentInfo(s2)) assert.NoError(t, err) segments := meta.GetUnFlushedSegments() @@ -484,9 +493,11 @@ func TestUpdateFlushSegmentsInfo(t *testing.T) { meta, err := newMemoryMeta() assert.NoError(t, err) - segment1 := &SegmentInfo{SegmentInfo: &datapb.SegmentInfo{ID: 1, State: commonpb.SegmentState_Growing, Binlogs: []*datapb.FieldBinlog{getFieldBinlogPaths(1, getInsertLogPath("binlog0", 1))}, - Statslogs: []*datapb.FieldBinlog{getFieldBinlogPaths(1, getStatsLogPath("statslog0", 1))}}} - err = meta.AddSegment(segment1) + segment1 := &SegmentInfo{SegmentInfo: &datapb.SegmentInfo{ + ID: 1, State: commonpb.SegmentState_Growing, Binlogs: []*datapb.FieldBinlog{getFieldBinlogPaths(1, getInsertLogPath("binlog0", 1))}, + Statslogs: []*datapb.FieldBinlog{getFieldBinlogPaths(1, getStatsLogPath("statslog0", 1))}, + }} + err = meta.AddSegment(context.TODO(), segment1) assert.NoError(t, err) err = meta.UpdateFlushSegmentsInfo(1, true, false, false, []*datapb.FieldBinlog{getFieldBinlogPathsWithEntry(1, 10, getInsertLogPath("binlog1", 1))}, @@ -513,7 +524,6 @@ func TestUpdateFlushSegmentsInfo(t *testing.T) { assert.Equal(t, updated.State, expected.State) assert.Equal(t, updated.size.Load(), expected.size.Load()) assert.Equal(t, updated.NumOfRows, expected.NumOfRows) - }) t.Run("update non-existed segment", func(t *testing.T) { @@ -529,7 +539,7 @@ func TestUpdateFlushSegmentsInfo(t *testing.T) { assert.NoError(t, err) segment1 := &SegmentInfo{SegmentInfo: &datapb.SegmentInfo{ID: 1, State: commonpb.SegmentState_Growing}} - err = meta.AddSegment(segment1) + err = meta.AddSegment(context.TODO(), segment1) assert.NoError(t, err) err = meta.UpdateFlushSegmentsInfo(1, false, false, false, nil, nil, nil, []*datapb.CheckPoint{{SegmentID: 2, NumOfRows: 10}}, diff --git a/internal/datacoord/metrics_info.go b/internal/datacoord/metrics_info.go index 6d135a6595fcd..fb74df46e2f4b 100644 --- a/internal/datacoord/metrics_info.go +++ b/internal/datacoord/metrics_info.go @@ -20,15 +20,14 @@ import ( "context" "github.com/cockroachdb/errors" - - "github.com/milvus-io/milvus/internal/types" - "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/hardware" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -90,20 +89,16 @@ func (s *Server) getSystemInfoMetrics( } resp := &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, - Response: "", + Status: merr.Success(), ComponentName: metricsinfo.ConstructComponentName(typeutil.DataCoordRole, paramtable.GetNodeID()), } var err error resp.Response, err = metricsinfo.MarshalTopology(coordTopology) if err != nil { - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } - resp.Status.ErrorCode = commonpb.ErrorCode_Success return resp, nil } @@ -168,8 +163,8 @@ func (s *Server) getDataNodeMetrics(ctx context.Context, req *milvuspb.GetMetric if metrics.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("invalid metrics of DataNode was found", - zap.Any("error_code", metrics.Status.ErrorCode), - zap.Any("error_reason", metrics.Status.Reason)) + zap.Any("error_code", metrics.GetStatus().GetErrorCode()), + zap.Any("error_reason", metrics.GetStatus().GetReason())) infos.BaseComponentInfos.ErrorReason = metrics.GetStatus().GetReason() return infos, nil } @@ -185,7 +180,7 @@ func (s *Server) getDataNodeMetrics(ctx context.Context, req *milvuspb.GetMetric return infos, nil } -func (s *Server) getIndexNodeMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, node types.IndexNode) (metricsinfo.IndexNodeInfos, error) { +func (s *Server) getIndexNodeMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, node types.IndexNodeClient) (metricsinfo.IndexNodeInfos, error) { infos := metricsinfo.IndexNodeInfos{ BaseComponentInfos: metricsinfo.BaseComponentInfos{ HasError: true, @@ -208,8 +203,8 @@ func (s *Server) getIndexNodeMetrics(ctx context.Context, req *milvuspb.GetMetri if metrics.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("invalid metrics of DataNode was found", - zap.Any("error_code", metrics.Status.ErrorCode), - zap.Any("error_reason", metrics.Status.Reason)) + zap.Any("error_code", metrics.GetStatus().GetErrorCode()), + zap.Any("error_reason", metrics.GetStatus().GetReason())) infos.BaseComponentInfos.ErrorReason = metrics.GetStatus().GetReason() return infos, nil } diff --git a/internal/datacoord/metrics_info_test.go b/internal/datacoord/metrics_info_test.go index 2c2cc73ee0f27..5e0e01140ca2c 100644 --- a/internal/datacoord/metrics_info_test.go +++ b/internal/datacoord/metrics_info_test.go @@ -21,35 +21,36 @@ import ( "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/stretchr/testify/assert" ) type mockMetricDataNodeClient struct { - types.DataNode + types.DataNodeClient mock func() (*milvuspb.GetMetricsResponse, error) } -func (c *mockMetricDataNodeClient) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (c *mockMetricDataNodeClient) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { if c.mock == nil { - return c.DataNode.GetMetrics(ctx, req) + return c.DataNodeClient.GetMetrics(ctx, req) } return c.mock() } type mockMetricIndexNodeClient struct { - types.IndexNode + types.IndexNodeClient mock func() (*milvuspb.GetMetricsResponse, error) } -func (m *mockMetricIndexNodeClient) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (m *mockMetricIndexNodeClient) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { if m.mock == nil { - return m.IndexNode.GetMetrics(ctx, req) + return m.IndexNodeClient.GetMetrics(ctx, req) } return m.mock() } @@ -68,7 +69,7 @@ func TestGetDataNodeMetrics(t *testing.T) { _, err = svr.getDataNodeMetrics(ctx, req, NewSession(&NodeInfo{}, nil)) assert.Error(t, err) - creator := func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { + creator := func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { return newMockDataNodeClient(100, nil) } @@ -80,10 +81,10 @@ func TestGetDataNodeMetrics(t *testing.T) { assert.Equal(t, metricsinfo.ConstructComponentName(typeutil.DataNodeRole, 100), info.BaseComponentInfos.Name) getMockFailedClientCreator := func(mockFunc func() (*milvuspb.GetMetricsResponse, error)) dataNodeCreatorFunc { - return func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { + return func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { cli, err := creator(ctx, addr, nodeID) assert.NoError(t, err) - return &mockMetricDataNodeClient{DataNode: cli, mock: mockFunc}, nil + return &mockMetricDataNodeClient{DataNodeClient: cli, mock: mockFunc}, nil } } @@ -95,13 +96,11 @@ func TestGetDataNodeMetrics(t *testing.T) { assert.NoError(t, err) assert.True(t, info.HasError) + mockErr := errors.New("mocked error") // mock status not success mockFailClientCreator = getMockFailedClientCreator(func() (*milvuspb.GetMetricsResponse, error) { return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "mocked error", - }, + Status: merr.Status(mockErr), }, nil }) @@ -113,9 +112,7 @@ func TestGetDataNodeMetrics(t *testing.T) { // mock parse error mockFailClientCreator = getMockFailedClientCreator(func() (*milvuspb.GetMetricsResponse, error) { return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Response: `{"error_reason": 1}`, }, nil }) @@ -123,7 +120,6 @@ func TestGetDataNodeMetrics(t *testing.T) { info, err = svr.getDataNodeMetrics(ctx, req, NewSession(&NodeInfo{}, mockFailClientCreator)) assert.NoError(t, err) assert.True(t, info.HasError) - } func TestGetIndexNodeMetrics(t *testing.T) { @@ -144,14 +140,11 @@ func TestGetIndexNodeMetrics(t *testing.T) { assert.True(t, info.HasError) // failed + mockErr := errors.New("mocked error") info, err = svr.getIndexNodeMetrics(ctx, req, &mockMetricIndexNodeClient{ mock: func() (*milvuspb.GetMetricsResponse, error) { return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "mock fail", - }, - Response: "", + Status: merr.Status(mockErr), ComponentName: "indexnode100", }, nil }, @@ -164,10 +157,7 @@ func TestGetIndexNodeMetrics(t *testing.T) { info, err = svr.getIndexNodeMetrics(ctx, req, &mockMetricIndexNodeClient{ mock: func() (*milvuspb.GetMetricsResponse, error) { return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), Response: "XXXXXXXXXXXXX", ComponentName: "indexnode100", }, nil @@ -191,20 +181,13 @@ func TestGetIndexNodeMetrics(t *testing.T) { resp, err := metricsinfo.MarshalComponentInfos(nodeInfos) if err != nil { return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - Response: "", + Status: merr.Status(err), ComponentName: metricsinfo.ConstructComponentName(typeutil.IndexNodeRole, nodeID), }, nil } return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.IndexNodeRole, nodeID), }, nil diff --git a/internal/datacoord/mock_test.go b/internal/datacoord/mock_test.go index e02fafd795943..b477b7db6efe8 100644 --- a/internal/datacoord/mock_test.go +++ b/internal/datacoord/mock_test.go @@ -22,9 +22,8 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/tsoutil" clientv3 "go.etcd.io/etcd/client/v3" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -36,7 +35,9 @@ import ( "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -103,8 +104,7 @@ func (m *MockAllocator) allocID(ctx context.Context) (UniqueID, error) { return val, nil } -type MockAllocator0 struct { -} +type MockAllocator0 struct{} func (m *MockAllocator0) allocTimestamp(ctx context.Context) (Timestamp, error) { return Timestamp(0), nil @@ -167,9 +167,7 @@ func newMockDataNodeClient(id int64, ch chan interface{}) (*mockDataNodeClient, state: commonpb.StateCode_Initializing, ch: ch, addImportSegmentResp: &datapb.AddImportSegmentResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, }, nil } @@ -186,20 +184,11 @@ func newMockIndexNodeClient(id int64) (*mockIndexNodeClient, error) { }, nil } -func (c *mockDataNodeClient) Init() error { - return nil -} - -func (c *mockDataNodeClient) Start() error { - c.state = commonpb.StateCode_Healthy +func (c *mockDataNodeClient) Close() error { return nil } -func (c *mockDataNodeClient) Register() error { - return nil -} - -func (c *mockDataNodeClient) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (c *mockDataNodeClient) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ NodeID: c.id, @@ -208,40 +197,34 @@ func (c *mockDataNodeClient) GetComponentStates(ctx context.Context) (*milvuspb. }, nil } -func (c *mockDataNodeClient) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *mockDataNodeClient) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return nil, nil } -func (c *mockDataNodeClient) WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelsRequest) (*commonpb.Status, error) { +func (c *mockDataNodeClient) WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil } -func (c *mockDataNodeClient) FlushSegments(ctx context.Context, in *datapb.FlushSegmentsRequest) (*commonpb.Status, error) { +func (c *mockDataNodeClient) FlushSegments(ctx context.Context, in *datapb.FlushSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { if c.ch != nil { c.ch <- in } return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil } -func (c *mockDataNodeClient) ResendSegmentStats(ctx context.Context, req *datapb.ResendSegmentStatsRequest) (*datapb.ResendSegmentStatsResponse, error) { +func (c *mockDataNodeClient) ResendSegmentStats(ctx context.Context, req *datapb.ResendSegmentStatsRequest, opts ...grpc.CallOption) (*datapb.ResendSegmentStatsResponse, error) { return &datapb.ResendSegmentStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), }, nil } -func (c *mockDataNodeClient) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { +func (c *mockDataNodeClient) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { return &internalpb.ShowConfigurationsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), }, nil } -func (c *mockDataNodeClient) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (c *mockDataNodeClient) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { // TODO(dragondriver): change the id, though it's not important in ut nodeID := c.id @@ -264,16 +247,13 @@ func (c *mockDataNodeClient) GetMetrics(ctx context.Context, req *milvuspb.GetMe } return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.DataNodeRole, nodeID), }, nil } -func (c *mockDataNodeClient) Compaction(ctx context.Context, req *datapb.CompactionPlan) (*commonpb.Status, error) { +func (c *mockDataNodeClient) Compaction(ctx context.Context, req *datapb.CompactionPlan, opts ...grpc.CallOption) (*commonpb.Status, error) { if c.ch != nil { c.ch <- struct{}{} if c.compactionResp != nil { @@ -287,79 +267,88 @@ func (c *mockDataNodeClient) Compaction(ctx context.Context, req *datapb.Compact return &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "not implemented"}, nil } -func (c *mockDataNodeClient) GetCompactionState(ctx context.Context, req *datapb.CompactionStateRequest) (*datapb.CompactionStateResponse, error) { +func (c *mockDataNodeClient) GetCompactionState(ctx context.Context, req *datapb.CompactionStateRequest, opts ...grpc.CallOption) (*datapb.CompactionStateResponse, error) { return c.compactionStateResp, nil } -func (c *mockDataNodeClient) Import(ctx context.Context, in *datapb.ImportTaskRequest) (*commonpb.Status, error) { +func (c *mockDataNodeClient) Import(ctx context.Context, in *datapb.ImportTaskRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil } -func (c *mockDataNodeClient) AddImportSegment(ctx context.Context, req *datapb.AddImportSegmentRequest) (*datapb.AddImportSegmentResponse, error) { +func (c *mockDataNodeClient) AddImportSegment(ctx context.Context, req *datapb.AddImportSegmentRequest, opts ...grpc.CallOption) (*datapb.AddImportSegmentResponse, error) { return c.addImportSegmentResp, nil } -func (c *mockDataNodeClient) SyncSegments(ctx context.Context, req *datapb.SyncSegmentsRequest) (*commonpb.Status, error) { +func (c *mockDataNodeClient) SyncSegments(ctx context.Context, req *datapb.SyncSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil } +func (c *mockDataNodeClient) FlushChannels(ctx context.Context, req *datapb.FlushChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil +} + +func (c *mockDataNodeClient) NotifyChannelOperation(ctx context.Context, req *datapb.ChannelOperationsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return merr.Success(), nil +} + +func (c *mockDataNodeClient) CheckChannelOperationProgress(ctx context.Context, req *datapb.ChannelWatchInfo, opts ...grpc.CallOption) (*datapb.ChannelOperationProgressResponse, error) { + return &datapb.ChannelOperationProgressResponse{Status: merr.Success()}, nil +} + func (c *mockDataNodeClient) Stop() error { c.state = commonpb.StateCode_Abnormal return nil } -type mockRootCoordService struct { +type mockRootCoordClient struct { state commonpb.StateCode cnt int64 } -func (m *mockRootCoordService) RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) { - //TODO implement me +func (m *mockRootCoordClient) Close() error { + // TODO implement me panic("implement me") } -func (m *mockRootCoordService) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { +func (m *mockRootCoordClient) RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + // TODO implement me panic("implement me") } -func (m *mockRootCoordService) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { panic("implement me") } -func (m *mockRootCoordService) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func (m *mockRootCoordService) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func newMockRootCoordService() *mockRootCoordService { - return &mockRootCoordService{state: commonpb.StateCode_Healthy} -} - -func (m *mockRootCoordService) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return nil, nil +func (m *mockRootCoordClient) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + panic("implement me") } -func (m *mockRootCoordService) Init() error { - return nil +func newMockRootCoordClient() *mockRootCoordClient { + return &mockRootCoordClient{state: commonpb.StateCode_Healthy} } -func (m *mockRootCoordService) Start() error { - return nil +func (m *mockRootCoordClient) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return nil, nil } -func (m *mockRootCoordService) Stop() error { +func (m *mockRootCoordClient) Stop() error { m.state = commonpb.StateCode_Abnormal return nil } -func (m *mockRootCoordService) Register() error { +func (m *mockRootCoordClient) Register() error { return nil } -func (m *mockRootCoordService) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (m *mockRootCoordClient) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ NodeID: 0, @@ -368,31 +357,28 @@ func (m *mockRootCoordService) GetComponentStates(ctx context.Context) (*milvusp ExtraInfo: []*commonpb.KeyValuePair{}, }, SubcomponentStates: []*milvuspb.ComponentInfo{}, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), }, nil } -func (m *mockRootCoordService) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (m *mockRootCoordClient) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { panic("not implemented") // TODO: Implement } // DDL request -func (m *mockRootCoordService) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) DropCollection(ctx context.Context, req *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) DropCollection(ctx context.Context, req *milvuspb.DropCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) HasCollection(ctx context.Context, req *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { +func (m *mockRootCoordClient) HasCollection(ctx context.Context, req *milvuspb.HasCollectionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (m *mockRootCoordClient) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { // return not exist if req.CollectionID == -1 { err := merr.WrapErrCollectionNotFound(req.GetCollectionID()) @@ -401,10 +387,7 @@ func (m *mockRootCoordService) DescribeCollection(ctx context.Context, req *milv }, nil } return &milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), Schema: &schemapb.CollectionSchema{ Name: "test", }, @@ -413,65 +396,59 @@ func (m *mockRootCoordService) DescribeCollection(ctx context.Context, req *milv }, nil } -func (m *mockRootCoordService) DescribeCollectionInternal(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (m *mockRootCoordClient) DescribeCollectionInternal(ctx context.Context, req *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { return m.DescribeCollection(ctx, req) } -func (m *mockRootCoordService) ShowCollections(ctx context.Context, req *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { +func (m *mockRootCoordClient) ShowCollections(ctx context.Context, req *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) { return &milvuspb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), CollectionNames: []string{"test"}, }, nil } -func (m *mockRootCoordService) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) { +func (m *mockRootCoordClient) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) DropPartition(ctx context.Context, req *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) DropPartition(ctx context.Context, req *milvuspb.DropPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) HasPartition(ctx context.Context, req *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { +func (m *mockRootCoordClient) HasPartition(ctx context.Context, req *milvuspb.HasPartitionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) ShowPartitions(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { +func (m *mockRootCoordClient) ShowPartitions(ctx context.Context, req *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { return &milvuspb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), PartitionNames: []string{"_default"}, PartitionIDs: []int64{0}, }, nil } -func (m *mockRootCoordService) ShowPartitionsInternal(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { +func (m *mockRootCoordClient) ShowPartitionsInternal(ctx context.Context, req *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { return m.ShowPartitions(ctx, req) } // global timestamp allocator -func (m *mockRootCoordService) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { +func (m *mockRootCoordClient) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) { if m.state != commonpb.StateCode_Healthy { return &rootcoordpb.AllocTimestampResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil } @@ -480,79 +457,67 @@ func (m *mockRootCoordService) AllocTimestamp(ctx context.Context, req *rootcoor phy := time.Now().UnixNano() / int64(time.Millisecond) ts := tsoutil.ComposeTS(phy, val) return &rootcoordpb.AllocTimestampResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), Timestamp: ts, Count: req.Count, }, nil } -func (m *mockRootCoordService) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { +func (m *mockRootCoordClient) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { if m.state != commonpb.StateCode_Healthy { return &rootcoordpb.AllocIDResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil } val := atomic.AddInt64(&m.cnt, int64(req.Count)) return &rootcoordpb.AllocIDResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - ID: val, - Count: req.Count, + Status: merr.Success(), + ID: val, + Count: req.Count, }, nil } // segment -func (m *mockRootCoordService) DescribeSegment(ctx context.Context, req *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) { +func (m *mockRootCoordClient) DescribeSegment(ctx context.Context, req *milvuspb.DescribeSegmentRequest, opts ...grpc.CallOption) (*milvuspb.DescribeSegmentResponse, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) ShowSegments(ctx context.Context, req *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { +func (m *mockRootCoordClient) ShowSegments(ctx context.Context, req *milvuspb.ShowSegmentsRequest, opts ...grpc.CallOption) (*milvuspb.ShowSegmentsResponse, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) DescribeSegments(ctx context.Context, req *rootcoordpb.DescribeSegmentsRequest) (*rootcoordpb.DescribeSegmentsResponse, error) { +func (m *mockRootCoordClient) DescribeSegments(ctx context.Context, req *rootcoordpb.DescribeSegmentsRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeSegmentsResponse, error) { panic("implement me") } -func (m *mockRootCoordService) GetDdChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (m *mockRootCoordClient) GetDdChannel(ctx context.Context, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Value: "ddchannel", + Status: merr.Success(), + Value: "ddchannel", }, nil } -func (m *mockRootCoordService) UpdateChannelTimeTick(ctx context.Context, req *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) { +func (m *mockRootCoordClient) UpdateChannelTimeTick(ctx context.Context, req *internalpb.ChannelTimeTickMsg, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) InvalidateCollectionMetaCache(ctx context.Context, req *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) InvalidateCollectionMetaCache(ctx context.Context, req *proxypb.InvalidateCollMetaCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) SegmentFlushCompleted(ctx context.Context, in *datapb.SegmentFlushCompletedMsg) (*commonpb.Status, error) { +func (m *mockRootCoordClient) SegmentFlushCompleted(ctx context.Context, in *datapb.SegmentFlushCompletedMsg, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil } -func (m *mockRootCoordService) AddNewSegment(ctx context.Context, in *datapb.SegmentMsg) (*commonpb.Status, error) { +func (m *mockRootCoordClient) AddNewSegment(ctx context.Context, in *datapb.SegmentMsg, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { +func (m *mockRootCoordClient) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { return &internalpb.ShowConfigurationsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), }, nil } -func (m *mockRootCoordService) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (m *mockRootCoordClient) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { // TODO(dragondriver): change the id, though it's not important in ut nodeID := UniqueID(20210901) @@ -583,33 +548,28 @@ func (m *mockRootCoordService) GetMetrics(ctx context.Context, req *milvuspb.Get } return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.RootCoordRole, nodeID), }, nil } -func (m *mockRootCoordService) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { +func (m *mockRootCoordClient) Import(ctx context.Context, req *milvuspb.ImportRequest, opts ...grpc.CallOption) (*milvuspb.ImportResponse, error) { panic("not implemented") // TODO: Implement } // Check import task state from datanode -func (m *mockRootCoordService) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) { +func (m *mockRootCoordClient) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest, opts ...grpc.CallOption) (*milvuspb.GetImportStateResponse, error) { panic("not implemented") // TODO: Implement } // Returns id array of all import tasks -func (m *mockRootCoordService) ListImportTasks(ctx context.Context, in *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) { +func (m *mockRootCoordClient) ListImportTasks(ctx context.Context, in *milvuspb.ListImportTasksRequest, opts ...grpc.CallOption) (*milvuspb.ListImportTasksResponse, error) { panic("not implemented") // TODO: Implement } -func (m *mockRootCoordService) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil +func (m *mockRootCoordClient) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult, opts ...grpc.CallOption) (*commonpb.Status, error) { + return merr.Success(), nil } type mockCompactionHandler struct { @@ -750,55 +710,55 @@ func (t *mockCompactionTrigger) stop() { panic("not implemented") } -func (m *mockRootCoordService) CreateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) { +func (m *mockRootCoordClient) CreateCredential(ctx context.Context, req *internalpb.CredentialInfo, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func (m *mockRootCoordService) UpdateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) { +func (m *mockRootCoordClient) UpdateCredential(ctx context.Context, req *internalpb.CredentialInfo, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func (m *mockRootCoordService) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCredentialRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func (m *mockRootCoordService) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) { +func (m *mockRootCoordClient) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest, opts ...grpc.CallOption) (*milvuspb.ListCredUsersResponse, error) { panic("implement me") } -func (m *mockRootCoordService) GetCredential(ctx context.Context, req *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) { +func (m *mockRootCoordClient) GetCredential(ctx context.Context, req *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error) { panic("implement me") } -func (m *mockRootCoordService) CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func (m *mockRootCoordService) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func (m *mockRootCoordService) OperateUserRole(ctx context.Context, req *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) OperateUserRole(ctx context.Context, req *milvuspb.OperateUserRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func (m *mockRootCoordService) SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) { +func (m *mockRootCoordClient) SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest, opts ...grpc.CallOption) (*milvuspb.SelectRoleResponse, error) { panic("implement me") } -func (m *mockRootCoordService) SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) { +func (m *mockRootCoordClient) SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest, opts ...grpc.CallOption) (*milvuspb.SelectUserResponse, error) { panic("implement me") } -func (m *mockRootCoordService) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) { +func (m *mockRootCoordClient) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePrivilegeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func (m *mockRootCoordService) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) { +func (m *mockRootCoordClient) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantRequest, opts ...grpc.CallOption) (*milvuspb.SelectGrantResponse, error) { panic("implement me") } -func (m *mockRootCoordService) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { +func (m *mockRootCoordClient) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest, opts ...grpc.CallOption) (*internalpb.ListPolicyResponse, error) { return &internalpb.ListPolicyResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil } diff --git a/internal/datacoord/policy.go b/internal/datacoord/policy.go index ca58340b77c42..e675821c14be2 100644 --- a/internal/datacoord/policy.go +++ b/internal/datacoord/policy.go @@ -23,10 +23,11 @@ import ( "strconv" "time" - "github.com/milvus-io/milvus/pkg/log" "go.uber.org/zap" "go.uber.org/zap/zapcore" "stathat.com/c/consistent" + + "github.com/milvus-io/milvus/pkg/log" ) // RegisterPolicy decides the channels mapping after registering the nodeID @@ -443,7 +444,6 @@ func RoundRobinReassignPolicy(store ROChannelStore, reassigns []*NodeChannelInfo } else { addUpdates[targetID].Channels = append(addUpdates[targetID].Channels, ch) } - } } for _, update := range addUpdates { diff --git a/internal/datacoord/policy_test.go b/internal/datacoord/policy_test.go index 343b4ab0f0d40..17db93a16d329 100644 --- a/internal/datacoord/policy_test.go +++ b/internal/datacoord/policy_test.go @@ -387,7 +387,7 @@ func TestBgCheckForChannelBalance(t *testing.T) { }, time.Now(), }, - //there should be no reallocate + // there should be no reallocate []*NodeChannelInfo{}, nil, }, @@ -409,8 +409,11 @@ func TestBgCheckForChannelBalance(t *testing.T) { "test uneven with zero", args{ []*NodeChannelInfo{ - {1, []*channel{{Name: "chan1", CollectionID: 1}, {Name: "chan2", CollectionID: 1}, - {Name: "chan3", CollectionID: 1}}}, + {1, []*channel{ + {Name: "chan1", CollectionID: 1}, + {Name: "chan2", CollectionID: 1}, + {Name: "chan3", CollectionID: 1}, + }}, {2, []*channel{}}, }, time.Now(), @@ -450,7 +453,7 @@ func TestAvgReassignPolicy(t *testing.T) { }, []*NodeChannelInfo{{1, []*channel{{Name: "chan1", CollectionID: 1}}}}, }, - //as there's no available nodes except the input node, there's no reassign plan generated + // as there's no available nodes except the input node, there's no reassign plan generated []*ChannelOp{}, }, { @@ -468,10 +471,11 @@ func TestAvgReassignPolicy(t *testing.T) { []*NodeChannelInfo{{1, []*channel{{Name: "chan1", CollectionID: 1}}}}, }, []*ChannelOp{ - //as we use ceil to calculate the wanted average number, there should be one reassign - //though the average num less than 1 + // as we use ceil to calculate the wanted average number, there should be one reassign + // though the average num less than 1 {Delete, 1, []*channel{{Name: "chan1", CollectionID: 1}}, nil}, - {Add, 2, []*channel{{Name: "chan1", CollectionID: 1}}, nil}}, + {Add, 2, []*channel{{Name: "chan1", CollectionID: 1}}, nil}, + }, }, { "test_normal_reassigning_for_one_available_nodes", @@ -487,7 +491,8 @@ func TestAvgReassignPolicy(t *testing.T) { }, []*ChannelOp{ {Delete, 1, []*channel{{Name: "chan1", CollectionID: 1}, {Name: "chan2", CollectionID: 1}}, nil}, - {Add, 2, []*channel{{Name: "chan1", CollectionID: 1}, {Name: "chan2", CollectionID: 1}}, nil}}, + {Add, 2, []*channel{{Name: "chan1", CollectionID: 1}, {Name: "chan2", CollectionID: 1}}, nil}, + }, }, { "test_normal_reassigning_for_multiple_available_nodes", @@ -499,7 +504,8 @@ func TestAvgReassignPolicy(t *testing.T) { {Name: "chan1", CollectionID: 1}, {Name: "chan2", CollectionID: 1}, {Name: "chan3", CollectionID: 1}, - {Name: "chan4", CollectionID: 1}}}, + {Name: "chan4", CollectionID: 1}, + }}, 2: {2, []*channel{}}, 3: {3, []*channel{}}, 4: {4, []*channel{}}, @@ -512,11 +518,15 @@ func TestAvgReassignPolicy(t *testing.T) { }}}, }, []*ChannelOp{ - {Delete, 1, []*channel{ - {Name: "chan1", CollectionID: 1}, - {Name: "chan2", CollectionID: 1}, - {Name: "chan3", CollectionID: 1}}, - nil}, + { + Delete, 1, + []*channel{ + {Name: "chan1", CollectionID: 1}, + {Name: "chan2", CollectionID: 1}, + {Name: "chan3", CollectionID: 1}, + }, + nil, + }, {Add, 2, []*channel{{Name: "chan1", CollectionID: 1}}, nil}, {Add, 3, []*channel{{Name: "chan2", CollectionID: 1}}, nil}, {Add, 4, []*channel{{Name: "chan3", CollectionID: 1}}, nil}, @@ -529,12 +539,18 @@ func TestAvgReassignPolicy(t *testing.T) { memkv.NewMemoryKV(), map[int64]*NodeChannelInfo{ 1: {1, []*channel{ - {Name: "chan1", CollectionID: 1}, {Name: "chan2", CollectionID: 1}, - {Name: "chan3", CollectionID: 1}, {Name: "chan4", CollectionID: 1}, - {Name: "chan5", CollectionID: 1}, {Name: "chan6", CollectionID: 1}, - {Name: "chan7", CollectionID: 1}, {Name: "chan8", CollectionID: 1}, - {Name: "chan9", CollectionID: 1}, {Name: "chan10", CollectionID: 1}, - {Name: "chan11", CollectionID: 1}, {Name: "chan12", CollectionID: 1}, + {Name: "chan1", CollectionID: 1}, + {Name: "chan2", CollectionID: 1}, + {Name: "chan3", CollectionID: 1}, + {Name: "chan4", CollectionID: 1}, + {Name: "chan5", CollectionID: 1}, + {Name: "chan6", CollectionID: 1}, + {Name: "chan7", CollectionID: 1}, + {Name: "chan8", CollectionID: 1}, + {Name: "chan9", CollectionID: 1}, + {Name: "chan10", CollectionID: 1}, + {Name: "chan11", CollectionID: 1}, + {Name: "chan12", CollectionID: 1}, }}, 2: {2, []*channel{ {Name: "chan13", CollectionID: 1}, {Name: "chan14", CollectionID: 1}, @@ -544,33 +560,51 @@ func TestAvgReassignPolicy(t *testing.T) { }, }, []*NodeChannelInfo{{1, []*channel{ - {Name: "chan1", CollectionID: 1}, {Name: "chan2", CollectionID: 1}, - {Name: "chan3", CollectionID: 1}, {Name: "chan4", CollectionID: 1}, - {Name: "chan5", CollectionID: 1}, {Name: "chan6", CollectionID: 1}, - {Name: "chan7", CollectionID: 1}, {Name: "chan8", CollectionID: 1}, - {Name: "chan9", CollectionID: 1}, {Name: "chan10", CollectionID: 1}, - {Name: "chan11", CollectionID: 1}, {Name: "chan12", CollectionID: 1}, + {Name: "chan1", CollectionID: 1}, + {Name: "chan2", CollectionID: 1}, + {Name: "chan3", CollectionID: 1}, + {Name: "chan4", CollectionID: 1}, + {Name: "chan5", CollectionID: 1}, + {Name: "chan6", CollectionID: 1}, + {Name: "chan7", CollectionID: 1}, + {Name: "chan8", CollectionID: 1}, + {Name: "chan9", CollectionID: 1}, + {Name: "chan10", CollectionID: 1}, + {Name: "chan11", CollectionID: 1}, + {Name: "chan12", CollectionID: 1}, }}}, }, []*ChannelOp{ {Delete, 1, []*channel{ - {Name: "chan1", CollectionID: 1}, {Name: "chan2", CollectionID: 1}, - {Name: "chan3", CollectionID: 1}, {Name: "chan4", CollectionID: 1}, - {Name: "chan5", CollectionID: 1}, {Name: "chan6", CollectionID: 1}, - {Name: "chan7", CollectionID: 1}, {Name: "chan8", CollectionID: 1}, - {Name: "chan9", CollectionID: 1}, {Name: "chan10", CollectionID: 1}, - {Name: "chan11", CollectionID: 1}, {Name: "chan12", CollectionID: 1}, + {Name: "chan1", CollectionID: 1}, + {Name: "chan2", CollectionID: 1}, + {Name: "chan3", CollectionID: 1}, + {Name: "chan4", CollectionID: 1}, + {Name: "chan5", CollectionID: 1}, + {Name: "chan6", CollectionID: 1}, + {Name: "chan7", CollectionID: 1}, + {Name: "chan8", CollectionID: 1}, + {Name: "chan9", CollectionID: 1}, + {Name: "chan10", CollectionID: 1}, + {Name: "chan11", CollectionID: 1}, + {Name: "chan12", CollectionID: 1}, }, nil}, {Add, 4, []*channel{ - {Name: "chan1", CollectionID: 1}, {Name: "chan2", CollectionID: 1}, - {Name: "chan3", CollectionID: 1}, {Name: "chan4", CollectionID: 1}, - {Name: "chan5", CollectionID: 1}}, nil}, + {Name: "chan1", CollectionID: 1}, + {Name: "chan2", CollectionID: 1}, + {Name: "chan3", CollectionID: 1}, + {Name: "chan4", CollectionID: 1}, + {Name: "chan5", CollectionID: 1}, + }, nil}, {Add, 3, []*channel{ - {Name: "chan6", CollectionID: 1}, {Name: "chan7", CollectionID: 1}, - {Name: "chan8", CollectionID: 1}, {Name: "chan9", CollectionID: 1}, + {Name: "chan6", CollectionID: 1}, + {Name: "chan7", CollectionID: 1}, + {Name: "chan8", CollectionID: 1}, + {Name: "chan9", CollectionID: 1}, }, nil}, {Add, 2, []*channel{ - {Name: "chan10", CollectionID: 1}, {Name: "chan11", CollectionID: 1}, + {Name: "chan10", CollectionID: 1}, + {Name: "chan11", CollectionID: 1}, {Name: "chan12", CollectionID: 1}, }, nil}, }, diff --git a/internal/datacoord/segment_allocation_policy.go b/internal/datacoord/segment_allocation_policy.go index 1071691273aef..0a0271c2db503 100644 --- a/internal/datacoord/segment_allocation_policy.go +++ b/internal/datacoord/segment_allocation_policy.go @@ -21,10 +21,10 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -68,7 +68,8 @@ type AllocatePolicy func(segments []*SegmentInfo, count int64, // AllocatePolicyV1 v1 policy simple allocation policy using Greedy Algorithm func AllocatePolicyV1(segments []*SegmentInfo, count int64, - maxCountPerSegment int64) ([]*Allocation, []*Allocation) { + maxCountPerSegment int64, +) ([]*Allocation, []*Allocation) { newSegmentAllocations := make([]*Allocation, 0) existedSegmentAllocations := make([]*Allocation, 0) // create new segment if count >= max num diff --git a/internal/datacoord/segment_allocation_policy_test.go b/internal/datacoord/segment_allocation_policy_test.go index ad45216bf389b..250d4b55b3f20 100644 --- a/internal/datacoord/segment_allocation_policy_test.go +++ b/internal/datacoord/segment_allocation_policy_test.go @@ -21,13 +21,13 @@ import ( "testing" "time" - "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) func TestUpperLimitCalBySchema(t *testing.T) { diff --git a/internal/datacoord/segment_info.go b/internal/datacoord/segment_info.go index 95f3e86f740eb..6d630e51275c2 100644 --- a/internal/datacoord/segment_info.go +++ b/internal/datacoord/segment_info.go @@ -238,7 +238,7 @@ func (s *SegmentInfo) Clone(opts ...SegmentInfoOption) *SegmentInfo { allocations: s.allocations, lastFlushTime: s.lastFlushTime, isCompacting: s.isCompacting, - //cannot copy size, since binlog may be changed + // cannot copy size, since binlog may be changed lastWrittenTime: s.lastWrittenTime, } for _, opt := range opts { diff --git a/internal/datacoord/segment_manager.go b/internal/datacoord/segment_manager.go index 0d87fc21137fa..08b2206755e30 100644 --- a/internal/datacoord/segment_manager.go +++ b/internal/datacoord/segment_manager.go @@ -23,7 +23,6 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "go.opentelemetry.io/otel" "go.uber.org/zap" @@ -31,17 +30,16 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/retry" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) -var ( - // allocPool pool of Allocation, to reduce allocation of Allocation - allocPool = sync.Pool{ - New: func() interface{} { - return &Allocation{} - }, - } -) +// allocPool pool of Allocation, to reduce allocation of Allocation +var allocPool = sync.Pool{ + New: func() interface{} { + return &Allocation{} + }, +} // getAllocation unifies way to retrieve allocation struct func getAllocation(numOfRows int64) *Allocation { @@ -230,7 +228,7 @@ func (s *SegmentManager) loadSegmentsFromMeta() { } func (s *SegmentManager) maybeResetLastExpireForSegments() error { - //for all sealed and growing segments, need to reset last expire + // for all sealed and growing segments, need to reset last expire if len(s.segments) > 0 { var latestTs uint64 allocateErr := retry.Do(context.Background(), func() error { @@ -257,8 +255,13 @@ func (s *SegmentManager) maybeResetLastExpireForSegments() error { // AllocSegment allocate segment per request collcation, partication, channel and rows func (s *SegmentManager) AllocSegment(ctx context.Context, collectionID UniqueID, - partitionID UniqueID, channelName string, requestRows int64) ([]*Allocation, error) { - + partitionID UniqueID, channelName string, requestRows int64, +) ([]*Allocation, error) { + log := log.Ctx(ctx). + With(zap.Int64("collectionID", collectionID)). + With(zap.Int64("partitionID", partitionID)). + With(zap.String("channelName", channelName)). + With(zap.Int64("requestRows", requestRows)) _, sp := otel.Tracer(typeutil.DataCoordRole).Start(ctx, "Alloc-Segment") defer sp.End() s.mu.Lock() @@ -269,7 +272,7 @@ func (s *SegmentManager) AllocSegment(ctx context.Context, collectionID UniqueID for _, segmentID := range s.segments { segment := s.meta.GetHealthySegment(segmentID) if segment == nil { - log.Warn("Failed to get seginfo from meta", zap.Int64("id", segmentID)) + log.Warn("Failed to get segment info from meta", zap.Int64("id", segmentID)) continue } if !satisfy(segment, collectionID, partitionID, channelName) || !isGrowing(segment) { @@ -294,6 +297,7 @@ func (s *SegmentManager) AllocSegment(ctx context.Context, collectionID UniqueID for _, allocation := range newSegmentAllocations { segment, err := s.openNewSegment(ctx, collectionID, partitionID, channelName, commonpb.SegmentState_Growing) if err != nil { + log.Error("Failed to open new segment for segment allocation") return nil, err } allocation.ExpireTime = expireTs @@ -306,6 +310,7 @@ func (s *SegmentManager) AllocSegment(ctx context.Context, collectionID UniqueID for _, allocation := range existedSegmentAllocations { allocation.ExpireTime = expireTs if err := s.meta.AddAllocation(allocation.SegmentID, allocation); err != nil { + log.Error("Failed to add allocation to existed segment", zap.Int64("segmentID", allocation.SegmentID)) return nil, err } } @@ -316,7 +321,8 @@ func (s *SegmentManager) AllocSegment(ctx context.Context, collectionID UniqueID // allocSegmentForImport allocates one segment allocation for bulk insert. func (s *SegmentManager) allocSegmentForImport(ctx context.Context, collectionID UniqueID, - partitionID UniqueID, channelName string, requestRows int64, importTaskID int64) (*Allocation, error) { + partitionID UniqueID, channelName string, requestRows int64, importTaskID int64, +) (*Allocation, error) { _, sp := otel.Tracer(typeutil.DataCoordRole).Start(ctx, "Alloc-ImportSegment") defer sp.End() s.mu.Lock() @@ -369,7 +375,9 @@ func (s *SegmentManager) genExpireTs(ctx context.Context, isImported bool) (Time } func (s *SegmentManager) openNewSegment(ctx context.Context, collectionID UniqueID, partitionID UniqueID, - channelName string, segmentState commonpb.SegmentState) (*SegmentInfo, error) { + channelName string, segmentState commonpb.SegmentState, +) (*SegmentInfo, error) { + log := log.Ctx(ctx) ctx, sp := otel.Tracer(typeutil.DataCoordRole).Start(ctx, "open-Segment") defer sp.End() id, err := s.allocator.allocID(ctx) @@ -397,7 +405,7 @@ func (s *SegmentManager) openNewSegment(ctx context.Context, collectionID Unique segmentInfo.IsImporting = true } segment := NewSegmentInfo(segmentInfo) - if err := s.meta.AddSegment(segment); err != nil { + if err := s.meta.AddSegment(ctx, segment); err != nil { log.Error("failed to add segment to DataCoord", zap.Error(err)) return nil, err } diff --git a/internal/datacoord/segment_manager_test.go b/internal/datacoord/segment_manager_test.go index 592aa21ac656c..b99e3004d2764 100644 --- a/internal/datacoord/segment_manager_test.go +++ b/internal/datacoord/segment_manager_test.go @@ -56,7 +56,7 @@ func TestManagerOptions(t *testing.T) { opt := withCalUpperLimitPolicy(defaultCalUpperLimitPolicy()) assert.NotNil(t, opt) - //manual set nil`` + // manual set nil`` segmentManager.estimatePolicy = nil opt.apply(segmentManager) assert.True(t, segmentManager.estimatePolicy != nil) @@ -144,12 +144,12 @@ func TestAllocSegment(t *testing.T) { } func TestLastExpireReset(t *testing.T) { - //set up meta on dc + // set up meta on dc ctx := context.Background() paramtable.Init() Params.Save(Params.DataCoordCfg.AllocLatestExpireAttempt.Key, "1") Params.Save(Params.DataCoordCfg.SegmentMaxSize.Key, "1") - mockAllocator := newRootCoordAllocator(newMockRootCoordService()) + mockAllocator := newRootCoordAllocator(newMockRootCoordClient()) etcdCli, _ := etcd.GetEtcdClient( Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), Params.EtcdCfg.EtcdUseSSL.GetAsBool(), @@ -170,10 +170,20 @@ func TestLastExpireReset(t *testing.T) { collID, err := mockAllocator.allocID(ctx) assert.Nil(t, err) meta.AddCollection(&collectionInfo{ID: collID, Schema: schema}) + initSegment := &SegmentInfo{ + SegmentInfo: &datapb.SegmentInfo{ + ID: 1, + InsertChannel: "ch1", + State: commonpb.SegmentState_Growing, + }, + } + meta.AddSegment(context.TODO(), initSegment) - //assign segments, set max segment to only 1MB, equalling to 10485 rows + // assign segments, set max segment to only 1MB, equalling to 10485 rows var bigRows, smallRows int64 = 10000, 1000 segmentManager, _ := newSegmentManager(meta, mockAllocator) + initSegment.SegmentInfo.State = commonpb.SegmentState_Dropped + meta.segments.SetSegment(1, initSegment) allocs, _ := segmentManager.AllocSegment(context.Background(), collID, 0, channelName, bigRows) segmentID1, expire1 := allocs[0].SegmentID, allocs[0].ExpireTime time.Sleep(100 * time.Millisecond) @@ -183,7 +193,7 @@ func TestLastExpireReset(t *testing.T) { allocs, _ = segmentManager.AllocSegment(context.Background(), collID, 0, channelName, smallRows) segmentID3, expire3 := allocs[0].SegmentID, allocs[0].ExpireTime - //simulate handleTimeTick op on dataCoord + // simulate handleTimeTick op on dataCoord meta.SetCurrentRows(segmentID1, bigRows) meta.SetCurrentRows(segmentID2, bigRows) meta.SetCurrentRows(segmentID3, smallRows) @@ -192,11 +202,11 @@ func TestLastExpireReset(t *testing.T) { assert.Equal(t, commonpb.SegmentState_Sealed, meta.GetSegment(segmentID2).GetState()) assert.Equal(t, commonpb.SegmentState_Growing, meta.GetSegment(segmentID3).GetState()) - //pretend that dataCoord break down + // pretend that dataCoord break down metaKV.Close() etcdCli.Close() - //dataCoord restart + // dataCoord restart newEtcdCli, _ := etcd.GetEtcdClient(Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), Params.EtcdCfg.EtcdUseSSL.GetAsBool(), Params.EtcdCfg.Endpoints.GetAsStrings(), Params.EtcdCfg.EtcdTLSCert.GetValue(), Params.EtcdCfg.EtcdTLSKey.GetValue(), Params.EtcdCfg.EtcdTLSCACert.GetValue(), Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) @@ -207,14 +217,14 @@ func TestLastExpireReset(t *testing.T) { restartedMeta.AddCollection(&collectionInfo{ID: collID, Schema: schema}) assert.Nil(t, err) newSegmentManager, _ := newSegmentManager(restartedMeta, mockAllocator) - //reset row number to avoid being cleaned by empty segment + // reset row number to avoid being cleaned by empty segment restartedMeta.SetCurrentRows(segmentID1, bigRows) restartedMeta.SetCurrentRows(segmentID2, bigRows) restartedMeta.SetCurrentRows(segmentID3, smallRows) - //verify lastExpire of growing and sealed segments + // verify lastExpire of growing and sealed segments segment1, segment2, segment3 := restartedMeta.GetSegment(segmentID1), restartedMeta.GetSegment(segmentID2), restartedMeta.GetSegment(segmentID3) - //segmentState should not be altered but growing segment's lastExpire has been reset to the latest + // segmentState should not be altered but growing segment's lastExpire has been reset to the latest assert.Equal(t, commonpb.SegmentState_Sealed, segment1.GetState()) assert.Equal(t, commonpb.SegmentState_Sealed, segment2.GetState()) assert.Equal(t, commonpb.SegmentState_Growing, segment3.GetState()) @@ -308,11 +318,11 @@ func TestLoadSegmentsFromMeta(t *testing.T) { MaxRowNum: 100, LastExpireTime: 1000, } - err = meta.AddSegment(NewSegmentInfo(sealedSegment)) + err = meta.AddSegment(context.TODO(), NewSegmentInfo(sealedSegment)) assert.NoError(t, err) - err = meta.AddSegment(NewSegmentInfo(growingSegment)) + err = meta.AddSegment(context.TODO(), NewSegmentInfo(growingSegment)) assert.NoError(t, err) - err = meta.AddSegment(NewSegmentInfo(flushedSegment)) + err = meta.AddSegment(context.TODO(), NewSegmentInfo(flushedSegment)) assert.NoError(t, err) segmentManager, _ := newSegmentManager(meta, mockAllocator) @@ -398,7 +408,7 @@ func TestAllocRowsLargerThanOneSegment(t *testing.T) { assert.NoError(t, err) meta.AddCollection(&collectionInfo{ID: collID, Schema: schema}) - var mockPolicy = func(schema *schemapb.CollectionSchema) (int, error) { + mockPolicy := func(schema *schemapb.CollectionSchema) (int, error) { return 1, nil } segmentManager, _ := newSegmentManager(meta, mockAllocator, withCalUpperLimitPolicy(mockPolicy)) @@ -420,7 +430,7 @@ func TestExpireAllocation(t *testing.T) { assert.NoError(t, err) meta.AddCollection(&collectionInfo{ID: collID, Schema: schema}) - var mockPolicy = func(schema *schemapb.CollectionSchema) (int, error) { + mockPolicy := func(schema *schemapb.CollectionSchema) (int, error) { return 10000000, nil } segmentManager, _ := newSegmentManager(meta, mockAllocator, withCalUpperLimitPolicy(mockPolicy)) @@ -538,7 +548,7 @@ func TestTryToSealSegment(t *testing.T) { collID, err := mockAllocator.allocID(context.Background()) assert.NoError(t, err) meta.AddCollection(&collectionInfo{ID: collID, Schema: schema}) - segmentManager, _ := newSegmentManager(meta, mockAllocator, withSegmentSealPolices(sealByLifetimePolicy(math.MinInt64))) //always seal + segmentManager, _ := newSegmentManager(meta, mockAllocator, withSegmentSealPolices(sealByLifetimePolicy(math.MinInt64))) // always seal allocations, err := segmentManager.AllocSegment(context.TODO(), collID, 0, "c1", 2) assert.NoError(t, err) assert.EqualValues(t, 1, len(allocations)) @@ -563,7 +573,7 @@ func TestTryToSealSegment(t *testing.T) { collID, err := mockAllocator.allocID(context.Background()) assert.NoError(t, err) meta.AddCollection(&collectionInfo{ID: collID, Schema: schema}) - segmentManager, _ := newSegmentManager(meta, mockAllocator, withChannelSealPolices(getChannelOpenSegCapacityPolicy(-1))) //always seal + segmentManager, _ := newSegmentManager(meta, mockAllocator, withChannelSealPolices(getChannelOpenSegCapacityPolicy(-1))) // always seal allocations, err := segmentManager.AllocSegment(context.TODO(), collID, 0, "c1", 2) assert.NoError(t, err) assert.EqualValues(t, 1, len(allocations)) @@ -590,7 +600,7 @@ func TestTryToSealSegment(t *testing.T) { meta.AddCollection(&collectionInfo{ID: collID, Schema: schema}) segmentManager, _ := newSegmentManager(meta, mockAllocator, withSegmentSealPolices(sealByLifetimePolicy(math.MinInt64)), - withChannelSealPolices(getChannelOpenSegCapacityPolicy(-1))) //always seal + withChannelSealPolices(getChannelOpenSegCapacityPolicy(-1))) // always seal allocations, err := segmentManager.AllocSegment(context.TODO(), collID, 0, "c1", 2) assert.NoError(t, err) assert.EqualValues(t, 1, len(allocations)) @@ -702,7 +712,7 @@ func TestTryToSealSegment(t *testing.T) { collID, err := mockAllocator.allocID(context.Background()) assert.NoError(t, err) meta.AddCollection(&collectionInfo{ID: collID, Schema: schema}) - segmentManager, _ := newSegmentManager(meta, mockAllocator, withSegmentSealPolices(sealByLifetimePolicy(math.MinInt64))) //always seal + segmentManager, _ := newSegmentManager(meta, mockAllocator, withSegmentSealPolices(sealByLifetimePolicy(math.MinInt64))) // always seal allocations, err := segmentManager.AllocSegment(context.TODO(), collID, 0, "c1", 2) assert.NoError(t, err) assert.EqualValues(t, 1, len(allocations)) @@ -731,7 +741,7 @@ func TestTryToSealSegment(t *testing.T) { collID, err := mockAllocator.allocID(context.Background()) assert.NoError(t, err) meta.AddCollection(&collectionInfo{ID: collID, Schema: schema}) - segmentManager, _ := newSegmentManager(meta, mockAllocator, withChannelSealPolices(getChannelOpenSegCapacityPolicy(-1))) //always seal + segmentManager, _ := newSegmentManager(meta, mockAllocator, withChannelSealPolices(getChannelOpenSegCapacityPolicy(-1))) // always seal allocations, err := segmentManager.AllocSegment(context.TODO(), collID, 0, "c1", 2) assert.NoError(t, err) assert.EqualValues(t, 1, len(allocations)) @@ -790,7 +800,6 @@ func TestAllocationPool(t *testing.T) { assert.EqualValues(t, 100, allo.NumOfRows) assert.EqualValues(t, 0, allo.ExpireTime) assert.EqualValues(t, 0, allo.SegmentID) - }) } diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index 023b12637242b..7cde094158f49 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -28,6 +28,7 @@ import ( "github.com/blang/semver/v4" "github.com/cockroachdb/errors" + "github.com/tikv/client-go/v2/txnkv" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" @@ -37,6 +38,7 @@ import ( rootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/kv/tikv" "github.com/milvus-io/milvus/internal/metastore/kv/datacoord" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/storage" @@ -47,6 +49,7 @@ import ( "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" + "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -59,7 +62,7 @@ import ( ) const ( - connEtcdMaxRetryTime = 100 + connMetaMaxRetryTime = 100 allPartitionID = 0 // partitionID means no filtering ) @@ -79,11 +82,11 @@ type ( Timestamp = typeutil.Timestamp ) -type dataNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) +type dataNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) -type indexNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.IndexNode, error) +type indexNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) -type rootCoordCreatorFunc func(ctx context.Context, metaRootPath string, etcdClient *clientv3.Client) (types.RootCoord, error) +type rootCoordCreatorFunc func(ctx context.Context, metaRootPath string, etcdClient *clientv3.Client) (types.RootCoordClient, error) // makes sure Server implements `DataCoord` var _ types.DataCoord = (*Server)(nil) @@ -102,15 +105,17 @@ type Server struct { helper ServerHelper etcdCli *clientv3.Client + tikvCli *txnkv.Client address string - kvClient kv.WatchKV + watchClient kv.WatchKV + kv kv.MetaKv meta *meta segmentManager Manager allocator allocator cluster *Cluster sessionManager *SessionManager channelManager *ChannelManager - rootCoordClient types.RootCoord + rootCoordClient types.RootCoordClient garbageCollector *garbageCollector gcOpt GcOption handler Handler @@ -129,7 +134,8 @@ type Server struct { icSession *sessionutil.Session dnEventCh <-chan *sessionutil.SessionEvent inEventCh <-chan *sessionutil.SessionEvent - //qcEventCh <-chan *sessionutil.SessionEvent + // qcEventCh <-chan *sessionutil.SessionEvent + qnEventCh <-chan *sessionutil.SessionEvent enableActiveStandBy bool activateFunc func() error @@ -137,11 +143,12 @@ type Server struct { dataNodeCreator dataNodeCreatorFunc indexNodeCreator indexNodeCreatorFunc rootCoordClientCreator rootCoordCreatorFunc - //indexCoord types.IndexCoord + // indexCoord types.IndexCoord - //segReferManager *SegmentReferenceManager - indexBuilder *indexBuilder - indexNodeManager *IndexNodeManager + // segReferManager *SegmentReferenceManager + indexBuilder *indexBuilder + indexNodeManager *IndexNodeManager + indexEngineVersionManager *IndexEngineVersionManager // manage ways that data coord access other coord broker Broker @@ -220,15 +227,15 @@ func CreateServer(ctx context.Context, factory dependency.Factory, opts ...Optio return s } -func defaultDataNodeCreatorFunc(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { +func defaultDataNodeCreatorFunc(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { return datanodeclient.NewClient(ctx, addr, nodeID) } -func defaultIndexNodeCreatorFunc(ctx context.Context, addr string, nodeID int64) (types.IndexNode, error) { +func defaultIndexNodeCreatorFunc(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) { return indexnodeclient.NewClient(ctx, addr, nodeID, Params.DataCoordCfg.WithCredential.GetAsBool()) } -func defaultRootCoordCreatorFunc(ctx context.Context, metaRootPath string, client *clientv3.Client) (types.RootCoord, error) { +func defaultRootCoordCreatorFunc(ctx context.Context, metaRootPath string, client *clientv3.Client) (types.RootCoordClient, error) { return rootcoordclient.NewClient(ctx, metaRootPath, client) } @@ -387,13 +394,8 @@ func (s *Server) startDataCoord() { s.compactionTrigger.start() } s.startServerLoop() - // DataCoord (re)starts successfully and starts to collection segment stats - // data from all DataNode. - // This will prevent DataCoord from missing out any important segment stats - // data while offline. - log.Info("DataCoord (re)starts successfully and re-collecting segment stats from DataNodes") - s.reCollectSegmentStats(s.ctx) s.stateCode.Store(commonpb.StateCode_Healthy) + sessionutil.SaveServerInfo(typeutil.DataCoordRole, s.session.ServerID) } func (s *Server) initCluster() error { @@ -402,7 +404,7 @@ func (s *Server) initCluster() error { } var err error - s.channelManager, err = NewChannelManager(s.kvClient, s.handler, withMsgstreamFactory(s.factory), + s.channelManager, err = NewChannelManager(s.watchClient, s.handler, withMsgstreamFactory(s.factory), withStateChecker(), withBgChecker()) if err != nil { return err @@ -421,15 +423,19 @@ func (s *Server) SetEtcdClient(client *clientv3.Client) { s.etcdCli = client } -func (s *Server) SetRootCoord(rootCoord types.RootCoord) { +func (s *Server) SetTiKVClient(client *txnkv.Client) { + s.tikvCli = client +} + +func (s *Server) SetRootCoordClient(rootCoord types.RootCoordClient) { s.rootCoordClient = rootCoord } -func (s *Server) SetDataNodeCreator(f func(context.Context, string, int64) (types.DataNode, error)) { +func (s *Server) SetDataNodeCreator(f func(context.Context, string, int64) (types.DataNodeClient, error)) { s.dataNodeCreator = f } -func (s *Server) SetIndexNodeCreator(f func(context.Context, string, int64) (types.IndexNode, error)) { +func (s *Server) SetIndexNodeCreator(f func(context.Context, string, int64) (types.IndexNodeClient, error)) { s.indexNodeCreator = f } @@ -514,6 +520,15 @@ func (s *Server) initServiceDiscovery() error { } s.inEventCh = s.session.WatchServices(typeutil.IndexNodeRole, inRevision+1, nil) + s.indexEngineVersionManager = newIndexEngineVersionManager() + qnSessions, qnRevision, err := s.session.GetSessions(typeutil.QueryNodeRole) + if err != nil { + log.Warn("DataCoord get QueryNode sessions failed", zap.Error(err)) + return err + } + s.indexEngineVersionManager.Startup(qnSessions) + s.qnEventCh = s.session.WatchServicesWithVersionRange(typeutil.QueryNodeRole, r, qnRevision+1, nil) + return nil } @@ -532,24 +547,33 @@ func (s *Server) initMeta(chunkManager storage.ChunkManager) error { if s.meta != nil { return nil } - etcdKV := etcdkv.NewEtcdKV(s.etcdCli, Params.EtcdCfg.MetaRootPath.GetValue()) + s.watchClient = etcdkv.NewEtcdKV(s.etcdCli, Params.EtcdCfg.MetaRootPath.GetValue()) + metaType := Params.MetaStoreCfg.MetaStoreType.GetValue() + log.Info("data coordinator connecting to metadata store", zap.String("metaType", metaType)) + if metaType == util.MetaStoreTypeTiKV { + s.kv = tikv.NewTiKV(s.tikvCli, Params.TiKVCfg.MetaRootPath.GetValue()) + } else if metaType == util.MetaStoreTypeEtcd { + s.kv = etcdkv.NewEtcdKV(s.etcdCli, Params.EtcdCfg.MetaRootPath.GetValue()) + } else { + return retry.Unrecoverable(fmt.Errorf("not supported meta store: %s", metaType)) + } + log.Info("data coordinator successfully connected to metadata store", zap.String("metaType", metaType)) - s.kvClient = etcdKV reloadEtcdFn := func() error { var err error - catalog := datacoord.NewCatalog(etcdKV, chunkManager.RootPath(), Params.EtcdCfg.MetaRootPath.GetValue()) + catalog := datacoord.NewCatalog(s.kv, chunkManager.RootPath(), Params.EtcdCfg.MetaRootPath.GetValue()) s.meta, err = newMeta(s.ctx, catalog, chunkManager) if err != nil { return err } return nil } - return retry.Do(s.ctx, reloadEtcdFn, retry.Attempts(connEtcdMaxRetryTime)) + return retry.Do(s.ctx, reloadEtcdFn, retry.Attempts(connMetaMaxRetryTime)) } func (s *Server) initIndexBuilder(manager storage.ChunkManager) { if s.indexBuilder == nil { - s.indexBuilder = newIndexBuilder(s.ctx, s.meta, s.indexNodeManager, manager) + s.indexBuilder = newIndexBuilder(s.ctx, s.meta, s.indexNodeManager, manager, s.indexEngineVersionManager) } } @@ -586,7 +610,7 @@ func (s *Server) startDataNodeTtLoop(ctx context.Context) { } subName := fmt.Sprintf("%s-%d-datanodeTl", Params.CommonCfg.DataCoordSubName.GetValue(), paramtable.GetNodeID()) - ttMsgStream.AsConsumer([]string{timeTickChannel}, subName, mqwrapper.SubscriptionPositionLatest) + ttMsgStream.AsConsumer(context.TODO(), []string{timeTickChannel}, subName, mqwrapper.SubscriptionPositionLatest) log.Info("DataCoord creates the timetick channel consumer", zap.String("timeTickChannel", timeTickChannel), zap.String("subscription", subName)) @@ -800,6 +824,19 @@ func (s *Server) watchService(ctx context.Context) { }() return } + case event, ok := <-s.qnEventCh: + if !ok { + s.stopServiceWatch() + return + } + if err := s.handleSessionEvent(ctx, typeutil.QueryNodeRole, event); err != nil { + go func() { + if err := s.Stop(); err != nil { + log.Warn("DataCoord server stop error", zap.Error(err)) + } + }() + return + } } } } @@ -863,6 +900,26 @@ func (s *Server) handleSessionEvent(ctx context.Context, role string, event *ses log.Warn("receive unknown service event type", zap.Any("type", event.EventType)) } + case typeutil.QueryNodeRole: + switch event.EventType { + case sessionutil.SessionAddEvent: + log.Info("received querynode register", + zap.String("address", event.Session.Address), + zap.Int64("serverID", event.Session.ServerID)) + s.indexEngineVersionManager.AddNode(event.Session) + case sessionutil.SessionDelEvent: + log.Info("received querynode unregister", + zap.String("address", event.Session.Address), + zap.Int64("serverID", event.Session.ServerID)) + s.indexEngineVersionManager.RemoveNode(event.Session) + case sessionutil.SessionUpdateEvent: + serverID := event.Session.ServerID + log.Info("received querynode SessionUpdateEvent", zap.Int64("serverID", serverID)) + s.indexEngineVersionManager.Update(event.Session) + default: + log.Warn("receive unknown service event type", + zap.Any("type", event.EventType)) + } } return nil @@ -884,7 +941,7 @@ func (s *Server) startFlushLoop(ctx context.Context) { logutil.Logger(s.ctx).Info("flush loop shutdown") return case segmentID := <-s.flushCh: - //Ignore return error + // Ignore return error log.Info("flush successfully", zap.Any("segmentID", segmentID)) err := s.postFlush(ctx, segmentID) if err != nil { @@ -952,10 +1009,7 @@ func (s *Server) initRootCoordClient() error { return err } } - if err = s.rootCoordClient.Init(); err != nil { - return err - } - return s.rootCoordClient.Start() + return nil } // Stop do the Server finalize processes @@ -991,8 +1045,17 @@ func (s *Server) Stop() error { // CleanMeta only for test func (s *Server) CleanMeta() error { - log.Debug("clean meta", zap.Any("kv", s.kvClient)) - return s.kvClient.RemoveWithPrefix("") + log.Debug("clean meta", zap.Any("kv", s.kv)) + err := s.kv.RemoveWithPrefix("") + err2 := s.watchClient.RemoveWithPrefix("") + if err2 != nil { + if err != nil { + err = fmt.Errorf("Failed to CleanMeta[metadata cleanup error: %w][watchdata cleanup error: %v]", err, err2) + } else { + err = err2 + } + } + return err } func (s *Server) stopServerLoop() { @@ -1000,7 +1063,7 @@ func (s *Server) stopServerLoop() { s.serverLoopWg.Wait() } -//func (s *Server) validateAllocRequest(collID UniqueID, partID UniqueID, channelName string) error { +// func (s *Server) validateAllocRequest(collID UniqueID, partID UniqueID, channelName string) error { // if !s.meta.HasCollection(collID) { // return fmt.Errorf("can not find collection %d", collID) // } @@ -1013,7 +1076,7 @@ func (s *Server) stopServerLoop() { // } // } // return fmt.Errorf("can not find channel %s", channelName) -//} +// } // loadCollectionFromRootCoord communicates with RootCoord and asks for collection information. // collection information will be added to server meta info. @@ -1043,25 +1106,3 @@ func (s *Server) loadCollectionFromRootCoord(ctx context.Context, collectionID i s.meta.AddCollection(collInfo) return nil } - -func (s *Server) reCollectSegmentStats(ctx context.Context) { - if s.channelManager == nil { - log.Error("null channel manager found, which should NOT happen in non-testing environment") - return - } - nodes := s.sessionManager.getLiveNodeIDs() - log.Info("re-collecting segment stats from DataNodes", - zap.Int64s("DataNode IDs", nodes)) - - reCollectFunc := func() error { - err := s.cluster.ReCollectSegmentStats(ctx) - if err != nil { - return err - } - return nil - } - - if err := retry.Do(ctx, reCollectFunc, retry.Attempts(20), retry.Sleep(time.Millisecond*100), retry.MaxSleepTime(5*time.Second)); err != nil { - panic(err) - } -} diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index 1696b99aa4907..233fc45cb3cec 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -36,12 +36,14 @@ import ( "github.com/stretchr/testify/require" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + grpcStatus "google.golang.org/grpc/status" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/indexnode" "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -51,15 +53,18 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + grpcmock "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tikv" "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -79,16 +84,22 @@ func TestMain(m *testing.M) { paramtable.Get().Save(Params.EtcdCfg.Endpoints.Key, strings.Join(addrs, ",")) rand.Seed(time.Now().UnixNano()) - os.Exit(m.Run()) + parameters := []string{"tikv", "etcd"} + var code int + for _, v := range parameters { + paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, v) + code = m.Run() + } + os.Exit(code) } func TestGetSegmentInfoChannel(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) t.Run("get segment info channel", func(t *testing.T) { - resp, err := svr.GetSegmentInfoChannel(context.TODO()) + resp, err := svr.GetSegmentInfoChannel(context.TODO(), nil) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, Params.CommonCfg.DataCoordSegmentInfo.GetValue(), resp.Value) }) } @@ -123,7 +134,7 @@ func TestAssignSegmentID(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(resp.SegIDAssignments)) assign := resp.SegIDAssignments[0] - assert.EqualValues(t, commonpb.ErrorCode_Success, assign.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, assign.GetStatus().GetErrorCode()) assert.EqualValues(t, collID, assign.CollectionID) assert.EqualValues(t, partID, assign.PartitionID) assert.EqualValues(t, channel0, assign.ChannelName) @@ -155,7 +166,7 @@ func TestAssignSegmentID(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, len(resp.SegIDAssignments)) assign := resp.SegIDAssignments[0] - assert.EqualValues(t, commonpb.ErrorCode_Success, assign.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, assign.GetStatus().GetErrorCode()) assert.EqualValues(t, collID, assign.CollectionID) assert.EqualValues(t, partID, assign.PartitionID) assert.EqualValues(t, channel0, assign.ChannelName) @@ -177,16 +188,15 @@ func TestAssignSegmentID(t *testing.T) { SegmentIDRequests: []*datapb.SegmentIDRequest{req}, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) t.Run("assign segment with invalid collection", func(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) svr.rootCoordClient = &mockRootCoord{ - RootCoord: svr.rootCoordClient, - collID: collID, + RootCoordClient: svr.rootCoordClient, + collID: collID, } schema := newTestSchema() @@ -213,11 +223,11 @@ func TestAssignSegmentID(t *testing.T) { } type mockRootCoord struct { - types.RootCoord + types.RootCoordClient collID UniqueID } -func (r *mockRootCoord) DescribeCollectionInternal(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (r *mockRootCoord) DescribeCollectionInternal(ctx context.Context, req *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { if req.CollectionID != r.collID { return &milvuspb.DescribeCollectionResponse{ Status: &commonpb.Status{ @@ -226,10 +236,10 @@ func (r *mockRootCoord) DescribeCollectionInternal(ctx context.Context, req *mil }, }, nil } - return r.RootCoord.DescribeCollection(ctx, req) + return r.RootCoordClient.DescribeCollection(ctx, req) } -func (r *mockRootCoord) ReportImport(context.Context, *rootcoordpb.ImportResult) (*commonpb.Status, error) { +func (r *mockRootCoord) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "something bad", @@ -260,7 +270,7 @@ func TestFlush(t *testing.T) { resp, err := svr.Flush(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) svr.meta.SetCurrentRows(segID, 1) ids, err := svr.segmentManager.GetFlushableSegments(context.TODO(), "channel-1", expireTs) @@ -282,7 +292,7 @@ func TestFlush(t *testing.T) { resp, err := svr.Flush(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 0, len(resp.SegmentIDs)) // should not flush anything since this is a normal flush svr.meta.SetCurrentRows(segID, 1) @@ -304,7 +314,7 @@ func TestFlush(t *testing.T) { resp, err = svr.Flush(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 1, len(resp.SegmentIDs)) ids, err = svr.segmentManager.GetFlushableSegments(context.TODO(), "channel-1", expireTs) @@ -318,19 +328,52 @@ func TestFlush(t *testing.T) { closeTestServer(t, svr) resp, err := svr.Flush(context.Background(), req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) + }) + + t.Run("test rolling upgrade", func(t *testing.T) { + svr := newTestServer(t, nil) + closeTestServer(t, svr) + svr.stateCode.Store(commonpb.StateCode_Healthy) + sm := NewSessionManager() + + datanodeClient := mocks.NewMockDataNodeClient(t) + datanodeClient.EXPECT().FlushChannels(mock.Anything, mock.Anything).Return(nil, + merr.WrapErrServiceUnimplemented(grpcStatus.Error(codes.Unimplemented, "mock grpc unimplemented error"))) + + sm.sessions = struct { + sync.RWMutex + data map[int64]*Session + }{data: map[int64]*Session{1: { + client: datanodeClient, + clientCreator: func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { + return datanodeClient, nil + }, + }}} + + svr.sessionManager = sm + svr.cluster.sessionManager = sm + + err := svr.channelManager.AddNode(1) + assert.NoError(t, err) + err = svr.channelManager.Watch(context.TODO(), &channel{Name: "ch1", CollectionID: 0}) + assert.NoError(t, err) + + resp, err := svr.Flush(context.TODO(), req) + assert.NoError(t, err) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + assert.Equal(t, Timestamp(0), resp.GetFlushTs()) }) } -//func TestGetComponentStates(t *testing.T) { -//svr := newTestServer(t) -//defer closeTestServer(t, svr) -//cli := newMockDataNodeClient(1) -//err := cli.Init() -//assert.NoError(t, err) -//err = cli.Start() -//assert.NoError(t, err) +// func TestGetComponentStates(t *testing.T) { +// svr := newTestServer(t) +// defer closeTestServer(t, svr) +// cli := newMockDataNodeClient(1) +// err := cli.Init() +// assert.NoError(t, err) +// err = cli.Start() +// assert.NoError(t, err) //err = svr.cluster.Register(&dataNode{ //id: 1, @@ -348,7 +391,7 @@ func TestFlush(t *testing.T) { //resp, err := svr.GetComponentStates(context.TODO()) //assert.NoError(t, err) -//assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) +//assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) //assert.EqualValues(t, commonpb.StateCode_Healthy, resp.State.StateCode) //assert.EqualValues(t, 1, len(resp.SubcomponentStates)) //assert.EqualValues(t, commonpb.StateCode_Healthy, resp.SubcomponentStates[0].StateCode) @@ -357,9 +400,9 @@ func TestFlush(t *testing.T) { func TestGetTimeTickChannel(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - resp, err := svr.GetTimeTickChannel(context.TODO()) + resp, err := svr.GetTimeTickChannel(context.TODO(), nil) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, Params.CommonCfg.DataCoordTimeTick.GetValue(), resp.Value) } @@ -381,7 +424,7 @@ func TestGetSegmentStates(t *testing.T) { Timestamp: 0, }, } - err := svr.meta.AddSegment(NewSegmentInfo(segment)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segment)) assert.NoError(t, err) cases := []struct { @@ -406,7 +449,7 @@ func TestGetSegmentStates(t *testing.T) { SegmentIDs: []int64{test.id}, }) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 1, len(resp.States)) if test.expected { assert.EqualValues(t, test.expectedState, resp.States[0].State) @@ -428,8 +471,7 @@ func TestGetSegmentStates(t *testing.T) { SegmentIDs: []int64{0}, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) } @@ -455,14 +497,14 @@ func TestGetInsertBinlogPaths(t *testing.T) { }, State: commonpb.SegmentState_Growing, } - err := svr.meta.AddSegment(NewSegmentInfo(info)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(info)) assert.NoError(t, err) req := &datapb.GetInsertBinlogPathsRequest{ SegmentID: 0, } resp, err := svr.GetInsertBinlogPaths(svr.ctx, req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("with invalid segmentID", func(t *testing.T) { @@ -487,15 +529,14 @@ func TestGetInsertBinlogPaths(t *testing.T) { State: commonpb.SegmentState_Growing, } - err := svr.meta.AddSegment(NewSegmentInfo(info)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(info)) assert.NoError(t, err) req := &datapb.GetInsertBinlogPathsRequest{ SegmentID: 1, } resp, err := svr.GetInsertBinlogPaths(svr.ctx, req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrSegmentNotFound) }) t.Run("with closed server", func(t *testing.T) { @@ -505,8 +546,7 @@ func TestGetInsertBinlogPaths(t *testing.T) { SegmentID: 0, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) } @@ -520,8 +560,7 @@ func TestGetCollectionStatistics(t *testing.T) { } resp, err := svr.GetCollectionStatistics(svr.ctx, req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("with closed server", func(t *testing.T) { svr := newTestServer(t, nil) @@ -530,8 +569,7 @@ func TestGetCollectionStatistics(t *testing.T) { CollectionID: 0, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) } @@ -546,15 +584,14 @@ func TestGetPartitionStatistics(t *testing.T) { } resp, err := svr.GetPartitionStatistics(context.Background(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("with closed server", func(t *testing.T) { svr := newTestServer(t, nil) closeTestServer(t, svr) resp, err := svr.GetPartitionStatistics(context.Background(), &datapb.GetPartitionStatisticsRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) } @@ -587,7 +624,7 @@ func TestGetSegmentInfo(t *testing.T) { }, }, } - err := svr.meta.AddSegment(NewSegmentInfo(segInfo)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo)) assert.NoError(t, err) req := &datapb.GetSegmentInfoRequest{ @@ -598,7 +635,7 @@ func TestGetSegmentInfo(t *testing.T) { // Check that # of rows is corrected from 100 to 60. assert.EqualValues(t, 60, resp.GetInfos()[0].GetNumOfRows()) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("with wrong segmentID", func(t *testing.T) { svr := newTestServer(t, nil) @@ -608,7 +645,7 @@ func TestGetSegmentInfo(t *testing.T) { ID: 0, State: commonpb.SegmentState_Flushed, } - err := svr.meta.AddSegment(NewSegmentInfo(segInfo)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo)) assert.NoError(t, err) req := &datapb.GetSegmentInfoRequest{ @@ -616,7 +653,7 @@ func TestGetSegmentInfo(t *testing.T) { } resp, err := svr.GetSegmentInfo(svr.ctx, req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrSegmentNotFound) }) t.Run("with closed server", func(t *testing.T) { svr := newTestServer(t, nil) @@ -625,8 +662,7 @@ func TestGetSegmentInfo(t *testing.T) { SegmentIDs: []int64{}, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) t.Run("with dropped segment", func(t *testing.T) { svr := newTestServer(t, nil) @@ -636,7 +672,7 @@ func TestGetSegmentInfo(t *testing.T) { ID: 0, State: commonpb.SegmentState_Dropped, } - err := svr.meta.AddSegment(NewSegmentInfo(segInfo)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo)) assert.NoError(t, err) req := &datapb.GetSegmentInfoRequest{ @@ -673,7 +709,7 @@ func TestGetSegmentInfo(t *testing.T) { ID: 0, State: commonpb.SegmentState_Flushed, } - err := svr.meta.AddSegment(NewSegmentInfo(segInfo)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo)) assert.NoError(t, err) req := &datapb.GetSegmentInfoRequest{ @@ -682,7 +718,7 @@ func TestGetSegmentInfo(t *testing.T) { // no channel checkpoint resp, err := svr.GetSegmentInfo(svr.ctx, req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(resp.GetChannelCheckpoint())) // with nil insert channel of segment @@ -690,18 +726,18 @@ func TestGetSegmentInfo(t *testing.T) { assert.NoError(t, err) resp, err = svr.GetSegmentInfo(svr.ctx, req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(resp.GetChannelCheckpoint())) // normal test segInfo.InsertChannel = mockVChannel segInfo.ID = 2 req.SegmentIDs = []int64{2} - err = svr.meta.AddSegment(NewSegmentInfo(segInfo)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo)) assert.NoError(t, err) resp, err = svr.GetSegmentInfo(svr.ctx, req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 1, len(resp.GetChannelCheckpoint())) assert.Equal(t, mockPChannel, resp.ChannelCheckpoint[mockVChannel].ChannelName) assert.Equal(t, Timestamp(1000), resp.ChannelCheckpoint[mockVChannel].Timestamp) @@ -710,7 +746,7 @@ func TestGetSegmentInfo(t *testing.T) { func TestGetComponentStates(t *testing.T) { svr := &Server{} - resp, err := svr.GetComponentStates(context.Background()) + resp, err := svr.GetComponentStates(context.Background(), nil) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, common.NotRegisteredID, resp.State.NodeID) @@ -727,7 +763,7 @@ func TestGetComponentStates(t *testing.T) { } for _, tc := range cases { svr.stateCode.Store(tc.state) - resp, err := svr.GetComponentStates(context.Background()) + resp, err := svr.GetComponentStates(context.Background(), nil) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, tc.code, resp.GetState().GetStateCode()) @@ -790,7 +826,7 @@ func TestGetFlushedSegments(t *testing.T) { PartitionID: tc.partID, State: commonpb.SegmentState_Flushed, } - assert.Nil(t, svr.meta.AddSegment(NewSegmentInfo(segInfo))) + assert.Nil(t, svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo))) } for _, us := range tc.unflushedSegments { segInfo := &datapb.SegmentInfo{ @@ -799,7 +835,7 @@ func TestGetFlushedSegments(t *testing.T) { PartitionID: tc.partID, State: commonpb.SegmentState_Growing, } - assert.Nil(t, svr.meta.AddSegment(NewSegmentInfo(segInfo))) + assert.Nil(t, svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo))) } resp, err := svr.GetFlushedSegments(context.Background(), &datapb.GetFlushedSegmentsRequest{ @@ -819,8 +855,7 @@ func TestGetFlushedSegments(t *testing.T) { closeTestServer(t, svr) resp, err := svr.GetFlushedSegments(context.Background(), &datapb.GetFlushedSegmentsRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) }) } @@ -885,7 +920,7 @@ func TestGetSegmentsByStates(t *testing.T) { PartitionID: tc.partID, State: commonpb.SegmentState_Flushed, } - assert.Nil(t, svr.meta.AddSegment(NewSegmentInfo(segInfo))) + assert.Nil(t, svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo))) } for _, us := range tc.sealedSegments { segInfo := &datapb.SegmentInfo{ @@ -894,7 +929,7 @@ func TestGetSegmentsByStates(t *testing.T) { PartitionID: tc.partID, State: commonpb.SegmentState_Sealed, } - assert.Nil(t, svr.meta.AddSegment(NewSegmentInfo(segInfo))) + assert.Nil(t, svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo))) } for _, us := range tc.growingSegments { segInfo := &datapb.SegmentInfo{ @@ -903,7 +938,7 @@ func TestGetSegmentsByStates(t *testing.T) { PartitionID: tc.partID, State: commonpb.SegmentState_Growing, } - assert.Nil(t, svr.meta.AddSegment(NewSegmentInfo(segInfo))) + assert.Nil(t, svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segInfo))) } resp, err := svr.GetSegmentsByStates(context.Background(), &datapb.GetSegmentsByStatesRequest{ @@ -924,8 +959,7 @@ func TestGetSegmentsByStates(t *testing.T) { closeTestServer(t, svr) resp, err := svr.GetSegmentsByStates(context.Background(), &datapb.GetSegmentsByStatesRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) }) } @@ -937,7 +971,7 @@ func TestService_WatchServices(t *testing.T) { factory := dependency.NewDefaultFactory(true) svr := CreateServer(context.TODO(), factory) svr.session = &sessionutil.Session{ - TriggerKill: true, + SessionRaw: sessionutil.SessionRaw{TriggerKill: true}, } svr.serverLoopWg.Add(1) @@ -1130,14 +1164,14 @@ func TestServer_ShowConfigurations(t *testing.T) { svr.stateCode.Store(commonpb.StateCode_Initializing) resp, err := svr.ShowConfigurations(svr.ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) // normal case svr.stateCode.Store(stateSave) resp, err = svr.ShowConfigurations(svr.ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 1, len(resp.Configuations)) assert.Equal(t, "datacoord.port", resp.Configuations[0].Key) } @@ -1153,7 +1187,7 @@ func TestServer_GetMetrics(t *testing.T) { svr.stateCode.Store(commonpb.StateCode_Initializing) resp, err := svr.GetMetrics(svr.ctx, &milvuspb.GetMetricsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) svr.stateCode.Store(stateSave) // failed to parse metric type @@ -1162,7 +1196,7 @@ func TestServer_GetMetrics(t *testing.T) { Request: invalidRequest, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // unsupported metric type unsupportedMetricType := "unsupported" @@ -1170,14 +1204,14 @@ func TestServer_GetMetrics(t *testing.T) { assert.NoError(t, err) resp, err = svr.GetMetrics(svr.ctx, req) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // normal case req, err = metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics) assert.NoError(t, err) resp, err = svr.GetMetrics(svr.ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) log.Info("TestServer_GetMetrics", zap.String("name", resp.ComponentName), zap.String("response", resp.Response)) @@ -1268,16 +1302,17 @@ func TestSaveBinlogPaths(t *testing.T) { InsertChannel: "ch1", State: commonpb.SegmentState_Growing, } - err := svr.meta.AddSegment(NewSegmentInfo(s)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(s)) assert.NoError(t, err) } + ctx := context.Background() + err := svr.channelManager.AddNode(0) assert.NoError(t, err) - err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 0}) + err = svr.channelManager.Watch(ctx, &channel{Name: "ch1", CollectionID: 0}) assert.NoError(t, err) - ctx := context.Background() resp, err := svr.SaveBinlogPaths(ctx, &datapb.SaveBinlogPathsRequest{ Base: &commonpb.MsgBase{ Timestamp: uint64(time.Now().Unix()), @@ -1355,16 +1390,16 @@ func TestSaveBinlogPaths(t *testing.T) { InsertChannel: "ch1", State: commonpb.SegmentState_Dropped, } - err := svr.meta.AddSegment(NewSegmentInfo(s)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(s)) assert.NoError(t, err) } + ctx := context.Background() err := svr.channelManager.AddNode(0) assert.NoError(t, err) - err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 0}) + err = svr.channelManager.Watch(ctx, &channel{Name: "ch1", CollectionID: 0}) assert.NoError(t, err) - ctx := context.Background() resp, err := svr.SaveBinlogPaths(ctx, &datapb.SaveBinlogPathsRequest{ Base: &commonpb.MsgBase{ Timestamp: uint64(time.Now().Unix()), @@ -1433,16 +1468,16 @@ func TestSaveBinlogPaths(t *testing.T) { InsertChannel: "ch1", State: commonpb.SegmentState_NotExist, } - err := svr.meta.AddSegment(NewSegmentInfo(s)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(s)) assert.NoError(t, err) } + ctx := context.Background() err := svr.channelManager.AddNode(0) assert.NoError(t, err) - err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 0}) + err = svr.channelManager.Watch(ctx, &channel{Name: "ch1", CollectionID: 0}) assert.NoError(t, err) - ctx := context.Background() resp, err := svr.SaveBinlogPaths(ctx, &datapb.SaveBinlogPathsRequest{ Base: &commonpb.MsgBase{ Timestamp: uint64(time.Now().Unix()), @@ -1479,7 +1514,7 @@ func TestSaveBinlogPaths(t *testing.T) { Flushed: false, }) assert.NoError(t, err) - assert.EqualValues(t, resp.ErrorCode, commonpb.ErrorCode_SegmentNotFound) + assert.ErrorIs(t, merr.Error(resp), merr.ErrSegmentNotFound) }) t.Run("SaveNotExistSegment", func(t *testing.T) { @@ -1491,12 +1526,12 @@ func TestSaveBinlogPaths(t *testing.T) { ID: 0, }) + ctx := context.Background() err := svr.channelManager.AddNode(0) assert.NoError(t, err) - err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 0}) + err = svr.channelManager.Watch(ctx, &channel{Name: "ch1", CollectionID: 0}) assert.NoError(t, err) - ctx := context.Background() resp, err := svr.SaveBinlogPaths(ctx, &datapb.SaveBinlogPathsRequest{ Base: &commonpb.MsgBase{ Timestamp: uint64(time.Now().Unix()), @@ -1533,7 +1568,7 @@ func TestSaveBinlogPaths(t *testing.T) { Flushed: false, }) assert.NoError(t, err) - assert.EqualValues(t, resp.ErrorCode, commonpb.ErrorCode_SegmentNotFound) + assert.ErrorIs(t, merr.Error(resp), merr.ErrSegmentNotFound) }) t.Run("with channel not matched", func(t *testing.T) { @@ -1541,21 +1576,21 @@ func TestSaveBinlogPaths(t *testing.T) { defer closeTestServer(t, svr) err := svr.channelManager.AddNode(0) require.Nil(t, err) - err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 0}) + err = svr.channelManager.Watch(context.TODO(), &channel{Name: "ch1", CollectionID: 0}) require.Nil(t, err) s := &datapb.SegmentInfo{ ID: 1, InsertChannel: "ch2", State: commonpb.SegmentState_Growing, } - svr.meta.AddSegment(NewSegmentInfo(s)) + svr.meta.AddSegment(context.TODO(), NewSegmentInfo(s)) resp, err := svr.SaveBinlogPaths(context.Background(), &datapb.SaveBinlogPathsRequest{ SegmentID: 1, Channel: "test", }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_MetaFailed, resp.GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp), merr.ErrChannelNotFound) }) t.Run("with closed server", func(t *testing.T) { @@ -1563,8 +1598,7 @@ func TestSaveBinlogPaths(t *testing.T) { closeTestServer(t, svr) resp, err := svr.SaveBinlogPaths(context.Background(), &datapb.SaveBinlogPathsRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetReason()) + assert.ErrorIs(t, merr.Error(resp), merr.ErrServiceNotReady) }) /* t.Run("test save dropped segment and remove channel", func(t *testing.T) { @@ -1643,7 +1677,7 @@ func TestDropVirtualChannel(t *testing.T) { {FieldID: 1}, } } - err := svr.meta.AddSegment(NewSegmentInfo(s)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(s)) assert.NoError(t, err) } // add non matched segments @@ -1655,14 +1689,14 @@ func TestDropVirtualChannel(t *testing.T) { State: commonpb.SegmentState_Growing, } - svr.meta.AddSegment(NewSegmentInfo(os)) + svr.meta.AddSegment(context.TODO(), NewSegmentInfo(os)) + ctx := context.Background() err := svr.channelManager.AddNode(0) require.Nil(t, err) - err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 0}) + err = svr.channelManager.Watch(ctx, &channel{Name: "ch1", CollectionID: 0}) require.Nil(t, err) - ctx := context.Background() req := &datapb.DropVirtualChannelRequest{ Base: &commonpb.MsgBase{ Timestamp: uint64(time.Now().Unix()), @@ -1732,14 +1766,13 @@ func TestDropVirtualChannel(t *testing.T) { <-spyCh - err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 0}) + err = svr.channelManager.Watch(ctx, &channel{Name: "ch1", CollectionID: 0}) require.Nil(t, err) - //resend + // resend resp, err = svr.DropVirtualChannel(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) - }) t.Run("with channel not matched", func(t *testing.T) { @@ -1747,14 +1780,14 @@ func TestDropVirtualChannel(t *testing.T) { defer closeTestServer(t, svr) err := svr.channelManager.AddNode(0) require.Nil(t, err) - err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 0}) + err = svr.channelManager.Watch(context.TODO(), &channel{Name: "ch1", CollectionID: 0}) require.Nil(t, err) resp, err := svr.DropVirtualChannel(context.Background(), &datapb.DropVirtualChannelRequest{ ChannelName: "ch2", }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_MetaFailed, resp.GetStatus().GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrChannelNotFound) }) t.Run("with closed server", func(t *testing.T) { @@ -1762,8 +1795,7 @@ func TestDropVirtualChannel(t *testing.T) { closeTestServer(t, svr) resp, err := svr.DropVirtualChannel(context.Background(), &datapb.DropVirtualChannelRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) } @@ -1790,35 +1822,45 @@ func TestGetChannelSeekPosition(t *testing.T) { channelName string expectedPos *msgpb.MsgPosition }{ - {"test-with-channelCP", + { + "test-with-channelCP", &msgpb.MsgPosition{ChannelName: "ch1", Timestamp: 100, MsgID: msgID}, []*msgpb.MsgPosition{{ChannelName: "ch1", Timestamp: 50, MsgID: msgID}, {ChannelName: "ch1", Timestamp: 200, MsgID: msgID}}, startPos1, - "ch1", &msgpb.MsgPosition{ChannelName: "ch1", Timestamp: 100, MsgID: msgID}}, + "ch1", &msgpb.MsgPosition{ChannelName: "ch1", Timestamp: 100, MsgID: msgID}, + }, - {"test-with-segmentDMLPos", + { + "test-with-segmentDMLPos", nil, []*msgpb.MsgPosition{{ChannelName: "ch1", Timestamp: 50, MsgID: msgID}, {ChannelName: "ch1", Timestamp: 200, MsgID: msgID}}, startPos1, - "ch1", &msgpb.MsgPosition{ChannelName: "ch1", Timestamp: 50, MsgID: msgID}}, + "ch1", &msgpb.MsgPosition{ChannelName: "ch1", Timestamp: 50, MsgID: msgID}, + }, - {"test-with-collStartPos", + { + "test-with-collStartPos", nil, nil, startPos1, - "ch1", &msgpb.MsgPosition{ChannelName: "ch1", MsgID: startPos1[0].Data}}, + "ch1", &msgpb.MsgPosition{ChannelName: "ch1", MsgID: startPos1[0].Data}, + }, - {"test-non-exist-channel-1", + { + "test-non-exist-channel-1", nil, nil, startPosNonExist, - "ch1", nil}, + "ch1", nil, + }, - {"test-non-exist-channel-2", + { + "test-non-exist-channel-2", nil, nil, nil, - "ch1", nil}, + "ch1", nil, + }, } for _, test := range tests { t.Run(test.testName, func(t *testing.T) { @@ -1840,7 +1882,7 @@ func TestGetChannelSeekPosition(t *testing.T) { DmlPosition: segPos, InsertChannel: "ch1", } - err := svr.meta.AddSegment(NewSegmentInfo(seg)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg)) assert.NoError(t, err) } if test.channelCP != nil { @@ -1850,7 +1892,8 @@ func TestGetChannelSeekPosition(t *testing.T) { seekPos := svr.handler.(*ServerHandler).GetChannelSeekPosition(&channel{ Name: test.channelName, - CollectionID: 0}, allPartitionID) + CollectionID: 0, + }, allPartitionID) if test.expectedPos == nil { assert.True(t, seekPos == nil) } else { @@ -1916,7 +1959,7 @@ func TestGetDataVChanPositions(t *testing.T) { MsgID: []byte{1, 2, 3}, }, } - err := svr.meta.AddSegment(NewSegmentInfo(s1)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(s1)) require.Nil(t, err) s2 := &datapb.SegmentInfo{ ID: 2, @@ -1934,7 +1977,7 @@ func TestGetDataVChanPositions(t *testing.T) { Timestamp: 1, }, } - err = svr.meta.AddSegment(NewSegmentInfo(s2)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(s2)) require.Nil(t, err) s3 := &datapb.SegmentInfo{ ID: 3, @@ -1952,7 +1995,7 @@ func TestGetDataVChanPositions(t *testing.T) { Timestamp: 2, }, } - err = svr.meta.AddSegment(NewSegmentInfo(s3)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(s3)) require.Nil(t, err) t.Run("get unexisted channel", func(t *testing.T) { @@ -2043,7 +2086,7 @@ func TestGetQueryVChanPositions(t *testing.T) { }, NumOfRows: 2048, } - err = svr.meta.AddSegment(NewSegmentInfo(s1)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(s1)) assert.NoError(t, err) err = svr.meta.AddSegmentIndex(&model.SegmentIndex{ SegmentID: 1, @@ -2074,7 +2117,7 @@ func TestGetQueryVChanPositions(t *testing.T) { Timestamp: 1, }, } - err = svr.meta.AddSegment(NewSegmentInfo(s2)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(s2)) assert.NoError(t, err) s3 := &datapb.SegmentInfo{ ID: 3, @@ -2094,7 +2137,7 @@ func TestGetQueryVChanPositions(t *testing.T) { Timestamp: 2, }, } - err = svr.meta.AddSegment(NewSegmentInfo(s3)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(s3)) assert.NoError(t, err) //mockResp := &indexpb.GetIndexInfoResponse{ // Status: &commonpb.Status{}, @@ -2183,7 +2226,7 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { }, CompactionFrom: []int64{99, 100}, // a, b which have been GC-ed } - err = svr.meta.AddSegment(NewSegmentInfo(c)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(c)) assert.NoError(t, err) d := &datapb.SegmentInfo{ ID: 2, @@ -2198,7 +2241,7 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { Timestamp: 1, }, } - err = svr.meta.AddSegment(NewSegmentInfo(d)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(d)) assert.NoError(t, err) e := &datapb.SegmentInfo{ ID: 3, @@ -2216,7 +2259,7 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { NumOfRows: 2048, } - err = svr.meta.AddSegment(NewSegmentInfo(e)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(e)) assert.NoError(t, err) vchan := svr.handler.GetQueryVChanPositions(&channel{Name: "ch1", CollectionID: 0}, allPartitionID) assert.EqualValues(t, 2, len(vchan.FlushedSegmentIds)) @@ -2252,7 +2295,7 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { Timestamp: 1, }, } - err = svr.meta.AddSegment(NewSegmentInfo(a)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(a)) assert.NoError(t, err) c := &datapb.SegmentInfo{ @@ -2269,7 +2312,7 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { }, CompactionFrom: []int64{99, 100}, // a, b which have been GC-ed } - err = svr.meta.AddSegment(NewSegmentInfo(c)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(c)) assert.NoError(t, err) d := &datapb.SegmentInfo{ ID: 2, @@ -2284,7 +2327,7 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { Timestamp: 1, }, } - err = svr.meta.AddSegment(NewSegmentInfo(d)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(d)) assert.NoError(t, err) e := &datapb.SegmentInfo{ ID: 3, @@ -2302,7 +2345,7 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { NumOfRows: 2048, } - err = svr.meta.AddSegment(NewSegmentInfo(e)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(e)) assert.NoError(t, err) vchan := svr.handler.GetQueryVChanPositions(&channel{Name: "ch1", CollectionID: 0}, allPartitionID) assert.EqualValues(t, 2, len(vchan.FlushedSegmentIds)) @@ -2339,7 +2382,7 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { }, CompactionFrom: []int64{99, 100}, // a, b which have been GC-ed } - err = svr.meta.AddSegment(NewSegmentInfo(c)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(c)) assert.NoError(t, err) d := &datapb.SegmentInfo{ ID: 2, @@ -2354,7 +2397,7 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { Timestamp: 1, }, } - err = svr.meta.AddSegment(NewSegmentInfo(d)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(d)) assert.NoError(t, err) err = svr.meta.AddSegmentIndex(&model.SegmentIndex{ SegmentID: 2, @@ -2382,7 +2425,7 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { CompactionFrom: []int64{1, 2}, // c, d NumOfRows: 2048, } - err = svr.meta.AddSegment(NewSegmentInfo(e)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(e)) assert.NoError(t, err) err = svr.meta.AddSegmentIndex(&model.SegmentIndex{ SegmentID: 3, @@ -2405,24 +2448,22 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) { func TestShouldDropChannel(t *testing.T) { type myRootCoord struct { - mocks.RootCoord + mocks.MockRootCoordClient } myRoot := &myRootCoord{} - myRoot.EXPECT().Init().Return(nil) - myRoot.EXPECT().Start().Return(nil) myRoot.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocTimestampResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), Timestamp: tsoutil.ComposeTSByTime(time.Now(), 0), Count: 1, }, nil) myRoot.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), ID: int64(tsoutil.ComposeTSByTime(time.Now(), 0)), Count: 1, }, nil) - var crt rootCoordCreatorFunc = func(ctx context.Context, metaRoot string, etcdClient *clientv3.Client) (types.RootCoord, error) { + var crt rootCoordCreatorFunc = func(ctx context.Context, metaRoot string, etcdClient *clientv3.Client) (types.RootCoordClient, error) { return myRoot, nil } @@ -2452,7 +2493,7 @@ func TestShouldDropChannel(t *testing.T) { }) t.Run("channel name not in kv, collection not exist", func(t *testing.T) { - //myRoot.code = commonpb.ErrorCode_CollectionNotExists + // myRoot.code = commonpb.ErrorCode_CollectionNotExists myRoot.EXPECT().DescribeCollection(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ Status: merr.Status(merr.WrapErrCollectionNotFound(-1)), @@ -2464,7 +2505,7 @@ func TestShouldDropChannel(t *testing.T) { t.Run("channel name not in kv, collection exist", func(t *testing.T) { myRoot.EXPECT().DescribeCollection(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), CollectionID: 0, }, nil).Once() assert.False(t, svr.handler.CheckShouldDropChannel("ch99", 0)) @@ -2473,7 +2514,7 @@ func TestShouldDropChannel(t *testing.T) { t.Run("collection name in kv, collection exist", func(t *testing.T) { myRoot.EXPECT().DescribeCollection(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), CollectionID: 0, }, nil).Once() assert.False(t, svr.handler.CheckShouldDropChannel("ch1", 0)) @@ -2493,7 +2534,7 @@ func TestShouldDropChannel(t *testing.T) { require.NoError(t, err) myRoot.EXPECT().DescribeCollection(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), CollectionID: 0, }, nil).Once() assert.True(t, svr.handler.CheckShouldDropChannel("ch1", 0)) @@ -2501,13 +2542,12 @@ func TestShouldDropChannel(t *testing.T) { } func TestGetRecoveryInfo(t *testing.T) { - t.Run("test get recovery info with no segments", func(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } req := &datapb.GetRecoveryInfoRequest{ @@ -2516,14 +2556,15 @@ func TestGetRecoveryInfo(t *testing.T) { } resp, err := svr.GetRecoveryInfo(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 0, len(resp.GetBinlogs())) assert.EqualValues(t, 1, len(resp.GetChannels())) assert.Nil(t, resp.GetChannels()[0].SeekPosition) }) createSegment := func(id, collectionID, partitionID, numOfRows int64, posTs uint64, - channel string, state commonpb.SegmentState) *datapb.SegmentInfo { + channel string, state commonpb.SegmentState, + ) *datapb.SegmentInfo { return &datapb.SegmentInfo{ ID: id, CollectionID: collectionID, @@ -2549,8 +2590,8 @@ func TestGetRecoveryInfo(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } svr.meta.AddCollection(&collectionInfo{ @@ -2609,9 +2650,9 @@ func TestGetRecoveryInfo(t *testing.T) { }, }, } - err = svr.meta.AddSegment(NewSegmentInfo(seg1)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg2)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) assert.NoError(t, err) err = svr.meta.AddSegmentIndex(&model.SegmentIndex{ SegmentID: seg1.ID, @@ -2640,7 +2681,7 @@ func TestGetRecoveryInfo(t *testing.T) { } resp, err := svr.GetRecoveryInfo(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 1, len(resp.GetChannels())) assert.EqualValues(t, 0, len(resp.GetChannels()[0].GetUnflushedSegmentIds())) assert.ElementsMatch(t, []int64{0, 1}, resp.GetChannels()[0].GetFlushedSegmentIds()) @@ -2654,8 +2695,8 @@ func TestGetRecoveryInfo(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } svr.meta.AddCollection(&collectionInfo{ @@ -2706,11 +2747,11 @@ func TestGetRecoveryInfo(t *testing.T) { }, }, } - err = svr.meta.AddSegment(NewSegmentInfo(seg1)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg2)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) assert.NoError(t, err) - //svr.indexCoord.(*mocks.MockIndexCoord).EXPECT().GetIndexInfos(mock.Anything, mock.Anything).Return(nil, nil) + // svr.indexCoord.(*mocks.MockIndexCoord).EXPECT().GetIndexInfos(mock.Anything, mock.Anything).Return(nil, nil) req := &datapb.GetRecoveryInfoRequest{ CollectionID: 0, @@ -2718,7 +2759,7 @@ func TestGetRecoveryInfo(t *testing.T) { } resp, err := svr.GetRecoveryInfo(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 0, len(resp.GetBinlogs())) assert.EqualValues(t, 1, len(resp.GetChannels())) assert.NotNil(t, resp.GetChannels()[0].SeekPosition) @@ -2733,8 +2774,8 @@ func TestGetRecoveryInfo(t *testing.T) { Schema: newTestSchema(), }) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } binlogReq := &datapb.SaveBinlogPathsRequest{ @@ -2780,7 +2821,7 @@ func TestGetRecoveryInfo(t *testing.T) { }, } segment := createSegment(0, 0, 1, 100, 10, "vchan1", commonpb.SegmentState_Flushed) - err := svr.meta.AddSegment(NewSegmentInfo(segment)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segment)) assert.NoError(t, err) err = svr.meta.CreateIndex(&model.Index{ @@ -2804,7 +2845,7 @@ func TestGetRecoveryInfo(t *testing.T) { err = svr.channelManager.AddNode(0) assert.NoError(t, err) - err = svr.channelManager.Watch(&channel{Name: "vchan1", CollectionID: 0}) + err = svr.channelManager.Watch(context.TODO(), &channel{Name: "vchan1", CollectionID: 0}) assert.NoError(t, err) sResp, err := svr.SaveBinlogPaths(context.TODO(), binlogReq) @@ -2817,7 +2858,7 @@ func TestGetRecoveryInfo(t *testing.T) { } resp, err := svr.GetRecoveryInfo(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 1, len(resp.GetBinlogs())) assert.EqualValues(t, 0, resp.GetBinlogs()[0].GetSegmentID()) assert.EqualValues(t, 1, len(resp.GetBinlogs()[0].GetFieldBinlogs())) @@ -2830,8 +2871,8 @@ func TestGetRecoveryInfo(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } svr.meta.AddCollection(&collectionInfo{ @@ -2848,9 +2889,9 @@ func TestGetRecoveryInfo(t *testing.T) { seg1 := createSegment(7, 0, 0, 100, 30, "vchan1", commonpb.SegmentState_Growing) seg2 := createSegment(8, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Dropped) - err = svr.meta.AddSegment(NewSegmentInfo(seg1)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg2)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) assert.NoError(t, err) req := &datapb.GetRecoveryInfoRequest{ @@ -2859,7 +2900,7 @@ func TestGetRecoveryInfo(t *testing.T) { } resp, err := svr.GetRecoveryInfo(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 0, len(resp.GetBinlogs())) assert.EqualValues(t, 1, len(resp.GetChannels())) assert.NotNil(t, resp.GetChannels()[0].SeekPosition) @@ -2872,8 +2913,8 @@ func TestGetRecoveryInfo(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } svr.meta.AddCollection(&collectionInfo{ @@ -2891,9 +2932,9 @@ func TestGetRecoveryInfo(t *testing.T) { seg1 := createSegment(7, 0, 0, 100, 30, "vchan1", commonpb.SegmentState_Growing) seg2 := createSegment(8, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Flushed) seg2.IsFake = true - err = svr.meta.AddSegment(NewSegmentInfo(seg1)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg2)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) assert.NoError(t, err) req := &datapb.GetRecoveryInfoRequest{ @@ -2902,7 +2943,7 @@ func TestGetRecoveryInfo(t *testing.T) { } resp, err := svr.GetRecoveryInfo(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 0, len(resp.GetBinlogs())) assert.EqualValues(t, 1, len(resp.GetChannels())) assert.NotNil(t, resp.GetChannels()[0].SeekPosition) @@ -2913,8 +2954,8 @@ func TestGetRecoveryInfo(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } svr.meta.AddCollection(&collectionInfo{ @@ -2936,15 +2977,15 @@ func TestGetRecoveryInfo(t *testing.T) { seg4 := createSegment(12, 0, 0, 2048, 40, "vchan1", commonpb.SegmentState_Dropped) seg5 := createSegment(13, 0, 0, 2048, 40, "vchan1", commonpb.SegmentState_Flushed) seg5.CompactionFrom = []int64{11, 12} - err = svr.meta.AddSegment(NewSegmentInfo(seg1)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg2)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg3)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg3)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg4)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg4)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg5)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg5)) assert.NoError(t, err) err = svr.meta.CreateIndex(&model.Index{ TenantID: "", @@ -2983,7 +3024,7 @@ func TestGetRecoveryInfo(t *testing.T) { } resp, err := svr.GetRecoveryInfo(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.NotNil(t, resp.GetChannels()[0].SeekPosition) assert.NotEqual(t, 0, resp.GetChannels()[0].GetSeekPosition().GetTimestamp()) assert.Len(t, resp.GetChannels()[0].GetDroppedSegmentIds(), 0) @@ -2996,8 +3037,7 @@ func TestGetRecoveryInfo(t *testing.T) { closeTestServer(t, svr) resp, err := svr.GetRecoveryInfo(context.TODO(), &datapb.GetRecoveryInfoRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) } @@ -3062,8 +3102,7 @@ func TestGetCompactionState(t *testing.T) { resp, err := svr.GetCompactionState(context.Background(), &milvuspb.GetCompactionStateRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, msgDataCoordIsUnhealthy(paramtable.GetNodeID()), resp.GetStatus().GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) } @@ -3086,7 +3125,7 @@ func TestManualCompaction(t *testing.T) { Timetravel: 1, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("test manual compaction failure", func(t *testing.T) { @@ -3105,7 +3144,7 @@ func TestManualCompaction(t *testing.T) { Timetravel: 1, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) }) t.Run("test manual compaction with closed server", func(t *testing.T) { @@ -3124,8 +3163,7 @@ func TestManualCompaction(t *testing.T) { Timetravel: 1, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode) - assert.Equal(t, msgDataCoordIsUnhealthy(paramtable.GetNodeID()), resp.Status.Reason) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) } @@ -3151,7 +3189,7 @@ func TestGetCompactionStateWithPlans(t *testing.T) { CompactionID: 1, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.CompactionState_Executing, resp.State) }) @@ -3175,8 +3213,7 @@ func TestGetCompactionStateWithPlans(t *testing.T) { CompactionID: 1, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode) - assert.Equal(t, msgDataCoordIsUnhealthy(paramtable.GetNodeID()), resp.Status.Reason) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) } @@ -3190,7 +3227,7 @@ func TestOptions(t *testing.T) { t.Run("WithRootCoordCreator", func(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - var crt rootCoordCreatorFunc = func(ctx context.Context, metaRoot string, etcdClient *clientv3.Client) (types.RootCoord, error) { + var crt rootCoordCreatorFunc = func(ctx context.Context, metaRoot string, etcdClient *clientv3.Client) (types.RootCoordClient, error) { return nil, errors.New("dummy") } opt := WithRootCoordCreator(crt) @@ -3220,8 +3257,8 @@ func TestOptions(t *testing.T) { }) t.Run("WithDataNodeCreator", func(t *testing.T) { var target int64 - var val = rand.Int63() - opt := WithDataNodeCreator(func(context.Context, string, int64) (types.DataNode, error) { + val := rand.Int63() + opt := WithDataNodeCreator(func(context.Context, string, int64) (types.DataNodeClient, error) { target = val return nil, nil }) @@ -3277,10 +3314,12 @@ func TestHandleSessionEvent(t *testing.T) { evt := &sessionutil.SessionEvent{ EventType: sessionutil.SessionNoneEvent, Session: &sessionutil.Session{ - ServerID: 0, - ServerName: "", - Address: "", - Exclusive: false, + SessionRaw: sessionutil.SessionRaw{ + ServerID: 0, + ServerName: "", + Address: "", + Exclusive: false, + }, }, } err = svr.handleSessionEvent(context.Background(), typeutil.DataNodeRole, evt) @@ -3289,10 +3328,12 @@ func TestHandleSessionEvent(t *testing.T) { evt = &sessionutil.SessionEvent{ EventType: sessionutil.SessionAddEvent, Session: &sessionutil.Session{ - ServerID: 101, - ServerName: "DN101", - Address: "DN127.0.0.101", - Exclusive: false, + SessionRaw: sessionutil.SessionRaw{ + ServerID: 101, + ServerName: "DN101", + Address: "DN127.0.0.101", + Exclusive: false, + }, }, } err = svr.handleSessionEvent(context.Background(), typeutil.DataNodeRole, evt) @@ -3304,10 +3345,12 @@ func TestHandleSessionEvent(t *testing.T) { evt = &sessionutil.SessionEvent{ EventType: sessionutil.SessionDelEvent, Session: &sessionutil.Session{ - ServerID: 101, - ServerName: "DN101", - Address: "DN127.0.0.101", - Exclusive: false, + SessionRaw: sessionutil.SessionRaw{ + ServerID: 101, + ServerName: "DN101", + Address: "DN127.0.0.101", + Exclusive: false, + }, }, } err = svr.handleSessionEvent(context.Background(), typeutil.DataNodeRole, evt) @@ -3325,14 +3368,14 @@ func TestHandleSessionEvent(t *testing.T) { } type rootCoordSegFlushComplete struct { - mockRootCoordService + mockRootCoordClient flag bool } // SegmentFlushCompleted, override default behavior func (rc *rootCoordSegFlushComplete) SegmentFlushCompleted(ctx context.Context, req *datapb.SegmentFlushCompletedMsg) (*commonpb.Status, error) { if rc.flag { - return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil + return merr.Success(), nil } return &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, nil } @@ -3351,7 +3394,7 @@ func TestPostFlush(t *testing.T) { defer closeTestServer(t, svr) svr.rootCoordClient = &rootCoordSegFlushComplete{flag: true} - err := svr.meta.AddSegment(NewSegmentInfo(&datapb.SegmentInfo{ + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(&datapb.SegmentInfo{ ID: 1, CollectionID: 1, PartitionID: 1, @@ -3367,93 +3410,236 @@ func TestPostFlush(t *testing.T) { func TestGetFlushState(t *testing.T) { t.Run("get flush state with all flushed segments", func(t *testing.T) { - svr := &Server{ - meta: &meta{ - segments: &SegmentsInfo{ - segments: map[int64]*SegmentInfo{ - 1: { - SegmentInfo: &datapb.SegmentInfo{ - ID: 1, - State: commonpb.SegmentState_Flushed, - }, - }, - 2: { - SegmentInfo: &datapb.SegmentInfo{ - ID: 2, - State: commonpb.SegmentState_Flushed, - }, - }, - }, + meta, err := newMemoryMeta() + assert.NoError(t, err) + svr := newTestServerWithMeta(t, nil, meta) + defer closeTestServer(t, svr) + + err = meta.AddSegment(context.TODO(), &SegmentInfo{ + SegmentInfo: &datapb.SegmentInfo{ + ID: 1, + State: commonpb.SegmentState_Flushed, + }, + }) + assert.NoError(t, err) + err = meta.AddSegment(context.TODO(), &SegmentInfo{ + SegmentInfo: &datapb.SegmentInfo{ + ID: 2, + State: commonpb.SegmentState_Flushed, + }, + }) + assert.NoError(t, err) + + var ( + vchannel = "ch1" + collection = int64(0) + ) + + svr.channelManager = &ChannelManager{ + store: &ChannelStore{ + channelsInfo: map[int64]*NodeChannelInfo{ + 1: {NodeID: 1, Channels: []*channel{{Name: vchannel, CollectionID: collection}}}, }, }, } - svr.stateCode.Store(commonpb.StateCode_Healthy) - resp, err := svr.GetFlushState(context.TODO(), &milvuspb.GetFlushStateRequest{SegmentIDs: []int64{1, 2}}) + + err = svr.meta.UpdateChannelCheckpoint(vchannel, &msgpb.MsgPosition{ + MsgID: []byte{1}, + Timestamp: 12, + }) + assert.NoError(t, err) + + resp, err := svr.GetFlushState(context.TODO(), &datapb.GetFlushStateRequest{SegmentIDs: []int64{1, 2}}) assert.NoError(t, err) assert.EqualValues(t, &milvuspb.GetFlushStateResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), Flushed: true, }, resp) }) t.Run("get flush state with unflushed segments", func(t *testing.T) { - svr := &Server{ - meta: &meta{ - segments: &SegmentsInfo{ - segments: map[int64]*SegmentInfo{ - 1: { - SegmentInfo: &datapb.SegmentInfo{ - ID: 1, - State: commonpb.SegmentState_Flushed, - }, - }, - 2: { - SegmentInfo: &datapb.SegmentInfo{ - ID: 2, - State: commonpb.SegmentState_Sealed, - }, - }, - }, + meta, err := newMemoryMeta() + assert.NoError(t, err) + svr := newTestServerWithMeta(t, nil, meta) + defer closeTestServer(t, svr) + + err = meta.AddSegment(context.TODO(), &SegmentInfo{ + SegmentInfo: &datapb.SegmentInfo{ + ID: 1, + State: commonpb.SegmentState_Flushed, + }, + }) + assert.NoError(t, err) + err = meta.AddSegment(context.TODO(), &SegmentInfo{ + SegmentInfo: &datapb.SegmentInfo{ + ID: 2, + State: commonpb.SegmentState_Sealed, + }, + }) + assert.NoError(t, err) + + var ( + vchannel = "ch1" + collection = int64(0) + ) + + svr.channelManager = &ChannelManager{ + store: &ChannelStore{ + channelsInfo: map[int64]*NodeChannelInfo{ + 1: {NodeID: 1, Channels: []*channel{{Name: vchannel, CollectionID: collection}}}, }, }, } - svr.stateCode.Store(commonpb.StateCode_Healthy) - resp, err := svr.GetFlushState(context.TODO(), &milvuspb.GetFlushStateRequest{SegmentIDs: []int64{1, 2}}) + err = svr.meta.UpdateChannelCheckpoint(vchannel, &msgpb.MsgPosition{ + MsgID: []byte{1}, + Timestamp: 12, + }) + assert.NoError(t, err) + + resp, err := svr.GetFlushState(context.TODO(), &datapb.GetFlushStateRequest{SegmentIDs: []int64{1, 2}}) assert.NoError(t, err) assert.EqualValues(t, &milvuspb.GetFlushStateResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), Flushed: false, }, resp) }) t.Run("get flush state with compacted segments", func(t *testing.T) { - svr := &Server{ - meta: &meta{ - segments: &SegmentsInfo{ - segments: map[int64]*SegmentInfo{ - 1: { - SegmentInfo: &datapb.SegmentInfo{ - ID: 1, - State: commonpb.SegmentState_Flushed, - }, - }, - 2: { - SegmentInfo: &datapb.SegmentInfo{ - ID: 2, - State: commonpb.SegmentState_Dropped, - }, - }, - }, + meta, err := newMemoryMeta() + assert.NoError(t, err) + svr := newTestServerWithMeta(t, nil, meta) + defer closeTestServer(t, svr) + + err = meta.AddSegment(context.TODO(), &SegmentInfo{ + SegmentInfo: &datapb.SegmentInfo{ + ID: 1, + State: commonpb.SegmentState_Flushed, + }, + }) + assert.NoError(t, err) + err = meta.AddSegment(context.TODO(), &SegmentInfo{ + SegmentInfo: &datapb.SegmentInfo{ + ID: 2, + State: commonpb.SegmentState_Dropped, + }, + }) + assert.NoError(t, err) + + var ( + vchannel = "ch1" + collection = int64(0) + ) + + svr.channelManager = &ChannelManager{ + store: &ChannelStore{ + channelsInfo: map[int64]*NodeChannelInfo{ + 1: {NodeID: 1, Channels: []*channel{{Name: vchannel, CollectionID: collection}}}, }, }, } - svr.stateCode.Store(commonpb.StateCode_Healthy) - resp, err := svr.GetFlushState(context.TODO(), &milvuspb.GetFlushStateRequest{SegmentIDs: []int64{1, 2}}) + err = svr.meta.UpdateChannelCheckpoint(vchannel, &msgpb.MsgPosition{ + MsgID: []byte{1}, + Timestamp: 12, + }) + assert.NoError(t, err) + + resp, err := svr.GetFlushState(context.TODO(), &datapb.GetFlushStateRequest{SegmentIDs: []int64{1, 2}}) + assert.NoError(t, err) + assert.EqualValues(t, &milvuspb.GetFlushStateResponse{ + Status: merr.Success(), + Flushed: true, + }, resp) + }) + + t.Run("channel flushed", func(t *testing.T) { + meta, err := newMemoryMeta() + assert.NoError(t, err) + svr := newTestServerWithMeta(t, nil, meta) + defer closeTestServer(t, svr) + + var ( + vchannel = "ch1" + collection = int64(0) + ) + + svr.channelManager = &ChannelManager{ + store: &ChannelStore{ + channelsInfo: map[int64]*NodeChannelInfo{ + 1: {NodeID: 1, Channels: []*channel{{Name: vchannel, CollectionID: collection}}}, + }, + }, + } + + err = svr.meta.UpdateChannelCheckpoint(vchannel, &msgpb.MsgPosition{ + MsgID: []byte{1}, + Timestamp: 12, + }) + assert.NoError(t, err) + + resp, err := svr.GetFlushState(context.Background(), &datapb.GetFlushStateRequest{ + FlushTs: 11, + CollectionID: collection, + }) assert.NoError(t, err) assert.EqualValues(t, &milvuspb.GetFlushStateResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), + Flushed: true, + }, resp) + }) + + t.Run("channel unflushed", func(t *testing.T) { + meta, err := newMemoryMeta() + assert.NoError(t, err) + svr := newTestServerWithMeta(t, nil, meta) + defer closeTestServer(t, svr) + + var ( + vchannel = "ch1" + collection = int64(0) + ) + + svr.channelManager = &ChannelManager{ + store: &ChannelStore{ + channelsInfo: map[int64]*NodeChannelInfo{ + 1: {NodeID: 1, Channels: []*channel{{Name: vchannel, CollectionID: collection}}}, + }, + }, + } + + err = svr.meta.UpdateChannelCheckpoint(vchannel, &msgpb.MsgPosition{ + MsgID: []byte{1}, + Timestamp: 10, + }) + assert.NoError(t, err) + + resp, err := svr.GetFlushState(context.Background(), &datapb.GetFlushStateRequest{ + FlushTs: 11, + CollectionID: collection, + }) + assert.NoError(t, err) + assert.EqualValues(t, &milvuspb.GetFlushStateResponse{ + Status: merr.Success(), + Flushed: false, + }, resp) + }) + + t.Run("no channels", func(t *testing.T) { + meta, err := newMemoryMeta() + assert.NoError(t, err) + svr := newTestServerWithMeta(t, nil, meta) + defer closeTestServer(t, svr) + + collection := int64(0) + + resp, err := svr.GetFlushState(context.Background(), &datapb.GetFlushStateRequest{ + FlushTs: 11, + CollectionID: collection, + }) + assert.NoError(t, err) + assert.EqualValues(t, &milvuspb.GetFlushStateResponse{ + Status: merr.Success(), Flushed: true, }, resp) }) @@ -3471,18 +3657,34 @@ func TestGetFlushAllState(t *testing.T) { ExpectedSuccess bool ExpectedFlushed bool }{ - {"test FlushAll flushed", []Timestamp{100, 200}, 99, - true, false, false, false, true, true}, - {"test FlushAll not flushed", []Timestamp{100, 200}, 150, - true, false, false, false, true, false}, - {"test Sever is not healthy", nil, 0, - false, false, false, false, false, false}, - {"test ListDatabase failed", nil, 0, - true, true, false, false, false, false}, - {"test ShowCollections failed", nil, 0, - true, false, true, false, false, false}, - {"test DescribeCollection failed", nil, 0, - true, false, false, true, false, false}, + { + "test FlushAll flushed", + []Timestamp{100, 200}, + 99, + true, false, false, false, true, true, + }, + { + "test FlushAll not flushed", + []Timestamp{100, 200}, + 150, + true, false, false, false, true, false, + }, + { + "test Sever is not healthy", nil, 0, + false, false, false, false, false, false, + }, + { + "test ListDatabase failed", nil, 0, + true, true, false, false, false, false, + }, + { + "test ShowCollections failed", nil, 0, + true, false, true, false, false, false, + }, + { + "test DescribeCollection failed", nil, 0, + true, false, false, true, false, false, + }, } for _, test := range tests { t.Run(test.testName, func(t *testing.T) { @@ -3495,43 +3697,43 @@ func TestGetFlushAllState(t *testing.T) { } var err error svr.meta = &meta{} - svr.rootCoordClient = mocks.NewRootCoord(t) + svr.rootCoordClient = mocks.NewMockRootCoordClient(t) svr.broker = NewCoordinatorBroker(svr.rootCoordClient) if test.ListDatabaseFailed { - svr.rootCoordClient.(*mocks.RootCoord).EXPECT().ListDatabases(mock.Anything, mock.Anything). + svr.rootCoordClient.(*mocks.MockRootCoordClient).EXPECT().ListDatabases(mock.Anything, mock.Anything). Return(&milvuspb.ListDatabasesResponse{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, }, nil).Maybe() } else { - svr.rootCoordClient.(*mocks.RootCoord).EXPECT().ListDatabases(mock.Anything, mock.Anything). + svr.rootCoordClient.(*mocks.MockRootCoordClient).EXPECT().ListDatabases(mock.Anything, mock.Anything). Return(&milvuspb.ListDatabasesResponse{ DbNames: []string{"db1"}, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), }, nil).Maybe() } if test.ShowCollectionFailed { - svr.rootCoordClient.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything). + svr.rootCoordClient.(*mocks.MockRootCoordClient).EXPECT().ShowCollections(mock.Anything, mock.Anything). Return(&milvuspb.ShowCollectionsResponse{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, }, nil).Maybe() } else { - svr.rootCoordClient.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything). + svr.rootCoordClient.(*mocks.MockRootCoordClient).EXPECT().ShowCollections(mock.Anything, mock.Anything). Return(&milvuspb.ShowCollectionsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), CollectionIds: []int64{collection}, }, nil).Maybe() } if test.DescribeCollectionFailed { - svr.rootCoordClient.(*mocks.RootCoord).EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything). + svr.rootCoordClient.(*mocks.MockRootCoordClient).EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, }, nil).Maybe() } else { - svr.rootCoordClient.(*mocks.RootCoord).EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything). + svr.rootCoordClient.(*mocks.MockRootCoordClient).EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), VirtualChannelNames: vchannels, }, nil).Maybe() } @@ -3549,8 +3751,10 @@ func TestGetFlushAllState(t *testing.T) { assert.NoError(t, err) if test.ExpectedSuccess { assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) - } else { + } else if test.ServerIsHealthy { assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) + } else { + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) } assert.Equal(t, test.ExpectedFlushed, resp.GetFlushed()) }) @@ -3579,32 +3783,32 @@ func TestGetFlushAllStateWithDB(t *testing.T) { svr.stateCode.Store(commonpb.StateCode_Healthy) var err error svr.meta = &meta{} - svr.rootCoordClient = mocks.NewRootCoord(t) + svr.rootCoordClient = mocks.NewMockRootCoordClient(t) svr.broker = NewCoordinatorBroker(svr.rootCoordClient) if test.DbExist { - svr.rootCoordClient.(*mocks.RootCoord).EXPECT().ListDatabases(mock.Anything, mock.Anything). + svr.rootCoordClient.(*mocks.MockRootCoordClient).EXPECT().ListDatabases(mock.Anything, mock.Anything). Return(&milvuspb.ListDatabasesResponse{ DbNames: []string{dbName}, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), }, nil).Maybe() } else { - svr.rootCoordClient.(*mocks.RootCoord).EXPECT().ListDatabases(mock.Anything, mock.Anything). + svr.rootCoordClient.(*mocks.MockRootCoordClient).EXPECT().ListDatabases(mock.Anything, mock.Anything). Return(&milvuspb.ListDatabasesResponse{ DbNames: []string{}, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), }, nil).Maybe() } - svr.rootCoordClient.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything). + svr.rootCoordClient.(*mocks.MockRootCoordClient).EXPECT().ShowCollections(mock.Anything, mock.Anything). Return(&milvuspb.ShowCollectionsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), CollectionIds: []int64{collectionID}, }, nil).Maybe() - svr.rootCoordClient.(*mocks.RootCoord).EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything). + svr.rootCoordClient.(*mocks.MockRootCoordClient).EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Success(), VirtualChannelNames: vchannels, CollectionID: collectionID, CollectionName: collectionName, @@ -3651,7 +3855,7 @@ func TestDataCoordServer_SetSegmentState(t *testing.T) { Timestamp: 0, }, } - err := svr.meta.AddSegment(NewSegmentInfo(segment)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segment)) assert.NoError(t, err) // Set segment state. svr.SetSegmentState(context.TODO(), &datapb.SetSegmentStateRequest{ @@ -3669,7 +3873,7 @@ func TestDataCoordServer_SetSegmentState(t *testing.T) { SegmentIDs: []int64{1000}, }) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 1, len(resp.States)) assert.EqualValues(t, commonpb.SegmentState_Flushed, resp.States[0].State) }) @@ -3695,7 +3899,7 @@ func TestDataCoordServer_SetSegmentState(t *testing.T) { SegmentIDs: []int64{1000}, }) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 1, len(resp.States)) assert.EqualValues(t, commonpb.SegmentState_NotExist, resp.States[0].State) }) @@ -3708,8 +3912,7 @@ func TestDataCoordServer_SetSegmentState(t *testing.T) { NewState: commonpb.SegmentState_Flushed, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) } @@ -3724,7 +3927,7 @@ func TestDataCoord_Import(t *testing.T) { }) err := svr.channelManager.AddNode(0) assert.NoError(t, err) - err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 0}) + err = svr.channelManager.Watch(svr.ctx, &channel{Name: "ch1", CollectionID: 0}) assert.NoError(t, err) resp, err := svr.Import(svr.ctx, &datapb.ImportTaskRequest{ @@ -3743,7 +3946,7 @@ func TestDataCoord_Import(t *testing.T) { err := svr.channelManager.AddNode(0) assert.NoError(t, err) - err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 0}) + err = svr.channelManager.Watch(svr.ctx, &channel{Name: "ch1", CollectionID: 0}) assert.NoError(t, err) resp, err := svr.Import(svr.ctx, &datapb.ImportTaskRequest{ @@ -3784,8 +3987,7 @@ func TestDataCoord_Import(t *testing.T) { }, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.Status.GetErrorCode()) - assert.Equal(t, msgDataCoordIsUnhealthy(paramtable.GetNodeID()), resp.Status.GetReason()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) t.Run("test update segment stat", func(t *testing.T) { @@ -3813,7 +4015,7 @@ func TestDataCoord_Import(t *testing.T) { }}, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.GetErrorCode()) + assert.ErrorIs(t, merr.Error(status), merr.ErrServiceNotReady) }) } @@ -3830,7 +4032,7 @@ func TestDataCoord_SegmentStatistics(t *testing.T) { } info := NewSegmentInfo(seg1) - svr.meta.AddSegment(info) + svr.meta.AddSegment(context.TODO(), info) status, err := svr.UpdateSegmentStatistics(context.TODO(), &datapb.UpdateSegmentStatisticsRequest{ Stats: []*commonpb.SegmentStats{{ @@ -3857,7 +4059,7 @@ func TestDataCoord_SegmentStatistics(t *testing.T) { } info := NewSegmentInfo(seg1) - svr.meta.AddSegment(info) + svr.meta.AddSegment(context.TODO(), info) status, err := svr.UpdateSegmentStatistics(context.TODO(), &datapb.UpdateSegmentStatisticsRequest{ Stats: []*commonpb.SegmentStats{{ @@ -3881,14 +4083,14 @@ func TestDataCoord_SaveImportSegment(t *testing.T) { ID: 100, }) seg := buildSegment(100, 100, 100, "ch1", false) - svr.meta.AddSegment(seg) + svr.meta.AddSegment(context.TODO(), seg) svr.sessionManager.AddSession(&NodeInfo{ NodeID: 110, Address: "localhost:8080", }) err := svr.channelManager.AddNode(110) assert.NoError(t, err) - err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 100}) + err = svr.channelManager.Watch(context.TODO(), &channel{Name: "ch1", CollectionID: 100}) assert.NoError(t, err) status, err := svr.SaveImportSegment(context.TODO(), &datapb.SaveImportSegmentRequest{ @@ -3925,7 +4127,7 @@ func TestDataCoord_SaveImportSegment(t *testing.T) { err := svr.channelManager.AddNode(110) assert.NoError(t, err) - err = svr.channelManager.Watch(&channel{Name: "ch1", CollectionID: 100}) + err = svr.channelManager.Watch(context.TODO(), &channel{Name: "ch1", CollectionID: 100}) assert.NoError(t, err) status, err := svr.SaveImportSegment(context.TODO(), &datapb.SaveImportSegmentRequest{ @@ -3945,7 +4147,7 @@ func TestDataCoord_SaveImportSegment(t *testing.T) { status, err := svr.SaveImportSegment(context.TODO(), &datapb.SaveImportSegmentRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_DataCoordNA, status.GetErrorCode()) + assert.ErrorIs(t, merr.Error(status), merr.ErrServiceNotReady) }) } @@ -3954,7 +4156,7 @@ func TestDataCoord_UnsetIsImportingState(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) seg := buildSegment(100, 100, 100, "ch1", false) - svr.meta.AddSegment(seg) + svr.meta.AddSegment(context.TODO(), seg) status, err := svr.UnsetIsImportingState(context.Background(), &datapb.UnsetIsImportingStateRequest{ SegmentIds: []int64{100}, @@ -4002,9 +4204,12 @@ func TestDataCoordServer_UpdateChannelCheckpoint(t *testing.T) { }) } +var globalTestTikv = tikv.SetupLocalTxn() + func newTestServer(t *testing.T, receiveCh chan any, opts ...Option) *Server { var err error paramtable.Get().Save(Params.CommonCfg.DataCoordTimeTick.Key, Params.CommonCfg.DataCoordTimeTick.GetValue()+strconv.Itoa(rand.Int())) + paramtable.Get().Save(Params.RocksmqCfg.CompressionTypes.Key, "0,0,0,0,0") factory := dependency.NewDefaultFactory(true) etcdCli, err := etcd.GetEtcdClient( Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), @@ -4021,11 +4226,13 @@ func newTestServer(t *testing.T, receiveCh chan any, opts ...Option) *Server { svr := CreateServer(context.TODO(), factory) svr.SetEtcdClient(etcdCli) - svr.dataNodeCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { + svr.SetTiKVClient(globalTestTikv) + + svr.dataNodeCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { return newMockDataNodeClient(0, receiveCh) } - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } for _, opt := range opts { @@ -4073,15 +4280,17 @@ func newTestServerWithMeta(t *testing.T, receiveCh chan any, meta *meta, opts .. svr := CreateServer(context.TODO(), factory, opts...) svr.SetEtcdClient(etcdCli) - svr.dataNodeCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { + svr.SetTiKVClient(globalTestTikv) + + svr.dataNodeCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { return newMockDataNodeClient(0, receiveCh) } - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } - //indexCoord := mocks.NewMockIndexCoord(t) - //indexCoord.EXPECT().GetIndexInfos(mock.Anything, mock.Anything).Return(nil, nil).Maybe() - //svr.indexCoord = indexCoord + // indexCoord := mocks.NewMockIndexCoord(t) + // indexCoord.EXPECT().GetIndexInfos(mock.Anything, mock.Anything).Return(nil, nil).Maybe() + // svr.indexCoord = indexCoord err = svr.Init() assert.NoError(t, err) @@ -4128,11 +4337,13 @@ func newTestServer2(t *testing.T, receiveCh chan any, opts ...Option) *Server { svr := CreateServer(context.TODO(), factory, opts...) svr.SetEtcdClient(etcdCli) - svr.dataNodeCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { + svr.SetTiKVClient(globalTestTikv) + + svr.dataNodeCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { return newMockDataNodeClient(0, receiveCh) } - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } err = svr.Init() @@ -4153,7 +4364,7 @@ func newTestServer2(t *testing.T, receiveCh chan any, opts ...Option) *Server { func Test_CheckHealth(t *testing.T) { t.Run("not healthy", func(t *testing.T) { ctx := context.Background() - s := &Server{session: &sessionutil.Session{ServerID: 1}} + s := &Server{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} s.stateCode.Store(commonpb.StateCode_Abnormal) resp, err := s.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) assert.NoError(t, err) @@ -4162,18 +4373,19 @@ func Test_CheckHealth(t *testing.T) { }) t.Run("data node health check is ok", func(t *testing.T) { - svr := &Server{session: &sessionutil.Session{ServerID: 1}} + svr := &Server{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} svr.stateCode.Store(commonpb.StateCode_Healthy) healthClient := &mockDataNodeClient{ id: 1, - state: commonpb.StateCode_Healthy} + state: commonpb.StateCode_Healthy, + } sm := NewSessionManager() sm.sessions = struct { sync.RWMutex data map[int64]*Session }{data: map[int64]*Session{1: { client: healthClient, - clientCreator: func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { + clientCreator: func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { return healthClient, nil }, }}} @@ -4187,18 +4399,19 @@ func Test_CheckHealth(t *testing.T) { }) t.Run("data node health check is fail", func(t *testing.T) { - svr := &Server{session: &sessionutil.Session{ServerID: 1}} + svr := &Server{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} svr.stateCode.Store(commonpb.StateCode_Healthy) unhealthClient := &mockDataNodeClient{ id: 1, - state: commonpb.StateCode_Abnormal} + state: commonpb.StateCode_Abnormal, + } sm := NewSessionManager() sm.sessions = struct { sync.RWMutex data map[int64]*Session }{data: map[int64]*Session{1: { client: unhealthClient, - clientCreator: func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { + clientCreator: func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { return unhealthClient, nil }, }}} @@ -4320,13 +4533,15 @@ func testDataCoordBase(t *testing.T, opts ...Option) *Server { svr := CreateServer(ctx, factory, opts...) svr.SetEtcdClient(etcdCli) - svr.SetDataNodeCreator(func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { + svr.SetTiKVClient(globalTestTikv) + + svr.SetDataNodeCreator(func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { return newMockDataNodeClient(0, nil) }) - svr.SetIndexNodeCreator(func(ctx context.Context, addr string, nodeID int64) (types.IndexNode, error) { - return indexnode.NewMockIndexNodeComponent(ctx) + svr.SetIndexNodeCreator(func(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) { + return &grpcmock.GrpcIndexNodeClient{Err: nil}, nil }) - svr.SetRootCoord(newMockRootCoordService()) + svr.SetRootCoordClient(newMockRootCoordClient()) err = svr.Init() assert.NoError(t, err) @@ -4335,7 +4550,7 @@ func testDataCoordBase(t *testing.T, opts ...Option) *Server { err = svr.Register() assert.NoError(t, err) - resp, err := svr.GetComponentStates(context.Background()) + resp, err := svr.GetComponentStates(context.Background(), nil) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.StateCode_Healthy, resp.GetState().GetStateCode()) @@ -4361,3 +4576,225 @@ func TestDataCoord_EnableActiveStandby(t *testing.T) { svr := testDataCoordBase(t) defer closeTestServer(t, svr) } + +func TestDataNodeTtChannel(t *testing.T) { + paramtable.Get().Save(Params.DataNodeCfg.DataNodeTimeTickByRPC.Key, "false") + defer paramtable.Get().Reset(Params.DataNodeCfg.DataNodeTimeTickByRPC.Key) + genMsg := func(msgType commonpb.MsgType, ch string, t Timestamp) *msgstream.DataNodeTtMsg { + return &msgstream.DataNodeTtMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: []uint32{0}, + }, + DataNodeTtMsg: msgpb.DataNodeTtMsg{ + Base: &commonpb.MsgBase{ + MsgType: msgType, + MsgID: 0, + Timestamp: t, + SourceID: 0, + }, + ChannelName: ch, + Timestamp: t, + }, + } + } + t.Run("Test segment flush after tt", func(t *testing.T) { + ch := make(chan any, 1) + svr := newTestServer(t, ch) + defer closeTestServer(t, svr) + + svr.meta.AddCollection(&collectionInfo{ + ID: 0, + Schema: newTestSchema(), + Partitions: []int64{0}, + }) + + ttMsgStream, err := svr.factory.NewMsgStream(context.TODO()) + assert.NoError(t, err) + ttMsgStream.AsProducer([]string{Params.CommonCfg.DataCoordTimeTick.GetValue()}) + defer ttMsgStream.Close() + info := &NodeInfo{ + Address: "localhost:7777", + NodeID: 0, + } + err = svr.cluster.Register(info) + assert.NoError(t, err) + + resp, err := svr.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{ + NodeID: 0, + PeerRole: "", + SegmentIDRequests: []*datapb.SegmentIDRequest{ + { + CollectionID: 0, + PartitionID: 0, + ChannelName: "ch-1", + Count: 100, + }, + }, + }) + + assert.NoError(t, err) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + assert.EqualValues(t, 1, len(resp.SegIDAssignments)) + assign := resp.SegIDAssignments[0] + + resp2, err := svr.Flush(context.TODO(), &datapb.FlushRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Flush, + MsgID: 0, + Timestamp: 0, + SourceID: 0, + }, + DbID: 0, + CollectionID: 0, + }) + assert.NoError(t, err) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp2.GetStatus().GetErrorCode()) + + msgPack := msgstream.MsgPack{} + msg := genMsg(commonpb.MsgType_DataNodeTt, "ch-1", assign.ExpireTime) + msg.SegmentsStats = append(msg.SegmentsStats, &commonpb.SegmentStats{ + SegmentID: assign.GetSegID(), + NumRows: 1, + }) + msgPack.Msgs = append(msgPack.Msgs, msg) + err = ttMsgStream.Produce(&msgPack) + assert.NoError(t, err) + + flushMsg := <-ch + flushReq := flushMsg.(*datapb.FlushSegmentsRequest) + assert.EqualValues(t, 1, len(flushReq.SegmentIDs)) + assert.EqualValues(t, assign.SegID, flushReq.SegmentIDs[0]) + }) + + t.Run("flush segment with different channels", func(t *testing.T) { + ch := make(chan any, 1) + svr := newTestServer(t, ch) + defer closeTestServer(t, svr) + svr.meta.AddCollection(&collectionInfo{ + ID: 0, + Schema: newTestSchema(), + Partitions: []int64{0}, + }) + ttMsgStream, err := svr.factory.NewMsgStream(context.TODO()) + assert.NoError(t, err) + ttMsgStream.AsProducer([]string{Params.CommonCfg.DataCoordTimeTick.GetValue()}) + defer ttMsgStream.Close() + info := &NodeInfo{ + Address: "localhost:7777", + NodeID: 0, + } + err = svr.cluster.Register(info) + assert.NoError(t, err) + resp, err := svr.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{ + NodeID: 0, + PeerRole: "", + SegmentIDRequests: []*datapb.SegmentIDRequest{ + { + CollectionID: 0, + PartitionID: 0, + ChannelName: "ch-1", + Count: 100, + }, + { + CollectionID: 0, + PartitionID: 0, + ChannelName: "ch-2", + Count: 100, + }, + }, + }) + assert.NoError(t, err) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + assert.EqualValues(t, 2, len(resp.SegIDAssignments)) + var assign *datapb.SegmentIDAssignment + for _, segment := range resp.SegIDAssignments { + if segment.GetChannelName() == "ch-1" { + assign = segment + break + } + } + assert.NotNil(t, assign) + resp2, err := svr.Flush(context.TODO(), &datapb.FlushRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Flush, + MsgID: 0, + Timestamp: 0, + SourceID: 0, + }, + DbID: 0, + CollectionID: 0, + }) + assert.NoError(t, err) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp2.GetStatus().GetErrorCode()) + + msgPack := msgstream.MsgPack{} + msg := genMsg(commonpb.MsgType_DataNodeTt, "ch-1", assign.ExpireTime) + msg.SegmentsStats = append(msg.SegmentsStats, &commonpb.SegmentStats{ + SegmentID: assign.GetSegID(), + NumRows: 1, + }) + msgPack.Msgs = append(msgPack.Msgs, msg) + err = ttMsgStream.Produce(&msgPack) + assert.NoError(t, err) + flushMsg := <-ch + flushReq := flushMsg.(*datapb.FlushSegmentsRequest) + assert.EqualValues(t, 1, len(flushReq.SegmentIDs)) + assert.EqualValues(t, assign.SegID, flushReq.SegmentIDs[0]) + }) + + t.Run("test expire allocation after receiving tt msg", func(t *testing.T) { + ch := make(chan any, 1) + helper := ServerHelper{ + eventAfterHandleDataNodeTt: func() { ch <- struct{}{} }, + } + svr := newTestServer(t, nil, WithServerHelper(helper)) + defer closeTestServer(t, svr) + + svr.meta.AddCollection(&collectionInfo{ + ID: 0, + Schema: newTestSchema(), + Partitions: []int64{0}, + }) + + ttMsgStream, err := svr.factory.NewMsgStream(context.TODO()) + assert.NoError(t, err) + ttMsgStream.AsProducer([]string{Params.CommonCfg.DataCoordTimeTick.GetValue()}) + defer ttMsgStream.Close() + node := &NodeInfo{ + NodeID: 0, + Address: "localhost:7777", + } + err = svr.cluster.Register(node) + assert.NoError(t, err) + + resp, err := svr.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{ + NodeID: 0, + PeerRole: "", + SegmentIDRequests: []*datapb.SegmentIDRequest{ + { + CollectionID: 0, + PartitionID: 0, + ChannelName: "ch-1", + Count: 100, + }, + }, + }) + assert.NoError(t, err) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + assert.EqualValues(t, 1, len(resp.SegIDAssignments)) + + assignedSegmentID := resp.SegIDAssignments[0].SegID + segment := svr.meta.GetHealthySegment(assignedSegmentID) + assert.EqualValues(t, 1, len(segment.allocations)) + + msgPack := msgstream.MsgPack{} + msg := genMsg(commonpb.MsgType_DataNodeTt, "ch-1", resp.SegIDAssignments[0].ExpireTime) + msgPack.Msgs = append(msgPack.Msgs, msg) + err = ttMsgStream.Produce(&msgPack) + assert.NoError(t, err) + + <-ch + segment = svr.meta.GetHealthySegment(assignedSegmentID) + assert.EqualValues(t, 0, len(segment.allocations)) + }) +} diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 02bda5624a4aa..ef282d01f00a3 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -23,6 +23,7 @@ import ( "strconv" "sync" + "github.com/cockroachdb/errors" "github.com/samber/lo" "go.opentelemetry.io/otel" "go.uber.org/zap" @@ -31,6 +32,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/metastore/kv/datacoord" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/util/segmentutil" @@ -38,34 +40,26 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/commonpbutil" - "github.com/milvus-io/milvus/pkg/util/errorutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) -// checks whether server in Healthy State -func (s *Server) isClosed() bool { - return s.stateCode.Load() != commonpb.StateCode_Healthy -} - // GetTimeTickChannel legacy API, returns time tick channel name -func (s *Server) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (s *Server) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Value: Params.CommonCfg.DataCoordTimeTick.GetValue(), }, nil } // GetStatisticsChannel legacy API, returns statistics channel name -func (s *Server) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (s *Server) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "no statistics channel", - }, + Status: merr.Status(merr.WrapErrChannelNotFound("no statistics channel")), }, nil } @@ -73,38 +67,36 @@ func (s *Server) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // this api only guarantees all the segments requested is sealed // these segments will be flushed only after the Flush policy is fulfilled func (s *Server) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.FlushResponse, error) { - log := log.Ctx(ctx) - log.Info("receive flush request", + log := log.Ctx(ctx).With( zap.Int64("dbID", req.GetDbID()), zap.Int64("collectionID", req.GetCollectionID()), zap.Bool("isImporting", req.GetIsImport())) + log.Info("receive flush request") ctx, sp := otel.Tracer(typeutil.DataCoordRole).Start(ctx, "DataCoord-Flush") defer sp.End() - resp := &datapb.FlushResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "", - }, - DbID: 0, - CollectionID: 0, - SegmentIDs: []int64{}, - } - if s.isClosed() { - resp.Status.Reason = serverNotServingErrMsg - return resp, nil + + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.FlushResponse{ + Status: merr.Status(err), + }, nil } // generate a timestamp timeOfSeal, all data before timeOfSeal is guaranteed to be sealed or flushed ts, err := s.allocator.allocTimestamp(ctx) if err != nil { log.Warn("unable to alloc timestamp", zap.Error(err)) + return &datapb.FlushResponse{ + Status: merr.Status(err), + }, nil } timeOfSeal, _ := tsoutil.ParseTS(ts) sealedSegmentIDs, err := s.segmentManager.SealAllSegments(ctx, req.GetCollectionID(), req.GetSegmentIDs(), req.GetIsImport()) if err != nil { - resp.Status.Reason = fmt.Sprintf("failed to flush %d, %s", req.CollectionID, err) - return resp, nil + return &datapb.FlushResponse{ + Status: merr.Status(errors.Wrapf(err, "failed to flush collection %d", + req.GetCollectionID())), + }, nil } sealedSegmentsIDDict := make(map[UniqueID]bool) @@ -116,36 +108,70 @@ func (s *Server) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.F flushSegmentIDs := make([]UniqueID, 0, len(segments)) for _, segment := range segments { if segment != nil && - (segment.GetState() == commonpb.SegmentState_Flushed || - segment.GetState() == commonpb.SegmentState_Flushing) && + (isFlushState(segment.GetState())) && !sealedSegmentsIDDict[segment.GetID()] { flushSegmentIDs = append(flushSegmentIDs, segment.GetID()) } } + var isUnimplemented bool + err = retry.Do(ctx, func() error { + for _, channelInfo := range s.channelManager.GetAssignedChannels() { + nodeID := channelInfo.NodeID + channels := lo.Filter(channelInfo.Channels, func(channel *channel, _ int) bool { + return channel.CollectionID == req.GetCollectionID() + }) + channelNames := lo.Map(channels, func(channel *channel, _ int) string { + return channel.Name + }) + err = s.cluster.FlushChannels(ctx, nodeID, ts, channelNames) + if err != nil && errors.Is(err, merr.ErrServiceUnimplemented) { + isUnimplemented = true + return nil + } + if err != nil { + return err + } + } + return nil + }, retry.Attempts(60)) // about 3min + if err != nil { + return &datapb.FlushResponse{ + Status: merr.Status(err), + }, nil + } + + if isUnimplemented { + // For compatible with rolling upgrade from version 2.2.x, + // fall back to the flush logic of version 2.2.x; + log.Warn("DataNode FlushChannels unimplemented", zap.Error(err)) + ts = 0 + } + log.Info("flush response with segments", zap.Int64("collectionID", req.GetCollectionID()), zap.Int64s("sealSegments", sealedSegmentIDs), zap.Int64s("flushSegments", flushSegmentIDs), - zap.Time("timeOfSeal", timeOfSeal)) - resp.Status.ErrorCode = commonpb.ErrorCode_Success - resp.DbID = req.GetDbID() - resp.CollectionID = req.GetCollectionID() - resp.SegmentIDs = sealedSegmentIDs - resp.TimeOfSeal = timeOfSeal.Unix() - resp.FlushSegmentIDs = flushSegmentIDs - return resp, nil + zap.Time("timeOfSeal", timeOfSeal), + zap.Time("flushTs", tsoutil.PhysicalTime(ts))) + + return &datapb.FlushResponse{ + Status: merr.Success(), + DbID: req.GetDbID(), + CollectionID: req.GetCollectionID(), + SegmentIDs: sealedSegmentIDs, + TimeOfSeal: timeOfSeal.Unix(), + FlushSegmentIDs: flushSegmentIDs, + FlushTs: ts, + }, nil } // AssignSegmentID applies for segment ids and make allocation for records. func (s *Server) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { log := log.Ctx(ctx) - if s.isClosed() { + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { return &datapb.AssignSegmentIDResponse{ - Status: &commonpb.Status{ - Reason: serverNotServingErrMsg, - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Status(err), }, nil } @@ -168,7 +194,7 @@ func (s *Server) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI } // Add the channel to cluster for watching. - s.cluster.Watch(r.ChannelName, r.CollectionID) + s.cluster.Watch(ctx, r.ChannelName, r.CollectionID) segmentAllocations := make([]*Allocation, 0) if r.GetIsImport() { @@ -201,29 +227,28 @@ func (s *Server) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI CollectionID: r.CollectionID, PartitionID: r.PartitionID, ExpireTime: allocation.ExpireTime, - Status: merr.Status(nil), + Status: merr.Success(), } assigns = append(assigns, result) } } return &datapb.AssignSegmentIDResponse{ - Status: merr.Status(nil), + Status: merr.Success(), SegIDAssignments: assigns, }, nil } // GetSegmentStates returns segments state func (s *Server) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { - resp := &datapb.GetSegmentStatesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, - } - if s.isClosed() { - resp.Status.Reason = serverNotServingErrMsg - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.GetSegmentStatesResponse{ + Status: merr.Status(err), + }, nil } + resp := &datapb.GetSegmentStatesResponse{ + Status: merr.Success(), + } for _, segmentID := range req.SegmentIDs { state := &datapb.SegmentStateInfo{ SegmentID: segmentID, @@ -237,26 +262,26 @@ func (s *Server) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentSta } resp.States = append(resp.States, state) } - resp.Status.ErrorCode = commonpb.ErrorCode_Success return resp, nil } // GetInsertBinlogPaths returns binlog paths info for requested segments func (s *Server) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsertBinlogPathsRequest) (*datapb.GetInsertBinlogPathsResponse, error) { - resp := &datapb.GetInsertBinlogPathsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, - } - if s.isClosed() { - resp.Status.Reason = serverNotServingErrMsg - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.GetInsertBinlogPathsResponse{ + Status: merr.Status(err), + }, nil } segment := s.meta.GetHealthySegment(req.GetSegmentID()) if segment == nil { - resp.Status.Reason = "segment not found" - return resp, nil + return &datapb.GetInsertBinlogPathsResponse{ + Status: merr.Status(merr.WrapErrSegmentNotFound(req.GetSegmentID())), + }, nil + } + + resp := &datapb.GetInsertBinlogPathsResponse{ + Status: merr.Success(), } binlogs := segment.GetBinlogs() fids := make([]UniqueID, 0, len(binlogs)) @@ -270,7 +295,6 @@ func (s *Server) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsert } paths = append(paths, &internalpb.StringList{Values: p}) } - resp.Status.ErrorCode = commonpb.ErrorCode_Success resp.FieldIDs = fids resp.Paths = paths return resp, nil @@ -283,17 +307,16 @@ func (s *Server) GetCollectionStatistics(ctx context.Context, req *datapb.GetCol zap.Int64("collectionID", req.GetCollectionID()), ) log.Info("received request to get collection statistics") - resp := &datapb.GetCollectionStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.GetCollectionStatisticsResponse{ + Status: merr.Status(err), + }, nil } - if s.isClosed() { - resp.Status.Reason = serverNotServingErrMsg - return resp, nil + + resp := &datapb.GetCollectionStatisticsResponse{ + Status: merr.Success(), } nums := s.meta.GetNumRowsOfCollection(req.CollectionID) - resp.Status.ErrorCode = commonpb.ErrorCode_Success resp.Stats = append(resp.Stats, &commonpb.KeyValuePair{Key: "row_count", Value: strconv.FormatInt(nums, 10)}) log.Info("success to get collection statistics", zap.Any("response", resp)) return resp, nil @@ -308,13 +331,12 @@ func (s *Server) GetPartitionStatistics(ctx context.Context, req *datapb.GetPart zap.Int64s("partitionIDs", req.GetPartitionIDs()), ) resp := &datapb.GetPartitionStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Success(), } - if s.isClosed() { - resp.Status.Reason = serverNotServingErrMsg - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.GetPartitionStatisticsResponse{ + Status: merr.Status(err), + }, nil } nums := int64(0) if len(req.GetPartitionIDs()) == 0 { @@ -324,16 +346,15 @@ func (s *Server) GetPartitionStatistics(ctx context.Context, req *datapb.GetPart num := s.meta.GetNumRowsOfPartition(req.CollectionID, partID) nums += num } - resp.Status.ErrorCode = commonpb.ErrorCode_Success resp.Stats = append(resp.Stats, &commonpb.KeyValuePair{Key: "row_count", Value: strconv.FormatInt(nums, 10)}) log.Info("success to get partition statistics", zap.Any("response", resp)) return resp, nil } // GetSegmentInfoChannel legacy API, returns segment info statistics channel -func (s *Server) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (s *Server) GetSegmentInfoChannel(ctx context.Context, req *datapb.GetSegmentInfoChannelRequest) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Value: Params.CommonCfg.DataCoordSegmentInfo.GetValue(), }, nil } @@ -343,13 +364,12 @@ func (s *Server) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringRes func (s *Server) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoRequest) (*datapb.GetSegmentInfoResponse, error) { log := log.Ctx(ctx) resp := &datapb.GetSegmentInfoResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Success(), } - if s.isClosed() { - resp.Status.Reason = serverNotServingErrMsg - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.GetSegmentInfoResponse{ + Status: merr.Status(err), + }, nil } infos := make([]*datapb.SegmentInfo, 0, len(req.GetSegmentIDs())) channelCPs := make(map[string]*msgpb.MsgPosition) @@ -360,7 +380,8 @@ func (s *Server) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoR if info == nil { log.Warn("failed to get segment, this may have been cleaned", zap.Int64("segmentID", id)) - resp.Status.Reason = msgSegmentNotFound(id) + err := merr.WrapErrSegmentNotFound(id) + resp.Status = merr.Status(err) return resp, nil } @@ -375,7 +396,8 @@ func (s *Server) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoR } else { info = s.meta.GetHealthySegment(id) if info == nil { - resp.Status.Reason = msgSegmentNotFound(id) + err := merr.WrapErrSegmentNotFound(id) + resp.Status = merr.Status(err) return resp, nil } clonedInfo := info.Clone() @@ -387,7 +409,6 @@ func (s *Server) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoR channelCPs[vchannel] = s.meta.GetChannelCheckpoint(vchannel) } } - resp.Status.ErrorCode = commonpb.ErrorCode_Success resp.Infos = infos resp.ChannelCheckpoint = channelCPs return resp, nil @@ -396,11 +417,8 @@ func (s *Server) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoR // SaveBinlogPaths updates segment related binlog path // works for Checkpoints and Flush func (s *Server) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) (*commonpb.Status, error) { - resp := &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError} - - if s.isClosed() { - resp.Reason = serverNotServingErrMsg - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return merr.Status(err), nil } log := log.Ctx(ctx).With( @@ -423,10 +441,9 @@ func (s *Server) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath // Also avoid to handle segment not found error if not the owner of shard if !req.GetImporting() && len(channelName) != 0 { if !s.channelManager.Match(nodeID, channelName) { - failResponse(resp, fmt.Sprintf("channel %s is not watched on node %d", channelName, nodeID)) - resp.ErrorCode = commonpb.ErrorCode_MetaFailed - log.Warn("node is not matched with channel", zap.String("channel", channelName)) - return resp, nil + err := merr.WrapErrChannelNotFound(channelName, fmt.Sprintf("for node %d", nodeID)) + log.Warn("node is not matched with channel", zap.String("channel", channelName), zap.Error(err)) + return merr.Status(err), nil } } @@ -435,19 +452,18 @@ func (s *Server) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath segment := s.meta.GetSegment(segmentID) if segment == nil { - log.Error("failed to get segment") - failResponseWithCode(resp, commonpb.ErrorCode_SegmentNotFound, fmt.Sprintf("failed to get segment %d", segmentID)) - return resp, nil + err := merr.WrapErrSegmentNotFound(segmentID) + log.Warn("failed to get segment", zap.Error(err)) + return merr.Status(err), nil } if segment.State == commonpb.SegmentState_Dropped { log.Info("save to dropped segment, ignore this request") - resp.ErrorCode = commonpb.ErrorCode_Success - return resp, nil + return merr.Success(), nil } else if !isSegmentHealthy(segment) { - log.Error("failed to get segment") - failResponseWithCode(resp, commonpb.ErrorCode_SegmentNotFound, fmt.Sprintf("failed to get segment %d", segmentID)) - return resp, nil + err := merr.WrapErrSegmentNotFound(segmentID) + log.Warn("failed to get segment, the segment not healthy", zap.Error(err)) + return merr.Status(err), nil } if req.GetDropped() { @@ -484,8 +500,7 @@ func (s *Server) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath } if err != nil { log.Error("save binlog and checkpoints failed", zap.Error(err)) - resp.Reason = err.Error() - return resp, nil + return merr.Status(err), nil } log.Info("flush segment with meta", zap.Any("meta", req.GetField2BinlogPaths())) @@ -504,8 +519,7 @@ func (s *Server) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath } } } - resp.ErrorCode = commonpb.ErrorCode_Success - return resp, nil + return merr.Success(), nil } // DropVirtualChannel notifies vchannel dropped @@ -513,13 +527,12 @@ func (s *Server) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath func (s *Server) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) (*datapb.DropVirtualChannelResponse, error) { log := log.Ctx(ctx) resp := &datapb.DropVirtualChannelResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Success(), } - if s.isClosed() { - resp.Status.Reason = serverNotServingErrMsg - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.DropVirtualChannelResponse{ + Status: merr.Status(err), + }, nil } channel := req.GetChannelName() @@ -529,8 +542,8 @@ func (s *Server) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtual // validate nodeID := req.GetBase().GetSourceID() if !s.channelManager.Match(nodeID, channel) { - failResponse(resp.Status, fmt.Sprintf("channel %s is not watched on node %d", channel, nodeID)) - resp.Status.ErrorCode = commonpb.ErrorCode_MetaFailed + err := merr.WrapErrChannelNotFound(channel, fmt.Sprintf("for node %d", nodeID)) + resp.Status = merr.Status(err) log.Warn("node is not matched with channel", zap.String("channel", channel), zap.Int64("nodeID", nodeID)) return resp, nil } @@ -557,7 +570,7 @@ func (s *Server) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtual err := s.meta.UpdateDropChannelSegmentInfo(channel, segments) if err != nil { log.Error("Update Drop Channel segment info failed", zap.String("channel", channel), zap.Error(err)) - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } @@ -571,35 +584,28 @@ func (s *Server) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtual metrics.CleanupDataCoordNumStoredRows(collectionID) // no compaction triggered in Drop procedure - resp.Status.ErrorCode = commonpb.ErrorCode_Success return resp, nil } // SetSegmentState reset the state of the given segment. func (s *Server) SetSegmentState(ctx context.Context, req *datapb.SetSegmentStateRequest) (*datapb.SetSegmentStateResponse, error) { log := log.Ctx(ctx) - if s.isClosed() { + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { return &datapb.SetSegmentStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: serverNotServingErrMsg, - }, + Status: merr.Status(err), }, nil } err := s.meta.SetState(req.GetSegmentId(), req.GetNewState()) if err != nil { log.Error("failed to updated segment state in dataCoord meta", zap.Int64("segmentID", req.SegmentId), - zap.String("to state", req.GetNewState().String())) + zap.String("newState", req.GetNewState().String())) return &datapb.SetSegmentStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } return &datapb.SetSegmentStateResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil } @@ -612,7 +618,7 @@ func (s *Server) GetStateCode() commonpb.StateCode { } // GetComponentStates returns DataCoord's current state -func (s *Server) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (s *Server) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { code := s.GetStateCode() nodeID := common.NotRegisteredID if s.session != nil && s.session.Registered() { @@ -625,7 +631,7 @@ func (s *Server) GetComponentStates(ctx context.Context) (*milvuspb.ComponentSta Role: "datacoord", StateCode: code, }, - Status: merr.Status(nil), + Status: merr.Success(), } return resp, nil } @@ -642,20 +648,19 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf ) log.Info("get recovery info request received") resp := &datapb.GetRecoveryInfoResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Success(), } - if s.isClosed() { - resp.Status.Reason = serverNotServingErrMsg - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.GetRecoveryInfoResponse{ + Status: merr.Status(err), + }, nil } dresp, err := s.broker.DescribeCollectionInternal(s.ctx, collectionID) if err != nil { log.Error("get collection info from rootcoord failed", zap.Error(err)) - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } channels := dresp.GetVirtualChannelNames() @@ -682,9 +687,9 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf for id := range flushedIDs { segment := s.meta.GetSegment(id) if segment == nil { - errMsg := fmt.Sprintf("failed to get segment %d", id) - log.Error(errMsg) - resp.Status.Reason = errMsg + err := merr.WrapErrSegmentNotFound(id) + log.Warn("failed to get segment", zap.Int64("segmentID", id)) + resp.Status = merr.Status(err) return resp, nil } // Skip non-flushing, non-flushed and dropped segments. @@ -760,7 +765,6 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf resp.Channels = channelInfos resp.Binlogs = binlogs - resp.Status.ErrorCode = commonpb.ErrorCode_Success return resp, nil } @@ -776,28 +780,18 @@ func (s *Server) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryI ) log.Info("get recovery info request received") resp := &datapb.GetRecoveryInfoResponseV2{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Success(), } - if s.isClosed() { - resp.Status.Reason = serverNotServingErrMsg - return resp, nil - } - - dresp, err := s.broker.DescribeCollectionInternal(s.ctx, collectionID) - if err != nil { - log.Error("get collection info from rootcoord failed", - zap.Error(err)) - - resp.Status.Reason = err.Error() - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.GetRecoveryInfoResponseV2{ + Status: merr.Status(err), + }, nil } - channels := dresp.GetVirtualChannelNames() + channels := s.channelManager.GetChannelsByCollectionID(collectionID) channelInfos := make([]*datapb.VchannelInfo, 0, len(channels)) flushedIDs := make(typeutil.UniqueSet) - for _, c := range channels { - channelInfo := s.handler.GetQueryVChanPositions(&channel{Name: c, CollectionID: collectionID}, partitionIDs...) + for _, ch := range channels { + channelInfo := s.handler.GetQueryVChanPositions(ch, partitionIDs...) channelInfos = append(channelInfos, channelInfo) log.Info("datacoord append channelInfo in GetRecoveryInfo", zap.String("channel", channelInfo.GetChannelName()), @@ -813,9 +807,9 @@ func (s *Server) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryI for id := range flushedIDs { segment := s.meta.GetSegment(id) if segment == nil { - errMsg := fmt.Sprintf("failed to get segment %d", id) - log.Error(errMsg) - resp.Status.Reason = errMsg + err := merr.WrapErrSegmentNotFound(id) + log.Warn("failed to get segment", zap.Int64("segmentID", id)) + resp.Status = merr.Status(err) return resp, nil } // Skip non-flushing, non-flushed and dropped segments. @@ -841,21 +835,42 @@ func (s *Server) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryI rowCount = segment.NumOfRows } + // save the traffic of sending + binLogs, err := datacoord.CompressBinLog(segment.Binlogs) + if err != nil { + log.Warn("failed to compress segment", zap.Int64("segmentID", id), zap.Error(err)) + resp.Status = merr.Status(err) + return resp, nil + } + + deltaLogs, err := datacoord.CompressBinLog(segment.Deltalogs) + if err != nil { + log.Warn("failed to compress segment", zap.Int64("segmentID", id), zap.Error(err)) + resp.Status = merr.Status(err) + return resp, nil + } + + statLogs, err := datacoord.CompressBinLog(segment.Statslogs) + if err != nil { + log.Warn("failed to compress segment", zap.Int64("segmentID", id), zap.Error(err)) + resp.Status = merr.Status(err) + return resp, nil + } + segmentInfos = append(segmentInfos, &datapb.SegmentInfo{ ID: segment.ID, PartitionID: segment.PartitionID, CollectionID: segment.CollectionID, InsertChannel: segment.InsertChannel, NumOfRows: rowCount, - Binlogs: segment.Binlogs, - Statslogs: segment.Statslogs, - Deltalogs: segment.Deltalogs, + Binlogs: binLogs, + Statslogs: statLogs, + Deltalogs: deltaLogs, }) } resp.Channels = channelInfos resp.Segments = segmentInfos - resp.Status.ErrorCode = commonpb.ErrorCode_Success return resp, nil } @@ -864,9 +879,7 @@ func (s *Server) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryI func (s *Server) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) { log := log.Ctx(ctx) resp := &datapb.GetFlushedSegmentsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Success(), } collectionID := req.GetCollectionID() partitionID := req.GetPartitionID() @@ -874,9 +887,10 @@ func (s *Server) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedS zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), ) - if s.isClosed() { - resp.Status.Reason = serverNotServingErrMsg - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.GetFlushedSegmentsResponse{ + Status: merr.Status(err), + }, nil } var segmentIDs []UniqueID if partitionID < 0 { @@ -901,7 +915,6 @@ func (s *Server) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedS } resp.Segments = ret - resp.Status.ErrorCode = commonpb.ErrorCode_Success return resp, nil } @@ -910,9 +923,7 @@ func (s *Server) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedS func (s *Server) GetSegmentsByStates(ctx context.Context, req *datapb.GetSegmentsByStatesRequest) (*datapb.GetSegmentsByStatesResponse, error) { log := log.Ctx(ctx) resp := &datapb.GetSegmentsByStatesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Success(), } collectionID := req.GetCollectionID() partitionID := req.GetPartitionID() @@ -921,9 +932,10 @@ func (s *Server) GetSegmentsByStates(ctx context.Context, req *datapb.GetSegment zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Any("states", states)) - if s.isClosed() { - resp.Status.Reason = serverNotServingErrMsg - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.GetSegmentsByStatesResponse{ + Status: merr.Status(err), + }, nil } var segmentIDs []UniqueID if partitionID < 0 { @@ -945,25 +957,14 @@ func (s *Server) GetSegmentsByStates(ctx context.Context, req *datapb.GetSegment } resp.Segments = ret - resp.Status.ErrorCode = commonpb.ErrorCode_Success return resp, nil } // ShowConfigurations returns the configurations of DataCoord matching req.Pattern func (s *Server) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - log := log.Ctx(ctx) - if s.isClosed() { - log.Warn("DataCoord.ShowConfigurations failed", - zap.Int64("nodeId", paramtable.GetNodeID()), - zap.String("req", req.Pattern), - zap.Error(errDataCoordIsUnhealthy(paramtable.GetNodeID()))) - + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { return &internalpb.ShowConfigurationsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: msgDataCoordIsUnhealthy(paramtable.GetNodeID()), - }, - Configuations: nil, + Status: merr.Status(err), }, nil } @@ -977,7 +978,7 @@ func (s *Server) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon } return &internalpb.ShowConfigurationsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Configuations: configList, }, nil } @@ -986,19 +987,9 @@ func (s *Server) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon // it may include SystemMetrics, Topology metrics, etc. func (s *Server) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { log := log.Ctx(ctx) - if s.isClosed() { - log.Warn("DataCoord.GetMetrics failed", - zap.Int64("nodeID", paramtable.GetNodeID()), - zap.String("req", req.Request), - zap.Error(errDataCoordIsUnhealthy(paramtable.GetNodeID()))) - + if err := merr.CheckHealthyStandby(s.GetStateCode()); err != nil { return &milvuspb.GetMetricsResponse{ - ComponentName: metricsinfo.ConstructComponentName(typeutil.DataCoordRole, paramtable.GetNodeID()), - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: msgDataCoordIsUnhealthy(paramtable.GetNodeID()), - }, - Response: "", + Status: merr.Status(err), }, nil } @@ -1007,15 +998,12 @@ func (s *Server) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest log.Warn("DataCoord.GetMetrics failed to parse metric type", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("req", req.Request), - zap.Error(err)) + zap.Error(err), + ) return &milvuspb.GetMetricsResponse{ ComponentName: metricsinfo.ConstructComponentName(typeutil.DataCoordRole, paramtable.GetNodeID()), - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - Response: "", + Status: merr.Status(err), }, nil } @@ -1024,10 +1012,7 @@ func (s *Server) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest if err != nil { log.Warn("DataCoord GetMetrics failed", zap.Int64("nodeID", paramtable.GetNodeID()), zap.Error(err)) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -1048,11 +1033,7 @@ func (s *Server) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest return &milvuspb.GetMetricsResponse{ ComponentName: metricsinfo.ConstructComponentName(typeutil.DataCoordRole, paramtable.GetNodeID()), - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: metricsinfo.MsgUnimplementedMetric, - }, - Response: "", + Status: merr.Status(merr.WrapErrMetricNotFound(metricType)), }, nil } @@ -1064,31 +1045,28 @@ func (s *Server) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompa log.Info("received manual compaction") resp := &milvuspb.ManualCompactionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Success(), } - if s.isClosed() { - log.Warn("failed to execute manual compaction", zap.Error(errDataCoordIsUnhealthy(paramtable.GetNodeID()))) - resp.Status.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &milvuspb.ManualCompactionResponse{ + Status: merr.Status(err), + }, nil } if !Params.DataCoordCfg.EnableCompaction.GetAsBool() { - resp.Status.Reason = "compaction disabled" + resp.Status = merr.Status(merr.WrapErrServiceUnavailable("compaction disabled")) return resp, nil } id, err := s.compactionTrigger.forceTriggerCompaction(req.CollectionID) if err != nil { log.Error("failed to trigger manual compaction", zap.Error(err)) - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } log.Info("success to trigger manual compaction", zap.Int64("compactionID", id)) - resp.Status.ErrorCode = commonpb.ErrorCode_Success resp.CompactionID = id return resp, nil } @@ -1100,19 +1078,17 @@ func (s *Server) GetCompactionState(ctx context.Context, req *milvuspb.GetCompac ) log.Info("received get compaction state request") resp := &milvuspb.GetCompactionStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Success(), } - if s.isClosed() { - log.Warn("failed to get compaction state", zap.Error(errDataCoordIsUnhealthy(paramtable.GetNodeID()))) - resp.Status.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &milvuspb.GetCompactionStateResponse{ + Status: merr.Status(err), + }, nil } if !Params.DataCoordCfg.EnableCompaction.GetAsBool() { - resp.Status.Reason = "compaction disabled" + resp.Status = merr.Status(merr.WrapErrServiceUnavailable("compaction disabled")) return resp, nil } @@ -1124,7 +1100,6 @@ func (s *Server) GetCompactionState(ctx context.Context, req *milvuspb.GetCompac resp.CompletedPlanNo = int64(completedCnt) resp.TimeoutPlanNo = int64(timeoutCnt) resp.FailedPlanNo = int64(failedCnt) - resp.Status.ErrorCode = commonpb.ErrorCode_Success log.Info("success to get compaction state", zap.Any("state", state), zap.Int("executing", executingCnt), zap.Int("completed", completedCnt), zap.Int("failed", failedCnt), zap.Int("timeout", timeoutCnt), zap.Int64s("plans", lo.Map(tasks, func(t *compactionTask, _ int) int64 { @@ -1143,18 +1118,17 @@ func (s *Server) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb. ) log.Info("received the request to get compaction state with plans") - resp := &milvuspb.GetCompactionPlansResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &milvuspb.GetCompactionPlansResponse{ + Status: merr.Status(err), + }, nil } - if s.isClosed() { - log.Warn("failed to get compaction state with plans", zap.Error(errDataCoordIsUnhealthy(paramtable.GetNodeID()))) - resp.Status.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return resp, nil + resp := &milvuspb.GetCompactionPlansResponse{ + Status: merr.Success(), } - if !Params.DataCoordCfg.EnableCompaction.GetAsBool() { - resp.Status.Reason = "compaction disabled" + resp.Status = merr.Status(merr.WrapErrServiceUnavailable("compaction disabled")) return resp, nil } @@ -1165,7 +1139,6 @@ func (s *Server) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb. state, _, _, _, _ := getCompactionState(tasks) - resp.Status.ErrorCode = commonpb.ErrorCode_Success resp.State = state log.Info("success to get state with plans", zap.Any("state", state), zap.Any("merge infos", resp.MergeInfos), zap.Int64s("plans", lo.Map(tasks, func(t *compactionTask, _ int) int64 { @@ -1226,15 +1199,13 @@ func (s *Server) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq ) log.Info("receive watch channels request") resp := &datapb.WatchChannelsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Success(), } - if s.isClosed() { - log.Warn("failed to watch channels request", zap.Error(errDataCoordIsUnhealthy(paramtable.GetNodeID()))) - resp.Status.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.WatchChannelsResponse{ + Status: merr.Status(err), + }, nil } for _, channelName := range req.GetChannelNames() { ch := &channel{ @@ -1244,71 +1215,104 @@ func (s *Server) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq Schema: req.GetSchema(), CreateTimestamp: req.GetCreateTimestamp(), } - err := s.channelManager.Watch(ch) + err := s.channelManager.Watch(ctx, ch) if err != nil { log.Warn("fail to watch channelName", zap.Error(err)) - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } if err := s.meta.catalog.MarkChannelAdded(ctx, ch.Name); err != nil { // TODO: add background task to periodically cleanup the orphaned channel add marks. log.Error("failed to mark channel added", zap.Error(err)) - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } } - resp.Status.ErrorCode = commonpb.ErrorCode_Success return resp, nil } -// GetFlushState gets the flush state of multiple segments -func (s *Server) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { - log := log.Ctx(ctx).WithRateGroup("dc.GetFlushState", 1, 60) - resp := &milvuspb.GetFlushStateResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}} - if s.isClosed() { - log.Warn("DataCoord receive GetFlushState request, server closed", - zap.Int64s("segmentIDs", req.GetSegmentIDs()), zap.Int("len", len(req.GetSegmentIDs()))) - resp.Status.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return resp, nil +// GetFlushState gets the flush state of the collection based on the provided flush ts and segment IDs. +func (s *Server) GetFlushState(ctx context.Context, req *datapb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { + log := log.Ctx(ctx).With(zap.Int64("collection", req.GetCollectionID()), + zap.Time("flushTs", tsoutil.PhysicalTime(req.GetFlushTs()))). + WithRateGroup("dc.GetFlushState", 1, 60) + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &milvuspb.GetFlushStateResponse{ + Status: merr.Status(err), + }, nil } - var unflushed []UniqueID - for _, sid := range req.GetSegmentIDs() { - segment := s.meta.GetHealthySegment(sid) - // segment is nil if it was compacted or it's a empty segment and is set to dropped - if segment == nil || segment.GetState() == commonpb.SegmentState_Flushing || - segment.GetState() == commonpb.SegmentState_Flushed { - continue + resp := &milvuspb.GetFlushStateResponse{Status: merr.Success()} + if len(req.GetSegmentIDs()) > 0 { + var unflushed []UniqueID + for _, sid := range req.GetSegmentIDs() { + segment := s.meta.GetHealthySegment(sid) + // segment is nil if it was compacted, or it's an empty segment and is set to dropped + if segment == nil || isFlushState(segment.GetState()) { + continue + } + unflushed = append(unflushed, sid) + } + if len(unflushed) != 0 { + log.RatedInfo(10, "DataCoord receive GetFlushState request, Flushed is false", zap.Int64s("unflushed", unflushed), zap.Int("len", len(unflushed))) + resp.Flushed = false + + return resp, nil } - unflushed = append(unflushed, sid) } - if len(unflushed) != 0 { - log.RatedInfo(10, "DataCoord receive GetFlushState request, Flushed is false", zap.Int64s("unflushed", unflushed), zap.Int("len", len(unflushed))) - resp.Flushed = false - } else { - log.Info("DataCoord receive GetFlushState request, Flushed is true", zap.Int64s("segmentIDs", req.GetSegmentIDs()), zap.Int("len", len(req.GetSegmentIDs()))) + channels := make([]string, 0) + for _, channelInfo := range s.channelManager.GetAssignedChannels() { + filtered := lo.Filter(channelInfo.Channels, func(channel *channel, _ int) bool { + return channel.CollectionID == req.GetCollectionID() + }) + channelNames := lo.Map(filtered, func(channel *channel, _ int) string { + return channel.Name + }) + channels = append(channels, channelNames...) + } + + if len(channels) == 0 { // For compatibility with old client resp.Flushed = true + + log.Info("GetFlushState all flushed without checking flush ts") + return resp, nil + } + + for _, channel := range channels { + cp := s.meta.GetChannelCheckpoint(channel) + if cp == nil || cp.GetTimestamp() < req.GetFlushTs() { + resp.Flushed = false + + log.RatedInfo(10, "GetFlushState failed, channel unflushed", zap.String("channel", channel), + zap.Time("CP", tsoutil.PhysicalTime(cp.GetTimestamp())), + zap.Duration("lag", tsoutil.PhysicalTime(req.GetFlushTs()).Sub(tsoutil.PhysicalTime(cp.GetTimestamp())))) + return resp, nil + } } - resp.Status.ErrorCode = commonpb.ErrorCode_Success + + resp.Flushed = true + log.Info("GetFlushState all flushed") + return resp, nil } // GetFlushAllState checks if all DML messages before `FlushAllTs` have been flushed. func (s *Server) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAllStateRequest) (*milvuspb.GetFlushAllStateResponse, error) { log := log.Ctx(ctx) - resp := &milvuspb.GetFlushAllStateResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}} - if s.isClosed() { - log.Warn("DataCoord receive GetFlushAllState request, server closed") - resp.Status.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &milvuspb.GetFlushAllStateResponse{ + Status: merr.Status(err), + }, nil } + resp := &milvuspb.GetFlushAllStateResponse{Status: merr.Success()} + dbsRsp, err := s.broker.ListDatabases(ctx) if err != nil { log.Warn("failed to ListDatabases", zap.Error(err)) - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } dbNames := dbsRsp.DbNames @@ -1317,7 +1321,7 @@ func (s *Server) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAll return dbName == req.GetDbName() }) if len(dbNames) == 0 { - resp.Status.Reason = merr.WrapErrDatabaseNotFound(req.GetDbName()).Error() + resp.Status = merr.Status(merr.WrapErrDatabaseNotFound(req.GetDbName())) return resp, nil } } @@ -1326,7 +1330,7 @@ func (s *Server) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAll showColRsp, err := s.broker.ShowCollections(ctx, dbName) if err != nil { log.Warn("failed to ShowCollections", zap.Error(err)) - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } @@ -1334,113 +1338,95 @@ func (s *Server) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAll describeColRsp, err := s.broker.DescribeCollectionInternal(ctx, collection) if err != nil { log.Warn("failed to DescribeCollectionInternal", zap.Error(err)) - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } for _, channel := range describeColRsp.GetVirtualChannelNames() { channelCP := s.meta.GetChannelCheckpoint(channel) if channelCP == nil || channelCP.GetTimestamp() < req.GetFlushAllTs() { resp.Flushed = false - resp.Status.ErrorCode = commonpb.ErrorCode_Success + return resp, nil } } } } resp.Flushed = true - resp.Status.ErrorCode = commonpb.ErrorCode_Success return resp, nil } -// Import distributes the import tasks to dataNodes. -// It returns a failed status if no dataNode is available or if any error occurs. -func (s *Server) Import(ctx context.Context, itr *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { +// Import distributes the import tasks to DataNodes. +// It returns a failed status if no DataNode is available or if any error occurs. +func (s *Server) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { log := log.Ctx(ctx) - log.Info("DataCoord receives import request", zap.Any("import task request", itr)) + log.Info("DataCoord receives import request", zap.Any("req", req)) resp := &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Success(), } - if s.isClosed() { - log.Error("failed to import for closed DataCoord service") - resp.Status.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.ImportTaskResponse{ + Status: merr.Status(err), + }, nil } nodes := s.sessionManager.getLiveNodeIDs() if len(nodes) == 0 { - log.Error("import failed as all DataNodes are offline") - resp.Status.Reason = "no data node available" + log.Warn("import failed as all DataNodes are offline") + resp.Status = merr.Status(merr.WrapErrNodeLackAny("no live DataNode")) return resp, nil } - log.Info("available DataNodes are", zap.Int64s("node ID", nodes)) + log.Info("available DataNodes are", zap.Int64s("nodeIDs", nodes)) - avaNodes := getDiff(nodes, itr.GetWorkingNodes()) + avaNodes := getDiff(nodes, req.GetWorkingNodes()) if len(avaNodes) > 0 { // If there exists available DataNodes, pick one at random. resp.DatanodeId = avaNodes[rand.Intn(len(avaNodes))] - log.Info("picking a free dataNode", - zap.Any("all dataNodes", nodes), - zap.Int64("picking free dataNode with ID", resp.GetDatanodeId())) - s.cluster.Import(s.ctx, resp.GetDatanodeId(), itr) + log.Info("picking a free DataNode", + zap.Any("all DataNodes", nodes), + zap.Int64("picking free DataNode with ID", resp.GetDatanodeId())) + s.cluster.Import(s.ctx, resp.GetDatanodeId(), req) } else { - // No dataNode is available, reject the import request. - msg := "all DataNodes are busy working on data import, the task has been rejected and wait for idle datanode" - log.Info(msg, zap.Int64("task ID", itr.GetImportTask().GetTaskId())) - resp.Status.Reason = msg + // No DataNode is available, reject the import request. + msg := "all DataNodes are busy working on data import, the task has been rejected and wait for idle DataNode" + log.Info(msg, zap.Int64("taskID", req.GetImportTask().GetTaskId())) + resp.Status = merr.Status(merr.WrapErrNodeLackAny("no available DataNode")) return resp, nil } - resp.Status.ErrorCode = commonpb.ErrorCode_Success return resp, nil } // UpdateSegmentStatistics updates a segment's stats. func (s *Server) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) { - log := log.Ctx(ctx) - resp := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "", - } - if s.isClosed() { - log.Warn("failed to update segment stat for closed server") - resp.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return merr.Status(err), nil } s.updateSegmentStatistics(req.GetStats()) - return merr.Status(nil), nil + return merr.Success(), nil } // UpdateChannelCheckpoint updates channel checkpoint in dataCoord. func (s *Server) UpdateChannelCheckpoint(ctx context.Context, req *datapb.UpdateChannelCheckpointRequest) (*commonpb.Status, error) { log := log.Ctx(ctx) - resp := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - } - if s.isClosed() { - log.Warn("failed to update channel position for closed server") - resp.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return resp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return merr.Status(err), nil } err := s.meta.UpdateChannelCheckpoint(req.GetVChannel(), req.GetPosition()) if err != nil { log.Warn("failed to UpdateChannelCheckpoint", zap.String("vChannel", req.GetVChannel()), zap.Error(err)) - resp.Reason = err.Error() - return resp, nil + return merr.Status(err), nil } - return merr.Status(nil), nil + return merr.Success(), nil } // ReportDataNodeTtMsgs send datenode timetick messages to dataCoord. func (s *Server) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDataNodeTtMsgsRequest) (*commonpb.Status, error) { log := log.Ctx(ctx) - if s.isClosed() { - log.Warn("failed to report dataNode ttMsgs on closed server") - return merr.Status(merr.WrapErrServiceUnavailable("Datacoord not ready")), nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return merr.Status(err), nil } for _, ttMsg := range req.GetMsgs() { @@ -1458,7 +1444,7 @@ func (s *Server) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDat } } - return merr.Status(nil), nil + return merr.Success(), nil } func (s *Server) handleRPCTimetickMessage(ctx context.Context, ttMsg *msgpb.DataNodeTtMsg) error { @@ -1535,22 +1521,15 @@ func (s *Server) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSe zap.Int64("partitionID", req.GetPartitionId()), zap.String("channelName", req.GetChannelName()), zap.Int64("# of rows", req.GetRowNum())) - errResp := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "", - } - if s.isClosed() { - log.Warn("failed to add segment for closed server") - errResp.ErrorCode = commonpb.ErrorCode_DataCoordNA - errResp.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return errResp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return merr.Status(err), nil } // Look for the DataNode that watches the channel. ok, nodeID := s.channelManager.getNodeIDByChannelName(req.GetChannelName()) if !ok { - log.Error("no DataNode found for channel", zap.String("channelName", req.GetChannelName())) - errResp.Reason = fmt.Sprint("no DataNode found for channel ", req.GetChannelName()) - return errResp, nil + err := merr.WrapErrChannelNotFound(req.GetChannelName(), "no DataNode watches this channel") + log.Error("no DataNode found for channel", zap.String("channelName", req.GetChannelName()), zap.Error(err)) + return merr.Status(err), nil } // Call DataNode to add the new segment to its own flow graph. cli, err := s.sessionManager.getClient(ctx, nodeID) @@ -1558,10 +1537,7 @@ func (s *Server) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSe log.Error("failed to get DataNode client for SaveImportSegment", zap.Int64("DataNode ID", nodeID), zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } resp, err := cli.AddImportSegment(ctx, &datapb.AddImportSegmentRequest{ @@ -1578,10 +1554,7 @@ func (s *Server) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSe }) if err := VerifyResponse(resp.GetStatus(), err); err != nil { log.Error("failed to add segment", zap.Int64("DataNode ID", nodeID), zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } log.Info("succeed to add segment", zap.Int64("DataNode ID", nodeID), zap.Any("add segment req", req)) // Fill in start position message ID. @@ -1591,12 +1564,9 @@ func (s *Server) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSe rsp, err := s.SaveBinlogPaths(context.Background(), req.GetSaveBinlogPathReq()) if err := VerifyResponse(rsp, err); err != nil { log.Error("failed to SaveBinlogPaths", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } - return merr.Status(nil), nil + return merr.Success(), nil } // UnsetIsImportingState unsets the isImporting states of the given segments. @@ -1609,17 +1579,14 @@ func (s *Server) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsI for _, segID := range req.GetSegmentIds() { if err := s.meta.UnsetIsImporting(segID); err != nil { // Fail-open. - log.Error("failed to unset segment is importing state", zap.Int64("segmentID", segID)) + log.Error("failed to unset segment is importing state", + zap.Int64("segmentID", segID), + ) reportErr = err } } - if reportErr != nil { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: reportErr.Error(), - }, nil - } - return merr.Status(nil), nil + + return merr.Status(reportErr), nil } // MarkSegmentsDropped marks the given segments as `Dropped`. @@ -1627,33 +1594,20 @@ func (s *Server) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsI // Deprecated, do not use it func (s *Server) MarkSegmentsDropped(ctx context.Context, req *datapb.MarkSegmentsDroppedRequest) (*commonpb.Status, error) { log.Info("marking segments dropped", zap.Int64s("segments", req.GetSegmentIds())) - failure := false + var err error for _, segID := range req.GetSegmentIds() { - if err := s.meta.SetState(segID, commonpb.SegmentState_Dropped); err != nil { + if err = s.meta.SetState(segID, commonpb.SegmentState_Dropped); err != nil { // Fail-open. log.Error("failed to set segment state as dropped", zap.Int64("segmentID", segID)) - failure = true + break } } - if failure { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, nil - } - return merr.Status(nil), nil + return merr.Status(err), nil } func (s *Server) BroadcastAlteredCollection(ctx context.Context, req *datapb.AlterCollectionRequest) (*commonpb.Status, error) { - log := log.Ctx(ctx) - errResp := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "", - } - - if s.isClosed() { - log.Warn("failed to broadcast collection information for closed server") - errResp.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return errResp, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return merr.Status(err), nil } // get collection info from cache @@ -1674,18 +1628,20 @@ func (s *Server) BroadcastAlteredCollection(ctx context.Context, req *datapb.Alt Properties: properties, } s.meta.AddCollection(collInfo) - return merr.Status(nil), nil + return merr.Success(), nil } clonedColl.Properties = properties s.meta.AddCollection(clonedColl) - return merr.Status(nil), nil + return merr.Success(), nil } func (s *Server) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { - if s.isClosed() { - reason := errorutil.UnHealthReason("datacoord", paramtable.GetNodeID(), "datacoord is closed") - return &milvuspb.CheckHealthResponse{IsHealthy: false, Reasons: []string{reason}}, nil + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &milvuspb.CheckHealthResponse{ + Status: merr.Status(err), + Reasons: []string{err.Error()}, + }, nil } mu := &sync.Mutex{} @@ -1700,43 +1656,42 @@ func (s *Server) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthReque if err != nil { mu.Lock() defer mu.Unlock() - errReasons = append(errReasons, errorutil.UnHealthReason("datanode", nodeID, err.Error())) + errReasons = append(errReasons, fmt.Sprintf("failed to get DataNode %d: %v", nodeID, err)) return err } - sta, err := cli.GetComponentStates(ctx) - isHealthy, reason := errorutil.UnHealthReasonWithComponentStatesOrErr("datanode", nodeID, sta, err) - if !isHealthy { + sta, err := cli.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + if err != nil { + return err + } + err = merr.AnalyzeState("DataNode", nodeID, sta) + if err != nil { mu.Lock() defer mu.Unlock() - errReasons = append(errReasons, reason) + errReasons = append(errReasons, err.Error()) } - return err + return nil }) } err := group.Wait() if err != nil || len(errReasons) != 0 { - return &milvuspb.CheckHealthResponse{IsHealthy: false, Reasons: errReasons}, nil + return &milvuspb.CheckHealthResponse{Status: merr.Success(), IsHealthy: false, Reasons: errReasons}, nil } - return &milvuspb.CheckHealthResponse{IsHealthy: true, Reasons: errReasons}, nil + return &milvuspb.CheckHealthResponse{Status: merr.Success(), IsHealthy: true, Reasons: errReasons}, nil } func (s *Server) GcConfirm(ctx context.Context, request *datapb.GcConfirmRequest) (*datapb.GcConfirmResponse, error) { - resp := &datapb.GcConfirmResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, - GcFinished: false, + if err := merr.CheckHealthy(s.GetStateCode()); err != nil { + return &datapb.GcConfirmResponse{ + Status: merr.Status(err), + }, nil } - if s.isClosed() { - resp.Status.Reason = msgDataCoordIsUnhealthy(paramtable.GetNodeID()) - return resp, nil + resp := &datapb.GcConfirmResponse{ + Status: merr.Success(), } - resp.GcFinished = s.meta.GcConfirm(ctx, request.GetCollectionId(), request.GetPartitionId()) - resp.Status.ErrorCode = commonpb.ErrorCode_Success return resp, nil } diff --git a/internal/datacoord/services_test.go b/internal/datacoord/services_test.go index 05b7031806d59..faaf503dda59d 100644 --- a/internal/datacoord/services_test.go +++ b/internal/datacoord/services_test.go @@ -2,7 +2,6 @@ package datacoord import ( "context" - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -17,6 +16,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" ) @@ -98,13 +98,12 @@ func TestServer_GcConfirm(t *testing.T) { } func TestGetRecoveryInfoV2(t *testing.T) { - t.Run("test get recovery info with no segments", func(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } req := &datapb.GetRecoveryInfoRequestV2{ @@ -112,14 +111,14 @@ func TestGetRecoveryInfoV2(t *testing.T) { } resp, err := svr.GetRecoveryInfoV2(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 0, len(resp.GetSegments())) - assert.EqualValues(t, 1, len(resp.GetChannels())) - assert.Nil(t, resp.GetChannels()[0].SeekPosition) + assert.EqualValues(t, 0, len(resp.GetChannels())) }) createSegment := func(id, collectionID, partitionID, numOfRows int64, posTs uint64, - channel string, state commonpb.SegmentState) *datapb.SegmentInfo { + channel string, state commonpb.SegmentState, + ) *datapb.SegmentInfo { return &datapb.SegmentInfo{ ID: id, CollectionID: collectionID, @@ -145,8 +144,8 @@ func TestGetRecoveryInfoV2(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } svr.meta.AddCollection(&collectionInfo{ @@ -205,9 +204,9 @@ func TestGetRecoveryInfoV2(t *testing.T) { }, }, } - err = svr.meta.AddSegment(NewSegmentInfo(seg1)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg2)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) assert.NoError(t, err) err = svr.meta.AddSegmentIndex(&model.SegmentIndex{ SegmentID: seg1.ID, @@ -230,12 +229,16 @@ func TestGetRecoveryInfoV2(t *testing.T) { }) assert.NoError(t, err) + ch := &channel{Name: "vchan1", CollectionID: 0} + svr.channelManager.AddNode(0) + svr.channelManager.Watch(context.Background(), ch) + req := &datapb.GetRecoveryInfoRequestV2{ CollectionID: 0, } resp, err := svr.GetRecoveryInfoV2(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 1, len(resp.GetChannels())) assert.EqualValues(t, 0, len(resp.GetChannels()[0].GetUnflushedSegmentIds())) assert.ElementsMatch(t, []int64{0, 1}, resp.GetChannels()[0].GetFlushedSegmentIds()) @@ -249,8 +252,8 @@ func TestGetRecoveryInfoV2(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } svr.meta.AddCollection(&collectionInfo{ @@ -301,17 +304,21 @@ func TestGetRecoveryInfoV2(t *testing.T) { }, }, } - err = svr.meta.AddSegment(NewSegmentInfo(seg1)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg2)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) assert.NoError(t, err) + ch := &channel{Name: "vchan1", CollectionID: 0} + svr.channelManager.AddNode(0) + svr.channelManager.Watch(context.Background(), ch) + req := &datapb.GetRecoveryInfoRequestV2{ CollectionID: 0, } resp, err := svr.GetRecoveryInfoV2(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 0, len(resp.GetSegments())) assert.EqualValues(t, 1, len(resp.GetChannels())) assert.NotNil(t, resp.GetChannels()[0].SeekPosition) @@ -326,8 +333,8 @@ func TestGetRecoveryInfoV2(t *testing.T) { Schema: newTestSchema(), }) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } binlogReq := &datapb.SaveBinlogPathsRequest{ @@ -338,10 +345,10 @@ func TestGetRecoveryInfoV2(t *testing.T) { FieldID: 1, Binlogs: []*datapb.Binlog{ { - LogPath: "/binlog/file1", + LogPath: metautil.BuildInsertLogPath("a", 0, 100, 0, 1, 801), }, { - LogPath: "/binlog/file2", + LogPath: metautil.BuildInsertLogPath("a", 0, 100, 0, 1, 801), }, }, }, @@ -351,10 +358,10 @@ func TestGetRecoveryInfoV2(t *testing.T) { FieldID: 1, Binlogs: []*datapb.Binlog{ { - LogPath: "/stats_log/file1", + LogPath: metautil.BuildStatsLogPath("a", 0, 100, 0, 1000, 10000), }, { - LogPath: "/stats_log/file2", + LogPath: metautil.BuildStatsLogPath("a", 0, 100, 0, 1000, 10000), }, }, }, @@ -365,7 +372,7 @@ func TestGetRecoveryInfoV2(t *testing.T) { { TimestampFrom: 0, TimestampTo: 1, - LogPath: "/stats_log/file1", + LogPath: metautil.BuildDeltaLogPath("a", 0, 100, 0, 100000), LogSize: 1, }, }, @@ -373,7 +380,7 @@ func TestGetRecoveryInfoV2(t *testing.T) { }, } segment := createSegment(0, 0, 1, 100, 10, "vchan1", commonpb.SegmentState_Flushed) - err := svr.meta.AddSegment(NewSegmentInfo(segment)) + err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segment)) assert.NoError(t, err) err = svr.meta.CreateIndex(&model.Index{ @@ -397,7 +404,7 @@ func TestGetRecoveryInfoV2(t *testing.T) { err = svr.channelManager.AddNode(0) assert.NoError(t, err) - err = svr.channelManager.Watch(&channel{Name: "vchan1", CollectionID: 0}) + err = svr.channelManager.Watch(context.Background(), &channel{Name: "vchan1", CollectionID: 0}) assert.NoError(t, err) sResp, err := svr.SaveBinlogPaths(context.TODO(), binlogReq) @@ -410,21 +417,31 @@ func TestGetRecoveryInfoV2(t *testing.T) { } resp, err := svr.GetRecoveryInfoV2(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NoError(t, merr.Error(resp.Status)) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 1, len(resp.GetSegments())) assert.EqualValues(t, 0, resp.GetSegments()[0].GetID()) assert.EqualValues(t, 1, len(resp.GetSegments()[0].GetBinlogs())) assert.EqualValues(t, 1, resp.GetSegments()[0].GetBinlogs()[0].GetFieldID()) - for i, binlog := range resp.GetSegments()[0].GetBinlogs()[0].GetBinlogs() { - assert.Equal(t, fmt.Sprintf("/binlog/file%d", i+1), binlog.GetLogPath()) + for _, binlog := range resp.GetSegments()[0].GetBinlogs()[0].GetBinlogs() { + assert.Equal(t, "", binlog.GetLogPath()) + assert.Equal(t, int64(801), binlog.GetLogID()) + } + for _, binlog := range resp.GetSegments()[0].GetStatslogs()[0].GetBinlogs() { + assert.Equal(t, "", binlog.GetLogPath()) + assert.Equal(t, int64(10000), binlog.GetLogID()) + } + for _, binlog := range resp.GetSegments()[0].GetDeltalogs()[0].GetBinlogs() { + assert.Equal(t, "", binlog.GetLogPath()) + assert.Equal(t, int64(100000), binlog.GetLogID()) } }) t.Run("with dropped segments", func(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } svr.meta.AddCollection(&collectionInfo{ @@ -441,17 +458,21 @@ func TestGetRecoveryInfoV2(t *testing.T) { seg1 := createSegment(7, 0, 0, 100, 30, "vchan1", commonpb.SegmentState_Growing) seg2 := createSegment(8, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Dropped) - err = svr.meta.AddSegment(NewSegmentInfo(seg1)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg2)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) assert.NoError(t, err) + ch := &channel{Name: "vchan1", CollectionID: 0} + svr.channelManager.AddNode(0) + svr.channelManager.Watch(context.Background(), ch) + req := &datapb.GetRecoveryInfoRequestV2{ CollectionID: 0, } resp, err := svr.GetRecoveryInfoV2(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 0, len(resp.GetSegments())) assert.EqualValues(t, 1, len(resp.GetChannels())) assert.NotNil(t, resp.GetChannels()[0].SeekPosition) @@ -464,8 +485,8 @@ func TestGetRecoveryInfoV2(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } svr.meta.AddCollection(&collectionInfo{ @@ -483,29 +504,224 @@ func TestGetRecoveryInfoV2(t *testing.T) { seg1 := createSegment(7, 0, 0, 100, 30, "vchan1", commonpb.SegmentState_Growing) seg2 := createSegment(8, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Flushed) seg2.IsFake = true - err = svr.meta.AddSegment(NewSegmentInfo(seg1)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg2)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) assert.NoError(t, err) + ch := &channel{Name: "vchan1", CollectionID: 0} + svr.channelManager.AddNode(0) + svr.channelManager.Watch(context.Background(), ch) + req := &datapb.GetRecoveryInfoRequestV2{ CollectionID: 0, } resp, err := svr.GetRecoveryInfoV2(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.EqualValues(t, 0, len(resp.GetSegments())) assert.EqualValues(t, 1, len(resp.GetChannels())) assert.NotNil(t, resp.GetChannels()[0].SeekPosition) assert.NotEqual(t, 0, resp.GetChannels()[0].GetSeekPosition().GetTimestamp()) }) + t.Run("with failed compress", func(t *testing.T) { + svr := newTestServer(t, nil) + defer closeTestServer(t, svr) + + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil + } + + svr.meta.AddCollection(&collectionInfo{ + ID: 0, + Schema: newTestSchema(), + }) + + err := svr.meta.UpdateChannelCheckpoint("vchan1", &msgpb.MsgPosition{ + ChannelName: "vchan1", + Timestamp: 0, + MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + }) + assert.NoError(t, err) + + svr.meta.AddCollection(&collectionInfo{ + ID: 1, + Schema: newTestSchema(), + }) + + err = svr.meta.UpdateChannelCheckpoint("vchan2", &msgpb.MsgPosition{ + ChannelName: "vchan2", + Timestamp: 0, + MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + }) + assert.NoError(t, err) + + svr.meta.AddCollection(&collectionInfo{ + ID: 2, + Schema: newTestSchema(), + }) + + err = svr.meta.UpdateChannelCheckpoint("vchan3", &msgpb.MsgPosition{ + ChannelName: "vchan3", + Timestamp: 0, + MsgID: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + }) + assert.NoError(t, err) + + svr.channelManager.AddNode(0) + ch := &channel{ + Name: "vchan1", + CollectionID: 0, + } + err = svr.channelManager.Watch(context.TODO(), ch) + assert.NoError(t, err) + + ch = &channel{ + Name: "vchan2", + CollectionID: 1, + } + err = svr.channelManager.Watch(context.TODO(), ch) + assert.NoError(t, err) + + ch = &channel{ + Name: "vchan3", + CollectionID: 2, + } + err = svr.channelManager.Watch(context.TODO(), ch) + assert.NoError(t, err) + + seg := createSegment(8, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Flushed) + binLogPaths := make([]*datapb.Binlog, 1) + // miss one field + path := metautil.JoinIDPath(0, 0, 8, fieldID) + path = path + "/mock" + binLogPaths[0] = &datapb.Binlog{ + EntriesNum: 10000, + LogPath: path, + } + + seg.Statslogs = append(seg.Statslogs, &datapb.FieldBinlog{ + FieldID: fieldID, + Binlogs: binLogPaths, + }) + + binLogPaths2 := make([]*datapb.Binlog, 1) + pathCorrect := metautil.JoinIDPath(0, 0, 8, fieldID, 1) + binLogPaths2[0] = &datapb.Binlog{ + EntriesNum: 10000, + LogPath: pathCorrect, + } + + seg.Binlogs = append(seg.Binlogs, &datapb.FieldBinlog{ + FieldID: fieldID, + Binlogs: binLogPaths2, + }) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg)) + assert.NoError(t, err) + + // make sure collection is indexed + err = svr.meta.CreateIndex(&model.Index{ + TenantID: "", + CollectionID: 0, + FieldID: 2, + IndexID: 0, + IndexName: "_default_idx_1", + IsDeleted: false, + CreateTime: 0, + TypeParams: nil, + IndexParams: nil, + IsAutoIndex: false, + UserIndexParams: nil, + }) + assert.NoError(t, err) + + svr.meta.segments.SetSegmentIndex(seg.ID, &model.SegmentIndex{ + SegmentID: seg.ID, + CollectionID: 0, + PartitionID: 0, + NumRows: 100, + IndexID: 0, + BuildID: 0, + NodeID: 0, + IndexVersion: 1, + IndexState: commonpb.IndexState_Finished, + FailReason: "", + IsDeleted: false, + CreateTime: 0, + IndexFileKeys: nil, + IndexSize: 0, + }) + + req := &datapb.GetRecoveryInfoRequestV2{ + CollectionID: 0, + } + resp, err := svr.GetRecoveryInfoV2(context.TODO(), req) + assert.NoError(t, err) + assert.True(t, resp.Status.ErrorCode == commonpb.ErrorCode_UnexpectedError) + + // test bin log + path = metautil.JoinIDPath(0, 0, 9, fieldID) + path = path + "/mock" + binLogPaths[0] = &datapb.Binlog{ + EntriesNum: 10000, + LogPath: path, + } + + seg2 := createSegment(9, 1, 0, 100, 40, "vchan2", commonpb.SegmentState_Flushed) + seg2.Binlogs = append(seg2.Binlogs, &datapb.FieldBinlog{ + FieldID: fieldID, + Binlogs: binLogPaths, + }) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) + assert.NoError(t, err) + + // make sure collection is indexed + err = svr.meta.CreateIndex(&model.Index{ + TenantID: "", + CollectionID: 1, + FieldID: 2, + IndexID: 1, + IndexName: "_default_idx_2", + IsDeleted: false, + CreateTime: 0, + TypeParams: nil, + IndexParams: nil, + IsAutoIndex: false, + UserIndexParams: nil, + }) + assert.NoError(t, err) + + svr.meta.segments.SetSegmentIndex(seg2.ID, &model.SegmentIndex{ + SegmentID: seg2.ID, + CollectionID: 1, + PartitionID: 0, + NumRows: 100, + IndexID: 1, + BuildID: 0, + NodeID: 0, + IndexVersion: 1, + IndexState: commonpb.IndexState_Finished, + FailReason: "", + IsDeleted: false, + CreateTime: 0, + IndexFileKeys: nil, + IndexSize: 0, + }) + req = &datapb.GetRecoveryInfoRequestV2{ + CollectionID: 1, + } + resp, err = svr.GetRecoveryInfoV2(context.TODO(), req) + assert.NoError(t, err) + assert.True(t, resp.Status.ErrorCode == commonpb.ErrorCode_UnexpectedError) + }) + t.Run("with continuous compaction", func(t *testing.T) { svr := newTestServer(t, nil) defer closeTestServer(t, svr) - svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) { - return newMockRootCoordService(), nil + svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoordClient, error) { + return newMockRootCoordClient(), nil } svr.meta.AddCollection(&collectionInfo{ @@ -527,15 +743,15 @@ func TestGetRecoveryInfoV2(t *testing.T) { seg4 := createSegment(12, 0, 0, 2048, 40, "vchan1", commonpb.SegmentState_Dropped) seg5 := createSegment(13, 0, 0, 2048, 40, "vchan1", commonpb.SegmentState_Flushed) seg5.CompactionFrom = []int64{11, 12} - err = svr.meta.AddSegment(NewSegmentInfo(seg1)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg1)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg2)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg2)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg3)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg3)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg4)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg4)) assert.NoError(t, err) - err = svr.meta.AddSegment(NewSegmentInfo(seg5)) + err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg5)) assert.NoError(t, err) err = svr.meta.CreateIndex(&model.Index{ TenantID: "", @@ -568,12 +784,16 @@ func TestGetRecoveryInfoV2(t *testing.T) { IndexSize: 0, }) + ch := &channel{Name: "vchan1", CollectionID: 0} + svr.channelManager.AddNode(0) + svr.channelManager.Watch(context.Background(), ch) + req := &datapb.GetRecoveryInfoRequestV2{ CollectionID: 0, } resp, err := svr.GetRecoveryInfoV2(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.NotNil(t, resp.GetChannels()[0].SeekPosition) assert.NotEqual(t, 0, resp.GetChannels()[0].GetSeekPosition().GetTimestamp()) assert.Len(t, resp.GetChannels()[0].GetDroppedSegmentIds(), 0) @@ -586,7 +806,7 @@ func TestGetRecoveryInfoV2(t *testing.T) { closeTestServer(t, svr) resp, err := svr.GetRecoveryInfoV2(context.TODO(), &datapb.GetRecoveryInfoRequestV2{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason()) + err = merr.Error(resp.GetStatus()) + assert.ErrorIs(t, err, merr.ErrServiceNotReady) }) } diff --git a/internal/datacoord/session.go b/internal/datacoord/session.go index b8bdae06604e5..115209bc40789 100644 --- a/internal/datacoord/session.go +++ b/internal/datacoord/session.go @@ -38,7 +38,7 @@ type NodeInfo struct { type Session struct { sync.Mutex info *NodeInfo - client types.DataNode + client types.DataNodeClient clientCreator dataNodeCreatorFunc isDisposed bool } @@ -52,7 +52,7 @@ func NewSession(info *NodeInfo, creator dataNodeCreatorFunc) *Session { } // GetOrCreateClient gets or creates a new client for session -func (n *Session) GetOrCreateClient(ctx context.Context) (types.DataNode, error) { +func (n *Session) GetOrCreateClient(ctx context.Context) (types.DataNodeClient, error) { n.Lock() defer n.Unlock() @@ -76,10 +76,7 @@ func (n *Session) initClient(ctx context.Context) (err error) { if n.client, err = n.clientCreator(ctx, n.info.Address, n.info.NodeID); err != nil { return } - if err = n.client.Init(); err != nil { - return - } - return n.client.Start() + return nil } // Dispose releases client connection @@ -88,7 +85,7 @@ func (n *Session) Dispose() { defer n.Unlock() if n.client != nil { - n.client.Stop() + n.client.Close() n.client = nil } n.isDisposed = true diff --git a/internal/datacoord/session_manager.go b/internal/datacoord/session_manager.go index 54bcf850ef3af..e4449433a2cc5 100644 --- a/internal/datacoord/session_manager.go +++ b/internal/datacoord/session_manager.go @@ -22,6 +22,8 @@ import ( "sync" "time" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" grpcdatanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -31,15 +33,14 @@ import ( "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" ) const ( flushTimeout = 15 * time.Second // TODO: evaluate and update import timeout. - importTimeout = 3 * time.Hour - reCollectTimeout = 5 * time.Second + importTimeout = 3 * time.Hour ) // SessionManager provides the grpc interfaces of cluster @@ -59,7 +60,7 @@ func withSessionCreator(creator dataNodeCreatorFunc) SessionOpt { } func defaultSessionCreator() dataNodeCreatorFunc { - return func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { + return func(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { return grpcdatanodeclient.NewClient(ctx, addr, nodeID) } } @@ -225,32 +226,6 @@ func (c *SessionManager) execImport(ctx context.Context, nodeID int64, itr *data log.Info("success to import", zap.Int64("node", nodeID), zap.Any("import task", itr)) } -// ReCollectSegmentStats collects segment stats info from DataNodes, after DataCoord reboots. -func (c *SessionManager) ReCollectSegmentStats(ctx context.Context, nodeID int64) error { - cli, err := c.getClient(ctx, nodeID) - if err != nil { - log.Warn("failed to get dataNode client", zap.Int64("DataNode ID", nodeID), zap.Error(err)) - return err - } - ctx, cancel := context.WithTimeout(ctx, reCollectTimeout) - defer cancel() - resp, err := cli.ResendSegmentStats(ctx, &datapb.ResendSegmentStatsRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_ResendSegmentStats), - commonpbutil.WithSourceID(paramtable.GetNodeID()), - ), - }) - if err := VerifyResponse(resp, err); err != nil { - log.Warn("re-collect segment stats call failed", - zap.Int64("DataNode ID", nodeID), zap.Error(err)) - return err - } - log.Info("re-collect segment stats call succeeded", - zap.Int64("DataNode ID", nodeID), - zap.Int64s("segment stat collected", resp.GetSegResent())) - return nil -} - func (c *SessionManager) GetCompactionState() map[int64]*datapb.CompactionStateResult { wg := sync.WaitGroup{} ctx := context.Background() @@ -300,7 +275,28 @@ func (c *SessionManager) GetCompactionState() map[int64]*datapb.CompactionStateR return rst } -func (c *SessionManager) getClient(ctx context.Context, nodeID int64) (types.DataNode, error) { +func (c *SessionManager) FlushChannels(ctx context.Context, nodeID int64, req *datapb.FlushChannelsRequest) error { + log := log.Ctx(ctx).With(zap.Int64("nodeID", nodeID), + zap.Time("flushTs", tsoutil.PhysicalTime(req.GetFlushTs())), + zap.Strings("channels", req.GetChannels())) + cli, err := c.getClient(ctx, nodeID) + if err != nil { + log.Warn("failed to get client", zap.Error(err)) + return err + } + + log.Info("SessionManager.FlushChannels start") + resp, err := cli.FlushChannels(ctx, req) + err = VerifyResponse(resp, err) + if err != nil { + log.Warn("SessionManager.FlushChannels failed", zap.Error(err)) + return err + } + log.Info("SessionManager.FlushChannels successfully") + return nil +} + +func (c *SessionManager) getClient(ctx context.Context, nodeID int64) (types.DataNodeClient, error) { c.sessions.RLock() session, ok := c.sessions.data[nodeID] c.sessions.RUnlock() diff --git a/internal/datacoord/util.go b/internal/datacoord/util.go index 8fcb82f43faa3..37bafea790ff4 100644 --- a/internal/datacoord/util.go +++ b/internal/datacoord/util.go @@ -22,7 +22,6 @@ import ( "strings" "time" - "github.com/cockroachdb/errors" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -30,6 +29,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" ) // Response response interface for verification @@ -51,32 +51,16 @@ func VerifyResponse(response interface{}, err error) error { if resp.GetStatus() == nil { return errNilStatusResponse } - if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - return errors.New(resp.GetStatus().GetReason()) - } + return merr.Error(resp.GetStatus()) + case *commonpb.Status: if resp == nil { return errNilResponse } - if resp.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(resp.GetReason()) - } + return merr.Error(resp) default: return errUnknownResponseType } - return nil -} - -// failResponse sets status to failed with unexpected error and reason. -func failResponse(status *commonpb.Status, reason string) { - status.ErrorCode = commonpb.ErrorCode_UnexpectedError - status.Reason = reason -} - -// failResponseWithCode sets status to failed with error code and reason. -func failResponseWithCode(status *commonpb.Status, errCode commonpb.ErrorCode, reason string) { - status.ErrorCode = errCode - status.Reason = reason } func FilterInIndexedSegments(handler Handler, mt *meta, segments ...*SegmentInfo) []*SegmentInfo { @@ -104,7 +88,8 @@ func FilterInIndexedSegments(handler Handler, mt *meta, segments ...*SegmentInfo } for _, field := range coll.Schema.GetFields() { if field.GetDataType() == schemapb.DataType_BinaryVector || - field.GetDataType() == schemapb.DataType_FloatVector { + field.GetDataType() == schemapb.DataType_FloatVector || + field.GetDataType() == schemapb.DataType_Float16Vector { vecFieldID[collection] = field.GetFieldID() break } diff --git a/internal/datacoord/util_test.go b/internal/datacoord/util_test.go index 16b59fcfa50dc..0b9b564a5d303 100644 --- a/internal/datacoord/util_test.go +++ b/internal/datacoord/util_test.go @@ -110,7 +110,7 @@ func (suite *UtilSuite) TestVerifyResponse() { for _, c := range cases { r := VerifyResponse(c.resp, c.err) if c.equalValue { - suite.EqualValues(c.expected.Error(), r.Error()) + suite.Contains(r.Error(), c.expected.Error()) } else { suite.Equal(c.expected, r) } diff --git a/internal/datanode/allocator/allocator.go b/internal/datanode/allocator/allocator.go index 0ec6a9c5a967c..df3408add79bb 100644 --- a/internal/datanode/allocator/allocator.go +++ b/internal/datanode/allocator/allocator.go @@ -45,7 +45,7 @@ type Impl struct { *gAllocator.IDAllocator } -func New(ctx context.Context, rootCoord types.RootCoord, peerID UniqueID) (Allocator, error) { +func New(ctx context.Context, rootCoord types.RootCoordClient, peerID UniqueID) (Allocator, error) { idAlloc, err := gAllocator.NewIDAllocator(ctx, rootCoord, peerID) if err != nil { return nil, err @@ -58,7 +58,6 @@ func (a *Impl) GetIDAlloactor() *gAllocator.IDAllocator { } func (a *Impl) GetGenerator(count int, done <-chan struct{}) (<-chan UniqueID, error) { - idStart, _, err := a.Alloc(uint32(count)) if err != nil { return nil, err diff --git a/internal/datanode/allocator/allocator_test.go b/internal/datanode/allocator/allocator_test.go index a78414d628aee..63d48690c962a 100644 --- a/internal/datanode/allocator/allocator_test.go +++ b/internal/datanode/allocator/allocator_test.go @@ -22,10 +22,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/merr" ) func TestGetGenerator(t *testing.T) { @@ -77,16 +78,15 @@ func TestGetGenerator(t *testing.T) { } type RootCoordFactory struct { - types.RootCoord + types.RootCoordClient ID UniqueID } -func (m *RootCoordFactory) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { +func (m *RootCoordFactory) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { resp := &rootcoordpb.AllocIDResponse{ - ID: m.ID, - Count: in.GetCount(), - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }} + ID: m.ID, + Count: in.GetCount(), + Status: merr.Success(), + } return resp, nil } diff --git a/internal/datanode/binlog_io.go b/internal/datanode/binlog_io.go index eede6ada63f15..aa92e6afffe60 100644 --- a/internal/datanode/binlog_io.go +++ b/internal/datanode/binlog_io.go @@ -23,6 +23,7 @@ import ( "time" "github.com/cockroachdb/errors" + "go.uber.org/zap" "github.com/milvus-io/milvus/internal/datanode/allocator" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -30,9 +31,8 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/metautil" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" ) var ( @@ -64,70 +64,89 @@ type binlogIO struct { allocator.Allocator } -var _ downloader = (*binlogIO)(nil) -var _ uploader = (*binlogIO)(nil) +var ( + _ downloader = (*binlogIO)(nil) + _ uploader = (*binlogIO)(nil) +) func (b *binlogIO) download(ctx context.Context, paths []string) ([]*Blob, error) { - var ( - err = errStart - vs [][]byte - ) - log.Debug("down load", zap.Strings("path", paths)) - g, gCtx := errgroup.WithContext(ctx) - g.Go(func() error { - for err != nil { - select { - case <-gCtx.Done(): - log.Warn("ctx done when downloading kvs from blob storage", zap.Strings("paths", paths)) - return errDownloadFromBlobStorage - - default: - if err != errStart { - log.Warn("downloading failed, retry in 50ms", zap.Strings("paths", paths)) - time.Sleep(50 * time.Millisecond) + resp := make([]*Blob, len(paths)) + if len(paths) == 0 { + return resp, nil + } + futures := make([]*conc.Future[any], len(paths)) + for i, path := range paths { + localPath := path + future := getMultiReadPool().Submit(func() (any, error) { + var vs []byte + err := errStart + for err != nil { + select { + case <-ctx.Done(): + log.Warn("ctx done when downloading kvs from blob storage", zap.Strings("paths", paths)) + return nil, errDownloadFromBlobStorage + default: + if err != errStart { + time.Sleep(50 * time.Millisecond) + } + vs, err = b.Read(ctx, localPath) } - vs, err = b.MultiRead(ctx, paths) } - } - return nil - }) - - if err := g.Wait(); err != nil { - return nil, err + return vs, nil + }) + futures[i] = future } - rst := make([]*Blob, len(vs)) - for i := range rst { - rst[i] = &Blob{Value: vs[i]} + for i := range futures { + if !futures[i].OK() { + return nil, futures[i].Err() + } + resp[i] = &Blob{Value: futures[i].Value().([]byte)} } - return rst, nil + return resp, nil } func (b *binlogIO) uploadSegmentFiles( ctx context.Context, CollectionID UniqueID, segID UniqueID, - kvs map[string][]byte) error { - var err = errStart - for err != nil { - select { - case <-ctx.Done(): - log.Warn("ctx done when saving kvs to blob storage", - zap.Int64("collectionID", CollectionID), - zap.Int64("segmentID", segID), - zap.Int("number of kvs", len(kvs))) - return errUploadToBlobStorage - default: - if err != errStart { - log.Warn("save binlog failed, retry in 50ms", - zap.Int64("collectionID", CollectionID), - zap.Int64("segmentID", segID)) - time.Sleep(50 * time.Millisecond) + kvs map[string][]byte, +) error { + log.Debug("update", zap.Int64("collectionID", CollectionID), zap.Int64("segmentID", segID)) + if len(kvs) == 0 { + return nil + } + futures := make([]*conc.Future[any], 0) + for key, val := range kvs { + localPath := key + localVal := val + future := getMultiReadPool().Submit(func() (any, error) { + err := errStart + for err != nil { + select { + case <-ctx.Done(): + log.Warn("ctx done when saving kvs to blob storage", + zap.Int64("collectionID", CollectionID), + zap.Int64("segmentID", segID), + zap.Int("number of kvs", len(kvs))) + return nil, errUploadToBlobStorage + default: + if err != errStart { + time.Sleep(50 * time.Millisecond) + } + err = b.Write(ctx, localPath, localVal) + } } - err = b.MultiWrite(ctx, kvs) - } + return nil, nil + }) + futures = append(futures, future) + } + + err := conc.AwaitAll(futures...) + if err != nil { + return err } return nil } @@ -225,7 +244,8 @@ func (b *binlogIO) uploadStatsLog( iData *InsertData, stats *storage.PrimaryKeyStats, totRows int64, - meta *etcdpb.CollectionMeta) (map[UniqueID]*datapb.FieldBinlog, map[UniqueID]*datapb.FieldBinlog, error) { + meta *etcdpb.CollectionMeta, +) (map[UniqueID]*datapb.FieldBinlog, map[UniqueID]*datapb.FieldBinlog, error) { var inPaths map[int64]*datapb.FieldBinlog var err error @@ -261,8 +281,8 @@ func (b *binlogIO) uploadInsertLog( segID UniqueID, partID UniqueID, iData *InsertData, - meta *etcdpb.CollectionMeta) (map[UniqueID]*datapb.FieldBinlog, error) { - + meta *etcdpb.CollectionMeta, +) (map[UniqueID]*datapb.FieldBinlog, error) { iCodec := storage.NewInsertCodecWithSchema(meta) kvs := make(map[string][]byte) @@ -292,7 +312,8 @@ func (b *binlogIO) uploadDeltaLog( segID UniqueID, partID UniqueID, dData *DeleteData, - meta *etcdpb.CollectionMeta) ([]*datapb.FieldBinlog, error) { + meta *etcdpb.CollectionMeta, +) ([]*datapb.FieldBinlog, error) { var ( deltaInfo = make([]*datapb.FieldBinlog, 0) kvs = make(map[string][]byte) diff --git a/internal/datanode/binlog_io_test.go b/internal/datanode/binlog_io_test.go index a51f5f481a386..3df83685ef134 100644 --- a/internal/datanode/binlog_io_test.go +++ b/internal/datanode/binlog_io_test.go @@ -24,16 +24,16 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/datanode/allocator" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "go.uber.org/zap" ) var binlogTestDir = "/tmp/milvus_test/test_binlog_io" @@ -89,7 +89,7 @@ func TestBinlogIOInterfaceMethods(t *testing.T) { ctx, cancel := context.WithCancel(test.inctx) cancel() - _, err := b.download(ctx, nil) + _, err := b.download(ctx, []string{"test"}) assert.EqualError(t, err, errDownloadFromBlobStorage.Error()) } }) @@ -97,7 +97,7 @@ func TestBinlogIOInterfaceMethods(t *testing.T) { }) t.Run("Test download twice", func(t *testing.T) { - mkc := &mockCm{errMultiLoad: true} + mkc := &mockCm{errRead: true} alloc := allocator.NewMockAllocator(t) b := &binlogIO{mkc, alloc} @@ -145,7 +145,7 @@ func TestBinlogIOInterfaceMethods(t *testing.T) { }) t.Run("upload failed", func(t *testing.T) { - mkc := &mockCm{errMultiLoad: true, errMultiSave: true} + mkc := &mockCm{errRead: true, errSave: true} alloc := allocator.NewMockAllocator(t) b := binlogIO{mkc, alloc} @@ -201,7 +201,6 @@ func TestBinlogIOInnerMethods(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { if test.isvalid { - k, v, err := b.genDeltaBlobs(&DeleteData{ Pks: []primaryKey{test.deletepk}, Tss: []uint64{test.ts}, @@ -237,7 +236,6 @@ func TestBinlogIOInnerMethods(t *testing.T) { assert.Error(t, err) assert.Empty(t, k) assert.Empty(t, v) - }) }) @@ -360,8 +358,8 @@ func TestBinlogIOInnerMethods(t *testing.T) { type mockCm struct { storage.ChunkManager - errMultiLoad bool - errMultiSave bool + errRead bool + errSave bool MultiReadReturn [][]byte ReadReturn []byte } @@ -373,25 +371,24 @@ func (mk *mockCm) RootPath() string { } func (mk *mockCm) Write(ctx context.Context, filePath string, content []byte) error { + if mk.errSave { + return errors.New("mockKv save error") + } return nil } func (mk *mockCm) MultiWrite(ctx context.Context, contents map[string][]byte) error { - if mk.errMultiSave { - return errors.New("mockKv multisave error") - } return nil } func (mk *mockCm) Read(ctx context.Context, filePath string) ([]byte, error) { + if mk.errRead { + return nil, errors.New("mockKv read error") + } return mk.ReadReturn, nil } func (mk *mockCm) MultiRead(ctx context.Context, filePaths []string) ([][]byte, error) { - if mk.errMultiLoad { - return nil, errors.New("mockKv multiload error") - } - if mk.MultiReadReturn != nil { return mk.MultiReadReturn, nil } diff --git a/internal/datanode/broker/broker.go b/internal/datanode/broker/broker.go new file mode 100644 index 0000000000000..456bce8782fa4 --- /dev/null +++ b/internal/datanode/broker/broker.go @@ -0,0 +1,54 @@ +package broker + +import ( + "context" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// Broker is the interface for datanode to interact with other components. +type Broker interface { + RootCoord + DataCoord +} + +type coordBroker struct { + *rootCoordBroker + *dataCoordBroker +} + +func NewCoordBroker(rc types.RootCoordClient, dc types.DataCoordClient) Broker { + return &coordBroker{ + rootCoordBroker: &rootCoordBroker{ + client: rc, + }, + dataCoordBroker: &dataCoordBroker{ + client: dc, + }, + } +} + +// RootCoord is the interface wraps `RootCoord` grpc call +type RootCoord interface { + DescribeCollection(ctx context.Context, collectionID typeutil.UniqueID, ts typeutil.Timestamp) (*milvuspb.DescribeCollectionResponse, error) + ShowPartitions(ctx context.Context, dbName, collectionName string) (map[string]int64, error) + ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) error + AllocTimestamp(ctx context.Context, num uint32) (ts uint64, count uint32, err error) +} + +// DataCoord is the interface wraps `DataCoord` grpc call +type DataCoord interface { + AssignSegmentID(ctx context.Context, reqs ...*datapb.SegmentIDRequest) ([]typeutil.UniqueID, error) + ReportTimeTick(ctx context.Context, msgs []*msgpb.DataNodeTtMsg) error + GetSegmentInfo(ctx context.Context, segmentIDs []int64) ([]*datapb.SegmentInfo, error) + UpdateChannelCheckpoint(ctx context.Context, channelName string, cp *msgpb.MsgPosition) error + SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) error + DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) error + UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) error + SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) error +} diff --git a/internal/datanode/broker/datacoord.go b/internal/datanode/broker/datacoord.go new file mode 100644 index 0000000000000..7814ccb8282fa --- /dev/null +++ b/internal/datanode/broker/datacoord.go @@ -0,0 +1,157 @@ +package broker + +import ( + "context" + + "github.com/samber/lo" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type dataCoordBroker struct { + client types.DataCoordClient +} + +func (dc *dataCoordBroker) AssignSegmentID(ctx context.Context, reqs ...*datapb.SegmentIDRequest) ([]typeutil.UniqueID, error) { + req := &datapb.AssignSegmentIDRequest{ + NodeID: paramtable.GetNodeID(), + PeerRole: typeutil.ProxyRole, + SegmentIDRequests: reqs, + } + + resp, err := dc.client.AssignSegmentID(ctx, req) + + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to call datacoord AssignSegmentID", zap.Error(err)) + return nil, err + } + + return lo.Map(resp.GetSegIDAssignments(), func(result *datapb.SegmentIDAssignment, _ int) typeutil.UniqueID { + return result.GetSegID() + }), nil +} + +func (dc *dataCoordBroker) ReportTimeTick(ctx context.Context, msgs []*msgpb.DataNodeTtMsg) error { + log := log.Ctx(ctx) + + req := &datapb.ReportDataNodeTtMsgsRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_DataNodeTt), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + Msgs: msgs, + } + + resp, err := dc.client.ReportDataNodeTtMsgs(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to report datanodeTtMsgs", zap.Error(err)) + return err + } + return nil +} + +func (dc *dataCoordBroker) GetSegmentInfo(ctx context.Context, segmentIDs []int64) ([]*datapb.SegmentInfo, error) { + log := log.Ctx(ctx).With( + zap.Int64s("segmentIDs", segmentIDs), + ) + + infoResp, err := dc.client.GetSegmentInfo(ctx, &datapb.GetSegmentInfoRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_SegmentInfo), + commonpbutil.WithMsgID(0), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + SegmentIDs: segmentIDs, + IncludeUnHealthy: true, + }) + if err := merr.CheckRPCCall(infoResp, err); err != nil { + log.Warn("Fail to get SegmentInfo by ids from datacoord", zap.Error(err)) + return nil, err + } + + return infoResp.Infos, nil +} + +func (dc *dataCoordBroker) UpdateChannelCheckpoint(ctx context.Context, channelName string, cp *msgpb.MsgPosition) error { + channelCPTs, _ := tsoutil.ParseTS(cp.GetTimestamp()) + log := log.Ctx(ctx).With( + zap.String("channelName", channelName), + zap.Time("channelCheckpointTime", channelCPTs), + ) + + req := &datapb.UpdateChannelCheckpointRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + VChannel: channelName, + Position: cp, + } + + resp, err := dc.client.UpdateChannelCheckpoint(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to update channel checkpoint", zap.Error(err)) + return err + } + return nil +} + +func (dc *dataCoordBroker) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) error { + log := log.Ctx(ctx) + + resp, err := dc.client.SaveBinlogPaths(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to SaveBinlogPaths", zap.Error(err)) + return err + } + + return nil +} + +func (dc *dataCoordBroker) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) error { + log := log.Ctx(ctx) + + resp, err := dc.client.DropVirtualChannel(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + if resp.GetStatus().GetErrorCode() == commonpb.ErrorCode_MetaFailed { + err = merr.WrapErrChannelNotFound(req.GetChannelName()) + } + log.Warn("failed to SaveBinlogPaths", zap.Error(err)) + return err + } + + return nil +} + +func (dc *dataCoordBroker) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) error { + log := log.Ctx(ctx) + + resp, err := dc.client.UpdateSegmentStatistics(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to UpdateSegmentStatistics", zap.Error(err)) + return err + } + + return nil +} + +func (dc *dataCoordBroker) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) error { + log := log.Ctx(ctx) + + resp, err := dc.client.SaveImportSegment(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to UpdateSegmentStatistics", zap.Error(err)) + return err + } + + return nil +} diff --git a/internal/datanode/broker/datacoord_test.go b/internal/datanode/broker/datacoord_test.go new file mode 100644 index 0000000000000..bb5dd5f4a8c82 --- /dev/null +++ b/internal/datanode/broker/datacoord_test.go @@ -0,0 +1,375 @@ +package broker + +import ( + "context" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tsoutil" +) + +type dataCoordSuite struct { + suite.Suite + + dc *mocks.MockDataCoordClient + broker Broker +} + +func (s *dataCoordSuite) SetupSuite() { + paramtable.Init() +} + +func (s *dataCoordSuite) SetupTest() { + s.dc = mocks.NewMockDataCoordClient(s.T()) + s.broker = NewCoordBroker(nil, s.dc) +} + +func (s *dataCoordSuite) resetMock() { + s.dc.AssertExpectations(s.T()) + s.dc.ExpectedCalls = nil +} + +func (s *dataCoordSuite) TestAssignSegmentID() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + reqs := []*datapb.SegmentIDRequest{ + {CollectionID: 100, Count: 1000}, + {CollectionID: 100, Count: 2000}, + } + + s.Run("normal_case", func() { + s.dc.EXPECT().AssignSegmentID(mock.Anything, mock.Anything). + Return(&datapb.AssignSegmentIDResponse{ + Status: merr.Status(nil), + SegIDAssignments: lo.Map(reqs, func(req *datapb.SegmentIDRequest, _ int) *datapb.SegmentIDAssignment { + return &datapb.SegmentIDAssignment{ + Status: merr.Status(nil), + SegID: 10001, + Count: req.GetCount(), + } + }), + }, nil) + + segmentIDs, err := s.broker.AssignSegmentID(ctx, reqs...) + s.NoError(err) + s.Equal(len(segmentIDs), len(reqs)) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.dc.EXPECT().AssignSegmentID(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + _, err := s.broker.AssignSegmentID(ctx, reqs...) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().AssignSegmentID(mock.Anything, mock.Anything). + Return(&datapb.AssignSegmentIDResponse{ + Status: merr.Status(errors.New("mock")), + }, nil) + + _, err := s.broker.AssignSegmentID(ctx, reqs...) + s.Error(err) + s.resetMock() + }) +} + +func (s *dataCoordSuite) TestReportTimeTick() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + msgs := []*msgpb.DataNodeTtMsg{ + {Timestamp: 1000, ChannelName: "dml_0"}, + {Timestamp: 2000, ChannelName: "dml_1"}, + } + + s.Run("normal_case", func() { + s.dc.EXPECT().ReportDataNodeTtMsgs(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *datapb.ReportDataNodeTtMsgsRequest, _ ...grpc.CallOption) { + s.Equal(msgs, req.GetMsgs()) + }). + Return(merr.Status(nil), nil) + + err := s.broker.ReportTimeTick(ctx, msgs) + s.NoError(err) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.dc.EXPECT().ReportDataNodeTtMsgs(mock.Anything, mock.Anything). + Return(merr.Status(errors.New("mock")), nil) + + err := s.broker.ReportTimeTick(ctx, msgs) + s.Error(err) + s.resetMock() + }) +} + +func (s *dataCoordSuite) TestGetSegmentInfo() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + segmentIDs := []int64{1, 2, 3} + + s.Run("normal_case", func() { + s.dc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *datapb.GetSegmentInfoRequest, _ ...grpc.CallOption) { + s.ElementsMatch(segmentIDs, req.GetSegmentIDs()) + s.True(req.GetIncludeUnHealthy()) + }). + Return(&datapb.GetSegmentInfoResponse{ + Status: merr.Status(nil), + Infos: lo.Map(segmentIDs, func(id int64, _ int) *datapb.SegmentInfo { + return &datapb.SegmentInfo{ID: id} + }), + }, nil) + infos, err := s.broker.GetSegmentInfo(ctx, segmentIDs) + s.NoError(err) + s.ElementsMatch(segmentIDs, lo.Map(infos, func(info *datapb.SegmentInfo, _ int) int64 { return info.GetID() })) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.dc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + _, err := s.broker.GetSegmentInfo(ctx, segmentIDs) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything). + Return(&datapb.GetSegmentInfoResponse{ + Status: merr.Status(errors.New("mock")), + }, nil) + _, err := s.broker.GetSegmentInfo(ctx, segmentIDs) + s.Error(err) + s.resetMock() + }) +} + +func (s *dataCoordSuite) TestUpdateChannelCheckpoint() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + channelName := "dml_0" + checkpoint := &msgpb.MsgPosition{ + ChannelName: channelName, + MsgID: []byte{1, 2, 3}, + Timestamp: tsoutil.ComposeTSByTime(time.Now(), 0), + } + + s.Run("normal_case", func() { + s.dc.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *datapb.UpdateChannelCheckpointRequest, _ ...grpc.CallOption) { + s.Equal(channelName, req.GetVChannel()) + cp := req.GetPosition() + s.Equal(checkpoint.MsgID, cp.GetMsgID()) + s.Equal(checkpoint.ChannelName, cp.GetChannelName()) + s.Equal(checkpoint.Timestamp, cp.GetTimestamp()) + }). + Return(merr.Status(nil), nil) + + err := s.broker.UpdateChannelCheckpoint(ctx, channelName, checkpoint) + s.NoError(err) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.dc.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + err := s.broker.UpdateChannelCheckpoint(ctx, channelName, checkpoint) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything). + Return(merr.Status(errors.New("mock")), nil) + + err := s.broker.UpdateChannelCheckpoint(ctx, channelName, checkpoint) + s.Error(err) + s.resetMock() + }) +} + +func (s *dataCoordSuite) TestSaveBinlogPaths() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req := &datapb.SaveBinlogPathsRequest{ + Channel: "dml_0", + } + + s.Run("normal_case", func() { + s.dc.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *datapb.SaveBinlogPathsRequest, _ ...grpc.CallOption) { + s.Equal("dml_0", req.GetChannel()) + }). + Return(merr.Status(nil), nil) + err := s.broker.SaveBinlogPaths(ctx, req) + s.NoError(err) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.dc.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + err := s.broker.SaveBinlogPaths(ctx, req) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything). + Return(merr.Status(errors.New("mock")), nil) + err := s.broker.SaveBinlogPaths(ctx, req) + s.Error(err) + s.resetMock() + }) +} + +func (s *dataCoordSuite) TestDropVirtualChannel() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req := &datapb.DropVirtualChannelRequest{ + ChannelName: "dml_0", + } + + s.Run("normal_case", func() { + s.dc.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *datapb.DropVirtualChannelRequest, _ ...grpc.CallOption) { + s.Equal("dml_0", req.GetChannelName()) + }). + Return(&datapb.DropVirtualChannelResponse{Status: merr.Status(nil)}, nil) + err := s.broker.DropVirtualChannel(ctx, req) + s.NoError(err) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.dc.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + err := s.broker.DropVirtualChannel(ctx, req) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything). + Return(&datapb.DropVirtualChannelResponse{Status: merr.Status(errors.New("mock"))}, nil) + err := s.broker.DropVirtualChannel(ctx, req) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_legacy_MetaFailed", func() { + s.dc.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything). + Return(&datapb.DropVirtualChannelResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_MetaFailed}}, nil) + err := s.broker.DropVirtualChannel(ctx, req) + s.Error(err) + s.ErrorIs(err, merr.ErrChannelNotFound) + s.resetMock() + }) +} + +func (s *dataCoordSuite) TestUpdateSegmentStatistics() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req := &datapb.UpdateSegmentStatisticsRequest{ + Stats: []*commonpb.SegmentStats{ + {}, {}, {}, + }, + } + + s.Run("normal_case", func() { + s.dc.EXPECT().UpdateSegmentStatistics(mock.Anything, mock.Anything). + Run(func(_ context.Context, r *datapb.UpdateSegmentStatisticsRequest, _ ...grpc.CallOption) { + s.Equal(len(req.GetStats()), len(r.GetStats())) + }). + Return(merr.Status(nil), nil) + err := s.broker.UpdateSegmentStatistics(ctx, req) + s.NoError(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().UpdateSegmentStatistics(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + err := s.broker.UpdateSegmentStatistics(ctx, req) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().UpdateSegmentStatistics(mock.Anything, mock.Anything). + Return(merr.Status(errors.New("mock")), nil) + err := s.broker.UpdateSegmentStatistics(ctx, req) + s.Error(err) + s.resetMock() + }) +} + +func (s *dataCoordSuite) TestSaveImportSegment() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + segmentID := int64(1001) + collectionID := int64(100) + + req := &datapb.SaveImportSegmentRequest{ + SegmentId: segmentID, + CollectionId: collectionID, + } + + s.Run("normal_case", func() { + s.dc.EXPECT().SaveImportSegment(mock.Anything, mock.Anything). + Run(func(_ context.Context, r *datapb.SaveImportSegmentRequest, _ ...grpc.CallOption) { + s.Equal(collectionID, req.GetCollectionId()) + s.Equal(segmentID, req.GetSegmentId()) + }). + Return(merr.Status(nil), nil) + err := s.broker.SaveImportSegment(ctx, req) + s.NoError(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().SaveImportSegment(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + err := s.broker.SaveImportSegment(ctx, req) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().SaveImportSegment(mock.Anything, mock.Anything). + Return(merr.Status(errors.New("mock")), nil) + err := s.broker.SaveImportSegment(ctx, req) + s.Error(err) + s.resetMock() + }) +} + +func TestDataCoordBroker(t *testing.T) { + suite.Run(t, new(dataCoordSuite)) +} diff --git a/internal/datanode/broker/mock_broker.go b/internal/datanode/broker/mock_broker.go new file mode 100644 index 0000000000000..15b86a99545d7 --- /dev/null +++ b/internal/datanode/broker/mock_broker.go @@ -0,0 +1,641 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package broker + +import ( + context "context" + + milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + datapb "github.com/milvus-io/milvus/internal/proto/datapb" + + mock "github.com/stretchr/testify/mock" + + msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + + rootcoordpb "github.com/milvus-io/milvus/internal/proto/rootcoordpb" +) + +// MockBroker is an autogenerated mock type for the Broker type +type MockBroker struct { + mock.Mock +} + +type MockBroker_Expecter struct { + mock *mock.Mock +} + +func (_m *MockBroker) EXPECT() *MockBroker_Expecter { + return &MockBroker_Expecter{mock: &_m.Mock} +} + +// AllocTimestamp provides a mock function with given fields: ctx, num +func (_m *MockBroker) AllocTimestamp(ctx context.Context, num uint32) (uint64, uint32, error) { + ret := _m.Called(ctx, num) + + var r0 uint64 + var r1 uint32 + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, uint32) (uint64, uint32, error)); ok { + return rf(ctx, num) + } + if rf, ok := ret.Get(0).(func(context.Context, uint32) uint64); ok { + r0 = rf(ctx, num) + } else { + r0 = ret.Get(0).(uint64) + } + + if rf, ok := ret.Get(1).(func(context.Context, uint32) uint32); ok { + r1 = rf(ctx, num) + } else { + r1 = ret.Get(1).(uint32) + } + + if rf, ok := ret.Get(2).(func(context.Context, uint32) error); ok { + r2 = rf(ctx, num) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockBroker_AllocTimestamp_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllocTimestamp' +type MockBroker_AllocTimestamp_Call struct { + *mock.Call +} + +// AllocTimestamp is a helper method to define mock.On call +// - ctx context.Context +// - num uint32 +func (_e *MockBroker_Expecter) AllocTimestamp(ctx interface{}, num interface{}) *MockBroker_AllocTimestamp_Call { + return &MockBroker_AllocTimestamp_Call{Call: _e.mock.On("AllocTimestamp", ctx, num)} +} + +func (_c *MockBroker_AllocTimestamp_Call) Run(run func(ctx context.Context, num uint32)) *MockBroker_AllocTimestamp_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(uint32)) + }) + return _c +} + +func (_c *MockBroker_AllocTimestamp_Call) Return(ts uint64, count uint32, err error) *MockBroker_AllocTimestamp_Call { + _c.Call.Return(ts, count, err) + return _c +} + +func (_c *MockBroker_AllocTimestamp_Call) RunAndReturn(run func(context.Context, uint32) (uint64, uint32, error)) *MockBroker_AllocTimestamp_Call { + _c.Call.Return(run) + return _c +} + +// AssignSegmentID provides a mock function with given fields: ctx, reqs +func (_m *MockBroker) AssignSegmentID(ctx context.Context, reqs ...*datapb.SegmentIDRequest) ([]int64, error) { + _va := make([]interface{}, len(reqs)) + for _i := range reqs { + _va[_i] = reqs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 []int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ...*datapb.SegmentIDRequest) ([]int64, error)); ok { + return rf(ctx, reqs...) + } + if rf, ok := ret.Get(0).(func(context.Context, ...*datapb.SegmentIDRequest) []int64); ok { + r0 = rf(ctx, reqs...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int64) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, ...*datapb.SegmentIDRequest) error); ok { + r1 = rf(ctx, reqs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroker_AssignSegmentID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AssignSegmentID' +type MockBroker_AssignSegmentID_Call struct { + *mock.Call +} + +// AssignSegmentID is a helper method to define mock.On call +// - ctx context.Context +// - reqs ...*datapb.SegmentIDRequest +func (_e *MockBroker_Expecter) AssignSegmentID(ctx interface{}, reqs ...interface{}) *MockBroker_AssignSegmentID_Call { + return &MockBroker_AssignSegmentID_Call{Call: _e.mock.On("AssignSegmentID", + append([]interface{}{ctx}, reqs...)...)} +} + +func (_c *MockBroker_AssignSegmentID_Call) Run(run func(ctx context.Context, reqs ...*datapb.SegmentIDRequest)) *MockBroker_AssignSegmentID_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]*datapb.SegmentIDRequest, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(*datapb.SegmentIDRequest) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockBroker_AssignSegmentID_Call) Return(_a0 []int64, _a1 error) *MockBroker_AssignSegmentID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroker_AssignSegmentID_Call) RunAndReturn(run func(context.Context, ...*datapb.SegmentIDRequest) ([]int64, error)) *MockBroker_AssignSegmentID_Call { + _c.Call.Return(run) + return _c +} + +// DescribeCollection provides a mock function with given fields: ctx, collectionID, ts +func (_m *MockBroker) DescribeCollection(ctx context.Context, collectionID int64, ts uint64) (*milvuspb.DescribeCollectionResponse, error) { + ret := _m.Called(ctx, collectionID, ts) + + var r0 *milvuspb.DescribeCollectionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64, uint64) (*milvuspb.DescribeCollectionResponse, error)); ok { + return rf(ctx, collectionID, ts) + } + if rf, ok := ret.Get(0).(func(context.Context, int64, uint64) *milvuspb.DescribeCollectionResponse); ok { + r0 = rf(ctx, collectionID, ts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.DescribeCollectionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64, uint64) error); ok { + r1 = rf(ctx, collectionID, ts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroker_DescribeCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeCollection' +type MockBroker_DescribeCollection_Call struct { + *mock.Call +} + +// DescribeCollection is a helper method to define mock.On call +// - ctx context.Context +// - collectionID int64 +// - ts uint64 +func (_e *MockBroker_Expecter) DescribeCollection(ctx interface{}, collectionID interface{}, ts interface{}) *MockBroker_DescribeCollection_Call { + return &MockBroker_DescribeCollection_Call{Call: _e.mock.On("DescribeCollection", ctx, collectionID, ts)} +} + +func (_c *MockBroker_DescribeCollection_Call) Run(run func(ctx context.Context, collectionID int64, ts uint64)) *MockBroker_DescribeCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64), args[2].(uint64)) + }) + return _c +} + +func (_c *MockBroker_DescribeCollection_Call) Return(_a0 *milvuspb.DescribeCollectionResponse, _a1 error) *MockBroker_DescribeCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroker_DescribeCollection_Call) RunAndReturn(run func(context.Context, int64, uint64) (*milvuspb.DescribeCollectionResponse, error)) *MockBroker_DescribeCollection_Call { + _c.Call.Return(run) + return _c +} + +// DropVirtualChannel provides a mock function with given fields: ctx, req +func (_m *MockBroker) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) error { + ret := _m.Called(ctx, req) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.DropVirtualChannelRequest) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBroker_DropVirtualChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropVirtualChannel' +type MockBroker_DropVirtualChannel_Call struct { + *mock.Call +} + +// DropVirtualChannel is a helper method to define mock.On call +// - ctx context.Context +// - req *datapb.DropVirtualChannelRequest +func (_e *MockBroker_Expecter) DropVirtualChannel(ctx interface{}, req interface{}) *MockBroker_DropVirtualChannel_Call { + return &MockBroker_DropVirtualChannel_Call{Call: _e.mock.On("DropVirtualChannel", ctx, req)} +} + +func (_c *MockBroker_DropVirtualChannel_Call) Run(run func(ctx context.Context, req *datapb.DropVirtualChannelRequest)) *MockBroker_DropVirtualChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*datapb.DropVirtualChannelRequest)) + }) + return _c +} + +func (_c *MockBroker_DropVirtualChannel_Call) Return(_a0 error) *MockBroker_DropVirtualChannel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroker_DropVirtualChannel_Call) RunAndReturn(run func(context.Context, *datapb.DropVirtualChannelRequest) error) *MockBroker_DropVirtualChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetSegmentInfo provides a mock function with given fields: ctx, segmentIDs +func (_m *MockBroker) GetSegmentInfo(ctx context.Context, segmentIDs []int64) ([]*datapb.SegmentInfo, error) { + ret := _m.Called(ctx, segmentIDs) + + var r0 []*datapb.SegmentInfo + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]*datapb.SegmentInfo, error)); ok { + return rf(ctx, segmentIDs) + } + if rf, ok := ret.Get(0).(func(context.Context, []int64) []*datapb.SegmentInfo); ok { + r0 = rf(ctx, segmentIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*datapb.SegmentInfo) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []int64) error); ok { + r1 = rf(ctx, segmentIDs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroker_GetSegmentInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSegmentInfo' +type MockBroker_GetSegmentInfo_Call struct { + *mock.Call +} + +// GetSegmentInfo is a helper method to define mock.On call +// - ctx context.Context +// - segmentIDs []int64 +func (_e *MockBroker_Expecter) GetSegmentInfo(ctx interface{}, segmentIDs interface{}) *MockBroker_GetSegmentInfo_Call { + return &MockBroker_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", ctx, segmentIDs)} +} + +func (_c *MockBroker_GetSegmentInfo_Call) Run(run func(ctx context.Context, segmentIDs []int64)) *MockBroker_GetSegmentInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]int64)) + }) + return _c +} + +func (_c *MockBroker_GetSegmentInfo_Call) Return(_a0 []*datapb.SegmentInfo, _a1 error) *MockBroker_GetSegmentInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroker_GetSegmentInfo_Call) RunAndReturn(run func(context.Context, []int64) ([]*datapb.SegmentInfo, error)) *MockBroker_GetSegmentInfo_Call { + _c.Call.Return(run) + return _c +} + +// ReportImport provides a mock function with given fields: ctx, req +func (_m *MockBroker) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) error { + ret := _m.Called(ctx, req) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.ImportResult) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBroker_ReportImport_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportImport' +type MockBroker_ReportImport_Call struct { + *mock.Call +} + +// ReportImport is a helper method to define mock.On call +// - ctx context.Context +// - req *rootcoordpb.ImportResult +func (_e *MockBroker_Expecter) ReportImport(ctx interface{}, req interface{}) *MockBroker_ReportImport_Call { + return &MockBroker_ReportImport_Call{Call: _e.mock.On("ReportImport", ctx, req)} +} + +func (_c *MockBroker_ReportImport_Call) Run(run func(ctx context.Context, req *rootcoordpb.ImportResult)) *MockBroker_ReportImport_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*rootcoordpb.ImportResult)) + }) + return _c +} + +func (_c *MockBroker_ReportImport_Call) Return(_a0 error) *MockBroker_ReportImport_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroker_ReportImport_Call) RunAndReturn(run func(context.Context, *rootcoordpb.ImportResult) error) *MockBroker_ReportImport_Call { + _c.Call.Return(run) + return _c +} + +// ReportTimeTick provides a mock function with given fields: ctx, msgs +func (_m *MockBroker) ReportTimeTick(ctx context.Context, msgs []*msgpb.DataNodeTtMsg) error { + ret := _m.Called(ctx, msgs) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []*msgpb.DataNodeTtMsg) error); ok { + r0 = rf(ctx, msgs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBroker_ReportTimeTick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportTimeTick' +type MockBroker_ReportTimeTick_Call struct { + *mock.Call +} + +// ReportTimeTick is a helper method to define mock.On call +// - ctx context.Context +// - msgs []*msgpb.DataNodeTtMsg +func (_e *MockBroker_Expecter) ReportTimeTick(ctx interface{}, msgs interface{}) *MockBroker_ReportTimeTick_Call { + return &MockBroker_ReportTimeTick_Call{Call: _e.mock.On("ReportTimeTick", ctx, msgs)} +} + +func (_c *MockBroker_ReportTimeTick_Call) Run(run func(ctx context.Context, msgs []*msgpb.DataNodeTtMsg)) *MockBroker_ReportTimeTick_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]*msgpb.DataNodeTtMsg)) + }) + return _c +} + +func (_c *MockBroker_ReportTimeTick_Call) Return(_a0 error) *MockBroker_ReportTimeTick_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroker_ReportTimeTick_Call) RunAndReturn(run func(context.Context, []*msgpb.DataNodeTtMsg) error) *MockBroker_ReportTimeTick_Call { + _c.Call.Return(run) + return _c +} + +// SaveBinlogPaths provides a mock function with given fields: ctx, req +func (_m *MockBroker) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) error { + ret := _m.Called(ctx, req) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.SaveBinlogPathsRequest) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBroker_SaveBinlogPaths_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveBinlogPaths' +type MockBroker_SaveBinlogPaths_Call struct { + *mock.Call +} + +// SaveBinlogPaths is a helper method to define mock.On call +// - ctx context.Context +// - req *datapb.SaveBinlogPathsRequest +func (_e *MockBroker_Expecter) SaveBinlogPaths(ctx interface{}, req interface{}) *MockBroker_SaveBinlogPaths_Call { + return &MockBroker_SaveBinlogPaths_Call{Call: _e.mock.On("SaveBinlogPaths", ctx, req)} +} + +func (_c *MockBroker_SaveBinlogPaths_Call) Run(run func(ctx context.Context, req *datapb.SaveBinlogPathsRequest)) *MockBroker_SaveBinlogPaths_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*datapb.SaveBinlogPathsRequest)) + }) + return _c +} + +func (_c *MockBroker_SaveBinlogPaths_Call) Return(_a0 error) *MockBroker_SaveBinlogPaths_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroker_SaveBinlogPaths_Call) RunAndReturn(run func(context.Context, *datapb.SaveBinlogPathsRequest) error) *MockBroker_SaveBinlogPaths_Call { + _c.Call.Return(run) + return _c +} + +// SaveImportSegment provides a mock function with given fields: ctx, req +func (_m *MockBroker) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) error { + ret := _m.Called(ctx, req) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.SaveImportSegmentRequest) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBroker_SaveImportSegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveImportSegment' +type MockBroker_SaveImportSegment_Call struct { + *mock.Call +} + +// SaveImportSegment is a helper method to define mock.On call +// - ctx context.Context +// - req *datapb.SaveImportSegmentRequest +func (_e *MockBroker_Expecter) SaveImportSegment(ctx interface{}, req interface{}) *MockBroker_SaveImportSegment_Call { + return &MockBroker_SaveImportSegment_Call{Call: _e.mock.On("SaveImportSegment", ctx, req)} +} + +func (_c *MockBroker_SaveImportSegment_Call) Run(run func(ctx context.Context, req *datapb.SaveImportSegmentRequest)) *MockBroker_SaveImportSegment_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*datapb.SaveImportSegmentRequest)) + }) + return _c +} + +func (_c *MockBroker_SaveImportSegment_Call) Return(_a0 error) *MockBroker_SaveImportSegment_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroker_SaveImportSegment_Call) RunAndReturn(run func(context.Context, *datapb.SaveImportSegmentRequest) error) *MockBroker_SaveImportSegment_Call { + _c.Call.Return(run) + return _c +} + +// ShowPartitions provides a mock function with given fields: ctx, dbName, collectionName +func (_m *MockBroker) ShowPartitions(ctx context.Context, dbName string, collectionName string) (map[string]int64, error) { + ret := _m.Called(ctx, dbName, collectionName) + + var r0 map[string]int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (map[string]int64, error)); ok { + return rf(ctx, dbName, collectionName) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) map[string]int64); ok { + r0 = rf(ctx, dbName, collectionName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]int64) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, dbName, collectionName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroker_ShowPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowPartitions' +type MockBroker_ShowPartitions_Call struct { + *mock.Call +} + +// ShowPartitions is a helper method to define mock.On call +// - ctx context.Context +// - dbName string +// - collectionName string +func (_e *MockBroker_Expecter) ShowPartitions(ctx interface{}, dbName interface{}, collectionName interface{}) *MockBroker_ShowPartitions_Call { + return &MockBroker_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", ctx, dbName, collectionName)} +} + +func (_c *MockBroker_ShowPartitions_Call) Run(run func(ctx context.Context, dbName string, collectionName string)) *MockBroker_ShowPartitions_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockBroker_ShowPartitions_Call) Return(_a0 map[string]int64, _a1 error) *MockBroker_ShowPartitions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroker_ShowPartitions_Call) RunAndReturn(run func(context.Context, string, string) (map[string]int64, error)) *MockBroker_ShowPartitions_Call { + _c.Call.Return(run) + return _c +} + +// UpdateChannelCheckpoint provides a mock function with given fields: ctx, channelName, cp +func (_m *MockBroker) UpdateChannelCheckpoint(ctx context.Context, channelName string, cp *msgpb.MsgPosition) error { + ret := _m.Called(ctx, channelName, cp) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.MsgPosition) error); ok { + r0 = rf(ctx, channelName, cp) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBroker_UpdateChannelCheckpoint_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateChannelCheckpoint' +type MockBroker_UpdateChannelCheckpoint_Call struct { + *mock.Call +} + +// UpdateChannelCheckpoint is a helper method to define mock.On call +// - ctx context.Context +// - channelName string +// - cp *msgpb.MsgPosition +func (_e *MockBroker_Expecter) UpdateChannelCheckpoint(ctx interface{}, channelName interface{}, cp interface{}) *MockBroker_UpdateChannelCheckpoint_Call { + return &MockBroker_UpdateChannelCheckpoint_Call{Call: _e.mock.On("UpdateChannelCheckpoint", ctx, channelName, cp)} +} + +func (_c *MockBroker_UpdateChannelCheckpoint_Call) Run(run func(ctx context.Context, channelName string, cp *msgpb.MsgPosition)) *MockBroker_UpdateChannelCheckpoint_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(*msgpb.MsgPosition)) + }) + return _c +} + +func (_c *MockBroker_UpdateChannelCheckpoint_Call) Return(_a0 error) *MockBroker_UpdateChannelCheckpoint_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroker_UpdateChannelCheckpoint_Call) RunAndReturn(run func(context.Context, string, *msgpb.MsgPosition) error) *MockBroker_UpdateChannelCheckpoint_Call { + _c.Call.Return(run) + return _c +} + +// UpdateSegmentStatistics provides a mock function with given fields: ctx, req +func (_m *MockBroker) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) error { + ret := _m.Called(ctx, req) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.UpdateSegmentStatisticsRequest) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBroker_UpdateSegmentStatistics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateSegmentStatistics' +type MockBroker_UpdateSegmentStatistics_Call struct { + *mock.Call +} + +// UpdateSegmentStatistics is a helper method to define mock.On call +// - ctx context.Context +// - req *datapb.UpdateSegmentStatisticsRequest +func (_e *MockBroker_Expecter) UpdateSegmentStatistics(ctx interface{}, req interface{}) *MockBroker_UpdateSegmentStatistics_Call { + return &MockBroker_UpdateSegmentStatistics_Call{Call: _e.mock.On("UpdateSegmentStatistics", ctx, req)} +} + +func (_c *MockBroker_UpdateSegmentStatistics_Call) Run(run func(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest)) *MockBroker_UpdateSegmentStatistics_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*datapb.UpdateSegmentStatisticsRequest)) + }) + return _c +} + +func (_c *MockBroker_UpdateSegmentStatistics_Call) Return(_a0 error) *MockBroker_UpdateSegmentStatistics_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroker_UpdateSegmentStatistics_Call) RunAndReturn(run func(context.Context, *datapb.UpdateSegmentStatisticsRequest) error) *MockBroker_UpdateSegmentStatistics_Call { + _c.Call.Return(run) + return _c +} + +// NewMockBroker creates a new instance of MockBroker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockBroker(t interface { + mock.TestingT + Cleanup(func()) +}) *MockBroker { + mock := &MockBroker{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/datanode/broker/rootcoord.go b/internal/datanode/broker/rootcoord.go new file mode 100644 index 0000000000000..47129f8487427 --- /dev/null +++ b/internal/datanode/broker/rootcoord.go @@ -0,0 +1,114 @@ +package broker + +import ( + "context" + "fmt" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type rootCoordBroker struct { + client types.RootCoordClient +} + +func (rc *rootCoordBroker) DescribeCollection(ctx context.Context, collectionID typeutil.UniqueID, timestamp typeutil.Timestamp) (*milvuspb.DescribeCollectionResponse, error) { + log := log.Ctx(ctx).With( + zap.Int64("collectionID", collectionID), + zap.Uint64("timestamp", timestamp), + ) + req := &milvuspb.DescribeCollectionRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + // please do not specify the collection name alone after database feature. + CollectionID: collectionID, + TimeStamp: timestamp, + } + + resp, err := rc.client.DescribeCollectionInternal(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to DescribeCollectionInternal", zap.Error(err)) + return nil, err + } + + return resp, nil +} + +func (rc *rootCoordBroker) ShowPartitions(ctx context.Context, dbName, collectionName string) (map[string]int64, error) { + req := &milvuspb.ShowPartitionsRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions), + ), + DbName: dbName, + CollectionName: collectionName, + } + + log := log.Ctx(ctx).With( + zap.String("dbName", dbName), + zap.String("collectionName", collectionName), + ) + + resp, err := rc.client.ShowPartitions(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to get partitions of collection", zap.Error(err)) + return nil, err + } + + partitionNames := resp.GetPartitionNames() + partitionIDs := resp.GetPartitionIDs() + if len(partitionNames) != len(partitionIDs) { + log.Warn("partition names and ids are unequal", + zap.Int("partitionNameNumber", len(partitionNames)), + zap.Int("partitionIDNumber", len(partitionIDs))) + return nil, fmt.Errorf("partition names and ids are unequal, number of names: %d, number of ids: %d", + len(partitionNames), len(partitionIDs)) + } + + partitions := make(map[string]int64) + for i := 0; i < len(partitionNames); i++ { + partitions[partitionNames[i]] = partitionIDs[i] + } + + return partitions, nil +} + +func (rc *rootCoordBroker) AllocTimestamp(ctx context.Context, num uint32) (uint64, uint32, error) { + log := log.Ctx(ctx) + + req := &rootcoordpb.AllocTimestampRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_RequestTSO), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + Count: num, + } + + resp, err := rc.client.AllocTimestamp(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to AllocTimestamp", zap.Error(err)) + return 0, 0, err + } + return resp.GetTimestamp(), resp.GetCount(), nil +} + +func (rc *rootCoordBroker) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) error { + log := log.Ctx(ctx) + resp, err := rc.client.ReportImport(ctx, req) + + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to ReportImport", zap.Error(err)) + return err + } + return nil +} diff --git a/internal/datanode/broker/rootcoord_test.go b/internal/datanode/broker/rootcoord_test.go new file mode 100644 index 0000000000000..e08279fe2f2b2 --- /dev/null +++ b/internal/datanode/broker/rootcoord_test.go @@ -0,0 +1,241 @@ +package broker + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tsoutil" +) + +type rootCoordSuite struct { + suite.Suite + + rc *mocks.MockRootCoordClient + broker Broker +} + +func (s *rootCoordSuite) SetupSuite() { + paramtable.Init() +} + +func (s *rootCoordSuite) SetupTest() { + s.rc = mocks.NewMockRootCoordClient(s.T()) + s.broker = NewCoordBroker(s.rc, nil) +} + +func (s *rootCoordSuite) resetMock() { + s.rc.AssertExpectations(s.T()) + s.rc.ExpectedCalls = nil +} + +func (s *rootCoordSuite) TestDescribeCollection() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + collectionID := int64(100) + timestamp := tsoutil.ComposeTSByTime(time.Now(), 0) + + s.Run("normal_case", func() { + collName := "test_collection_name" + + s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) { + s.Equal(collectionID, req.GetCollectionID()) + s.Equal(timestamp, req.GetTimeStamp()) + }). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + CollectionID: collectionID, + CollectionName: collName, + }, nil) + + resp, err := s.broker.DescribeCollection(ctx, collectionID, timestamp) + s.NoError(err) + s.Equal(collectionID, resp.GetCollectionID()) + s.Equal(collName, resp.GetCollectionName()) + s.resetMock() + }) + + s.Run("rootcoord_return_error", func() { + s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + _, err := s.broker.DescribeCollection(ctx, collectionID, timestamp) + s.Error(err) + s.resetMock() + }) + + s.Run("rootcoord_return_failure_status", func() { + s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(errors.New("mocked")), + }, nil) + + _, err := s.broker.DescribeCollection(ctx, collectionID, timestamp) + s.Error(err) + s.resetMock() + }) +} + +func (s *rootCoordSuite) TestShowPartitions() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dbName := "defaultDB" + collName := "testCollection" + + s.Run("normal_case", func() { + partitions := map[string]int64{ + "part1": 1001, + "part2": 1002, + "part3": 1003, + } + + names := lo.Keys(partitions) + ids := lo.Map(names, func(name string, _ int) int64 { + return partitions[name] + }) + + s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *milvuspb.ShowPartitionsRequest, _ ...grpc.CallOption) { + s.Equal(dbName, req.GetDbName()) + s.Equal(collName, req.GetCollectionName()) + }). + Return(&milvuspb.ShowPartitionsResponse{ + Status: merr.Status(nil), + PartitionIDs: ids, + PartitionNames: names, + }, nil) + partNameIDs, err := s.broker.ShowPartitions(ctx, dbName, collName) + s.NoError(err) + s.Equal(len(partitions), len(partNameIDs)) + for name, id := range partitions { + result, ok := partNameIDs[name] + s.True(ok) + s.Equal(id, result) + } + s.resetMock() + }) + + s.Run("rootcoord_return_error", func() { + s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + _, err := s.broker.ShowPartitions(ctx, dbName, collName) + s.Error(err) + s.resetMock() + }) + + s.Run("partition_id_name_not_match", func() { + s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything). + Return(&milvuspb.ShowPartitionsResponse{ + Status: merr.Status(nil), + PartitionIDs: []int64{1, 2}, + PartitionNames: []string{"part1"}, + }, nil) + + _, err := s.broker.ShowPartitions(ctx, dbName, collName) + s.Error(err) + s.resetMock() + }) +} + +func (s *rootCoordSuite) TestAllocTimestamp() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("normal_case", func() { + num := rand.Intn(10) + 1 + ts := tsoutil.ComposeTSByTime(time.Now(), 0) + s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *rootcoordpb.AllocTimestampRequest, _ ...grpc.CallOption) { + s.EqualValues(num, req.GetCount()) + }). + Return(&rootcoordpb.AllocTimestampResponse{ + Status: merr.Status(nil), + Timestamp: ts, + Count: uint32(num), + }, nil) + + timestamp, cnt, err := s.broker.AllocTimestamp(ctx, uint32(num)) + s.NoError(err) + s.Equal(ts, timestamp) + s.EqualValues(num, cnt) + s.resetMock() + }) + + s.Run("rootcoord_return_error", func() { + s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + _, _, err := s.broker.AllocTimestamp(ctx, 1) + s.Error(err) + s.resetMock() + }) + + s.Run("rootcoord_return_failure_status", func() { + s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything). + Return(&rootcoordpb.AllocTimestampResponse{Status: merr.Status(errors.New("mock"))}, nil) + _, _, err := s.broker.AllocTimestamp(ctx, 1) + s.Error(err) + s.resetMock() + }) +} + +func (s *rootCoordSuite) TestReportImport() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + taskID := rand.Int63() + + req := &rootcoordpb.ImportResult{ + Status: merr.Status(nil), + TaskId: taskID, + } + + s.Run("normal_case", func() { + s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *rootcoordpb.ImportResult, _ ...grpc.CallOption) { + s.Equal(taskID, req.GetTaskId()) + }). + Return(merr.Status(nil), nil) + + err := s.broker.ReportImport(ctx, req) + s.NoError(err) + s.resetMock() + }) + + s.Run("rootcoord_return_error", func() { + s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + err := s.broker.ReportImport(ctx, req) + s.Error(err) + s.resetMock() + }) + + s.Run("rootcoord_return_failure_status", func() { + s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything). + Return(merr.Status(errors.New("mock")), nil) + + err := s.broker.ReportImport(ctx, req) + s.Error(err) + s.resetMock() + }) +} + +func TestRootCoordBroker(t *testing.T) { + suite.Run(t, new(rootCoordSuite)) +} diff --git a/internal/datanode/buffer.go b/internal/datanode/buffer.go index 6615ee2905736..48e16988b52f0 100644 --- a/internal/datanode/buffer.go +++ b/internal/datanode/buffer.go @@ -62,7 +62,6 @@ func (m *DeltaBufferManager) GetEntriesNum(segID UniqueID) int64 { func (m *DeltaBufferManager) UpdateCompactedSegments() { compactedTo2From := m.channel.listCompactedSegmentIDs() for compactedTo, compactedFrom := range compactedTo2From { - // if the compactedTo segment has 0 numRows, there'll be no segments // in the channel meta, so remove all compacted from segments related if !m.channel.hasSegment(compactedTo, true) { @@ -87,7 +86,6 @@ func (m *DeltaBufferManager) UpdateCompactedSegments() { // only store delBuf if EntriesNum > 0 if compactToDelBuff.EntriesNum > 0 { - m.pushOrFixHeap(compactedTo, compactToDelBuff) // We need to re-add the memorySize because m.Delete(segID) sub them all. m.usedMemory.Add(compactToDelBuff.GetMemorySize()) @@ -129,7 +127,8 @@ func (m *DeltaBufferManager) deleteFromHeap(buffer *DelDataBuf) { } func (m *DeltaBufferManager) StoreNewDeletes(segID UniqueID, pks []primaryKey, - tss []Timestamp, tr TimeRange, startPos, endPos *msgpb.MsgPosition) { + tss []Timestamp, tr TimeRange, startPos, endPos *msgpb.MsgPosition, +) { buffer, loaded := m.Load(segID) if !loaded { buffer = newDelDataBuf(segID) @@ -154,7 +153,6 @@ func (m *DeltaBufferManager) Delete(segID UniqueID) { m.usedMemory.Sub(buffer.GetMemorySize()) m.deleteFromHeap(buffer) m.channel.rollDeleteBuffer(segID) - } } @@ -165,7 +163,7 @@ func (m *DeltaBufferManager) popHeapItem() *Item { } func (m *DeltaBufferManager) ShouldFlushSegments() []UniqueID { - var memUsage = m.usedMemory.Load() + memUsage := m.usedMemory.Load() if memUsage < Params.DataNodeCfg.FlushDeleteBufferBytes.GetAsInt64() { return nil } @@ -181,12 +179,11 @@ func (m *DeltaBufferManager) ShouldFlushSegments() []UniqueID { memUsage -= segItem.memorySize if memUsage < Params.DataNodeCfg.FlushDeleteBufferBytes.GetAsInt64() { break - } } - //here we push all selected segment back into the heap - //in order to keep the heap semantically correct + // here we push all selected segment back into the heap + // in order to keep the heap semantically correct m.heapGuard.Lock() for _, segMem := range poppedItems { heap.Push(m.delBufHeap, segMem) @@ -334,7 +331,7 @@ func (ddb *DelDataBuf) Buffer(pks []primaryKey, tss []Timestamp, tr TimeRange, s varCharPk := pks[i].(*varCharPrimaryKey) bufSize += int64(len(varCharPk.Value)) } - //accumulate buf size for timestamp, which is 8 bytes + // accumulate buf size for timestamp, which is 8 bytes bufSize += 8 } @@ -430,13 +427,14 @@ func newBufferData(collSchema *schemapb.CollectionSchema) (*BufferData, error) { limit++ } - //TODO::xige-16 eval vec and string field + // TODO::xige-16 eval vec and string field return &BufferData{ buffer: &InsertData{Data: make(map[UniqueID]storage.FieldData)}, size: 0, limit: limit, tsFrom: math.MaxUint64, - tsTo: 0}, nil + tsTo: 0, + }, nil } func newDelDataBuf(segmentID UniqueID) *DelDataBuf { diff --git a/internal/datanode/buffer_test.go b/internal/datanode/buffer_test.go index 742e6ac7f645e..784169e4079c8 100644 --- a/internal/datanode/buffer_test.go +++ b/internal/datanode/buffer_test.go @@ -170,7 +170,7 @@ func Test_CompactSegBuff(t *testing.T) { }, delBufHeap: &PriorityQueue{}, } - //1. set compactTo and compactFrom + // 1. set compactTo and compactFrom targetSeg := &Segment{segmentID: 3333} targetSeg.setType(datapb.SegmentType_Flushed) @@ -190,7 +190,7 @@ func Test_CompactSegBuff(t *testing.T) { channelSegments[seg2.segmentID] = seg2 channelSegments[targetSeg.segmentID] = targetSeg - //2. set up deleteDataBuf for seg1 and seg2 + // 2. set up deleteDataBuf for seg1 and seg2 delDataBuf1 := newDelDataBuf(seg1.segmentID) delDataBuf1.EntriesNum++ delDataBuf1.updateStartAndEndPosition(nil, &msgpb.MsgPosition{Timestamp: 50}) @@ -203,12 +203,12 @@ func Test_CompactSegBuff(t *testing.T) { delBufferManager.updateMeta(seg2.segmentID, delDataBuf2) heap.Push(delBufferManager.delBufHeap, delDataBuf2.item) - //3. test compact + // 3. test compact delBufferManager.UpdateCompactedSegments() - //4. expect results in two aspects: - //4.1 compactedFrom segments are removed from delBufferManager - //4.2 compactedTo seg is set properly with correct entriesNum + // 4. expect results in two aspects: + // 4.1 compactedFrom segments are removed from delBufferManager + // 4.2 compactedTo seg is set properly with correct entriesNum _, seg1Exist := delBufferManager.Load(seg1.segmentID) _, seg2Exist := delBufferManager.Load(seg2.segmentID) assert.False(t, seg1Exist) @@ -221,7 +221,7 @@ func Test_CompactSegBuff(t *testing.T) { assert.NotNil(t, targetSegBuf.item) assert.Equal(t, targetSeg.segmentID, targetSegBuf.item.segmentID) - //5. test roll and evict (https://github.com/milvus-io/milvus/issues/20501) + // 5. test roll and evict (https://github.com/milvus-io/milvus/issues/20501) delBufferManager.channel.rollDeleteBuffer(targetSeg.segmentID) _, segCompactedToExist := delBufferManager.Load(targetSeg.segmentID) assert.False(t, segCompactedToExist) @@ -271,25 +271,61 @@ func TestUpdateCompactedSegments(t *testing.T) { expectedSegsRemain []UniqueID }{ - {"zero segments", false, - []UniqueID{}, []UniqueID{}, []UniqueID{}}, - {"segment no compaction", false, - []UniqueID{}, []UniqueID{}, []UniqueID{100, 101}}, - {"segment compacted", true, - []UniqueID{200}, []UniqueID{103}, []UniqueID{100, 101}}, - {"segment compacted 100>201", true, - []UniqueID{201}, []UniqueID{100}, []UniqueID{101, 201}}, - {"segment compacted 100+101>201", true, - []UniqueID{201, 201}, []UniqueID{100, 101}, []UniqueID{201}}, - {"segment compacted 100>201, 101>202", true, - []UniqueID{201, 202}, []UniqueID{100, 101}, []UniqueID{201, 202}}, + { + "zero segments", false, + []UniqueID{}, + []UniqueID{}, + []UniqueID{}, + }, + { + "segment no compaction", false, + []UniqueID{}, + []UniqueID{}, + []UniqueID{100, 101}, + }, + { + "segment compacted", true, + []UniqueID{200}, + []UniqueID{103}, + []UniqueID{100, 101}, + }, + { + "segment compacted 100>201", true, + []UniqueID{201}, + []UniqueID{100}, + []UniqueID{101, 201}, + }, + { + "segment compacted 100+101>201", true, + []UniqueID{201, 201}, + []UniqueID{100, 101}, + []UniqueID{201}, + }, + { + "segment compacted 100>201, 101>202", true, + []UniqueID{201, 202}, + []UniqueID{100, 101}, + []UniqueID{201, 202}, + }, // false - {"segment compacted 100>201", false, - []UniqueID{201}, []UniqueID{100}, []UniqueID{101}}, - {"segment compacted 100+101>201", false, - []UniqueID{201, 201}, []UniqueID{100, 101}, []UniqueID{}}, - {"segment compacted 100>201, 101>202", false, - []UniqueID{201, 202}, []UniqueID{100, 101}, []UniqueID{}}, + { + "segment compacted 100>201", false, + []UniqueID{201}, + []UniqueID{100}, + []UniqueID{101}, + }, + { + "segment compacted 100+101>201", false, + []UniqueID{201, 201}, + []UniqueID{100, 101}, + []UniqueID{}, + }, + { + "segment compacted 100>201, 101>202", false, + []UniqueID{201, 202}, + []UniqueID{100, 101}, + []UniqueID{}, + }, } for _, test := range tests { diff --git a/internal/datanode/channel_manager.go b/internal/datanode/channel_manager.go new file mode 100644 index 0000000000000..3407c8bb33e69 --- /dev/null +++ b/internal/datanode/channel_manager.go @@ -0,0 +1,505 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datanode + +import ( + "context" + "sync" + "time" + + "github.com/cockroachdb/errors" + "go.uber.org/atomic" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type releaseFunc func(channel string) + +type ChannelManager struct { + mu sync.RWMutex + dn *DataNode + + communicateCh chan *opState + runningFlowgraphs *flowgraphManager + opRunners *typeutil.ConcurrentMap[string, *opRunner] // channel -> runner + abnormals *typeutil.ConcurrentMap[int64, string] // OpID -> Channel + + releaseFunc releaseFunc + + closeCh chan struct{} + closeOnce sync.Once + closeWaiter sync.WaitGroup +} + +func NewChannelManager(dn *DataNode) *ChannelManager { + fm := newFlowgraphManager() + cm := ChannelManager{ + dn: dn, + + communicateCh: make(chan *opState, 100), + runningFlowgraphs: fm, + opRunners: typeutil.NewConcurrentMap[string, *opRunner](), + abnormals: typeutil.NewConcurrentMap[int64, string](), + + releaseFunc: fm.release, + + closeCh: make(chan struct{}), + } + + return &cm +} + +func (m *ChannelManager) Submit(info *datapb.ChannelWatchInfo) error { + channel := info.GetVchan().GetChannelName() + runner := m.getOrCreateRunner(channel) + return runner.Enqueue(info) +} + +func (m *ChannelManager) GetProgress(info *datapb.ChannelWatchInfo) *datapb.ChannelOperationProgressResponse { + m.mu.RLock() + defer m.mu.RUnlock() + resp := &datapb.ChannelOperationProgressResponse{ + Status: merr.Success(), + OpID: info.GetOpID(), + } + + channel := info.GetVchan().GetChannelName() + switch info.GetState() { + case datapb.ChannelWatchState_ToWatch: + if m.runningFlowgraphs.existWithOpID(channel, info.GetOpID()) { + resp.State = datapb.ChannelWatchState_WatchSuccess + return resp + } + + if runner, ok := m.opRunners.Get(channel); ok { + if runner.Exist(info.GetOpID()) { + resp.State = datapb.ChannelWatchState_ToWatch + } else { + resp.State = datapb.ChannelWatchState_WatchFailure + } + return resp + } + resp.State = datapb.ChannelWatchState_WatchFailure + return resp + + case datapb.ChannelWatchState_ToRelease: + if !m.runningFlowgraphs.exist(channel) { + resp.State = datapb.ChannelWatchState_ReleaseSuccess + return resp + } + if runner, ok := m.opRunners.Get(channel); ok && runner.Exist(info.GetOpID()) { + resp.State = datapb.ChannelWatchState_ToRelease + return resp + } + + resp.State = datapb.ChannelWatchState_ReleaseFailure + return resp + default: + err := merr.WrapErrParameterInvalid("ToWatch or ToRelease", info.GetState().String()) + log.Warn("fail to get progress", zap.Error(err)) + resp.Status = merr.Status(err) + return resp + } +} + +func (m *ChannelManager) Close() { + m.closeOnce.Do(func() { + m.opRunners.Range(func(channel string, runner *opRunner) bool { + runner.Close() + return true + }) + m.runningFlowgraphs.close() + close(m.closeCh) + m.closeWaiter.Wait() + }) +} + +func (m *ChannelManager) Start() { + m.closeWaiter.Add(2) + + go m.runningFlowgraphs.start(&m.closeWaiter) + go func() { + defer m.closeWaiter.Done() + log.Info("DataNode ChannelManager start") + for { + select { + case opState := <-m.communicateCh: + m.handleOpState(opState) + case <-m.closeCh: + log.Info("DataNode ChannelManager exit") + return + } + } + }() +} + +func (m *ChannelManager) handleOpState(opState *opState) { + m.mu.Lock() + defer m.mu.Unlock() + log := log.With( + zap.Int64("opID", opState.opID), + zap.String("channel", opState.channel), + zap.String("State", opState.state.String()), + ) + switch opState.state { + case datapb.ChannelWatchState_WatchSuccess: + log.Info("Success to watch") + m.runningFlowgraphs.Add(opState.fg) + m.finishOp(opState.opID, opState.channel) + + case datapb.ChannelWatchState_WatchFailure: + log.Info("Fail to watch") + m.finishOp(opState.opID, opState.channel) + + case datapb.ChannelWatchState_ReleaseSuccess: + log.Info("Success to release") + m.finishOp(opState.opID, opState.channel) + m.destoryRunner(opState.channel) + + case datapb.ChannelWatchState_ReleaseFailure: + log.Info("Fail to release, add channel to abnormal lists") + m.abnormals.Insert(opState.opID, opState.channel) + m.finishOp(opState.opID, opState.channel) + m.destoryRunner(opState.channel) + } +} + +func (m *ChannelManager) getOrCreateRunner(channel string) *opRunner { + runner, loaded := m.opRunners.GetOrInsert(channel, NewOpRunner(channel, m.dn, m.releaseFunc, m.communicateCh)) + if !loaded { + runner.Start() + } + return runner +} + +func (m *ChannelManager) destoryRunner(channel string) { + if runner, loaded := m.opRunners.GetAndRemove(channel); loaded { + runner.Close() + } +} + +func (m *ChannelManager) finishOp(opID int64, channel string) { + if runner, loaded := m.opRunners.Get(channel); loaded { + runner.FinishOp(opID) + } +} + +type opInfo struct { + tickler *tickler +} + +type opRunner struct { + channel string + dn *DataNode + releaseFunc releaseFunc + + guard sync.RWMutex + allOps map[UniqueID]*opInfo // opID -> tickler + opsInQueue chan *datapb.ChannelWatchInfo + resultCh chan *opState + + closeWg sync.WaitGroup + closeOnce sync.Once + closeCh chan struct{} +} + +func NewOpRunner(channel string, dn *DataNode, f releaseFunc, resultCh chan *opState) *opRunner { + return &opRunner{ + channel: channel, + dn: dn, + releaseFunc: f, + opsInQueue: make(chan *datapb.ChannelWatchInfo, 10), + allOps: make(map[UniqueID]*opInfo), + resultCh: resultCh, + closeCh: make(chan struct{}), + } +} + +func (r *opRunner) Start() { + r.closeWg.Add(1) + go func() { + defer r.closeWg.Done() + for { + select { + case info := <-r.opsInQueue: + r.NotifyState(r.Execute(info)) + case <-r.closeCh: + return + } + } + }() +} + +func (r *opRunner) FinishOp(opID UniqueID) { + r.guard.Lock() + defer r.guard.Unlock() + delete(r.allOps, opID) +} + +func (r *opRunner) Exist(opID UniqueID) bool { + r.guard.RLock() + defer r.guard.RUnlock() + _, ok := r.allOps[opID] + return ok +} + +func (r *opRunner) Enqueue(info *datapb.ChannelWatchInfo) error { + if info.GetState() != datapb.ChannelWatchState_ToWatch && + info.GetState() != datapb.ChannelWatchState_ToRelease { + return errors.New("Invalid channel watch state") + } + + r.guard.Lock() + defer r.guard.Unlock() + if _, ok := r.allOps[info.GetOpID()]; !ok { + r.opsInQueue <- info + r.allOps[info.GetOpID()] = &opInfo{} + } + return nil +} + +func (r *opRunner) UnfinishedOpSize() int { + r.guard.RLock() + defer r.guard.RUnlock() + return len(r.allOps) +} + +// Execute excutes channel operations, channel state is validated during enqueue +func (r *opRunner) Execute(info *datapb.ChannelWatchInfo) *opState { + log.Info("Start to execute channel operation", + zap.String("channel", info.GetVchan().GetChannelName()), + zap.Int64("opID", info.GetOpID()), + zap.String("state", info.GetState().String()), + ) + if info.GetState() == datapb.ChannelWatchState_ToWatch { + return r.watchWithTimer(info) + } + + // ToRelease state + return releaseWithTimer(r.releaseFunc, info.GetVchan().GetChannelName(), info.GetOpID()) +} + +// watchWithTimer will return WatchFailure after WatchTimeoutInterval +func (r *opRunner) watchWithTimer(info *datapb.ChannelWatchInfo) *opState { + opState := &opState{ + channel: info.GetVchan().GetChannelName(), + opID: info.GetOpID(), + } + log := log.With(zap.String("channel", opState.channel), zap.Int64("opID", opState.opID)) + + r.guard.Lock() + opInfo, ok := r.allOps[info.GetOpID()] + if !ok { + opState.state = datapb.ChannelWatchState_WatchFailure + return opState + } + tickler := newTickler() + opInfo.tickler = tickler + r.guard.Unlock() + + var ( + successSig = make(chan struct{}, 1) + waiter sync.WaitGroup + ) + + watchTimeout := Params.DataCoordCfg.WatchTimeoutInterval.GetAsDuration(time.Second) + ctx, cancel := context.WithTimeout(context.Background(), watchTimeout) + defer cancel() + + startTimer := func(wg *sync.WaitGroup) { + defer wg.Done() + + timer := time.NewTimer(watchTimeout) + defer timer.Stop() + + log.Info("Start timer for ToWatch operation", zap.Duration("timeout", watchTimeout)) + for { + select { + case <-timer.C: + // watch timeout + tickler.close() + cancel() + log.Info("Stop timer for ToWatch operation timeout", zap.Duration("timeout", watchTimeout)) + return + + case <-tickler.progressSig: + timer.Reset(watchTimeout) + + case <-successSig: + // watch success + log.Info("Stop timer for ToWatch operation succeeded", zap.Duration("timeout", watchTimeout)) + return + } + } + } + + waiter.Add(2) + go startTimer(&waiter) + go func() { + defer waiter.Done() + fg, err := executeWatch(ctx, r.dn, info, tickler) + if err != nil { + opState.state = datapb.ChannelWatchState_WatchFailure + } else { + opState.state = datapb.ChannelWatchState_WatchSuccess + opState.fg = fg + successSig <- struct{}{} + } + }() + + waiter.Wait() + return opState +} + +// releaseWithTimer will return ReleaseFailure after WatchTimeoutInterval +func releaseWithTimer(releaseFunc releaseFunc, channel string, opID UniqueID) *opState { + opState := &opState{ + channel: channel, + opID: opID, + } + var ( + successSig = make(chan struct{}, 1) + waiter sync.WaitGroup + ) + + log := log.With(zap.String("channel", channel)) + startTimer := func(wg *sync.WaitGroup) { + defer wg.Done() + releaseTimeout := Params.DataCoordCfg.WatchTimeoutInterval.GetAsDuration(time.Second) + timer := time.NewTimer(releaseTimeout) + defer timer.Stop() + + log.Info("Start timer for ToRelease operation", zap.Duration("timeout", releaseTimeout)) + for { + select { + case <-timer.C: + log.Info("Stop timer for ToRelease operation timeout", zap.Duration("timeout", releaseTimeout)) + opState.state = datapb.ChannelWatchState_ReleaseFailure + return + + case <-successSig: + log.Info("Stop timer for ToRelease operation succeeded", zap.Duration("timeout", releaseTimeout)) + opState.state = datapb.ChannelWatchState_ReleaseSuccess + return + } + } + } + + waiter.Add(1) + go startTimer(&waiter) + go func() { + // TODO: failure should panic this DN, but we're not sure how + // to recover when releaseFunc stuck. + // Whenever we see a stuck, it's a bug need to be fixed. + // In case of the unknown behavior after the stuck of release, + // we'll mark this channel abnormal in this DN. This goroutine might never return. + // + // The channel can still be balanced into other DNs, but not on this one. + // ExclusiveConsumer error happens when the same DN subscribes the same pchannel twice. + releaseFunc(opState.channel) + successSig <- struct{}{} + }() + + waiter.Wait() + return opState +} + +func (r *opRunner) NotifyState(state *opState) { + r.resultCh <- state +} + +func (r *opRunner) Close() { + r.guard.Lock() + for _, info := range r.allOps { + if info.tickler != nil { + info.tickler.close() + } + } + r.guard.Unlock() + + r.closeOnce.Do(func() { + close(r.closeCh) + r.closeWg.Wait() + }) +} + +type opState struct { + channel string + opID int64 + state datapb.ChannelWatchState + fg *dataSyncService +} + +// executeWatch will always return, won't be stuck, either success or fail. +func executeWatch(ctx context.Context, dn *DataNode, info *datapb.ChannelWatchInfo, tickler *tickler) (*dataSyncService, error) { + dataSyncService, err := newDataSyncService(ctx, dn, info, tickler) + if err != nil { + return nil, err + } + + dataSyncService.start() + + return dataSyncService, nil +} + +// tickler counts every time when called inc(), +type tickler struct { + count *atomic.Int32 + total *atomic.Int32 + closedSig *atomic.Bool + + progressSig chan struct{} +} + +func (t *tickler) inc() { + t.count.Inc() + t.progressSig <- struct{}{} +} + +func (t *tickler) setTotal(total int32) { + t.total.Store(total) +} + +// progress returns the count over total if total is set +// else just return the count number. +func (t *tickler) progress() int32 { + if t.total.Load() == 0 { + return t.count.Load() + } + return (t.count.Load() / t.total.Load()) * 100 +} + +func (t *tickler) close() { + t.closedSig.CompareAndSwap(false, true) +} + +func (t *tickler) closed() bool { + return t.closedSig.Load() +} + +func newTickler() *tickler { + return &tickler{ + count: atomic.NewInt32(0), + total: atomic.NewInt32(0), + closedSig: atomic.NewBool(false), + progressSig: make(chan struct{}, 200), + } +} diff --git a/internal/datanode/channel_manager_test.go b/internal/datanode/channel_manager_test.go new file mode 100644 index 0000000000000..16281dacc6028 --- /dev/null +++ b/internal/datanode/channel_manager_test.go @@ -0,0 +1,188 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datanode + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func TestChannelManagerSuite(t *testing.T) { + suite.Run(t, new(ChannelManagerSuite)) +} + +type ChannelManagerSuite struct { + suite.Suite + + node *DataNode + manager *ChannelManager +} + +func (s *ChannelManagerSuite) SetupTest() { + ctx := context.Background() + s.node = newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) + s.manager = NewChannelManager(s.node) +} + +func getWatchInfoByOpID(opID UniqueID, channel string, state datapb.ChannelWatchState) *datapb.ChannelWatchInfo { + return &datapb.ChannelWatchInfo{ + OpID: opID, + State: state, + Vchan: &datapb.VchannelInfo{ + CollectionID: 1, + ChannelName: channel, + }, + } +} + +func (s *ChannelManagerSuite) TearDownTest() { + s.manager.Close() +} + +func (s *ChannelManagerSuite) TestWatchFail() { + channel := "by-dev-rootcoord-dml-2" + paramtable.Get().Save(Params.DataCoordCfg.WatchTimeoutInterval.Key, "0.000001") + defer paramtable.Get().Reset(Params.DataCoordCfg.WatchTimeoutInterval.Key) + info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch) + s.Require().Equal(0, s.manager.opRunners.Len()) + err := s.manager.Submit(info) + s.Require().NoError(err) + + opState := <-s.manager.communicateCh + s.Require().NotNil(opState) + s.Equal(info.GetOpID(), opState.opID) + s.Equal(datapb.ChannelWatchState_WatchFailure, opState.state) + + s.manager.handleOpState(opState) + + resp := s.manager.GetProgress(info) + s.Equal(datapb.ChannelWatchState_WatchFailure, resp.GetState()) +} + +func (s *ChannelManagerSuite) TestReleaseStuck() { + var ( + channel = "by-dev-rootcoord-dml-2" + stuckSig = make(chan struct{}) + ) + s.manager.releaseFunc = func(channel string) { + stuckSig <- struct{}{} + } + + info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch) + s.Require().Equal(0, s.manager.opRunners.Len()) + err := s.manager.Submit(info) + s.Require().NoError(err) + + opState := <-s.manager.communicateCh + s.Require().NotNil(opState) + + s.manager.handleOpState(opState) + + releaseInfo := getWatchInfoByOpID(101, channel, datapb.ChannelWatchState_ToRelease) + paramtable.Get().Save(Params.DataCoordCfg.WatchTimeoutInterval.Key, "0.1") + defer paramtable.Get().Reset(Params.DataCoordCfg.WatchTimeoutInterval.Key) + + err = s.manager.Submit(releaseInfo) + s.NoError(err) + + opState = <-s.manager.communicateCh + s.Require().NotNil(opState) + s.Equal(datapb.ChannelWatchState_ReleaseFailure, opState.state) + s.manager.handleOpState(opState) + + s.Equal(1, s.manager.abnormals.Len()) + abchannel, ok := s.manager.abnormals.Get(releaseInfo.GetOpID()) + s.True(ok) + s.Equal(channel, abchannel) + + <-stuckSig + + resp := s.manager.GetProgress(releaseInfo) + s.Equal(datapb.ChannelWatchState_ReleaseFailure, resp.GetState()) +} + +func (s *ChannelManagerSuite) TestSubmitIdempotent() { + channel := "by-dev-rootcoord-dml-1" + + info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch) + s.Require().Equal(0, s.manager.opRunners.Len()) + + for i := 0; i < 10; i++ { + err := s.manager.Submit(info) + s.NoError(err) + } + + s.Equal(1, s.manager.opRunners.Len()) + s.True(s.manager.opRunners.Contain(channel)) + + runner, ok := s.manager.opRunners.Get(channel) + s.True(ok) + s.Equal(1, runner.UnfinishedOpSize()) +} + +func (s *ChannelManagerSuite) TestSubmitWatchAndRelease() { + channel := "by-dev-rootcoord-dml-0" + + info := getWatchInfoByOpID(100, channel, datapb.ChannelWatchState_ToWatch) + + err := s.manager.Submit(info) + s.NoError(err) + + opState := <-s.manager.communicateCh + s.NotNil(opState) + s.Equal(datapb.ChannelWatchState_WatchSuccess, opState.state) + s.NotNil(opState.fg) + s.Equal(info.GetOpID(), opState.fg.opID) + + resp := s.manager.GetProgress(info) + s.Equal(info.GetOpID(), resp.GetOpID()) + s.Equal(datapb.ChannelWatchState_ToWatch, resp.GetState()) + + s.manager.handleOpState(opState) + s.Equal(1, s.manager.runningFlowgraphs.getFlowGraphNum()) + s.True(s.manager.opRunners.Contain(info.GetVchan().GetChannelName())) + s.Equal(1, s.manager.opRunners.Len()) + + resp = s.manager.GetProgress(info) + s.Equal(info.GetOpID(), resp.GetOpID()) + s.Equal(datapb.ChannelWatchState_WatchSuccess, resp.GetState()) + + // release + info = getWatchInfoByOpID(101, channel, datapb.ChannelWatchState_ToRelease) + + err = s.manager.Submit(info) + s.NoError(err) + + opState = <-s.manager.communicateCh + s.NotNil(opState) + s.Equal(datapb.ChannelWatchState_ReleaseSuccess, opState.state) + s.manager.handleOpState(opState) + + resp = s.manager.GetProgress(info) + s.Equal(info.GetOpID(), resp.GetOpID()) + s.Equal(datapb.ChannelWatchState_ReleaseSuccess, resp.GetState()) + + s.Equal(0, s.manager.runningFlowgraphs.getFlowGraphNum()) + s.False(s.manager.opRunners.Contain(info.GetVchan().GetChannelName())) + s.Equal(0, s.manager.opRunners.Len()) +} diff --git a/internal/datanode/channel_meta.go b/internal/datanode/channel_meta.go index 76ce5b05c98c3..3dfa78e764f49 100644 --- a/internal/datanode/channel_meta.go +++ b/internal/datanode/channel_meta.go @@ -33,9 +33,9 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" milvus_storage "github.com/milvus-io/milvus-storage/go/storage" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/conc" @@ -62,7 +62,7 @@ type Channel interface { getCollectionAndPartitionID(segID UniqueID) (collID, partitionID UniqueID, err error) getChannelName(segID UniqueID) string - addSegment(req addSegmentReq) error + addSegment(ctx context.Context, req addSegmentReq) error getSegment(segID UniqueID) *Segment removeSegments(segID ...UniqueID) hasSegment(segID UniqueID, countFlushed bool) bool @@ -77,7 +77,7 @@ type Channel interface { listNewSegmentsStartPositions() []*datapb.SegmentStartPosition transferNewSegments(segmentIDs []UniqueID) updateSegmentPKRange(segID UniqueID, ids storage.FieldData) - mergeFlushedSegments(ctx context.Context, seg *Segment, planID UniqueID, compactedFrom []UniqueID) error + mergeFlushedSegments(ctx context.Context, seg *Segment, planID UniqueID, compactedFrom []UniqueID) listCompactedSegmentIDs() map[UniqueID][]UniqueID listSegmentIDsToSync(ts Timestamp) []UniqueID @@ -100,7 +100,11 @@ type Channel interface { // getTotalMemorySize returns the sum of memory sizes of segments. getTotalMemorySize() int64 - forceToSync() + setIsHighMemory(b bool) + getIsHighMemory() bool + + getFlushTs() Timestamp + setFlushTs(ts Timestamp) getSpace(segmentID UniqueID) (*milvus_storage.Space, bool) setSpace(segmentID UniqueID, space *milvus_storage.Space) @@ -118,7 +122,15 @@ type ChannelMeta struct { segMu sync.RWMutex segments map[UniqueID]*Segment - needToSync *atomic.Bool + // isHighMemory is intended to trigger the syncing of segments + // when segment's buffer consumes a significant amount of memory. + isHighMemory *atomic.Bool + + // flushTs is intended to trigger: + // 1. the syncing of segments when consumed ts exceeds flushTs; + // 2. the updating of channelCP when channelCP exceeds flushTs. + flushTs *atomic.Uint64 + syncPolicies []segmentSyncPolicy metaService *metaService @@ -143,8 +155,8 @@ type addSegmentReq struct { var _ Channel = &ChannelMeta{} -func newChannel(channelName string, collID UniqueID, schema *schemapb.CollectionSchema, rc types.RootCoord, cm storage.ChunkManager) *ChannelMeta { - metaService := newMetaService(rc, collID) +func newChannel(channelName string, collID UniqueID, schema *schemapb.CollectionSchema, broker broker.Broker, cm storage.ChunkManager) *ChannelMeta { + metaService := newMetaService(broker, collID) channel := ChannelMeta{ collectionID: collID, @@ -153,10 +165,12 @@ func newChannel(channelName string, collID UniqueID, schema *schemapb.Collection segments: make(map[UniqueID]*Segment), - needToSync: atomic.NewBool(false), + isHighMemory: atomic.NewBool(false), + flushTs: atomic.NewUint64(math.MaxUint64), syncPolicies: []segmentSyncPolicy{ syncPeriodically(), syncMemoryTooHigh(), + syncSegmentsAtTs(), }, metaService: metaService, @@ -202,7 +216,7 @@ func (c *ChannelMeta) getChannelName(segID UniqueID) string { // addSegment adds the segment to current channel. Segments can be added as *new*, *normal* or *flushed*. // Make sure to verify `channel.hasSegment(segID)` == false before calling `channel.addSegment()`. -func (c *ChannelMeta) addSegment(req addSegmentReq) error { +func (c *ChannelMeta) addSegment(ctx context.Context, req addSegmentReq) error { if req.collID != c.collectionID { log.Warn("failed to addSegment, collection mismatch", zap.Int64("current collection ID", req.collID), @@ -231,7 +245,7 @@ func (c *ChannelMeta) addSegment(req addSegmentReq) error { } seg.setType(req.segType) // Set up pk stats - err := c.InitPKstats(context.TODO(), seg, req.statsBinLogs, req.recoverTs) + err := c.InitPKstats(ctx, seg, req.statsBinLogs, req.recoverTs) if err != nil { log.Error("failed to init bloom filter", zap.Int64("segmentID", req.segID), @@ -285,7 +299,7 @@ func (c *ChannelMeta) listSegmentIDsToSync(ts Timestamp) []UniqueID { segIDsToSync := typeutil.NewUniqueSet() for _, policy := range c.syncPolicies { - segments := policy(validSegs, ts, c.needToSync) + segments := policy(validSegs, c, ts) for _, segID := range segments { segIDsToSync.Insert(segID) } @@ -670,7 +684,7 @@ func (c *ChannelMeta) getCollectionSchema(collID UniqueID, ts Timestamp) (*schem return c.collSchema, nil } -func (c *ChannelMeta) mergeFlushedSegments(ctx context.Context, seg *Segment, planID UniqueID, compactedFrom []UniqueID) error { +func (c *ChannelMeta) mergeFlushedSegments(ctx context.Context, seg *Segment, planID UniqueID, compactedFrom []UniqueID) { log := log.Ctx(ctx).With( zap.Int64("segmentID", seg.segmentID), zap.Int64("collectionID", seg.collectionID), @@ -679,20 +693,12 @@ func (c *ChannelMeta) mergeFlushedSegments(ctx context.Context, seg *Segment, pl zap.Int64("planID", planID), zap.String("channelName", c.channelName)) - if seg.collectionID != c.collectionID { - log.Warn("failed to mergeFlushedSegments, collection mismatch", - zap.Int64("current collection ID", seg.collectionID), - zap.Int64("expected collection ID", c.collectionID)) - return merr.WrapErrParameterInvalid(c.collectionID, seg.collectionID, "collection not match") - } - var inValidSegments []UniqueID for _, ID := range compactedFrom { // no such segments in channel or the segments are unflushed. if !c.hasSegment(ID, true) || c.hasSegment(ID, false) { inValidSegments = append(inValidSegments, ID) } - } if len(inValidSegments) > 0 { @@ -703,12 +709,6 @@ func (c *ChannelMeta) mergeFlushedSegments(ctx context.Context, seg *Segment, pl log.Info("merge flushed segments") c.segMu.Lock() defer c.segMu.Unlock() - select { - case <-ctx.Done(): - log.Warn("the context has been closed", zap.Error(ctx.Err())) - return errors.New("invalid context") - default: - } for _, ID := range compactedFrom { // the existent of the segments are already checked @@ -725,8 +725,6 @@ func (c *ChannelMeta) mergeFlushedSegments(ctx context.Context, seg *Segment, pl seg.setType(datapb.SegmentType_Flushed) c.segments[seg.segmentID] = seg } - - return nil } // for tests only @@ -803,6 +801,7 @@ func (c *ChannelMeta) listNotFlushedSegmentIDs() []UniqueID { } func (c *ChannelMeta) getChannelCheckpoint(ttPos *msgpb.MsgPosition) *msgpb.MsgPosition { + log := log.With().WithRateGroup("ChannelMeta", 1, 60) c.segMu.RLock() defer c.segMu.RUnlock() channelCP := &msgpb.MsgPosition{Timestamp: math.MaxUint64} @@ -824,8 +823,7 @@ func (c *ChannelMeta) getChannelCheckpoint(ttPos *msgpb.MsgPosition) *msgpb.MsgP channelCP = db.startPos } } - // TODO: maybe too many logs would print - log.Debug("getChannelCheckpoint for segment", zap.Int64("segmentID", seg.segmentID), + log.RatedDebug(10, "getChannelCheckpoint for segment", zap.Int64("segmentID", seg.segmentID), zap.Bool("isCurIBEmpty", seg.curInsertBuf == nil), zap.Bool("isCurDBEmpty", seg.curDeleteBuf == nil), zap.Int("len(hisIB)", len(seg.historyInsertBuf)), @@ -932,8 +930,12 @@ func (c *ChannelMeta) evictHistoryDeleteBuffer(segmentID UniqueID, endPos *msgpb log.Warn("cannot find segment when evictHistoryDeleteBuffer", zap.Int64("segmentID", segmentID)) } -func (c *ChannelMeta) forceToSync() { - c.needToSync.Store(true) +func (c *ChannelMeta) setIsHighMemory(b bool) { + c.isHighMemory.Store(b) +} + +func (c *ChannelMeta) getIsHighMemory() bool { + return c.isHighMemory.Load() } func (c *ChannelMeta) getTotalMemorySize() int64 { @@ -946,6 +948,14 @@ func (c *ChannelMeta) getTotalMemorySize() int64 { return res } +func (c *ChannelMeta) getFlushTs() Timestamp { + return c.flushTs.Load() +} + +func (c *ChannelMeta) setFlushTs(ts Timestamp) { + c.flushTs.Store(ts) +} + func (c *ChannelMeta) close() { c.closed.Store(true) } diff --git a/internal/datanode/channel_meta_test.go b/internal/datanode/channel_meta_test.go index 53c028833c744..f954335c626b8 100644 --- a/internal/datanode/channel_meta_test.go +++ b/internal/datanode/channel_meta_test.go @@ -33,22 +33,25 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" ) var channelMetaNodeTestDir = "/tmp/milvus_test/channel_meta" func TestNewChannel(t *testing.T) { - rc := &RootCoordFactory{} + broker := broker.NewMockBroker(t) cm := storage.NewLocalChunkManager(storage.RootPath(channelMetaNodeTestDir)) defer cm.RemoveWithPrefix(context.Background(), cm.RootPath()) - channel := newChannel("channel", 0, nil, rc, cm) + channel := newChannel("channel", 0, nil, broker, cm) assert.NotNil(t, channel) } @@ -110,17 +113,22 @@ func getSimpleFieldBinlog() *datapb.FieldBinlog { func TestChannelMeta_InnerFunction(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - rc := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } var ( + broker = broker.NewMockBroker(t) collID = UniqueID(1) cm = storage.NewLocalChunkManager(storage.RootPath(channelMetaNodeTestDir)) - channel = newChannel("insert-01", collID, nil, rc, cm) + channel = newChannel("insert-01", collID, nil, broker, cm) ) defer cm.RemoveWithPrefix(ctx, cm.RootPath()) + meta := NewMetaFactory().GetCollectionMeta(collID, "test_collection", schemapb.DataType_Int64) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: meta.GetSchema(), + }, nil) + require.False(t, channel.hasSegment(0, true)) require.False(t, channel.hasSegment(0, false)) @@ -129,6 +137,7 @@ func TestChannelMeta_InnerFunction(t *testing.T) { startPos := &msgpb.MsgPosition{ChannelName: "insert-01", Timestamp: Timestamp(100)} endPos := &msgpb.MsgPosition{ChannelName: "insert-01", Timestamp: Timestamp(200)} err = channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_New, segID: 0, @@ -202,7 +211,8 @@ func TestChannelMeta_getCollectionAndPartitionID(t *testing.T) { seg.setType(test.segType) channel := &ChannelMeta{ segments: map[UniqueID]*Segment{ - test.segID: &seg}, + test.segID: &seg, + }, } collID, parID, err := channel.getCollectionAndPartitionID(test.segID) @@ -216,16 +226,15 @@ func TestChannelMeta_getCollectionAndPartitionID(t *testing.T) { func TestChannelMeta_segmentFlushed(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - rc := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + broker := broker.NewMockBroker(t) collID := UniqueID(1) cm := storage.NewLocalChunkManager(storage.RootPath(channelMetaNodeTestDir)) defer cm.RemoveWithPrefix(ctx, cm.RootPath()) t.Run("Test coll mot match", func(t *testing.T) { - channel := newChannel("channel", collID, nil, rc, cm) + channel := newChannel("channel", collID, nil, broker, cm) err := channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_New, segID: 1, @@ -284,9 +293,8 @@ func TestChannelMeta_segmentFlushed(t *testing.T) { func TestChannelMeta_InterfaceMethod(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - rc := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + broker := broker.NewMockBroker(t) + f := MetaFactory{} cm := storage.NewLocalChunkManager(storage.RootPath(channelMetaNodeTestDir)) defer cm.RemoveWithPrefix(ctx, cm.RootPath()) @@ -307,7 +315,7 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { } for _, test := range tests { t.Run(test.description, func(t *testing.T) { - channel := newChannel("a", test.channelCollID, nil, rc, cm) + channel := newChannel("a", test.channelCollID, nil, broker, cm) if test.isvalid { channel.addFlushedSegmentWithPKs(100, test.incollID, 10, 1, primaryKeyData) @@ -342,9 +350,17 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - channel := newChannel("a", test.channelCollID, nil, rc, cm) + broker.ExpectedCalls = nil + meta := NewMetaFactory().GetCollectionMeta(test.channelCollID, "test_collection", schemapb.DataType_Int64) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: meta.GetSchema(), + }, nil) + channel := newChannel("a", test.channelCollID, nil, broker, cm) require.False(t, channel.hasSegment(test.inSegID, true)) err := channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_New, segID: test.inSegID, @@ -385,9 +401,10 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - channel := newChannel("a", test.channelCollID, nil, rc, &mockDataCM{}) + channel := newChannel("a", test.channelCollID, nil, broker, &mockDataCM{}) require.False(t, channel.hasSegment(test.inSegID, true)) err := channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_Normal, segID: test.inSegID, @@ -413,11 +430,12 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { }) t.Run("Test_addNormalSegmentWithNilDml", func(t *testing.T) { - channel := newChannel("a", 1, nil, rc, &mockDataCM{}) + channel := newChannel("a", 1, nil, broker, &mockDataCM{}) segID := int64(101) require.False(t, channel.hasSegment(segID, true)) assert.NotPanics(t, func() { err := channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_Normal, segID: segID, @@ -447,13 +465,23 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - channel := newChannel("a", test.channelCollID, nil, rc, cm) + channel := newChannel("a", test.channelCollID, nil, broker, cm) if test.metaServiceErr { channel.collSchema = nil - rc.setCollectionID(-1) + broker.ExpectedCalls = nil + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) } else { - rc.setCollectionID(1) + meta := f.GetCollectionMeta(test.channelCollID, "test_collection", schemapb.DataType_Int64) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + CollectionID: test.channelCollID, + CollectionName: "test_collection", + ShardsNum: common.DefaultShardsNum, + Schema: meta.GetSchema(), + }, nil) } s, err := channel.getCollectionSchema(test.inputCollID, Timestamp(0)) @@ -466,7 +494,7 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { } }) } - rc.setCollectionID(1) + broker.ExpectedCalls = nil }) t.Run("Test listAllSegmentIDs", func(t *testing.T) { @@ -519,10 +547,18 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { }) t.Run("Test_addSegmentMinIOLoadError", func(t *testing.T) { - channel := newChannel("a", 1, nil, rc, cm) + broker.ExpectedCalls = nil + meta := NewMetaFactory().GetCollectionMeta(1, "test_collection", schemapb.DataType_Int64) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: meta.GetSchema(), + }, nil) + channel := newChannel("a", 1, nil, broker, cm) channel.chunkManager = &mockDataCMError{} err := channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_Normal, segID: 1, @@ -534,6 +570,7 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { }) assert.Error(t, err) err = channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_Flushed, segID: 1, @@ -547,11 +584,12 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { }) t.Run("Test_addSegmentStatsError", func(t *testing.T) { - channel := newChannel("insert-01", 1, nil, rc, cm) + channel := newChannel("insert-01", 1, nil, broker, cm) channel.chunkManager = &mockDataCMStatsError{} var err error err = channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_Normal, segID: 1, @@ -563,6 +601,7 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { }) assert.Error(t, err) err = channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_Flushed, segID: 1, @@ -576,11 +615,12 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { }) t.Run("Test_addSegmentPkfilterError", func(t *testing.T) { - channel := newChannel("insert-01", 1, nil, rc, cm) + channel := newChannel("insert-01", 1, nil, broker, cm) channel.chunkManager = &mockPkfilterMergeError{} var err error err = channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_Normal, segID: 1, @@ -592,6 +632,7 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { }) assert.Error(t, err) err = channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_Flushed, segID: 1, @@ -605,7 +646,7 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { }) t.Run("Test_mergeFlushedSegments", func(t *testing.T) { - channel := newChannel("channel", 1, nil, rc, cm) + channel := newChannel("channel", 1, nil, broker, cm) primaryKeyData := &storage.Int64FieldData{ Data: []UniqueID{1}, @@ -658,12 +699,14 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { if !channel.hasSegment(4, false) { channel.removeSegments(4) - channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_Normal, - segID: 4, - collID: 1, - partitionID: 0, - }) + channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: 4, + collID: 1, + partitionID: 0, + }) } if channel.hasSegment(3, true) { @@ -676,12 +719,7 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { require.False(t, channel.hasSegment(3, true)) // tests start - err := channel.mergeFlushedSegments(context.Background(), test.inSeg, 100, test.inCompactedFrom) - if test.isValid { - assert.NoError(t, err) - } else { - assert.Error(t, err) - } + channel.mergeFlushedSegments(context.Background(), test.inSeg, 100, test.inCompactedFrom) if test.stored { assert.True(t, channel.hasSegment(3, true)) @@ -695,18 +733,14 @@ func TestChannelMeta_InterfaceMethod(t *testing.T) { } else { assert.False(t, channel.hasSegment(3, true)) } - }) } }) - } func TestChannelMeta_loadStats(t *testing.T) { f := &MetaFactory{} - rc := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + broker := broker.NewMockBroker(t) t.Run("list with merged stats log", func(t *testing.T) { meta := f.GetCollectionMeta(UniqueID(10001), "test_load_stats", schemapb.DataType_Int64) @@ -722,13 +756,13 @@ func TestChannelMeta_loadStats(t *testing.T) { partitionID: 2, } - //gen pk stats bytes + // gen pk stats bytes stats := storage.NewPrimaryKeyStats(106, int64(schemapb.DataType_Int64), 10) iCodec := storage.NewInsertCodecWithSchema(meta) cm := &mockCm{} - channel := newChannel("channel", 1, meta.Schema, rc, cm) + channel := newChannel("channel", 1, meta.Schema, broker, cm) channel.segments[seg1.segmentID] = seg1 channel.segments[seg2.segmentID] = seg2 @@ -744,7 +778,8 @@ func TestChannelMeta_loadStats(t *testing.T) { Binlogs: []*datapb.Binlog{{ /////// LogPath: path.Join(common.SegmentStatslogPath, metautil.JoinIDPath(1, 2, 1, 106, 10)), - }}}}, 0) + }}, + }}, 0) assert.NoError(t, err) // load flushed stats log @@ -759,7 +794,8 @@ func TestChannelMeta_loadStats(t *testing.T) { Binlogs: []*datapb.Binlog{{ /////// LogPath: path.Join(common.SegmentStatslogPath, metautil.JoinIDPath(1, 2, 2, 106), storage.CompoundStatsType.LogIdx()), - }}}}, 0) + }}, + }}, 0) assert.NoError(t, err) }) } @@ -767,21 +803,28 @@ func TestChannelMeta_loadStats(t *testing.T) { func TestChannelMeta_UpdatePKRange(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - rc := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + + broker := broker.NewMockBroker(t) collID := UniqueID(1) partID := UniqueID(2) chanName := "insert-02" startPos := &msgpb.MsgPosition{ChannelName: chanName, Timestamp: Timestamp(100)} endPos := &msgpb.MsgPosition{ChannelName: chanName, Timestamp: Timestamp(200)} + meta := NewMetaFactory().GetCollectionMeta(collID, "test_collection", schemapb.DataType_Int64) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: meta.GetSchema(), + }, nil) + cm := storage.NewLocalChunkManager(storage.RootPath(channelMetaNodeTestDir)) defer cm.RemoveWithPrefix(ctx, cm.RootPath()) - channel := newChannel("chanName", collID, nil, rc, cm) + channel := newChannel("chanName", collID, nil, broker, cm) channel.chunkManager = &mockDataCM{} err := channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_New, segID: 1, @@ -792,6 +835,7 @@ func TestChannelMeta_UpdatePKRange(t *testing.T) { }) assert.NoError(t, err) err = channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_Normal, segID: 2, @@ -820,15 +864,13 @@ func TestChannelMeta_UpdatePKRange(t *testing.T) { assert.True(t, segNew.isPKExist(pk)) assert.True(t, segNormal.isPKExist(pk)) } - } func TestChannelMeta_ChannelCP(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - rc := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + + broker := broker.NewMockBroker(t) mockVChannel := "fake-by-dev-rootcoord-dml-1-testchannelcp-v0" mockPChannel := "fake-by-dev-rootcoord-dml-1" @@ -839,13 +881,19 @@ func TestChannelMeta_ChannelCP(t *testing.T) { err := cm.RemoveWithPrefix(ctx, cm.RootPath()) assert.NoError(t, err) }() + meta := NewMetaFactory().GetCollectionMeta(collID, "test_collection", schemapb.DataType_Int64) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: meta.GetSchema(), + }, nil) t.Run("get and set", func(t *testing.T) { pos := &msgpb.MsgPosition{ ChannelName: mockPChannel, Timestamp: 1000, } - channel := newChannel(mockVChannel, collID, nil, rc, cm) + channel := newChannel(mockVChannel, collID, nil, broker, cm) channel.chunkManager = &mockDataCM{} position := channel.getChannelCheckpoint(pos) assert.NotNil(t, position) @@ -856,11 +904,13 @@ func TestChannelMeta_ChannelCP(t *testing.T) { t.Run("set insertBuffer&deleteBuffer then get", func(t *testing.T) { run := func(curInsertPos, curDeletePos *msgpb.MsgPosition, hisInsertPoss, hisDeletePoss []*msgpb.MsgPosition, - ttPos, expectedPos *msgpb.MsgPosition) { + ttPos, expectedPos *msgpb.MsgPosition, + ) { segmentID := UniqueID(1) - channel := newChannel(mockVChannel, collID, nil, rc, cm) + channel := newChannel(mockVChannel, collID, nil, broker, cm) channel.chunkManager = &mockDataCM{} err := channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_New, segID: segmentID, @@ -940,12 +990,17 @@ type ChannelMetaSuite struct { } func (s *ChannelMetaSuite) SetupSuite() { - rc := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + broker := broker.NewMockBroker(s.T()) + f := MetaFactory{} + meta := f.GetCollectionMeta(1, "testCollection", schemapb.DataType_Int64) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: meta.GetSchema(), + }, nil).Maybe() s.collID = 1 s.cm = storage.NewLocalChunkManager(storage.RootPath(channelMetaNodeTestDir)) - s.channel = newChannel("channel", s.collID, nil, rc, s.cm) + s.channel = newChannel("channel", s.collID, nil, broker, s.cm) s.vchanName = "channel" } @@ -955,34 +1010,40 @@ func (s *ChannelMetaSuite) TearDownSuite() { func (s *ChannelMetaSuite) SetupTest() { var err error - err = s.channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_New, - segID: 1, - collID: s.collID, - partitionID: s.partID, - startPos: &msgpb.MsgPosition{}, - endPos: nil, - }) + err = s.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 1, + collID: s.collID, + partitionID: s.partID, + startPos: &msgpb.MsgPosition{}, + endPos: nil, + }) s.Require().NoError(err) - err = s.channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_Normal, - segID: 2, - collID: s.collID, - partitionID: s.partID, - numOfRows: 10, - statsBinLogs: nil, - recoverTs: 0, - }) + err = s.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: 2, + collID: s.collID, + partitionID: s.partID, + numOfRows: 10, + statsBinLogs: nil, + recoverTs: 0, + }) s.Require().NoError(err) - err = s.channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_Flushed, - segID: 3, - collID: s.collID, - partitionID: s.partID, - numOfRows: 10, - statsBinLogs: nil, - recoverTs: 0, - }) + err = s.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: 3, + collID: s.collID, + partitionID: s.partID, + numOfRows: 10, + statsBinLogs: nil, + recoverTs: 0, + }) s.Require().NoError(err) } @@ -1057,13 +1118,18 @@ type ChannelMetaMockSuite struct { } func (s *ChannelMetaMockSuite) SetupTest() { - rc := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + broker := broker.NewMockBroker(s.T()) + f := MetaFactory{} + meta := f.GetCollectionMeta(1, "testCollection", schemapb.DataType_Int64) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: meta.GetSchema(), + }, nil).Maybe() s.cm = mocks.NewChunkManager(s.T()) s.collID = 1 - s.channel = newChannel("channel", s.collID, nil, rc, s.cm) + s.channel = newChannel("channel", s.collID, nil, broker, s.cm) s.vchanName = "channel" } @@ -1081,20 +1147,22 @@ func (s *ChannelMetaMockSuite) TestAddSegment_SkipBFLoad() { <-ch }).Return([][]byte{}, nil) - err := s.channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_Flushed, - segID: 100, - collID: s.collID, - partitionID: s.partID, - statsBinLogs: []*datapb.FieldBinlog{ - { - FieldID: 106, - Binlogs: []*datapb.Binlog{ - {LogPath: "rootPath/stats/1/0/100/10001"}, + err := s.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: 100, + collID: s.collID, + partitionID: s.partID, + statsBinLogs: []*datapb.FieldBinlog{ + { + FieldID: 106, + Binlogs: []*datapb.Binlog{ + {LogPath: "rootPath/stats/1/0/100/10001"}, + }, }, }, - }, - }) + }) s.NoError(err) @@ -1117,20 +1185,22 @@ func (s *ChannelMetaMockSuite) TestAddSegment_SkipBFLoad() { <-ch }).Return(nil, storage.WrapErrNoSuchKey("rootPath/stats/1/0/100/10001")) - err := s.channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_Flushed, - segID: 100, - collID: s.collID, - partitionID: s.partID, - statsBinLogs: []*datapb.FieldBinlog{ - { - FieldID: 106, - Binlogs: []*datapb.Binlog{ - {LogPath: "rootPath/stats/1/0/100/10001"}, + err := s.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: 100, + collID: s.collID, + partitionID: s.partID, + statsBinLogs: []*datapb.FieldBinlog{ + { + FieldID: 106, + Binlogs: []*datapb.Binlog{ + {LogPath: "rootPath/stats/1/0/100/10001"}, + }, }, }, - }, - }) + }) s.NoError(err) @@ -1163,20 +1233,22 @@ func (s *ChannelMetaMockSuite) TestAddSegment_SkipBFLoad() { []byte("ABC"), }, nil) - err := s.channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_Flushed, - segID: 100, - collID: s.collID, - partitionID: s.partID, - statsBinLogs: []*datapb.FieldBinlog{ - { - FieldID: 106, - Binlogs: []*datapb.Binlog{ - {LogPath: "rootPath/stats/1/0/100/10001"}, + err := s.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: 100, + collID: s.collID, + partitionID: s.partID, + statsBinLogs: []*datapb.FieldBinlog{ + { + FieldID: 106, + Binlogs: []*datapb.Binlog{ + {LogPath: "rootPath/stats/1/0/100/10001"}, + }, }, }, - }, - }) + }) s.NoError(err) @@ -1229,20 +1301,22 @@ func (s *ChannelMetaMockSuite) TestAddSegment_SkipBFLoad2() { return nil }) - err := s.channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_Flushed, - segID: 100, - collID: s.collID, - partitionID: s.partID, - statsBinLogs: []*datapb.FieldBinlog{ - { - FieldID: 106, - Binlogs: []*datapb.Binlog{ - {LogPath: "rootPath/stats/1/0/100/10001"}, + err := s.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: 100, + collID: s.collID, + partitionID: s.partID, + statsBinLogs: []*datapb.FieldBinlog{ + { + FieldID: 106, + Binlogs: []*datapb.Binlog{ + {LogPath: "rootPath/stats/1/0/100/10001"}, + }, }, }, - }, - }) + }) s.NoError(err) diff --git a/internal/datanode/compaction_executor_test.go b/internal/datanode/compaction_executor_test.go index b3b92a7702aa4..107eddcd16091 100644 --- a/internal/datanode/compaction_executor_test.go +++ b/internal/datanode/compaction_executor_test.go @@ -20,8 +20,9 @@ import ( "context" "testing" - "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/proto/datapb" ) func TestCompactionExecutor(t *testing.T) { @@ -114,7 +115,6 @@ func TestCompactionExecutor(t *testing.T) { t.FailNow() } }) - } func newMockCompactor(isvalid bool) *mockCompactor { @@ -143,7 +143,6 @@ func (mc *mockCompactor) complete() { } func (mc *mockCompactor) injectDone(success bool) { - } func (mc *mockCompactor) compact() (*datapb.CompactionResult, error) { diff --git a/internal/datanode/compactor.go b/internal/datanode/compactor.go index ec2db90246c87..8e265e470348e 100644 --- a/internal/datanode/compactor.go +++ b/internal/datanode/compactor.go @@ -22,16 +22,10 @@ import ( "path" "strconv" "strings" - "sync" "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/metautil" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "go.uber.org/zap" - "golang.org/x/sync/errgroup" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" milvus_storage "github.com/milvus-io/milvus-storage/go/storage" @@ -42,8 +36,11 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -91,9 +88,6 @@ type compactionTask struct { space *milvus_storage.Space } -// check if compactionTask implements compactor -var _ compactor = (*compactionTask)(nil) - func newCompactionTask( ctx context.Context, dl downloader, @@ -102,8 +96,8 @@ func newCompactionTask( fm flushManager, alloc allocator.Allocator, plan *datapb.CompactionPlan, - chunkManager storage.ChunkManager) *compactionTask { - + chunkManager storage.ChunkManager, +) *compactionTask { ctx1, cancel := context.WithCancel(ctx) return &compactionTask{ ctx: ctx1, @@ -158,7 +152,7 @@ func (t *compactionTask) mergeDeltalogs(dBlobs map[UniqueID][]*Blob) (map[interf mergeStart := time.Now() dCodec := storage.NewDeleteCodec() - var pk2ts = make(map[interface{}]Timestamp) + pk2ts := make(map[interface{}]Timestamp) for _, blobs := range dBlobs { _, _, dData, err := dCodec.Deserialize(blobs) @@ -190,7 +184,8 @@ func (t *compactionTask) uploadRemainLog( stats *storage.PrimaryKeyStats, totRows int64, fID2Content map[UniqueID][]interface{}, - fID2Type map[UniqueID]schemapb.DataType) (map[UniqueID]*datapb.FieldBinlog, map[UniqueID]*datapb.FieldBinlog, error) { + fID2Type map[UniqueID]schemapb.DataType, +) (map[UniqueID]*datapb.FieldBinlog, map[UniqueID]*datapb.FieldBinlog, error) { var iData *InsertData // remain insert data @@ -226,9 +221,11 @@ func (t *compactionTask) uploadSingleInsertLog( partID UniqueID, meta *etcdpb.CollectionMeta, fID2Content map[UniqueID][]interface{}, - fID2Type map[UniqueID]schemapb.DataType) (map[UniqueID]*datapb.FieldBinlog, error) { + fID2Type map[UniqueID]schemapb.DataType, +) (map[UniqueID]*datapb.FieldBinlog, error) { iData := &InsertData{ - Data: make(map[storage.FieldID]storage.FieldData)} + Data: make(map[storage.FieldID]storage.FieldData), + } for fID, content := range fID2Content { tp, ok := fID2Type[fID] @@ -259,7 +256,8 @@ func (t *compactionTask) merge( targetSegID UniqueID, partID UniqueID, meta *etcdpb.CollectionMeta, - delta map[interface{}]Timestamp) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, int64, error) { + delta map[interface{}]Timestamp, +) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, int64, error) { log := log.With(zap.Int64("planID", t.getPlanID())) mergeStart := time.Now() @@ -422,7 +420,7 @@ func (t *compactionTask) merge( } fID2Content[fID] = append(fID2Content[fID], vInter) } - //update pk to new stats log + // update pk to new stats log stats.Update(v.PK) currentRows++ @@ -496,7 +494,6 @@ func (t *compactionTask) compact() (*datapb.CompactionResult, error) { var targetSegID UniqueID var err error switch { - case t.plan.GetType() == datapb.CompactionType_UndefinedCompaction: log.Warn("compact wrong, compaction type undefined") return nil, errCompactionTypeUndifined @@ -630,18 +627,11 @@ func (t *compactionTask) compact() (*datapb.CompactionResult, error) { <-ti.Injected() log.Info("compact inject elapse", zap.Duration("elapse", time.Since(injectStart))) - var ( - // SegmentID to deltaBlobs - dblobs = make(map[UniqueID][]*Blob) - dmu sync.Mutex - ) - + dblobs := make(map[UniqueID][]*Blob) allPath := make([][]string, 0) downloadStart := time.Now() - g, gCtx := errgroup.WithContext(ctxTimeout) for _, s := range t.plan.GetSegmentBinlogs() { - // Get the number of field binlog files from non-empty segment var binlogNum int for _, b := range s.GetFieldBinlogs() { @@ -665,27 +655,24 @@ func (t *compactionTask) compact() (*datapb.CompactionResult, error) { } segID := s.GetSegmentID() + paths := make([]string, 0) for _, d := range s.GetDeltalogs() { for _, l := range d.GetBinlogs() { path := l.GetLogPath() - g.Go(func() error { - bs, err := t.download(gCtx, []string{path}) - if err != nil { - log.Warn("compact download deltalogs wrong", zap.String("path", path), zap.Error(err)) - return err - } - - dmu.Lock() - dblobs[segID] = append(dblobs[segID], bs...) - dmu.Unlock() + paths = append(paths, path) + } + } - return nil - }) + if len(paths) != 0 { + bs, err := t.download(ctxTimeout, paths) + if err != nil { + log.Warn("compact download deltalogs wrong", zap.Int64("segment", segID), zap.Strings("path", paths), zap.Error(err)) + return nil, err } + dblobs[segID] = append(dblobs[segID], bs...) } } - err = g.Wait() log.Info("compact download deltalogs elapse", zap.Duration("elapse", time.Since(downloadStart))) if err != nil { @@ -724,7 +711,7 @@ func (t *compactionTask) compact() (*datapb.CompactionResult, error) { ) log.Info("compact overall elapse", zap.Duration("elapse", time.Since(compactStart))) - metrics.DataNodeCompactionLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(t.tr.ElapseSpan().Seconds()) + metrics.DataNodeCompactionLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(t.tr.ElapseSpan().Milliseconds())) metrics.DataNodeCompactionLatencyInQueue.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(durInQueue.Milliseconds())) return pack, nil @@ -743,7 +730,7 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{} var rst storage.FieldData switch schemaDataType { case schemapb.DataType_Bool: - var data = &storage.BoolFieldData{ + data := &storage.BoolFieldData{ Data: make([]bool, 0, len(content)), } @@ -757,7 +744,7 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{} rst = data case schemapb.DataType_Int8: - var data = &storage.Int8FieldData{ + data := &storage.Int8FieldData{ Data: make([]int8, 0, len(content)), } @@ -771,7 +758,7 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{} rst = data case schemapb.DataType_Int16: - var data = &storage.Int16FieldData{ + data := &storage.Int16FieldData{ Data: make([]int16, 0, len(content)), } @@ -785,7 +772,7 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{} rst = data case schemapb.DataType_Int32: - var data = &storage.Int32FieldData{ + data := &storage.Int32FieldData{ Data: make([]int32, 0, len(content)), } @@ -799,7 +786,7 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{} rst = data case schemapb.DataType_Int64: - var data = &storage.Int64FieldData{ + data := &storage.Int64FieldData{ Data: make([]int64, 0, len(content)), } @@ -813,7 +800,7 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{} rst = data case schemapb.DataType_Float: - var data = &storage.FloatFieldData{ + data := &storage.FloatFieldData{ Data: make([]float32, 0, len(content)), } @@ -827,7 +814,7 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{} rst = data case schemapb.DataType_Double: - var data = &storage.DoubleFieldData{ + data := &storage.DoubleFieldData{ Data: make([]float64, 0, len(content)), } @@ -841,7 +828,7 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{} rst = data case schemapb.DataType_String, schemapb.DataType_VarChar: - var data = &storage.StringFieldData{ + data := &storage.StringFieldData{ Data: make([]string, 0, len(content)), } @@ -855,7 +842,7 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{} rst = data case schemapb.DataType_JSON: - var data = &storage.JSONFieldData{ + data := &storage.JSONFieldData{ Data: make([][]byte, 0, len(content)), } @@ -869,7 +856,7 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{} rst = data case schemapb.DataType_FloatVector: - var data = &storage.FloatVectorFieldData{ + data := &storage.FloatVectorFieldData{ Data: []float32{}, } @@ -884,8 +871,24 @@ func interface2FieldData(schemaDataType schemapb.DataType, content []interface{} data.Dim = len(data.Data) / int(numRows) rst = data + case schemapb.DataType_Float16Vector: + data := &storage.Float16VectorFieldData{ + Data: []byte{}, + } + + for _, c := range content { + r, ok := c.([]byte) + if !ok { + return nil, errTransferType + } + data.Data = append(data.Data, r...) + } + + data.Dim = len(data.Data) / 2 / int(numRows) + rst = data + case schemapb.DataType_BinaryVector: - var data = &storage.BinaryVectorFieldData{ + data := &storage.BinaryVectorFieldData{ Data: []byte{}, } diff --git a/internal/datanode/compactor_test.go b/internal/datanode/compactor_test.go index 6aaf1ec45fa36..06d6833801391 100644 --- a/internal/datanode/compactor_test.go +++ b/internal/datanode/compactor_test.go @@ -24,6 +24,7 @@ import ( "testing" "time" + "github.com/cockroachdb/errors" "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -34,12 +35,13 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/datanode/allocator" + "github.com/milvus-io/milvus/internal/datanode/broker" memkv "github.com/milvus-io/milvus/internal/kv/mem" - "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" ) @@ -52,10 +54,18 @@ func TestCompactionTaskInnerMethods(t *testing.T) { cm := storage.NewLocalChunkManager(storage.RootPath(compactTestDir)) defer cm.RemoveWithPrefix(ctx, cm.RootPath()) t.Run("Test getSegmentMeta", func(t *testing.T) { - rc := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } - channel := newChannel("a", 1, nil, rc, cm) + f := MetaFactory{} + meta := f.GetCollectionMeta(1, "testCollection", schemapb.DataType_Int64) + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, int64(1), mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + CollectionID: 1, + CollectionName: "testCollection", + Schema: meta.GetSchema(), + ShardsNum: common.DefaultShardsNum, + }, nil) + channel := newChannel("a", 1, nil, broker, cm) var err error task := &compactionTask{ @@ -66,14 +76,16 @@ func TestCompactionTaskInnerMethods(t *testing.T) { _, _, _, err = task.getSegmentMeta(100) assert.Error(t, err) - err = channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_New, - segID: 100, - collID: 1, - partitionID: 10, - startPos: new(msgpb.MsgPosition), - endPos: nil, - }) + err = channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 100, + collID: 1, + partitionID: 10, + startPos: new(msgpb.MsgPosition), + endPos: nil, + }) require.NoError(t, err) collID, partID, meta, err := task.getSegmentMeta(100) @@ -82,7 +94,9 @@ func TestCompactionTaskInnerMethods(t *testing.T) { assert.Equal(t, UniqueID(10), partID) assert.NotNil(t, meta) - rc.setCollectionID(-2) + broker.ExpectedCalls = nil + broker.EXPECT().DescribeCollection(mock.Anything, int64(1), mock.Anything). + Return(nil, errors.New("mock")) task.Channel.(*ChannelMeta).collSchema = nil _, _, _, err = task.getSegmentMeta(100) assert.Error(t, err) @@ -108,6 +122,7 @@ func TestCompactionTaskInnerMethods(t *testing.T) { {true, schemapb.DataType_JSON, []interface{}{[]byte("{\"key\":\"value\"}"), []byte("{\"hello\":\"world\"}")}, "valid json"}, {true, schemapb.DataType_FloatVector, []interface{}{[]float32{1.0, 2.0}}, "valid floatvector"}, {true, schemapb.DataType_BinaryVector, []interface{}{[]byte{255}}, "valid binaryvector"}, + {true, schemapb.DataType_Float16Vector, []interface{}{[]byte{255, 255, 255, 255}}, "valid float16vector"}, {false, schemapb.DataType_Bool, []interface{}{1, 2}, "invalid bool"}, {false, schemapb.DataType_Int8, []interface{}{nil, nil}, "invalid int8"}, {false, schemapb.DataType_Int16, []interface{}{nil, nil}, "invalid int16"}, @@ -119,6 +134,7 @@ func TestCompactionTaskInnerMethods(t *testing.T) { {false, schemapb.DataType_JSON, []interface{}{nil, nil}, "invalid json"}, {false, schemapb.DataType_FloatVector, []interface{}{nil, nil}, "invalid floatvector"}, {false, schemapb.DataType_BinaryVector, []interface{}{nil, nil}, "invalid binaryvector"}, + {false, schemapb.DataType_Float16Vector, []interface{}{nil, nil}, "invalid float16vector"}, {false, schemapb.DataType_None, nil, "invalid data type"}, } @@ -135,7 +151,6 @@ func TestCompactionTaskInnerMethods(t *testing.T) { } }) } - }) t.Run("Test mergeDeltalogs", func(t *testing.T) { @@ -215,14 +230,24 @@ func TestCompactionTaskInnerMethods(t *testing.T) { }{ { 0, nil, nil, - 100, []UniqueID{1, 2, 3}, []Timestamp{20000, 30000, 20005}, - 200, []UniqueID{4, 5, 6}, []Timestamp{50000, 50001, 50002}, + 100, + []UniqueID{1, 2, 3}, + []Timestamp{20000, 30000, 20005}, + 200, + []UniqueID{4, 5, 6}, + []Timestamp{50000, 50001, 50002}, 6, "2 segments", }, { - 300, []UniqueID{10, 20}, []Timestamp{20001, 40001}, - 100, []UniqueID{1, 2, 3}, []Timestamp{20000, 30000, 20005}, - 200, []UniqueID{4, 5, 6}, []Timestamp{50000, 50001, 50002}, + 300, + []UniqueID{10, 20}, + []Timestamp{20001, 40001}, + 100, + []UniqueID{1, 2, 3}, + []Timestamp{20000, 30000, 20005}, + 200, + []UniqueID{4, 5, 6}, + []Timestamp{50000, 50001, 50002}, 8, "3 segments", }, } @@ -255,26 +280,25 @@ func TestCompactionTaskInnerMethods(t *testing.T) { }) } }) - }) t.Run("Test merge", func(t *testing.T) { collectionID := int64(1) meta := NewMetaFactory().GetCollectionMeta(collectionID, "test", schemapb.DataType_Int64) - rc := &mocks.RootCoord{} - rc.EXPECT().DescribeCollection(mock.Anything, mock.Anything). + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ Schema: meta.GetSchema(), - }, nil) - channel := newChannel("a", collectionID, meta.GetSchema(), rc, nil) + }, nil).Maybe() + + channel := newChannel("a", collectionID, meta.GetSchema(), broker, nil) channel.segments[1] = &Segment{numRows: 10} alloc := allocator.NewMockAllocator(t) alloc.EXPECT().GetGenerator(mock.Anything, mock.Anything).Call.Return(validGeneratorFn, nil) alloc.EXPECT().AllocOne().Return(0, nil) t.Run("Merge without expiration", func(t *testing.T) { - mockbIO := &binlogIO{cm, alloc} paramtable.Get().Save(Params.CommonCfg.EntityExpirationTTL.Key, "0") iData := genInsertDataWithExpiredTS() @@ -302,8 +326,10 @@ func TestCompactionTaskInnerMethods(t *testing.T) { Channel: channel, downloader: mockbIO, uploader: mockbIO, done: make(chan struct{}, 1), plan: &datapb.CompactionPlan{ SegmentBinlogs: []*datapb.CompactionSegmentBinlogs{ - {SegmentID: 1}}, - }} + {SegmentID: 1}, + }, + }, + } inPaths, statsPaths, numOfRow, err := ct.merge(context.Background(), allPaths, 2, 0, meta, dm) assert.NoError(t, err) assert.Equal(t, int64(2), numOfRow) @@ -344,8 +370,10 @@ func TestCompactionTaskInnerMethods(t *testing.T) { Channel: channel, downloader: mockbIO, uploader: mockbIO, done: make(chan struct{}, 1), plan: &datapb.CompactionPlan{ SegmentBinlogs: []*datapb.CompactionSegmentBinlogs{ - {SegmentID: 1}}, - }} + {SegmentID: 1}, + }, + }, + } inPaths, statsPaths, numOfRow, err := ct.merge(context.Background(), allPaths, 2, 0, meta, dm) assert.NoError(t, err) assert.Equal(t, int64(2), numOfRow) @@ -357,7 +385,6 @@ func TestCompactionTaskInnerMethods(t *testing.T) { }) // set Params.DataNodeCfg.BinLogMaxSize.Key = 1 to generate multi binlogs, each has only one row t.Run("Merge without expiration3", func(t *testing.T) { - mockbIO := &binlogIO{cm, alloc} paramtable.Get().Save(Params.CommonCfg.EntityExpirationTTL.Key, "0") BinLogMaxSize := Params.DataNodeCfg.BinLogMaxSize.GetAsInt() @@ -390,8 +417,10 @@ func TestCompactionTaskInnerMethods(t *testing.T) { Channel: channel, downloader: mockbIO, uploader: mockbIO, done: make(chan struct{}, 1), plan: &datapb.CompactionPlan{ SegmentBinlogs: []*datapb.CompactionSegmentBinlogs{ - {SegmentID: 1}}, - }} + {SegmentID: 1}, + }, + }, + } inPaths, statsPaths, numOfRow, err := ct.merge(context.Background(), allPaths, 2, 0, meta, dm) assert.NoError(t, err) assert.Equal(t, int64(2), numOfRow) @@ -438,7 +467,8 @@ func TestCompactionTaskInnerMethods(t *testing.T) { plan: &datapb.CompactionPlan{ CollectionTtl: 864000, SegmentBinlogs: []*datapb.CompactionSegmentBinlogs{ - {SegmentID: 1}}, + {SegmentID: 1}, + }, }, done: make(chan struct{}, 1), } @@ -478,8 +508,10 @@ func TestCompactionTaskInnerMethods(t *testing.T) { Channel: channel, downloader: mockbIO, uploader: mockbIO, done: make(chan struct{}, 1), plan: &datapb.CompactionPlan{ SegmentBinlogs: []*datapb.CompactionSegmentBinlogs{ - {SegmentID: 1}}, - }} + {SegmentID: 1}, + }, + }, + } _, _, _, err = ct.merge(context.Background(), allPaths, 2, 0, &etcdpb.CollectionMeta{ Schema: &schemapb.CollectionSchema{Fields: []*schemapb.FieldSchema{ {DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ @@ -522,7 +554,8 @@ func TestCompactionTaskInnerMethods(t *testing.T) { {DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ {Key: common.DimKey, Value: "bad_dim"}, }}, - }}}, dm) + }}, + }, dm) assert.Error(t, err) }) }) @@ -600,23 +633,22 @@ func TestCompactionTaskInnerMethods(t *testing.T) { }) t.Run("Test getNumRows error", func(t *testing.T) { - rc := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } cm := &mockCm{} + broker := broker.NewMockBroker(t) ct := &compactionTask{ - Channel: newChannel("channel", 1, nil, rc, cm), + Channel: newChannel("channel", 1, nil, broker, cm), plan: &datapb.CompactionPlan{ SegmentBinlogs: []*datapb.CompactionSegmentBinlogs{ { SegmentID: 1, - }}, + }, + }, }, done: make(chan struct{}, 1), } - //segment not in channel + // segment not in channel _, err := ct.getNumRows() assert.Error(t, err) }) @@ -665,7 +697,7 @@ func TestCompactionTaskInnerMethods(t *testing.T) { stats := storage.NewPrimaryKeyStats(106, int64(schemapb.DataType_Int64), 10) ct := &compactionTask{ - uploader: &binlogIO{&mockCm{errMultiSave: true}, alloc}, + uploader: &binlogIO{&mockCm{errSave: true}, alloc}, done: make(chan struct{}, 1), } @@ -780,20 +812,27 @@ func TestCompactorInterfaceMethods(t *testing.T) { } for _, c := range cases { - rc := &RootCoordFactory{ - pkType: c.pkType, - } + collName := "test_compact_coll_name" + meta := NewMetaFactory().GetCollectionMeta(c.colID, collName, c.pkType) + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: meta.GetSchema(), + CollectionID: c.colID, + CollectionName: collName, + ShardsNum: common.DefaultShardsNum, + }, nil) mockfm := &mockFlushManager{} mockKv := memkv.NewMemoryKV() mockbIO := &binlogIO{cm, alloc} - channel := newChannel("a", c.colID, nil, rc, cm) + channel := newChannel("a", c.colID, nil, broker, cm) channel.addFlushedSegmentWithPKs(c.segID1, c.colID, c.parID, 2, c.iData1) channel.addFlushedSegmentWithPKs(c.segID2, c.colID, c.parID, 2, c.iData2) require.True(t, channel.hasSegment(c.segID1, true)) require.True(t, channel.hasSegment(c.segID2, true)) - meta := NewMetaFactory().GetCollectionMeta(c.colID, "test_compact_coll_name", c.pkType) iData1 := genInsertDataWithPKs(c.pks1, c.pkType) dData1 := &DeleteData{ Pks: []primaryKey{c.pks1[0]}, @@ -888,23 +927,31 @@ func TestCompactorInterfaceMethods(t *testing.T) { // The merged segment 19530 should only contain 2 rows and both pk=2 // Both pk = 1 rows of the two segments are compacted. var collID, partID, segID1, segID2 UniqueID = 1, 10, 200, 201 + var collName string = "test_compact_coll_name" alloc := allocator.NewMockAllocator(t) alloc.EXPECT().AllocOne().Call.Return(int64(19530), nil) alloc.EXPECT().GetGenerator(mock.Anything, mock.Anything).Call.Return(validGeneratorFn, nil) - rc := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + + meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name", schemapb.DataType_Int64) + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: meta.GetSchema(), + CollectionID: collID, + CollectionName: collName, + ShardsNum: common.DefaultShardsNum, + }, nil) mockfm := &mockFlushManager{} mockbIO := &binlogIO{cm, alloc} - channel := newChannel("channelname", collID, nil, rc, cm) + channel := newChannel("channelname", collID, nil, broker, cm) channel.addFlushedSegmentWithPKs(segID1, collID, partID, 2, &storage.Int64FieldData{Data: []UniqueID{1}}) channel.addFlushedSegmentWithPKs(segID2, collID, partID, 2, &storage.Int64FieldData{Data: []UniqueID{1}}) require.True(t, channel.hasSegment(segID1, true)) require.True(t, channel.hasSegment(segID2, true)) - meta := NewMetaFactory().GetCollectionMeta(collID, "test_compact_coll_name", schemapb.DataType_Int64) // the same pk for segmentI and segmentII pks := [2]primaryKey{newInt64PrimaryKey(1), newInt64PrimaryKey(2)} iData1 := genInsertDataWithPKs(pks, schemapb.DataType_Int64) @@ -923,14 +970,14 @@ func TestCompactorInterfaceMethods(t *testing.T) { RowCount: 0, } - stats1 := storage.NewPrimaryKeyStats(1, int64(rc.pkType), 1) + stats1 := storage.NewPrimaryKeyStats(1, int64(schemapb.DataType_Int64), 1) iPaths1, sPaths1, err := mockbIO.uploadStatsLog(context.TODO(), segID1, partID, iData1, stats1, 1, meta) require.NoError(t, err) dPaths1, err := mockbIO.uploadDeltaLog(context.TODO(), segID1, partID, dData1, meta) require.NoError(t, err) require.Equal(t, 12, len(iPaths1)) - stats2 := storage.NewPrimaryKeyStats(1, int64(rc.pkType), 1) + stats2 := storage.NewPrimaryKeyStats(1, int64(schemapb.DataType_Int64), 1) iPaths2, sPaths2, err := mockbIO.uploadStatsLog(context.TODO(), segID2, partID, iData2, stats2, 1, meta) require.NoError(t, err) dPaths2, err := mockbIO.uploadDeltaLog(context.TODO(), segID2, partID, dData2, meta) @@ -1015,7 +1062,7 @@ func (mfm *mockFlushManager) isFull() bool { func (mfm *mockFlushManager) injectFlush(injection *taskInjection, segments ...UniqueID) { go func() { time.Sleep(time.Second * time.Duration(mfm.sleepSeconds)) - //injection.injected <- struct{}{} + // injection.injected <- struct{}{} close(injection.injected) <-injection.injectOver mfm.injectOverCount.Lock() diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index c5716a96ada63..91b0c609b4f25 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -25,25 +25,20 @@ import ( "io" "math/rand" "os" - "path" - "strings" "sync" "sync/atomic" "syscall" "time" "github.com/cockroachdb/errors" - "github.com/golang/protobuf/proto" - v3rpc "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/datanode/allocator" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" @@ -51,7 +46,6 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" - "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -60,15 +54,8 @@ import ( ) const ( - // RPCConnectionTimeout is used to set the timeout for rpc request - RPCConnectionTimeout = 30 * time.Second - // ConnectEtcdMaxRetryTime is used to limit the max retry time for connection etcd ConnectEtcdMaxRetryTime = 100 - - // ImportCallTimeout is the timeout used in Import() method calls - // This value is equal to RootCoord's task expire time - ImportCallTimeout = 15 * 60 * time.Second ) var getFlowGraphServiceAttempts = uint(50) @@ -87,8 +74,7 @@ var Params *paramtable.ComponentParam = paramtable.Get() // `etcdCli` is a connection of etcd // `rootCoord` is a grpc client of root coordinator. // `dataCoord` is a grpc client of data service. -// `NodeID` is unique to each datanode. -// `State` is current statement of this data node, indicating whether it's healthy. +// `stateCode` is current statement of this data node, indicating whether it's healthy. // // `clearSignal` is a signal channel for releasing the flowgraph resources. // `segmentCache` stores all flushing and flushed segments. @@ -107,14 +93,15 @@ type DataNode struct { etcdCli *clientv3.Client address string - rootCoord types.RootCoord - dataCoord types.DataCoord + rootCoord types.RootCoordClient + dataCoord types.DataCoordClient + broker broker.Broker - //call once + // call once initOnce sync.Once startOnce sync.Once stopOnce sync.Once - wg sync.WaitGroup + stopWaiter sync.WaitGroup sessionMu sync.Mutex // to fix data race session *sessionutil.Session watchKv kv.WatchKV @@ -167,8 +154,8 @@ func (node *DataNode) SetEtcdClient(etcdCli *clientv3.Client) { node.etcdCli = etcdCli } -// SetRootCoord sets RootCoord's grpc client, error is returned if repeatedly set. -func (node *DataNode) SetRootCoord(rc types.RootCoord) error { +// SetRootCoordClient sets RootCoord's grpc client, error is returned if repeatedly set. +func (node *DataNode) SetRootCoordClient(rc types.RootCoordClient) error { switch { case rc == nil, node.rootCoord != nil: return errors.New("nil parameter or repeatedly set") @@ -178,8 +165,8 @@ func (node *DataNode) SetRootCoord(rc types.RootCoord) error { } } -// SetDataCoord sets data service's grpc client, error is returned if repeatedly set. -func (node *DataNode) SetDataCoord(ds types.DataCoord) error { +// SetDataCoordClient sets data service's grpc client, error is returned if repeatedly set. +func (node *DataNode) SetDataCoordClient(ds types.DataCoordClient) error { switch { case ds == nil, node.dataCoord != nil: return errors.New("nil parameter or repeatedly set") @@ -219,6 +206,7 @@ func (node *DataNode) initSession() error { return errors.New("failed to initialize session") } node.session.Init(typeutil.DataNodeRole, node.address, false, true) + sessionutil.SaveServerInfo(typeutil.DataNodeRole, node.session.ServerID) return nil } @@ -245,6 +233,8 @@ func (node *DataNode) Init() error { return } + node.broker = broker.NewCoordBroker(node.rootCoord, node.dataCoord) + err := node.initRateCollector() if err != nil { log.Error("DataNode server init rateCollector failed", zap.Int64("node ID", paramtable.GetNodeID()), zap.Error(err)) @@ -269,72 +259,10 @@ func (node *DataNode) Init() error { node.factory.Init(Params) log.Info("DataNode server init succeeded", zap.String("MsgChannelSubName", Params.CommonCfg.DataNodeSubName.GetValue())) - }) return initError } -// StartWatchChannels start loop to watch channel allocation status via kv(etcd for now) -func (node *DataNode) StartWatchChannels(ctx context.Context) { - defer node.wg.Done() - defer logutil.LogPanic() - // REF MEP#7 watch path should be [prefix]/channel/{node_id}/{channel_name} - // TODO, this is risky, we'd better watch etcd with revision rather simply a path - watchPrefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.GetSession().ServerID)) - log.Info("Start watch channel", zap.String("prefix", watchPrefix)) - evtChan := node.watchKv.WatchWithPrefix(watchPrefix) - // after watch, first check all exists nodes first - err := node.checkWatchedList() - if err != nil { - log.Warn("StartWatchChannels failed", zap.Error(err)) - return - } - for { - select { - case <-ctx.Done(): - log.Info("watch etcd loop quit") - return - case event, ok := <-evtChan: - if !ok { - log.Warn("datanode failed to watch channel, return") - go node.StartWatchChannels(ctx) - return - } - - if err := event.Err(); err != nil { - log.Warn("datanode watch channel canceled", zap.Error(event.Err())) - // https://github.com/etcd-io/etcd/issues/8980 - if event.Err() == v3rpc.ErrCompacted { - go node.StartWatchChannels(ctx) - return - } - // if watch loop return due to event canceled, the datanode is not functional anymore - log.Panic("datanode is not functional for event canceled", zap.Error(err)) - return - } - for _, evt := range event.Events { - // We need to stay in order until events enqueued - node.handleChannelEvt(evt) - } - } - } -} - -// checkWatchedList list all nodes under [prefix]/channel/{node_id} and make sure all nodeds are watched -// serves the corner case for etcd connection lost and missing some events -func (node *DataNode) checkWatchedList() error { - // REF MEP#7 watch path should be [prefix]/channel/{node_id}/{channel_name} - prefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", paramtable.GetNodeID())) - keys, values, err := node.watchKv.LoadWithPrefix(prefix) - if err != nil { - return err - } - for i, val := range values { - node.handleWatchInfo(&event{eventType: putEventType}, keys[i], []byte(val)) - } - return nil -} - // handleChannelEvt handles event from kv watch event func (node *DataNode) handleChannelEvt(evt *clientv3.Event) { var e *event @@ -354,123 +282,6 @@ func (node *DataNode) handleChannelEvt(evt *clientv3.Event) { node.handleWatchInfo(e, string(evt.Kv.Key), evt.Kv.Value) } -func (node *DataNode) handleWatchInfo(e *event, key string, data []byte) { - switch e.eventType { - case putEventType: - watchInfo, err := parsePutEventData(data) - if err != nil { - log.Warn("fail to handle watchInfo", zap.Int("event type", e.eventType), zap.String("key", key), zap.Error(err)) - return - } - - if isEndWatchState(watchInfo.State) { - log.Info("DataNode received a PUT event with an end State", zap.String("state", watchInfo.State.String())) - return - } - - if watchInfo.Progress != 0 { - log.Info("DataNode received a PUT event with tickler update progress", zap.String("channel", watchInfo.Vchan.ChannelName), zap.Int64("version", e.version)) - return - } - - e.info = watchInfo - e.vChanName = watchInfo.GetVchan().GetChannelName() - log.Info("DataNode is handling watchInfo PUT event", zap.String("key", key), zap.Any("watch state", watchInfo.GetState().String())) - case deleteEventType: - e.vChanName = parseDeleteEventKey(key) - log.Info("DataNode is handling watchInfo DELETE event", zap.String("key", key)) - } - - actualManager, loaded := node.eventManagerMap.GetOrInsert(e.vChanName, newChannelEventManager( - node.handlePutEvent, node.handleDeleteEvent, retryWatchInterval, - )) - - if !loaded { - actualManager.Run() - } - - actualManager.handleEvent(*e) - - // Whenever a delete event comes, this eventManager will be removed from map - if e.eventType == deleteEventType { - if m, loaded := node.eventManagerMap.GetAndRemove(e.vChanName); loaded { - m.Close() - } - } -} - -func parsePutEventData(data []byte) (*datapb.ChannelWatchInfo, error) { - watchInfo := datapb.ChannelWatchInfo{} - err := proto.Unmarshal(data, &watchInfo) - if err != nil { - return nil, fmt.Errorf("invalid event data: fail to parse ChannelWatchInfo, err: %v", err) - } - - if watchInfo.Vchan == nil { - return nil, fmt.Errorf("invalid event: ChannelWatchInfo with nil VChannelInfo") - } - reviseVChannelInfo(watchInfo.GetVchan()) - return &watchInfo, nil -} - -func parseDeleteEventKey(key string) string { - parts := strings.Split(key, "/") - vChanName := parts[len(parts)-1] - return vChanName -} - -func (node *DataNode) handlePutEvent(watchInfo *datapb.ChannelWatchInfo, version int64) (err error) { - vChanName := watchInfo.GetVchan().GetChannelName() - key := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.GetSession().ServerID), vChanName) - tickler := newTickler(version, key, watchInfo, node.watchKv, Params.DataNodeCfg.WatchEventTicklerInterval.GetAsDuration(time.Second)) - - switch watchInfo.State { - case datapb.ChannelWatchState_Uncomplete, datapb.ChannelWatchState_ToWatch: - if err := node.flowgraphManager.addAndStart(node, watchInfo.GetVchan(), watchInfo.GetSchema(), tickler); err != nil { - log.Warn("handle put event: new data sync service failed", zap.String("vChanName", vChanName), zap.Error(err)) - watchInfo.State = datapb.ChannelWatchState_WatchFailure - } else { - log.Info("handle put event: new data sync service success", zap.String("vChanName", vChanName)) - watchInfo.State = datapb.ChannelWatchState_WatchSuccess - } - case datapb.ChannelWatchState_ToRelease: - // there is no reason why we release fail - node.tryToReleaseFlowgraph(vChanName) - watchInfo.State = datapb.ChannelWatchState_ReleaseSuccess - } - - v, err := proto.Marshal(watchInfo) - if err != nil { - return fmt.Errorf("fail to marshal watchInfo with state, vChanName: %s, state: %s ,err: %w", vChanName, watchInfo.State.String(), err) - } - - success, err := node.watchKv.CompareVersionAndSwap(key, tickler.version, string(v)) - // etcd error - if err != nil { - // flow graph will leak if not release, causing new datanode failed to subscribe - node.tryToReleaseFlowgraph(vChanName) - log.Warn("fail to update watch state to etcd", zap.String("vChanName", vChanName), - zap.String("state", watchInfo.State.String()), zap.Error(err)) - return err - } - // etcd valid but the states updated. - if !success { - log.Info("handle put event: failed to compare version and swap, release flowgraph", - zap.String("key", key), zap.String("state", watchInfo.State.String()), - zap.String("vChanName", vChanName)) - // flow graph will leak if not release, causing new datanode failed to subscribe - node.tryToReleaseFlowgraph(vChanName) - return nil - } - log.Info("handle put event success", zap.String("key", key), - zap.String("state", watchInfo.State.String()), zap.String("vChanName", vChanName)) - return nil -} - -func (node *DataNode) handleDeleteEvent(vChanName string) { - node.tryToReleaseFlowgraph(vChanName) -} - // tryToReleaseFlowgraph tries to release a flowgraph func (node *DataNode) tryToReleaseFlowgraph(vChanName string) { log.Info("try to release flowgraph", zap.String("vChanName", vChanName)) @@ -480,7 +291,7 @@ func (node *DataNode) tryToReleaseFlowgraph(vChanName string) { // BackGroundGC runs in background to release datanode resources // GOOSE TODO: remove background GC, using ToRelease for drop-collection after #15846 func (node *DataNode) BackGroundGC(vChannelCh <-chan string) { - defer node.wg.Done() + defer node.stopWaiter.Done() log.Info("DataNode Background GC Start") for { select { @@ -504,33 +315,33 @@ func (node *DataNode) Start() error { } log.Info("start id allocator done", zap.String("role", typeutil.DataNodeRole)) - rep, err := node.rootCoord.AllocTimestamp(node.ctx, &rootcoordpb.AllocTimestampRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_RequestTSO), - commonpbutil.WithMsgID(0), - commonpbutil.WithSourceID(paramtable.GetNodeID()), - ), - Count: 1, - }) - if err != nil || rep.Status.ErrorCode != commonpb.ErrorCode_Success { - log.Warn("fail to alloc timestamp", zap.Any("rep", rep), zap.Error(err)) - startErr = errors.New("DataNode fail to alloc timestamp") - return - } + /* + rep, err := node.rootCoord.AllocTimestamp(node.ctx, &rootcoordpb.AllocTimestampRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_RequestTSO), + commonpbutil.WithMsgID(0), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + Count: 1, + }) + if err != nil || rep.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + log.Warn("fail to alloc timestamp", zap.Any("rep", rep), zap.Error(err)) + startErr = errors.New("DataNode fail to alloc timestamp") + return + }*/ connectEtcdFn := func() error { etcdKV := etcdkv.NewEtcdKV(node.etcdCli, Params.EtcdCfg.MetaRootPath.GetValue()) node.watchKv = etcdKV return nil } - err = retry.Do(node.ctx, connectEtcdFn, retry.Attempts(ConnectEtcdMaxRetryTime)) + err := retry.Do(node.ctx, connectEtcdFn, retry.Attempts(ConnectEtcdMaxRetryTime)) if err != nil { startErr = errors.New("DataNode fail to connect etcd") return } chunkManager, err := node.factory.NewPersistentStorageChunkManager(node.ctx) - if err != nil { startErr = err return @@ -538,24 +349,24 @@ func (node *DataNode) Start() error { node.chunkManager = chunkManager - node.wg.Add(1) + node.stopWaiter.Add(1) go node.BackGroundGC(node.clearSignal) go node.compactionExecutor.start(node.ctx) if Params.DataNodeCfg.DataNodeTimeTickByRPC.GetAsBool() { - node.timeTickSender = newTimeTickSender(node.dataCoord, node.session.ServerID) + node.timeTickSender = newTimeTickSender(node.broker, node.session.ServerID) go node.timeTickSender.start(node.ctx) } - node.wg.Add(1) + node.stopWaiter.Add(1) // Start node watch node go node.StartWatchChannels(node.ctx) - go node.flowgraphManager.start() + node.stopWaiter.Add(1) + go node.flowgraphManager.start(&node.stopWaiter) node.UpdateStateCode(commonpb.StateCode_Healthy) - }) return startErr } @@ -608,7 +419,7 @@ func (node *DataNode) Stop() error { node.session.Stop() } - node.wg.Wait() + node.stopWaiter.Wait() }) return nil } diff --git a/internal/datanode/data_node_test.go b/internal/datanode/data_node_test.go index fb006af9e4bce..47bb46deba6e9 100644 --- a/internal/datanode/data_node_test.go +++ b/internal/datanode/data_node_test.go @@ -18,7 +18,6 @@ package datanode import ( "context" - "fmt" "math/rand" "os" "strconv" @@ -26,13 +25,13 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" @@ -105,16 +104,23 @@ func TestDataNode(t *testing.T) { assert.Empty(t, node.GetAddress()) node.SetAddress("address") assert.Equal(t, "address", node.GetAddress()) + + broker := &broker.MockBroker{} + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return([]*datapb.SegmentInfo{}, nil).Maybe() + + node.broker = broker + defer node.Stop() node.chunkManager = storage.NewLocalChunkManager(storage.RootPath("/tmp/milvus_test/datanode")) paramtable.SetNodeID(1) defer cancel() - t.Run("Test SetRootCoord", func(t *testing.T) { + t.Run("Test SetRootCoordClient", func(t *testing.T) { emptyDN := &DataNode{} tests := []struct { - inrc types.RootCoord + inrc types.RootCoordClient isvalid bool description string }{ @@ -124,7 +130,7 @@ func TestDataNode(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - err := emptyDN.SetRootCoord(test.inrc) + err := emptyDN.SetRootCoordClient(test.inrc) if test.isvalid { assert.NoError(t, err) } else { @@ -134,10 +140,10 @@ func TestDataNode(t *testing.T) { } }) - t.Run("Test SetDataCoord", func(t *testing.T) { + t.Run("Test SetDataCoordClient", func(t *testing.T) { emptyDN := &DataNode{} tests := []struct { - inrc types.DataCoord + inrc types.DataCoordClient isvalid bool description string }{ @@ -147,7 +153,7 @@ func TestDataNode(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - err := emptyDN.SetDataCoord(test.inrc) + err := emptyDN.SetDataCoordClient(test.inrc) if test.isvalid { assert.NoError(t, err) } else { @@ -159,7 +165,7 @@ func TestDataNode(t *testing.T) { t.Run("Test getSystemInfoMetrics", func(t *testing.T) { emptyNode := &DataNode{} - emptyNode.SetSession(&sessionutil.Session{ServerID: 1}) + emptyNode.SetSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}) emptyNode.flowgraphManager = newFlowgraphManager() req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics) @@ -174,7 +180,7 @@ func TestDataNode(t *testing.T) { t.Run("Test getSystemInfoMetrics with quotaMetric error", func(t *testing.T) { emptyNode := &DataNode{} - emptyNode.SetSession(&sessionutil.Session{ServerID: 1}) + emptyNode.SetSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}) emptyNode.flowgraphManager = newFlowgraphManager() req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics) @@ -189,7 +195,7 @@ func TestDataNode(t *testing.T) { t.Run("Test BackGroundGC", func(t *testing.T) { vchanNameCh := make(chan string) node.clearSignal = vchanNameCh - node.wg.Add(1) + node.stopWaiter.Add(1) go node.BackGroundGC(vchanNameCh) testDataSyncs := []struct { @@ -200,7 +206,7 @@ func TestDataNode(t *testing.T) { } for _, test := range testDataSyncs { - err = node.flowgraphManager.addAndStart(node, &datapb.VchannelInfo{CollectionID: 1, ChannelName: test.dmChannelName}, nil, genTestTickler()) + err = node.flowgraphManager.addAndStartWithEtcdTickler(node, &datapb.VchannelInfo{CollectionID: 1, ChannelName: test.dmChannelName}, nil, genTestTickler()) assert.NoError(t, err) vchanNameCh <- test.dmChannelName } @@ -214,262 +220,4 @@ func TestDataNode(t *testing.T) { return true }, 2*time.Second, 10*time.Millisecond) }) - -} - -func TestWatchChannel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) - etcdCli, err := etcd.GetEtcdClient( - Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), - Params.EtcdCfg.EtcdUseSSL.GetAsBool(), - Params.EtcdCfg.Endpoints.GetAsStrings(), - Params.EtcdCfg.EtcdTLSCert.GetValue(), - Params.EtcdCfg.EtcdTLSKey.GetValue(), - Params.EtcdCfg.EtcdTLSCACert.GetValue(), - Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) - assert.NoError(t, err) - defer etcdCli.Close() - node.SetEtcdClient(etcdCli) - err = node.Init() - assert.NoError(t, err) - err = node.Start() - assert.NoError(t, err) - defer node.Stop() - err = node.Register() - assert.NoError(t, err) - - defer cancel() - - t.Run("test watch channel", func(t *testing.T) { - kv := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath.GetValue()) - oldInvalidCh := "datanode-etcd-test-by-dev-rootcoord-dml-channel-invalid" - path := fmt.Sprintf("%s/%d/%s", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID(), oldInvalidCh) - err = kv.Save(path, string([]byte{23})) - assert.NoError(t, err) - - ch := fmt.Sprintf("datanode-etcd-test-by-dev-rootcoord-dml-channel_%d", rand.Int31()) - path = fmt.Sprintf("%s/%d/%s", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID(), ch) - - vchan := &datapb.VchannelInfo{ - CollectionID: 1, - ChannelName: ch, - UnflushedSegmentIds: []int64{}, - } - info := &datapb.ChannelWatchInfo{ - State: datapb.ChannelWatchState_ToWatch, - Vchan: vchan, - } - val, err := proto.Marshal(info) - assert.NoError(t, err) - err = kv.Save(path, string(val)) - assert.NoError(t, err) - - assert.Eventually(t, func() bool { - exist := node.flowgraphManager.exist(ch) - if !exist { - return false - } - bs, err := kv.LoadBytes(fmt.Sprintf("%s/%d/%s", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID(), ch)) - if err != nil { - return false - } - watchInfo := &datapb.ChannelWatchInfo{} - err = proto.Unmarshal(bs, watchInfo) - if err != nil { - return false - } - return watchInfo.GetState() == datapb.ChannelWatchState_WatchSuccess - }, 3*time.Second, 100*time.Millisecond) - - err = kv.RemoveWithPrefix(fmt.Sprintf("%s/%d", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID())) - assert.NoError(t, err) - - assert.Eventually(t, func() bool { - exist := node.flowgraphManager.exist(ch) - return !exist - }, 3*time.Second, 100*time.Millisecond) - }) - - t.Run("Test release channel", func(t *testing.T) { - kv := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath.GetValue()) - oldInvalidCh := "datanode-etcd-test-by-dev-rootcoord-dml-channel-invalid" - path := fmt.Sprintf("%s/%d/%s", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID(), oldInvalidCh) - err = kv.Save(path, string([]byte{23})) - assert.NoError(t, err) - - ch := fmt.Sprintf("datanode-etcd-test-by-dev-rootcoord-dml-channel_%d", rand.Int31()) - path = fmt.Sprintf("%s/%d/%s", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID(), ch) - c := make(chan struct{}) - go func() { - ec := kv.WatchWithPrefix(fmt.Sprintf("%s/%d", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID())) - c <- struct{}{} - cnt := 0 - for { - evt := <-ec - for _, event := range evt.Events { - if strings.Contains(string(event.Kv.Key), ch) { - cnt++ - } - } - if cnt >= 2 { - break - } - } - c <- struct{}{} - }() - // wait for check goroutine start Watch - <-c - - vchan := &datapb.VchannelInfo{ - CollectionID: 1, - ChannelName: ch, - UnflushedSegmentIds: []int64{}, - } - info := &datapb.ChannelWatchInfo{ - State: datapb.ChannelWatchState_ToRelease, - Vchan: vchan, - } - val, err := proto.Marshal(info) - assert.NoError(t, err) - err = kv.Save(path, string(val)) - assert.NoError(t, err) - - // wait for check goroutine received 2 events - <-c - exist := node.flowgraphManager.exist(ch) - assert.False(t, exist) - - err = kv.RemoveWithPrefix(fmt.Sprintf("%s/%d", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID())) - assert.NoError(t, err) - //TODO there is not way to sync Release done, use sleep for now - time.Sleep(100 * time.Millisecond) - - exist = node.flowgraphManager.exist(ch) - assert.False(t, exist) - - }) - - t.Run("handle watch info failed", func(t *testing.T) { - e := &event{ - eventType: putEventType, - } - - node.handleWatchInfo(e, "test1", []byte{23}) - - exist := node.flowgraphManager.exist("test1") - assert.False(t, exist) - - info := datapb.ChannelWatchInfo{ - Vchan: nil, - State: datapb.ChannelWatchState_Uncomplete, - } - bs, err := proto.Marshal(&info) - assert.NoError(t, err) - node.handleWatchInfo(e, "test2", bs) - - exist = node.flowgraphManager.exist("test2") - assert.False(t, exist) - - chPut := make(chan struct{}, 1) - chDel := make(chan struct{}, 1) - - ch := fmt.Sprintf("datanode-etcd-test-by-dev-rootcoord-dml-channel_%d", rand.Int31()) - m := newChannelEventManager( - func(info *datapb.ChannelWatchInfo, version int64) error { - r := node.handlePutEvent(info, version) - chPut <- struct{}{} - return r - }, - func(vChan string) { - node.handleDeleteEvent(vChan) - chDel <- struct{}{} - }, time.Millisecond*100, - ) - node.eventManagerMap.Insert(ch, m) - m.Run() - defer m.Close() - - info = datapb.ChannelWatchInfo{ - Vchan: &datapb.VchannelInfo{ChannelName: ch}, - State: datapb.ChannelWatchState_Uncomplete, - } - bs, err = proto.Marshal(&info) - assert.NoError(t, err) - - msFactory := node.factory - defer func() { node.factory = msFactory }() - - // todo review the UT logic - // As we remove timetick channel logic, flow_graph_insert_buffer_node no longer depend on MessageStreamFactory - // so data_sync_service can be created. this assert becomes true - node.factory = &FailMessageStreamFactory{} - node.handleWatchInfo(e, ch, bs) - <-chPut - exist = node.flowgraphManager.exist(ch) - assert.True(t, exist) - }) - - t.Run("handle watchinfo out of date", func(t *testing.T) { - chPut := make(chan struct{}, 1) - chDel := make(chan struct{}, 1) - // inject eventManager - ch := fmt.Sprintf("datanode-etcd-test-by-dev-rootcoord-dml-channel_%d", rand.Int31()) - m := newChannelEventManager( - func(info *datapb.ChannelWatchInfo, version int64) error { - r := node.handlePutEvent(info, version) - chPut <- struct{}{} - return r - }, - func(vChan string) { - node.handleDeleteEvent(vChan) - chDel <- struct{}{} - }, time.Millisecond*100, - ) - node.eventManagerMap.Insert(ch, m) - m.Run() - defer m.Close() - e := &event{ - eventType: putEventType, - version: 10000, - } - - info := datapb.ChannelWatchInfo{ - Vchan: &datapb.VchannelInfo{ChannelName: ch}, - State: datapb.ChannelWatchState_Uncomplete, - } - bs, err := proto.Marshal(&info) - assert.NoError(t, err) - - node.handleWatchInfo(e, ch, bs) - <-chPut - exist := node.flowgraphManager.exist("test3") - assert.False(t, exist) - }) - - t.Run("handle watchinfo compatibility", func(t *testing.T) { - info := datapb.ChannelWatchInfo{ - Vchan: &datapb.VchannelInfo{ - CollectionID: 1, - ChannelName: "delta-channel1", - UnflushedSegments: []*datapb.SegmentInfo{{ID: 1}}, - FlushedSegments: []*datapb.SegmentInfo{{ID: 2}}, - DroppedSegments: []*datapb.SegmentInfo{{ID: 3}}, - UnflushedSegmentIds: []int64{1}, - }, - State: datapb.ChannelWatchState_Uncomplete, - } - bs, err := proto.Marshal(&info) - assert.NoError(t, err) - - newWatchInfo, err := parsePutEventData(bs) - assert.NoError(t, err) - - assert.Equal(t, []*datapb.SegmentInfo{}, newWatchInfo.GetVchan().GetUnflushedSegments()) - assert.Equal(t, []*datapb.SegmentInfo{}, newWatchInfo.GetVchan().GetFlushedSegments()) - assert.Equal(t, []*datapb.SegmentInfo{}, newWatchInfo.GetVchan().GetDroppedSegments()) - assert.NotEmpty(t, newWatchInfo.GetVchan().GetUnflushedSegmentIds()) - assert.NotEmpty(t, newWatchInfo.GetVchan().GetFlushedSegmentIds()) - assert.NotEmpty(t, newWatchInfo.GetVchan().GetDroppedSegmentIds()) - }) } diff --git a/internal/datanode/data_sync_service.go b/internal/datanode/data_sync_service.go index 565855776d48a..cff061907a713 100644 --- a/internal/datanode/data_sync_service.go +++ b/internal/datanode/data_sync_service.go @@ -18,27 +18,21 @@ package datanode import ( "context" - "fmt" "sync" "github.com/cockroachdb/errors" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/datanode/allocator" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/flowgraph" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" - "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" - "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/conc" - "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" ) @@ -46,155 +40,37 @@ import ( type dataSyncService struct { ctx context.Context cancelFn context.CancelFunc - fg *flowgraph.TimeTickedFlowGraph // internal flowgraph processes insert/delta messages - flushCh chan flushMsg - resendTTCh chan resendTTMsg // chan to ask for resending DataNode time tick message. - channel Channel // channel stores meta of channel - idAllocator allocator.Allocator // id/timestamp allocator - dispClient msgdispatcher.Client - msFactory msgstream.Factory + channel Channel // channel stores meta of channel + opID int64 collectionID UniqueID // collection id of vchan for which this data sync service serves vchannelName string - dataCoord types.DataCoord // DataCoord instance to interact with - clearSignal chan<- string // signal channel to notify flowgraph close for collection/partition drop msg consumed - - delBufferManager *DeltaBufferManager - flushingSegCache *Cache // a guarding cache stores currently flushing segment ids - flushManager flushManager // flush manager handles flush process - chunkManager storage.ChunkManager - compactor *compactionExecutor // reference to compaction executor - - serverID int64 - stopOnce sync.Once - flushListener chan *segmentFlushPack // chan to listen flush event - timetickSender *timeTickSender // reference to timeTickSender - cli *clientv3.Client -} - -func newDataSyncServiceV2(ctx context.Context, - flushCh chan flushMsg, - resendTTCh chan resendTTMsg, - channel Channel, - alloc allocator.Allocator, - dispClient msgdispatcher.Client, - factory msgstream.Factory, - vchan *datapb.VchannelInfo, - clearSignal chan<- string, - dataCoord types.DataCoord, - flushingSegCache *Cache, - chunkManager storage.ChunkManager, - compactor *compactionExecutor, - tickler *tickler, - serverID int64, - timetickSender *timeTickSender, - cli *clientv3.Client, -) (*dataSyncService, error) { - if channel == nil { - return nil, errors.New("Nil input") - } - ctx1, cancel := context.WithCancel(ctx) - - delBufferManager := &DeltaBufferManager{ - channel: channel, - delBufHeap: &PriorityQueue{}, - } - - service := &dataSyncService{ - ctx: ctx1, - cancelFn: cancel, - fg: nil, - flushCh: flushCh, - resendTTCh: resendTTCh, - channel: channel, - idAllocator: alloc, - dispClient: dispClient, - msFactory: factory, - collectionID: vchan.GetCollectionID(), - vchannelName: vchan.GetChannelName(), - dataCoord: dataCoord, - clearSignal: clearSignal, - delBufferManager: delBufferManager, - flushingSegCache: flushingSegCache, - chunkManager: chunkManager, - compactor: compactor, - serverID: serverID, - timetickSender: timetickSender, - cli: cli, - } - - if err := service.initNodes(vchan, tickler); err != nil { - return nil, err - } - if tickler.isWatchFailed.Load() { - return nil, errors.Errorf("tickler watch failed") - } - return service, nil -} + // TODO: should be equal to paramtable.GetNodeID(), but intergrationtest has 1 paramtable for a minicluster, the NodeID + // varies, will cause savebinglogpath check fail. So we pass ServerID into dataSyncService to aviod it failure. + serverID UniqueID -func newDataSyncService(ctx context.Context, - flushCh chan flushMsg, - resendTTCh chan resendTTMsg, - channel Channel, - alloc allocator.Allocator, - dispClient msgdispatcher.Client, - factory msgstream.Factory, - vchan *datapb.VchannelInfo, - clearSignal chan<- string, - dataCoord types.DataCoord, - flushingSegCache *Cache, - chunkManager storage.ChunkManager, - compactor *compactionExecutor, - tickler *tickler, - serverID int64, - timetickSender *timeTickSender, -) (*dataSyncService, error) { - - if channel == nil { - return nil, errors.New("Nil input") - } + fg *flowgraph.TimeTickedFlowGraph // internal flowgraph processes insert/delta messages - ctx1, cancel := context.WithCancel(ctx) + broker broker.Broker + delBufferManager *DeltaBufferManager + flushManager flushManager // flush manager handles flush process - delBufferManager := &DeltaBufferManager{ - channel: channel, - delBufHeap: &PriorityQueue{}, - } + flushCh chan flushMsg + resendTTCh chan resendTTMsg // chan to ask for resending DataNode time tick message. + timetickSender *timeTickSender // reference to timeTickSender + compactor *compactionExecutor // reference to compaction executor + flushingSegCache *Cache // a guarding cache stores currently flushing segment ids - service := &dataSyncService{ - ctx: ctx1, - cancelFn: cancel, - fg: nil, - flushCh: flushCh, - resendTTCh: resendTTCh, - channel: channel, - idAllocator: alloc, - dispClient: dispClient, - msFactory: factory, - collectionID: vchan.GetCollectionID(), - vchannelName: vchan.GetChannelName(), - dataCoord: dataCoord, - clearSignal: clearSignal, - delBufferManager: delBufferManager, - flushingSegCache: flushingSegCache, - chunkManager: chunkManager, - compactor: compactor, - serverID: serverID, - timetickSender: timetickSender, - } + clearSignal chan<- string // signal channel to notify flowgraph close for collection/partition drop msg consumed + idAllocator allocator.Allocator // id/timestamp allocator + msFactory msgstream.Factory + dispClient msgdispatcher.Client + chunkManager storage.ChunkManager - if err := service.initNodes(vchan, tickler); err != nil { - return nil, err - } - if tickler.isWatchFailed.Load() { - return nil, errors.Errorf("tickler watch failed") - } - return service, nil -} + // test only + flushListener chan *segmentFlushPack // chan to listen flush event -type parallelConfig struct { - maxQueueLength int32 - maxParallelism int32 + stopOnce sync.Once } type nodeConfig struct { @@ -203,13 +79,7 @@ type nodeConfig struct { vChannelName string channel Channel // Channel info allocator allocator.Allocator - serverID int64 - // defaults - parallelConfig -} - -func newParallelConfig() parallelConfig { - return parallelConfig{Params.DataNodeCfg.FlowGraphMaxQueueLength.GetAsInt32(), Params.DataNodeCfg.FlowGraphMaxParallelism.GetAsInt32()} + serverID UniqueID } // start the flow graph in datasyncservice @@ -234,7 +104,7 @@ func (dsService *dataSyncService) GracefullyClose() { func (dsService *dataSyncService) close() { dsService.stopOnce.Do(func() { - log := log.Ctx(context.Background()).With( + log := log.Ctx(dsService.ctx).With( zap.Int64("collectionID", dsService.collectionID), zap.String("vChanName", dsService.vchannelName), ) @@ -247,8 +117,12 @@ func (dsService *dataSyncService) close() { dsService.clearGlobalFlushingCache() close(dsService.flushCh) - dsService.flushManager.close() - log.Info("dataSyncService flush manager closed") + + if dsService.flushManager != nil { + dsService.flushManager.close() + log.Info("dataSyncService flush manager closed") + } + dsService.cancelFn() dsService.channel.close() @@ -261,51 +135,100 @@ func (dsService *dataSyncService) clearGlobalFlushingCache() { dsService.flushingSegCache.Remove(segments...) } -// initNodes inits a TimetickedFlowGraph -func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo, tickler *tickler) error { - dsService.fg = flowgraph.NewTimeTickedFlowGraph(dsService.ctx) - // initialize flush manager for DataSync Service - if Params.CommonCfg.EnableStorageV2.GetAsBool() { - dsService.flushManager = NewRendezvousFlushManagerV2(dsService.idAllocator, dsService.chunkManager, dsService.channel, - flushNotifyFunc2(dsService, retry.Attempts(50)), dropVirtualChannelFunc(dsService), dsService, dsService.cli) - } else { - dsService.flushManager = NewRendezvousFlushManager(dsService.idAllocator, dsService.chunkManager, dsService.channel, - flushNotifyFunc(dsService, retry.Attempts(50)), dropVirtualChannelFunc(dsService)) +func getChannelWithTickler(initCtx context.Context, node *DataNode, info *datapb.ChannelWatchInfo, tickler *tickler, unflushed, flushed []*datapb.SegmentInfo) (Channel, error) { + var ( + channelName = info.GetVchan().GetChannelName() + collectionID = info.GetVchan().GetCollectionID() + recoverTs = info.GetVchan().GetSeekPosition().GetTimestamp() + ) + + // init channel meta + channel := newChannel(channelName, collectionID, info.GetSchema(), node.broker, node.chunkManager) + + // tickler will update addSegment progress to watchInfo + futures := make([]*conc.Future[any], 0, len(unflushed)+len(flushed)) + tickler.setTotal(int32(len(unflushed) + len(flushed))) + + for _, us := range unflushed { + log.Info("recover growing segments from checkpoints", + zap.String("vChannelName", us.GetInsertChannel()), + zap.Int64("segmentID", us.GetID()), + zap.Int64("numRows", us.GetNumOfRows()), + ) + + // avoid closure capture iteration variable + segment := us + future := getOrCreateIOPool().Submit(func() (interface{}, error) { + if err := channel.addSegment(initCtx, addSegmentReq{ + segType: datapb.SegmentType_Normal, + segID: segment.GetID(), + collID: segment.CollectionID, + partitionID: segment.PartitionID, + numOfRows: segment.GetNumOfRows(), + statsBinLogs: segment.Statslogs, + binLogs: segment.GetBinlogs(), + endPos: segment.GetDmlPosition(), + recoverTs: recoverTs, + }); err != nil { + return nil, err + } + tickler.inc() + return nil, nil + }) + futures = append(futures, future) } - log.Info("begin to init data sync service", zap.Int64("collection", vchanInfo.CollectionID), - zap.String("Chan", vchanInfo.ChannelName), - zap.Int64s("unflushed", vchanInfo.GetUnflushedSegmentIds()), - zap.Int64s("flushed", vchanInfo.GetFlushedSegmentIds()), - ) - var err error - // recover segment checkpoints - unflushedSegmentInfos, err := dsService.getSegmentInfos(vchanInfo.GetUnflushedSegmentIds()) - if err != nil { - return err + for _, fs := range flushed { + log.Info("recover sealed segments form checkpoints", + zap.String("vChannelName", fs.GetInsertChannel()), + zap.Int64("segmentID", fs.GetID()), + zap.Int64("numRows", fs.GetNumOfRows()), + ) + // avoid closure capture iteration variable + segment := fs + future := getOrCreateIOPool().Submit(func() (interface{}, error) { + if err := channel.addSegment(initCtx, addSegmentReq{ + segType: datapb.SegmentType_Flushed, + segID: segment.GetID(), + collID: segment.GetCollectionID(), + partitionID: segment.GetPartitionID(), + numOfRows: segment.GetNumOfRows(), + statsBinLogs: segment.GetStatslogs(), + binLogs: segment.GetBinlogs(), + recoverTs: recoverTs, + }); err != nil { + return nil, err + } + tickler.inc() + return nil, nil + }) + futures = append(futures, future) } - flushedSegmentInfos, err := dsService.getSegmentInfos(vchanInfo.GetFlushedSegmentIds()) - if err != nil { - return err + + if err := conc.AwaitAll(futures...); err != nil { + return nil, err } - //tickler will update addSegment progress to watchInfo + return channel, nil +} + +// getChannelWithEtcdTickler updates progress into etcd when a new segment is added into channel. +func getChannelWithEtcdTickler(initCtx context.Context, node *DataNode, info *datapb.ChannelWatchInfo, tickler *etcdTickler, unflushed, flushed []*datapb.SegmentInfo) (Channel, error) { + var ( + channelName = info.GetVchan().GetChannelName() + collectionID = info.GetVchan().GetCollectionID() + recoverTs = info.GetVchan().GetSeekPosition().GetTimestamp() + ) + + // init channel meta + channel := newChannel(channelName, collectionID, info.GetSchema(), node.broker, node.chunkManager) + + // tickler will update addSegment progress to watchInfo tickler.watch() defer tickler.stop() - futures := make([]*conc.Future[any], 0, len(unflushedSegmentInfos)+len(flushedSegmentInfos)) - - for _, us := range unflushedSegmentInfos { - if us.CollectionID != dsService.collectionID || - us.GetInsertChannel() != vchanInfo.ChannelName { - log.Warn("Collection ID or ChannelName not match", - zap.Int64("Wanted ID", dsService.collectionID), - zap.Int64("Actual ID", us.CollectionID), - zap.String("Wanted channel Name", vchanInfo.ChannelName), - zap.String("Actual Channel Name", us.GetInsertChannel()), - ) - continue - } + futures := make([]*conc.Future[any], 0, len(unflushed)+len(flushed)) + for _, us := range unflushed { log.Info("recover growing segments from checkpoints", zap.String("vChannelName", us.GetInsertChannel()), zap.Int64("segmentID", us.GetID()), @@ -315,7 +238,7 @@ func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo, tick // avoid closure capture iteration variable segment := us future := getOrCreateIOPool().Submit(func() (interface{}, error) { - if err := dsService.channel.addSegment(addSegmentReq{ + if err := channel.addSegment(initCtx, addSegmentReq{ segType: datapb.SegmentType_Normal, segID: segment.GetID(), collID: segment.CollectionID, @@ -324,7 +247,8 @@ func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo, tick statsBinLogs: segment.Statslogs, binLogs: segment.GetBinlogs(), endPos: segment.GetDmlPosition(), - recoverTs: vchanInfo.GetSeekPosition().GetTimestamp()}); err != nil { + recoverTs: recoverTs, + }); err != nil { return nil, err } tickler.inc() @@ -333,17 +257,7 @@ func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo, tick futures = append(futures, future) } - for _, fs := range flushedSegmentInfos { - if fs.CollectionID != dsService.collectionID || - fs.GetInsertChannel() != vchanInfo.ChannelName { - log.Warn("Collection ID or ChannelName not match", - zap.Int64("Wanted ID", dsService.collectionID), - zap.Int64("Actual ID", fs.CollectionID), - zap.String("Wanted Channel Name", vchanInfo.ChannelName), - zap.String("Actual Channel Name", fs.GetInsertChannel()), - ) - continue - } + for _, fs := range flushed { log.Info("recover sealed segments form checkpoints", zap.String("vChannelName", fs.GetInsertChannel()), zap.Int64("segmentID", fs.GetID()), @@ -352,7 +266,7 @@ func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo, tick // avoid closure capture iteration variable segment := fs future := getOrCreateIOPool().Submit(func() (interface{}, error) { - if err := dsService.channel.addSegment(addSegmentReq{ + if err := channel.addSegment(initCtx, addSegmentReq{ segType: datapb.SegmentType_Flushed, segID: segment.GetID(), collID: segment.GetCollectionID(), @@ -360,7 +274,7 @@ func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo, tick numOfRows: segment.GetNumOfRows(), statsBinLogs: segment.GetStatslogs(), binLogs: segment.GetBinlogs(), - recoverTs: vchanInfo.GetSeekPosition().GetTimestamp(), + recoverTs: recoverTs, }); err != nil { return nil, err } @@ -370,163 +284,182 @@ func (dsService *dataSyncService) initNodes(vchanInfo *datapb.VchannelInfo, tick futures = append(futures, future) } - err = conc.AwaitAll(futures...) - if err != nil { - return err + if err := conc.AwaitAll(futures...); err != nil { + return nil, err } - c := &nodeConfig{ - msFactory: dsService.msFactory, - collectionID: vchanInfo.GetCollectionID(), - vChannelName: vchanInfo.GetChannelName(), - channel: dsService.channel, - allocator: dsService.idAllocator, - - parallelConfig: newParallelConfig(), - serverID: dsService.serverID, + if tickler.isWatchFailed.Load() { + return nil, errors.Errorf("tickler watch failed") } + return channel, nil +} - var dmStreamNode Node - dmStreamNode, err = newDmInputNode(dsService.dispClient, vchanInfo.GetSeekPosition(), c) - if err != nil { - return err - } +func getServiceWithChannel(initCtx context.Context, node *DataNode, info *datapb.ChannelWatchInfo, channel Channel, unflushed, flushed []*datapb.SegmentInfo) (*dataSyncService, error) { + var ( + channelName = info.GetVchan().GetChannelName() + collectionID = info.GetVchan().GetCollectionID() + ) - var ddNode Node - ddNode, err = newDDNode( - dsService.ctx, - dsService.collectionID, - vchanInfo.GetChannelName(), - vchanInfo.GetDroppedSegmentIds(), - flushedSegmentInfos, - unflushedSegmentInfos, - dsService.compactor) - if err != nil { - return err + config := &nodeConfig{ + msFactory: node.factory, + allocator: node.allocator, + + collectionID: collectionID, + vChannelName: channelName, + channel: channel, + serverID: node.session.ServerID, } - var insertBufferNode Node - insertBufferNode, err = newInsertBufferNode( - dsService.ctx, - dsService.collectionID, - dsService.delBufferManager, - dsService.flushCh, - dsService.resendTTCh, - dsService.flushManager, - dsService.flushingSegCache, - c, - dsService.timetickSender, + var ( + flushCh = make(chan flushMsg, 100) + resendTTCh = make(chan resendTTMsg, 100) + delBufferManager = &DeltaBufferManager{ + channel: channel, + delBufHeap: &PriorityQueue{}, + } ) - if err != nil { - return err - } - var deleteNode Node - deleteNode, err = newDeleteNode(dsService.ctx, dsService.flushManager, dsService.delBufferManager, dsService.clearSignal, c) - if err != nil { - return err + ctx, cancel := context.WithCancel(node.ctx) + ds := &dataSyncService{ + ctx: ctx, + cancelFn: cancel, + flushCh: flushCh, + resendTTCh: resendTTCh, + delBufferManager: delBufferManager, + opID: info.GetOpID(), + + dispClient: node.dispClient, + msFactory: node.factory, + broker: node.broker, + + idAllocator: config.allocator, + channel: config.channel, + collectionID: config.collectionID, + vchannelName: config.vChannelName, + serverID: config.serverID, + + flushingSegCache: node.segmentCache, + clearSignal: node.clearSignal, + chunkManager: node.chunkManager, + compactor: node.compactionExecutor, + timetickSender: node.timeTickSender, + + fg: nil, + flushManager: nil, } - var ttNode Node - ttNode, err = newTTNode(c, dsService.dataCoord) - if err != nil { - return err + // init flushManager + if Params.CommonCfg.EnableStorageV2.GetAsBool() { + ds.flushManager = NewRendezvousFlushManagerV2(node.allocator, node.chunkManager, channel, + flushNotifyFunc2(ds, retry.Attempts(50)), dropVirtualChannelFunc(ds), ds, ds.cli) + } else { + ds.flushManager = NewRendezvousFlushManagerV2(node.allocator, node.chunkManager, channel, + flushNotifyFunc2(ds, retry.Attempts(50)), dropVirtualChannelFunc(ds), ds) } - dsService.fg.AddNode(dmStreamNode) - dsService.fg.AddNode(ddNode) - dsService.fg.AddNode(insertBufferNode) - dsService.fg.AddNode(deleteNode) - dsService.fg.AddNode(ttNode) + // flushManager := NewRendezvousFlushManager( + // node.allocator, + // node.chunkManager, + // channel, + // flushNotifyFunc(ds, retry.Attempts(50)), dropVirtualChannelFunc(ds), + // ) + // ds.flushManager = flushManager - // ddStreamNode - err = dsService.fg.SetEdges(dmStreamNode.Name(), - []string{ddNode.Name()}, - ) + // init flowgraph + fg := flowgraph.NewTimeTickedFlowGraph(node.ctx) + dmStreamNode, err := newDmInputNode(initCtx, node.dispClient, info.GetVchan().GetSeekPosition(), config) if err != nil { - log.Error("set edges failed in node", zap.String("name", dmStreamNode.Name()), zap.Error(err)) - return err + return nil, err } - // ddNode - err = dsService.fg.SetEdges(ddNode.Name(), - []string{insertBufferNode.Name()}, + ddNode, err := newDDNode( + node.ctx, + collectionID, + channelName, + info.GetVchan().GetDroppedSegmentIds(), + flushed, + unflushed, + node.compactionExecutor, ) if err != nil { - log.Error("set edges failed in node", zap.String("name", ddNode.Name()), zap.Error(err)) - return err + return nil, err } - // insertBufferNode - err = dsService.fg.SetEdges(insertBufferNode.Name(), - []string{deleteNode.Name()}, + insertBufferNode, err := newInsertBufferNode( + node.ctx, + flushCh, + resendTTCh, + delBufferManager, + flushManager, + node.segmentCache, + node.timeTickSender, + config, ) if err != nil { - log.Error("set edges failed in node", zap.String("name", insertBufferNode.Name()), zap.Error(err)) - return err + return nil, err } - //deleteNode - err = dsService.fg.SetEdges(deleteNode.Name(), - []string{ttNode.Name()}, - ) + deleteNode, err := newDeleteNode(node.ctx, flushManager, delBufferManager, node.clearSignal, config) if err != nil { - log.Error("set edges failed in node", zap.String("name", deleteNode.Name()), zap.Error(err)) - return err + return nil, err } - // ttNode - err = dsService.fg.SetEdges(ttNode.Name(), - []string{}, - ) + ttNode, err := newTTNode(config, node.broker) if err != nil { - log.Error("set edges failed in node", zap.String("name", ttNode.Name()), zap.Error(err)) - return err + return nil, err + } + + if err := fg.AssembleNodes(dmStreamNode, ddNode, insertBufferNode, deleteNode, ttNode); err != nil { + return nil, err } - return nil + ds.fg = fg + + return ds, nil } -// getSegmentInfos return the SegmentInfo details according to the given ids through RPC to datacoord -func (dsService *dataSyncService) getSegmentInfos(segmentIDs []int64) ([]*datapb.SegmentInfo, error) { - infoResp, err := dsService.dataCoord.GetSegmentInfo(dsService.ctx, &datapb.GetSegmentInfoRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_SegmentInfo), - commonpbutil.WithMsgID(0), - commonpbutil.WithSourceID(paramtable.GetNodeID()), - ), - SegmentIDs: segmentIDs, - IncludeUnHealthy: true, - }) +// newServiceWithEtcdTickler gets a dataSyncService, but flowgraphs are not running +// initCtx is used to init the dataSyncService only, if initCtx.Canceled or initCtx.Timeout +// newServiceWithEtcdTickler stops and returns the initCtx.Err() +func newServiceWithEtcdTickler(initCtx context.Context, node *DataNode, info *datapb.ChannelWatchInfo, tickler *etcdTickler) (*dataSyncService, error) { + // recover segment checkpoints + unflushedSegmentInfos, err := node.broker.GetSegmentInfo(initCtx, info.GetVchan().GetUnflushedSegmentIds()) if err != nil { - log.Error("Fail to get datapb.SegmentInfo by ids from datacoord", zap.Error(err)) return nil, err } - if infoResp.GetStatus().ErrorCode != commonpb.ErrorCode_Success { - err = errors.New(infoResp.GetStatus().Reason) - log.Error("Fail to get datapb.SegmentInfo by ids from datacoord", zap.Error(err)) + flushedSegmentInfos, err := node.broker.GetSegmentInfo(initCtx, info.GetVchan().GetFlushedSegmentIds()) + if err != nil { return nil, err } - return infoResp.Infos, nil + + // init channel meta + channel, err := getChannelWithEtcdTickler(initCtx, node, info, tickler, unflushedSegmentInfos, flushedSegmentInfos) + if err != nil { + return nil, err + } + + return getServiceWithChannel(initCtx, node, info, channel, unflushedSegmentInfos, flushedSegmentInfos) } -func (dsService *dataSyncService) getChannelLatestMsgID(ctx context.Context, channelName string, segmentID int64) ([]byte, error) { - pChannelName := funcutil.ToPhysicalChannel(channelName) - dmlStream, err := dsService.msFactory.NewMsgStream(ctx) +// newDataSyncService gets a dataSyncService, but flowgraphs are not running +// initCtx is used to init the dataSyncService only, if initCtx.Canceled or initCtx.Timeout +// newDataSyncService stops and returns the initCtx.Err() +// NOTE: compactiable for event manager +func newDataSyncService(initCtx context.Context, node *DataNode, info *datapb.ChannelWatchInfo, tickler *tickler) (*dataSyncService, error) { + // recover segment checkpoints + unflushedSegmentInfos, err := node.broker.GetSegmentInfo(initCtx, info.GetVchan().GetUnflushedSegmentIds()) + if err != nil { + return nil, err + } + flushedSegmentInfos, err := node.broker.GetSegmentInfo(initCtx, info.GetVchan().GetFlushedSegmentIds()) if err != nil { return nil, err } - defer dmlStream.Close() - subName := fmt.Sprintf("datanode-%d-%s-%d", paramtable.GetNodeID(), channelName, segmentID) - log.Debug("dataSyncService register consumer for getChannelLatestMsgID", - zap.String("pChannelName", pChannelName), - zap.String("subscription", subName), - ) - dmlStream.AsConsumer([]string{pChannelName}, subName, mqwrapper.SubscriptionPositionUnknown) - id, err := dmlStream.GetLatestMsgID(pChannelName) + // init channel meta + channel, err := getChannelWithTickler(initCtx, node, info, tickler, unflushedSegmentInfos, flushedSegmentInfos) if err != nil { - log.Error("fail to GetLatestMsgID", zap.String("pChannelName", pChannelName), zap.Error(err)) return nil, err } - return id.Serialize(), nil + + return getServiceWithChannel(initCtx, node, info, channel, unflushedSegmentInfos, flushedSegmentInfos) } diff --git a/internal/datanode/data_sync_service_test.go b/internal/datanode/data_sync_service_test.go index fdea3db63d6c6..114797ae20ff3 100644 --- a/internal/datanode/data_sync_service_test.go +++ b/internal/datanode/data_sync_service_test.go @@ -20,30 +20,34 @@ import ( "bytes" "context" "encoding/binary" + "fmt" "math" + "math/rand" "os" "testing" "time" - "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/datanode/allocator" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) var dataSyncServiceTestDir = "/tmp/milvus_test/data_sync_service" @@ -52,6 +56,12 @@ func init() { paramtable.Init() } +func getWatchInfo(info *testInfo) *datapb.ChannelWatchInfo { + return &datapb.ChannelWatchInfo{ + Vchan: getVchanInfo(info), + } +} + func getVchanInfo(info *testInfo) *datapb.VchannelInfo { var ufs []*datapb.SegmentInfo var fs []*datapb.SegmentInfo @@ -97,7 +107,7 @@ func getVchanInfo(info *testInfo) *datapb.VchannelInfo { type testInfo struct { isValidCase bool channelNil bool - inMsgFactory msgstream.Factory + inMsgFactory dependency.Factory collID UniqueID chanName string @@ -116,66 +126,44 @@ type testInfo struct { } func TestDataSyncService_newDataSyncService(t *testing.T) { - ctx := context.Background() tests := []*testInfo{ - {true, false, &mockMsgStreamFactory{false, true}, - 0, "by-dev-rootcoord-dml-test_v0", - 0, 0, "", 0, - 0, 0, "", 0, - "SetParamsReturnError"}, - {true, false, &mockMsgStreamFactory{true, true}, - 0, "by-dev-rootcoord-dml-test_v0", + { + true, false, &mockMsgStreamFactory{false, true}, + 1, "by-dev-rootcoord-dml-test_v0", 1, 0, "", 0, - 1, 1, "", 0, - "CollID 0 mismach with seginfo collID 1"}, - {true, false, &mockMsgStreamFactory{true, true}, - 1, "by-dev-rootcoord-dml-test_v1", - 1, 0, "by-dev-rootcoord-dml-test_v2", 0, - 1, 1, "by-dev-rootcoord-dml-test_v3", 0, - "chanName c1 mismach with seginfo chanName c2"}, - {true, false, &mockMsgStreamFactory{true, true}, + 1, 0, "", 0, + "SetParamsReturnError", + }, + { + true, false, &mockMsgStreamFactory{true, true}, 1, "by-dev-rootcoord-dml-test_v1", 1, 0, "by-dev-rootcoord-dml-test_v1", 0, 1, 1, "by-dev-rootcoord-dml-test_v2", 0, - "add normal segments"}, - {true, false, &mockMsgStreamFactory{true, true}, + "add normal segments", + }, + { + true, false, &mockMsgStreamFactory{true, true}, 1, "by-dev-rootcoord-dml-test_v1", 1, 1, "by-dev-rootcoord-dml-test_v1", 0, 1, 2, "by-dev-rootcoord-dml-test_v1", 0, - "add un-flushed and flushed segments"}, + "add un-flushed and flushed segments", + }, } cm := storage.NewLocalChunkManager(storage.RootPath(dataSyncServiceTestDir)) defer cm.RemoveWithPrefix(ctx, cm.RootPath()) + node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) + for _, test := range tests { t.Run(test.description, func(t *testing.T) { - df := &DataCoordFactory{} - rc := &RootCoordFactory{pkType: schemapb.DataType_Int64} - - channel := newChannel("channel", test.collID, nil, rc, cm) - if test.channelNil { - channel = nil - } - dispClient := msgdispatcher.NewClient(test.inMsgFactory, typeutil.DataNodeRole, paramtable.GetNodeID()) - - ds, err := newDataSyncService(ctx, - make(chan flushMsg), - make(chan resendTTMsg), - channel, - allocator.NewMockAllocator(t), - dispClient, - test.inMsgFactory, - getVchanInfo(test), - make(chan string), - df, - newCache(), - cm, - newCompactionExecutor(), + node.factory = test.inMsgFactory + ds, err := newServiceWithEtcdTickler( + ctx, + node, + getWatchInfo(test), genTestTickler(), - 0, - nil, ) if !test.isValidCase { @@ -191,40 +179,44 @@ func TestDataSyncService_newDataSyncService(t *testing.T) { } }) } - } // NOTE: start pulsar before test func TestDataSyncService_Start(t *testing.T) { const ctxTimeInMillisecond = 10000 + os.RemoveAll("/tmp/milvus") + defer os.RemoveAll("/tmp/milvus") delay := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) ctx, cancel := context.WithDeadline(context.Background(), delay) defer cancel() - // init data node - insertChannelName := "by-dev-rootcoord-dml" + node := newIDLEDataNodeMock(context.Background(), schemapb.DataType_Int64) + node.chunkManager = storage.NewLocalChunkManager(storage.RootPath(dataSyncServiceTestDir)) + defer node.chunkManager.RemoveWithPrefix(ctx, node.chunkManager.RootPath()) - Factory := &MetaFactory{} - collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) - mockRootCoord := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + broker := broker.NewMockBroker(t) + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() - flushChan := make(chan flushMsg, 100) - resendTTChan := make(chan resendTTMsg, 100) - cm := storage.NewLocalChunkManager(storage.RootPath(dataSyncServiceTestDir)) - defer cm.RemoveWithPrefix(ctx, cm.RootPath()) - channel := newChannel(insertChannelName, collMeta.ID, collMeta.GetSchema(), mockRootCoord, cm) + node.broker = broker alloc := allocator.NewMockAllocator(t) alloc.EXPECT().Alloc(mock.Anything).Call.Return(int64(22222), func(count uint32) int64 { return int64(22222 + count) }, nil) - factory := dependency.NewDefaultFactory(true) - dispClient := msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID()) - defer os.RemoveAll("/tmp/milvus") + node.allocator = alloc + + var ( + insertChannelName = fmt.Sprintf("by-dev-rootcoord-dml-%d", rand.Int()) + + Factory = &MetaFactory{} + collMeta = Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) + ) + paramtable.Get().Save(Params.DataNodeCfg.FlushInsertBufferSize.Key, "1") ufs := []*datapb.SegmentInfo{{ @@ -251,35 +243,48 @@ func TestDataSyncService_Start(t *testing.T) { for _, segmentInfo := range fs { fsIds = append(fsIds, segmentInfo.ID) } - vchan := &datapb.VchannelInfo{ - CollectionID: collMeta.ID, - ChannelName: insertChannelName, - UnflushedSegmentIds: ufsIds, - FlushedSegmentIds: fsIds, - } - - signalCh := make(chan string, 100) - - dataCoord := &DataCoordFactory{} - dataCoord.UserSegmentInfo = map[int64]*datapb.SegmentInfo{ - 0: { - ID: 0, - CollectionID: collMeta.ID, - PartitionID: 1, - InsertChannel: insertChannelName, - }, - 1: { - ID: 1, - CollectionID: collMeta.ID, - PartitionID: 1, - InsertChannel: insertChannelName, + watchInfo := &datapb.ChannelWatchInfo{ + Schema: collMeta.GetSchema(), + Vchan: &datapb.VchannelInfo{ + CollectionID: collMeta.ID, + ChannelName: insertChannelName, + UnflushedSegmentIds: ufsIds, + FlushedSegmentIds: fsIds, }, } - atimeTickSender := newTimeTickSender(dataCoord, 0) - sync, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, alloc, dispClient, factory, vchan, signalCh, dataCoord, newCache(), cm, newCompactionExecutor(), genTestTickler(), 0, atimeTickSender) - assert.Nil(t, err) + broker.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Call.Return( + func(_ context.Context, segmentIDs []int64) []*datapb.SegmentInfo { + data := map[int64]*datapb.SegmentInfo{ + 0: { + ID: 0, + CollectionID: collMeta.ID, + PartitionID: 1, + InsertChannel: insertChannelName, + }, + + 1: { + ID: 1, + CollectionID: collMeta.ID, + PartitionID: 1, + InsertChannel: insertChannelName, + }, + } + return lo.FilterMap(segmentIDs, func(id int64, _ int) (*datapb.SegmentInfo, bool) { + item, ok := data[id] + return item, ok + }) + }, nil) + + sync, err := newServiceWithEtcdTickler( + ctx, + node, + watchInfo, + genTestTickler(), + ) + require.NoError(t, err) + require.NotNil(t, sync) sync.flushListener = make(chan *segmentFlushPack) defer close(sync.flushListener) @@ -288,7 +293,7 @@ func TestDataSyncService_Start(t *testing.T) { timeRange := TimeRange{ timestampMin: 0, - timestampMax: math.MaxUint64, + timestampMax: math.MaxUint64 - 1, } dataFactory := NewDataFactory() insertMessages := dataFactory.GetMsgStreamTsInsertMsgs(2, insertChannelName, tsoutil.GetCurrentTime()) @@ -318,7 +323,7 @@ func TestDataSyncService_Start(t *testing.T) { Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_TimeTick, MsgID: UniqueID(0), - Timestamp: math.MaxUint64, + Timestamp: math.MaxUint64 - 1, SourceID: 0, }, }, @@ -327,7 +332,7 @@ func TestDataSyncService_Start(t *testing.T) { // pulsar produce assert.NoError(t, err) - insertStream, _ := factory.NewMsgStream(ctx) + insertStream, _ := node.factory.NewMsgStream(ctx) insertStream.AsProducer([]string{insertChannelName}) var insertMsgStream msgstream.MsgStream = insertStream @@ -361,13 +366,20 @@ func TestDataSyncService_Close(t *testing.T) { var ( insertChannelName = "by-dev-rootcoord-dml2" - metaFactory = &MetaFactory{} - mockRootCoord = &RootCoordFactory{pkType: schemapb.DataType_Int64} - - collMeta = metaFactory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) - cm = storage.NewLocalChunkManager(storage.RootPath(dataSyncServiceTestDir)) + metaFactory = &MetaFactory{} + collMeta = metaFactory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) + node = newIDLEDataNodeMock(context.Background(), schemapb.DataType_Int64) ) - defer cm.RemoveWithPrefix(ctx, cm.RootPath()) + node.chunkManager = storage.NewLocalChunkManager(storage.RootPath(dataSyncServiceTestDir)) + defer node.chunkManager.RemoveWithPrefix(ctx, node.chunkManager.RootPath()) + + broker := broker.NewMockBroker(t) + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + + node.broker = broker ufs := []*datapb.SegmentInfo{{ CollectionID: collMeta.ID, @@ -393,54 +405,62 @@ func TestDataSyncService_Close(t *testing.T) { for _, segmentInfo := range fs { fsIds = append(fsIds, segmentInfo.ID) } - vchan := &datapb.VchannelInfo{ - CollectionID: collMeta.ID, - ChannelName: insertChannelName, - UnflushedSegmentIds: ufsIds, - FlushedSegmentIds: fsIds, + watchInfo := &datapb.ChannelWatchInfo{ + Schema: collMeta.GetSchema(), + Vchan: &datapb.VchannelInfo{ + CollectionID: collMeta.ID, + ChannelName: insertChannelName, + UnflushedSegmentIds: ufsIds, + FlushedSegmentIds: fsIds, + }, } + alloc := allocator.NewMockAllocator(t) - alloc.EXPECT().AllocOne().Call.Return(int64(11111), nil) + alloc.EXPECT().AllocOne().Call.Return(int64(11111), nil).Maybe() alloc.EXPECT().Alloc(mock.Anything).Call.Return(int64(22222), func(count uint32) int64 { return int64(22222 + count) - }, nil) - - var ( - flushChan = make(chan flushMsg, 100) - resendTTChan = make(chan resendTTMsg, 100) - signalCh = make(chan string, 100) - - factory = dependency.NewDefaultFactory(true) - dispClient = msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID()) - mockDataCoord = &DataCoordFactory{} - ) - mockDataCoord.UserSegmentInfo = map[int64]*datapb.SegmentInfo{ - 0: { - ID: 0, - CollectionID: collMeta.ID, - PartitionID: 1, - InsertChannel: insertChannelName, - }, - - 1: { - ID: 1, - CollectionID: collMeta.ID, - PartitionID: 1, - InsertChannel: insertChannelName, - }, - } + }, nil).Maybe() + node.allocator = alloc + + broker.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Call.Return( + func(_ context.Context, segmentIDs []int64) []*datapb.SegmentInfo { + data := map[int64]*datapb.SegmentInfo{ + 0: { + ID: 0, + CollectionID: collMeta.ID, + PartitionID: 1, + InsertChannel: insertChannelName, + }, + + 1: { + ID: 1, + CollectionID: collMeta.ID, + PartitionID: 1, + InsertChannel: insertChannelName, + }, + } + segments := lo.FilterMap(segmentIDs, func(id int64, _ int) (*datapb.SegmentInfo, bool) { + item, ok := data[id] + return item, ok + }) + return segments + }, nil).Maybe() // No Auto flush paramtable.Get().Reset(Params.DataNodeCfg.FlushInsertBufferSize.Key) - channel := newChannel(insertChannelName, collMeta.ID, collMeta.GetSchema(), mockRootCoord, cm) - channel.syncPolicies = []segmentSyncPolicy{ + syncService, err := newServiceWithEtcdTickler( + context.Background(), + node, + watchInfo, + genTestTickler(), + ) + require.NoError(t, err) + assert.NotNil(t, syncService) + syncService.channel.(*ChannelMeta).syncPolicies = []segmentSyncPolicy{ syncMemoryTooHigh(), } - atimeTickSender := newTimeTickSender(mockDataCoord, 0) - syncService, err := newDataSyncService(ctx, flushChan, resendTTChan, channel, alloc, dispClient, factory, vchan, signalCh, mockDataCoord, newCache(), cm, newCompactionExecutor(), genTestTickler(), 0, atimeTickSender) - assert.NoError(t, err) syncService.flushListener = make(chan *segmentFlushPack, 10) defer close(syncService.flushListener) @@ -513,7 +533,7 @@ func TestDataSyncService_Close(t *testing.T) { // pulsar produce assert.NoError(t, err) - insertStream, _ := factory.NewMsgStream(ctx) + insertStream, _ := node.factory.NewMsgStream(ctx) insertStream.AsProducer([]string{insertChannelName}) var insertMsgStream msgstream.MsgStream = insertStream @@ -555,7 +575,7 @@ func genBytes() (rawData []byte) { const N = 1 // Float vector - var fvector = [DIM]float32{1, 2} + fvector := [DIM]float32{1, 2} for _, ele := range fvector { buf := make([]byte, 4) common.Endian.PutUint32(buf, math.Float32bits(ele)) @@ -565,11 +585,11 @@ func genBytes() (rawData []byte) { // Binary vector // Dimension of binary vector is 32 // size := 4, = 32 / 8 - var bvector = []byte{255, 255, 255, 0} + bvector := []byte{255, 255, 255, 0} rawData = append(rawData, bvector...) // Bool - var fieldBool = true + fieldBool := true buf := new(bytes.Buffer) if err := binary.Write(buf, common.Endian, fieldBool); err != nil { panic(err) @@ -594,12 +614,12 @@ func TestBytesReader(t *testing.T) { // Bytes Reader is able to recording the position rawDataReader := bytes.NewReader(rawData) - var fvector = make([]float32, 2) + fvector := make([]float32, 2) err := binary.Read(rawDataReader, common.Endian, &fvector) assert.NoError(t, err) assert.ElementsMatch(t, fvector, []float32{1, 2}) - var bvector = make([]byte, 4) + bvector := make([]byte, 4) err = binary.Read(rawDataReader, common.Endian, &bvector) assert.NoError(t, err) assert.ElementsMatch(t, bvector, []byte{255, 255, 255, 0}) @@ -615,70 +635,44 @@ func TestBytesReader(t *testing.T) { assert.Equal(t, int8(100), dataInt8) } -func TestGetSegmentInfos(t *testing.T) { - dataCoord := &DataCoordFactory{} - dsService := &dataSyncService{ - dataCoord: dataCoord, - } - segmentInfos, err := dsService.getSegmentInfos([]int64{1}) - assert.NoError(t, err) - assert.Equal(t, 1, len(segmentInfos)) - - dataCoord.GetSegmentInfosError = true - segmentInfos2, err := dsService.getSegmentInfos([]int64{1}) - assert.Error(t, err) - assert.Empty(t, segmentInfos2) - - dataCoord.GetSegmentInfosError = false - dataCoord.GetSegmentInfosNotSuccess = true - segmentInfos3, err := dsService.getSegmentInfos([]int64{1}) - assert.Error(t, err) - assert.Empty(t, segmentInfos3) - - dataCoord.GetSegmentInfosError = false - dataCoord.GetSegmentInfosNotSuccess = false - dataCoord.UserSegmentInfo = map[int64]*datapb.SegmentInfo{ - 5: { - ID: 100, - CollectionID: 101, - PartitionID: 102, - InsertChannel: "by-dev-rootcoord-dml-test_v1", - }, - } - - segmentInfos, err = dsService.getSegmentInfos([]int64{5}) - assert.NoError(t, err) - assert.Equal(t, 1, len(segmentInfos)) - assert.Equal(t, int64(100), segmentInfos[0].ID) -} - func TestClearGlobalFlushingCache(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - dataCoord := &DataCoordFactory{} + meta := NewMetaFactory().GetCollectionMeta(1, "test_collection", schemapb.DataType_Int64) + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + CollectionID: 1, + CollectionName: "test_collection", + Schema: meta.GetSchema(), + }, nil) cm := storage.NewLocalChunkManager(storage.RootPath(dataSyncServiceTestDir)) defer cm.RemoveWithPrefix(ctx, cm.RootPath()) - channel := newChannel("channel", 1, nil, &RootCoordFactory{pkType: schemapb.DataType_Int64}, cm) + channel := newChannel("channel", 1, nil, broker, cm) var err error cache := newCache() dsService := &dataSyncService{ - dataCoord: dataCoord, + broker: broker, channel: channel, flushingSegCache: cache, } err = channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_New, segID: 1, collID: 1, partitionID: 1, startPos: &msgpb.MsgPosition{}, - endPos: &msgpb.MsgPosition{}}) + endPos: &msgpb.MsgPosition{}, + }) assert.NoError(t, err) err = channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_Flushed, segID: 2, @@ -691,6 +685,7 @@ func TestClearGlobalFlushingCache(t *testing.T) { assert.NoError(t, err) err = channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_Normal, segID: 3, @@ -718,19 +713,72 @@ func TestGetChannelLatestMsgID(t *testing.T) { delay := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) ctx, cancel := context.WithDeadline(context.Background(), delay) defer cancel() - factory := dependency.NewDefaultFactory(true) - - dataCoord := &DataCoordFactory{} - dsService := &dataSyncService{ - dataCoord: dataCoord, - msFactory: factory, - } + node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) dmlChannelName := "fake-by-dev-rootcoord-dml-channel_12345v0" - insertStream, _ := factory.NewMsgStream(ctx) + insertStream, _ := node.factory.NewMsgStream(ctx) insertStream.AsProducer([]string{dmlChannelName}) - id, err := dsService.getChannelLatestMsgID(ctx, dmlChannelName, 0) + id, err := node.getChannelLatestMsgID(ctx, dmlChannelName, 0) assert.NoError(t, err) assert.NotNil(t, id) } + +func TestGetChannelWithTickler(t *testing.T) { + channelName := "by-dev-rootcoord-dml-0" + info := getWatchInfoByOpID(100, channelName, datapb.ChannelWatchState_ToWatch) + node := newIDLEDataNodeMock(context.Background(), schemapb.DataType_Int64) + node.chunkManager = storage.NewLocalChunkManager(storage.RootPath(dataSyncServiceTestDir)) + defer node.chunkManager.RemoveWithPrefix(context.Background(), node.chunkManager.RootPath()) + + meta := NewMetaFactory().GetCollectionMeta(1, "test_collection", schemapb.DataType_Int64) + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + CollectionID: 1, + CollectionName: "test_collection", + Schema: meta.GetSchema(), + }, nil) + node.broker = broker + + unflushed := []*datapb.SegmentInfo{ + { + ID: 100, + CollectionID: 1, + PartitionID: 10, + NumOfRows: 20, + }, + { + ID: 101, + CollectionID: 1, + PartitionID: 10, + NumOfRows: 20, + }, + } + + flushed := []*datapb.SegmentInfo{ + { + ID: 200, + CollectionID: 1, + PartitionID: 10, + NumOfRows: 20, + }, + { + ID: 201, + CollectionID: 1, + PartitionID: 10, + NumOfRows: 20, + }, + } + + channel, err := getChannelWithTickler(context.TODO(), node, info, newTickler(), unflushed, flushed) + assert.NoError(t, err) + assert.NotNil(t, channel) + assert.Equal(t, channelName, channel.getChannelName(100)) + assert.Equal(t, int64(1), channel.getCollectionID()) + assert.True(t, channel.hasSegment(100, true)) + assert.True(t, channel.hasSegment(101, true)) + assert.True(t, channel.hasSegment(200, true)) + assert.True(t, channel.hasSegment(201, true)) +} diff --git a/internal/datanode/event_manager.go b/internal/datanode/event_manager.go index afee116527f53..9db8448076144 100644 --- a/internal/datanode/event_manager.go +++ b/internal/datanode/event_manager.go @@ -17,20 +17,205 @@ package datanode import ( + "context" + "fmt" + "path" + "strings" "sync" "time" "github.com/golang/protobuf/proto" + v3rpc "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" "go.uber.org/atomic" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/logutil" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) const retryWatchInterval = 20 * time.Second +// StartWatchChannels start loop to watch channel allocation status via kv(etcd for now) +func (node *DataNode) StartWatchChannels(ctx context.Context) { + defer node.stopWaiter.Done() + defer logutil.LogPanic() + // REF MEP#7 watch path should be [prefix]/channel/{node_id}/{channel_name} + // TODO, this is risky, we'd better watch etcd with revision rather simply a path + watchPrefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.GetSession().ServerID)) + log.Info("Start watch channel", zap.String("prefix", watchPrefix)) + evtChan := node.watchKv.WatchWithPrefix(watchPrefix) + // after watch, first check all exists nodes first + err := node.checkWatchedList() + if err != nil { + log.Warn("StartWatchChannels failed", zap.Error(err)) + return + } + for { + select { + case <-ctx.Done(): + log.Info("watch etcd loop quit") + return + case event, ok := <-evtChan: + if !ok { + log.Warn("datanode failed to watch channel, return") + go node.StartWatchChannels(ctx) + return + } + + if err := event.Err(); err != nil { + log.Warn("datanode watch channel canceled", zap.Error(event.Err())) + // https://github.com/etcd-io/etcd/issues/8980 + if event.Err() == v3rpc.ErrCompacted { + go node.StartWatchChannels(ctx) + return + } + // if watch loop return due to event canceled, the datanode is not functional anymore + log.Panic("datanode is not functional for event canceled", zap.Error(err)) + return + } + for _, evt := range event.Events { + // We need to stay in order until events enqueued + node.handleChannelEvt(evt) + } + } + } +} + +// checkWatchedList list all nodes under [prefix]/channel/{node_id} and make sure all nodeds are watched +// serves the corner case for etcd connection lost and missing some events +func (node *DataNode) checkWatchedList() error { + // REF MEP#7 watch path should be [prefix]/channel/{node_id}/{channel_name} + prefix := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", paramtable.GetNodeID())) + keys, values, err := node.watchKv.LoadWithPrefix(prefix) + if err != nil { + return err + } + for i, val := range values { + node.handleWatchInfo(&event{eventType: putEventType}, keys[i], []byte(val)) + } + return nil +} + +func (node *DataNode) handleWatchInfo(e *event, key string, data []byte) { + switch e.eventType { + case putEventType: + watchInfo, err := parsePutEventData(data) + if err != nil { + log.Warn("fail to handle watchInfo", zap.Int("event type", e.eventType), zap.String("key", key), zap.Error(err)) + return + } + + if isEndWatchState(watchInfo.State) { + log.Info("DataNode received a PUT event with an end State", zap.String("state", watchInfo.State.String())) + return + } + + if watchInfo.Progress != 0 { + log.Info("DataNode received a PUT event with tickler update progress", zap.String("channel", watchInfo.Vchan.ChannelName), zap.Int64("version", e.version)) + return + } + + e.info = watchInfo + e.vChanName = watchInfo.GetVchan().GetChannelName() + log.Info("DataNode is handling watchInfo PUT event", zap.String("key", key), zap.Any("watch state", watchInfo.GetState().String())) + case deleteEventType: + e.vChanName = parseDeleteEventKey(key) + log.Info("DataNode is handling watchInfo DELETE event", zap.String("key", key)) + } + + actualManager, loaded := node.eventManagerMap.GetOrInsert(e.vChanName, newChannelEventManager( + node.handlePutEvent, node.handleDeleteEvent, retryWatchInterval, + )) + + if !loaded { + actualManager.Run() + } + + actualManager.handleEvent(*e) + + // Whenever a delete event comes, this eventManager will be removed from map + if e.eventType == deleteEventType { + if m, loaded := node.eventManagerMap.GetAndRemove(e.vChanName); loaded { + m.Close() + } + } +} + +func parsePutEventData(data []byte) (*datapb.ChannelWatchInfo, error) { + watchInfo := datapb.ChannelWatchInfo{} + err := proto.Unmarshal(data, &watchInfo) + if err != nil { + return nil, fmt.Errorf("invalid event data: fail to parse ChannelWatchInfo, err: %v", err) + } + + if watchInfo.Vchan == nil { + return nil, fmt.Errorf("invalid event: ChannelWatchInfo with nil VChannelInfo") + } + reviseVChannelInfo(watchInfo.GetVchan()) + return &watchInfo, nil +} + +func parseDeleteEventKey(key string) string { + parts := strings.Split(key, "/") + vChanName := parts[len(parts)-1] + return vChanName +} + +func (node *DataNode) handlePutEvent(watchInfo *datapb.ChannelWatchInfo, version int64) (err error) { + vChanName := watchInfo.GetVchan().GetChannelName() + key := path.Join(Params.CommonCfg.DataCoordWatchSubPath.GetValue(), fmt.Sprintf("%d", node.GetSession().ServerID), vChanName) + tickler := newEtcdTickler(version, key, watchInfo, node.watchKv, Params.DataNodeCfg.WatchEventTicklerInterval.GetAsDuration(time.Second)) + + switch watchInfo.State { + case datapb.ChannelWatchState_Uncomplete, datapb.ChannelWatchState_ToWatch: + if err := node.flowgraphManager.addAndStartWithEtcdTickler(node, watchInfo.GetVchan(), watchInfo.GetSchema(), tickler); err != nil { + log.Warn("handle put event: new data sync service failed", zap.String("vChanName", vChanName), zap.Error(err)) + watchInfo.State = datapb.ChannelWatchState_WatchFailure + } else { + log.Info("handle put event: new data sync service success", zap.String("vChanName", vChanName)) + watchInfo.State = datapb.ChannelWatchState_WatchSuccess + } + case datapb.ChannelWatchState_ToRelease: + // there is no reason why we release fail + node.tryToReleaseFlowgraph(vChanName) + watchInfo.State = datapb.ChannelWatchState_ReleaseSuccess + } + + v, err := proto.Marshal(watchInfo) + if err != nil { + return fmt.Errorf("fail to marshal watchInfo with state, vChanName: %s, state: %s ,err: %w", vChanName, watchInfo.State.String(), err) + } + + success, err := node.watchKv.CompareVersionAndSwap(key, tickler.version, string(v)) + // etcd error + if err != nil { + // flow graph will leak if not release, causing new datanode failed to subscribe + node.tryToReleaseFlowgraph(vChanName) + log.Warn("fail to update watch state to etcd", zap.String("vChanName", vChanName), + zap.String("state", watchInfo.State.String()), zap.Error(err)) + return err + } + // etcd valid but the states updated. + if !success { + log.Info("handle put event: failed to compare version and swap, release flowgraph", + zap.String("key", key), zap.String("state", watchInfo.State.String()), + zap.String("vChanName", vChanName)) + // flow graph will leak if not release, causing new datanode failed to subscribe + node.tryToReleaseFlowgraph(vChanName) + return nil + } + log.Info("handle put event success", zap.String("key", key), + zap.String("state", watchInfo.State.String()), zap.String("vChanName", vChanName)) + return nil +} + +func (node *DataNode) handleDeleteEvent(vChanName string) { + node.tryToReleaseFlowgraph(vChanName) +} + type event struct { eventType int vChanName string @@ -54,7 +239,8 @@ const ( ) func newChannelEventManager(handlePut func(*datapb.ChannelWatchInfo, int64) error, - handleDel func(string), retryInterval time.Duration) *channelEventManager { + handleDel func(string), retryInterval time.Duration, +) *channelEventManager { return &channelEventManager{ eventChan: make(chan event, 10), closeChan: make(chan struct{}), @@ -105,11 +291,11 @@ func isEndWatchState(state datapb.ChannelWatchState) bool { state != datapb.ChannelWatchState_Uncomplete // legacy state, equal to ToWatch } -type tickler struct { +type etcdTickler struct { progress *atomic.Int32 version int64 - kv kv.MetaKv + kv kv.WatchKV path string watchInfo *datapb.ChannelWatchInfo @@ -119,11 +305,11 @@ type tickler struct { isWatchFailed *atomic.Bool } -func (t *tickler) inc() { +func (t *etcdTickler) inc() { t.progress.Inc() } -func (t *tickler) watch() { +func (t *etcdTickler) watch() { if t.interval == 0 { log.Info("zero interval, close ticler watch", zap.String("channelName", t.watchInfo.GetVchan().GetChannelName()), @@ -177,13 +363,13 @@ func (t *tickler) watch() { }() } -func (t *tickler) stop() { +func (t *etcdTickler) stop() { close(t.closeCh) t.closeWg.Wait() } -func newTickler(version int64, path string, watchInfo *datapb.ChannelWatchInfo, kv kv.MetaKv, interval time.Duration) *tickler { - return &tickler{ +func newEtcdTickler(version int64, path string, watchInfo *datapb.ChannelWatchInfo, kv kv.WatchKV, interval time.Duration) *etcdTickler { + return &etcdTickler{ progress: atomic.NewInt32(0), path: path, kv: kv, diff --git a/internal/datanode/event_manager_test.go b/internal/datanode/event_manager_test.go index 9d7e6bb7cd0e4..f41421246d13d 100644 --- a/internal/datanode/event_manager_test.go +++ b/internal/datanode/event_manager_test.go @@ -17,18 +17,294 @@ package datanode import ( + "context" "fmt" + "math/rand" "path" + "strings" "testing" "time" "github.com/cockroachdb/errors" - "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/datanode/broker" + etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) +func TestWatchChannel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + node := newIDLEDataNodeMock(ctx, schemapb.DataType_Int64) + etcdCli, err := etcd.GetEtcdClient( + Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), + Params.EtcdCfg.EtcdUseSSL.GetAsBool(), + Params.EtcdCfg.Endpoints.GetAsStrings(), + Params.EtcdCfg.EtcdTLSCert.GetValue(), + Params.EtcdCfg.EtcdTLSKey.GetValue(), + Params.EtcdCfg.EtcdTLSCACert.GetValue(), + Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) + assert.NoError(t, err) + defer etcdCli.Close() + node.SetEtcdClient(etcdCli) + err = node.Init() + assert.NoError(t, err) + err = node.Start() + assert.NoError(t, err) + defer node.Stop() + err = node.Register() + assert.NoError(t, err) + + defer cancel() + + broker := broker.NewMockBroker(t) + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return([]*datapb.SegmentInfo{}, nil).Maybe() + broker.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + + node.broker = broker + + node.timeTickSender = newTimeTickSender(node.broker, 0) + + t.Run("test watch channel", func(t *testing.T) { + kv := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath.GetValue()) + oldInvalidCh := "datanode-etcd-test-by-dev-rootcoord-dml-channel-invalid" + path := fmt.Sprintf("%s/%d/%s", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID(), oldInvalidCh) + err = kv.Save(path, string([]byte{23})) + assert.NoError(t, err) + + ch := fmt.Sprintf("datanode-etcd-test-by-dev-rootcoord-dml-channel_%d", rand.Int31()) + path = fmt.Sprintf("%s/%d/%s", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID(), ch) + + vchan := &datapb.VchannelInfo{ + CollectionID: 1, + ChannelName: ch, + UnflushedSegmentIds: []int64{}, + } + info := &datapb.ChannelWatchInfo{ + State: datapb.ChannelWatchState_ToWatch, + Vchan: vchan, + } + val, err := proto.Marshal(info) + assert.NoError(t, err) + err = kv.Save(path, string(val)) + assert.NoError(t, err) + + assert.Eventually(t, func() bool { + exist := node.flowgraphManager.exist(ch) + if !exist { + return false + } + bs, err := kv.LoadBytes(fmt.Sprintf("%s/%d/%s", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID(), ch)) + if err != nil { + return false + } + watchInfo := &datapb.ChannelWatchInfo{} + err = proto.Unmarshal(bs, watchInfo) + if err != nil { + return false + } + return watchInfo.GetState() == datapb.ChannelWatchState_WatchSuccess + }, 3*time.Second, 100*time.Millisecond) + + err = kv.RemoveWithPrefix(fmt.Sprintf("%s/%d", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID())) + assert.NoError(t, err) + + assert.Eventually(t, func() bool { + exist := node.flowgraphManager.exist(ch) + return !exist + }, 3*time.Second, 100*time.Millisecond) + }) + + t.Run("Test release channel", func(t *testing.T) { + kv := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath.GetValue()) + oldInvalidCh := "datanode-etcd-test-by-dev-rootcoord-dml-channel-invalid" + path := fmt.Sprintf("%s/%d/%s", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID(), oldInvalidCh) + err = kv.Save(path, string([]byte{23})) + assert.NoError(t, err) + + ch := fmt.Sprintf("datanode-etcd-test-by-dev-rootcoord-dml-channel_%d", rand.Int31()) + path = fmt.Sprintf("%s/%d/%s", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID(), ch) + c := make(chan struct{}) + go func() { + ec := kv.WatchWithPrefix(fmt.Sprintf("%s/%d", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID())) + c <- struct{}{} + cnt := 0 + for { + evt := <-ec + for _, event := range evt.Events { + if strings.Contains(string(event.Kv.Key), ch) { + cnt++ + } + } + if cnt >= 2 { + break + } + } + c <- struct{}{} + }() + // wait for check goroutine start Watch + <-c + + vchan := &datapb.VchannelInfo{ + CollectionID: 1, + ChannelName: ch, + UnflushedSegmentIds: []int64{}, + } + info := &datapb.ChannelWatchInfo{ + State: datapb.ChannelWatchState_ToRelease, + Vchan: vchan, + } + val, err := proto.Marshal(info) + assert.NoError(t, err) + err = kv.Save(path, string(val)) + assert.NoError(t, err) + + // wait for check goroutine received 2 events + <-c + exist := node.flowgraphManager.exist(ch) + assert.False(t, exist) + + err = kv.RemoveWithPrefix(fmt.Sprintf("%s/%d", Params.CommonCfg.DataCoordWatchSubPath.GetValue(), paramtable.GetNodeID())) + assert.NoError(t, err) + // TODO there is not way to sync Release done, use sleep for now + time.Sleep(100 * time.Millisecond) + + exist = node.flowgraphManager.exist(ch) + assert.False(t, exist) + }) + + t.Run("handle watch info failed", func(t *testing.T) { + e := &event{ + eventType: putEventType, + } + + node.handleWatchInfo(e, "test1", []byte{23}) + + exist := node.flowgraphManager.exist("test1") + assert.False(t, exist) + + info := datapb.ChannelWatchInfo{ + Vchan: nil, + State: datapb.ChannelWatchState_Uncomplete, + } + bs, err := proto.Marshal(&info) + assert.NoError(t, err) + node.handleWatchInfo(e, "test2", bs) + + exist = node.flowgraphManager.exist("test2") + assert.False(t, exist) + + chPut := make(chan struct{}, 1) + chDel := make(chan struct{}, 1) + + ch := fmt.Sprintf("datanode-etcd-test-by-dev-rootcoord-dml-channel_%d", rand.Int31()) + m := newChannelEventManager( + func(info *datapb.ChannelWatchInfo, version int64) error { + r := node.handlePutEvent(info, version) + chPut <- struct{}{} + return r + }, + func(vChan string) { + node.handleDeleteEvent(vChan) + chDel <- struct{}{} + }, time.Millisecond*100, + ) + node.eventManagerMap.Insert(ch, m) + m.Run() + defer m.Close() + + info = datapb.ChannelWatchInfo{ + Vchan: &datapb.VchannelInfo{ChannelName: ch}, + State: datapb.ChannelWatchState_Uncomplete, + } + bs, err = proto.Marshal(&info) + assert.NoError(t, err) + + msFactory := node.factory + defer func() { node.factory = msFactory }() + + // todo review the UT logic + // As we remove timetick channel logic, flow_graph_insert_buffer_node no longer depend on MessageStreamFactory + // so data_sync_service can be created. this assert becomes true + node.factory = &FailMessageStreamFactory{} + node.handleWatchInfo(e, ch, bs) + <-chPut + exist = node.flowgraphManager.exist(ch) + assert.True(t, exist) + }) + + t.Run("handle watchinfo out of date", func(t *testing.T) { + chPut := make(chan struct{}, 1) + chDel := make(chan struct{}, 1) + // inject eventManager + ch := fmt.Sprintf("datanode-etcd-test-by-dev-rootcoord-dml-channel_%d", rand.Int31()) + m := newChannelEventManager( + func(info *datapb.ChannelWatchInfo, version int64) error { + r := node.handlePutEvent(info, version) + chPut <- struct{}{} + return r + }, + func(vChan string) { + node.handleDeleteEvent(vChan) + chDel <- struct{}{} + }, time.Millisecond*100, + ) + node.eventManagerMap.Insert(ch, m) + m.Run() + defer m.Close() + e := &event{ + eventType: putEventType, + version: 10000, + } + + info := datapb.ChannelWatchInfo{ + Vchan: &datapb.VchannelInfo{ChannelName: ch}, + State: datapb.ChannelWatchState_Uncomplete, + } + bs, err := proto.Marshal(&info) + assert.NoError(t, err) + + node.handleWatchInfo(e, ch, bs) + <-chPut + exist := node.flowgraphManager.exist("test3") + assert.False(t, exist) + }) + + t.Run("handle watchinfo compatibility", func(t *testing.T) { + info := datapb.ChannelWatchInfo{ + Vchan: &datapb.VchannelInfo{ + CollectionID: 1, + ChannelName: "delta-channel1", + UnflushedSegments: []*datapb.SegmentInfo{{ID: 1}}, + FlushedSegments: []*datapb.SegmentInfo{{ID: 2}}, + DroppedSegments: []*datapb.SegmentInfo{{ID: 3}}, + UnflushedSegmentIds: []int64{1}, + }, + State: datapb.ChannelWatchState_Uncomplete, + } + bs, err := proto.Marshal(&info) + assert.NoError(t, err) + + newWatchInfo, err := parsePutEventData(bs) + assert.NoError(t, err) + + assert.Equal(t, []*datapb.SegmentInfo{}, newWatchInfo.GetVchan().GetUnflushedSegments()) + assert.Equal(t, []*datapb.SegmentInfo{}, newWatchInfo.GetVchan().GetFlushedSegments()) + assert.Equal(t, []*datapb.SegmentInfo{}, newWatchInfo.GetVchan().GetDroppedSegments()) + assert.NotEmpty(t, newWatchInfo.GetVchan().GetUnflushedSegmentIds()) + assert.NotEmpty(t, newWatchInfo.GetVchan().GetFlushedSegmentIds()) + assert.NotEmpty(t, newWatchInfo.GetVchan().GetDroppedSegmentIds()) + }) +} + func TestChannelEventManager(t *testing.T) { t.Run("normal case", func(t *testing.T) { ch := make(chan struct{}, 1) @@ -150,7 +426,6 @@ func parseWatchInfo(key string, data []byte) (*datapb.ChannelWatchInfo, error) { watchInfo := datapb.ChannelWatchInfo{} if err := proto.Unmarshal(data, &watchInfo); err != nil { return nil, fmt.Errorf("invalid event data: fail to parse ChannelWatchInfo, key: %s, err: %v", key, err) - } if watchInfo.Vchan == nil { @@ -170,7 +445,7 @@ func TestEventTickler(t *testing.T) { kv.RemoveWithPrefix(etcdPrefix) defer kv.RemoveWithPrefix(etcdPrefix) - tickler := newTickler(0, path.Join(etcdPrefix, channelName), &datapb.ChannelWatchInfo{ + tickler := newEtcdTickler(0, path.Join(etcdPrefix, channelName), &datapb.ChannelWatchInfo{ Vchan: &datapb.VchannelInfo{ ChannelName: channelName, }, @@ -193,7 +468,6 @@ func TestEventTickler(t *testing.T) { } } } - }() tickler.inc() diff --git a/internal/datanode/flow_graph_dd_node.go b/internal/datanode/flow_graph_dd_node.go index ddb21c1aa2057..5487eebe36f5c 100644 --- a/internal/datanode/flow_graph_dd_node.go +++ b/internal/datanode/flow_graph_dd_node.go @@ -99,7 +99,7 @@ func (ddn *ddNode) Operate(in []Msg) []Msg { } if msMsg.IsCloseMsg() { - var fgMsg = flowGraphMsg{ + fgMsg := flowGraphMsg{ BaseMsg: flowgraph.NewBaseMsg(true), insertMessages: make([]*msgstream.InsertMsg, 0), timeRange: TimeRange{ @@ -133,7 +133,7 @@ func (ddn *ddNode) Operate(in []Msg) []Msg { } }() - var fgMsg = flowGraphMsg{ + fgMsg := flowGraphMsg{ insertMessages: make([]*msgstream.InsertMsg, 0), timeRange: TimeRange{ timestampMin: msMsg.TimestampMin(), @@ -279,8 +279,8 @@ func (ddn *ddNode) isDropped(segID UniqueID) bool { func (ddn *ddNode) Close() {} func newDDNode(ctx context.Context, collID UniqueID, vChannelName string, droppedSegmentIDs []UniqueID, - sealedSegments []*datapb.SegmentInfo, growingSegments []*datapb.SegmentInfo, compactor *compactionExecutor) (*ddNode, error) { - + sealedSegments []*datapb.SegmentInfo, growingSegments []*datapb.SegmentInfo, compactor *compactionExecutor, +) (*ddNode, error) { baseNode := BaseNode{} baseNode.SetMaxQueueLength(Params.DataNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()) baseNode.SetMaxParallelism(Params.DataNodeCfg.FlowGraphMaxParallelism.GetAsInt32()) diff --git a/internal/datanode/flow_graph_dd_node_test.go b/internal/datanode/flow_graph_dd_node_test.go index 9a8a3dd4e812b..f191c34e8e6ad 100644 --- a/internal/datanode/flow_graph_dd_node_test.go +++ b/internal/datanode/flow_graph_dd_node_test.go @@ -48,9 +48,11 @@ func TestFlowGraph_DDNode_newDDNode(t *testing.T) { []*datapb.SegmentInfo{ getSegmentInfo(100, 10000), getSegmentInfo(101, 10000), - getSegmentInfo(102, 10000)}, + getSegmentInfo(102, 10000), + }, []*datapb.SegmentInfo{ - getSegmentInfo(200, 10000)}, + getSegmentInfo(200, 10000), + }, }, { "0 sealed segments and 0 growing segment", @@ -94,12 +96,18 @@ func TestFlowGraph_DDNode_Operate(t *testing.T) { in []Msg description string }{ - {[]Msg{}, - "Invalid input length == 0"}, - {[]Msg{&flowGraphMsg{}, &flowGraphMsg{}, &flowGraphMsg{}}, - "Invalid input length == 3"}, - {[]Msg{&flowGraphMsg{}}, - "Invalid input length == 1 but input message is not msgStreamMsg"}, + { + []Msg{}, + "Invalid input length == 0", + }, + { + []Msg{&flowGraphMsg{}, &flowGraphMsg{}, &flowGraphMsg{}}, + "Invalid input length == 3", + }, + { + []Msg{&flowGraphMsg{}}, + "Invalid input length == 1 but input message is not msgStreamMsg", + }, } for _, test := range invalidInTests { @@ -117,10 +125,14 @@ func TestFlowGraph_DDNode_Operate(t *testing.T) { description string }{ - {1, 1, 1, - "DropCollectionMsg collID == ddNode collID"}, - {1, 2, 0, - "DropCollectionMsg collID != ddNode collID"}, + { + 1, 1, 1, + "DropCollectionMsg collID == ddNode collID", + }, + { + 1, 2, 0, + "DropCollectionMsg collID != ddNode collID", + }, } for _, test := range tests { @@ -164,10 +176,16 @@ func TestFlowGraph_DDNode_Operate(t *testing.T) { description string }{ - {1, 1, 101, []UniqueID{101}, - "DropCollectionMsg collID == ddNode collID"}, - {1, 2, 101, []UniqueID{}, - "DropCollectionMsg collID != ddNode collID"}, + { + 1, 1, 101, + []UniqueID{101}, + "DropCollectionMsg collID == ddNode collID", + }, + { + 1, 2, 101, + []UniqueID{}, + "DropCollectionMsg collID != ddNode collID", + }, } for _, test := range tests { @@ -195,15 +213,12 @@ func TestFlowGraph_DDNode_Operate(t *testing.T) { fgMsg, ok := rt[0].(*flowGraphMsg) assert.True(t, ok) assert.ElementsMatch(t, test.expectOutput, fgMsg.dropPartitions) - }) } }) t.Run("Test DDNode Operate and filter insert msg", func(t *testing.T) { - var ( - collectionID UniqueID = 1 - ) + var collectionID UniqueID = 1 // Prepare ddNode states ddn := ddNode{ ctx: context.Background(), @@ -260,7 +275,6 @@ func TestFlowGraph_DDNode_Operate(t *testing.T) { }) } }) - } func TestFlowGraph_DDNode_filterMessages(t *testing.T) { @@ -274,19 +288,24 @@ func TestFlowGraph_DDNode_filterMessages(t *testing.T) { inMsg *msgstream.InsertMsg expected bool }{ - {"test dropped segments true", + { + "test dropped segments true", []UniqueID{100}, nil, nil, getInsertMsg(100, 10000), - true}, - {"test dropped segments true 2", + true, + }, + { + "test dropped segments true 2", []UniqueID{100, 101, 102}, nil, nil, getInsertMsg(102, 10000), - true}, - {"test sealed segments msgTs <= segmentTs true", + true, + }, + { + "test sealed segments msgTs <= segmentTs true", []UniqueID{}, map[UniqueID]*datapb.SegmentInfo{ 200: getSegmentInfo(200, 50000), @@ -294,8 +313,10 @@ func TestFlowGraph_DDNode_filterMessages(t *testing.T) { }, nil, getInsertMsg(200, 10000), - true}, - {"test sealed segments msgTs <= segmentTs true", + true, + }, + { + "test sealed segments msgTs <= segmentTs true", []UniqueID{}, map[UniqueID]*datapb.SegmentInfo{ 200: getSegmentInfo(200, 50000), @@ -303,8 +324,10 @@ func TestFlowGraph_DDNode_filterMessages(t *testing.T) { }, nil, getInsertMsg(200, 50000), - true}, - {"test sealed segments msgTs > segmentTs false", + true, + }, + { + "test sealed segments msgTs > segmentTs false", []UniqueID{}, map[UniqueID]*datapb.SegmentInfo{ 200: getSegmentInfo(200, 50000), @@ -312,8 +335,10 @@ func TestFlowGraph_DDNode_filterMessages(t *testing.T) { }, nil, getInsertMsg(222, 70000), - false}, - {"test growing segments msgTs <= segmentTs true", + false, + }, + { + "test growing segments msgTs <= segmentTs true", []UniqueID{}, nil, map[UniqueID]*datapb.SegmentInfo{ @@ -321,8 +346,10 @@ func TestFlowGraph_DDNode_filterMessages(t *testing.T) { 300: getSegmentInfo(300, 50000), }, getInsertMsg(200, 10000), - true}, - {"test growing segments msgTs > segmentTs false", + true, + }, + { + "test growing segments msgTs > segmentTs false", []UniqueID{}, nil, map[UniqueID]*datapb.SegmentInfo{ @@ -330,8 +357,10 @@ func TestFlowGraph_DDNode_filterMessages(t *testing.T) { 300: getSegmentInfo(300, 50000), }, getInsertMsg(200, 70000), - false}, - {"test not exist", + false, + }, + { + "test not exist", []UniqueID{}, map[UniqueID]*datapb.SegmentInfo{ 400: getSegmentInfo(500, 50000), @@ -342,14 +371,17 @@ func TestFlowGraph_DDNode_filterMessages(t *testing.T) { 300: getSegmentInfo(300, 50000), }, getInsertMsg(111, 70000), - false}, + false, + }, // for pChannel reuse on same collection - {"test insert msg with different channelName", + { + "test insert msg with different channelName", []UniqueID{100}, nil, nil, getInsertMsgWithChannel(100, 10000, anotherChannelName), - true}, + true, + }, } for _, test := range tests { @@ -364,7 +396,6 @@ func TestFlowGraph_DDNode_filterMessages(t *testing.T) { // Test got := ddn.tryToFilterSegmentInsertMessages(test.inMsg) assert.Equal(t, test.expected, got) - }) } @@ -380,33 +411,39 @@ func TestFlowGraph_DDNode_filterMessages(t *testing.T) { inMsg *msgstream.InsertMsg msgFiltered bool }{ - {"msgTssegTs", + { + "msgTs>segTs", false, 50000, 10000, map[UniqueID]*datapb.SegmentInfo{ 100: getSegmentInfo(100, 70000), - 101: getSegmentInfo(101, 50000)}, + 101: getSegmentInfo(101, 50000), + }, getInsertMsg(300, 60000), false, }, @@ -440,27 +477,33 @@ func TestFlowGraph_DDNode_filterMessages(t *testing.T) { inMsg *msgstream.InsertMsg msgFiltered bool }{ - {"msgTssegTs", + { + "msgTs>segTs", false, map[UniqueID]*datapb.SegmentInfo{ 100: getSegmentInfo(100, 50000), - 101: getSegmentInfo(101, 50000)}, + 101: getSegmentInfo(101, 50000), + }, getInsertMsg(100, 60000), false, }, @@ -497,16 +540,31 @@ func TestFlowGraph_DDNode_isDropped(t *testing.T) { description string }{ - {[]*datapb.SegmentInfo{getSegmentInfo(1, 0), getSegmentInfo(2, 0), getSegmentInfo(3, 0)}, 1, true, - "Input seg 1 in droppedSegs{1,2,3}"}, - {[]*datapb.SegmentInfo{getSegmentInfo(1, 0), getSegmentInfo(2, 0), getSegmentInfo(3, 0)}, 2, true, - "Input seg 2 in droppedSegs{1,2,3}"}, - {[]*datapb.SegmentInfo{getSegmentInfo(1, 0), getSegmentInfo(2, 0), getSegmentInfo(3, 0)}, 3, true, - "Input seg 3 in droppedSegs{1,2,3}"}, - {[]*datapb.SegmentInfo{getSegmentInfo(1, 0), getSegmentInfo(2, 0), getSegmentInfo(3, 0)}, 4, false, - "Input seg 4 not in droppedSegs{1,2,3}"}, - {[]*datapb.SegmentInfo{}, 5, false, - "Input seg 5, no droppedSegs {}"}, + { + []*datapb.SegmentInfo{getSegmentInfo(1, 0), getSegmentInfo(2, 0), getSegmentInfo(3, 0)}, + 1, true, + "Input seg 1 in droppedSegs{1,2,3}", + }, + { + []*datapb.SegmentInfo{getSegmentInfo(1, 0), getSegmentInfo(2, 0), getSegmentInfo(3, 0)}, + 2, true, + "Input seg 2 in droppedSegs{1,2,3}", + }, + { + []*datapb.SegmentInfo{getSegmentInfo(1, 0), getSegmentInfo(2, 0), getSegmentInfo(3, 0)}, + 3, true, + "Input seg 3 in droppedSegs{1,2,3}", + }, + { + []*datapb.SegmentInfo{getSegmentInfo(1, 0), getSegmentInfo(2, 0), getSegmentInfo(3, 0)}, + 4, false, + "Input seg 4 not in droppedSegs{1,2,3}", + }, + { + []*datapb.SegmentInfo{}, + 5, false, + "Input seg 5, no droppedSegs {}", + }, } for _, test := range tests { diff --git a/internal/datanode/flow_graph_delete_node.go b/internal/datanode/flow_graph_delete_node.go index f8d566926ef60..faccda3a53fdd 100644 --- a/internal/datanode/flow_graph_delete_node.go +++ b/internal/datanode/flow_graph_delete_node.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -99,6 +100,7 @@ func (dn *deleteNode) Operate(in []Msg) []Msg { log.Debug("Buffer delete request in DataNode", zap.String("traceID", traceID)) tmpSegIDs, err := dn.bufferDeleteMsg(msg, fgMsg.timeRange, fgMsg.startPositions[0], fgMsg.endPositions[0]) if err != nil { + // should not happen // error occurs only when deleteMsg is misaligned, should not happen log.Fatal("failed to buffer delete msg", zap.String("traceID", traceID), zap.Error(err)) } @@ -127,6 +129,10 @@ func (dn *deleteNode) Operate(in []Msg) []Msg { return dn.flushManager.flushDelData(buf, segmentToFlush, fgMsg.endPositions[0]) }, getFlowGraphRetryOpt()) if err != nil { + if merr.IsCanceledOrTimeout(err) { + log.Warn("skip syncing delete data for context done", zap.Int64("segmentID", segmentToFlush)) + continue + } log.Fatal("failed to flush delete data", zap.Int64("segmentID", segmentToFlush), zap.Error(err)) } // remove delete buf @@ -172,7 +178,8 @@ func (dn *deleteNode) bufferDeleteMsg(msg *msgstream.DeleteMsg, tr TimeRange, st // If the key may exist in the segment, returns it in map. // If the key not exist in the segment, the segment is filter out. func (dn *deleteNode) filterSegmentByPK(partID UniqueID, pks []primaryKey, tss []Timestamp) ( - map[UniqueID][]primaryKey, map[UniqueID][]uint64) { + map[UniqueID][]primaryKey, map[UniqueID][]uint64, +) { segID2Pks := make(map[UniqueID][]primaryKey) segID2Tss := make(map[UniqueID][]uint64) segments := dn.channel.filterSegments(partID) @@ -191,8 +198,8 @@ func (dn *deleteNode) filterSegmentByPK(partID UniqueID, pks []primaryKey, tss [ func newDeleteNode(ctx context.Context, fm flushManager, manager *DeltaBufferManager, sig chan<- string, config *nodeConfig) (*deleteNode, error) { baseNode := BaseNode{} - baseNode.SetMaxQueueLength(config.maxQueueLength) - baseNode.SetMaxParallelism(config.maxParallelism) + baseNode.SetMaxQueueLength(Params.DataNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()) + baseNode.SetMaxParallelism(Params.DataNodeCfg.FlowGraphMaxParallelism.GetAsInt32()) return &deleteNode{ ctx: ctx, diff --git a/internal/datanode/flow_graph_delete_node_test.go b/internal/datanode/flow_graph_delete_node_test.go index 0d237f1efc342..87faf6806ac38 100644 --- a/internal/datanode/flow_graph_delete_node_test.go +++ b/internal/datanode/flow_graph_delete_node_test.go @@ -140,12 +140,18 @@ func TestFlowGraphDeleteNode_Operate(t *testing.T) { in []Msg desc string }{ - {[]Msg{}, - "Invalid input length == 0"}, - {[]Msg{&flowGraphMsg{}, &flowGraphMsg{}, &flowGraphMsg{}}, - "Invalid input length == 3"}, - {[]Msg{&flowgraph.MsgStreamMsg{}}, - "Invalid input length == 1 but input message is not flowGraphMsg"}, + { + []Msg{}, + "Invalid input length == 0", + }, + { + []Msg{&flowGraphMsg{}, &flowGraphMsg{}, &flowGraphMsg{}}, + "Invalid input length == 3", + }, + { + []Msg{&flowgraph.MsgStreamMsg{}}, + "Invalid input length == 1 but input message is not flowGraphMsg", + }, } for _, test := range invalidInTests { @@ -399,7 +405,7 @@ func TestFlowGraphDeleteNode_Operate(t *testing.T) { }) t.Run("Test deleteNode auto flush function", func(t *testing.T) { - //for issue + // for issue ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -422,8 +428,8 @@ func TestFlowGraphDeleteNode_Operate(t *testing.T) { delNode, err := newDeleteNode(ctx, mockFlushManager, delBufManager, make(chan string, 1), c) assert.NoError(t, err) - //2. here we set flushing segments inside fgmsg to empty - //in order to verify the validity of auto flush function + // 2. here we set flushing segments inside fgmsg to empty + // in order to verify the validity of auto flush function msg := genFlowGraphDeleteMsg(int64Pks, chanName) // delete has to match segment partition ID @@ -433,9 +439,9 @@ func TestFlowGraphDeleteNode_Operate(t *testing.T) { msg.segmentsToSync = []UniqueID{} var fgMsg flowgraph.Msg = &msg - //1. here we set buffer bytes to a relatively high level - //and the sum of memory consumption in this case is 208 - //so no segments will be flushed + // 1. here we set buffer bytes to a relatively high level + // and the sum of memory consumption in this case is 208 + // so no segments will be flushed paramtable.Get().Save(Params.DataNodeCfg.FlushDeleteBufferBytes.Key, "300") fgMsg.(*flowGraphMsg).segmentsToSync = delNode.delBufferManager.ShouldFlushSegments() delNode.Operate([]flowgraph.Msg{fgMsg}) @@ -443,10 +449,10 @@ func TestFlowGraphDeleteNode_Operate(t *testing.T) { assert.Equal(t, int64(208), delNode.delBufferManager.usedMemory.Load()) assert.Equal(t, 5, delNode.delBufferManager.delBufHeap.Len()) - //3. note that the whole memory size used by 5 segments will be 208 - //so when setting delete buffer size equal to 200 - //there will only be one segment to be flushed then the - //memory consumption will be reduced to 160(under 200) + // 3. note that the whole memory size used by 5 segments will be 208 + // so when setting delete buffer size equal to 200 + // there will only be one segment to be flushed then the + // memory consumption will be reduced to 160(under 200) msg.deleteMessages = []*msgstream.DeleteMsg{} msg.segmentsToSync = []UniqueID{} paramtable.Get().Save(Params.DataNodeCfg.FlushDeleteBufferBytes.Key, "200") @@ -456,17 +462,17 @@ func TestFlowGraphDeleteNode_Operate(t *testing.T) { assert.Equal(t, int64(160), delNode.delBufferManager.usedMemory.Load()) assert.Equal(t, 4, delNode.delBufferManager.delBufHeap.Len()) - //4. there is no new delete msg and delBufferSize is still 200 - //we expect there will not be any auto flush del + // 4. there is no new delete msg and delBufferSize is still 200 + // we expect there will not be any auto flush del fgMsg.(*flowGraphMsg).segmentsToSync = delNode.delBufferManager.ShouldFlushSegments() delNode.Operate([]flowgraph.Msg{fgMsg}) assert.Equal(t, 1, len(mockFlushManager.flushedSegIDs)) assert.Equal(t, int64(160), delNode.delBufferManager.usedMemory.Load()) assert.Equal(t, 4, delNode.delBufferManager.delBufHeap.Len()) - //5. we reset buffer bytes to 150, then we expect there would be one more - //segment which is 48 in size to be flushed, so the remained del memory size - //will be 112 + // 5. we reset buffer bytes to 150, then we expect there would be one more + // segment which is 48 in size to be flushed, so the remained del memory size + // will be 112 paramtable.Get().Save(Params.DataNodeCfg.FlushDeleteBufferBytes.Key, "150") fgMsg.(*flowGraphMsg).segmentsToSync = delNode.delBufferManager.ShouldFlushSegments() delNode.Operate([]flowgraph.Msg{fgMsg}) @@ -474,8 +480,8 @@ func TestFlowGraphDeleteNode_Operate(t *testing.T) { assert.Equal(t, int64(112), delNode.delBufferManager.usedMemory.Load()) assert.Equal(t, 3, delNode.delBufferManager.delBufHeap.Len()) - //6. we reset buffer bytes to 60, then most of the segments will be flushed - //except for the smallest entry with size equaling to 32 + // 6. we reset buffer bytes to 60, then most of the segments will be flushed + // except for the smallest entry with size equaling to 32 paramtable.Get().Save(Params.DataNodeCfg.FlushDeleteBufferBytes.Key, "60") fgMsg.(*flowGraphMsg).segmentsToSync = delNode.delBufferManager.ShouldFlushSegments() delNode.Operate([]flowgraph.Msg{fgMsg}) @@ -483,9 +489,9 @@ func TestFlowGraphDeleteNode_Operate(t *testing.T) { assert.Equal(t, int64(32), delNode.delBufferManager.usedMemory.Load()) assert.Equal(t, 1, delNode.delBufferManager.delBufHeap.Len()) - //7. we reset buffer bytes to 20, then as all segment-memory consumption - //is more than 20, so all five segments will be flushed and the remained - //del memory will be lowered to zero + // 7. we reset buffer bytes to 20, then as all segment-memory consumption + // is more than 20, so all five segments will be flushed and the remained + // del memory will be lowered to zero paramtable.Get().Save(Params.DataNodeCfg.FlushDeleteBufferBytes.Key, "20") fgMsg.(*flowGraphMsg).segmentsToSync = delNode.delBufferManager.ShouldFlushSegments() delNode.Operate([]flowgraph.Msg{fgMsg}) diff --git a/internal/datanode/flow_graph_dmstream_input_node.go b/internal/datanode/flow_graph_dmstream_input_node.go index 21e3ecf56edcf..7add6b06f6cc9 100644 --- a/internal/datanode/flow_graph_dmstream_input_node.go +++ b/internal/datanode/flow_graph_dmstream_input_node.go @@ -17,6 +17,7 @@ package datanode import ( + "context" "fmt" "time" @@ -38,14 +39,14 @@ import ( // // messages between two timeticks to the following flowgraph node. In DataNode, the following flow graph node is // flowgraph ddNode. -func newDmInputNode(dispatcherClient msgdispatcher.Client, seekPos *msgpb.MsgPosition, dmNodeConfig *nodeConfig) (*flowgraph.InputNode, error) { +func newDmInputNode(initCtx context.Context, dispatcherClient msgdispatcher.Client, seekPos *msgpb.MsgPosition, dmNodeConfig *nodeConfig) (*flowgraph.InputNode, error) { log := log.With(zap.Int64("nodeID", paramtable.GetNodeID()), zap.Int64("collectionID", dmNodeConfig.collectionID), zap.String("vchannel", dmNodeConfig.vChannelName)) var err error var input <-chan *msgstream.MsgPack if seekPos != nil && len(seekPos.MsgID) != 0 { - input, err = dispatcherClient.Register(dmNodeConfig.vChannelName, seekPos, mqwrapper.SubscriptionPositionUnknown) + input, err = dispatcherClient.Register(initCtx, dmNodeConfig.vChannelName, seekPos, mqwrapper.SubscriptionPositionUnknown) if err != nil { return nil, err } @@ -54,7 +55,7 @@ func newDmInputNode(dispatcherClient msgdispatcher.Client, seekPos *msgpb.MsgPos zap.Time("tsTime", tsoutil.PhysicalTime(seekPos.GetTimestamp())), zap.Duration("tsLag", time.Since(tsoutil.PhysicalTime(seekPos.GetTimestamp())))) } else { - input, err = dispatcherClient.Register(dmNodeConfig.vChannelName, nil, mqwrapper.SubscriptionPositionEarliest) + input, err = dispatcherClient.Register(initCtx, dmNodeConfig.vChannelName, nil, mqwrapper.SubscriptionPositionEarliest) if err != nil { return nil, err } @@ -65,8 +66,8 @@ func newDmInputNode(dispatcherClient msgdispatcher.Client, seekPos *msgpb.MsgPos node := flowgraph.NewInputNode( input, name, - dmNodeConfig.maxQueueLength, - dmNodeConfig.maxParallelism, + Params.DataNodeCfg.FlowGraphMaxQueueLength.GetAsInt32(), + Params.DataNodeCfg.FlowGraphMaxParallelism.GetAsInt32(), typeutil.DataNodeRole, paramtable.GetNodeID(), dmNodeConfig.collectionID, diff --git a/internal/datanode/flow_graph_dmstream_input_node_test.go b/internal/datanode/flow_graph_dmstream_input_node_test.go index ce2a601335db0..75df57af0b49c 100644 --- a/internal/datanode/flow_graph_dmstream_input_node_test.go +++ b/internal/datanode/flow_graph_dmstream_input_node_test.go @@ -24,6 +24,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" @@ -36,13 +38,14 @@ type mockMsgStreamFactory struct { NewMsgStreamNoError bool } -var _ msgstream.Factory = &mockMsgStreamFactory{} +var ( + _ msgstream.Factory = &mockMsgStreamFactory{} + _ dependency.Factory = (*mockMsgStreamFactory)(nil) +) -func (mm *mockMsgStreamFactory) Init(params *paramtable.ComponentParam) error { - if !mm.InitReturnNil { - return errors.New("Init Error") - } - return nil +func (mm *mockMsgStreamFactory) Init(params *paramtable.ComponentParam) {} +func (mm *mockMsgStreamFactory) NewPersistentStorageChunkManager(ctx context.Context) (storage.ChunkManager, error) { + return nil, nil } func (mm *mockMsgStreamFactory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) { @@ -60,8 +63,7 @@ func (mm *mockMsgStreamFactory) NewMsgStreamDisposer(ctx context.Context) func([ return nil } -type mockTtMsgStream struct { -} +type mockTtMsgStream struct{} func (mtm *mockTtMsgStream) Close() {} @@ -71,7 +73,8 @@ func (mtm *mockTtMsgStream) Chan() <-chan *msgstream.MsgPack { func (mtm *mockTtMsgStream) AsProducer(channels []string) {} -func (mtm *mockTtMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) { +func (mtm *mockTtMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error { + return nil } func (mtm *mockTtMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {} @@ -88,7 +91,7 @@ func (mtm *mockTtMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstrea return nil, nil } -func (mtm *mockTtMsgStream) Seek(offset []*msgpb.MsgPosition) error { +func (mtm *mockTtMsgStream) Seek(ctx context.Context, offset []*msgpb.MsgPosition) error { return nil } @@ -100,9 +103,12 @@ func (mtm *mockTtMsgStream) CheckTopicValid(channel string) error { return nil } +func (mtm *mockTtMsgStream) EnableProduce(can bool) { +} + func TestNewDmInputNode(t *testing.T) { client := msgdispatcher.NewClient(&mockMsgStreamFactory{}, typeutil.DataNodeRole, paramtable.GetNodeID()) - _, err := newDmInputNode(client, new(msgpb.MsgPosition), &nodeConfig{ + _, err := newDmInputNode(context.Background(), client, new(msgpb.MsgPosition), &nodeConfig{ msFactory: &mockMsgStreamFactory{}, vChannelName: "mock_vchannel_0", }) diff --git a/internal/datanode/flow_graph_insert_buffer_node.go b/internal/datanode/flow_graph_insert_buffer_node.go index 84f73810e0ebd..3e816bfb98a42 100644 --- a/internal/datanode/flow_graph_insert_buffer_node.go +++ b/internal/datanode/flow_graph_insert_buffer_node.go @@ -22,6 +22,7 @@ import ( "math" "reflect" + "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" "go.opentelemetry.io/otel/trace" "go.uber.org/atomic" @@ -38,6 +39,7 @@ import ( "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/tsoutil" @@ -308,7 +310,7 @@ type syncTask struct { } func (ibNode *insertBufferNode) FillInSyncTasks(fgMsg *flowGraphMsg, seg2Upload []UniqueID) map[UniqueID]*syncTask { - var syncTasks = make(map[UniqueID]*syncTask) + syncTasks := make(map[UniqueID]*syncTask) if fgMsg.dropCollection { // All segments in the collection will be synced, not matter empty buffer or not @@ -376,10 +378,10 @@ func (ibNode *insertBufferNode) FillInSyncTasks(fgMsg *flowGraphMsg, seg2Upload } // sync delete - //here we adopt a quite radical strategy: - //every time we make sure that the N biggest delDataBuf can be flushed - //when memsize usage reaches a certain level - //the aim for taking all these actions is to guarantee that the memory consumed by delBuf will not exceed a limit + // here we adopt a quite radical strategy: + // every time we make sure that the N biggest delDataBuf can be flushed + // when memsize usage reaches a certain level + // the aim for taking all these actions is to guarantee that the memory consumed by delBuf will not exceed a limit segmentsToFlush := ibNode.delBufferManager.ShouldFlushSegments() for _, segID := range segmentsToFlush { syncTasks[segID] = &syncTask{ @@ -450,7 +452,7 @@ func (ibNode *insertBufferNode) FillInSyncTasks(fgMsg *flowGraphMsg, seg2Upload func (ibNode *insertBufferNode) Sync(fgMsg *flowGraphMsg, seg2Upload []UniqueID, endPosition *msgpb.MsgPosition) []UniqueID { syncTasks := ibNode.FillInSyncTasks(fgMsg, seg2Upload) segmentsToSync := make([]UniqueID, 0, len(syncTasks)) - ibNode.channel.(*ChannelMeta).needToSync.Store(false) + ibNode.channel.setIsHighMemory(false) for _, task := range syncTasks { log := log.With(zap.Int64("segmentID", task.segmentID), @@ -462,6 +464,7 @@ func (ibNode *insertBufferNode) Sync(fgMsg *flowGraphMsg, seg2Upload []UniqueID, ).WithRateGroup("ibNode.sync", 1, 60) // check if segment is syncing segment := ibNode.channel.getSegment(task.segmentID) + if !task.dropped && !task.flushed && segment.isSyncing() { log.RatedInfo(10, "segment is syncing, skip it") continue @@ -484,6 +487,9 @@ func (ibNode *insertBufferNode) Sync(fgMsg *flowGraphMsg, seg2Upload []UniqueID, task.dropped, endPosition) if err != nil { + if errors.Is(err, merr.ErrSegmentNotFound) { + return retry.Unrecoverable(err) + } return err } return nil @@ -495,6 +501,24 @@ func (ibNode *insertBufferNode) Sync(fgMsg *flowGraphMsg, seg2Upload []UniqueID, metrics.DataNodeAutoFlushBufferCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.FailLabel).Inc() metrics.DataNodeAutoFlushBufferCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.TotalLabel).Inc() } + + if errors.Is(err, merr.ErrSegmentNotFound) { + if !segment.isValid() { + log.Info("try to flush a compacted segment, ignore..", + zap.Int64("segmentID", task.segmentID), + zap.Error(err)) + } + continue + } + + if merr.IsCanceledOrTimeout(err) { + log.Warn("skip syncing buffer data for context done", + zap.Int64("segmentID", task.segmentID), + zap.Error(err), + ) + continue + } + log.Fatal("insertBufferNode failed to flushBufferData", zap.Int64("segmentID", task.segmentID), zap.Error(err), @@ -521,13 +545,13 @@ func (ibNode *insertBufferNode) Sync(fgMsg *flowGraphMsg, seg2Upload []UniqueID, func (ibNode *insertBufferNode) addSegmentAndUpdateRowNum(insertMsgs []*msgstream.InsertMsg, startPos, endPos *msgpb.MsgPosition) (seg2Upload []UniqueID, err error) { uniqueSeg := make(map[UniqueID]int64) for _, msg := range insertMsgs { - currentSegID := msg.GetSegmentID() collID := msg.GetCollectionID() partitionID := msg.GetPartitionID() if !ibNode.channel.hasSegment(currentSegID, true) { err = ibNode.channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_New, segID: currentSegID, @@ -645,7 +669,6 @@ func (ibNode *insertBufferNode) getTimestampRange(tsData *storage.Int64FieldData // WriteTimeTick writes timetick once insertBufferNode operates. func (ibNode *insertBufferNode) WriteTimeTick(ts Timestamp, segmentIDs []int64) { - select { case resendTTMsg := <-ibNode.resendTTChan: log.Info("resend TT msg received in insertBufferNode", @@ -677,12 +700,19 @@ func (ibNode *insertBufferNode) getCollectionandPartitionIDbySegID(segmentID Uni return ibNode.channel.getCollectionAndPartitionID(segmentID) } -func newInsertBufferNode(ctx context.Context, collID UniqueID, delBufManager *DeltaBufferManager, flushCh <-chan flushMsg, resendTTCh <-chan resendTTMsg, - fm flushManager, flushingSegCache *Cache, config *nodeConfig, timeTickManager *timeTickSender) (*insertBufferNode, error) { - +func newInsertBufferNode( + ctx context.Context, + flushCh <-chan flushMsg, + resendTTCh <-chan resendTTMsg, + delBufManager *DeltaBufferManager, + fm flushManager, + flushingSegCache *Cache, + timeTickManager *timeTickSender, + config *nodeConfig, +) (*insertBufferNode, error) { baseNode := BaseNode{} - baseNode.SetMaxQueueLength(config.maxQueueLength) - baseNode.SetMaxParallelism(config.maxParallelism) + baseNode.SetMaxQueueLength(Params.DataNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()) + baseNode.SetMaxParallelism(Params.DataNodeCfg.FlowGraphMaxParallelism.GetAsInt32()) if Params.DataNodeCfg.DataNodeTimeTickByRPC.GetAsBool() { return &insertBufferNode{ @@ -702,7 +732,7 @@ func newInsertBufferNode(ctx context.Context, collID UniqueID, delBufManager *De }, nil } - //input stream, data node time tick + // input stream, data node time tick wTt, err := config.msFactory.NewMsgStream(ctx) if err != nil { return nil, err @@ -710,7 +740,8 @@ func newInsertBufferNode(ctx context.Context, collID UniqueID, delBufManager *De wTt.AsProducer([]string{Params.CommonCfg.DataCoordTimeTick.GetValue()}) metrics.DataNodeNumProducers.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc() log.Info("datanode AsProducer", zap.String("TimeTickChannelName", Params.CommonCfg.DataCoordTimeTick.GetValue())) - var wTtMsgStream msgstream.MsgStream = wTt + wTtMsgStream := wTt + wTtMsgStream.EnableProduce(true) mt := newMergedTimeTickerSender(func(ts Timestamp, segmentIDs []int64) error { stats := make([]*commonpb.SegmentStats, 0, len(segmentIDs)) @@ -734,7 +765,7 @@ func newInsertBufferNode(ctx context.Context, collID UniqueID, delBufManager *De commonpbutil.WithMsgType(commonpb.MsgType_DataNodeTt), commonpbutil.WithMsgID(0), commonpbutil.WithTimeStamp(ts), - commonpbutil.WithSourceID(config.serverID), + commonpbutil.WithSourceID(paramtable.GetNodeID()), ), ChannelName: config.vChannelName, Timestamp: ts, @@ -745,7 +776,7 @@ func newInsertBufferNode(ctx context.Context, collID UniqueID, delBufManager *De sub := tsoutil.SubByNow(ts) pChan := funcutil.ToPhysicalChannel(config.vChannelName) metrics.DataNodeProduceTimeTickLag. - WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(collID), pChan). + WithLabelValues(fmt.Sprint(config.serverID), fmt.Sprint(config.collectionID), pChan). Set(float64(sub)) return wTtMsgStream.Produce(&msgPack) }) diff --git a/internal/datanode/flow_graph_insert_buffer_node_test.go b/internal/datanode/flow_graph_insert_buffer_node_test.go index 485d9198256cb..890835ed07491 100644 --- a/internal/datanode/flow_graph_insert_buffer_node_test.go +++ b/internal/datanode/flow_graph_insert_buffer_node_test.go @@ -31,12 +31,14 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "go.uber.org/atomic" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/datanode/allocator" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/storage" @@ -44,6 +46,7 @@ import ( "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/flowgraph" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/tsoutil" @@ -79,14 +82,19 @@ func TestFlowGraphInsertBufferNodeCreate(t *testing.T) { require.NoError(t, err) Params.Save("etcd.rootPath", "/test/datanode/root") - Factory := &MetaFactory{} - collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) - mockRootCoord := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } - - channel := newChannel(insertChannelName, collMeta.ID, collMeta.Schema, mockRootCoord, cm) + collMeta := NewMetaFactory().GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: collMeta.GetSchema(), + }, nil).Maybe() + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything). + Return(nil).Maybe() + + channel := newChannel(insertChannelName, collMeta.ID, collMeta.Schema, broker, cm) err = channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_New, segID: 1, @@ -116,9 +124,8 @@ func TestFlowGraphInsertBufferNodeCreate(t *testing.T) { delBufHeap: &PriorityQueue{}, } - dataCoord := &DataCoordFactory{} - atimeTickSender := newTimeTickSender(dataCoord, 0) - iBNode, err := newInsertBufferNode(ctx, collMeta.ID, delBufManager, flushChan, resendTTChan, fm, newCache(), c, atimeTickSender) + atimeTickSender := newTimeTickSender(broker, 0) + iBNode, err := newInsertBufferNode(ctx, flushChan, resendTTChan, delBufManager, fm, newCache(), atimeTickSender, c) assert.NotNil(t, iBNode) require.NoError(t, err) } @@ -141,12 +148,18 @@ func TestFlowGraphInsertBufferNode_Operate(t *testing.T) { in []Msg description string }{ - {[]Msg{}, - "Invalid input length == 0"}, - {[]Msg{&flowGraphMsg{}, &flowGraphMsg{}, &flowGraphMsg{}}, - "Invalid input length == 3"}, - {[]Msg{&mockMsg{}}, - "Invalid input length == 1 but input message is not flowGraphMsg"}, + { + []Msg{}, + "Invalid input length == 0", + }, + { + []Msg{&flowGraphMsg{}, &flowGraphMsg{}, &flowGraphMsg{}}, + "Invalid input length == 3", + }, + { + []Msg{&mockMsg{}}, + "Invalid input length == 1 but input message is not flowGraphMsg", + }, } for _, test := range invalidInTests { @@ -171,13 +184,19 @@ func TestFlowGraphInsertBufferNode_Operate(t *testing.T) { Factory := &MetaFactory{} collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) - mockRootCoord := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: collMeta.GetSchema(), + }, nil).Maybe() + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything). + Return(nil).Maybe() - channel := newChannel(insertChannelName, collMeta.ID, collMeta.Schema, mockRootCoord, cm) + channel := newChannel(insertChannelName, collMeta.ID, collMeta.Schema, broker, cm) err = channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_New, segID: 1, @@ -210,9 +229,8 @@ func TestFlowGraphInsertBufferNode_Operate(t *testing.T) { delBufHeap: &PriorityQueue{}, } - dataCoord := &DataCoordFactory{} - atimeTickSender := newTimeTickSender(dataCoord, 0) - iBNode, err := newInsertBufferNode(ctx, collMeta.ID, delBufManager, flushChan, resendTTChan, fm, newCache(), c, atimeTickSender) + atimeTickSender := newTimeTickSender(broker, 0) + iBNode, err := newInsertBufferNode(ctx, flushChan, resendTTChan, delBufManager, fm, newCache(), atimeTickSender, c) require.NoError(t, err) flushChan <- flushMsg{ @@ -332,21 +350,25 @@ func TestFlowGraphInsertBufferNode_AutoFlush(t *testing.T) { require.NoError(t, err) Params.Save("etcd.rootPath", "/test/datanode/root") - Factory := &MetaFactory{} - collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) - dataFactory := NewDataFactory() + collMeta := NewMetaFactory().GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: collMeta.GetSchema(), + }, nil).Maybe() + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything). + Return(nil).Maybe() - mockRootCoord := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + dataFactory := NewDataFactory() channel := &ChannelMeta{ collectionID: collMeta.ID, segments: make(map[UniqueID]*Segment), - needToSync: atomic.NewBool(false), + isHighMemory: atomic.NewBool(false), } - channel.metaService = newMetaService(mockRootCoord, collMeta.ID) + channel.metaService = newMetaService(broker, collMeta.ID) factory := dependency.NewDefaultFactory(true) @@ -378,6 +400,7 @@ func TestFlowGraphInsertBufferNode_AutoFlush(t *testing.T) { flushChan := make(chan flushMsg, 100) resendTTChan := make(chan resendTTMsg, 100) c := &nodeConfig{ + collectionID: collMeta.GetID(), channel: channel, msFactory: factory, allocator: alloc, @@ -387,9 +410,8 @@ func TestFlowGraphInsertBufferNode_AutoFlush(t *testing.T) { channel: channel, delBufHeap: &PriorityQueue{}, } - dataCoord := &DataCoordFactory{} - atimeTickSender := newTimeTickSender(dataCoord, 0) - iBNode, err := newInsertBufferNode(ctx, collMeta.ID, delBufManager, flushChan, resendTTChan, fm, newCache(), c, atimeTickSender) + atimeTickSender := newTimeTickSender(broker, 0) + iBNode, err := newInsertBufferNode(ctx, flushChan, resendTTChan, delBufManager, fm, newCache(), atimeTickSender, c) require.NoError(t, err) // Auto flush number of rows set to 2 @@ -488,7 +510,6 @@ func TestFlowGraphInsertBufferNode_AutoFlush(t *testing.T) { // // assert.Equal(t, int64(1), iBNode.insertBuffer.size(UniqueID(i+1))) // } } - }) t.Run("Auto with manual flush", func(t *testing.T) { @@ -566,7 +587,6 @@ func TestFlowGraphInsertBufferNode_AutoFlush(t *testing.T) { assert.Equal(t, false, pack.flushed) } } - }) } @@ -579,21 +599,25 @@ func TestInsertBufferNodeRollBF(t *testing.T) { require.NoError(t, err) Params.Save("etcd.rootPath", "/test/datanode/root") - Factory := &MetaFactory{} - collMeta := Factory.GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) - dataFactory := NewDataFactory() + collMeta := NewMetaFactory().GetCollectionMeta(UniqueID(0), "coll1", schemapb.DataType_Int64) + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: collMeta.GetSchema(), + }, nil).Maybe() + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything). + Return(nil).Maybe() - mockRootCoord := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + dataFactory := NewDataFactory() channel := &ChannelMeta{ collectionID: collMeta.ID, segments: make(map[UniqueID]*Segment), - needToSync: atomic.NewBool(false), + isHighMemory: atomic.NewBool(false), } - channel.metaService = newMetaService(mockRootCoord, collMeta.ID) + channel.metaService = newMetaService(broker, collMeta.ID) factory := dependency.NewDefaultFactory(true) @@ -625,6 +649,7 @@ func TestInsertBufferNodeRollBF(t *testing.T) { flushChan := make(chan flushMsg, 100) resendTTChan := make(chan resendTTMsg, 100) c := &nodeConfig{ + collectionID: collMeta.ID, channel: channel, msFactory: factory, allocator: alloc, @@ -635,9 +660,8 @@ func TestInsertBufferNodeRollBF(t *testing.T) { delBufHeap: &PriorityQueue{}, } - dataCoord := &DataCoordFactory{} - atimeTickSender := newTimeTickSender(dataCoord, 0) - iBNode, err := newInsertBufferNode(ctx, collMeta.ID, delBufManager, flushChan, resendTTChan, fm, newCache(), c, atimeTickSender) + atimeTickSender := newTimeTickSender(broker, 0) + iBNode, err := newInsertBufferNode(ctx, flushChan, resendTTChan, delBufManager, fm, newCache(), atimeTickSender, c) require.NoError(t, err) // Auto flush number of rows set to 2 @@ -728,13 +752,18 @@ type InsertBufferNodeSuite struct { func (s *InsertBufferNodeSuite) SetupSuite() { insertBufferNodeTestDir := "/tmp/milvus_test/insert_buffer_node" - rc := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + + collMeta := NewMetaFactory().GetCollectionMeta(1, "coll1", schemapb.DataType_Int64) + broker := broker.NewMockBroker(s.T()) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: collMeta.GetSchema(), + }, nil).Maybe() s.collID = 1 s.partID = 10 - s.channel = newChannel("channel", s.collID, nil, rc, s.cm) + s.channel = newChannel("channel", s.collID, nil, broker, s.cm) s.delBufManager = &DeltaBufferManager{ channel: s.channel, @@ -763,14 +792,16 @@ func (s *InsertBufferNodeSuite) SetupTest() { } for _, seg := range segs { - err := s.channel.addSegment(addSegmentReq{ - segType: seg.sType, - segID: seg.segID, - collID: s.collID, - partitionID: s.partID, - startPos: new(msgpb.MsgPosition), - endPos: new(msgpb.MsgPosition), - }) + err := s.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: seg.sType, + segID: seg.segID, + collID: s.collID, + partitionID: s.partID, + startPos: new(msgpb.MsgPosition), + endPos: new(msgpb.MsgPosition), + }) s.Require().NoError(err) } } @@ -925,7 +956,6 @@ func (s *InsertBufferNodeSuite) TestFillInSyncTasks() { s.Assert().True(task.auto) } }) - } func TestInsertBufferNodeSuite(t *testing.T) { @@ -934,11 +964,11 @@ func TestInsertBufferNodeSuite(t *testing.T) { // CompactedRootCoord has meta info compacted at ts type CompactedRootCoord struct { - types.RootCoord + types.RootCoordClient compactTs Timestamp } -func (m *CompactedRootCoord) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (m *CompactedRootCoord) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { if in.TimeStamp != 0 && in.GetTimeStamp() <= m.compactTs { return &milvuspb.DescribeCollectionResponse{ Status: &commonpb.Status{ @@ -947,7 +977,7 @@ func (m *CompactedRootCoord) DescribeCollection(ctx context.Context, in *milvusp }, }, nil } - return m.RootCoord.DescribeCollection(ctx, in) + return m.RootCoordClient.DescribeCollection(ctx, in) } func TestInsertBufferNode_bufferInsertMsg(t *testing.T) { @@ -961,7 +991,7 @@ func TestInsertBufferNode_bufferInsertMsg(t *testing.T) { require.NoError(t, err) Params.Save("etcd.rootPath", "/test/datanode/root") - Factory := &MetaFactory{} + factory := &MetaFactory{} tests := []struct { collID UniqueID pkType schemapb.DataType @@ -974,17 +1004,20 @@ func TestInsertBufferNode_bufferInsertMsg(t *testing.T) { cm := storage.NewLocalChunkManager(storage.RootPath(insertNodeTestDir)) defer cm.RemoveWithPrefix(ctx, cm.RootPath()) for _, test := range tests { - collMeta := Factory.GetCollectionMeta(test.collID, "collection", test.pkType) - rcf := &RootCoordFactory{ - pkType: test.pkType, - } - mockRootCoord := &CompactedRootCoord{ - RootCoord: rcf, - compactTs: 100, - } - - channel := newChannel(insertChannelName, collMeta.ID, collMeta.Schema, mockRootCoord, cm) + collMeta := factory.GetCollectionMeta(test.collID, "collection", test.pkType) + + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: collMeta.GetSchema(), + }, nil).Maybe() + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything). + Return(nil).Maybe() + + channel := newChannel(insertChannelName, collMeta.ID, collMeta.Schema, broker, cm) err = channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_New, segID: 1, @@ -1003,6 +1036,7 @@ func TestInsertBufferNode_bufferInsertMsg(t *testing.T) { flushChan := make(chan flushMsg, 100) resendTTChan := make(chan resendTTMsg, 100) c := &nodeConfig{ + collectionID: collMeta.ID, channel: channel, msFactory: factory, allocator: alloc, @@ -1013,9 +1047,8 @@ func TestInsertBufferNode_bufferInsertMsg(t *testing.T) { delBufHeap: &PriorityQueue{}, } - dataCoord := &DataCoordFactory{} - atimeTickSender := newTimeTickSender(dataCoord, 0) - iBNode, err := newInsertBufferNode(ctx, collMeta.ID, delBufManager, flushChan, resendTTChan, fm, newCache(), c, atimeTickSender) + atimeTickSender := newTimeTickSender(broker, 0) + iBNode, err := newInsertBufferNode(ctx, flushChan, resendTTChan, delBufManager, fm, newCache(), atimeTickSender, c) require.NoError(t, err) inMsg := genFlowGraphInsertMsg(insertChannelName) @@ -1027,7 +1060,7 @@ func TestInsertBufferNode_bufferInsertMsg(t *testing.T) { for _, msg := range inMsg.insertMessages { msg.EndTimestamp = 101 // ts valid - msg.RowIDs = []int64{} //misaligned data + msg.RowIDs = []int64{} // misaligned data err = iBNode.bufferInsertMsg(msg, &msgpb.MsgPosition{}, &msgpb.MsgPosition{}) assert.Error(t, err) } @@ -1050,7 +1083,8 @@ func TestInsertBufferNode_updateSegmentStates(te *testing.T) { } for _, test := range invalideTests { - channel := newChannel("channel", test.channelCollID, nil, &RootCoordFactory{pkType: schemapb.DataType_Int64}, cm) + broker := broker.NewMockBroker(te) + channel := newChannel("channel", test.channelCollID, nil, broker, cm) ibNode := &insertBufferNode{ channel: channel, @@ -1073,7 +1107,6 @@ func TestInsertBufferNode_updateSegmentStates(te *testing.T) { } func TestInsertBufferNode_getTimestampRange(t *testing.T) { - type testCase struct { tag string @@ -1160,14 +1193,26 @@ func TestInsertBufferNode_task_pool_is_full(t *testing.T) { collection := UniqueID(0) segmentID := UniqueID(100) - channel := newChannel(channelName, collection, nil, &RootCoordFactory{pkType: schemapb.DataType_Int64}, cm) - err := channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_New, - segID: segmentID, - collID: collection, - startPos: new(msgpb.MsgPosition), - endPos: new(msgpb.MsgPosition), - }) + meta := NewMetaFactory().GetCollectionMeta(collection, "test_collection", schemapb.DataType_Int64) + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + CollectionID: 1, + CollectionName: "test_collection", + Schema: meta.GetSchema(), + }, nil) + + channel := newChannel(channelName, collection, nil, broker, cm) + err := channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: segmentID, + collID: collection, + startPos: new(msgpb.MsgPosition), + endPos: new(msgpb.MsgPosition), + }) assert.NoError(t, err) channel.setCurInsertBuffer(segmentID, &BufferData{size: 100}) diff --git a/internal/datanode/flow_graph_manager.go b/internal/datanode/flow_graph_manager.go index 86b6f9b77cfc6..765ff0f6e903c 100644 --- a/internal/datanode/flow_graph_manager.go +++ b/internal/datanode/flow_graph_manager.go @@ -17,6 +17,7 @@ package datanode import ( + "context" "fmt" "sort" "sync" @@ -48,7 +49,8 @@ func newFlowgraphManager() *flowgraphManager { } } -func (fm *flowgraphManager) start() { +func (fm *flowgraphManager) start(waiter *sync.WaitGroup) { + defer waiter.Done() ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() for { @@ -108,32 +110,39 @@ func (fm *flowgraphManager) execute(totalMemory uint64) { return channels[i].bufferSize > channels[j].bufferSize }) if fg, ok := fm.flowgraphs.Get(channels[0].channel); ok { // sync the first channel with the largest memory usage - fg.channel.forceToSync() + fg.channel.setIsHighMemory(true) log.Info("notify flowgraph to sync", zap.String("channel", channels[0].channel), zap.Int64("bufferSize", channels[0].bufferSize)) } } -func (fm *flowgraphManager) addAndStart(dn *DataNode, vchan *datapb.VchannelInfo, schema *schemapb.CollectionSchema, tickler *tickler) error { +func (fm *flowgraphManager) Add(ds *dataSyncService) { + fm.flowgraphs.Insert(ds.vchannelName, ds) + metrics.DataNodeNumFlowGraphs.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc() +} + +func (fm *flowgraphManager) addAndStartWithEtcdTickler(dn *DataNode, vchan *datapb.VchannelInfo, schema *schemapb.CollectionSchema, tickler *etcdTickler) error { log := log.With(zap.String("channel", vchan.GetChannelName())) if fm.flowgraphs.Contain(vchan.GetChannelName()) { log.Warn("try to add an existed DataSyncService") return nil } - channel := newChannel(vchan.GetChannelName(), vchan.GetCollectionID(), schema, dn.rootCoord, dn.chunkManager) - var dataSyncService *dataSyncService var err error if Params.CommonCfg.EnableStorageV2.GetAsBool() { - dataSyncService, err = newDataSyncServiceV2(dn.ctx, make(chan flushMsg, 100), make(chan resendTTMsg, 100), channel, - dn.allocator, dn.dispClient, dn.factory, vchan, dn.clearSignal, dn.dataCoord, dn.segmentCache, dn.chunkManager, dn.compactionExecutor, tickler, dn.GetSession().ServerID, dn.timeTickSender, dn.etcdCli) + dataSyncService, err := newServiceWithEtcdTicklerV2(context.TODO(), dn, &datapb.ChannelWatchInfo{ + Schema: schema, + Vchan: vchan, + }, tickler) } else { - dataSyncService, err = newDataSyncService(dn.ctx, make(chan flushMsg, 100), make(chan resendTTMsg, 100), channel, - dn.allocator, dn.dispClient, dn.factory, vchan, dn.clearSignal, dn.dataCoord, dn.segmentCache, dn.chunkManager, dn.compactionExecutor, tickler, dn.GetSession().ServerID, dn.timeTickSender) + dataSyncService, err := newServiceWithEtcdTickler(context.TODO(), dn, &datapb.ChannelWatchInfo{ + Schema: schema, + Vchan: vchan, + }, tickler) } if err != nil { - log.Warn("fail to create new datasyncservice", zap.Error(err)) + log.Warn("fail to create new DataSyncService", zap.Error(err)) return err } dataSyncService.start() @@ -144,11 +153,13 @@ func (fm *flowgraphManager) addAndStart(dn *DataNode, vchan *datapb.VchannelInfo } func (fm *flowgraphManager) release(vchanName string) { - if fg, loaded := fm.flowgraphs.GetAndRemove(vchanName); loaded { + if fg, loaded := fm.flowgraphs.Get(vchanName); loaded { fg.close() + fm.flowgraphs.Remove(vchanName) + metrics.DataNodeNumFlowGraphs.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec() + rateCol.removeFlowGraphChannel(vchanName) } - rateCol.removeFlowGraphChannel(vchanName) } func (fm *flowgraphManager) getFlushCh(segID UniqueID) (chan<- flushMsg, error) { @@ -190,26 +201,6 @@ func (fm *flowgraphManager) getChannel(segID UniqueID) (Channel, error) { return nil, fmt.Errorf("cannot find segment %d in all flowgraphs", segID) } -// resendTT loops through flow graphs, looks for segments that are not flushed, -// and sends them to that flow graph's `resendTTCh` channel so stats of -// these segments will be resent. -func (fm *flowgraphManager) resendTT() []UniqueID { - var unFlushedSegments []UniqueID - fm.flowgraphs.Range(func(key string, fg *dataSyncService) bool { - segIDs := fg.channel.listNotFlushedSegmentIDs() - if len(segIDs) > 0 { - log.Info("un-flushed segments found, stats will be resend", - zap.Int64s("segment IDs", segIDs)) - unFlushedSegments = append(unFlushedSegments, segIDs...) - fg.resendTTCh <- resendTTMsg{ - segmentIDs: segIDs, - } - } - return true - }) - return unFlushedSegments -} - func (fm *flowgraphManager) getFlowgraphService(vchan string) (*dataSyncService, bool) { return fm.flowgraphs.Get(vchan) } @@ -219,6 +210,11 @@ func (fm *flowgraphManager) exist(vchan string) bool { return exist } +func (fm *flowgraphManager) existWithOpID(vchan string, opID UniqueID) bool { + ds, exist := fm.getFlowgraphService(vchan) + return exist && ds.opID == opID +} + // getFlowGraphNum returns number of flow graphs. func (fm *flowgraphManager) getFlowGraphNum() int { return fm.flowgraphs.Len() @@ -234,3 +230,13 @@ func (fm *flowgraphManager) dropAll() { return true }) } + +func (fm *flowgraphManager) collections() []int64 { + collectionSet := typeutil.UniqueSet{} + fm.flowgraphs.Range(func(key string, value *dataSyncService) bool { + collectionSet.Insert(value.channel.getCollectionID()) + return true + }) + + return collectionSet.Collect() +} diff --git a/internal/datanode/flow_graph_manager_test.go b/internal/datanode/flow_graph_manager_test.go index 358575ea574e0..88098a7eb62c0 100644 --- a/internal/datanode/flow_graph_manager_test.go +++ b/internal/datanode/flow_graph_manager_test.go @@ -22,12 +22,16 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/merr" ) func TestFlowGraphManager(t *testing.T) { @@ -51,10 +55,28 @@ func TestFlowGraphManager(t *testing.T) { err = node.Init() require.Nil(t, err) + meta := NewMetaFactory().GetCollectionMeta(1, "test_collection", schemapb.DataType_Int64) + broker := broker.NewMockBroker(t) + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return([]*datapb.SegmentInfo{}, nil).Maybe() + broker.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + CollectionID: 1, + CollectionName: "test_collection", + Schema: meta.GetSchema(), + }, nil) + + node.broker = broker + fm := newFlowgraphManager() defer func() { fm.dropAll() }() + t.Run("Test addAndStart", func(t *testing.T) { vchanName := "by-dev-rootcoord-dml-test-flowgraphmanager-addAndStart" vchan := &datapb.VchannelInfo{ @@ -63,7 +85,7 @@ func TestFlowGraphManager(t *testing.T) { } require.False(t, fm.exist(vchanName)) - err := fm.addAndStart(node, vchan, nil, genTestTickler()) + err := fm.addAndStartWithEtcdTickler(node, vchan, nil, genTestTickler()) assert.NoError(t, err) assert.True(t, fm.exist(vchanName)) @@ -78,7 +100,7 @@ func TestFlowGraphManager(t *testing.T) { } require.False(t, fm.exist(vchanName)) - err := fm.addAndStart(node, vchan, nil, genTestTickler()) + err := fm.addAndStartWithEtcdTickler(node, vchan, nil, genTestTickler()) assert.NoError(t, err) assert.True(t, fm.exist(vchanName)) @@ -96,19 +118,21 @@ func TestFlowGraphManager(t *testing.T) { } require.False(t, fm.exist(vchanName)) - err := fm.addAndStart(node, vchan, nil, genTestTickler()) + err := fm.addAndStartWithEtcdTickler(node, vchan, nil, genTestTickler()) assert.NoError(t, err) assert.True(t, fm.exist(vchanName)) fg, ok := fm.getFlowgraphService(vchanName) require.True(t, ok) - err = fg.channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_New, - segID: 100, - collID: 1, - partitionID: 10, - startPos: &msgpb.MsgPosition{}, - endPos: &msgpb.MsgPosition{}, - }) + err = fg.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 100, + collID: 1, + partitionID: 10, + startPos: &msgpb.MsgPosition{}, + endPos: &msgpb.MsgPosition{}, + }) require.NoError(t, err) tests := []struct { @@ -144,20 +168,22 @@ func TestFlowGraphManager(t *testing.T) { } require.False(t, fm.exist(vchanName)) - err := fm.addAndStart(node, vchan, nil, genTestTickler()) + err := fm.addAndStartWithEtcdTickler(node, vchan, nil, genTestTickler()) assert.NoError(t, err) assert.True(t, fm.exist(vchanName)) fg, ok := fm.getFlowgraphService(vchanName) require.True(t, ok) - err = fg.channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_New, - segID: 100, - collID: 1, - partitionID: 10, - startPos: &msgpb.MsgPosition{}, - endPos: &msgpb.MsgPosition{}, - }) + err = fg.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 100, + collID: 1, + partitionID: 10, + startPos: &msgpb.MsgPosition{}, + endPos: &msgpb.MsgPosition{}, + }) require.NoError(t, err) tests := []struct { @@ -199,10 +225,16 @@ func TestFlowGraphManager(t *testing.T) { memorySizes []int64 expectNeedToSync []bool }{ - {"test over the watermark", 100, 0.5, - []int64{15, 16, 17, 18}, []bool{false, false, false, true}}, - {"test below the watermark", 100, 0.5, - []int64{1, 2, 3, 4}, []bool{false, false, false, false}}, + { + "test over the watermark", 100, 0.5, + []int64{15, 16, 17, 18}, + []bool{false, false, false, true}, + }, + { + "test below the watermark", 100, 0.5, + []int64{1, 2, 3, 4}, + []bool{false, false, false, false}, + }, } fm.dropAll() @@ -215,21 +247,21 @@ func TestFlowGraphManager(t *testing.T) { vchan := &datapb.VchannelInfo{ ChannelName: vchannel, } - err = fm.addAndStart(node, vchan, nil, genTestTickler()) + err = fm.addAndStartWithEtcdTickler(node, vchan, nil, genTestTickler()) assert.NoError(t, err) fg, ok := fm.flowgraphs.Get(vchannel) assert.True(t, ok) - err = fg.channel.addSegment(addSegmentReq{segID: 0}) + err = fg.channel.addSegment(context.TODO(), addSegmentReq{segID: 0}) assert.NoError(t, err) fg.channel.getSegment(0).memorySize = memorySize - fg.channel.(*ChannelMeta).needToSync.Store(false) + fg.channel.setIsHighMemory(false) } fm.execute(test.totalMemory) for i, needToSync := range test.expectNeedToSync { vchannel := fmt.Sprintf("%s%d", channelPrefix, i) fg, ok := fm.flowgraphs.Get(vchannel) assert.True(t, ok) - assert.Equal(t, needToSync, fg.channel.(*ChannelMeta).needToSync.Load()) + assert.Equal(t, needToSync, fg.channel.getIsHighMemory()) } } }) diff --git a/internal/datanode/flow_graph_message.go b/internal/datanode/flow_graph_message.go index ce1d5bbb9bb2b..c14603529904b 100644 --- a/internal/datanode/flow_graph_message.go +++ b/internal/datanode/flow_graph_message.go @@ -49,7 +49,7 @@ type flowGraphMsg struct { timeRange TimeRange startPositions []*msgpb.MsgPosition endPositions []*msgpb.MsgPosition - //segmentsToSync is the signal used by insertBufferNode to notify deleteNode to flush + // segmentsToSync is the signal used by insertBufferNode to notify deleteNode to flush segmentsToSync []UniqueID dropCollection bool dropPartitions []UniqueID @@ -69,7 +69,7 @@ type flushMsg struct { timestamp Timestamp segmentID UniqueID collectionID UniqueID - //isFlush illustrates if this is a flush or normal sync + // isFlush illustrates if this is a flush or normal sync isFlush bool } diff --git a/internal/datanode/flow_graph_message_test.go b/internal/datanode/flow_graph_message_test.go index 70b9e93491030..d5d9dbbd6cdb3 100644 --- a/internal/datanode/flow_graph_message_test.go +++ b/internal/datanode/flow_graph_message_test.go @@ -38,5 +38,4 @@ func TestInsertMsg_TimeTick(te *testing.T) { assert.Equal(t, test.timeTimestanpMax, fgMsg.TimeTick()) }) } - } diff --git a/internal/datanode/flow_graph_time_tick_node.go b/internal/datanode/flow_graph_time_tick_node.go index b5ab67654f449..e81a671f98484 100644 --- a/internal/datanode/flow_graph_time_tick_node.go +++ b/internal/datanode/flow_graph_time_tick_node.go @@ -19,19 +19,18 @@ package datanode import ( "context" "fmt" + "math" "reflect" + "sync" "time" + "go.uber.org/atomic" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/util/flowgraph" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/commonpbutil" - "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tsoutil" ) @@ -47,8 +46,17 @@ type ttNode struct { BaseNode vChannelName string channel Channel - lastUpdateTime time.Time - dataCoord types.DataCoord + lastUpdateTime *atomic.Time + broker broker.Broker + + updateCPLock sync.Mutex + notifyChannel chan checkPoint + closeChannel chan struct{} +} + +type checkPoint struct { + curTs time.Time + pos *msgpb.MsgPosition } // Name returns node name, implementing flowgraph.Node @@ -71,55 +79,66 @@ func (ttn *ttNode) IsValidInMsg(in []Msg) bool { // Operate handles input messages, implementing flowgraph.Node func (ttn *ttNode) Operate(in []Msg) []Msg { fgMsg := in[0].(*flowGraphMsg) + curTs, _ := tsoutil.ParseTS(fgMsg.timeRange.timestampMax) if fgMsg.IsCloseMsg() { if len(fgMsg.endPositions) > 0 { + close(ttn.closeChannel) + channelPos := ttn.channel.getChannelCheckpoint(fgMsg.endPositions[0]) log.Info("flowgraph is closing, force update channel CP", - zap.Uint64("endTs", fgMsg.endPositions[0].GetTimestamp()), - zap.String("channel", fgMsg.endPositions[0].GetChannelName())) - ttn.updateChannelCP(fgMsg.endPositions[0]) + zap.Time("cpTs", tsoutil.PhysicalTime(channelPos.GetTimestamp())), + zap.String("channel", channelPos.GetChannelName())) + ttn.updateChannelCP(channelPos, curTs) } return in } - curTs, _ := tsoutil.ParseTS(fgMsg.timeRange.timestampMax) - if curTs.Sub(ttn.lastUpdateTime) >= updateChanCPInterval { - ttn.updateChannelCP(fgMsg.endPositions[0]) - ttn.lastUpdateTime = curTs + // Do not block and async updateCheckPoint + channelPos := ttn.channel.getChannelCheckpoint(fgMsg.endPositions[0]) + nonBlockingNotify := func() { + select { + case ttn.notifyChannel <- checkPoint{curTs, channelPos}: + default: + } + } + + if curTs.Sub(ttn.lastUpdateTime.Load()) >= updateChanCPInterval { + nonBlockingNotify() + return []Msg{} } + if channelPos.GetTimestamp() >= ttn.channel.getFlushTs() { + nonBlockingNotify() + } return []Msg{} } -func (ttn *ttNode) updateChannelCP(ttPos *msgpb.MsgPosition) { - channelPos := ttn.channel.getChannelCheckpoint(ttPos) - if channelPos == nil || channelPos.MsgID == nil { - log.Warn("updateChannelCP failed, get nil check point", zap.String("vChannel", ttn.vChannelName)) - return - } - channelCPTs, _ := tsoutil.ParseTS(channelPos.Timestamp) +func (ttn *ttNode) updateChannelCP(channelPos *msgpb.MsgPosition, curTs time.Time) error { + ttn.updateCPLock.Lock() + defer ttn.updateCPLock.Unlock() + channelCPTs, _ := tsoutil.ParseTS(channelPos.GetTimestamp()) + // TODO, change to ETCD operation, avoid datacoord operation ctx, cancel := context.WithTimeout(context.Background(), updateChanCPTimeout) defer cancel() - resp, err := ttn.dataCoord.UpdateChannelCheckpoint(ctx, &datapb.UpdateChannelCheckpointRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithSourceID(paramtable.GetNodeID()), - ), - VChannel: ttn.vChannelName, - Position: channelPos, - }) - if err = funcutil.VerifyResponse(resp, err); err != nil { - log.Warn("UpdateChannelCheckpoint failed", zap.String("channel", ttn.vChannelName), - zap.Time("channelCPTs", channelCPTs), zap.Error(err)) - return + + err := ttn.broker.UpdateChannelCheckpoint(ctx, ttn.vChannelName, channelPos) + if err != nil { + return err } + ttn.lastUpdateTime.Store(curTs) + // channelPos ts > flushTs means we could stop flush. + if channelPos.GetTimestamp() >= ttn.channel.getFlushTs() { + ttn.channel.setFlushTs(math.MaxUint64) + } log.Info("UpdateChannelCheckpoint success", zap.String("channel", ttn.vChannelName), - zap.Uint64("cpTs", channelPos.Timestamp), + zap.Uint64("cpTs", channelPos.GetTimestamp()), zap.Time("cpTime", channelCPTs)) + return nil } -func newTTNode(config *nodeConfig, dc types.DataCoord) (*ttNode, error) { +func newTTNode(config *nodeConfig, broker broker.Broker) (*ttNode, error) { baseNode := BaseNode{} baseNode.SetMaxQueueLength(Params.DataNodeCfg.FlowGraphMaxQueueLength.GetAsInt32()) baseNode.SetMaxParallelism(Params.DataNodeCfg.FlowGraphMaxParallelism.GetAsInt32()) @@ -128,9 +147,23 @@ func newTTNode(config *nodeConfig, dc types.DataCoord) (*ttNode, error) { BaseNode: baseNode, vChannelName: config.vChannelName, channel: config.channel, - lastUpdateTime: time.Time{}, // set to Zero to update channel checkpoint immediately after fg started - dataCoord: dc, + lastUpdateTime: atomic.NewTime(time.Time{}), // set to Zero to update channel checkpoint immediately after fg started + broker: broker, + notifyChannel: make(chan checkPoint, 1), + closeChannel: make(chan struct{}), } + // check point updater + go func() { + for { + select { + case <-tt.closeChannel: + return + case cp := <-tt.notifyChannel: + tt.updateChannelCP(cp.pos, cp.curTs) + } + } + }() + return tt, nil } diff --git a/internal/datanode/flow_graph_time_ticker.go b/internal/datanode/flow_graph_time_ticker.go index 34dabf995816a..50db66bfb6d22 100644 --- a/internal/datanode/flow_graph_time_ticker.go +++ b/internal/datanode/flow_graph_time_ticker.go @@ -21,9 +21,8 @@ import ( "time" "github.com/samber/lo" - "golang.org/x/exp/maps" - "go.uber.org/zap" + "golang.org/x/exp/maps" "github.com/milvus-io/milvus/pkg/log" ) diff --git a/internal/datanode/flush_manager.go b/internal/datanode/flush_manager.go index 6f98bf82639e8..05f6fe4e3f963 100644 --- a/internal/datanode/flush_manager.go +++ b/internal/datanode/flush_manager.go @@ -33,7 +33,6 @@ import ( "go.uber.org/zap" "golang.org/x/sync/errgroup" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" milvus_storage "github.com/milvus-io/milvus-storage/go/storage" @@ -915,7 +914,6 @@ func (m *rendezvousFlushManager) flushBufferData(data *BufferData, segmentID Uni // binlogs for _, blob := range binLogBlobs { - defer func() { logidx++ }() fieldID, err := strconv.ParseInt(blob.GetKey(), 10, 64) if err != nil { log.Error("Flush failed ... cannot parse string to fieldID ..", zap.Error(err)) @@ -933,6 +931,8 @@ func (m *rendezvousFlushManager) flushBufferData(data *BufferData, segmentID Uni LogPath: key, LogSize: int64(fieldMemorySize[fieldID]), } + + logidx += 1 } // pk stats binlog @@ -957,8 +957,8 @@ func (m *rendezvousFlushManager) flushBufferData(data *BufferData, segmentID Uni kvs[key] = pkStatsBlob.Value field2Stats[fieldID] = &datapb.Binlog{ EntriesNum: 0, - TimestampFrom: 0, //TODO - TimestampTo: 0, //TODO, + TimestampFrom: 0, // TODO + TimestampTo: 0, // TODO, LogPath: key, LogSize: int64(len(pkStatsBlob.Value)), } @@ -975,8 +975,8 @@ func (m *rendezvousFlushManager) flushBufferData(data *BufferData, segmentID Uni // notify flush manager del buffer data func (m *rendezvousFlushManager) flushDelData(data *DelDataBuf, segmentID UniqueID, - pos *msgpb.MsgPosition) error { - + pos *msgpb.MsgPosition, +) error { // del signal with empty data if data == nil || data.delData == nil { m.handleDeleteTask(segmentID, &flushBufferDeleteTask{}, nil, pos) @@ -1025,7 +1025,7 @@ func (m *rendezvousFlushManager) injectFlush(injection *taskInjection, segments // fetch meta info for segment func (m *rendezvousFlushManager) getSegmentMeta(segmentID UniqueID, pos *msgpb.MsgPosition) (UniqueID, UniqueID, *etcdpb.CollectionMeta, error) { if !m.hasSegment(segmentID, true) { - return -1, -1, nil, fmt.Errorf("no such segment %d in the channel", segmentID) + return -1, -1, nil, merr.WrapErrSegmentNotFound(segmentID, "segment not found during flush") } // fetch meta information of segment @@ -1095,7 +1095,7 @@ func getSyncTaskID(pos *msgpb.MsgPosition) string { // close cleans up all the left members func (m *rendezvousFlushManager) close() { m.dispatcher.Range(func(segmentID int64, queue *orderFlushQueue) bool { - //assertion ok + // assertion ok queue.injectMut.Lock() for i := 0; i < len(queue.injectCh); i++ { go queue.handleInject(<-queue.injectCh) @@ -1222,9 +1222,9 @@ func dropVirtualChannelFunc(dsService *dataSyncService, opts ...retry.Option) fl return func(packs []*segmentFlushPack) { req := &datapb.DropVirtualChannelRequest{ Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(0), //TODO msg type - commonpbutil.WithMsgID(0), //TODO msg id - commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithMsgType(0), // TODO msg type + commonpbutil.WithMsgID(0), // TODO msg id + commonpbutil.WithSourceID(dsService.serverID), ), ChannelName: dsService.vchannelName, } @@ -1297,21 +1297,14 @@ func dropVirtualChannelFunc(dsService *dataSyncService, opts ...retry.Option) fl req.Segments = segments err := retry.Do(context.Background(), func() error { - rsp, err := dsService.dataCoord.DropVirtualChannel(context.Background(), req) - // should be network issue, return error and retry + err := dsService.broker.DropVirtualChannel(context.Background(), req) if err != nil { - return fmt.Errorf(err.Error()) - } - - // meta error, datanode handles a virtual channel does not belong here - if rsp.GetStatus().GetErrorCode() == commonpb.ErrorCode_MetaFailed { - log.Warn("meta error found, skip sync and start to drop virtual channel", zap.String("channel", dsService.vchannelName)) - return nil - } - - // retry for other error - if rsp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - return fmt.Errorf("data service DropVirtualChannel failed, reason = %s", rsp.GetStatus().GetReason()) + // meta error, datanode handles a virtual channel does not belong here + if errors.Is(err, merr.ErrChannelNotFound) { + log.Warn("meta error found, skip sync and start to drop virtual channel", zap.String("channel", dsService.vchannelName)) + return nil + } + return err } dsService.channel.transferNewSegments(lo.Map(startPos, func(pos *datapb.SegmentStartPosition, _ int) UniqueID { return pos.GetSegmentID() @@ -1497,7 +1490,7 @@ func flushNotifyFunc(dsService *dataSyncService, opts ...retry.Option) notifyMet Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(0), commonpbutil.WithMsgID(0), - commonpbutil.WithSourceID(dsService.serverID), + commonpbutil.WithSourceID(paramtable.GetNodeID()), ), SegmentID: pack.segmentID, CollectionID: dsService.collectionID, @@ -1513,30 +1506,25 @@ func flushNotifyFunc(dsService *dataSyncService, opts ...retry.Option) notifyMet Channel: dsService.vchannelName, } err := retry.Do(context.Background(), func() error { - rsp, err := dsService.dataCoord.SaveBinlogPaths(context.Background(), req) - // should be network issue, return error and retry - if err != nil { - return err - } - + err := dsService.broker.SaveBinlogPaths(context.Background(), req) // Segment not found during stale segment flush. Segment might get compacted already. // Stop retry and still proceed to the end, ignoring this error. - if !pack.flushed && rsp.GetErrorCode() == commonpb.ErrorCode_SegmentNotFound { + if !pack.flushed && errors.Is(err, merr.ErrSegmentNotFound) { log.Warn("stale segment not found, could be compacted", zap.Int64("segmentID", pack.segmentID)) log.Warn("failed to SaveBinlogPaths", zap.Int64("segmentID", pack.segmentID), - zap.Error(errors.New(rsp.GetReason()))) + zap.Error(err)) return nil } // meta error, datanode handles a virtual channel does not belong here - if rsp.GetErrorCode() == commonpb.ErrorCode_MetaFailed { + if errors.IsAny(err, merr.ErrSegmentNotFound, merr.ErrChannelNotFound) { log.Warn("meta error found, skip sync and start to drop virtual channel", zap.String("channel", dsService.vchannelName)) return nil } - if rsp.ErrorCode != commonpb.ErrorCode_Success { - return fmt.Errorf("data service save bin log path failed, reason = %s", rsp.Reason) + if err != nil { + return err } dsService.channel.transferNewSegments(lo.Map(startPos, func(pos *datapb.SegmentStartPosition, _ int) UniqueID { diff --git a/internal/datanode/flush_manager_test.go b/internal/datanode/flush_manager_test.go index 9398ec1d315a5..5c499c90a1001 100644 --- a/internal/datanode/flush_manager_test.go +++ b/internal/datanode/flush_manager_test.go @@ -25,16 +25,19 @@ import ( "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.uber.org/atomic" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/datanode/allocator" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/retry" ) @@ -329,7 +332,6 @@ func TestRendezvousFlushManager_Inject(t *testing.T) { }) assert.Eventually(t, func() bool { return counter.Load() == int64(size+3) }, 3*time.Second, 100*time.Millisecond) assert.EqualValues(t, 4, packs[size+1].segmentID) - } func TestRendezvousFlushManager_getSegmentMeta(t *testing.T) { @@ -455,7 +457,7 @@ func TestRendezvousFlushManager_dropMode(t *testing.T) { channel := newTestChannel() targets := make(map[int64]struct{}) - //init failed segment + // init failed segment testSeg := &Segment{ collectionID: 1, segmentID: -1, @@ -463,7 +465,7 @@ func TestRendezvousFlushManager_dropMode(t *testing.T) { testSeg.setType(datapb.SegmentType_New) channel.segments[testSeg.segmentID] = testSeg - //init target segment + // init target segment for i := 1; i < 11; i++ { targets[int64(i)] = struct{}{} testSeg := &Segment{ @@ -474,7 +476,7 @@ func TestRendezvousFlushManager_dropMode(t *testing.T) { channel.segments[testSeg.segmentID] = testSeg } - //init flush manager + // init flush manager m := NewRendezvousFlushManager(allocator.NewMockAllocator(t), cm, channel, func(pack *segmentFlushPack) { }, func(packs []*segmentFlushPack) { mut.Lock() @@ -532,7 +534,7 @@ func TestRendezvousFlushManager_dropMode(t *testing.T) { var result []*segmentFlushPack signal := make(chan struct{}) channel := newTestChannel() - //init failed segment + // init failed segment testSeg := &Segment{ collectionID: 1, segmentID: -1, @@ -540,7 +542,7 @@ func TestRendezvousFlushManager_dropMode(t *testing.T) { testSeg.setType(datapb.SegmentType_New) channel.segments[testSeg.segmentID] = testSeg - //init target segment + // init target segment for i := 1; i < 11; i++ { seg := &Segment{ collectionID: 1, @@ -558,14 +560,14 @@ func TestRendezvousFlushManager_dropMode(t *testing.T) { close(signal) }) - //flush failed segment before start drop mode + // flush failed segment before start drop mode halfMsgID := []byte{1, 1, 1} _, err := m.flushBufferData(nil, -1, true, false, &msgpb.MsgPosition{ MsgID: halfMsgID, }) assert.NoError(t, err) - //inject target segment + // inject target segment injFunc := func(pack *segmentFlushPack) { pack.segmentID = 100 } @@ -617,7 +619,7 @@ func TestRendezvousFlushManager_close(t *testing.T) { channel := newTestChannel() - //init test segment + // init test segment testSeg := &Segment{ collectionID: 1, segmentID: 1, @@ -653,22 +655,28 @@ func TestRendezvousFlushManager_close(t *testing.T) { } func TestFlushNotifyFunc(t *testing.T) { - rcf := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + meta := NewMetaFactory().GetCollectionMeta(1, "testCollection", schemapb.DataType_Int64) + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: meta.GetSchema(), + }, nil).Maybe() + broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).Return(nil).Maybe() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() cm := storage.NewLocalChunkManager(storage.RootPath(flushTestDir)) defer cm.RemoveWithPrefix(ctx, cm.RootPath()) - channel := newChannel("channel", 1, nil, rcf, cm) + channel := newChannel("channel", 1, nil, broker, cm) - dataCoord := &DataCoordFactory{} flushingCache := newCache() dsService := &dataSyncService{ collectionID: 1, channel: channel, - dataCoord: dataCoord, + broker: broker, flushingSegCache: flushingCache, } notifyFunc := flushNotifyFunc(dsService, retry.Attempts(1)) @@ -693,14 +701,18 @@ func TestFlushNotifyFunc(t *testing.T) { }) t.Run("datacoord save fails", func(t *testing.T) { - dataCoord.SaveBinlogPathStatus = commonpb.ErrorCode_UnexpectedError + broker.ExpectedCalls = nil + broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything). + Return(merr.WrapErrCollectionNotFound("test_collection")) assert.Panics(t, func() { notifyFunc(&segmentFlushPack{}) }) }) t.Run("stale segment not found", func(t *testing.T) { - dataCoord.SaveBinlogPathStatus = commonpb.ErrorCode_SegmentNotFound + broker.ExpectedCalls = nil + broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything). + Return(merr.WrapErrSegmentNotFound(0)) assert.NotPanics(t, func() { notifyFunc(&segmentFlushPack{flushed: false}) }) @@ -709,14 +721,18 @@ func TestFlushNotifyFunc(t *testing.T) { // issue https://github.com/milvus-io/milvus/issues/17097 // meta error, datanode shall not panic, just drop the virtual channel t.Run("datacoord found meta error", func(t *testing.T) { - dataCoord.SaveBinlogPathStatus = commonpb.ErrorCode_MetaFailed + broker.ExpectedCalls = nil + broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything). + Return(merr.WrapErrChannelNotFound("channel")) assert.NotPanics(t, func() { notifyFunc(&segmentFlushPack{}) }) }) t.Run("datacoord call error", func(t *testing.T) { - dataCoord.SaveBinlogPathError = true + broker.ExpectedCalls = nil + broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything). + Return(errors.New("mock")) assert.Panics(t, func() { notifyFunc(&segmentFlushPack{}) }) @@ -724,9 +740,15 @@ func TestFlushNotifyFunc(t *testing.T) { } func TestDropVirtualChannelFunc(t *testing.T) { - rcf := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } + meta := NewMetaFactory().GetCollectionMeta(1, "testCollection", schemapb.DataType_Int64) + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: meta.GetSchema(), + }, nil) + broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).Return(nil).Maybe() vchanName := "vchan_01" ctx, cancel := context.WithCancel(context.Background()) @@ -734,20 +756,20 @@ func TestDropVirtualChannelFunc(t *testing.T) { cm := storage.NewLocalChunkManager(storage.RootPath(flushTestDir)) defer cm.RemoveWithPrefix(ctx, cm.RootPath()) - channel := newChannel(vchanName, 1, nil, rcf, cm) + channel := newChannel(vchanName, 1, nil, broker, cm) - dataCoord := &DataCoordFactory{} flushingCache := newCache() dsService := &dataSyncService{ collectionID: 1, channel: channel, - dataCoord: dataCoord, + broker: broker, flushingSegCache: flushingCache, vchannelName: vchanName, } dropFunc := dropVirtualChannelFunc(dsService, retry.Attempts(1)) t.Run("normal run", func(t *testing.T) { channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_New, segID: 2, @@ -757,7 +779,8 @@ func TestDropVirtualChannelFunc(t *testing.T) { ChannelName: vchanName, MsgID: []byte{1, 2, 3}, Timestamp: 10, - }, endPos: nil}) + }, endPos: nil, + }) assert.NotPanics(t, func() { dropFunc([]*segmentFlushPack{ { @@ -785,17 +808,10 @@ func TestDropVirtualChannelFunc(t *testing.T) { }) }) }) - t.Run("datacoord drop fails", func(t *testing.T) { - dataCoord.DropVirtualChannelStatus = commonpb.ErrorCode_UnexpectedError - assert.Panics(t, func() { - dropFunc(nil) - }) - }) - - t.Run("datacoord call error", func(t *testing.T) { - - dataCoord.DropVirtualChannelStatus = commonpb.ErrorCode_UnexpectedError - dataCoord.DropVirtualChannelError = true + t.Run("datacoord_return_error", func(t *testing.T) { + broker.ExpectedCalls = nil + broker.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything). + Return(errors.New("mock")) assert.Panics(t, func() { dropFunc(nil) }) @@ -803,12 +819,12 @@ func TestDropVirtualChannelFunc(t *testing.T) { // issue https://github.com/milvus-io/milvus/issues/17097 // meta error, datanode shall not panic, just drop the virtual channel - t.Run("datacoord found meta error", func(t *testing.T) { - dataCoord.DropVirtualChannelStatus = commonpb.ErrorCode_MetaFailed - dataCoord.DropVirtualChannelError = false + t.Run("datacoord_return_channel_not_found", func(t *testing.T) { + broker.ExpectedCalls = nil + broker.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything). + Return(merr.WrapErrChannelNotFound("channel")) assert.NotPanics(t, func() { dropFunc(nil) }) }) - } diff --git a/internal/datanode/flush_task.go b/internal/datanode/flush_task.go index 96b1529ccc51f..c914cfdb47dd6 100644 --- a/internal/datanode/flush_task.go +++ b/internal/datanode/flush_task.go @@ -24,7 +24,6 @@ import ( "github.com/cockroachdb/errors" milvus_storage "github.com/milvus-io/milvus-storage/go/storage" "github.com/milvus-io/milvus-storage/go/storage/options" - "github.com/milvus-io/milvus/pkg/util/tsoutil" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" @@ -34,6 +33,7 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/retry" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) // errStart used for retry start @@ -68,7 +68,7 @@ type flushTaskRunner struct { segmentID UniqueID insertLogs map[UniqueID]*datapb.Binlog statsLogs map[UniqueID]*datapb.Binlog - deltaLogs []*datapb.Binlog //[]*DelDataBuf + deltaLogs []*datapb.Binlog // []*DelDataBuf pos *msgpb.MsgPosition flushed bool dropped bool @@ -104,7 +104,7 @@ func newTaskInjection(segmentCnt int, pf func(pack *segmentFlushPack)) *taskInje // Injected returns a chan, which will be closed after pre set segments counts an injected func (ti *taskInjection) Injected() <-chan struct{} { - + return ti.injected } @@ -143,7 +143,8 @@ func (t *flushTaskRunner) init(f notifyMetaFunc, postFunc taskPostFunc, signal < // runFlushInsert executes flush insert task with once and retry func (t *flushTaskRunner) runFlushInsert(task flushInsertTask, - binlogs, statslogs map[UniqueID]*datapb.Binlog, flushed bool, dropped bool, pos *msgpb.MsgPosition, opts ...retry.Option) { + binlogs, statslogs map[UniqueID]*datapb.Binlog, flushed bool, dropped bool, pos *msgpb.MsgPosition, opts ...retry.Option, +) { t.insertOnce.Do(func() { t.insertLogs = binlogs t.statsLogs = statslogs @@ -173,7 +174,7 @@ func (t *flushTaskRunner) runFlushInsert(task flushInsertTask, func (t *flushTaskRunner) runFlushDel(task flushDeleteTask, deltaLogs *DelDataBuf, opts ...retry.Option) { t.deleteOnce.Do(func() { if deltaLogs == nil { - t.deltaLogs = nil //[]*DelDataBuf{} + t.deltaLogs = nil // []*DelDataBuf{} } else { t.deltaLogs = []*datapb.Binlog{ { diff --git a/internal/datanode/flush_task_test.go b/internal/datanode/flush_task_test.go index bc4489bbe489a..0bdcc06895069 100644 --- a/internal/datanode/flush_task_test.go +++ b/internal/datanode/flush_task_test.go @@ -19,8 +19,9 @@ package datanode import ( "testing" - "github.com/milvus-io/milvus/pkg/util/retry" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/retry" ) func TestFlushTaskRunner(t *testing.T) { @@ -91,7 +92,6 @@ func TestFlushTaskRunner_FailError(t *testing.T) { assert.True(t, errFlag) assert.True(t, nextFlag) - } func TestFlushTaskRunner_Injection(t *testing.T) { diff --git a/internal/datanode/io_pool.go b/internal/datanode/io_pool.go index 5f5feaa0c9dc0..00bb1f7e4a1b8 100644 --- a/internal/datanode/io_pool.go +++ b/internal/datanode/io_pool.go @@ -7,11 +7,15 @@ import ( "github.com/milvus-io/milvus/pkg/util/conc" ) -var ioPool *conc.Pool[any] -var ioPoolInitOnce sync.Once +var ( + ioPool *conc.Pool[any] + ioPoolInitOnce sync.Once +) -var statsPool *conc.Pool[any] -var statsPoolInitOnce sync.Once +var ( + statsPool *conc.Pool[any] + statsPoolInitOnce sync.Once +) func initIOPool() { capacity := Params.DataNodeCfg.IOConcurrency.GetAsInt() @@ -28,10 +32,28 @@ func getOrCreateIOPool() *conc.Pool[any] { } func initStatsPool() { - statsPool = conc.NewPool[any](runtime.GOMAXPROCS(0), conc.WithPreAlloc(false), conc.WithNonBlocking(false)) + poolSize := Params.DataNodeCfg.ChannelWorkPoolSize.GetAsInt() + if poolSize <= 0 { + poolSize = runtime.GOMAXPROCS(0) + } + statsPool = conc.NewPool[any](poolSize, conc.WithPreAlloc(false), conc.WithNonBlocking(false)) } func getOrCreateStatsPool() *conc.Pool[any] { statsPoolInitOnce.Do(initStatsPool) return statsPool } + +func initMultiReadPool() { + capacity := Params.DataNodeCfg.FileReadConcurrency.GetAsInt() + if capacity > runtime.GOMAXPROCS(0) { + capacity = runtime.GOMAXPROCS(0) + } + // error only happens with negative expiry duration or with negative pre-alloc size. + ioPool = conc.NewPool[any](capacity) +} + +func getMultiReadPool() *conc.Pool[any] { + ioPoolInitOnce.Do(initMultiReadPool) + return ioPool +} diff --git a/internal/datanode/iterators/binlog_iterator.go b/internal/datanode/iterators/binlog_iterator.go new file mode 100644 index 0000000000000..884efaf14e2c4 --- /dev/null +++ b/internal/datanode/iterators/binlog_iterator.go @@ -0,0 +1,103 @@ +package iterator + +import ( + "sync" + + "go.uber.org/atomic" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type BinlogIterator struct { + disposed atomic.Bool + disposedCh chan struct{} + disposedOnce sync.Once + + data *storage.InsertData + label *Label + pkFieldID int64 + pkType schemapb.DataType + pos int +} + +var _ Iterator = (*BinlogIterator)(nil) + +// NewInsertBinlogIterator creates a new iterator +func NewInsertBinlogIterator(v [][]byte, pkFieldID typeutil.UniqueID, pkType schemapb.DataType, label *Label) (*BinlogIterator, error) { + blobs := make([]*storage.Blob, len(v)) + for i := range blobs { + blobs[i] = &storage.Blob{Value: v[i]} + } + + reader := storage.NewInsertCodec() + _, _, iData, err := reader.Deserialize(blobs) + if err != nil { + return nil, err + } + + return &BinlogIterator{ + disposedCh: make(chan struct{}), + data: iData, + pkFieldID: pkFieldID, + pkType: pkType, + label: label, + }, nil +} + +// HasNext returns true if the iterator have unread record +func (i *BinlogIterator) HasNext() bool { + return !i.isDisposed() && i.hasNext() +} + +func (i *BinlogIterator) Next() (*LabeledRowData, error) { + if i.isDisposed() { + return nil, ErrDisposed + } + + if !i.hasNext() { + return nil, ErrNoMoreRecord + } + + fields := make(map[int64]interface{}) + for fieldID, fieldData := range i.data.Data { + fields[fieldID] = fieldData.GetRow(i.pos) + } + + pk, err := storage.GenPrimaryKeyByRawData(i.data.Data[i.pkFieldID].GetRow(i.pos), i.pkType) + if err != nil { + return nil, err + } + + row := &InsertRow{ + ID: i.data.Data[common.RowIDField].GetRow(i.pos).(int64), + Timestamp: uint64(i.data.Data[common.TimeStampField].GetRow(i.pos).(int64)), + PK: pk, + Value: fields, + } + i.pos++ + return NewLabeledRowData(row, i.label), nil +} + +// Dispose disposes the iterator +func (i *BinlogIterator) Dispose() { + i.disposed.CompareAndSwap(false, true) + i.disposedOnce.Do(func() { + close(i.disposedCh) + }) +} + +func (i *BinlogIterator) hasNext() bool { + return i.pos < i.data.GetRowNum() +} + +func (i *BinlogIterator) isDisposed() bool { + return i.disposed.Load() +} + +// Disposed wait forever for the iterator to dispose +func (i *BinlogIterator) WaitForDisposed() { + <-i.disposedCh +} diff --git a/internal/datanode/iterators/binlog_iterator_test.go b/internal/datanode/iterators/binlog_iterator_test.go new file mode 100644 index 0000000000000..e9e95ba44290b --- /dev/null +++ b/internal/datanode/iterators/binlog_iterator_test.go @@ -0,0 +1,305 @@ +package iterator + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/etcdpb" + "github.com/milvus-io/milvus/internal/storage" +) + +func TestInsertBinlogIteratorSuite(t *testing.T) { + suite.Run(t, new(InsertBinlogIteratorSuite)) +} + +const ( + CollectionID = 10000 + PartitionID = 10001 + SegmentID = 10002 + RowIDField = 0 + TimestampField = 1 + BoolField = 100 + Int8Field = 101 + Int16Field = 102 + Int32Field = 103 + Int64Field = 104 + FloatField = 105 + DoubleField = 106 + StringField = 107 + BinaryVectorField = 108 + FloatVectorField = 109 + ArrayField = 110 + JSONField = 111 + Float16VectorField = 112 +) + +type InsertBinlogIteratorSuite struct { + suite.Suite + + i *BinlogIterator +} + +func (s *InsertBinlogIteratorSuite) TestBinlogIterator() { + insertData, meta := genTestInsertData() + writer := storage.NewInsertCodecWithSchema(meta) + blobs, err := writer.Serialize(PartitionID, SegmentID, insertData) + s.Require().NoError(err) + + values := [][]byte{} + for _, b := range blobs { + values = append(values, b.Value[:]) + } + s.Run("invalid blobs", func() { + iter, err := NewInsertBinlogIterator([][]byte{}, Int64Field, schemapb.DataType_Int64, nil) + s.Error(err) + s.Nil(iter) + }) + + s.Run("invalid pk type", func() { + iter, err := NewInsertBinlogIterator(values, Int64Field, schemapb.DataType_Float, &Label{segmentID: 19530}) + s.NoError(err) + + _, err = iter.Next() + s.Error(err) + }) + + s.Run("normal", func() { + iter, err := NewInsertBinlogIterator(values, Int64Field, schemapb.DataType_Int64, &Label{segmentID: 19530}) + s.NoError(err) + + rows := []interface{}{} + var idx int = 0 // row number + + for iter.HasNext() { + labeled, err := iter.Next() + s.NoError(err) + s.Equal(int64(19530), labeled.GetSegmentID()) + + rows = append(rows, labeled.data) + + insertRow, ok := labeled.data.(*InsertRow) + s.True(ok) + + s.Equal(insertData.Data[Int64Field].GetRow(idx).(int64), insertRow.PK.GetValue().(int64)) + s.Equal(insertData.Data[RowIDField].GetRow(idx).(int64), insertRow.ID) + s.Equal(insertData.Data[BoolField].GetRow(idx).(bool), insertRow.Value[BoolField].(bool)) + s.Equal(insertData.Data[Int8Field].GetRow(idx).(int8), insertRow.Value[Int8Field].(int8)) + s.Equal(insertData.Data[Int16Field].GetRow(idx).(int16), insertRow.Value[Int16Field].(int16)) + s.Equal(insertData.Data[Int32Field].GetRow(idx).(int32), insertRow.Value[Int32Field].(int32)) + s.Equal(insertData.Data[Int64Field].GetRow(idx).(int64), insertRow.Value[Int64Field].(int64)) + s.Equal(insertData.Data[Int64Field].GetRow(idx).(int64), insertRow.Value[Int64Field].(int64)) + s.Equal(insertData.Data[FloatField].GetRow(idx).(float32), insertRow.Value[FloatField].(float32)) + s.Equal(insertData.Data[DoubleField].GetRow(idx).(float64), insertRow.Value[DoubleField].(float64)) + s.Equal(insertData.Data[StringField].GetRow(idx).(string), insertRow.Value[StringField].(string)) + s.Equal(insertData.Data[ArrayField].GetRow(idx).(*schemapb.ScalarField).GetIntData().Data, insertRow.Value[ArrayField].(*schemapb.ScalarField).GetIntData().Data) + s.Equal(insertData.Data[JSONField].GetRow(idx).([]byte), insertRow.Value[JSONField].([]byte)) + s.Equal(insertData.Data[BinaryVectorField].GetRow(idx).([]byte), insertRow.Value[BinaryVectorField].([]byte)) + s.Equal(insertData.Data[FloatVectorField].GetRow(idx).([]float32), insertRow.Value[FloatVectorField].([]float32)) + s.Equal(insertData.Data[Float16VectorField].GetRow(idx).([]byte), insertRow.Value[Float16VectorField].([]byte)) + + idx++ + } + + s.Equal(2, len(rows)) + + _, err = iter.Next() + s.ErrorIs(err, ErrNoMoreRecord) + + iter.Dispose() + iter.WaitForDisposed() + + _, err = iter.Next() + s.ErrorIs(err, ErrDisposed) + }) +} + +func genTestInsertData() (*storage.InsertData, *etcdpb.CollectionMeta) { + meta := &etcdpb.CollectionMeta{ + ID: CollectionID, + CreateTime: 1, + SegmentIDs: []int64{SegmentID}, + PartitionTags: []string{"partition_0", "partition_1"}, + Schema: &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: RowIDField, + Name: "row_id", + IsPrimaryKey: false, + Description: "row_id", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: TimestampField, + Name: "Timestamp", + IsPrimaryKey: false, + Description: "Timestamp", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: BoolField, + Name: "field_bool", + IsPrimaryKey: false, + Description: "bool", + DataType: schemapb.DataType_Bool, + }, + { + FieldID: Int8Field, + Name: "field_int8", + IsPrimaryKey: false, + Description: "int8", + DataType: schemapb.DataType_Int8, + }, + { + FieldID: Int16Field, + Name: "field_int16", + IsPrimaryKey: false, + Description: "int16", + DataType: schemapb.DataType_Int16, + }, + { + FieldID: Int32Field, + Name: "field_int32", + IsPrimaryKey: false, + Description: "int32", + DataType: schemapb.DataType_Int32, + }, + { + FieldID: Int64Field, + Name: "field_int64", + IsPrimaryKey: true, + Description: "int64", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: FloatField, + Name: "field_float", + IsPrimaryKey: false, + Description: "float", + DataType: schemapb.DataType_Float, + }, + { + FieldID: DoubleField, + Name: "field_double", + IsPrimaryKey: false, + Description: "double", + DataType: schemapb.DataType_Double, + }, + { + FieldID: StringField, + Name: "field_string", + IsPrimaryKey: false, + Description: "string", + DataType: schemapb.DataType_String, + }, + { + FieldID: ArrayField, + Name: "field_int32_array", + Description: "int32 array", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int32, + }, + { + FieldID: JSONField, + Name: "field_json", + Description: "json", + DataType: schemapb.DataType_JSON, + }, + { + FieldID: BinaryVectorField, + Name: "field_binary_vector", + IsPrimaryKey: false, + Description: "binary_vector", + DataType: schemapb.DataType_BinaryVector, + }, + { + FieldID: FloatVectorField, + Name: "field_float_vector", + IsPrimaryKey: false, + Description: "float_vector", + DataType: schemapb.DataType_FloatVector, + }, + { + FieldID: Float16VectorField, + Name: "field_float16_vector", + IsPrimaryKey: false, + Description: "float16_vector", + DataType: schemapb.DataType_Float16Vector, + }, + }, + }, + } + insertData := storage.InsertData{ + Data: map[int64]storage.FieldData{ + RowIDField: &storage.Int64FieldData{ + Data: []int64{3, 4}, + }, + TimestampField: &storage.Int64FieldData{ + Data: []int64{3, 4}, + }, + BoolField: &storage.BoolFieldData{ + Data: []bool{true, false}, + }, + Int8Field: &storage.Int8FieldData{ + Data: []int8{3, 4}, + }, + Int16Field: &storage.Int16FieldData{ + Data: []int16{3, 4}, + }, + Int32Field: &storage.Int32FieldData{ + Data: []int32{3, 4}, + }, + Int64Field: &storage.Int64FieldData{ + Data: []int64{3, 4}, + }, + FloatField: &storage.FloatFieldData{ + Data: []float32{3, 4}, + }, + DoubleField: &storage.DoubleFieldData{ + Data: []float64{3, 4}, + }, + StringField: &storage.StringFieldData{ + Data: []string{"3", "4"}, + }, + BinaryVectorField: &storage.BinaryVectorFieldData{ + Data: []byte{0, 255}, + Dim: 8, + }, + FloatVectorField: &storage.FloatVectorFieldData{ + Data: []float32{4, 5, 6, 7, 4, 5, 6, 7}, + Dim: 4, + }, + ArrayField: &storage.ArrayFieldData{ + ElementType: schemapb.DataType_Int32, + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{Data: []int32{3, 2, 1}}, + }, + }, + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{Data: []int32{6, 5, 4}}, + }, + }, + }, + }, + JSONField: &storage.JSONFieldData{ + Data: [][]byte{ + []byte(`{"batch":2}`), + []byte(`{"key":"world"}`), + }, + }, + Float16VectorField: &storage.Float16VectorFieldData{ + Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}, + Dim: 4, + }, + }, + } + + return &insertData, meta +} diff --git a/internal/datanode/iterators/deltalog_iterator.go b/internal/datanode/iterators/deltalog_iterator.go new file mode 100644 index 0000000000000..3d63ee1b2ce47 --- /dev/null +++ b/internal/datanode/iterators/deltalog_iterator.go @@ -0,0 +1,80 @@ +package iterator + +import ( + "sync" + + "go.uber.org/atomic" + + "github.com/milvus-io/milvus/internal/storage" +) + +var _ Iterator = (*DeltalogIterator)(nil) + +type DeltalogIterator struct { + disposeCh chan struct{} + disposedOnce sync.Once + disposed atomic.Bool + + data *storage.DeleteData + label *Label + pos int +} + +func NewDeltalogIterator(v [][]byte, label *Label) (*DeltalogIterator, error) { + blobs := make([]*storage.Blob, len(v)) + for i := range blobs { + blobs[i] = &storage.Blob{Value: v[i]} + } + + reader := storage.NewDeleteCodec() + _, _, dData, err := reader.Deserialize(blobs) + if err != nil { + return nil, err + } + return &DeltalogIterator{ + disposeCh: make(chan struct{}), + data: dData, + label: label, + }, nil +} + +func (d *DeltalogIterator) HasNext() bool { + return !d.isDisposed() && d.hasNext() +} + +func (d *DeltalogIterator) Next() (*LabeledRowData, error) { + if d.isDisposed() { + return nil, ErrDisposed + } + + if !d.hasNext() { + return nil, ErrNoMoreRecord + } + + row := &DeltalogRow{ + Pk: d.data.Pks[d.pos], + Timestamp: d.data.Tss[d.pos], + } + d.pos++ + + return NewLabeledRowData(row, d.label), nil +} + +func (d *DeltalogIterator) Dispose() { + d.disposed.CompareAndSwap(false, true) + d.disposedOnce.Do(func() { + close(d.disposeCh) + }) +} + +func (d *DeltalogIterator) hasNext() bool { + return int64(d.pos) < d.data.RowCount +} + +func (d *DeltalogIterator) isDisposed() bool { + return d.disposed.Load() +} + +func (d *DeltalogIterator) WaitForDisposed() { + <-d.disposeCh +} diff --git a/internal/datanode/iterators/deltalog_iterator_test.go b/internal/datanode/iterators/deltalog_iterator_test.go new file mode 100644 index 0000000000000..930b3f0f17fa5 --- /dev/null +++ b/internal/datanode/iterators/deltalog_iterator_test.go @@ -0,0 +1,61 @@ +package iterator + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/storage" +) + +func TestDeltalogIteratorSuite(t *testing.T) { + suite.Run(t, new(DeltalogIteratorSuite)) +} + +type DeltalogIteratorSuite struct { + suite.Suite +} + +func (s *DeltalogIteratorSuite) TestDeltalogIteratorIntPK() { + testpks := []int64{1, 2, 3, 4} + testtss := []uint64{43757345, 43757346, 43757347, 43757348} + + dData := &storage.DeleteData{} + for i := range testpks { + dData.Append(storage.NewInt64PrimaryKey(testpks[i]), testtss[i]) + } + + dCodec := storage.NewDeleteCodec() + blob, err := dCodec.Serialize(CollectionID, 1, 1, dData) + s.Require().NoError(err) + value := [][]byte{blob.Value[:]} + + iter, err := NewDeltalogIterator(value, &Label{segmentID: 100}) + s.NoError(err) + + var ( + gotpks = []int64{} + gottss = []uint64{} + ) + + for iter.HasNext() { + labeled, err := iter.Next() + s.NoError(err) + + s.Equal(labeled.GetSegmentID(), int64(100)) + gotpks = append(gotpks, labeled.data.(*DeltalogRow).Pk.GetValue().(int64)) + gottss = append(gottss, labeled.data.(*DeltalogRow).Timestamp) + } + + s.ElementsMatch(gotpks, testpks) + s.ElementsMatch(gottss, testtss) + + _, err = iter.Next() + s.ErrorIs(err, ErrNoMoreRecord) + + iter.Dispose() + iter.WaitForDisposed() + + _, err = iter.Next() + s.ErrorIs(err, ErrDisposed) +} diff --git a/internal/datanode/iterators/iterator.go b/internal/datanode/iterators/iterator.go new file mode 100644 index 0000000000000..d7d9c26c3ccfc --- /dev/null +++ b/internal/datanode/iterators/iterator.go @@ -0,0 +1,62 @@ +package iterator + +import ( + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +var ( + // ErrNoMoreRecord is the error that the iterator does not have next record. + ErrNoMoreRecord = errors.New("no more record") + // ErrDisposed is the error that the iterator is disposed. + ErrDisposed = errors.New("iterator is disposed") +) + +const InvalidID int64 = -1 + +type Row interface{} + +type InsertRow struct { + ID int64 + PK storage.PrimaryKey + Timestamp typeutil.Timestamp + Value map[storage.FieldID]interface{} +} + +type DeltalogRow struct { + Pk storage.PrimaryKey + Timestamp typeutil.Timestamp +} + +type Label struct { + segmentID typeutil.UniqueID +} + +type LabeledRowData struct { + label *Label + data Row +} + +func (l *LabeledRowData) GetSegmentID() typeutil.UniqueID { + if l.label == nil { + return InvalidID + } + + return l.label.segmentID +} + +func NewLabeledRowData(data Row, label *Label) *LabeledRowData { + return &LabeledRowData{ + label: label, + data: data, + } +} + +type Iterator interface { + HasNext() bool + Next() (*LabeledRowData, error) + Dispose() + WaitForDisposed() // wait until the iterator is disposed +} diff --git a/internal/datanode/meta_service.go b/internal/datanode/meta_service.go index 4a7cfea7fa33b..32514d404ebbd 100644 --- a/internal/datanode/meta_service.go +++ b/internal/datanode/meta_service.go @@ -22,30 +22,25 @@ import ( "go.uber.org/zap" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/proto/etcdpb" - "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/commonpbutil" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/paramtable" ) // metaService initialize channel collection in data node from root coord. // Initializing channel collection happens on data node starting. It depends on // a healthy root coord and a valid root coord grpc client. type metaService struct { - channel Channel collectionID UniqueID - rootCoord types.RootCoord + broker broker.Broker } // newMetaService creates a new metaService with provided RootCoord and collectionID. -func newMetaService(rc types.RootCoord, collectionID UniqueID) *metaService { +func newMetaService(broker broker.Broker, collectionID UniqueID) *metaService { return &metaService{ - rootCoord: rc, + broker: broker, collectionID: collectionID, } } @@ -53,34 +48,17 @@ func newMetaService(rc types.RootCoord, collectionID UniqueID) *metaService { // getCollectionSchema get collection schema with provided collection id at specified timestamp. func (mService *metaService) getCollectionSchema(ctx context.Context, collID UniqueID, timestamp Timestamp) (*schemapb.CollectionSchema, error) { response, err := mService.getCollectionInfo(ctx, collID, timestamp) - if response != nil { - return response.GetSchema(), err + if err != nil { + return nil, err } - return nil, err + return response.GetSchema(), nil } // getCollectionInfo get collection info with provided collection id at specified timestamp. func (mService *metaService) getCollectionInfo(ctx context.Context, collID UniqueID, timestamp Timestamp) (*milvuspb.DescribeCollectionResponse, error) { - req := &milvuspb.DescribeCollectionRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), - commonpbutil.WithMsgID(0), //GOOSE TODO - commonpbutil.WithSourceID(paramtable.GetNodeID()), - ), - // please do not specify the collection name alone after database feature. - CollectionID: collID, - TimeStamp: timestamp, - } - - response, err := mService.rootCoord.DescribeCollectionInternal(ctx, req) + response, err := mService.broker.DescribeCollection(ctx, collID, timestamp) if err != nil { - log.Error("grpc error when describe", zap.Int64("collectionID", collID), zap.Error(err)) - return nil, err - } - - if response.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - err := merr.Error(response.Status) - log.Error("describe collection from rootcoord failed", zap.Int64("collectionID", collID), zap.Error(err)) + log.Error("failed to describe collection from rootcoord", zap.Int64("collectionID", collID), zap.Error(err)) return nil, err } diff --git a/internal/datanode/meta_service_test.go b/internal/datanode/meta_service_test.go index 297d32cfedd70..c1a0c9ed60db8 100644 --- a/internal/datanode/meta_service_test.go +++ b/internal/datanode/meta_service_test.go @@ -21,11 +21,15 @@ import ( "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/datanode/broker" + "github.com/milvus-io/milvus/pkg/util/merr" ) const ( @@ -39,15 +43,17 @@ func TestMetaService_All(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - mFactory := &RootCoordFactory{ - pkType: schemapb.DataType_Int64, - } - mFactory.setCollectionID(collectionID0) - mFactory.setCollectionName(collectionName0) - ms := newMetaService(mFactory, collectionID0) + meta := NewMetaFactory().GetCollectionMeta(collectionID0, collectionName0, schemapb.DataType_Int64) + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: meta.GetSchema(), + }, nil).Maybe() - t.Run("Test getCollectionSchema", func(t *testing.T) { + ms := newMetaService(broker, collectionID0) + t.Run("Test getCollectionSchema", func(t *testing.T) { sch, err := ms.getCollectionSchema(ctx, collectionID0, 0) assert.NoError(t, err) assert.NotNil(t, sch) @@ -67,7 +73,7 @@ type RootCoordFails1 struct { } // DescribeCollectionInternal override method that will fails -func (rc *RootCoordFails1) DescribeCollectionInternal(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (rc *RootCoordFails1) DescribeCollectionInternal(ctx context.Context, req *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { return nil, errors.New("always fail") } @@ -77,30 +83,23 @@ type RootCoordFails2 struct { } // DescribeCollectionInternal override method that will fails -func (rc *RootCoordFails2) DescribeCollectionInternal(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (rc *RootCoordFails2) DescribeCollectionInternal(ctx context.Context, req *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { return &milvuspb.DescribeCollectionResponse{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, }, nil } func TestMetaServiceRootCoodFails(t *testing.T) { - t.Run("Test Describe with error", func(t *testing.T) { rc := &RootCoordFails1{} rc.setCollectionID(collectionID0) rc.setCollectionName(collectionName0) - ms := newMetaService(rc, collectionID0) - _, err := ms.getCollectionSchema(context.Background(), collectionID1, 0) - assert.Error(t, err) - }) - - t.Run("Test Describe wit nil response", func(t *testing.T) { - rc := &RootCoordFails2{} - rc.setCollectionID(collectionID0) - rc.setCollectionName(collectionName0) + broker := broker.NewMockBroker(t) + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) - ms := newMetaService(rc, collectionID0) + ms := newMetaService(broker, collectionID0) _, err := ms.getCollectionSchema(context.Background(), collectionID1, 0) assert.Error(t, err) }) diff --git a/internal/datanode/metacache/meta_cache.go b/internal/datanode/metacache/meta_cache.go new file mode 100644 index 0000000000000..792b6f1845ccc --- /dev/null +++ b/internal/datanode/metacache/meta_cache.go @@ -0,0 +1,133 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metacache + +import ( + "sync" + + "github.com/pingcap/log" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/datapb" +) + +type MetaCache interface { + NewSegment(segmentID, partitionID int64) + UpdateSegment(newSegmentID, partitionID int64, dropSegmentIDs ...int64) + GetSegmentIDsBy(filters ...SegmentFilter) []int64 +} + +type SegmentFilter func(info *SegmentInfo) bool + +type SegmentInfo struct { + segmentID int64 + partitionID int64 +} + +func newSegmentInfo(segmentID, partitionID int64) *SegmentInfo { + return &SegmentInfo{ + segmentID: segmentID, + partitionID: partitionID, + } +} + +func WithPartitionID(partitionID int64) func(info *SegmentInfo) bool { + return func(info *SegmentInfo) bool { + return info.partitionID == partitionID + } +} + +var _ MetaCache = (*MetaCacheImpl)(nil) + +type MetaCacheImpl struct { + collectionID int64 + vChannelName string + segmentInfos map[int64]*SegmentInfo + mu sync.Mutex +} + +func NewMetaCache(vchannel *datapb.VchannelInfo) MetaCache { + cache := &MetaCacheImpl{ + collectionID: vchannel.GetCollectionID(), + vChannelName: vchannel.GetChannelName(), + segmentInfos: make(map[int64]*SegmentInfo), + } + + cache.init(vchannel) + return cache +} + +func (c *MetaCacheImpl) init(vchannel *datapb.VchannelInfo) { + for _, seg := range vchannel.FlushedSegments { + c.segmentInfos[seg.GetID()] = newSegmentInfo(seg.GetID(), seg.GetPartitionID()) + } + + for _, seg := range vchannel.UnflushedSegments { + c.segmentInfos[seg.GetID()] = newSegmentInfo(seg.GetID(), seg.GetPartitionID()) + } +} + +func (c *MetaCacheImpl) NewSegment(segmentID, partitionID int64) { + c.mu.Lock() + defer c.mu.Unlock() + + if _, ok := c.segmentInfos[segmentID]; !ok { + c.segmentInfos[segmentID] = newSegmentInfo(segmentID, partitionID) + } +} + +func (c *MetaCacheImpl) UpdateSegment(newSegmentID, partitionID int64, dropSegmentIDs ...int64) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, dropSeg := range dropSegmentIDs { + if _, ok := c.segmentInfos[dropSeg]; ok { + delete(c.segmentInfos, dropSeg) + } else { + log.Warn("some dropped segment not exist in meta cache", + zap.String("channel", c.vChannelName), + zap.Int64("collectionID", c.collectionID), + zap.Int64("segmentID", dropSeg)) + } + } + + if _, ok := c.segmentInfos[newSegmentID]; !ok { + c.segmentInfos[newSegmentID] = newSegmentInfo(newSegmentID, partitionID) + } +} + +func (c *MetaCacheImpl) GetSegmentIDsBy(filters ...SegmentFilter) []int64 { + c.mu.Lock() + defer c.mu.Unlock() + + filter := func(info *SegmentInfo) bool { + for _, filter := range filters { + if !filter(info) { + return false + } + } + return true + } + + segments := []int64{} + for _, info := range c.segmentInfos { + if filter(info) { + segments = append(segments, info.segmentID) + } + } + return segments +} diff --git a/internal/datanode/metacache/meta_cache_test.go b/internal/datanode/metacache/meta_cache_test.go new file mode 100644 index 0000000000000..90e36dc96531c --- /dev/null +++ b/internal/datanode/metacache/meta_cache_test.go @@ -0,0 +1,106 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metacache + +import ( + "testing" + + "github.com/samber/lo" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/proto/datapb" +) + +type MetaCacheSuite struct { + suite.Suite + + collectionID int64 + vchannel string + invaliedSeg int64 + partitionIDs []int64 + flushedSegments []int64 + growingSegments []int64 + newSegments []int64 + cache MetaCache +} + +func (s *MetaCacheSuite) SetupSuite() { + s.collectionID = 1 + s.vchannel = "test" + s.partitionIDs = []int64{1, 2, 3, 4} + s.flushedSegments = []int64{1, 2, 3, 4} + s.growingSegments = []int64{5, 6, 7, 8} + s.newSegments = []int64{9, 10, 11, 12} + s.invaliedSeg = 111 +} + +func (s *MetaCacheSuite) SetupTest() { + flushSegmentInfos := lo.RepeatBy(len(s.flushedSegments), func(i int) *datapb.SegmentInfo { + return &datapb.SegmentInfo{ + ID: s.flushedSegments[i], + PartitionID: s.partitionIDs[i], + } + }) + + growingSegmentInfos := lo.RepeatBy(len(s.growingSegments), func(i int) *datapb.SegmentInfo { + return &datapb.SegmentInfo{ + ID: s.growingSegments[i], + PartitionID: s.partitionIDs[i], + } + }) + + s.cache = NewMetaCache(&datapb.VchannelInfo{ + CollectionID: s.collectionID, + ChannelName: s.vchannel, + FlushedSegments: flushSegmentInfos, + UnflushedSegments: growingSegmentInfos, + }) +} + +func (s *MetaCacheSuite) TestNewSegment() { + for i, seg := range s.newSegments { + s.cache.NewSegment(seg, s.partitionIDs[i]) + } + + for id, partitionID := range s.partitionIDs { + segs := s.cache.GetSegmentIDsBy(WithPartitionID(partitionID)) + targets := []int64{s.flushedSegments[id], s.growingSegments[id], s.newSegments[id]} + s.Equal(len(targets), len(segs)) + for _, seg := range segs { + s.True(lo.Contains(targets, seg)) + } + } +} + +func (s *MetaCacheSuite) TestUpdateSegment() { + for i, seg := range s.newSegments { + // compaction from flushed[i], unflushed[i] and invalidSeg to new[i] + s.cache.UpdateSegment(seg, s.partitionIDs[i], s.flushedSegments[i], s.growingSegments[i], s.invaliedSeg) + } + + for i, partitionID := range s.partitionIDs { + segs := s.cache.GetSegmentIDsBy(WithPartitionID(partitionID)) + s.Equal(1, len(segs)) + for _, seg := range segs { + s.Equal(seg, s.newSegments[i]) + } + } +} + +func TestMetaCacheSuite(t *testing.T) { + suite.Run(t, new(MetaCacheSuite)) +} diff --git a/internal/datanode/metrics_info.go b/internal/datanode/metrics_info.go index d4f43e9d9e162..6bbafef9dc10c 100644 --- a/internal/datanode/metrics_info.go +++ b/internal/datanode/metrics_info.go @@ -19,7 +19,6 @@ package datanode import ( "context" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/merr" @@ -50,15 +49,6 @@ func (node *DataNode) getQuotaMetrics() (*metricsinfo.DataNodeQuotaMetrics, erro return nil, err } - getAllCollections := func() []int64 { - collectionSet := typeutil.UniqueSet{} - node.flowgraphManager.flowgraphs.Range(func(key string, fg *dataSyncService) bool { - collectionSet.Insert(fg.channel.getCollectionID()) - return true - }) - - return collectionSet.Collect() - } minFGChannel, minFGTt := rateCol.getMinFlowGraphTt() return &metricsinfo.DataNodeQuotaMetrics{ Hms: metricsinfo.HardwareMetrics{}, @@ -70,7 +60,7 @@ func (node *DataNode) getQuotaMetrics() (*metricsinfo.DataNodeQuotaMetrics, erro }, Effect: metricsinfo.NodeEffect{ NodeID: node.GetSession().ServerID, - CollectionIDs: getAllCollections(), + CollectionIDs: node.flowgraphManager.collections(), }, }, nil } @@ -83,10 +73,7 @@ func (node *DataNode) getSystemInfoMetrics(ctx context.Context, req *milvuspb.Ge quotaMetrics, err := node.getQuotaMetrics() if err != nil { return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), ComponentName: metricsinfo.ConstructComponentName(typeutil.DataNodeRole, paramtable.GetNodeID()), }, nil } @@ -122,17 +109,14 @@ func (node *DataNode) getSystemInfoMetrics(ctx context.Context, req *milvuspb.Ge resp, err := metricsinfo.MarshalComponentInfos(nodeInfos) if err != nil { return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), Response: "", ComponentName: metricsinfo.ConstructComponentName(typeutil.DataNodeRole, paramtable.GetNodeID()), }, nil } return &milvuspb.GetMetricsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.DataNodeRole, paramtable.GetNodeID()), }, nil diff --git a/internal/datanode/mock_test.go b/internal/datanode/mock_test.go index 6117a4b33b65d..ce65b951cdf7d 100644 --- a/internal/datanode/mock_test.go +++ b/internal/datanode/mock_test.go @@ -25,20 +25,21 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/stretchr/testify/mock" "go.uber.org/zap" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/storage" - s "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/sessionutil" @@ -47,14 +48,14 @@ import ( "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) const ctxTimeInMillisecond = 5000 -const debug = false - // As used in data_sync_service_test.go var segID2SegInfo = map[int64]*datapb.SegmentInfo{ 1: { @@ -80,62 +81,19 @@ var emptyFlushAndDropFunc flushAndDropFunc = func(_ []*segmentFlushPack) {} func newIDLEDataNodeMock(ctx context.Context, pkType schemapb.DataType) *DataNode { factory := dependency.NewDefaultFactory(true) node := NewDataNode(ctx, factory) - node.SetSession(&sessionutil.Session{ServerID: 1}) + node.SetSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}) node.dispClient = msgdispatcher.NewClient(factory, typeutil.DataNodeRole, paramtable.GetNodeID()) - rc := &RootCoordFactory{ - ID: 0, - collectionID: 1, - collectionName: "collection-1", - pkType: pkType, - } - node.rootCoord = rc - - ds := &DataCoordFactory{} - node.dataCoord = ds - - return node -} - -func newHEALTHDataNodeMock(dmChannelName string) *DataNode { - var ctx context.Context - - if debug { - ctx = context.Background() - } else { - var cancel context.CancelFunc - d := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) - ctx, cancel = context.WithDeadline(context.Background(), d) - go func() { - <-ctx.Done() - cancel() - }() - } - - factory := dependency.NewDefaultFactory(true) - node := NewDataNode(ctx, factory) - - ms := &RootCoordFactory{ - ID: 0, - collectionID: 1, - collectionName: "collection-1", - } - node.rootCoord = ms + broker := &broker.MockBroker{} + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return([]*datapb.SegmentInfo{}, nil).Maybe() - ds := &DataCoordFactory{} - node.dataCoord = ds + node.broker = broker + node.timeTickSender = newTimeTickSender(node.broker, 0) return node } -func makeNewChannelNames(names []string, suffix string) []string { - var ret []string - for _, name := range names { - ret = append(ret, name+suffix) - } - return ret -} - func newTestEtcdKV() (kv.WatchKV, error) { etcdCli, err := etcd.GetEtcdClient( Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), @@ -186,11 +144,9 @@ func clearEtcd(rootPath string) error { } log.Debug("Clear ETCD with prefix writer/ddl") return nil - } -type MetaFactory struct { -} +type MetaFactory struct{} func NewMetaFactory() *MetaFactory { return &MetaFactory{} @@ -202,7 +158,7 @@ type DataFactory struct { } type RootCoordFactory struct { - types.RootCoord + types.RootCoordClient ID UniqueID collectionName string collectionID UniqueID @@ -218,15 +174,15 @@ type RootCoordFactory struct { } type DataCoordFactory struct { - types.DataCoord + types.DataCoordClient SaveBinlogPathError bool - SaveBinlogPathStatus commonpb.ErrorCode + SaveBinlogPathStatus *commonpb.Status CompleteCompactionError bool CompleteCompactionNotSuccess bool + DropVirtualChannelError bool - DropVirtualChannelError bool DropVirtualChannelStatus commonpb.ErrorCode GetSegmentInfosError bool @@ -241,14 +197,12 @@ type DataCoordFactory struct { ReportDataNodeTtMsgsNotSuccess bool } -func (ds *DataCoordFactory) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { +func (ds *DataCoordFactory) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) { if ds.AddSegmentError { return nil, errors.New("Error") } res := &datapb.AssignSegmentIDResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), SegIDAssignments: []*datapb.SegmentIDAssignment{ { SegID: 666, @@ -265,7 +219,7 @@ func (ds *DataCoordFactory) AssignSegmentID(ctx context.Context, req *datapb.Ass return res, nil } -func (ds *DataCoordFactory) CompleteCompaction(ctx context.Context, req *datapb.CompactionResult) (*commonpb.Status, error) { +func (ds *DataCoordFactory) CompleteCompaction(ctx context.Context, req *datapb.CompactionResult, opts ...grpc.CallOption) (*commonpb.Status, error) { if ds.CompleteCompactionError { return nil, errors.New("Error") } @@ -276,14 +230,14 @@ func (ds *DataCoordFactory) CompleteCompaction(ctx context.Context, req *datapb. return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil } -func (ds *DataCoordFactory) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) (*commonpb.Status, error) { +func (ds *DataCoordFactory) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { if ds.SaveBinlogPathError { return nil, errors.New("Error") } - return &commonpb.Status{ErrorCode: ds.SaveBinlogPathStatus}, nil + return ds.SaveBinlogPathStatus, nil } -func (ds *DataCoordFactory) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) (*datapb.DropVirtualChannelResponse, error) { +func (ds *DataCoordFactory) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest, opts ...grpc.CallOption) (*datapb.DropVirtualChannelResponse, error) { if ds.DropVirtualChannelError { return nil, errors.New("error") } @@ -294,19 +248,15 @@ func (ds *DataCoordFactory) DropVirtualChannel(ctx context.Context, req *datapb. }, nil } -func (ds *DataCoordFactory) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil +func (ds *DataCoordFactory) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return merr.Success(), nil } -func (ds *DataCoordFactory) UpdateChannelCheckpoint(ctx context.Context, req *datapb.UpdateChannelCheckpointRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil +func (ds *DataCoordFactory) UpdateChannelCheckpoint(ctx context.Context, req *datapb.UpdateChannelCheckpointRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return merr.Success(), nil } -func (ds *DataCoordFactory) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDataNodeTtMsgsRequest) (*commonpb.Status, error) { +func (ds *DataCoordFactory) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDataNodeTtMsgsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { if ds.ReportDataNodeTtMsgsError { return nil, errors.New("mock ReportDataNodeTtMsgs error") } @@ -315,42 +265,32 @@ func (ds *DataCoordFactory) ReportDataNodeTtMsgs(ctx context.Context, req *datap ErrorCode: commonpb.ErrorCode_UnexpectedError, }, nil } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Success(), nil } -func (ds *DataCoordFactory) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil +func (ds *DataCoordFactory) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return merr.Success(), nil } -func (ds *DataCoordFactory) UnsetIsImportingState(context.Context, *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil +func (ds *DataCoordFactory) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsImportingStateRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return merr.Success(), nil } -func (ds *DataCoordFactory) MarkSegmentsDropped(context.Context, *datapb.MarkSegmentsDroppedRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil +func (ds *DataCoordFactory) MarkSegmentsDropped(ctx context.Context, req *datapb.MarkSegmentsDroppedRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return merr.Success(), nil } -func (ds *DataCoordFactory) BroadcastAlteredCollection(ctx context.Context, req *datapb.AlterCollectionRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil +func (ds *DataCoordFactory) BroadcastAlteredCollection(ctx context.Context, req *datapb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return merr.Success(), nil } -func (ds *DataCoordFactory) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { +func (ds *DataCoordFactory) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { return &milvuspb.CheckHealthResponse{ IsHealthy: true, }, nil } -func (ds *DataCoordFactory) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoRequest) (*datapb.GetSegmentInfoResponse, error) { +func (ds *DataCoordFactory) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*datapb.GetSegmentInfoResponse, error) { if ds.GetSegmentInfosError { return nil, errors.New("mock get segment info error") } @@ -370,15 +310,14 @@ func (ds *DataCoordFactory) GetSegmentInfo(ctx context.Context, req *datapb.GetS segmentInfos = append(segmentInfos, segInfo) } else { segmentInfos = append(segmentInfos, &datapb.SegmentInfo{ - ID: segmentID, + ID: segmentID, + CollectionID: 1, }) } } return &datapb.GetSegmentInfoResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - Infos: segmentInfos, + Status: merr.Success(), + Infos: segmentInfos, }, nil } @@ -546,10 +485,9 @@ func NewDataFactory() *DataFactory { func GenRowData() (rawData []byte) { const DIM = 2 - const N = 1 // Float vector - var fvector = [DIM]float32{1, 2} + fvector := [DIM]float32{1, 2} for _, ele := range fvector { buf := make([]byte, 4) common.Endian.PutUint32(buf, math.Float32bits(ele)) @@ -559,11 +497,11 @@ func GenRowData() (rawData []byte) { // Binary vector // Dimension of binary vector is 32 // size := 4, = 32 / 8 - var bvector = []byte{255, 255, 255, 0} + bvector := []byte{255, 255, 255, 0} rawData = append(rawData, bvector...) // Bool - var fieldBool = true + fieldBool := true buf := new(bytes.Buffer) if err := binary.Write(buf, common.Endian, fieldBool); err != nil { panic(err) @@ -612,7 +550,7 @@ func GenRowData() (rawData []byte) { rawData = append(rawData, bfloat32.Bytes()...) // float64 - var datafloat64 = 2.2 + datafloat64 := 2.2 bfloat64 := new(bytes.Buffer) if err := binary.Write(bfloat64, common.Endian, datafloat64); err != nil { panic(err) @@ -624,7 +562,7 @@ func GenRowData() (rawData []byte) { func GenColumnData() (fieldsData []*schemapb.FieldData) { // Float vector - var fVector = []float32{1, 2} + fVector := []float32{1, 2} floatVectorData := &schemapb.FieldData{ Type: schemapb.DataType_FloatVector, FieldName: "float_vector_field", @@ -769,7 +707,7 @@ func GenColumnData() (fieldsData []*schemapb.FieldData) { } fieldsData = append(fieldsData, floatFieldData) - //double + // double doubleData := []float64{2.2} doubleFieldData := &schemapb.FieldData{ Type: schemapb.DataType_Double, @@ -787,7 +725,7 @@ func GenColumnData() (fieldsData []*schemapb.FieldData) { } fieldsData = append(fieldsData, doubleFieldData) - //var char + // var char varCharData := []string{"test"} varCharFieldData := &schemapb.FieldData{ Type: schemapb.DataType_VarChar, @@ -809,7 +747,7 @@ func GenColumnData() (fieldsData []*schemapb.FieldData) { } func (df *DataFactory) GenMsgStreamInsertMsg(idx int, chanName string) *msgstream.InsertMsg { - var msg = &msgstream.InsertMsg{ + msg := &msgstream.InsertMsg{ BaseMsg: msgstream.BaseMsg{ HashValues: []uint32{uint32(idx)}, }, @@ -837,7 +775,7 @@ func (df *DataFactory) GenMsgStreamInsertMsg(idx int, chanName string) *msgstrea } func (df *DataFactory) GenMsgStreamInsertMsgWithTs(idx int, chanName string, ts Timestamp) *msgstream.InsertMsg { - var msg = &msgstream.InsertMsg{ + msg := &msgstream.InsertMsg{ BaseMsg: msgstream.BaseMsg{ HashValues: []uint32{uint32(idx)}, BeginTimestamp: ts, @@ -868,7 +806,7 @@ func (df *DataFactory) GenMsgStreamInsertMsgWithTs(idx int, chanName string, ts func (df *DataFactory) GetMsgStreamTsInsertMsgs(n int, chanName string, ts Timestamp) (inMsgs []msgstream.TsMsg) { for i := 0; i < n; i++ { - var msg = df.GenMsgStreamInsertMsgWithTs(i, chanName, ts) + msg := df.GenMsgStreamInsertMsgWithTs(i, chanName, ts) var tsMsg msgstream.TsMsg = msg inMsgs = append(inMsgs, tsMsg) } @@ -877,7 +815,7 @@ func (df *DataFactory) GetMsgStreamTsInsertMsgs(n int, chanName string, ts Times func (df *DataFactory) GetMsgStreamInsertMsgs(n int) (msgs []*msgstream.InsertMsg) { for i := 0; i < n; i++ { - var msg = df.GenMsgStreamInsertMsg(i, "") + msg := df.GenMsgStreamInsertMsg(i, "") msgs = append(msgs, msg) } return @@ -889,7 +827,7 @@ func (df *DataFactory) GenMsgStreamDeleteMsg(pks []primaryKey, chanName string) for i := 0; i < len(pks); i++ { timestamps[i] = Timestamp(i) + 1000 } - var msg = &msgstream.DeleteMsg{ + msg := &msgstream.DeleteMsg{ BaseMsg: msgstream.BaseMsg{ HashValues: []uint32{uint32(idx)}, }, @@ -904,7 +842,7 @@ func (df *DataFactory) GenMsgStreamDeleteMsg(pks []primaryKey, chanName string) PartitionName: "default", PartitionID: 1, ShardName: chanName, - PrimaryKeys: s.ParsePrimaryKeys2IDs(pks), + PrimaryKeys: storage.ParsePrimaryKeys2IDs(pks), Timestamps: timestamps, NumRows: int64(len(pks)), }, @@ -913,7 +851,7 @@ func (df *DataFactory) GenMsgStreamDeleteMsg(pks []primaryKey, chanName string) } func (df *DataFactory) GenMsgStreamDeleteMsgWithTs(idx int, pks []primaryKey, chanName string, ts Timestamp) *msgstream.DeleteMsg { - var msg = &msgstream.DeleteMsg{ + msg := &msgstream.DeleteMsg{ BaseMsg: msgstream.BaseMsg{ HashValues: []uint32{uint32(idx)}, BeginTimestamp: ts, @@ -931,7 +869,7 @@ func (df *DataFactory) GenMsgStreamDeleteMsgWithTs(idx int, pks []primaryKey, ch PartitionID: 1, CollectionID: UniqueID(0), ShardName: chanName, - PrimaryKeys: s.ParsePrimaryKeys2IDs(pks), + PrimaryKeys: storage.ParsePrimaryKeys2IDs(pks), Timestamps: []Timestamp{ts}, NumRows: int64(len(pks)), }, @@ -953,7 +891,7 @@ func genFlowGraphInsertMsg(chanName string) flowGraphMsg { }, } - var fgMsg = &flowGraphMsg{ + fgMsg := &flowGraphMsg{ insertMessages: make([]*msgstream.InsertMsg, 0), timeRange: TimeRange{ timestampMin: timeRange.timestampMin, @@ -983,7 +921,7 @@ func genFlowGraphDeleteMsg(pks []primaryKey, chanName string) flowGraphMsg { }, } - var fgMsg = &flowGraphMsg{ + fgMsg := &flowGraphMsg{ insertMessages: make([]*msgstream.InsertMsg, 0), timeRange: TimeRange{ timestampMin: timeRange.timestampMin, @@ -999,12 +937,6 @@ func genFlowGraphDeleteMsg(pks []primaryKey, chanName string) flowGraphMsg { return *fgMsg } -// If id == 0, AllocID will return not successful status -// If id == -1, AllocID will return err -func (m *RootCoordFactory) setID(id UniqueID) { - m.ID = id // GOOSE TODO: random ID generator -} - func (m *RootCoordFactory) setCollectionID(id UniqueID) { m.collectionID = id } @@ -1013,11 +945,12 @@ func (m *RootCoordFactory) setCollectionName(name string) { m.collectionName = name } -func (m *RootCoordFactory) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { +func (m *RootCoordFactory) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { resp := &rootcoordpb.AllocIDResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, - }} + }, + } if in.Count == 12 { resp.Status.ErrorCode = commonpb.ErrorCode_Success @@ -1031,7 +964,7 @@ func (m *RootCoordFactory) AllocID(ctx context.Context, in *rootcoordpb.AllocIDR } if m.ID == -1 { - return nil, errors.New(resp.Status.GetReason()) + return nil, merr.Error(resp.Status) } resp.ID = m.ID @@ -1040,7 +973,7 @@ func (m *RootCoordFactory) AllocID(ctx context.Context, in *rootcoordpb.AllocIDR return resp, nil } -func (m *RootCoordFactory) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { +func (m *RootCoordFactory) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestampRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) { resp := &rootcoordpb.AllocTimestampResponse{ Status: &commonpb.Status{}, Timestamp: 1000, @@ -1055,16 +988,15 @@ func (m *RootCoordFactory) AllocTimestamp(ctx context.Context, in *rootcoordpb.A return resp, nil } -func (m *RootCoordFactory) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { +func (m *RootCoordFactory) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) { resp := &milvuspb.ShowCollectionsResponse{ Status: &commonpb.Status{}, CollectionNames: []string{m.collectionName}, } return resp, nil - } -func (m *RootCoordFactory) DescribeCollectionInternal(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (m *RootCoordFactory) DescribeCollectionInternal(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { f := MetaFactory{} meta := f.GetCollectionMeta(m.collectionID, m.collectionName, m.pkType) resp := &milvuspb.DescribeCollectionResponse{ @@ -1079,8 +1011,7 @@ func (m *RootCoordFactory) DescribeCollectionInternal(ctx context.Context, in *m } if m.collectionID == -1 { - resp.Status.ErrorCode = commonpb.ErrorCode_Success - return resp, errors.New(resp.Status.GetReason()) + return nil, merr.Error(resp.Status) } resp.CollectionID = m.collectionID @@ -1090,12 +1021,10 @@ func (m *RootCoordFactory) DescribeCollectionInternal(ctx context.Context, in *m return resp, nil } -func (m *RootCoordFactory) ShowPartitions(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { +func (m *RootCoordFactory) ShowPartitions(ctx context.Context, req *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { if m.ShowPartitionsErr { return &milvuspb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, fmt.Errorf("mock show partitions error") } @@ -1109,43 +1038,35 @@ func (m *RootCoordFactory) ShowPartitions(ctx context.Context, req *milvuspb.Sho } return &milvuspb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), PartitionNames: m.ShowPartitionsNames, PartitionIDs: m.ShowPartitionsIDs, }, nil } -func (m *RootCoordFactory) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (m *RootCoordFactory) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{}, SubcomponentStates: make([]*milvuspb.ComponentInfo, 0), - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } -func (m *RootCoordFactory) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) (*commonpb.Status, error) { +func (m *RootCoordFactory) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult, opts ...grpc.CallOption) (*commonpb.Status, error) { if ctx != nil && ctx.Value(ctxKey{}) != nil { if v := ctx.Value(ctxKey{}).(string); v == returnError { return nil, fmt.Errorf("injected error") } } if m.ReportImportErr { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, fmt.Errorf("mock report import error") + return merr.Success(), fmt.Errorf("mock report import error") } if m.ReportImportNotSuccess { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, }, nil } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Success(), nil } // FailMessageStreamFactory mock MessageStreamFactory failure @@ -1169,15 +1090,15 @@ func genInsertDataWithPKs(PKs [2]primaryKey, dataType schemapb.DataType) *Insert for index, pk := range PKs { values[index] = pk.(*int64PrimaryKey).Value } - iD.Data[106].(*s.Int64FieldData).Data = values + iD.Data[106].(*storage.Int64FieldData).Data = values case schemapb.DataType_VarChar: values := make([]string, len(PKs)) for index, pk := range PKs { values[index] = pk.(*varCharPrimaryKey).Value } - iD.Data[109].(*s.StringFieldData).Data = values + iD.Data[109].(*storage.StringFieldData).Data = values default: - //TODO:: + // TODO:: } return iD } @@ -1195,134 +1116,137 @@ func genTestStat(meta *etcdpb.CollectionMeta) *storage.PrimaryKeyStats { func genInsertData() *InsertData { return &InsertData{ - Data: map[int64]s.FieldData{ - 0: &s.Int64FieldData{ + Data: map[int64]storage.FieldData{ + 0: &storage.Int64FieldData{ Data: []int64{1, 2}, }, - 1: &s.Int64FieldData{ + 1: &storage.Int64FieldData{ Data: []int64{3, 4}, }, - 100: &s.FloatVectorFieldData{ + 100: &storage.FloatVectorFieldData{ Data: []float32{1.0, 6.0, 7.0, 8.0}, Dim: 2, }, - 101: &s.BinaryVectorFieldData{ + 101: &storage.BinaryVectorFieldData{ Data: []byte{0, 255, 255, 255, 128, 128, 128, 0}, Dim: 32, }, - 102: &s.BoolFieldData{ + 102: &storage.BoolFieldData{ Data: []bool{true, false}, }, - 103: &s.Int8FieldData{ + 103: &storage.Int8FieldData{ Data: []int8{5, 6}, }, - 104: &s.Int16FieldData{ + 104: &storage.Int16FieldData{ Data: []int16{7, 8}, }, - 105: &s.Int32FieldData{ + 105: &storage.Int32FieldData{ Data: []int32{9, 10}, }, - 106: &s.Int64FieldData{ + 106: &storage.Int64FieldData{ Data: []int64{1, 2}, }, - 107: &s.FloatFieldData{ + 107: &storage.FloatFieldData{ Data: []float32{2.333, 2.334}, }, - 108: &s.DoubleFieldData{ + 108: &storage.DoubleFieldData{ Data: []float64{3.333, 3.334}, }, - 109: &s.StringFieldData{ + 109: &storage.StringFieldData{ Data: []string{"test1", "test2"}, }, - }} + }, + } } func genEmptyInsertData() *InsertData { return &InsertData{ - Data: map[int64]s.FieldData{ - 0: &s.Int64FieldData{ + Data: map[int64]storage.FieldData{ + 0: &storage.Int64FieldData{ Data: []int64{}, }, - 1: &s.Int64FieldData{ + 1: &storage.Int64FieldData{ Data: []int64{}, }, - 100: &s.FloatVectorFieldData{ + 100: &storage.FloatVectorFieldData{ Data: []float32{}, Dim: 2, }, - 101: &s.BinaryVectorFieldData{ + 101: &storage.BinaryVectorFieldData{ Data: []byte{}, Dim: 32, }, - 102: &s.BoolFieldData{ + 102: &storage.BoolFieldData{ Data: []bool{}, }, - 103: &s.Int8FieldData{ + 103: &storage.Int8FieldData{ Data: []int8{}, }, - 104: &s.Int16FieldData{ + 104: &storage.Int16FieldData{ Data: []int16{}, }, - 105: &s.Int32FieldData{ + 105: &storage.Int32FieldData{ Data: []int32{}, }, - 106: &s.Int64FieldData{ + 106: &storage.Int64FieldData{ Data: []int64{}, }, - 107: &s.FloatFieldData{ + 107: &storage.FloatFieldData{ Data: []float32{}, }, - 108: &s.DoubleFieldData{ + 108: &storage.DoubleFieldData{ Data: []float64{}, }, - 109: &s.StringFieldData{ + 109: &storage.StringFieldData{ Data: []string{}, }, - }} + }, + } } func genInsertDataWithExpiredTS() *InsertData { return &InsertData{ - Data: map[int64]s.FieldData{ - 0: &s.Int64FieldData{ + Data: map[int64]storage.FieldData{ + 0: &storage.Int64FieldData{ Data: []int64{11, 22}, }, - 1: &s.Int64FieldData{ + 1: &storage.Int64FieldData{ Data: []int64{329749364736000000, 329500223078400000}, // 2009-11-10 23:00:00 +0000 UTC, 2009-10-31 23:00:00 +0000 UTC }, - 100: &s.FloatVectorFieldData{ + 100: &storage.FloatVectorFieldData{ Data: []float32{1.0, 6.0, 7.0, 8.0}, Dim: 2, }, - 101: &s.BinaryVectorFieldData{ + 101: &storage.BinaryVectorFieldData{ Data: []byte{0, 255, 255, 255, 128, 128, 128, 0}, Dim: 32, }, - 102: &s.BoolFieldData{ + 102: &storage.BoolFieldData{ Data: []bool{true, false}, }, - 103: &s.Int8FieldData{ + 103: &storage.Int8FieldData{ Data: []int8{5, 6}, }, - 104: &s.Int16FieldData{ + 104: &storage.Int16FieldData{ Data: []int16{7, 8}, }, - 105: &s.Int32FieldData{ + 105: &storage.Int32FieldData{ Data: []int32{9, 10}, }, - 106: &s.Int64FieldData{ + 106: &storage.Int64FieldData{ Data: []int64{1, 2}, }, - 107: &s.FloatFieldData{ + 107: &storage.FloatFieldData{ Data: []float32{2.333, 2.334}, }, - 108: &s.DoubleFieldData{ + 108: &storage.DoubleFieldData{ Data: []float64{3.333, 3.334}, }, - 109: &s.StringFieldData{ + 109: &storage.StringFieldData{ Data: []string{"test1", "test2"}, }, - }} + }, + } } func genTimestamp() typeutil.Timestamp { @@ -1331,6 +1255,6 @@ func genTimestamp() typeutil.Timestamp { return tsoutil.ComposeTSByTime(gb, 0) } -func genTestTickler() *tickler { - return newTickler(0, "", nil, nil, 0) +func genTestTickler() *etcdTickler { + return newEtcdTickler(0, "", nil, nil, 0) } diff --git a/internal/datanode/rate_collector.go b/internal/datanode/rate_collector.go index 96d5e59150662..b7052c3bb1a37 100644 --- a/internal/datanode/rate_collector.go +++ b/internal/datanode/rate_collector.go @@ -24,8 +24,10 @@ import ( ) // rateCol is global rateCollector in DataNode. -var rateCol *rateCollector -var initOnce sync.Once +var ( + rateCol *rateCollector + initOnce sync.Once +) // rateCollector helps to collect and calculate values (like rate, timeTick and etc...). type rateCollector struct { diff --git a/internal/datanode/segment.go b/internal/datanode/segment.go index f0432c7a6c60e..57fbf2820e28d 100644 --- a/internal/datanode/segment.go +++ b/internal/datanode/segment.go @@ -24,7 +24,6 @@ import ( "sync/atomic" "github.com/bits-and-blooms/bloom/v3" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" @@ -34,6 +33,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) // Segment contains the latest segment infos from channel. diff --git a/internal/datanode/segment_sync_policy.go b/internal/datanode/segment_sync_policy.go index a4ef4e177bf98..72498efb4b6cf 100644 --- a/internal/datanode/segment_sync_policy.go +++ b/internal/datanode/segment_sync_policy.go @@ -21,20 +21,21 @@ import ( "sort" "time" + "github.com/samber/lo" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/tsoutil" - "go.uber.org/atomic" - "go.uber.org/zap" ) const minSyncSize = 0.5 * 1024 * 1024 // segmentsSyncPolicy sync policy applies to segments -type segmentSyncPolicy func(segments []*Segment, ts Timestamp, needToSync *atomic.Bool) []UniqueID +type segmentSyncPolicy func(segments []*Segment, c Channel, ts Timestamp) []UniqueID // syncPeriodically get segmentSyncPolicy with segments sync periodically. func syncPeriodically() segmentSyncPolicy { - return func(segments []*Segment, ts Timestamp, _ *atomic.Bool) []UniqueID { + return func(segments []*Segment, c Channel, ts Timestamp) []UniqueID { segmentsToSync := make([]UniqueID, 0) for _, seg := range segments { endPosTime := tsoutil.PhysicalTime(ts) @@ -45,7 +46,7 @@ func syncPeriodically() segmentSyncPolicy { } } if len(segmentsToSync) > 0 { - log.Info("sync segment periodically", zap.Int64s("segmentID", segmentsToSync)) + log.Info("sync segment periodically", zap.Int64s("segmentIDs", segmentsToSync)) } return segmentsToSync } @@ -53,8 +54,8 @@ func syncPeriodically() segmentSyncPolicy { // syncMemoryTooHigh force sync the largest segment. func syncMemoryTooHigh() segmentSyncPolicy { - return func(segments []*Segment, ts Timestamp, needToSync *atomic.Bool) []UniqueID { - if len(segments) == 0 || !needToSync.Load() { + return func(segments []*Segment, c Channel, _ Timestamp) []UniqueID { + if len(segments) == 0 || !c.getIsHighMemory() { return nil } sort.Slice(segments, func(i, j int) bool { @@ -74,3 +75,22 @@ func syncMemoryTooHigh() segmentSyncPolicy { return syncSegments } } + +// syncSegmentsAtTs returns a new segmentSyncPolicy, sync segments when ts exceeds ChannelMeta.flushTs +func syncSegmentsAtTs() segmentSyncPolicy { + return func(segments []*Segment, c Channel, ts Timestamp) []UniqueID { + flushTs := c.getFlushTs() + if flushTs != 0 && ts >= flushTs { + segmentsWithBuffer := lo.Filter(segments, func(segment *Segment, _ int) bool { + return !segment.isBufferEmpty() + }) + segmentIDs := lo.Map(segmentsWithBuffer, func(segment *Segment, _ int) UniqueID { + return segment.segmentID + }) + log.Info("sync segment at ts", zap.Int64s("segmentIDs", segmentIDs), + zap.Time("ts", tsoutil.PhysicalTime(ts)), zap.Time("flushTs", tsoutil.PhysicalTime(flushTs))) + return segmentIDs + } + return nil + } +} diff --git a/internal/datanode/segment_sync_policy_test.go b/internal/datanode/segment_sync_policy_test.go index e54eca3a117e7..bc6a14533d66c 100644 --- a/internal/datanode/segment_sync_policy_test.go +++ b/internal/datanode/segment_sync_policy_test.go @@ -18,11 +18,11 @@ package datanode import ( "fmt" + "math" "testing" "time" "github.com/stretchr/testify/assert" - "go.uber.org/atomic" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/util/tsoutil" @@ -55,7 +55,7 @@ func TestSyncPeriodically(t *testing.T) { if test.isBufferEmpty { segment.curInsertBuf = nil } - res := policy([]*Segment{segment}, tsoutil.ComposeTSByTime(test.endPosTs, 0), nil) + res := policy([]*Segment{segment}, nil, tsoutil.ComposeTSByTime(test.endPosTs, 0)) assert.Equal(t, test.shouldSyncNum, len(res)) }) } @@ -65,26 +65,46 @@ func TestSyncMemoryTooHigh(t *testing.T) { tests := []struct { testName string syncSegmentNum int - needToSync bool + isHighMemory bool memorySizesInMB []float64 shouldSyncSegs []UniqueID }{ - {"test normal 1", 3, true, - []float64{1, 2, 3, 4, 5}, []UniqueID{5, 4, 3}}, - {"test normal 2", 2, true, - []float64{1, 2, 3, 4, 5}, []UniqueID{5, 4}}, - {"test normal 3", 5, true, - []float64{1, 2, 3, 4, 5}, []UniqueID{5, 4, 3, 2, 1}}, - {"test needToSync false", 3, false, - []float64{1, 2, 3, 4, 5}, []UniqueID{}}, - {"test syncSegmentNum 1", 1, true, - []float64{1, 2, 3, 4, 5}, []UniqueID{5}}, - {"test with small segment", 3, true, - []float64{0.1, 0.1, 0.1, 4, 5}, []UniqueID{5, 4}}, + { + "test normal 1", 3, true, + []float64{1, 2, 3, 4, 5}, + []UniqueID{5, 4, 3}, + }, + { + "test normal 2", 2, true, + []float64{1, 2, 3, 4, 5}, + []UniqueID{5, 4}, + }, + { + "test normal 3", 5, true, + []float64{1, 2, 3, 4, 5}, + []UniqueID{5, 4, 3, 2, 1}, + }, + { + "test isHighMemory false", 3, false, + []float64{1, 2, 3, 4, 5}, + []UniqueID{}, + }, + { + "test syncSegmentNum 1", 1, true, + []float64{1, 2, 3, 4, 5}, + []UniqueID{5}, + }, + { + "test with small segment", 3, true, + []float64{0.1, 0.1, 0.1, 4, 5}, + []UniqueID{5, 4}, + }, } for _, test := range tests { t.Run(test.testName, func(t *testing.T) { + channel := newChannel("channel", 0, nil, nil, nil) + channel.setIsHighMemory(test.isHighMemory) Params.Save(Params.DataNodeCfg.MemoryForceSyncSegmentNum.Key, fmt.Sprintf("%d", test.syncSegmentNum)) policy := syncMemoryTooHigh() segments := make([]*Segment, len(test.memorySizesInMB)) @@ -93,8 +113,39 @@ func TestSyncMemoryTooHigh(t *testing.T) { segmentID: UniqueID(i + 1), memorySize: int64(test.memorySizesInMB[i] * 1024 * 1024), } } - segs := policy(segments, 0, atomic.NewBool(test.needToSync)) + segs := policy(segments, channel, 0) assert.ElementsMatch(t, segs, test.shouldSyncSegs) }) } } + +func TestSyncSegmentsAtTs(t *testing.T) { + tests := []struct { + testName string + ts Timestamp + flushTs Timestamp + shouldSyncNum int + }{ + {"test ts < flushTs", 100, 200, 0}, + {"test ts > flushTs", 300, 200, 1}, + {"test ts = flushTs", 100, 100, 1}, + {"test flushTs = 0", 100, 0, 0}, + {"test flushTs = maxUint64", 100, math.MaxUint64, 0}, + } + + for _, test := range tests { + t.Run(test.testName, func(t *testing.T) { + channel := newChannel("channel", 0, nil, nil, nil) + channel.setFlushTs(test.flushTs) + + segment := &Segment{} + segment.setInsertBuffer(&BufferData{ + startPos: &msgpb.MsgPosition{}, + }) + + policy := syncSegmentsAtTs() + res := policy([]*Segment{segment}, channel, test.ts) + assert.Equal(t, test.shouldSyncNum, len(res)) + }) + } +} diff --git a/internal/datanode/segment_test.go b/internal/datanode/segment_test.go index b21130b2f32b6..fef988d7712ca 100644 --- a/internal/datanode/segment_test.go +++ b/internal/datanode/segment_test.go @@ -20,9 +20,10 @@ import ( "math/rand" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" - "github.com/stretchr/testify/assert" ) func TestSegment_UpdatePKRange(t *testing.T) { diff --git a/internal/datanode/services.go b/internal/datanode/services.go index c70eb207efe41..591d6b026c106 100644 --- a/internal/datanode/services.go +++ b/internal/datanode/services.go @@ -27,7 +27,6 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/samber/lo" "go.uber.org/zap" @@ -44,13 +43,16 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/timerecord" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -59,14 +61,11 @@ func (node *DataNode) WatchDmChannels(ctx context.Context, in *datapb.WatchDmCha log.Warn("DataNode WatchDmChannels is not in use") // TODO ERROR OF GRPC NOT IN USE - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "watchDmChannels do nothing", - }, nil + return merr.Success(), nil } // GetComponentStates will return current state of DataNode -func (node *DataNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (node *DataNode) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { log.Debug("DataNode current state", zap.Any("State", node.stateCode.Load())) nodeID := common.NotRegisteredID if node.GetSession() != nil && node.session.Registered() { @@ -80,7 +79,7 @@ func (node *DataNode) GetComponentStates(ctx context.Context) (*milvuspb.Compone StateCode: node.stateCode.Load().(commonpb.StateCode), }, SubcomponentStates: make([]*milvuspb.ComponentInfo, 0), - Status: merr.Status(nil), + Status: merr.Success(), } return states, nil } @@ -95,8 +94,7 @@ func (node *DataNode) FlushSegments(ctx context.Context, req *datapb.FlushSegmen fmt.Sprint(paramtable.GetNodeID()), metrics.TotalLabel).Inc() - if !node.isHealthy() { - err := merr.WrapErrServiceNotReady(node.GetStateCode().String()) + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { log.Warn("DataNode.FlushSegments failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) return merr.Status(err), nil @@ -161,42 +159,37 @@ func (node *DataNode) FlushSegments(ctx context.Context, req *datapb.FlushSegmen metrics.DataNodeFlushReqCounter.WithLabelValues( fmt.Sprint(paramtable.GetNodeID()), metrics.SuccessLabel).Inc() - return merr.Status(nil), nil + return merr.Success(), nil } -// ResendSegmentStats resend un-flushed segment stats back upstream to DataCoord by resending DataNode time tick message. +// ResendSegmentStats . ResendSegmentStats resend un-flushed segment stats back upstream to DataCoord by resending DataNode time tick message. // It returns a list of segments to be sent. +// Deprecated in 2.3.2, reversed it just for compatibility during rolling back func (node *DataNode) ResendSegmentStats(ctx context.Context, req *datapb.ResendSegmentStatsRequest) (*datapb.ResendSegmentStatsResponse, error) { - log.Info("start resending segment stats, if any", - zap.Int64("DataNode ID", paramtable.GetNodeID())) - segResent := node.flowgraphManager.resendTT() - log.Info("found segment(s) with stats to resend", - zap.Int64s("segment IDs", segResent)) return &datapb.ResendSegmentStatsResponse{ - Status: merr.Status(nil), - SegResent: segResent, + Status: merr.Success(), + SegResent: make([]int64, 0), }, nil } // GetTimeTickChannel currently do nothing -func (node *DataNode) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (node *DataNode) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil } // GetStatisticsChannel currently do nothing -func (node *DataNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (node *DataNode) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil } // ShowConfigurations returns the configurations of DataNode matching req.Pattern func (node *DataNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { log.Debug("DataNode.ShowConfigurations", zap.String("pattern", req.Pattern)) - if !node.isHealthy() { - err := merr.WrapErrServiceNotReady(node.GetStateCode().String()) + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { log.Warn("DataNode.ShowConfigurations failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) return &internalpb.ShowConfigurationsResponse{ @@ -214,15 +207,14 @@ func (node *DataNode) ShowConfigurations(ctx context.Context, req *internalpb.Sh } return &internalpb.ShowConfigurationsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Configuations: configList, }, nil } // GetMetrics return datanode metrics func (node *DataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - if !node.isHealthy() { - err := merr.WrapErrServiceNotReady(node.GetStateCode().String()) + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { log.Warn("DataNode.GetMetrics failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) return &milvuspb.GetMetricsResponse{ @@ -267,8 +259,7 @@ func (node *DataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRe // Compaction handles compaction request from DataCoord // returns status as long as compaction task enqueued or invalid func (node *DataNode) Compaction(ctx context.Context, req *datapb.CompactionPlan) (*commonpb.Status, error) { - if !node.isHealthy() { - err := merr.WrapErrServiceNotReady(node.GetStateCode().String()) + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { log.Warn("DataNode.Compaction failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) return merr.Status(err), nil } @@ -297,14 +288,13 @@ func (node *DataNode) Compaction(ctx context.Context, req *datapb.CompactionPlan node.compactionExecutor.execute(task) - return merr.Status(nil), nil + return merr.Success(), nil } // GetCompactionState called by DataCoord // return status of all compaction plans func (node *DataNode) GetCompactionState(ctx context.Context, req *datapb.CompactionStateRequest) (*datapb.CompactionStateResponse, error) { - if !node.isHealthy() { - err := merr.WrapErrServiceNotReady(node.GetStateCode().String()) + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { log.Warn("DataNode.GetCompactionState failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) return &datapb.CompactionStateResponse{ Status: merr.Status(err), @@ -334,7 +324,7 @@ func (node *DataNode) GetCompactionState(ctx context.Context, req *datapb.Compac log.Info("Compaction results", zap.Int64s("planIDs", planIDs)) } return &datapb.CompactionStateResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Results: results, }, nil } @@ -348,8 +338,7 @@ func (node *DataNode) SyncSegments(ctx context.Context, req *datapb.SyncSegments zap.Int64("numOfRows", req.GetNumOfRows()), ) - if !node.isHealthy() { - err := merr.WrapErrServiceNotReady(node.GetStateCode().String()) + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { log.Warn("DataNode.SyncSegments failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.Error(err)) return merr.Status(err), nil } @@ -362,6 +351,8 @@ func (node *DataNode) SyncSegments(ctx context.Context, req *datapb.SyncSegments oneSegment int64 channel Channel err error + ds *dataSyncService + ok bool ) for _, fromSegment := range req.GetCompactedFrom() { @@ -370,12 +361,17 @@ func (node *DataNode) SyncSegments(ctx context.Context, req *datapb.SyncSegments log.Ctx(ctx).Warn("fail to get the channel", zap.Int64("segment", fromSegment), zap.Error(err)) continue } + ds, ok = node.flowgraphManager.getFlowgraphService(channel.getChannelName(fromSegment)) + if !ok { + log.Ctx(ctx).Warn("fail to find flow graph service", zap.Int64("segment", fromSegment)) + continue + } oneSegment = fromSegment break } if oneSegment == 0 { log.Ctx(ctx).Warn("no valid segment, maybe the request is a retry") - return merr.Status(nil), nil + return merr.Success(), nil } // oneSegment is definitely in the channel, guaranteed by the check before. @@ -392,11 +388,23 @@ func (node *DataNode) SyncSegments(ctx context.Context, req *datapb.SyncSegments return merr.Status(err), nil } - if err := channel.mergeFlushedSegments(ctx, targetSeg, req.GetPlanID(), req.GetCompactedFrom()); err != nil { - return merr.Status(err), nil - } + ds.fg.Blockall() + defer ds.fg.Unblock() + channel.mergeFlushedSegments(ctx, targetSeg, req.GetPlanID(), req.GetCompactedFrom()) node.compactionExecutor.injectDone(req.GetPlanID(), true) - return merr.Status(nil), nil + return merr.Success(), nil +} + +func (node *DataNode) NotifyChannelOperation(ctx context.Context, req *datapb.ChannelOperationsRequest) (*commonpb.Status, error) { + log.Warn("DataNode NotifyChannelOperation is unimplemented") + return merr.Status(merr.ErrServiceUnavailable), nil +} + +func (node *DataNode) CheckChannelOperationProgress(ctx context.Context, req *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error) { + log.Warn("DataNode CheckChannelOperationProgress is unimplemented") + return &datapb.ChannelOperationProgressResponse{ + Status: merr.Status(merr.ErrServiceUnavailable), + }, nil } // Import data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments @@ -416,7 +424,7 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) }() importResult := &rootcoordpb.ImportResult{ - Status: merr.Status(nil), + Status: merr.Success(), TaskId: req.GetImportTask().TaskId, DatanodeId: paramtable.GetNodeID(), State: commonpb.ImportState_ImportStarted, @@ -433,22 +441,18 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) // function to report import state to RootCoord. // retry 10 times, if the rootcoord is down, the report function will cost 20+ seconds reportFunc := reportImportFunc(node) - returnFailFunc := func(msg string, inputErr error) (*commonpb.Status, error) { - logFields = append(logFields, zap.Error(inputErr)) + returnFailFunc := func(msg string, err error) (*commonpb.Status, error) { + logFields = append(logFields, zap.Error(err)) log.Warn(msg, logFields...) importResult.State = commonpb.ImportState_ImportFailed - importResult.Infos = append(importResult.Infos, &commonpb.KeyValuePair{Key: importutil.FailedReason, Value: inputErr.Error()}) + importResult.Infos = append(importResult.Infos, &commonpb.KeyValuePair{Key: importutil.FailedReason, Value: err.Error()}) reportFunc(importResult) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: inputErr.Error(), - }, nil + return merr.Status(err), nil } - if !node.isHealthy() { - err := merr.WrapErrServiceNotReady(node.GetStateCode().String()) + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { logFields = append(logFields, zap.Error(err)) log.Warn("DataNode import failed, node is not healthy", logFields...) return merr.Status(err), nil @@ -456,38 +460,36 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) // get a timestamp for all the rows // Ignore cancellation from parent context. - rep, err := node.rootCoord.AllocTimestamp(newCtx, &rootcoordpb.AllocTimestampRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_RequestTSO), - commonpbutil.WithMsgID(0), - commonpbutil.WithTimeStamp(0), - commonpbutil.WithSourceID(paramtable.GetNodeID()), - ), - Count: 1, - }) - - if rep.Status.ErrorCode != commonpb.ErrorCode_Success || err != nil { + ts, _, err := node.broker.AllocTimestamp(newCtx, 1) + if err != nil { return returnFailFunc("DataNode alloc ts failed", err) } - ts := rep.GetTimestamp() - // get collection schema and shard number - metaService := newMetaService(node.rootCoord, req.GetImportTask().GetCollectionId()) + metaService := newMetaService(node.broker, req.GetImportTask().GetCollectionId()) colInfo, err := metaService.getCollectionInfo(newCtx, req.GetImportTask().GetCollectionId(), 0) if err != nil { return returnFailFunc("failed to get collection info for collection ID", err) } - // the colInfo doesn't have a collect database name(it is empty). use the database name passed from rootcoord. - partitions, err := node.getPartitions(ctx, req.GetImportTask().GetDatabaseName(), colInfo.GetCollectionName()) - if err != nil { - return returnFailFunc("failed to get partition id list", err) - } - - partitionIDs, err := importutil.DeduceTargetPartitions(partitions, colInfo.GetSchema(), req.GetImportTask().GetPartitionId()) - if err != nil { - return returnFailFunc("failed to decude target partitions", err) + var partitionIDs []int64 + if req.GetImportTask().GetPartitionId() == 0 { + if !typeutil.HasPartitionKey(colInfo.GetSchema()) { + err = errors.New("try auto-distribute data but the collection has no partition key") + return returnFailFunc(err.Error(), err) + } + // TODO: prefer to set partitionIDs in coord instead of get here. + // the colInfo doesn't have a correct database name(it is empty). use the database name passed from rootcoord. + partitions, err := node.broker.ShowPartitions(ctx, req.GetImportTask().GetDatabaseName(), colInfo.GetCollectionName()) + if err != nil { + return returnFailFunc("failed to get partition id list", err) + } + _, partitionIDs, err = typeutil.RearrangePartitionsForPartitionKey(partitions) + if err != nil { + return returnFailFunc("failed to rearrange target partitions", err) + } + } else { + partitionIDs = []int64{req.GetImportTask().GetPartitionId()} } collectionInfo, err := importutil.NewCollectionInfo(colInfo.GetSchema(), colInfo.GetShardsNum(), partitionIDs) @@ -516,49 +518,31 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) return returnFailFunc("failed to import files", err) } - resp := merr.Status(nil) + resp := merr.Success() return resp, nil } -func (node *DataNode) getPartitions(ctx context.Context, dbName string, collectionName string) (map[string]int64, error) { - req := &milvuspb.ShowPartitionsRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions), - ), - DbName: dbName, - CollectionName: collectionName, - } +func (node *DataNode) FlushChannels(ctx context.Context, req *datapb.FlushChannelsRequest) (*commonpb.Status, error) { + log := log.Ctx(ctx).With(zap.Int64("nodeId", paramtable.GetNodeID()), + zap.Time("flushTs", tsoutil.PhysicalTime(req.GetFlushTs())), + zap.Strings("channels", req.GetChannels())) - logFields := []zap.Field{ - zap.String("dbName", dbName), - zap.String("collectionName", collectionName), - } - resp, err := node.rootCoord.ShowPartitions(ctx, req) - if err != nil { - logFields = append(logFields, zap.Error(err)) - log.Warn("failed to get partitions of collection", logFields...) - return nil, err - } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - log.Warn("failed to get partitions of collection", logFields...) - return nil, errors.New(resp.Status.Reason) - } + log.Info("DataNode receives FlushChannels request") - partitionNames := resp.GetPartitionNames() - partitionIDs := resp.GetPartitionIDs() - if len(partitionNames) != len(partitionIDs) { - logFields = append(logFields, zap.Int("number of names", len(partitionNames)), zap.Int("number of ids", len(partitionIDs))) - log.Warn("partition names and ids are unequal", logFields...) - return nil, fmt.Errorf("partition names and ids are unequal, number of names: %d, number of ids: %d", - len(partitionNames), len(partitionIDs)) + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + log.Warn("DataNode.FlushChannels failed", zap.Error(err)) + return merr.Status(err), nil } - partitions := make(map[string]int64) - for i := 0; i < len(partitionNames); i++ { - partitions[partitionNames[i]] = partitionIDs[i] + for _, channel := range req.GetChannels() { + fg, ok := node.flowgraphManager.getFlowgraphService(channel) + if !ok { + return merr.Status(merr.WrapErrChannelNotFound(channel)), nil + } + fg.channel.setFlushTs(req.GetFlushTs()) } - return partitions, nil + return merr.Success(), nil } // AddImportSegment adds the import segment to the current DataNode. @@ -596,7 +580,7 @@ func (node *DataNode) AddImportSegment(ctx context.Context, req *datapb.AddImpor // Get the current dml channel position ID, that will be used in segments start positions and end positions. var posID []byte err = retry.Do(ctx, func() error { - id, innerError := ds.getChannelLatestMsgID(context.Background(), req.GetChannelName(), req.GetSegmentId()) + id, innerError := node.getChannelLatestMsgID(context.Background(), req.GetChannelName(), req.GetSegmentId()) posID = id return innerError }, retry.Attempts(30)) @@ -616,6 +600,7 @@ func (node *DataNode) AddImportSegment(ctx context.Context, req *datapb.AddImpor // Add segment as a flushed segment, but set `importing` to true to add extra information of the segment. // By 'extra information' we mean segment info while adding a `SegmentType_Flushed` typed segment. if err := ds.channel.addSegment( + context.TODO(), addSegmentReq{ segType: datapb.SegmentType_Flushed, segID: req.GetSegmentId(), @@ -649,11 +634,33 @@ func (node *DataNode) AddImportSegment(ctx context.Context, req *datapb.AddImpor } ds.flushingSegCache.Remove(req.GetSegmentId()) return &datapb.AddImportSegmentResponse{ - Status: merr.Status(nil), + Status: merr.Success(), ChannelPos: posID, }, nil } +func (node *DataNode) getChannelLatestMsgID(ctx context.Context, channelName string, segmentID int64) ([]byte, error) { + pChannelName := funcutil.ToPhysicalChannel(channelName) + dmlStream, err := node.factory.NewMsgStream(ctx) + if err != nil { + return nil, err + } + defer dmlStream.Close() + + subName := fmt.Sprintf("datanode-%d-%s-%d", paramtable.GetNodeID(), channelName, segmentID) + log.Debug("dataSyncService register consumer for getChannelLatestMsgID", + zap.String("pChannelName", pChannelName), + zap.String("subscription", subName), + ) + dmlStream.AsConsumer(ctx, []string{pChannelName}, subName, mqwrapper.SubscriptionPositionUnknown) + id, err := dmlStream.GetLatestMsgID(pChannelName) + if err != nil { + log.Error("fail to GetLatestMsgID", zap.String("pChannelName", pChannelName), zap.Error(err)) + return nil, err + } + return id.Serialize(), nil +} + func assignSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest) importutil.AssignSegmentFunc { return func(shardID int, partID int64) (int64, string, error) { chNames := req.GetImportTask().GetChannelNames() @@ -679,18 +686,16 @@ func assignSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest) importutil logFields = append(logFields, zap.Int64("collection ID", colID)) logFields = append(logFields, zap.String("target channel name", targetChName)) log.Info("assign segment for the import task", logFields...) - resp, err := node.dataCoord.AssignSegmentID(context.Background(), segmentIDReq) + ids, err := node.broker.AssignSegmentID(context.Background(), segmentIDReq.GetSegmentIDRequests()...) if err != nil { - return 0, "", fmt.Errorf("syncSegmentID Failed:%w", err) - } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return 0, "", fmt.Errorf("syncSegmentID Failed:%s", resp.Status.Reason) + return 0, "", errors.Wrap(err, "failed to AssignSegmentID") } - if len(resp.SegIDAssignments) == 0 || resp.SegIDAssignments[0] == nil { - return 0, "", fmt.Errorf("syncSegmentID Failed: the collection was dropped") + + if len(ids) == 0 { + return 0, "", merr.WrapErrSegmentNotFound(0, "failed to assign segment id") } - segmentID := resp.SegIDAssignments[0].SegID + segmentID := ids[0] logFields = append(logFields, zap.Int64("segmentID", segmentID)) log.Info("new segment assigned", logFields...) @@ -698,7 +703,7 @@ func assignSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest) importutil // ignore the returned error, since even report failed the segments still can be cleaned // retry 10 times, if the rootcoord is down, the report function will cost 20+ seconds importResult := &rootcoordpb.ImportResult{ - Status: merr.Status(nil), + Status: merr.Success(), TaskId: req.GetImportTask().TaskId, DatanodeId: paramtable.GetNodeID(), State: commonpb.ImportState_ImportStarted, @@ -755,7 +760,8 @@ func createBinLogsFunc(node *DataNode, req *datapb.ImportTaskRequest, schema *sc func saveSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoordpb.ImportResult, ts Timestamp) importutil.SaveSegmentFunc { importTaskID := req.GetImportTask().GetTaskId() return func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, segmentID int64, - targetChName string, rowCount int64, partID int64) error { + targetChName string, rowCount int64, partID int64, + ) error { logFields := []zap.Field{ zap.Int64("task ID", importTaskID), zap.Int64("partitionID", partID), @@ -768,7 +774,7 @@ func saveSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoo err := retry.Do(context.Background(), func() error { // Ask DataCoord to save binlog path and add segment to the corresponding DataNode flow graph. - resp, err := node.dataCoord.SaveImportSegment(context.Background(), &datapb.SaveImportSegmentRequest{ + err := node.broker.SaveImportSegment(context.Background(), &datapb.SaveImportSegmentRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithTimeStamp(ts), // Pass current timestamp downstream. commonpbutil.WithSourceID(paramtable.GetNodeID()), @@ -804,12 +810,10 @@ func saveSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoo }) // Only retrying when DataCoord is unhealthy or err != nil, otherwise return immediately. if err != nil { - return fmt.Errorf(err.Error()) - } - if resp.ErrorCode != commonpb.ErrorCode_Success && resp.ErrorCode != commonpb.ErrorCode_NotReadyServe { - return retry.Unrecoverable(fmt.Errorf("failed to save import segment, reason = %s", resp.Reason)) - } else if resp.ErrorCode == commonpb.ErrorCode_NotReadyServe { - return fmt.Errorf("failed to save import segment: %s", resp.GetReason()) + if errors.Is(err, merr.ErrServiceNotReady) { + return retry.Unrecoverable(err) + } + return err } return nil }) @@ -825,7 +829,8 @@ func saveSegmentFunc(node *DataNode, req *datapb.ImportTaskRequest, res *rootcoo } func composeAssignSegmentIDRequest(rowNum int, shardID int, chNames []string, - collID int64, partID int64) *datapb.AssignSegmentIDRequest { + collID int64, partID int64, +) *datapb.AssignSegmentIDRequest { // use the first field's row count as segment row count // all the fields row count are same, checked by ImportWrapper // ask DataCoord to alloc a new segment @@ -847,8 +852,8 @@ func composeAssignSegmentIDRequest(rowNum int, shardID int, chNames []string, } func createBinLogs(rowNum int, schema *schemapb.CollectionSchema, ts Timestamp, - fields map[storage.FieldID]storage.FieldData, node *DataNode, segmentID, colID, partID UniqueID) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error) { - + fields map[storage.FieldID]storage.FieldData, node *DataNode, segmentID, colID, partID UniqueID, +) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -860,13 +865,15 @@ func createBinLogs(rowNum int, schema *schemapb.CollectionSchema, ts Timestamp, Data: tsFieldData, } - if status, _ := node.dataCoord.UpdateSegmentStatistics(context.TODO(), &datapb.UpdateSegmentStatisticsRequest{ - Stats: []*commonpb.SegmentStats{{ - SegmentID: segmentID, - NumRows: int64(rowNum), - }}, - }); status.GetErrorCode() != commonpb.ErrorCode_Success { - return nil, nil, fmt.Errorf(status.GetReason()) + if err := node.broker.UpdateSegmentStatistics(context.TODO(), &datapb.UpdateSegmentStatisticsRequest{ + Stats: []*commonpb.SegmentStats{ + { + SegmentID: segmentID, + NumRows: int64(rowNum), + }, + }, + }); err != nil { + return nil, nil, err } data := BufferData{buffer: &InsertData{ @@ -963,15 +970,11 @@ func createBinLogs(rowNum int, schema *schemapb.CollectionSchema, ts Timestamp, func reportImportFunc(node *DataNode) importutil.ReportFunc { return func(importResult *rootcoordpb.ImportResult) error { err := retry.Do(context.Background(), func() error { - status, err := node.rootCoord.ReportImport(context.Background(), importResult) + err := node.broker.ReportImport(context.Background(), importResult) if err != nil { - log.Error("fail to report import state to RootCoord", zap.Error(err)) - return err + log.Error("failed to report import state to RootCoord", zap.Error(err)) } - if status != nil && status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(status.GetReason()) - } - return nil + return err }, retry.Attempts(node.reportImportRetryTimes)) return err diff --git a/internal/datanode/services_test.go b/internal/datanode/services_test.go index 6834a6212eaec..283433bc734f3 100644 --- a/internal/datanode/services_test.go +++ b/internal/datanode/services_test.go @@ -23,7 +23,9 @@ import ( "path/filepath" "sync" "testing" + "time" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" clientv3 "go.etcd.io/etcd/client/v3" @@ -35,6 +37,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" allocator2 "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/datanode/allocator" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/storage" @@ -48,11 +51,13 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) type DataNodeServicesSuite struct { suite.Suite + broker *broker.MockBroker node *DataNode etcdCli *clientv3.Client ctx context.Context @@ -96,6 +101,25 @@ func (s *DataNodeServicesSuite) SetupTest() { }, nil).Maybe() s.node.allocator = alloc + meta := NewMetaFactory().GetCollectionMeta(1, "collection", schemapb.DataType_Int64) + broker := broker.NewMockBroker(s.T()) + broker.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything). + Return([]*datapb.SegmentInfo{}, nil).Maybe() + broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + Schema: meta.GetSchema(), + ShardsNum: common.DefaultShardsNum, + }, nil).Maybe() + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + broker.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Call.Return(tsoutil.ComposeTSByTime(time.Now(), 0), + func(_ context.Context, num uint32) uint32 { return num }, nil).Maybe() + + s.broker = broker + s.node.broker = broker + err = s.node.Start() s.Require().NoError(err) @@ -104,8 +128,15 @@ func (s *DataNodeServicesSuite) SetupTest() { } func (s *DataNodeServicesSuite) TearDownTest() { - s.node.Stop() - s.node = nil + if s.broker != nil { + s.broker.AssertExpectations(s.T()) + s.broker = nil + } + + if s.node != nil { + s.node.Stop() + s.node = nil + } } func (s *DataNodeServicesSuite) TearDownSuite() { @@ -121,25 +152,25 @@ func (s *DataNodeServicesSuite) TestNotInUseAPIs() { s.Assert().True(merr.Ok(status)) }) s.Run("GetTimeTickChannel", func() { - _, err := s.node.GetTimeTickChannel(s.ctx) + _, err := s.node.GetTimeTickChannel(s.ctx, nil) s.Assert().NoError(err) }) s.Run("GetStatisticsChannel", func() { - _, err := s.node.GetStatisticsChannel(s.ctx) + _, err := s.node.GetStatisticsChannel(s.ctx, nil) s.Assert().NoError(err) }) } func (s *DataNodeServicesSuite) TestGetComponentStates() { - resp, err := s.node.GetComponentStates(s.ctx) + resp, err := s.node.GetComponentStates(s.ctx, nil) s.Assert().NoError(err) s.Assert().True(merr.Ok(resp.GetStatus())) s.Assert().Equal(common.NotRegisteredID, resp.State.NodeID) s.node.SetSession(&sessionutil.Session{}) s.node.session.UpdateRegistered(true) - resp, err = s.node.GetComponentStates(context.Background()) + resp, err = s.node.GetComponentStates(context.Background(), nil) s.Assert().NoError(err) s.Assert().True(merr.Ok(resp.GetStatus())) } @@ -190,20 +221,22 @@ func (s *DataNodeServicesSuite) TestFlushSegments() { FlushedSegmentIds: []int64{}, } - err := s.node.flowgraphManager.addAndStart(s.node, vchan, nil, genTestTickler()) + err := s.node.flowgraphManager.addAndStartWithEtcdTickler(s.node, vchan, nil, genTestTickler()) s.Require().NoError(err) fgservice, ok := s.node.flowgraphManager.getFlowgraphService(dmChannelName) s.Require().True(ok) - err = fgservice.channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_New, - segID: 0, - collID: 1, - partitionID: 1, - startPos: &msgpb.MsgPosition{}, - endPos: &msgpb.MsgPosition{}, - }) + err = fgservice.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 0, + collID: 1, + partitionID: 1, + startPos: &msgpb.MsgPosition{}, + endPos: &msgpb.MsgPosition{}, + }) s.Require().NoError(err) req := &datapb.FlushSegmentsRequest{ @@ -317,9 +350,9 @@ func (s *DataNodeServicesSuite) TestShowConfigurations() { Pattern: pattern, } - //test closed server + // test closed server node := &DataNode{} - node.SetSession(&sessionutil.Session{ServerID: 1}) + node.SetSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}) node.stateCode.Store(commonpb.StateCode_Abnormal) resp, err := node.ShowConfigurations(s.ctx, req) @@ -336,7 +369,7 @@ func (s *DataNodeServicesSuite) TestShowConfigurations() { func (s *DataNodeServicesSuite) TestGetMetrics() { node := &DataNode{} - node.SetSession(&sessionutil.Session{ServerID: 1}) + node.SetSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}) node.flowgraphManager = newFlowgraphManager() // server is closed node.stateCode.Store(commonpb.StateCode_Abnormal) @@ -378,9 +411,8 @@ func (s *DataNodeServicesSuite) TestImport() { collectionID: 100, pkType: schemapb.DataType_Int64, } - s.node.reportImportRetryTimes = 1 // save test time cost from 440s to 180s - s.Run("test normal", func() { - content := []byte(`{ + + content := []byte(`{ "rows":[ {"bool_field": true, "int8_field": 10, "int16_field": 101, "int32_field": 1001, "int64_field": 10001, "float32_field": 3.14, "float64_field": 1.56, "varChar_field": "hello world", "binary_vector_field": [254, 0, 254, 0], "float_vector_field": [1.1, 1.2]}, {"bool_field": false, "int8_field": 11, "int16_field": 102, "int32_field": 1002, "int64_field": 10002, "float32_field": 3.15, "float64_field": 2.56, "varChar_field": "hello world", "binary_vector_field": [253, 0, 253, 0], "float_vector_field": [2.1, 2.2]}, @@ -389,17 +421,25 @@ func (s *DataNodeServicesSuite) TestImport() { {"bool_field": true, "int8_field": 14, "int16_field": 105, "int32_field": 1005, "int64_field": 10005, "float32_field": 3.18, "float64_field": 5.56, "varChar_field": "hello world", "binary_vector_field": [250, 0, 250, 0], "float_vector_field": [5.1, 5.2]} ] }`) + filePath := filepath.Join(s.node.chunkManager.RootPath(), "rows_1.json") + err := s.node.chunkManager.Write(s.ctx, filePath, content) + s.Require().NoError(err) + s.node.reportImportRetryTimes = 1 // save test time cost from 440s to 180s + s.Run("test normal", func() { + defer func() { + s.TearDownTest() + }() chName1 := "fake-by-dev-rootcoord-dml-testimport-1" chName2 := "fake-by-dev-rootcoord-dml-testimport-2" - err := s.node.flowgraphManager.addAndStart(s.node, &datapb.VchannelInfo{ + err := s.node.flowgraphManager.addAndStartWithEtcdTickler(s.node, &datapb.VchannelInfo{ CollectionID: 100, ChannelName: chName1, UnflushedSegmentIds: []int64{}, FlushedSegmentIds: []int64{}, }, nil, genTestTickler()) s.Require().Nil(err) - err = s.node.flowgraphManager.addAndStart(s.node, &datapb.VchannelInfo{ + err = s.node.flowgraphManager.addAndStartWithEtcdTickler(s.node, &datapb.VchannelInfo{ CollectionID: 100, ChannelName: chName2, UnflushedSegmentIds: []int64{}, @@ -412,9 +452,6 @@ func (s *DataNodeServicesSuite) TestImport() { _, ok = s.node.flowgraphManager.getFlowgraphService(chName2) s.Require().True(ok) - filePath := filepath.Join(s.node.chunkManager.RootPath(), "rows_1.json") - err = s.node.chunkManager.Write(s.ctx, filePath, content) - s.Require().Nil(err) req := &datapb.ImportTaskRequest{ ImportTask: &datapb.ImportTask{ CollectionId: 100, @@ -424,48 +461,48 @@ func (s *DataNodeServicesSuite) TestImport() { RowBased: true, }, } - s.node.rootCoord.(*RootCoordFactory).ReportImportErr = true - _, err = s.node.Import(s.ctx, req) - s.Assert().NoError(err) - s.node.rootCoord.(*RootCoordFactory).ReportImportErr = false - - s.node.rootCoord.(*RootCoordFactory).ReportImportNotSuccess = true - _, err = s.node.Import(context.WithValue(s.ctx, ctxKey{}, ""), req) - s.Assert().NoError(err) - s.node.rootCoord.(*RootCoordFactory).ReportImportNotSuccess = false - s.node.dataCoord.(*DataCoordFactory).AddSegmentError = true - _, err = s.node.Import(context.WithValue(s.ctx, ctxKey{}, ""), req) - s.Assert().NoError(err) - s.node.dataCoord.(*DataCoordFactory).AddSegmentError = false - - s.node.dataCoord.(*DataCoordFactory).AddSegmentNotSuccess = true - _, err = s.node.Import(context.WithValue(s.ctx, ctxKey{}, ""), req) - s.Assert().NoError(err) - s.node.dataCoord.(*DataCoordFactory).AddSegmentNotSuccess = false + s.broker.EXPECT().ReportImport(mock.Anything, mock.Anything).Return(nil) + s.broker.EXPECT().UpdateSegmentStatistics(mock.Anything, mock.Anything).Return(nil) + s.broker.EXPECT().AssignSegmentID(mock.Anything, mock.Anything). + Return([]int64{10001}, nil) + s.broker.EXPECT().SaveImportSegment(mock.Anything, mock.Anything).Return(nil) - s.node.dataCoord.(*DataCoordFactory).AddSegmentEmpty = true - _, err = s.node.Import(context.WithValue(s.ctx, ctxKey{}, ""), req) - s.Assert().NoError(err) - s.node.dataCoord.(*DataCoordFactory).AddSegmentEmpty = false + s.node.Import(s.ctx, req) stat, err := s.node.Import(context.WithValue(s.ctx, ctxKey{}, ""), req) s.Assert().NoError(err) s.Assert().True(merr.Ok(stat)) s.Assert().Equal("", stat.GetReason()) + + reqWithoutPartition := &datapb.ImportTaskRequest{ + ImportTask: &datapb.ImportTask{ + CollectionId: 100, + ChannelNames: []string{chName1, chName2}, + Files: []string{filePath}, + RowBased: true, + }, + } + stat2, err := s.node.Import(context.WithValue(s.ctx, ctxKey{}, ""), reqWithoutPartition) + s.Assert().NoError(err) + s.Assert().False(merr.Ok(stat2)) }) s.Run("Test Import bad flow graph", func() { + s.SetupTest() + defer func() { + s.TearDownTest() + }() chName1 := "fake-by-dev-rootcoord-dml-testimport-1-badflowgraph" chName2 := "fake-by-dev-rootcoord-dml-testimport-2-badflowgraph" - err := s.node.flowgraphManager.addAndStart(s.node, &datapb.VchannelInfo{ + err := s.node.flowgraphManager.addAndStartWithEtcdTickler(s.node, &datapb.VchannelInfo{ CollectionID: 100, ChannelName: chName1, UnflushedSegmentIds: []int64{}, FlushedSegmentIds: []int64{}, }, nil, genTestTickler()) s.Require().Nil(err) - err = s.node.flowgraphManager.addAndStart(s.node, &datapb.VchannelInfo{ + err = s.node.flowgraphManager.addAndStartWithEtcdTickler(s.node, &datapb.VchannelInfo{ CollectionID: 999, // wrong collection ID. ChannelName: chName2, UnflushedSegmentIds: []int64{}, @@ -478,19 +515,12 @@ func (s *DataNodeServicesSuite) TestImport() { _, ok = s.node.flowgraphManager.getFlowgraphService(chName2) s.Require().True(ok) - content := []byte(`{ - "rows":[ - {"bool_field": true, "int8_field": 10, "int16_field": 101, "int32_field": 1001, "int64_field": 10001, "float32_field": 3.14, "float64_field": 1.56, "varChar_field": "hello world", "binary_vector_field": [254, 0, 254, 0], "float_vector_field": [1.1, 1.2]}, - {"bool_field": false, "int8_field": 11, "int16_field": 102, "int32_field": 1002, "int64_field": 10002, "float32_field": 3.15, "float64_field": 2.56, "varChar_field": "hello world", "binary_vector_field": [253, 0, 253, 0], "float_vector_field": [2.1, 2.2]}, - {"bool_field": true, "int8_field": 12, "int16_field": 103, "int32_field": 1003, "int64_field": 10003, "float32_field": 3.16, "float64_field": 3.56, "varChar_field": "hello world", "binary_vector_field": [252, 0, 252, 0], "float_vector_field": [3.1, 3.2]}, - {"bool_field": false, "int8_field": 13, "int16_field": 104, "int32_field": 1004, "int64_field": 10004, "float32_field": 3.17, "float64_field": 4.56, "varChar_field": "hello world", "binary_vector_field": [251, 0, 251, 0], "float_vector_field": [4.1, 4.2]}, - {"bool_field": true, "int8_field": 14, "int16_field": 105, "int32_field": 1005, "int64_field": 10005, "float32_field": 3.18, "float64_field": 5.56, "varChar_field": "hello world", "binary_vector_field": [250, 0, 250, 0], "float_vector_field": [5.1, 5.2]} - ] - }`) + s.broker.EXPECT().UpdateSegmentStatistics(mock.Anything, mock.Anything).Return(nil) + s.broker.EXPECT().ReportImport(mock.Anything, mock.Anything).Return(nil) + s.broker.EXPECT().AssignSegmentID(mock.Anything, mock.Anything). + Return([]int64{10001}, nil) + s.broker.EXPECT().SaveImportSegment(mock.Anything, mock.Anything).Return(nil) - filePath := filepath.Join(s.node.chunkManager.RootPath(), "rows_1.json") - err = s.node.chunkManager.Write(s.ctx, filePath, content) - s.Assert().NoError(err) req := &datapb.ImportTaskRequest{ ImportTask: &datapb.ImportTask{ CollectionId: 100, @@ -505,25 +535,19 @@ func (s *DataNodeServicesSuite) TestImport() { s.Assert().True(merr.Ok(stat)) s.Assert().Equal("", stat.GetReason()) }) - s.Run("Test Import report import error", func() { - s.node.rootCoord = &RootCoordFactory{ - collectionID: 100, - pkType: schemapb.DataType_Int64, - ReportImportErr: true, - } - content := []byte(`{ - "rows":[ - {"bool_field": true, "int8_field": 10, "int16_field": 101, "int32_field": 1001, "int64_field": 10001, "float32_field": 3.14, "float64_field": 1.56, "varChar_field": "hello world", "binary_vector_field": [254, 0, 254, 0], "float_vector_field": [1.1, 1.2]}, - {"bool_field": false, "int8_field": 11, "int16_field": 102, "int32_field": 1002, "int64_field": 10002, "float32_field": 3.15, "float64_field": 2.56, "varChar_field": "hello world", "binary_vector_field": [253, 0, 253, 0], "float_vector_field": [2.1, 2.2]}, - {"bool_field": true, "int8_field": 12, "int16_field": 103, "int32_field": 1003, "int64_field": 10003, "float32_field": 3.16, "float64_field": 3.56, "varChar_field": "hello world", "binary_vector_field": [252, 0, 252, 0], "float_vector_field": [3.1, 3.2]}, - {"bool_field": false, "int8_field": 13, "int16_field": 104, "int32_field": 1004, "int64_field": 10004, "float32_field": 3.17, "float64_field": 4.56, "varChar_field": "hello world", "binary_vector_field": [251, 0, 251, 0], "float_vector_field": [4.1, 4.2]}, - {"bool_field": true, "int8_field": 14, "int16_field": 105, "int32_field": 1005, "int64_field": 10005, "float32_field": 3.18, "float64_field": 5.56, "varChar_field": "hello world", "binary_vector_field": [250, 0, 250, 0], "float_vector_field": [5.1, 5.2]} - ] - }`) + s.Run("test_Import_report_import_error", func() { + s.SetupTest() + s.node.reportImportRetryTimes = 1 + defer func() { + s.TearDownTest() + }() + + s.broker.EXPECT().AssignSegmentID(mock.Anything, mock.Anything). + Return([]int64{10001}, nil) + s.broker.EXPECT().ReportImport(mock.Anything, mock.Anything).Return(errors.New("mocked")) + s.broker.EXPECT().UpdateSegmentStatistics(mock.Anything, mock.Anything).Return(nil) + s.broker.EXPECT().SaveImportSegment(mock.Anything, mock.Anything).Return(nil) - filePath := filepath.Join(s.node.chunkManager.RootPath(), "rows_1.json") - err := s.node.chunkManager.Write(s.ctx, filePath, content) - s.Assert().NoError(err) req := &datapb.ImportTaskRequest{ ImportTask: &datapb.ImportTask{ CollectionId: 100, @@ -538,8 +562,25 @@ func (s *DataNodeServicesSuite) TestImport() { s.Assert().False(merr.Ok(stat)) }) - s.Run("Test Import error", func() { - s.node.rootCoord = &RootCoordFactory{collectionID: -1} + s.Run("test_import_error", func() { + s.SetupTest() + defer func() { + s.TearDownTest() + }() + s.broker.ExpectedCalls = nil + s.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(merr.WrapErrCollectionNotFound("collection")), + }, nil) + s.broker.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything). + Return([]*datapb.SegmentInfo{}, nil).Maybe() + s.broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything).Return(nil).Maybe() + s.broker.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything).Return(nil).Maybe() + s.broker.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + s.broker.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).Call.Return(tsoutil.ComposeTSByTime(time.Now(), 0), + func(_ context.Context, num uint32) uint32 { return num }, nil).Maybe() + + s.broker.EXPECT().ReportImport(mock.Anything, mock.Anything).Return(nil) req := &datapb.ImportTaskRequest{ ImportTask: &datapb.ImportTask{ CollectionId: 100, @@ -559,42 +600,6 @@ func (s *DataNodeServicesSuite) TestImport() { s.Assert().NoError(err) s.Assert().False(merr.Ok(stat)) }) - - s.Run("test get partitions", func() { - s.node.rootCoord = &RootCoordFactory{ - ShowPartitionsErr: true, - } - - _, err := s.node.getPartitions(context.Background(), "", "") - s.Assert().Error(err) - - s.node.rootCoord = &RootCoordFactory{ - ShowPartitionsNotSuccess: true, - } - - _, err = s.node.getPartitions(context.Background(), "", "") - s.Assert().Error(err) - - s.node.rootCoord = &RootCoordFactory{ - ShowPartitionsNames: []string{"a", "b"}, - ShowPartitionsIDs: []int64{1}, - } - - _, err = s.node.getPartitions(context.Background(), "", "") - s.Assert().Error(err) - - s.node.rootCoord = &RootCoordFactory{ - ShowPartitionsNames: []string{"a", "b"}, - ShowPartitionsIDs: []int64{1, 2}, - } - - partitions, err := s.node.getPartitions(context.Background(), "", "") - s.Assert().NoError(err) - s.Assert().Contains(partitions, "a") - s.Assert().Equal(int64(1), partitions["a"]) - s.Assert().Contains(partitions, "b") - s.Assert().Equal(int64(2), partitions["b"]) - }) } func (s *DataNodeServicesSuite) TestAddImportSegment() { @@ -606,14 +611,14 @@ func (s *DataNodeServicesSuite) TestAddImportSegment() { chName1 := "fake-by-dev-rootcoord-dml-testaddsegment-1" chName2 := "fake-by-dev-rootcoord-dml-testaddsegment-2" - err := s.node.flowgraphManager.addAndStart(s.node, &datapb.VchannelInfo{ + err := s.node.flowgraphManager.addAndStartWithEtcdTickler(s.node, &datapb.VchannelInfo{ CollectionID: 100, ChannelName: chName1, UnflushedSegmentIds: []int64{}, FlushedSegmentIds: []int64{}, }, nil, genTestTickler()) s.Require().NoError(err) - err = s.node.flowgraphManager.addAndStart(s.node, &datapb.VchannelInfo{ + err = s.node.flowgraphManager.addAndStartWithEtcdTickler(s.node, &datapb.VchannelInfo{ CollectionID: 100, ChannelName: chName2, UnflushedSegmentIds: []int64{}, @@ -651,13 +656,13 @@ func (s *DataNodeServicesSuite) TestAddImportSegment() { s.Assert().False(merr.Ok(resp.GetStatus())) // s.Assert().Equal(merr.Code(merr.ErrChannelNotFound), stat.GetStatus().GetCode()) }) - } func (s *DataNodeServicesSuite) TestSyncSegments() { chanName := "fake-by-dev-rootcoord-dml-test-syncsegments-1" - err := s.node.flowgraphManager.addAndStart(s.node, &datapb.VchannelInfo{ + err := s.node.flowgraphManager.addAndStartWithEtcdTickler(s.node, &datapb.VchannelInfo{ + CollectionID: 1, ChannelName: chanName, UnflushedSegmentIds: []int64{}, FlushedSegmentIds: []int64{100, 200, 300}, @@ -666,9 +671,9 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { fg, ok := s.node.flowgraphManager.getFlowgraphService(chanName) s.Assert().True(ok) - s1 := Segment{segmentID: 100} - s2 := Segment{segmentID: 200} - s3 := Segment{segmentID: 300} + s1 := Segment{segmentID: 100, collectionID: 1} + s2 := Segment{segmentID: 200, collectionID: 1} + s3 := Segment{segmentID: 300, collectionID: 1} s1.setType(datapb.SegmentType_Flushed) s2.setType(datapb.SegmentType_Flushed) s3.setType(datapb.SegmentType_Flushed) @@ -701,13 +706,7 @@ func (s *DataNodeServicesSuite) TestSyncSegments() { CompactedTo: 102, NumOfRows: 100, } - cancelCtx, cancel := context.WithCancel(context.Background()) - cancel() - status, err := s.node.SyncSegments(cancelCtx, req) - s.Assert().NoError(err) - s.Assert().False(merr.Ok(status)) - - status, err = s.node.SyncSegments(s.ctx, req) + status, err := s.node.SyncSegments(s.ctx, req) s.Assert().NoError(err) s.Assert().True(merr.Ok(status)) @@ -755,38 +754,44 @@ func (s *DataNodeServicesSuite) TestResendSegmentStats() { FlushedSegmentIds: []int64{}, } - err := s.node.flowgraphManager.addAndStart(s.node, vChan, nil, genTestTickler()) + err := s.node.flowgraphManager.addAndStartWithEtcdTickler(s.node, vChan, nil, genTestTickler()) s.Require().Nil(err) fgService, ok := s.node.flowgraphManager.getFlowgraphService(dmChannelName) s.Assert().True(ok) - err = fgService.channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_New, - segID: 0, - collID: 1, - partitionID: 1, - startPos: &msgpb.MsgPosition{}, - endPos: &msgpb.MsgPosition{}, - }) + err = fgService.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 0, + collID: 1, + partitionID: 1, + startPos: &msgpb.MsgPosition{}, + endPos: &msgpb.MsgPosition{}, + }) s.Assert().Nil(err) - err = fgService.channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_New, - segID: 1, - collID: 1, - partitionID: 2, - startPos: &msgpb.MsgPosition{}, - endPos: &msgpb.MsgPosition{}, - }) + err = fgService.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 1, + collID: 1, + partitionID: 2, + startPos: &msgpb.MsgPosition{}, + endPos: &msgpb.MsgPosition{}, + }) s.Assert().Nil(err) - err = fgService.channel.addSegment(addSegmentReq{ - segType: datapb.SegmentType_New, - segID: 2, - collID: 1, - partitionID: 3, - startPos: &msgpb.MsgPosition{}, - endPos: &msgpb.MsgPosition{}, - }) + err = fgService.channel.addSegment( + context.TODO(), + addSegmentReq{ + segType: datapb.SegmentType_New, + segID: 2, + collID: 1, + partitionID: 3, + startPos: &msgpb.MsgPosition{}, + endPos: &msgpb.MsgPosition{}, + }) s.Assert().Nil(err) req := &datapb.ResendSegmentStatsRequest{ @@ -799,11 +804,55 @@ func (s *DataNodeServicesSuite) TestResendSegmentStats() { resp, err := s.node.ResendSegmentStats(s.ctx, req) s.Assert().NoError(err) s.Assert().True(merr.Ok(resp.GetStatus())) - s.Assert().ElementsMatch([]UniqueID{0, 1, 2}, resp.GetSegResent()) + s.Assert().Empty(resp.GetSegResent()) // Duplicate call. resp, err = s.node.ResendSegmentStats(s.ctx, req) s.Assert().NoError(err) s.Assert().True(merr.Ok(resp.GetStatus())) - s.Assert().ElementsMatch([]UniqueID{0, 1, 2}, resp.GetSegResent()) + s.Assert().Empty(resp.GetSegResent()) +} + +func (s *DataNodeServicesSuite) TestFlushChannels() { + dmChannelName := "fake-by-dev-rootcoord-dml-channel-TestFlushChannels" + + vChan := &datapb.VchannelInfo{ + CollectionID: 1, + ChannelName: dmChannelName, + UnflushedSegmentIds: []int64{}, + FlushedSegmentIds: []int64{}, + } + + err := s.node.flowgraphManager.addAndStartWithEtcdTickler(s.node, vChan, nil, genTestTickler()) + s.Require().NoError(err) + + fgService, ok := s.node.flowgraphManager.getFlowgraphService(dmChannelName) + s.Require().True(ok) + + flushTs := Timestamp(100) + + req := &datapb.FlushChannelsRequest{ + Base: &commonpb.MsgBase{ + TargetID: s.node.GetSession().ServerID, + }, + FlushTs: flushTs, + Channels: []string{dmChannelName}, + } + + status, err := s.node.FlushChannels(s.ctx, req) + s.Assert().NoError(err) + s.Assert().True(merr.Ok(status)) + + s.Assert().True(fgService.channel.getFlushTs() == flushTs) +} + +func (s *DataNodeServicesSuite) TestRPCWatch() { + ctx := context.Background() + status, err := s.node.NotifyChannelOperation(ctx, nil) + s.NoError(err) + s.NotNil(status) + + resp, err := s.node.CheckChannelOperationProgress(ctx, nil) + s.NoError(err) + s.NotNil(resp) } diff --git a/internal/datanode/timetick_sender.go b/internal/datanode/timetick_sender.go index d9924c02fa7e2..f1957409e2ace 100644 --- a/internal/datanode/timetick_sender.go +++ b/internal/datanode/timetick_sender.go @@ -21,25 +21,22 @@ import ( "sync" "time" - "github.com/cockroachdb/errors" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/retry" - "github.com/milvus-io/milvus/pkg/util/tsoutil" ) // timeTickSender is to merge channel states updated by flow graph node and send to datacoord periodically // timeTickSender hold a SegmentStats time sequence cache for each channel, // after send succeeds will clean the cache earlier than the sended timestamp type timeTickSender struct { - nodeID int64 - dataCoord types.DataCoord + nodeID int64 + broker broker.Broker mu sync.Mutex channelStatesCaches map[string]*segmentStatesSequence // string -> *segmentStatesSequence @@ -50,10 +47,10 @@ type segmentStatesSequence struct { data map[uint64][]*commonpb.SegmentStats // ts -> segmentStats } -func newTimeTickSender(dataCoord types.DataCoord, nodeID int64) *timeTickSender { +func newTimeTickSender(broker broker.Broker, nodeID int64) *timeTickSender { return &timeTickSender{ nodeID: nodeID, - dataCoord: dataCoord, + broker: broker, channelStatesCaches: make(map[string]*segmentStatesSequence, 0), } } @@ -159,27 +156,7 @@ func (m *timeTickSender) sendReport(ctx context.Context) error { toSendMsgs, sendLastTss := m.mergeDatanodeTtMsg() log.RatedDebug(30, "timeTickSender send datanode timetick message", zap.Any("toSendMsgs", toSendMsgs), zap.Any("sendLastTss", sendLastTss)) err := retry.Do(ctx, func() error { - submitTs := tsoutil.ComposeTSByTime(time.Now(), 0) - statusResp, err := m.dataCoord.ReportDataNodeTtMsgs(ctx, &datapb.ReportDataNodeTtMsgsRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_DataNodeTt), - commonpbutil.WithTimeStamp(submitTs), - commonpbutil.WithSourceID(m.nodeID), - ), - Msgs: toSendMsgs, - }) - if err != nil { - log.Warn("error happen when ReportDataNodeTtMsgs", zap.Error(err)) - return err - } - if statusResp.GetErrorCode() != commonpb.ErrorCode_Success { - log.Warn("ReportDataNodeTtMsgs resp status not succeed", - zap.String("error_code", statusResp.GetErrorCode().String()), - zap.Int32("code", statusResp.GetCode()), - zap.String("reason", statusResp.GetReason())) - return errors.New(statusResp.GetReason()) - } - return nil + return m.broker.ReportTimeTick(ctx, toSendMsgs) }, retry.Attempts(20), retry.Sleep(time.Millisecond*100)) if err != nil { log.Error("ReportDataNodeTtMsgs fail after retry", zap.Error(err)) diff --git a/internal/datanode/timetick_sender_test.go b/internal/datanode/timetick_sender_test.go index 7d2711d446f01..1f96168643a32 100644 --- a/internal/datanode/timetick_sender_test.go +++ b/internal/datanode/timetick_sender_test.go @@ -21,18 +21,25 @@ import ( "testing" "time" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "go.uber.org/atomic" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/datanode/broker" "github.com/milvus-io/milvus/internal/mocks" - "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/util/merr" ) func TestTimetickManagerNormal(t *testing.T) { ctx := context.Background() - manager := newTimeTickSender(&DataCoordFactory{}, 0) + + broker := broker.NewMockBroker(t) + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything).Return(nil).Maybe() + + manager := newTimeTickSender(broker, 0) channelName1 := "channel1" ts := uint64(time.Now().UnixMilli()) @@ -127,26 +134,11 @@ func TestTimetickManagerNormal(t *testing.T) { func TestTimetickManagerSendErr(t *testing.T) { ctx := context.Background() - manager := newTimeTickSender(&DataCoordFactory{ReportDataNodeTtMsgsError: true}, 0) - channelName1 := "channel1" - ts := uint64(time.Now().Unix()) - var segmentID1 int64 = 28257 - segmentStats := []*commonpb.SegmentStats{ - { - SegmentID: segmentID1, - NumRows: 100, - }, - } - // update first time - manager.update(channelName1, ts, segmentStats) - err := manager.sendReport(ctx) - assert.Error(t, err) -} + broker := broker.NewMockBroker(t) + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything).Return(errors.New("mock")).Maybe() -func TestTimetickManagerSendNotSuccess(t *testing.T) { - ctx := context.Background() - manager := newTimeTickSender(&DataCoordFactory{ReportDataNodeTtMsgsNotSuccess: true}, 0) + manager := newTimeTickSender(broker, 0) channelName1 := "channel1" ts := uint64(time.Now().Unix()) @@ -166,21 +158,21 @@ func TestTimetickManagerSendNotSuccess(t *testing.T) { func TestTimetickManagerSendReport(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - mockDataCoord := mocks.NewMockDataCoord(t) - tsInMill := time.Now().UnixMilli() - - validTs := atomic.NewBool(false) - mockDataCoord.EXPECT().ReportDataNodeTtMsgs(mock.Anything, mock.Anything).Run(func(ctx context.Context, req *datapb.ReportDataNodeTtMsgsRequest) { - if req.GetBase().Timestamp > uint64(tsInMill) { - validTs.Store(true) - } - }).Return(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil) - manager := newTimeTickSender(mockDataCoord, 0) + mockDataCoord := mocks.NewMockDataCoordClient(t) + + called := atomic.NewBool(false) + + broker := broker.NewMockBroker(t) + broker.EXPECT().ReportTimeTick(mock.Anything, mock.Anything). + Run(func(_ context.Context, _ []*msgpb.DataNodeTtMsg) { + called.Store(true) + }). + Return(nil) + mockDataCoord.EXPECT().ReportDataNodeTtMsgs(mock.Anything, mock.Anything).Return(merr.Status(nil), nil).Maybe() + manager := newTimeTickSender(broker, 0) go manager.start(ctx) assert.Eventually(t, func() bool { - return validTs.Load() + return called.Load() }, 2*time.Second, 500*time.Millisecond) } diff --git a/internal/datanode/util.go b/internal/datanode/util.go index 2d9f7c3f4272e..dc1fbdab2c278 100644 --- a/internal/datanode/util.go +++ b/internal/datanode/util.go @@ -19,10 +19,11 @@ package datanode import ( "context" - "github.com/milvus-io/milvus/pkg/mq/msgstream" - "github.com/milvus-io/milvus/pkg/util/typeutil" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" + + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type ( diff --git a/internal/distributed/connection_manager.go b/internal/distributed/connection_manager.go index e43d1817e975e..682a751daeb5a 100644 --- a/internal/distributed/connection_manager.go +++ b/internal/distributed/connection_manager.go @@ -24,24 +24,23 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/util/sessionutil" - "github.com/milvus-io/milvus/pkg/tracer" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials/insecure" - - "go.uber.org/zap" ) // ConnectionManager handles connection to other components of the system @@ -123,6 +122,7 @@ func (cm *ConnectionManager) AddDependency(roleName string) error { return nil } + func (cm *ConnectionManager) Start() { go cm.receiveFinishTask() } @@ -369,7 +369,6 @@ func newBuildClientTask(session *sessionutil.Session, notify chan int64, retryOp notify: notify, } - } func (bct *buildClientTask) Run() { @@ -420,6 +419,7 @@ func (bct *buildClientTask) Run() { } }() } + func (bct *buildClientTask) Stop() { bct.cancel() } diff --git a/internal/distributed/connection_manager_test.go b/internal/distributed/connection_manager_test.go index 30703f814c673..ce3eff2273b09 100644 --- a/internal/distributed/connection_manager_test.go +++ b/internal/distributed/connection_manager_test.go @@ -27,6 +27,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" @@ -36,9 +40,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/stretchr/testify/assert" - "go.uber.org/zap" - "google.golang.org/grpc" ) func TestMain(t *testing.M) { @@ -221,8 +222,10 @@ func TestConnectionManager_processEvent(t *testing.T) { cm := &ConnectionManager{ closeCh: make(chan struct{}), session: &sessionutil.Session{ - ServerID: 1, - TriggerKill: true, + SessionRaw: sessionutil.SessionRaw{ + ServerID: 1, + TriggerKill: true, + }, }, } diff --git a/internal/distributed/datacoord/client/client.go b/internal/distributed/datacoord/client/client.go index fb3ed367c3e26..97c3bfd908a13 100644 --- a/internal/distributed/datacoord/client/client.go +++ b/internal/distributed/datacoord/client/client.go @@ -20,8 +20,6 @@ import ( "context" "fmt" - "github.com/milvus-io/milvus/internal/util/grpcclient" - "github.com/milvus-io/milvus/internal/util/sessionutil" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" "google.golang.org/grpc" @@ -32,6 +30,8 @@ import ( "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -41,7 +41,7 @@ import ( var Params *paramtable.ComponentParam = paramtable.Get() -var _ types.DataCoord = (*Client)(nil) +var _ types.DataCoordClient = (*Client)(nil) // Client is the datacoord grpc client type Client struct { @@ -93,26 +93,11 @@ func (c *Client) getDataCoordAddr() (string, error) { return ms.Address, nil } -// Init initializes the client -func (c *Client) Init() error { - return nil -} - -// Start enables the client -func (c *Client) Start() error { - return nil -} - // Stop stops the client -func (c *Client) Stop() error { +func (c *Client) Close() error { return c.grpcClient.Close() } -// Register dummy -func (c *Client) Register() error { - return nil -} - func wrapGrpcCall[T any](ctx context.Context, c *Client, call func(coordClient datapb.DataCoordClient) (*T, error)) (*T, error) { ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { @@ -127,28 +112,28 @@ func wrapGrpcCall[T any](ctx context.Context, c *Client, call func(coordClient d } // GetComponentStates calls DataCoord GetComponentStates services -func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (c *Client) GetComponentStates(ctx context.Context, _ *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*milvuspb.ComponentStates, error) { return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) } // GetTimeTickChannel return the name of time tick channel. -func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Client) GetTimeTickChannel(ctx context.Context, _ *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*milvuspb.StringResponse, error) { return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) } // GetStatisticsChannel return the name of statistics channel. -func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Client) GetStatisticsChannel(ctx context.Context, _ *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*milvuspb.StringResponse, error) { return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) } // Flush flushes a collection's data -func (c *Client) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.FlushResponse, error) { +func (c *Client) Flush(ctx context.Context, req *datapb.FlushRequest, opts ...grpc.CallOption) (*datapb.FlushResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -172,7 +157,7 @@ func (c *Client) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.F // `AssignSegmentID` will applies current configured allocation policies for each request // if the VChannel is newly used, `WatchDmlChannels` will be invoked to notify a `DataNode`(selected by policy) to watch it // if there is anything make the allocation impossible, the response will not contain the corresponding result -func (c *Client) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { +func (c *Client) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*datapb.AssignSegmentIDResponse, error) { return client.AssignSegmentID(ctx, req) }) @@ -189,7 +174,7 @@ func (c *Client) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI // otherwise the Segment State and Start position information will be returned // // error is returned only when some communication issue occurs -func (c *Client) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { +func (c *Client) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentStatesRequest, opts ...grpc.CallOption) (*datapb.GetSegmentStatesResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -210,7 +195,7 @@ func (c *Client) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentSta // and corresponding binlog path list // // error is returned only when some communication issue occurs -func (c *Client) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsertBinlogPathsRequest) (*datapb.GetInsertBinlogPathsResponse, error) { +func (c *Client) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsertBinlogPathsRequest, opts ...grpc.CallOption) (*datapb.GetInsertBinlogPathsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -231,7 +216,7 @@ func (c *Client) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsert // only row count for now // // error is returned only when some communication issue occurs -func (c *Client) GetCollectionStatistics(ctx context.Context, req *datapb.GetCollectionStatisticsRequest) (*datapb.GetCollectionStatisticsResponse, error) { +func (c *Client) GetCollectionStatistics(ctx context.Context, req *datapb.GetCollectionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetCollectionStatisticsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -252,7 +237,7 @@ func (c *Client) GetCollectionStatistics(ctx context.Context, req *datapb.GetCol // only row count for now // // error is returned only when some communication issue occurs -func (c *Client) GetPartitionStatistics(ctx context.Context, req *datapb.GetPartitionStatisticsRequest) (*datapb.GetPartitionStatisticsResponse, error) { +func (c *Client) GetPartitionStatistics(ctx context.Context, req *datapb.GetPartitionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetPartitionStatisticsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -265,7 +250,7 @@ func (c *Client) GetPartitionStatistics(ctx context.Context, req *datapb.GetPart // GetSegmentInfoChannel DEPRECATED // legacy api to get SegmentInfo Channel name -func (c *Client) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Client) GetSegmentInfoChannel(ctx context.Context, _ *datapb.GetSegmentInfoChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*milvuspb.StringResponse, error) { return client.GetSegmentInfoChannel(ctx, &datapb.GetSegmentInfoChannelRequest{}) }) @@ -278,7 +263,7 @@ func (c *Client) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringRes // // response struct `GetSegmentInfoResponse` contains the list of segment info // error is returned only when some communication issue occurs -func (c *Client) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoRequest) (*datapb.GetSegmentInfoResponse, error) { +func (c *Client) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*datapb.GetSegmentInfoResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -303,7 +288,7 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoR // // the root reason is each `SaveBinlogPaths` will overwrite the checkpoint position // if the constraint is broken, the checkpoint position will not be monotonically increasing and the integrity will be compromised -func (c *Client) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) (*commonpb.Status, error) { +func (c *Client) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { // use Call here on purpose req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( @@ -322,7 +307,7 @@ func (c *Client) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath // // response struct `GetRecoveryInfoResponse` contains the list of segments info and corresponding vchannel info // error is returned only when some communication issue occurs -func (c *Client) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInfoRequest) (*datapb.GetRecoveryInfoResponse, error) { +func (c *Client) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInfoRequest, opts ...grpc.CallOption) (*datapb.GetRecoveryInfoResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -340,7 +325,7 @@ func (c *Client) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf // // response struct `GetRecoveryInfoResponseV2` contains the list of segments info and corresponding vchannel info // error is returned only when some communication issue occurs -func (c *Client) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryInfoRequestV2) (*datapb.GetRecoveryInfoResponseV2, error) { +func (c *Client) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryInfoRequestV2, opts ...grpc.CallOption) (*datapb.GetRecoveryInfoResponseV2, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -360,7 +345,7 @@ func (c *Client) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryI // // response struct `GetFlushedSegmentsResponse` contains flushed segment id list // error is returned only when some communication issue occurs -func (c *Client) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) { +func (c *Client) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest, opts ...grpc.CallOption) (*datapb.GetFlushedSegmentsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -379,7 +364,7 @@ func (c *Client) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedS // // response struct `GetSegmentsByStatesResponse` contains segment id list // error is returned only when some communication issue occurs -func (c *Client) GetSegmentsByStates(ctx context.Context, req *datapb.GetSegmentsByStatesRequest) (*datapb.GetSegmentsByStatesResponse, error) { +func (c *Client) GetSegmentsByStates(ctx context.Context, req *datapb.GetSegmentsByStatesRequest, opts ...grpc.CallOption) (*datapb.GetSegmentsByStatesResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -391,7 +376,7 @@ func (c *Client) GetSegmentsByStates(ctx context.Context, req *datapb.GetSegment } // ShowConfigurations gets specified configurations para of DataCoord -func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { +func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -403,7 +388,7 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon } // GetMetrics gets all metrics of datacoord -func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -415,49 +400,49 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest } // ManualCompaction triggers a compaction for a collection -func (c *Client) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) { +func (c *Client) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompactionRequest, opts ...grpc.CallOption) (*milvuspb.ManualCompactionResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*milvuspb.ManualCompactionResponse, error) { return client.ManualCompaction(ctx, req) }) } // GetCompactionState gets the state of a compaction -func (c *Client) GetCompactionState(ctx context.Context, req *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) { +func (c *Client) GetCompactionState(ctx context.Context, req *milvuspb.GetCompactionStateRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionStateResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*milvuspb.GetCompactionStateResponse, error) { return client.GetCompactionState(ctx, req) }) } // GetCompactionStateWithPlans gets the state of a compaction by plan -func (c *Client) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) { +func (c *Client) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionPlansResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*milvuspb.GetCompactionPlansResponse, error) { return client.GetCompactionStateWithPlans(ctx, req) }) } // WatchChannels notifies DataCoord to watch vchannels of a collection -func (c *Client) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) { +func (c *Client) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest, opts ...grpc.CallOption) (*datapb.WatchChannelsResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*datapb.WatchChannelsResponse, error) { return client.WatchChannels(ctx, req) }) } -// GetFlushState gets the flush state of multiple segments -func (c *Client) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { +// GetFlushState gets the flush state of the collection based on the provided flush ts and segment IDs. +func (c *Client) GetFlushState(ctx context.Context, req *datapb.GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*milvuspb.GetFlushStateResponse, error) { return client.GetFlushState(ctx, req) }) } // GetFlushAllState checks if all DML messages before `FlushAllTs` have been flushed. -func (c *Client) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAllStateRequest) (*milvuspb.GetFlushAllStateResponse, error) { +func (c *Client) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAllStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushAllStateResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*milvuspb.GetFlushAllStateResponse, error) { return client.GetFlushAllState(ctx, req) }) } // DropVirtualChannel drops virtual channel in datacoord. -func (c *Client) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) (*datapb.DropVirtualChannelResponse, error) { +func (c *Client) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest, opts ...grpc.CallOption) (*datapb.DropVirtualChannelResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -469,7 +454,7 @@ func (c *Client) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtual } // SetSegmentState sets the state of a given segment. -func (c *Client) SetSegmentState(ctx context.Context, req *datapb.SetSegmentStateRequest) (*datapb.SetSegmentStateResponse, error) { +func (c *Client) SetSegmentState(ctx context.Context, req *datapb.SetSegmentStateRequest, opts ...grpc.CallOption) (*datapb.SetSegmentStateResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -481,7 +466,7 @@ func (c *Client) SetSegmentState(ctx context.Context, req *datapb.SetSegmentStat } // Import data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments -func (c *Client) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { +func (c *Client) Import(ctx context.Context, req *datapb.ImportTaskRequest, opts ...grpc.CallOption) (*datapb.ImportTaskResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -493,7 +478,7 @@ func (c *Client) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*da } // UpdateSegmentStatistics is the client side caller of UpdateSegmentStatistics. -func (c *Client) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) { +func (c *Client) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -505,7 +490,7 @@ func (c *Client) UpdateSegmentStatistics(ctx context.Context, req *datapb.Update } // UpdateChannelCheckpoint updates channel checkpoint in dataCoord. -func (c *Client) UpdateChannelCheckpoint(ctx context.Context, req *datapb.UpdateChannelCheckpointRequest) (*commonpb.Status, error) { +func (c *Client) UpdateChannelCheckpoint(ctx context.Context, req *datapb.UpdateChannelCheckpointRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -517,7 +502,7 @@ func (c *Client) UpdateChannelCheckpoint(ctx context.Context, req *datapb.Update } // SaveImportSegment is the DataCoord client side code for SaveImportSegment call. -func (c *Client) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { +func (c *Client) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -528,7 +513,7 @@ func (c *Client) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSe }) } -func (c *Client) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { +func (c *Client) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsImportingStateRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -539,7 +524,7 @@ func (c *Client) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsI }) } -func (c *Client) MarkSegmentsDropped(ctx context.Context, req *datapb.MarkSegmentsDroppedRequest) (*commonpb.Status, error) { +func (c *Client) MarkSegmentsDropped(ctx context.Context, req *datapb.MarkSegmentsDroppedRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -551,82 +536,82 @@ func (c *Client) MarkSegmentsDropped(ctx context.Context, req *datapb.MarkSegmen } // BroadcastAlteredCollection is the DataCoord client side code for BroadcastAlteredCollection call. -func (c *Client) BroadcastAlteredCollection(ctx context.Context, req *datapb.AlterCollectionRequest) (*commonpb.Status, error) { +func (c *Client) BroadcastAlteredCollection(ctx context.Context, req *datapb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*commonpb.Status, error) { return client.BroadcastAlteredCollection(ctx, req) }) } -func (c *Client) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { +func (c *Client) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*milvuspb.CheckHealthResponse, error) { return client.CheckHealth(ctx, req) }) } -func (c *Client) GcConfirm(ctx context.Context, req *datapb.GcConfirmRequest) (*datapb.GcConfirmResponse, error) { +func (c *Client) GcConfirm(ctx context.Context, req *datapb.GcConfirmRequest, opts ...grpc.CallOption) (*datapb.GcConfirmResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*datapb.GcConfirmResponse, error) { return client.GcConfirm(ctx, req) }) } // CreateIndex sends the build index request to IndexCoord. -func (c *Client) CreateIndex(ctx context.Context, req *indexpb.CreateIndexRequest) (*commonpb.Status, error) { +func (c *Client) CreateIndex(ctx context.Context, req *indexpb.CreateIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*commonpb.Status, error) { return client.CreateIndex(ctx, req) }) } // GetIndexState gets the index states from IndexCoord. -func (c *Client) GetIndexState(ctx context.Context, req *indexpb.GetIndexStateRequest) (*indexpb.GetIndexStateResponse, error) { +func (c *Client) GetIndexState(ctx context.Context, req *indexpb.GetIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetIndexStateResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*indexpb.GetIndexStateResponse, error) { return client.GetIndexState(ctx, req) }) } // GetSegmentIndexState gets the index states from IndexCoord. -func (c *Client) GetSegmentIndexState(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest) (*indexpb.GetSegmentIndexStateResponse, error) { +func (c *Client) GetSegmentIndexState(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetSegmentIndexStateResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*indexpb.GetSegmentIndexStateResponse, error) { return client.GetSegmentIndexState(ctx, req) }) } // GetIndexInfos gets the index file paths from IndexCoord. -func (c *Client) GetIndexInfos(ctx context.Context, req *indexpb.GetIndexInfoRequest) (*indexpb.GetIndexInfoResponse, error) { +func (c *Client) GetIndexInfos(ctx context.Context, req *indexpb.GetIndexInfoRequest, opts ...grpc.CallOption) (*indexpb.GetIndexInfoResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*indexpb.GetIndexInfoResponse, error) { return client.GetIndexInfos(ctx, req) }) } // DescribeIndex describe the index info of the collection. -func (c *Client) DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error) { +func (c *Client) DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*indexpb.DescribeIndexResponse, error) { return client.DescribeIndex(ctx, req) }) } // GetIndexStatistics get the statistics of the index. -func (c *Client) GetIndexStatistics(ctx context.Context, req *indexpb.GetIndexStatisticsRequest) (*indexpb.GetIndexStatisticsResponse, error) { +func (c *Client) GetIndexStatistics(ctx context.Context, req *indexpb.GetIndexStatisticsRequest, opts ...grpc.CallOption) (*indexpb.GetIndexStatisticsResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*indexpb.GetIndexStatisticsResponse, error) { return client.GetIndexStatistics(ctx, req) }) } // GetIndexBuildProgress describe the progress of the index. -func (c *Client) GetIndexBuildProgress(ctx context.Context, req *indexpb.GetIndexBuildProgressRequest) (*indexpb.GetIndexBuildProgressResponse, error) { +func (c *Client) GetIndexBuildProgress(ctx context.Context, req *indexpb.GetIndexBuildProgressRequest, opts ...grpc.CallOption) (*indexpb.GetIndexBuildProgressResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*indexpb.GetIndexBuildProgressResponse, error) { return client.GetIndexBuildProgress(ctx, req) }) } // DropIndex sends the drop index request to IndexCoord. -func (c *Client) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error) { +func (c *Client) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*commonpb.Status, error) { return client.DropIndex(ctx, req) }) } -func (c *Client) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDataNodeTtMsgsRequest) (*commonpb.Status, error) { +func (c *Client) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDataNodeTtMsgsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataCoordClient) (*commonpb.Status, error) { return client.ReportDataNodeTtMsgs(ctx, req) }) diff --git a/internal/distributed/datacoord/client/client_test.go b/internal/distributed/datacoord/client/client_test.go index 67fe3b25e3188..ab4b78590a862 100644 --- a/internal/distributed/datacoord/client/client_test.go +++ b/internal/distributed/datacoord/client/client_test.go @@ -25,16 +25,16 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/util/mock" + "github.com/stretchr/testify/assert" "go.uber.org/zap" + "google.golang.org/grpc" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proxy" + "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" - "google.golang.org/grpc" ) func TestMain(m *testing.M) { @@ -70,15 +70,6 @@ func Test_NewClient(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, client) - err = client.Init() - assert.NoError(t, err) - - err = client.Start() - assert.NoError(t, err) - - err = client.Register() - assert.NoError(t, err) - checkFunc := func(retNotNil bool) { retCheck := func(notNil bool, ret any, err error) { if notNil { @@ -90,13 +81,13 @@ func Test_NewClient(t *testing.T) { } } - r1, err := client.GetComponentStates(ctx) + r1, err := client.GetComponentStates(ctx, nil) retCheck(retNotNil, r1, err) - r2, err := client.GetTimeTickChannel(ctx) + r2, err := client.GetTimeTickChannel(ctx, nil) retCheck(retNotNil, r2, err) - r3, err := client.GetStatisticsChannel(ctx) + r3, err := client.GetStatisticsChannel(ctx, nil) retCheck(retNotNil, r3, err) r4, err := client.Flush(ctx, nil) @@ -120,7 +111,7 @@ func Test_NewClient(t *testing.T) { r10, err := client.GetPartitionStatistics(ctx, nil) retCheck(retNotNil, r10, err) - r11, err := client.GetSegmentInfoChannel(ctx) + r11, err := client.GetSegmentInfoChannel(ctx, nil) retCheck(retNotNil, r11, err) // r12, err := client.SaveBinlogPaths(ctx, nil) @@ -259,6 +250,6 @@ func Test_NewClient(t *testing.T) { assert.NotNil(t, ret) assert.NoError(t, err) - err = client.Stop() + err = client.Close() assert.NoError(t, err) } diff --git a/internal/distributed/datacoord/service.go b/internal/distributed/datacoord/service.go index 911ecc52dab49..0d798d46b6109 100644 --- a/internal/distributed/datacoord/service.go +++ b/internal/distributed/datacoord/service.go @@ -24,12 +24,8 @@ import ( "sync" "time" - "github.com/milvus-io/milvus/internal/proto/indexpb" - "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/pkg/tracer" - "github.com/milvus-io/milvus/pkg/util/interceptor" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + "github.com/tikv/client-go/v2/txnkv" clientv3 "go.etcd.io/etcd/client/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/atomic" @@ -40,14 +36,21 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/datacoord" + "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tikv" ) // Server is the grpc server of datacoord @@ -61,6 +64,7 @@ type Server struct { dataCoord types.DataCoordComponent etcdCli *clientv3.Client + tikvCli *txnkv.Client grpcErrChan chan error grpcServer *grpc.Server @@ -79,9 +83,11 @@ func NewServer(ctx context.Context, factory dependency.Factory, opts ...datacoor return s } +var getTiKVClient = tikv.GetTiKVClient + func (s *Server) init() error { - etcdConfig := ¶mtable.Get().EtcdCfg - Params := ¶mtable.Get().DataCoordGrpcServerCfg + params := paramtable.Get() + etcdConfig := ¶ms.EtcdCfg etcdCli, err := etcd.GetEtcdClient( etcdConfig.UseEmbedEtcd.GetAsBool(), @@ -97,7 +103,18 @@ func (s *Server) init() error { } s.etcdCli = etcdCli s.dataCoord.SetEtcdClient(etcdCli) - s.dataCoord.SetAddress(Params.GetAddress()) + s.dataCoord.SetAddress(params.DataCoordGrpcServerCfg.GetAddress()) + + if params.MetaStoreCfg.MetaStoreType.GetValue() == util.MetaStoreTypeTiKV { + log.Info("Connecting to tikv metadata storage.") + tikvCli, err := getTiKVClient(¶mtable.Get().TiKVCfg) + if err != nil { + log.Warn("DataCoord failed to connect to tikv", zap.Error(err)) + return err + } + s.dataCoord.SetTiKVClient(tikvCli) + log.Info("Connected to tikv. Using tikv as metadata storage.") + } err = s.startGrpc() if err != nil { @@ -137,12 +154,12 @@ func (s *Server) startGrpcLoop(grpcPort int) { ctx, cancel := context.WithCancel(s.ctx) defer cancel() - var kaep = keepalive.EnforcementPolicy{ + kaep := keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, // If a client pings more than once every 5 seconds, terminate the connection PermitWithoutStream: true, // Allow pings even when there are no active streams } - var kasp = keepalive.ServerParameters{ + kasp := keepalive.ServerParameters{ Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } @@ -209,9 +226,11 @@ func (s *Server) Stop() error { if s.etcdCli != nil { defer s.etcdCli.Close() } + if s.tikvCli != nil { + defer s.tikvCli.Close() + } if s.grpcServer != nil { - log.Debug("Graceful stop grpc server...") - s.grpcServer.GracefulStop() + utils.GracefulStopGRPCServer(s.grpcServer) } err = s.dataCoord.Stop() @@ -240,17 +259,17 @@ func (s *Server) Run() error { // GetComponentStates gets states of datacoord and datanodes func (s *Server) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { - return s.dataCoord.GetComponentStates(ctx) + return s.dataCoord.GetComponentStates(ctx, req) } // GetTimeTickChannel gets timetick channel func (s *Server) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { - return s.dataCoord.GetTimeTickChannel(ctx) + return s.dataCoord.GetTimeTickChannel(ctx, req) } // GetStatisticsChannel gets statistics channel func (s *Server) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { - return s.dataCoord.GetStatisticsChannel(ctx) + return s.dataCoord.GetStatisticsChannel(ctx, req) } // GetSegmentInfo gets segment information according to segment id @@ -290,7 +309,7 @@ func (s *Server) GetPartitionStatistics(ctx context.Context, req *datapb.GetPart // GetSegmentInfoChannel gets channel to which datacoord sends segment information func (s *Server) GetSegmentInfoChannel(ctx context.Context, req *datapb.GetSegmentInfoChannelRequest) (*milvuspb.StringResponse, error) { - return s.dataCoord.GetSegmentInfoChannel(ctx) + return s.dataCoord.GetSegmentInfoChannel(ctx, req) } // SaveBinlogPaths implement DataCoordServer, saves segment, collection binlog according to datanode request @@ -348,8 +367,8 @@ func (s *Server) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq return s.dataCoord.WatchChannels(ctx, req) } -// GetFlushState gets the flush state of multiple segments -func (s *Server) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { +// GetFlushState gets the flush state of the collection based on the provided flush ts and segment IDs. +func (s *Server) GetFlushState(ctx context.Context, req *datapb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { return s.dataCoord.GetFlushState(ctx, req) } diff --git a/internal/distributed/datacoord/service_test.go b/internal/distributed/datacoord/service_test.go index 53f289bf5e912..2188be1685dd0 100644 --- a/internal/distributed/datacoord/service_test.go +++ b/internal/distributed/datacoord/service_test.go @@ -18,9 +18,13 @@ package grpcdatacoord import ( "context" + "fmt" "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/tikv/client-go/v2/txnkv" + clientv3 "go.etcd.io/etcd/client/v3" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -28,9 +32,9 @@ import ( "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" - clientv3 "go.etcd.io/etcd/client/v3" + "github.com/milvus-io/milvus/pkg/util/tikv" ) type MockDataCoord struct { @@ -104,24 +108,27 @@ func (*MockDataCoord) SetAddress(address string) { func (m *MockDataCoord) SetEtcdClient(etcdClient *clientv3.Client) { } -func (m *MockDataCoord) SetRootCoord(rootCoord types.RootCoord) { +func (m *MockDataCoord) SetTiKVClient(client *txnkv.Client) { } -func (m *MockDataCoord) SetDataNodeCreator(func(context.Context, string, int64) (types.DataNode, error)) { +func (m *MockDataCoord) SetRootCoordClient(rootCoord types.RootCoordClient) { } -func (m *MockDataCoord) SetIndexNodeCreator(func(context.Context, string, int64) (types.IndexNode, error)) { +func (m *MockDataCoord) SetDataNodeCreator(func(context.Context, string, int64) (types.DataNodeClient, error)) { } -func (m *MockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (m *MockDataCoord) SetIndexNodeCreator(func(context.Context, string, int64) (types.IndexNodeClient, error)) { +} + +func (m *MockDataCoord) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { return m.states, m.err } -func (m *MockDataCoord) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (m *MockDataCoord) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { return m.strResp, m.err } -func (m *MockDataCoord) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (m *MockDataCoord) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { return m.strResp, m.err } @@ -153,7 +160,7 @@ func (m *MockDataCoord) GetPartitionStatistics(ctx context.Context, req *datapb. return m.partStatResp, m.err } -func (m *MockDataCoord) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (m *MockDataCoord) GetSegmentInfoChannel(ctx context.Context, req *datapb.GetSegmentInfoChannelRequest) (*milvuspb.StringResponse, error) { return m.strResp, m.err } @@ -201,7 +208,7 @@ func (m *MockDataCoord) WatchChannels(ctx context.Context, req *datapb.WatchChan return m.watchChannelsResp, m.err } -func (m *MockDataCoord) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { +func (m *MockDataCoord) GetFlushState(ctx context.Context, req *datapb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { return m.getFlushStateResp, m.err } @@ -285,424 +292,437 @@ func (m *MockDataCoord) DropIndex(ctx context.Context, req *indexpb.DropIndexReq func Test_NewServer(t *testing.T) { paramtable.Init() - ctx := context.Background() - server := NewServer(ctx, nil) - assert.NotNil(t, server) - - t.Run("Run", func(t *testing.T) { - server.dataCoord = &MockDataCoord{} - //indexCoord := mocks.NewMockIndexCoord(t) - //indexCoord.EXPECT().Init().Return(nil) - //server.indexCoord = indexCoord - - err := server.Run() - assert.NoError(t, err) - }) - - t.Run("GetComponentStates", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - states: &milvuspb.ComponentStates{}, - } - states, err := server.GetComponentStates(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, states) - }) - - t.Run("GetTimeTickChannel", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - strResp: &milvuspb.StringResponse{}, - } - resp, err := server.GetTimeTickChannel(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetStatisticsChannel", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - strResp: &milvuspb.StringResponse{}, - } - resp, err := server.GetStatisticsChannel(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetSegmentInfo", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - infoResp: &datapb.GetSegmentInfoResponse{}, - } - resp, err := server.GetSegmentInfo(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("Flush", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - flushResp: &datapb.FlushResponse{}, - } - resp, err := server.Flush(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("AssignSegmentID", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - assignResp: &datapb.AssignSegmentIDResponse{}, - } - resp, err := server.AssignSegmentID(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetSegmentStates", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - segStateResp: &datapb.GetSegmentStatesResponse{}, - } - resp, err := server.GetSegmentStates(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetInsertBinlogPaths", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - binResp: &datapb.GetInsertBinlogPathsResponse{}, - } - resp, err := server.GetInsertBinlogPaths(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetCollectionStatistics", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - colStatResp: &datapb.GetCollectionStatisticsResponse{}, - } - resp, err := server.GetCollectionStatistics(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetPartitionStatistics", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - partStatResp: &datapb.GetPartitionStatisticsResponse{}, - } - resp, err := server.GetPartitionStatistics(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetSegmentInfoChannel", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - strResp: &milvuspb.StringResponse{}, - } - resp, err := server.GetSegmentInfoChannel(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("SaveBinlogPaths", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - status: &commonpb.Status{}, - } - resp, err := server.SaveBinlogPaths(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetRecoveryInfo", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - recoverResp: &datapb.GetRecoveryInfoResponse{}, - } - resp, err := server.GetRecoveryInfo(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetFlushedSegments", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - flushSegResp: &datapb.GetFlushedSegmentsResponse{}, - } - resp, err := server.GetFlushedSegments(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("ShowConfigurations", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - configResp: &internalpb.ShowConfigurationsResponse{}, - } - resp, err := server.ShowConfigurations(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetMetrics", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - metricResp: &milvuspb.GetMetricsResponse{}, - } - resp, err := server.GetMetrics(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("WatchChannels", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - watchChannelsResp: &datapb.WatchChannelsResponse{}, - } - resp, err := server.WatchChannels(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetFlushState", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - getFlushStateResp: &milvuspb.GetFlushStateResponse{}, - } - resp, err := server.GetFlushState(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetFlushAllState", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - getFlushAllStateResp: &milvuspb.GetFlushAllStateResponse{}, - } - resp, err := server.GetFlushAllState(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("DropVirtualChannel", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - dropVChanResp: &datapb.DropVirtualChannelResponse{}, - } - resp, err := server.DropVirtualChannel(ctx, nil) + parameters := []string{"tikv", "etcd"} + for _, v := range parameters { + paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, v) + ctx := context.Background() + getTiKVClient = func(cfg *paramtable.TiKVConfig) (*txnkv.Client, error) { + return tikv.SetupLocalTxn(), nil + } + defer func() { + getTiKVClient = tikv.GetTiKVClient + }() + server := NewServer(ctx, nil) + assert.NotNil(t, server) + + t.Run("Run", func(t *testing.T) { + server.dataCoord = &MockDataCoord{} + // indexCoord := mocks.NewMockIndexCoord(t) + // indexCoord.EXPECT().Init().Return(nil) + // server.indexCoord = indexCoord + + err := server.Run() + assert.NoError(t, err) + }) + + t.Run("GetComponentStates", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + states: &milvuspb.ComponentStates{}, + } + states, err := server.GetComponentStates(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, states) + }) + + t.Run("GetTimeTickChannel", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + strResp: &milvuspb.StringResponse{}, + } + resp, err := server.GetTimeTickChannel(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetStatisticsChannel", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + strResp: &milvuspb.StringResponse{}, + } + resp, err := server.GetStatisticsChannel(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetSegmentInfo", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + infoResp: &datapb.GetSegmentInfoResponse{}, + } + resp, err := server.GetSegmentInfo(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("Flush", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + flushResp: &datapb.FlushResponse{}, + } + resp, err := server.Flush(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("AssignSegmentID", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + assignResp: &datapb.AssignSegmentIDResponse{}, + } + resp, err := server.AssignSegmentID(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetSegmentStates", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + segStateResp: &datapb.GetSegmentStatesResponse{}, + } + resp, err := server.GetSegmentStates(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetInsertBinlogPaths", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + binResp: &datapb.GetInsertBinlogPathsResponse{}, + } + resp, err := server.GetInsertBinlogPaths(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetCollectionStatistics", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + colStatResp: &datapb.GetCollectionStatisticsResponse{}, + } + resp, err := server.GetCollectionStatistics(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetPartitionStatistics", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + partStatResp: &datapb.GetPartitionStatisticsResponse{}, + } + resp, err := server.GetPartitionStatistics(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetSegmentInfoChannel", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + strResp: &milvuspb.StringResponse{}, + } + resp, err := server.GetSegmentInfoChannel(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("SaveBinlogPaths", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + status: &commonpb.Status{}, + } + resp, err := server.SaveBinlogPaths(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetRecoveryInfo", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + recoverResp: &datapb.GetRecoveryInfoResponse{}, + } + resp, err := server.GetRecoveryInfo(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetFlushedSegments", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + flushSegResp: &datapb.GetFlushedSegmentsResponse{}, + } + resp, err := server.GetFlushedSegments(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("ShowConfigurations", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + configResp: &internalpb.ShowConfigurationsResponse{}, + } + resp, err := server.ShowConfigurations(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetMetrics", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + metricResp: &milvuspb.GetMetricsResponse{}, + } + resp, err := server.GetMetrics(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("WatchChannels", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + watchChannelsResp: &datapb.WatchChannelsResponse{}, + } + resp, err := server.WatchChannels(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetFlushState", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + getFlushStateResp: &milvuspb.GetFlushStateResponse{}, + } + resp, err := server.GetFlushState(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetFlushAllState", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + getFlushAllStateResp: &milvuspb.GetFlushAllStateResponse{}, + } + resp, err := server.GetFlushAllState(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("DropVirtualChannel", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + dropVChanResp: &datapb.DropVirtualChannelResponse{}, + } + resp, err := server.DropVirtualChannel(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("ManualCompaction", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + manualCompactionResp: &milvuspb.ManualCompactionResponse{}, + } + resp, err := server.ManualCompaction(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetCompactionState", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + compactionStateResp: &milvuspb.GetCompactionStateResponse{}, + } + resp, err := server.GetCompactionState(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetCompactionStateWithPlans", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + compactionPlansResp: &milvuspb.GetCompactionPlansResponse{}, + } + resp, err := server.GetCompactionStateWithPlans(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("set segment state", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + setSegmentStateResp: &datapb.SetSegmentStateResponse{}, + } + resp, err := server.SetSegmentState(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("import", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + importResp: &datapb.ImportTaskResponse{ + Status: &commonpb.Status{}, + }, + } + resp, err := server.Import(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("update seg stat", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + updateSegStatResp: merr.Success(), + } + resp, err := server.UpdateSegmentStatistics(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("UpdateChannelCheckpoint", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + updateChanPos: merr.Success(), + } + resp, err := server.UpdateChannelCheckpoint(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("save import segment", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + addSegmentResp: merr.Success(), + } + resp, err := server.SaveImportSegment(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("unset isImporting state", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + unsetIsImportingStateResp: merr.Success(), + } + resp, err := server.UnsetIsImportingState(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("mark segments dropped", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + markSegmentsDroppedResp: merr.Success(), + } + resp, err := server.MarkSegmentsDropped(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("broadcast altered collection", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + broadCastResp: &commonpb.Status{}, + } + resp, err := server.BroadcastAlteredCollection(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("CheckHealth", func(t *testing.T) { + server.dataCoord = &MockDataCoord{} + ret, err := server.CheckHealth(ctx, nil) + assert.NoError(t, err) + assert.Equal(t, true, ret.IsHealthy) + }) + + t.Run("CreateIndex", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + createIndexResp: &commonpb.Status{}, + } + ret, err := server.CreateIndex(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, ret) + }) + + t.Run("DescribeIndex", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + describeIndexResp: &indexpb.DescribeIndexResponse{}, + } + ret, err := server.DescribeIndex(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, ret) + }) + + t.Run("GetIndexStatistics", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + getIndexStatisticsResp: &indexpb.GetIndexStatisticsResponse{}, + } + ret, err := server.GetIndexStatistics(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, ret) + }) + + t.Run("DropIndex", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + dropIndexResp: &commonpb.Status{}, + } + ret, err := server.DropIndex(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, ret) + }) + + t.Run("GetIndexState", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + getIndexStateResp: &indexpb.GetIndexStateResponse{}, + } + ret, err := server.GetIndexState(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, ret) + }) + + t.Run("GetIndexBuildProgress", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + getIndexBuildProgressResp: &indexpb.GetIndexBuildProgressResponse{}, + } + ret, err := server.GetIndexBuildProgress(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, ret) + }) + + t.Run("GetSegmentIndexState", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + getSegmentIndexStateResp: &indexpb.GetSegmentIndexStateResponse{}, + } + ret, err := server.GetSegmentIndexState(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, ret) + }) + + t.Run("GetIndexInfos", func(t *testing.T) { + server.dataCoord = &MockDataCoord{ + getIndexInfosResp: &indexpb.GetIndexInfoResponse{}, + } + ret, err := server.GetIndexInfos(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, ret) + }) + + err := server.Stop() assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("ManualCompaction", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - manualCompactionResp: &milvuspb.ManualCompactionResponse{}, - } - resp, err := server.ManualCompaction(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetCompactionState", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - compactionStateResp: &milvuspb.GetCompactionStateResponse{}, - } - resp, err := server.GetCompactionState(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetCompactionStateWithPlans", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - compactionPlansResp: &milvuspb.GetCompactionPlansResponse{}, - } - resp, err := server.GetCompactionStateWithPlans(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("set segment state", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - setSegmentStateResp: &datapb.SetSegmentStateResponse{}, - } - resp, err := server.SetSegmentState(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("import", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - importResp: &datapb.ImportTaskResponse{ - Status: &commonpb.Status{}, - }, - } - resp, err := server.Import(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("update seg stat", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - updateSegStatResp: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - } - resp, err := server.UpdateSegmentStatistics(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("UpdateChannelCheckpoint", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - updateChanPos: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - } - resp, err := server.UpdateChannelCheckpoint(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("save import segment", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - addSegmentResp: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - } - resp, err := server.SaveImportSegment(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("unset isImporting state", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - unsetIsImportingStateResp: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - } - resp, err := server.UnsetIsImportingState(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("mark segments dropped", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - markSegmentsDroppedResp: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - } - resp, err := server.MarkSegmentsDropped(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("broadcast altered collection", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - broadCastResp: &commonpb.Status{}, - } - resp, err := server.BroadcastAlteredCollection(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("CheckHealth", func(t *testing.T) { - server.dataCoord = &MockDataCoord{} - ret, err := server.CheckHealth(ctx, nil) - assert.NoError(t, err) - assert.Equal(t, true, ret.IsHealthy) - }) - - t.Run("CreateIndex", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - createIndexResp: &commonpb.Status{}, - } - ret, err := server.CreateIndex(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, ret) - }) - - t.Run("DescribeIndex", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - describeIndexResp: &indexpb.DescribeIndexResponse{}, - } - ret, err := server.DescribeIndex(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, ret) - }) - - t.Run("GetIndexStatistics", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - getIndexStatisticsResp: &indexpb.GetIndexStatisticsResponse{}, - } - ret, err := server.GetIndexStatistics(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, ret) - }) - - t.Run("DropIndex", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - dropIndexResp: &commonpb.Status{}, - } - ret, err := server.DropIndex(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, ret) - }) - - t.Run("GetIndexState", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - getIndexStateResp: &indexpb.GetIndexStateResponse{}, - } - ret, err := server.GetIndexState(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, ret) - }) - - t.Run("GetIndexBuildProgress", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - getIndexBuildProgressResp: &indexpb.GetIndexBuildProgressResponse{}, - } - ret, err := server.GetIndexBuildProgress(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, ret) - }) - - t.Run("GetSegmentIndexState", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - getSegmentIndexStateResp: &indexpb.GetSegmentIndexStateResponse{}, - } - ret, err := server.GetSegmentIndexState(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, ret) - }) - - t.Run("GetIndexInfos", func(t *testing.T) { - server.dataCoord = &MockDataCoord{ - getIndexInfosResp: &indexpb.GetIndexInfoResponse{}, - } - ret, err := server.GetIndexInfos(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, ret) - }) - - err := server.Stop() - assert.NoError(t, err) + } } func Test_Run(t *testing.T) { - ctx := context.Background() - server := NewServer(ctx, nil) - assert.NotNil(t, server) - - server.dataCoord = &MockDataCoord{ - regErr: errors.New("error"), - } - - err := server.Run() - assert.Error(t, err) - - server.dataCoord = &MockDataCoord{ - startErr: errors.New("error"), - } - - err = server.Run() - assert.Error(t, err) - - server.dataCoord = &MockDataCoord{ - initErr: errors.New("error"), - } - - err = server.Run() - assert.Error(t, err) - - server.dataCoord = &MockDataCoord{ - stopErr: errors.New("error"), + paramtable.Init() + parameters := []string{"tikv", "etcd"} + for _, v := range parameters { + t.Run(fmt.Sprintf("Run server with %s as metadata storage", v), func(t *testing.T) { + paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, v) + ctx := context.Background() + getTiKVClient = func(cfg *paramtable.TiKVConfig) (*txnkv.Client, error) { + return tikv.SetupLocalTxn(), nil + } + defer func() { + getTiKVClient = tikv.GetTiKVClient + }() + server := NewServer(ctx, nil) + assert.NotNil(t, server) + + server.dataCoord = &MockDataCoord{ + regErr: errors.New("error"), + } + + err := server.Run() + assert.Error(t, err) + + server.dataCoord = &MockDataCoord{ + startErr: errors.New("error"), + } + + err = server.Run() + assert.Error(t, err) + + server.dataCoord = &MockDataCoord{ + initErr: errors.New("error"), + } + + err = server.Run() + assert.Error(t, err) + + server.dataCoord = &MockDataCoord{ + stopErr: errors.New("error"), + } + + err = server.Stop() + assert.Error(t, err) + }) } - - err = server.Stop() - assert.Error(t, err) } diff --git a/internal/distributed/datanode/client/client.go b/internal/distributed/datanode/client/client.go index f9f1197e850a3..9db4d3a7658ef 100644 --- a/internal/distributed/datanode/client/client.go +++ b/internal/distributed/datanode/client/client.go @@ -20,6 +20,8 @@ import ( "context" "fmt" + "google.golang.org/grpc" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -29,7 +31,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" - "google.golang.org/grpc" ) var Params *paramtable.ComponentParam = paramtable.Get() @@ -50,7 +51,8 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error) addr: addr, grpcClient: grpcclient.NewClientBase[datapb.DataNodeClient](config, "milvus.proto.data.DataNode"), } - client.grpcClient.SetRole(typeutil.DataNodeRole) + // node shall specify node id + client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.DataNodeRole, nodeID)) client.grpcClient.SetGetAddrFunc(client.getAddr) client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) client.grpcClient.SetNodeID(nodeID) @@ -58,28 +60,12 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error) return client, nil } -// Init initializes the client. -func (c *Client) Init() error { - return nil -} - -// Start starts the client. -// Currently, it does nothing. -func (c *Client) Start() error { - return nil -} - -// Stop stops the client. +// Close stops the client. // Currently, it closes the grpc connection with the DataNode. -func (c *Client) Stop() error { +func (c *Client) Close() error { return c.grpcClient.Close() } -// Register does nothing. -func (c *Client) Register() error { - return nil -} - func (c *Client) newGrpcClient(cc *grpc.ClientConn) datapb.DataNodeClient { return datapb.NewDataNodeClient(cc) } @@ -102,7 +88,7 @@ func wrapGrpcCall[T any](ctx context.Context, c *Client, call func(grpcClient da } // GetComponentStates returns ComponentStates -func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (c *Client) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*milvuspb.ComponentStates, error) { return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) @@ -110,7 +96,7 @@ func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentSta // GetStatisticsChannel return the statistics channel in string // Statistics channel contains statistics infos of query nodes, such as segment infos, memory infos -func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Client) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*milvuspb.StringResponse, error) { return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) @@ -118,7 +104,7 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // Deprecated // WatchDmChannels create consumers on dmChannels to reveive Incremental data -func (c *Client) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannelsRequest) (*commonpb.Status, error) { +func (c *Client) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -140,7 +126,7 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannel // Return Success code in status and trigers background flush: // // Log an info log if a segment is under flushing -func (c *Client) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsRequest) (*commonpb.Status, error) { +func (c *Client) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -151,7 +137,7 @@ func (c *Client) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsReq } // ShowConfigurations gets specified configurations para of DataNode -func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { +func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -162,7 +148,7 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon } // GetMetrics returns metrics -func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -173,13 +159,13 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest } // Compaction return compaction by given plan -func (c *Client) Compaction(ctx context.Context, req *datapb.CompactionPlan) (*commonpb.Status, error) { +func (c *Client) Compaction(ctx context.Context, req *datapb.CompactionPlan, opts ...grpc.CallOption) (*commonpb.Status, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*commonpb.Status, error) { return client.Compaction(ctx, req) }) } -func (c *Client) GetCompactionState(ctx context.Context, req *datapb.CompactionStateRequest) (*datapb.CompactionStateResponse, error) { +func (c *Client) GetCompactionState(ctx context.Context, req *datapb.CompactionStateRequest, opts ...grpc.CallOption) (*datapb.CompactionStateResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -190,7 +176,7 @@ func (c *Client) GetCompactionState(ctx context.Context, req *datapb.CompactionS } // Import data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments -func (c *Client) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*commonpb.Status, error) { +func (c *Client) Import(ctx context.Context, req *datapb.ImportTaskRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -200,7 +186,7 @@ func (c *Client) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*co }) } -func (c *Client) ResendSegmentStats(ctx context.Context, req *datapb.ResendSegmentStatsRequest) (*datapb.ResendSegmentStatsResponse, error) { +func (c *Client) ResendSegmentStats(ctx context.Context, req *datapb.ResendSegmentStatsRequest, opts ...grpc.CallOption) (*datapb.ResendSegmentStatsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -211,7 +197,7 @@ func (c *Client) ResendSegmentStats(ctx context.Context, req *datapb.ResendSegme } // AddImportSegment is the DataNode client side code for AddImportSegment call. -func (c *Client) AddImportSegment(ctx context.Context, req *datapb.AddImportSegmentRequest) (*datapb.AddImportSegmentResponse, error) { +func (c *Client) AddImportSegment(ctx context.Context, req *datapb.AddImportSegmentRequest, opts ...grpc.CallOption) (*datapb.AddImportSegmentResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -222,8 +208,27 @@ func (c *Client) AddImportSegment(ctx context.Context, req *datapb.AddImportSegm } // SyncSegments is the DataNode client side code for SyncSegments call. -func (c *Client) SyncSegments(ctx context.Context, req *datapb.SyncSegmentsRequest) (*commonpb.Status, error) { +func (c *Client) SyncSegments(ctx context.Context, req *datapb.SyncSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*commonpb.Status, error) { return client.SyncSegments(ctx, req) }) } + +// FlushChannels notifies DataNode to sync all the segments belongs to the target channels. +func (c *Client) FlushChannels(ctx context.Context, req *datapb.FlushChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*commonpb.Status, error) { + return client.FlushChannels(ctx, req) + }) +} + +func (c *Client) NotifyChannelOperation(ctx context.Context, req *datapb.ChannelOperationsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*commonpb.Status, error) { + return client.NotifyChannelOperation(ctx, req) + }) +} + +func (c *Client) CheckChannelOperationProgress(ctx context.Context, req *datapb.ChannelWatchInfo, opts ...grpc.CallOption) (*datapb.ChannelOperationProgressResponse, error) { + return wrapGrpcCall(ctx, c, func(client datapb.DataNodeClient) (*datapb.ChannelOperationProgressResponse, error) { + return client.CheckChannelOperationProgress(ctx, req) + }) +} diff --git a/internal/distributed/datanode/client/client_test.go b/internal/distributed/datanode/client/client_test.go index 4b50a1f09d08d..a8364c5c7187b 100644 --- a/internal/distributed/datanode/client/client_test.go +++ b/internal/distributed/datanode/client/client_test.go @@ -21,13 +21,12 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/util/mock" - "github.com/milvus-io/milvus/pkg/util/paramtable" - - "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/stretchr/testify/assert" "google.golang.org/grpc" - "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/util/mock" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func Test_NewClient(t *testing.T) { @@ -41,15 +40,6 @@ func Test_NewClient(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, client) - err = client.Init() - assert.NoError(t, err) - - err = client.Start() - assert.NoError(t, err) - - err = client.Register() - assert.NoError(t, err) - checkFunc := func(retNotNil bool) { retCheck := func(notNil bool, ret interface{}, err error) { if notNil { @@ -61,10 +51,10 @@ func Test_NewClient(t *testing.T) { } } - r1, err := client.GetComponentStates(ctx) + r1, err := client.GetComponentStates(ctx, nil) retCheck(retNotNil, r1, err) - r2, err := client.GetStatisticsChannel(ctx) + r2, err := client.GetStatisticsChannel(ctx, nil) retCheck(retNotNil, r2, err) r3, err := client.WatchDmChannels(ctx, nil) @@ -93,6 +83,12 @@ func Test_NewClient(t *testing.T) { r11, err := client.GetCompactionState(ctx, nil) retCheck(retNotNil, r11, err) + + r12, err := client.NotifyChannelOperation(ctx, nil) + retCheck(retNotNil, r12, err) + + r13, err := client.CheckChannelOperationProgress(ctx, nil) + retCheck(retNotNil, r13, err) } client.grpcClient = &mock.GRPCClientBase[datapb.DataNodeClient]{ @@ -129,6 +125,6 @@ func Test_NewClient(t *testing.T) { checkFunc(true) - err = client.Stop() + err = client.Close() assert.NoError(t, err) } diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 62d4ce401a427..3d3dba2537d70 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -24,12 +24,6 @@ import ( "sync" "time" - "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/util/componentutil" - "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/pkg/tracer" - "github.com/milvus-io/milvus/pkg/util/interceptor" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" clientv3 "go.etcd.io/etcd/client/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -43,13 +37,19 @@ import ( dn "github.com/milvus-io/milvus/internal/datanode" dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" + "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/componentutil" + "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" ) @@ -69,22 +69,22 @@ type Server struct { rootCoord types.RootCoord dataCoord types.DataCoord - newRootCoordClient func(string, *clientv3.Client) (types.RootCoord, error) - newDataCoordClient func(string, *clientv3.Client) (types.DataCoord, error) + newRootCoordClient func(string, *clientv3.Client) (types.RootCoordClient, error) + newDataCoordClient func(string, *clientv3.Client) (types.DataCoordClient, error) } // NewServer new DataNode grpc server func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) { ctx1, cancel := context.WithCancel(ctx) - var s = &Server{ + s := &Server{ ctx: ctx1, cancel: cancel, factory: factory, grpcErrChan: make(chan error), - newRootCoordClient: func(etcdMetaRoot string, client *clientv3.Client) (types.RootCoord, error) { + newRootCoordClient: func(etcdMetaRoot string, client *clientv3.Client) (types.RootCoordClient, error) { return rcc.NewClient(ctx1, etcdMetaRoot, client) }, - newDataCoordClient: func(etcdMetaRoot string, client *clientv3.Client) (types.DataCoord, error) { + newDataCoordClient: func(etcdMetaRoot string, client *clientv3.Client) (types.DataCoordClient, error) { return dcc.NewClient(ctx1, etcdMetaRoot, client) }, } @@ -107,12 +107,12 @@ func (s *Server) startGrpc() error { func (s *Server) startGrpcLoop(grpcPort int) { defer s.wg.Done() Params := ¶mtable.Get().DataNodeGrpcServerCfg - var kaep = keepalive.EnforcementPolicy{ + kaep := keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, // If a client pings more than once every 5 seconds, terminate the connection PermitWithoutStream: true, // Allow pings even when there are no active streams } - var kasp = keepalive.ServerParameters{ + kasp := keepalive.ServerParameters{ Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } @@ -124,7 +124,6 @@ func (s *Server) startGrpcLoop(grpcPort int) { lis, err = net.Listen("tcp", addr) return err }, retry.Attempts(10)) - if err != nil { log.Error("DataNode GrpcServer:failed to listen", zap.Error(err)) s.grpcErrChan <- err @@ -169,19 +168,18 @@ func (s *Server) startGrpcLoop(grpcPort int) { log.Warn("DataNode failed to start gRPC") s.grpcErrChan <- err } - } func (s *Server) SetEtcdClient(client *clientv3.Client) { s.datanode.SetEtcdClient(client) } -func (s *Server) SetRootCoordInterface(ms types.RootCoord) error { - return s.datanode.SetRootCoord(ms) +func (s *Server) SetRootCoordInterface(ms types.RootCoordClient) error { + return s.datanode.SetRootCoordClient(ms) } -func (s *Server) SetDataCoordInterface(ds types.DataCoord) error { - return s.datanode.SetDataCoord(ds) +func (s *Server) SetDataCoordInterface(ds types.DataCoordClient) error { + return s.datanode.SetDataCoordClient(ds) } // Run initializes and starts Datanode's grpc service. @@ -208,22 +206,7 @@ func (s *Server) Stop() error { defer s.etcdCli.Close() } if s.grpcServer != nil { - log.Debug("Graceful stop grpc server...") - // make graceful stop has a timeout - stopped := make(chan struct{}) - go func() { - s.grpcServer.GracefulStop() - close(stopped) - }() - - t := time.NewTimer(10 * time.Second) - select { - case <-t.C: - // hard stop since grace timeout - s.grpcServer.Stop() - case <-stopped: - t.Stop() - } + utils.GracefulStopGRPCServer(s.grpcServer) } err := s.datanode.Stop() @@ -274,14 +257,7 @@ func (s *Server) init() error { log.Error("failed to create new RootCoord client", zap.Error(err)) panic(err) } - if err = rootCoordClient.Init(); err != nil { - log.Error("failed to init RootCoord client", zap.Error(err)) - panic(err) - } - if err = rootCoordClient.Start(); err != nil { - log.Error("failed to start RootCoord client", zap.Error(err)) - panic(err) - } + if err = componentutil.WaitForComponentHealthy(ctx, rootCoordClient, "RootCoord", 1000000, time.Millisecond*200); err != nil { log.Error("failed to wait for RootCoord client to be ready", zap.Error(err)) panic(err) @@ -300,14 +276,7 @@ func (s *Server) init() error { log.Error("failed to create new DataCoord client", zap.Error(err)) panic(err) } - if err = dataCoordClient.Init(); err != nil { - log.Error("failed to init DataCoord client", zap.Error(err)) - panic(err) - } - if err = dataCoordClient.Start(); err != nil { - log.Error("failed to start DataCoord client", zap.Error(err)) - panic(err) - } + if err = componentutil.WaitForComponentInitOrHealthy(ctx, dataCoordClient, "DataCoord", 1000000, time.Millisecond*200); err != nil { log.Error("failed to wait for DataCoord client to be ready", zap.Error(err)) panic(err) @@ -343,12 +312,12 @@ func (s *Server) start() error { // GetComponentStates gets the component states of Datanode func (s *Server) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { - return s.datanode.GetComponentStates(ctx) + return s.datanode.GetComponentStates(ctx, req) } // GetStatisticsChannel gets the statistics channel of Datanode. func (s *Server) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { - return s.datanode.GetStatisticsChannel(ctx) + return s.datanode.GetStatisticsChannel(ctx, req) } // Deprecated @@ -357,11 +326,8 @@ func (s *Server) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannel } func (s *Server) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsRequest) (*commonpb.Status, error) { - if s.datanode.GetStateCode() != commonpb.StateCode_Healthy { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "DataNode isn't healthy.", - }, errors.New("DataNode is not ready yet") + if err := merr.CheckHealthy(s.datanode.GetStateCode()); err != nil { + return merr.Status(err), nil } return s.datanode.FlushSegments(ctx, req) } @@ -400,3 +366,15 @@ func (s *Server) AddImportSegment(ctx context.Context, request *datapb.AddImport func (s *Server) SyncSegments(ctx context.Context, request *datapb.SyncSegmentsRequest) (*commonpb.Status, error) { return s.datanode.SyncSegments(ctx, request) } + +func (s *Server) FlushChannels(ctx context.Context, req *datapb.FlushChannelsRequest) (*commonpb.Status, error) { + return s.datanode.FlushChannels(ctx, req) +} + +func (s *Server) NotifyChannelOperation(ctx context.Context, req *datapb.ChannelOperationsRequest) (*commonpb.Status, error) { + return s.datanode.NotifyChannelOperation(ctx, req) +} + +func (s *Server) CheckChannelOperationProgress(ctx context.Context, req *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error) { + return s.datanode.CheckChannelOperationProgress(ctx, req) +} diff --git a/internal/distributed/datanode/service_test.go b/internal/distributed/datanode/service_test.go index ff8248e587ed0..e6550c84c23a7 100644 --- a/internal/distributed/datanode/service_test.go +++ b/internal/distributed/datanode/service_test.go @@ -22,16 +22,18 @@ import ( "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + clientv3 "go.etcd.io/etcd/client/v3" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/stretchr/testify/assert" - clientv3 "go.etcd.io/etcd/client/v3" ) // ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -89,19 +91,19 @@ func (m *MockDataNode) GetAddress() string { return "" } -func (m *MockDataNode) SetRootCoord(rc types.RootCoord) error { +func (m *MockDataNode) SetRootCoordClient(rc types.RootCoordClient) error { return m.err } -func (m *MockDataNode) SetDataCoord(dc types.DataCoord) error { +func (m *MockDataNode) SetDataCoordClient(dc types.DataCoordClient) error { return m.err } -func (m *MockDataNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (m *MockDataNode) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { return m.states, m.err } -func (m *MockDataNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (m *MockDataNode) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { return m.strResp, m.err } @@ -148,25 +150,29 @@ func (m *MockDataNode) SyncSegments(ctx context.Context, req *datapb.SyncSegment return m.status, m.err } -// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -type mockDataCoord struct { - types.DataCoord +func (m *MockDataNode) FlushChannels(ctx context.Context, req *datapb.FlushChannelsRequest) (*commonpb.Status, error) { + return m.status, m.err +} + +func (m *MockDataNode) NotifyChannelOperation(ctx context.Context, req *datapb.ChannelOperationsRequest) (*commonpb.Status, error) { + return m.status, m.err } -func (m *mockDataCoord) Init() error { - return nil +func (m *MockDataNode) CheckChannelOperationProgress(ctx context.Context, req *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error) { + return &datapb.ChannelOperationProgressResponse{}, m.err } -func (m *mockDataCoord) Start() error { - return nil + +// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +type mockDataCoord struct { + types.DataCoordClient } -func (m *mockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { + +func (m *mockDataCoord) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Healthy, }, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), SubcomponentStates: []*milvuspb.ComponentInfo{ { StateCode: commonpb.StateCode_Healthy, @@ -174,29 +180,22 @@ func (m *mockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.Compo }, }, nil } + func (m *mockDataCoord) Stop() error { return fmt.Errorf("stop error") } // ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// type mockRootCoord struct { - types.RootCoord + types.RootCoordClient } -func (m *mockRootCoord) Init() error { - return nil -} -func (m *mockRootCoord) Start() error { - return nil -} -func (m *mockRootCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (m *mockRootCoord) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Healthy, }, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), SubcomponentStates: []*milvuspb.ComponentInfo{ { StateCode: commonpb.StateCode_Healthy, @@ -204,6 +203,7 @@ func (m *mockRootCoord) GetComponentStates(ctx context.Context) (*milvuspb.Compo }, }, nil } + func (m *mockRootCoord) Stop() error { return fmt.Errorf("stop error") } @@ -216,11 +216,11 @@ func Test_NewServer(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, server) - server.newRootCoordClient = func(string, *clientv3.Client) (types.RootCoord, error) { + server.newRootCoordClient = func(string, *clientv3.Client) (types.RootCoordClient, error) { return &mockRootCoord{}, nil } - server.newDataCoordClient = func(string, *clientv3.Client) (types.DataCoord, error) { + server.newDataCoordClient = func(string, *clientv3.Client) (types.DataCoordClient, error) { return &mockDataCoord{}, nil } @@ -262,7 +262,7 @@ func Test_NewServer(t *testing.T) { status: &commonpb.Status{}, } states, err := server.FlushSegments(ctx, nil) - assert.Error(t, err) + assert.NoError(t, err) assert.NotNil(t, states) }) @@ -315,9 +315,7 @@ func Test_NewServer(t *testing.T) { server.datanode = &MockDataNode{ status: &commonpb.Status{}, addImportSegmentResp: &datapb.AddImportSegmentResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, } resp, err := server.AddImportSegment(ctx, nil) @@ -325,6 +323,24 @@ func Test_NewServer(t *testing.T) { assert.NotNil(t, resp) }) + t.Run("NotifyChannelOperation", func(t *testing.T) { + server.datanode = &MockDataNode{ + status: &commonpb.Status{}, + } + resp, err := server.NotifyChannelOperation(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("CheckChannelOperationProgress", func(t *testing.T) { + server.datanode = &MockDataNode{ + status: &commonpb.Status{}, + } + resp, err := server.CheckChannelOperationProgress(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + err = server.Stop() assert.NoError(t, err) } @@ -339,11 +355,11 @@ func Test_Run(t *testing.T) { regErr: errors.New("error"), } - server.newRootCoordClient = func(string, *clientv3.Client) (types.RootCoord, error) { + server.newRootCoordClient = func(string, *clientv3.Client) (types.RootCoordClient, error) { return &mockRootCoord{}, nil } - server.newDataCoordClient = func(string, *clientv3.Client) (types.DataCoord, error) { + server.newDataCoordClient = func(string, *clientv3.Client) (types.DataCoordClient, error) { return &mockDataCoord{}, nil } diff --git a/internal/distributed/indexnode/client/client.go b/internal/distributed/indexnode/client/client.go index 999f6dddc9fa7..af45015b7f597 100644 --- a/internal/distributed/indexnode/client/client.go +++ b/internal/distributed/indexnode/client/client.go @@ -20,13 +20,13 @@ import ( "context" "fmt" - "github.com/milvus-io/milvus/internal/util/grpcclient" "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -51,7 +51,8 @@ func NewClient(ctx context.Context, addr string, nodeID int64, encryption bool) addr: addr, grpcClient: grpcclient.NewClientBase[indexpb.IndexNodeClient](config, "milvus.proto.index.IndexNode"), } - client.grpcClient.SetRole(typeutil.IndexNodeRole) + // node shall specify node id + client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.IndexNodeRole, nodeID)) client.grpcClient.SetGetAddrFunc(client.getAddr) client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) client.grpcClient.SetNodeID(nodeID) @@ -61,26 +62,11 @@ func NewClient(ctx context.Context, addr string, nodeID int64, encryption bool) return client, nil } -// Init initializes IndexNode's grpc client. -func (c *Client) Init() error { - return nil -} - -// Start starts IndexNode's client service. But it does nothing here. -func (c *Client) Start() error { - return nil -} - -// Stop stops IndexNode's grpc client. -func (c *Client) Stop() error { +// Close stops IndexNode's grpc client. +func (c *Client) Close() error { return c.grpcClient.Close() } -// Register dummy -func (c *Client) Register() error { - return nil -} - func (c *Client) newGrpcClient(cc *grpc.ClientConn) indexpb.IndexNodeClient { return indexpb.NewIndexNodeClient(cc) } @@ -103,48 +89,48 @@ func wrapGrpcCall[T any](ctx context.Context, c *Client, call func(indexClient i } // GetComponentStates gets the component states of IndexNode. -func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (c *Client) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return wrapGrpcCall(ctx, c, func(client indexpb.IndexNodeClient) (*milvuspb.ComponentStates, error) { return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) } -func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Client) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return wrapGrpcCall(ctx, c, func(client indexpb.IndexNodeClient) (*milvuspb.StringResponse, error) { return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) } // CreateJob sends the build index request to IndexNode. -func (c *Client) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest) (*commonpb.Status, error) { +func (c *Client) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return wrapGrpcCall(ctx, c, func(client indexpb.IndexNodeClient) (*commonpb.Status, error) { return client.CreateJob(ctx, req) }) } // QueryJobs query the task info of the index task. -func (c *Client) QueryJobs(ctx context.Context, req *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error) { +func (c *Client) QueryJobs(ctx context.Context, req *indexpb.QueryJobsRequest, opts ...grpc.CallOption) (*indexpb.QueryJobsResponse, error) { return wrapGrpcCall(ctx, c, func(client indexpb.IndexNodeClient) (*indexpb.QueryJobsResponse, error) { return client.QueryJobs(ctx, req) }) } // DropJobs query the task info of the index task. -func (c *Client) DropJobs(ctx context.Context, req *indexpb.DropJobsRequest) (*commonpb.Status, error) { +func (c *Client) DropJobs(ctx context.Context, req *indexpb.DropJobsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return wrapGrpcCall(ctx, c, func(client indexpb.IndexNodeClient) (*commonpb.Status, error) { return client.DropJobs(ctx, req) }) } // GetJobStats query the task info of the index task. -func (c *Client) GetJobStats(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { +func (c *Client) GetJobStats(ctx context.Context, req *indexpb.GetJobStatsRequest, opts ...grpc.CallOption) (*indexpb.GetJobStatsResponse, error) { return wrapGrpcCall(ctx, c, func(client indexpb.IndexNodeClient) (*indexpb.GetJobStatsResponse, error) { return client.GetJobStats(ctx, req) }) } // ShowConfigurations gets specified configurations para of IndexNode -func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { +func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -155,7 +141,7 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon } // GetMetrics gets the metrics info of IndexNode. -func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), diff --git a/internal/distributed/indexnode/client/client_test.go b/internal/distributed/indexnode/client/client_test.go index c0141008e1b20..07dc65ce1889c 100644 --- a/internal/distributed/indexnode/client/client_test.go +++ b/internal/distributed/indexnode/client/client_test.go @@ -21,18 +21,12 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/internal/util/mock" - "github.com/stretchr/testify/assert" "google.golang.org/grpc" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode" - "github.com/milvus-io/milvus/internal/indexnode" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -48,15 +42,6 @@ func Test_NewClient(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, client) - err = client.Init() - assert.NoError(t, err) - - err = client.Start() - assert.NoError(t, err) - - err = client.Register() - assert.NoError(t, err) - checkFunc := func(retNotNil bool) { retCheck := func(notNil bool, ret interface{}, err error) { if notNil { @@ -68,10 +53,10 @@ func Test_NewClient(t *testing.T) { } } - r1, err := client.GetComponentStates(ctx) + r1, err := client.GetComponentStates(ctx, nil) retCheck(retNotNil, r1, err) - r3, err := client.GetStatisticsChannel(ctx) + r3, err := client.GetStatisticsChannel(ctx, nil) retCheck(retNotNil, r3, err) r4, err := client.CreateJob(ctx, nil) @@ -118,57 +103,23 @@ func Test_NewClient(t *testing.T) { client.grpcClient.SetNewGrpcClientFunc(newFunc3) checkFunc(true) - err = client.Stop() + err = client.Close() assert.NoError(t, err) } func TestIndexNodeClient(t *testing.T) { - paramtable.Init() - ctx := context.Background() - - factory := dependency.NewDefaultFactory(true) - ins, err := grpcindexnode.NewServer(ctx, factory) - assert.NoError(t, err) - assert.NotNil(t, ins) - - inm := indexnode.NewIndexNodeMock() - etcdCli, err := etcd.GetEtcdClient( - Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), - Params.EtcdCfg.EtcdUseSSL.GetAsBool(), - Params.EtcdCfg.Endpoints.GetAsStrings(), - Params.EtcdCfg.EtcdTLSCert.GetValue(), - Params.EtcdCfg.EtcdTLSKey.GetValue(), - Params.EtcdCfg.EtcdTLSCACert.GetValue(), - Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) - assert.NoError(t, err) - inm.SetEtcdClient(etcdCli) - err = ins.SetClient(inm) - assert.NoError(t, err) - - err = ins.Run() - assert.NoError(t, err) - - inc, err := NewClient(ctx, "localhost:21121", paramtable.GetNodeID(), false) - assert.NoError(t, err) + inc := &mock.GrpcIndexNodeClient{Err: nil} assert.NotNil(t, inc) - err = inc.Init() - assert.NoError(t, err) - - err = inc.Start() - assert.NoError(t, err) - + ctx := context.TODO() t.Run("GetComponentStates", func(t *testing.T) { - states, err := inc.GetComponentStates(ctx) + _, err := inc.GetComponentStates(ctx, nil) assert.NoError(t, err) - assert.Equal(t, commonpb.StateCode_Healthy, states.State.StateCode) - assert.Equal(t, commonpb.ErrorCode_Success, states.Status.ErrorCode) }) t.Run("GetStatisticsChannel", func(t *testing.T) { - resp, err := inc.GetStatisticsChannel(ctx) + _, err := inc.GetStatisticsChannel(ctx, nil) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) }) t.Run("CreatJob", func(t *testing.T) { @@ -176,52 +127,43 @@ func TestIndexNodeClient(t *testing.T) { ClusterID: "0", BuildID: 0, } - resp, err := inc.CreateJob(ctx, req) + _, err := inc.CreateJob(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) }) t.Run("QueryJob", func(t *testing.T) { req := &indexpb.QueryJobsRequest{} - resp, err := inc.QueryJobs(ctx, req) + _, err := inc.QueryJobs(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) }) t.Run("DropJob", func(t *testing.T) { req := &indexpb.DropJobsRequest{} - resp, err := inc.DropJobs(ctx, req) + _, err := inc.DropJobs(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) }) t.Run("ShowConfigurations", func(t *testing.T) { req := &internalpb.ShowConfigurationsRequest{ Pattern: "", } - resp, err := inc.ShowConfigurations(ctx, req) + _, err := inc.ShowConfigurations(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) }) t.Run("GetMetrics", func(t *testing.T) { req, err := metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics) assert.NoError(t, err) - resp, err := inc.GetMetrics(ctx, req) + _, err = inc.GetMetrics(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) }) t.Run("GetJobStats", func(t *testing.T) { req := &indexpb.GetJobStatsRequest{} - resp, err := inc.GetJobStats(ctx, req) + _, err := inc.GetJobStats(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) }) - err = ins.Stop() - assert.NoError(t, err) - - err = inc.Stop() + err := inc.Close() assert.NoError(t, err) } diff --git a/internal/distributed/indexnode/service.go b/internal/distributed/indexnode/service.go index 1aa8fdee98049..4b969a5dfb726 100644 --- a/internal/distributed/indexnode/service.go +++ b/internal/distributed/indexnode/service.go @@ -25,8 +25,6 @@ import ( "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/pkg/tracer" clientv3 "go.etcd.io/etcd/client/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/atomic" @@ -36,11 +34,14 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/indexnode" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/interceptor" @@ -85,7 +86,7 @@ func (s *Server) startGrpcLoop(grpcPort int) { log.Debug("IndexNode", zap.String("network address", Params.GetAddress()), zap.Int("network port: ", grpcPort)) lis, err := net.Listen("tcp", ":"+strconv.Itoa(grpcPort)) if err != nil { - log.Warn("IndexNode", zap.String("GrpcServer:failed to listen", err.Error())) + log.Warn("IndexNode", zap.Error(err)) s.grpcErrChan <- err return } @@ -93,12 +94,12 @@ func (s *Server) startGrpcLoop(grpcPort int) { ctx, cancel := context.WithCancel(s.loopCtx) defer cancel() - var kaep = keepalive.EnforcementPolicy{ + kaep := keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, // If a client pings more than once every 5 seconds, terminate the connection PermitWithoutStream: true, // Allow pings even when there are no active streams } - var kasp = keepalive.ServerParameters{ + kasp := keepalive.ServerParameters{ Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } @@ -217,17 +218,16 @@ func (s *Server) Stop() error { defer s.etcdCli.Close() } if s.grpcServer != nil { - log.Debug("Graceful stop grpc server...") - s.grpcServer.GracefulStop() + utils.GracefulStopGRPCServer(s.grpcServer) } s.loopWg.Wait() return nil } -// SetClient sets the IndexNode's instance. -func (s *Server) SetClient(indexNodeClient types.IndexNodeComponent) error { - s.indexnode = indexNodeClient +// setServer sets the IndexNode's instance. +func (s *Server) setServer(indexNode types.IndexNodeComponent) error { + s.indexnode = indexNode return nil } @@ -238,12 +238,12 @@ func (s *Server) SetEtcdClient(etcdCli *clientv3.Client) { // GetComponentStates gets the component states of IndexNode. func (s *Server) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { - return s.indexnode.GetComponentStates(ctx) + return s.indexnode.GetComponentStates(ctx, req) } // GetStatisticsChannel gets the statistics channel of IndexNode. func (s *Server) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { - return s.indexnode.GetStatisticsChannel(ctx) + return s.indexnode.GetStatisticsChannel(ctx, req) } // CreateJob sends the create index request to IndexNode. diff --git a/internal/distributed/indexnode/service_test.go b/internal/distributed/indexnode/service_test.go index 5a7d877c7196c..edfc175423e59 100644 --- a/internal/distributed/indexnode/service_test.go +++ b/internal/distributed/indexnode/service_test.go @@ -20,7 +20,6 @@ import ( "context" "testing" - "github.com/milvus-io/milvus/internal/util/dependency" "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -28,13 +27,11 @@ import ( "github.com/milvus-io/milvus/internal/indexnode" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" ) -var ParamsGlobal paramtable.ComponentParam - func TestIndexNodeServer(t *testing.T) { paramtable.Init() ctx := context.Background() @@ -44,18 +41,7 @@ func TestIndexNodeServer(t *testing.T) { assert.NotNil(t, server) inm := indexnode.NewIndexNodeMock() - ParamsGlobal.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) - etcdCli, err := etcd.GetEtcdClient( - ParamsGlobal.EtcdCfg.UseEmbedEtcd.GetAsBool(), - ParamsGlobal.EtcdCfg.EtcdUseSSL.GetAsBool(), - ParamsGlobal.EtcdCfg.Endpoints.GetAsStrings(), - ParamsGlobal.EtcdCfg.EtcdTLSCert.GetValue(), - ParamsGlobal.EtcdCfg.EtcdTLSKey.GetValue(), - ParamsGlobal.EtcdCfg.EtcdTLSCACert.GetValue(), - ParamsGlobal.EtcdCfg.EtcdTLSMinVersion.GetValue()) - assert.NoError(t, err) - inm.SetEtcdClient(etcdCli) - err = server.SetClient(inm) + err = server.setServer(inm) assert.NoError(t, err) err = server.Run() @@ -72,7 +58,7 @@ func TestIndexNodeServer(t *testing.T) { req := &internalpb.GetStatisticsChannelRequest{} resp, err := server.GetStatisticsChannel(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("CreateJob", func(t *testing.T) { @@ -91,7 +77,7 @@ func TestIndexNodeServer(t *testing.T) { req := &indexpb.QueryJobsRequest{} resp, err := server.QueryJobs(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("DropJobs", func(t *testing.T) { @@ -107,7 +93,7 @@ func TestIndexNodeServer(t *testing.T) { } resp, err := server.ShowConfigurations(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("GetMetrics", func(t *testing.T) { @@ -115,14 +101,14 @@ func TestIndexNodeServer(t *testing.T) { assert.NoError(t, err) resp, err := server.GetMetrics(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("GetTaskSlots", func(t *testing.T) { req := &indexpb.GetJobStatsRequest{} resp, err := server.GetJobStats(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) err = server.Stop() diff --git a/internal/distributed/proxy/client/client.go b/internal/distributed/proxy/client/client.go index 18f3249994fd3..b45f00efd7efc 100644 --- a/internal/distributed/proxy/client/client.go +++ b/internal/distributed/proxy/client/client.go @@ -20,6 +20,8 @@ import ( "context" "fmt" + "google.golang.org/grpc" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -29,7 +31,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" - "google.golang.org/grpc" ) var Params *paramtable.ComponentParam = paramtable.Get() @@ -50,18 +51,14 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error) addr: addr, grpcClient: grpcclient.NewClientBase[proxypb.ProxyClient](config, "milvus.proto.proxy.Proxy"), } - client.grpcClient.SetRole(typeutil.ProxyRole) + // node shall specify node id + client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.ProxyRole, nodeID)) client.grpcClient.SetGetAddrFunc(client.getAddr) client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) client.grpcClient.SetNodeID(nodeID) return client, nil } -// Init initializes Proxy's grpc client. -func (c *Client) Init() error { - return nil -} - func (c *Client) newGrpcClient(cc *grpc.ClientConn) proxypb.ProxyClient { return proxypb.NewProxyClient(cc) } @@ -70,21 +67,11 @@ func (c *Client) getAddr() (string, error) { return c.addr, nil } -// Start dummy -func (c *Client) Start() error { - return nil -} - -// Stop stops the client, closes the connection -func (c *Client) Stop() error { +// Close stops the client, closes the connection +func (c *Client) Close() error { return c.grpcClient.Close() } -// Register dummy -func (c *Client) Register() error { - return nil -} - func wrapGrpcCall[T any](ctx context.Context, c *Client, call func(proxyClient proxypb.ProxyClient) (*T, error)) (*T, error) { ret, err := c.grpcClient.ReCall(ctx, func(client proxypb.ProxyClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { @@ -99,21 +86,21 @@ func wrapGrpcCall[T any](ctx context.Context, c *Client, call func(proxyClient p } // GetComponentStates get the component state. -func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (c *Client) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return wrapGrpcCall(ctx, c, func(client proxypb.ProxyClient) (*milvuspb.ComponentStates, error) { return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) } // GetStatisticsChannel return the statistics channel in string -func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Client) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return wrapGrpcCall(ctx, c, func(client proxypb.ProxyClient) (*milvuspb.StringResponse, error) { return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) } // InvalidateCollectionMetaCache invalidate collection meta cache -func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, req *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { +func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, req *proxypb.InvalidateCollMetaCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -124,7 +111,7 @@ func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, req *proxypb }) } -func (c *Client) InvalidateCredentialCache(ctx context.Context, req *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { +func (c *Client) InvalidateCredentialCache(ctx context.Context, req *proxypb.InvalidateCredCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -135,7 +122,7 @@ func (c *Client) InvalidateCredentialCache(ctx context.Context, req *proxypb.Inv }) } -func (c *Client) UpdateCredentialCache(ctx context.Context, req *proxypb.UpdateCredCacheRequest) (*commonpb.Status, error) { +func (c *Client) UpdateCredentialCache(ctx context.Context, req *proxypb.UpdateCredCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -146,7 +133,7 @@ func (c *Client) UpdateCredentialCache(ctx context.Context, req *proxypb.UpdateC }) } -func (c *Client) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { +func (c *Client) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -159,7 +146,7 @@ func (c *Client) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.Refres // GetProxyMetrics gets the metrics of proxy, it's an internal interface which is different from GetMetrics interface, // because it only obtains the metrics of Proxy, not including the topological metrics of Query cluster and Data cluster. -func (c *Client) GetProxyMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (c *Client) GetProxyMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -171,7 +158,7 @@ func (c *Client) GetProxyMetrics(ctx context.Context, req *milvuspb.GetMetricsRe } // SetRates notifies Proxy to limit rates of requests. -func (c *Client) SetRates(ctx context.Context, req *proxypb.SetRatesRequest) (*commonpb.Status, error) { +func (c *Client) SetRates(ctx context.Context, req *proxypb.SetRatesRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -182,7 +169,7 @@ func (c *Client) SetRates(ctx context.Context, req *proxypb.SetRatesRequest) (*c }) } -func (c *Client) ListClientInfos(ctx context.Context, req *proxypb.ListClientInfosRequest) (*proxypb.ListClientInfosResponse, error) { +func (c *Client) ListClientInfos(ctx context.Context, req *proxypb.ListClientInfosRequest, opts ...grpc.CallOption) (*proxypb.ListClientInfosResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -192,3 +179,9 @@ func (c *Client) ListClientInfos(ctx context.Context, req *proxypb.ListClientInf return client.ListClientInfos(ctx, req) }) } + +func (c *Client) GetDdChannel(ctx context.Context, req *internalpb.GetDdChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return wrapGrpcCall(ctx, c, func(client proxypb.ProxyClient) (*milvuspb.StringResponse, error) { + return client.GetDdChannel(ctx, req) + }) +} diff --git a/internal/distributed/proxy/client/client_test.go b/internal/distributed/proxy/client/client_test.go index bece027060ebd..1043bf8f53c04 100644 --- a/internal/distributed/proxy/client/client_test.go +++ b/internal/distributed/proxy/client/client_test.go @@ -22,12 +22,12 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/util/mock" - "github.com/milvus-io/milvus/pkg/util/paramtable" - - "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/stretchr/testify/assert" "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/internal/util/mock" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func Test_NewClient(t *testing.T) { @@ -42,12 +42,6 @@ func Test_NewClient(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, client) - err = client.Start() - assert.NoError(t, err) - - err = client.Register() - assert.NoError(t, err) - checkFunc := func(retNotNil bool) { retCheck := func(notNil bool, ret interface{}, err error) { if notNil { @@ -59,10 +53,10 @@ func Test_NewClient(t *testing.T) { } } - r1, err := client.GetComponentStates(ctx) + r1, err := client.GetComponentStates(ctx, nil) retCheck(retNotNil, r1, err) - r2, err := client.GetStatisticsChannel(ctx) + r2, err := client.GetStatisticsChannel(ctx, nil) retCheck(retNotNil, r2, err) r3, err := client.InvalidateCollectionMetaCache(ctx, nil) @@ -123,10 +117,10 @@ func Test_NewClient(t *testing.T) { assert.Error(t, err) } - r1Timeout, err := client.GetComponentStates(shortCtx) + r1Timeout, err := client.GetComponentStates(shortCtx, nil) retCheck(r1Timeout, err) - r2Timeout, err := client.GetStatisticsChannel(shortCtx) + r2Timeout, err := client.GetStatisticsChannel(shortCtx, nil) retCheck(r2Timeout, err) r3Timeout, err := client.InvalidateCollectionMetaCache(shortCtx, nil) @@ -144,6 +138,6 @@ func Test_NewClient(t *testing.T) { } // cleanup - err = client.Stop() + err = client.Close() assert.NoError(t, err) } diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index 9bd4bfc193b66..03e33c41b859c 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -7,6 +7,7 @@ const ( VectorCollectionsDescribePath = "/vector/collections/describe" VectorCollectionsDropPath = "/vector/collections/drop" VectorInsertPath = "/vector/insert" + VectorUpsertPath = "/vector/upsert" VectorSearchPath = "/vector/search" VectorGetPath = "/vector/get" VectorQueryPath = "/vector/query" diff --git a/internal/distributed/proxy/httpserver/handler.go b/internal/distributed/proxy/httpserver/handler.go index 8a99c7fa40570..2685448874ec7 100644 --- a/internal/distributed/proxy/httpserver/handler.go +++ b/internal/distributed/proxy/httpserver/handler.go @@ -5,6 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/golang/protobuf/proto" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/types" ) @@ -79,7 +80,6 @@ func (h *Handlers) RegisterRoutesTo(router gin.IRouter) { router.PATCH("/credential", wrapHandler(h.handleUpdateCredential)) router.DELETE("/credential", wrapHandler(h.handleDeleteCredential)) router.GET("/credential/users", wrapHandler(h.handleListCredUsers)) - } func (h *Handlers) handleGetHealth(c *gin.Context) (interface{}, error) { diff --git a/internal/distributed/proxy/httpserver/handler_test.go b/internal/distributed/proxy/httpserver/handler_test.go index 95cfd651f24c9..7957888e22dfc 100644 --- a/internal/distributed/proxy/httpserver/handler_test.go +++ b/internal/distributed/proxy/httpserver/handler_test.go @@ -10,14 +10,14 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/types" - "github.com/stretchr/testify/assert" ) func Test_WrappedInsertRequest_JSONMarshal_AsInsertRequest(t *testing.T) { @@ -46,8 +46,10 @@ func (m *mockProxyComponent) Dummy(ctx context.Context, request *milvuspb.DummyR return nil, nil } -var emptyBody = &gin.H{} -var testStatus = &commonpb.Status{Reason: "ok"} +var ( + emptyBody = &gin.H{} + testStatus = &commonpb.Status{Reason: "ok"} +) func (m *mockProxyComponent) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { return testStatus, nil @@ -112,9 +114,11 @@ func (m *mockProxyComponent) HasPartition(ctx context.Context, request *milvuspb func (m *mockProxyComponent) LoadPartitions(ctx context.Context, request *milvuspb.LoadPartitionsRequest) (*commonpb.Status, error) { return testStatus, nil } + func (m *mockProxyComponent) ReleasePartitions(ctx context.Context, request *milvuspb.ReleasePartitionsRequest) (*commonpb.Status, error) { return testStatus, nil } + func (m *mockProxyComponent) GetPartitionStatistics(ctx context.Context, request *milvuspb.GetPartitionStatisticsRequest) (*milvuspb.GetPartitionStatisticsResponse, error) { return &milvuspb.GetPartitionStatisticsResponse{Status: testStatus}, nil } @@ -413,7 +417,8 @@ func TestHandlers(t *testing.T) { }, { http.MethodGet, "/partition/statistics", emptyBody, - http.StatusOK, milvuspb.GetPartitionStatisticsResponse{Status: testStatus}, + http.StatusOK, + milvuspb.GetPartitionStatisticsResponse{Status: testStatus}, }, { http.MethodGet, "/partitions", emptyBody, @@ -456,26 +461,32 @@ func TestHandlers(t *testing.T) { http.StatusOK, &milvuspb.MutationResult{Acknowledged: true}, }, { - http.MethodDelete, "/entities", milvuspb.DeleteRequest{Expr: "some expr"}, + http.MethodDelete, "/entities", + milvuspb.DeleteRequest{Expr: "some expr"}, http.StatusOK, &milvuspb.MutationResult{Acknowledged: true}, }, { - http.MethodPost, "/search", milvuspb.SearchRequest{Dsl: "some dsl"}, + http.MethodPost, "/search", + milvuspb.SearchRequest{Dsl: "some dsl"}, http.StatusOK, &searchResult, }, { - http.MethodPost, "/query", milvuspb.QueryRequest{Expr: "some expr"}, + http.MethodPost, "/query", + milvuspb.QueryRequest{Expr: "some expr"}, http.StatusOK, &queryResult, }, { - http.MethodPost, "/persist", milvuspb.FlushRequest{CollectionNames: []string{"c1"}}, + http.MethodPost, "/persist", + milvuspb.FlushRequest{CollectionNames: []string{"c1"}}, http.StatusOK, flushResult, }, { - http.MethodGet, "/distance", milvuspb.CalcDistanceRequest{ + http.MethodGet, "/distance", + milvuspb.CalcDistanceRequest{ Params: []*commonpb.KeyValuePair{ {Key: "key", Value: "val"}, - }}, + }, + }, http.StatusOK, calcDistanceResult, }, { diff --git a/internal/distributed/proxy/httpserver/handler_v1.go b/internal/distributed/proxy/httpserver/handler_v1.go index df4e4b83cb62c..431e0ee634238 100644 --- a/internal/distributed/proxy/httpserver/handler_v1.go +++ b/internal/distributed/proxy/httpserver/handler_v1.go @@ -1,53 +1,52 @@ package httpserver import ( + "context" "encoding/json" "net/http" "strconv" - "github.com/milvus-io/milvus/pkg/util/merr" - - "github.com/cockroachdb/errors" - "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" "github.com/golang/protobuf/proto" + "github.com/tidwall/gjson" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proxy" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" - "github.com/tidwall/gjson" - "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/util/merr" ) -func checkAuthorization(c *gin.Context, req interface{}) error { +func checkAuthorization(ctx context.Context, c *gin.Context, req interface{}) error { if proxy.Params.CommonCfg.AuthorizationEnabled.GetAsBool() { username, ok := c.Get(ContextUsername) - if !ok { - c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) + if !ok || username.(string) == "" { + c.JSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) return merr.ErrNeedAuthenticate } - _, authErr := proxy.PrivilegeInterceptorWithUsername(c, username.(string), req) + _, authErr := proxy.PrivilegeInterceptor(ctx, req) if authErr != nil { - c.JSON(http.StatusForbidden, gin.H{HTTPReturnCode: Code(authErr), HTTPReturnMessage: authErr.Error()}) + c.JSON(http.StatusForbidden, gin.H{HTTPReturnCode: merr.Code(authErr), HTTPReturnMessage: authErr.Error()}) return authErr } } return nil } -func (h *Handlers) checkDatabase(c *gin.Context, dbName string) bool { +func (h *Handlers) checkDatabase(ctx context.Context, c *gin.Context, dbName string) bool { if dbName == DefaultDbName { return true } - response, err := h.proxy.ListDatabases(c, &milvuspb.ListDatabasesRequest{}) + response, err := h.proxy.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{}) + if err == nil { + err = merr.Error(response.GetStatus()) + } if err != nil { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()}) - return false - } else if response.Status.ErrorCode != commonpb.ErrorCode_Success { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: int32(response.Status.ErrorCode), HTTPReturnMessage: response.Status.Reason}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return false } for _, db := range response.DbNames { @@ -55,27 +54,27 @@ func (h *Handlers) checkDatabase(c *gin.Context, dbName string) bool { return true } } - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrDatabaseNotfound), HTTPReturnMessage: merr.ErrDatabaseNotfound.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrDatabaseNotFound), HTTPReturnMessage: merr.ErrDatabaseNotFound.Error()}) return false } -func (h *Handlers) describeCollection(c *gin.Context, dbName string, collectionName string, needAuth bool) (*milvuspb.DescribeCollectionResponse, error) { +func (h *Handlers) describeCollection(ctx context.Context, c *gin.Context, dbName string, collectionName string, needAuth bool) (*milvuspb.DescribeCollectionResponse, error) { req := milvuspb.DescribeCollectionRequest{ DbName: dbName, CollectionName: collectionName, } if needAuth { - if err := checkAuthorization(c, &req); err != nil { + if err := checkAuthorization(ctx, c, &req); err != nil { return nil, err } } - response, err := h.proxy.DescribeCollection(c, &req) + response, err := h.proxy.DescribeCollection(ctx, &req) + if err == nil { + err = merr.Error(response.GetStatus()) + } if err != nil { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return nil, err - } else if response.Status.ErrorCode != commonpb.ErrorCode_Success { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: int32(response.Status.ErrorCode), HTTPReturnMessage: response.Status.Reason}) - return nil, errors.New(response.Status.Reason) } primaryField, ok := getPrimaryField(response.Schema) if ok && primaryField.AutoID && !response.Schema.AutoID { @@ -85,21 +84,20 @@ func (h *Handlers) describeCollection(c *gin.Context, dbName string, collectionN return response, nil } -func (h *Handlers) hasCollection(c *gin.Context, dbName string, collectionName string) (bool, error) { +func (h *Handlers) hasCollection(ctx context.Context, c *gin.Context, dbName string, collectionName string) (bool, error) { req := milvuspb.HasCollectionRequest{ DbName: dbName, CollectionName: collectionName, } - response, err := h.proxy.HasCollection(c, &req) + response, err := h.proxy.HasCollection(ctx, &req) + if err == nil { + err = merr.Error(response.GetStatus()) + } if err != nil { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return false, err - } else if response.Status.ErrorCode != commonpb.ErrorCode_Success { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: int32(response.Status.ErrorCode), HTTPReturnMessage: response.Status.Reason}) - return false, errors.New(response.Status.Reason) - } else { - return response.Value, nil } + return response.Value, nil } func (h *Handlers) RegisterRoutesToV1(router gin.IRouter) { @@ -111,6 +109,7 @@ func (h *Handlers) RegisterRoutesToV1(router gin.IRouter) { router.POST(VectorGetPath, h.get) router.POST(VectorDeletePath, h.delete) router.POST(VectorInsertPath, h.insert) + router.POST(VectorUpsertPath, h.upsert) router.POST(VectorSearchPath, h.search) } @@ -119,26 +118,29 @@ func (h *Handlers) listCollections(c *gin.Context) { req := milvuspb.ShowCollectionsRequest{ DbName: dbName, } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, dbName) { + if !h.checkDatabase(ctx, c, dbName) { return } - response, err := h.proxy.ShowCollections(c, &req) + response, err := h.proxy.ShowCollections(ctx, &req) + if err == nil { + err = merr.Error(response.GetStatus()) + } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()}) - } else if response.Status.ErrorCode != commonpb.ErrorCode_Success { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: int32(response.Status.ErrorCode), HTTPReturnMessage: response.Status.Reason}) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + return + } + var collections []string + if response.CollectionNames != nil { + collections = response.CollectionNames } else { - var collections []string - if response.CollectionNames != nil { - collections = response.CollectionNames - } else { - collections = []string{} - } - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: collections}) + collections = []string{} } + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: collections}) } func (h *Handlers) createCollection(c *gin.Context) { @@ -150,12 +152,12 @@ func (h *Handlers) createCollection(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of create collection is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) return } if httpReq.CollectionName == "" || httpReq.Dimension == 0 { log.Warn("high level restful api, create collection require parameters: [collectionName, dimension], but miss", zap.Any("request", httpReq)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) return } schema, err := proto.Marshal(&schemapb.CollectionSchema{ @@ -187,7 +189,7 @@ func (h *Handlers) createCollection(c *gin.Context) { }) if err != nil { log.Warn("high level restful api, marshal collection schema fail", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMarshalCollectionSchema), HTTPReturnMessage: merr.ErrMarshalCollectionSchema.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMarshalCollectionSchema), HTTPReturnMessage: merr.ErrMarshalCollectionSchema.Error()}) return } req := milvuspb.CreateCollectionRequest{ @@ -197,44 +199,46 @@ func (h *Handlers) createCollection(c *gin.Context) { ShardsNum: ShardNumDefault, ConsistencyLevel: commonpb.ConsistencyLevel_Bounded, } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - response, err := h.proxy.CreateCollection(c, &req) + response, err := h.proxy.CreateCollection(ctx, &req) + if err == nil { + err = merr.Error(response) + } if err != nil { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()}) - return - } else if response.ErrorCode != commonpb.ErrorCode_Success { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: int32(response.ErrorCode), HTTPReturnMessage: response.Reason}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return } - response, err = h.proxy.CreateIndex(c, &milvuspb.CreateIndexRequest{ + response, err = h.proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ DbName: httpReq.DbName, CollectionName: httpReq.CollectionName, FieldName: httpReq.VectorField, IndexName: DefaultIndexName, ExtraParams: []*commonpb.KeyValuePair{{Key: common.MetricTypeKey, Value: httpReq.MetricType}}, }) + if err == nil { + err = merr.Error(response) + } if err != nil { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()}) - return - } else if response.ErrorCode != commonpb.ErrorCode_Success { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: int32(response.ErrorCode), HTTPReturnMessage: response.Reason}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return } - response, err = h.proxy.LoadCollection(c, &milvuspb.LoadCollectionRequest{ + response, err = h.proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ DbName: httpReq.DbName, CollectionName: httpReq.CollectionName, }) + if err == nil { + err = merr.Error(response) + } if err != nil { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()}) - return - } else if response.ErrorCode != commonpb.ErrorCode_Success { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: int32(response.ErrorCode), HTTPReturnMessage: response.Reason}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) return } c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}}) @@ -244,26 +248,32 @@ func (h *Handlers) getCollectionDetails(c *gin.Context) { collectionName := c.Query(HTTPCollectionName) if collectionName == "" { log.Warn("high level restful api, desc collection require parameter: [collectionName], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) return } dbName := c.DefaultQuery(HTTPDbName, DefaultDbName) - if !h.checkDatabase(c, dbName) { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), dbName) + if !h.checkDatabase(ctx, c, dbName) { return } - coll, err := h.describeCollection(c, dbName, collectionName, true) + coll, err := h.describeCollection(ctx, c, dbName, collectionName, true) if err != nil { return } - stateResp, stateErr := h.proxy.GetLoadState(c, &milvuspb.GetLoadStateRequest{ + stateResp, err := h.proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{ DbName: dbName, CollectionName: collectionName, }) collLoadState := "" - if stateErr != nil { - log.Warn("get collection load state fail", zap.String("collection", collectionName), zap.String("err", stateErr.Error())) - } else if stateResp.Status.ErrorCode != commonpb.ErrorCode_Success { - log.Warn("get collection load state fail", zap.String("collection", collectionName), zap.String("err", stateResp.Status.Reason)) + if err == nil { + err = merr.Error(stateResp.GetStatus()) + } + if err != nil { + log.Warn("get collection load state fail", + zap.String("collection", collectionName), + zap.Error(err), + ) } else { collLoadState = stateResp.State.String() } @@ -274,18 +284,22 @@ func (h *Handlers) getCollectionDetails(c *gin.Context) { break } } - indexResp, indexErr := h.proxy.DescribeIndex(c, &milvuspb.DescribeIndexRequest{ + indexResp, err := h.proxy.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{ DbName: dbName, CollectionName: collectionName, FieldName: vectorField, }) + if err == nil { + err = merr.Error(indexResp.GetStatus()) + } var indexDesc []gin.H - if indexErr != nil { - indexDesc = []gin.H{} - log.Warn("get indexes description fail", zap.String("collection", collectionName), zap.String("vectorField", vectorField), zap.String("err", indexErr.Error())) - } else if indexResp.Status.ErrorCode != commonpb.ErrorCode_Success { + if err != nil { indexDesc = []gin.H{} - log.Warn("get indexes description fail", zap.String("collection", collectionName), zap.String("vectorField", vectorField), zap.String("err", indexResp.Status.Reason)) + log.Warn("get indexes description fail", + zap.String("collection", collectionName), + zap.String("vectorField", vectorField), + zap.Error(err), + ) } else { indexDesc = printIndexes(indexResp.IndexDescriptions) } @@ -306,37 +320,40 @@ func (h *Handlers) dropCollection(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of drop collection is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) return } if httpReq.CollectionName == "" { log.Warn("high level restful api, drop collection require parameter: [collectionName], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) return } req := milvuspb.DropCollectionRequest{ DbName: httpReq.DbName, CollectionName: httpReq.CollectionName, } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - has, err := h.hasCollection(c, httpReq.DbName, httpReq.CollectionName) + has, err := h.hasCollection(ctx, c, httpReq.DbName, httpReq.CollectionName) if err != nil { return } if !has { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCollectionNotFound), HTTPReturnMessage: merr.ErrCollectionNotFound.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCollectionNotFound), HTTPReturnMessage: merr.ErrCollectionNotFound.Error()}) return } - response, err := h.proxy.DropCollection(c, &req) + response, err := h.proxy.DropCollection(ctx, &req) + if err == nil { + err = merr.Error(response) + } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()}) - } else if response.ErrorCode != commonpb.ErrorCode_Success { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: int32(response.ErrorCode), HTTPReturnMessage: response.Reason}) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}}) } @@ -350,12 +367,12 @@ func (h *Handlers) query(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of query is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) return } if httpReq.CollectionName == "" || httpReq.Filter == "" { log.Warn("high level restful api, query require parameter: [collectionName, filter], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) return } req := milvuspb.QueryRequest{ @@ -372,22 +389,25 @@ func (h *Handlers) query(c *gin.Context) { if httpReq.Limit > 0 { req.QueryParams = append(req.QueryParams, &commonpb.KeyValuePair{Key: ParamLimit, Value: strconv.FormatInt(int64(httpReq.Limit), 10)}) } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - response, err := h.proxy.Query(c, &req) + response, err := h.proxy.Query(ctx, &req) + if err == nil { + err = merr.Error(response.GetStatus()) + } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()}) - } else if response.Status.ErrorCode != commonpb.ErrorCode_Success { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: int32(response.Status.ErrorCode), HTTPReturnMessage: response.Status.Reason}) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { outputData, err := buildQueryResp(int64(0), response.OutputFields, response.FieldsData, nil, nil) if err != nil { log.Warn("high level restful api, fail to deal with query result", zap.Any("response", response), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()}) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()}) } else { c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) } @@ -401,12 +421,12 @@ func (h *Handlers) get(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of get is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) return } if httpReq.CollectionName == "" || httpReq.ID == nil { log.Warn("high level restful api, get require parameter: [collectionName, id], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) return } req := milvuspb.QueryRequest{ @@ -415,33 +435,36 @@ func (h *Handlers) get(c *gin.Context) { OutputFields: httpReq.OutputFields, GuaranteeTimestamp: BoundedTimestamp, } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false) + coll, err := h.describeCollection(ctx, c, httpReq.DbName, httpReq.CollectionName, false) if err != nil || coll == nil { return } body, _ := c.Get(gin.BodyBytesKey) filter, err := checkGetPrimaryKey(coll.Schema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName)) if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) return } req.Expr = filter - response, err := h.proxy.Query(c, &req) + response, err := h.proxy.Query(ctx, &req) + if err == nil { + err = merr.Error(response.GetStatus()) + } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()}) - } else if response.Status.ErrorCode != commonpb.ErrorCode_Success { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: int32(response.Status.ErrorCode), HTTPReturnMessage: response.Status.Reason}) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { outputData, err := buildQueryResp(int64(0), response.OutputFields, response.FieldsData, nil, nil) if err != nil { log.Warn("high level restful api, fail to deal with get result", zap.Any("response", response), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()}) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()}) } else { c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) log.Error("get resultIS: ", zap.Any("res", outputData)) @@ -455,40 +478,46 @@ func (h *Handlers) delete(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of delete is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) return } - if httpReq.CollectionName == "" || httpReq.ID == nil { - log.Warn("high level restful api, delete require parameter: [collectionName, id], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + if httpReq.CollectionName == "" || (httpReq.ID == nil && httpReq.Filter == "") { + log.Warn("high level restful api, delete require parameter: [collectionName, id/filter], but miss") + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) return } req := milvuspb.DeleteRequest{ DbName: httpReq.DbName, CollectionName: httpReq.CollectionName, } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false) + coll, err := h.describeCollection(ctx, c, httpReq.DbName, httpReq.CollectionName, false) if err != nil || coll == nil { return } - body, _ := c.Get(gin.BodyBytesKey) - filter, err := checkGetPrimaryKey(coll.Schema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName)) - if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) - return + req.Expr = httpReq.Filter + if req.Expr == "" { + body, _ := c.Get(gin.BodyBytesKey) + filter, err := checkGetPrimaryKey(coll.Schema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName)) + if err != nil { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) + return + } + req.Expr = filter + } + response, err := h.proxy.Delete(ctx, &req) + if err == nil { + err = merr.Error(response.GetStatus()) } - req.Expr = filter - response, err := h.proxy.Delete(c, &req) if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()}) - } else if response.Status.ErrorCode != commonpb.ErrorCode_Success { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: int32(response.Status.ErrorCode), HTTPReturnMessage: response.Status.Reason}) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{}}) } @@ -504,7 +533,7 @@ func (h *Handlers) insert(c *gin.Context) { } if err = c.ShouldBindBodyWith(&singleInsertReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of insert is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) return } httpReq.DbName = singleInsertReq.DbName @@ -513,7 +542,7 @@ func (h *Handlers) insert(c *gin.Context) { } if httpReq.CollectionName == "" || httpReq.Data == nil { log.Warn("high level restful api, insert require parameter: [collectionName, data], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) return } req := milvuspb.InsertRequest{ @@ -522,34 +551,37 @@ func (h *Handlers) insert(c *gin.Context) { PartitionName: "_default", NumRows: uint32(len(httpReq.Data)), } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - coll, err := h.describeCollection(c, httpReq.DbName, httpReq.CollectionName, false) + coll, err := h.describeCollection(ctx, c, httpReq.DbName, httpReq.CollectionName, false) if err != nil || coll == nil { return } body, _ := c.Get(gin.BodyBytesKey) - err = checkAndSetData(string(body.([]byte)), coll, &httpReq) + err, httpReq.Data = checkAndSetData(string(body.([]byte)), coll) if err != nil { log.Warn("high level restful api, fail to deal with insert data", zap.Any("body", body), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) return } req.FieldsData, err = anyToColumns(httpReq.Data, coll.Schema) if err != nil { log.Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) return } - response, err := h.proxy.Insert(c, &req) + response, err := h.proxy.Insert(ctx, &req) + if err == nil { + err = merr.Error(response.GetStatus()) + } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()}) - } else if response.Status.ErrorCode != commonpb.ErrorCode_Success { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: int32(response.Status.ErrorCode), HTTPReturnMessage: response.Status.Reason}) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { switch response.IDs.GetIdField().(type) { case *schemapb.IDs_IntId: @@ -557,7 +589,83 @@ func (h *Handlers) insert(c *gin.Context) { case *schemapb.IDs_StrId: c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": response.InsertCnt, "insertIds": response.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}}) default: - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) + } + } +} + +func (h *Handlers) upsert(c *gin.Context) { + httpReq := UpsertReq{ + DbName: DefaultDbName, + } + if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { + singleUpsertReq := SingleUpsertReq{ + DbName: DefaultDbName, + } + if err = c.ShouldBindBodyWith(&singleUpsertReq, binding.JSON); err != nil { + log.Warn("high level restful api, the parameter of insert is incorrect", zap.Any("request", httpReq), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + return + } + httpReq.DbName = singleUpsertReq.DbName + httpReq.CollectionName = singleUpsertReq.CollectionName + httpReq.Data = []map[string]interface{}{singleUpsertReq.Data} + } + if httpReq.CollectionName == "" || httpReq.Data == nil { + log.Warn("high level restful api, insert require parameter: [collectionName, data], but miss") + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + return + } + req := milvuspb.UpsertRequest{ + DbName: httpReq.DbName, + CollectionName: httpReq.CollectionName, + PartitionName: "_default", + NumRows: uint32(len(httpReq.Data)), + } + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { + return + } + if !h.checkDatabase(ctx, c, req.DbName) { + return + } + coll, err := h.describeCollection(ctx, c, httpReq.DbName, httpReq.CollectionName, false) + if err != nil || coll == nil { + return + } + if coll.Schema.AutoID { + err := merr.WrapErrParameterInvalid("autoID: false", "autoID: true", "cannot upsert an autoID collection") + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + return + } + body, _ := c.Get(gin.BodyBytesKey) + err, httpReq.Data = checkAndSetData(string(body.([]byte)), coll) + if err != nil { + log.Warn("high level restful api, fail to deal with insert data", zap.Any("body", body), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) + return + } + req.FieldsData, err = anyToColumns(httpReq.Data, coll.Schema) + if err != nil { + log.Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) + return + } + response, err := h.proxy.Upsert(ctx, &req) + if err == nil { + err = merr.Error(response.GetStatus()) + } + if err != nil { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) + } else { + switch response.IDs.GetIdField().(type) { + case *schemapb.IDs_IntId: + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": response.UpsertCnt, "upsertIds": response.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}}) + case *schemapb.IDs_StrId: + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": response.UpsertCnt, "upsertIds": response.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}}) + default: + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) } } } @@ -569,15 +677,15 @@ func (h *Handlers) search(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of search is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) return } if httpReq.CollectionName == "" || httpReq.Vector == nil { log.Warn("high level restful api, search require parameter: [collectionName, vector], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) return } - params := map[string]interface{}{ //auto generated mapping + params := map[string]interface{}{ // auto generated mapping "level": int(commonpb.ConsistencyLevel_Bounded), } bs, _ := json.Marshal(params) @@ -598,17 +706,20 @@ func (h *Handlers) search(c *gin.Context) { GuaranteeTimestamp: BoundedTimestamp, Nq: int64(1), } - if err := checkAuthorization(c, &req); err != nil { + username, _ := c.Get(ContextUsername) + ctx := proxy.NewContextWithMetadata(c, username.(string), req.DbName) + if err := checkAuthorization(ctx, c, &req); err != nil { return } - if !h.checkDatabase(c, req.DbName) { + if !h.checkDatabase(ctx, c, req.DbName) { return } - response, err := h.proxy.Search(c, &req) + response, err := h.proxy.Search(ctx, &req) + if err == nil { + err = merr.Error(response.GetStatus()) + } if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(err), HTTPReturnMessage: err.Error()}) - } else if response.Status.ErrorCode != commonpb.ErrorCode_Success { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: int32(response.Status.ErrorCode), HTTPReturnMessage: response.Status.Reason}) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { if response.Results.TopK == int64(0) { c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []interface{}{}}) @@ -616,7 +727,7 @@ func (h *Handlers) search(c *gin.Context) { outputData, err := buildQueryResp(response.Results.TopK, response.Results.OutputFields, response.Results.FieldsData, response.Results.Ids, response.Results.Scores) if err != nil { log.Warn("high level restful api, fail to deal with search result", zap.Any("result", response.Results), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()}) + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()}) } else { c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) } diff --git a/internal/distributed/proxy/httpserver/handler_v1_test.go b/internal/distributed/proxy/httpserver/handler_v1_test.go index 4fdcfe3ec3c12..facbdb71832db 100644 --- a/internal/distributed/proxy/httpserver/handler_v1_test.go +++ b/internal/distributed/proxy/httpserver/handler_v1_test.go @@ -8,13 +8,11 @@ import ( "net/http/httptest" "testing" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/cockroachdb/errors" - - "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -22,8 +20,8 @@ import ( "github.com/milvus-io/milvus/internal/proxy" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) const ( @@ -36,7 +34,7 @@ const ( ReturnTrue = 3 ReturnFalse = 4 - URIPrefix = "/v1" + URIPrefixV1 = "/v1" ) var StatusSuccess = commonpb.Status{ @@ -55,7 +53,8 @@ var DefaultDescCollectionResp = milvuspb.DescribeCollectionResponse{ CollectionName: DefaultCollectionName, Schema: generateCollectionSchema(false), ShardsNum: ShardNumDefault, - Status: &StatusSuccess} + Status: &StatusSuccess, +} var DefaultLoadStateResp = milvuspb.GetLoadStateResponse{ Status: &StatusSuccess, @@ -77,10 +76,14 @@ var DefaultFalseResp = milvuspb.BoolResponse{ Value: false, } +func versional(path string) string { + return URIPrefixV1 + path +} + func initHTTPServer(proxy types.ProxyComponent, needAuth bool) *gin.Engine { h := NewHandlers(proxy) ginHandler := gin.Default() - app := ginHandler.Group("/v1", genAuthMiddleWare(needAuth)) + app := ginHandler.Group(URIPrefixV1, genAuthMiddleWare(needAuth)) NewHandlers(h.proxy).RegisterRoutesToV1(app) return ginHandler } @@ -104,11 +107,12 @@ func initHTTPServer(proxy types.ProxyComponent, needAuth bool) *gin.Engine { func genAuthMiddleWare(needAuth bool) gin.HandlerFunc { if needAuth { return func(c *gin.Context) { + c.Set(ContextUsername, "") username, password, ok := ParseUsernamePassword(c) if !ok { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) } else if username == util.UserRoot && password != util.DefaultRootPassword { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) } else { c.Set(ContextUsername, username) } @@ -124,7 +128,7 @@ func Print(code int32, message string) string { } func PrintErr(err error) string { - return Print(Code(err), err.Error()) + return Print(merr.Code(err), err.Error()) } func TestVectorAuthenticate(t *testing.T) { @@ -139,7 +143,7 @@ func TestVectorAuthenticate(t *testing.T) { testEngine := initHTTPServer(mp, true) t.Run("need authentication", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath), nil) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) assert.Equal(t, w.Code, http.StatusUnauthorized) @@ -147,7 +151,7 @@ func TestVectorAuthenticate(t *testing.T) { }) t.Run("username or password incorrect", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath), nil) req.SetBasicAuth(util.UserRoot, util.UserRoot) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -156,7 +160,7 @@ func TestVectorAuthenticate(t *testing.T) { }) t.Run("root's password correct", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath), nil) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -165,7 +169,7 @@ func TestVectorAuthenticate(t *testing.T) { }) t.Run("username and password both provided", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath), nil) req.SetBasicAuth("test", util.UserRoot) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -186,19 +190,16 @@ func TestVectorListCollection(t *testing.T) { expectedBody: PrintErr(ErrDefault), }) - reason := "cannot create folder" mp1 := mocks.NewMockProxy(t) + err := merr.WrapErrIoFailedReason("cannot create folder") mp1.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&milvuspb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_CannotCreateFolder, - Reason: reason, - }, + Status: merr.Status(err), }, nil).Once() testCases = append(testCases, testCase{ name: "show collections fail", mp: mp1, exceptCode: 200, - expectedBody: Print(int32(commonpb.ErrorCode_CannotCreateFolder), reason), + expectedBody: PrintErr(err), }) mp := mocks.NewMockProxy(t) @@ -213,7 +214,7 @@ func TestVectorListCollection(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath), nil) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -272,7 +273,7 @@ func TestVectorCollectionsDescribe(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections/describe?collectionName="+DefaultCollectionName, nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsDescribePath)+"?collectionName="+DefaultCollectionName, nil) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -282,7 +283,7 @@ func TestVectorCollectionsDescribe(t *testing.T) { } t.Run("need collectionName", func(t *testing.T) { testEngine := initHTTPServer(mocks.NewMockProxy(t), true) - req := httptest.NewRequest(http.MethodGet, "/v1/vector/collections/describe?"+DefaultCollectionName, nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsDescribePath)+"?"+DefaultCollectionName, nil) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -304,17 +305,14 @@ func TestVectorCreateCollection(t *testing.T) { expectedBody: PrintErr(ErrDefault), }) - reason := "collection " + DefaultCollectionName + " already exists" + err := merr.WrapErrCollectionResourceLimitExceeded() mp2 := mocks.NewMockProxy(t) - mp2.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_CannotCreateFile, // 18 - Reason: reason, - }, nil).Once() + mp2.EXPECT().CreateCollection(mock.Anything, mock.Anything).Return(merr.Status(err), nil).Once() testCases = append(testCases, testCase{ name: "create collection fail", mp: mp2, exceptCode: 200, - expectedBody: Print(int32(commonpb.ErrorCode_CannotCreateFile), reason), + expectedBody: PrintErr(err), }) mp3 := mocks.NewMockProxy(t) @@ -354,7 +352,7 @@ func TestVectorCreateCollection(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) jsonBody := []byte(`{"collectionName": "` + DefaultCollectionName + `", "dimension": 2}`) bodyReader := bytes.NewReader(jsonBody) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/collections/create", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorCollectionsCreatePath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -381,18 +379,15 @@ func TestVectorDropCollection(t *testing.T) { expectedBody: PrintErr(ErrDefault), }) - reason := "cannot find collection " + DefaultCollectionName + err := merr.WrapErrCollectionNotFound(DefaultCollectionName) mp2 := mocks.NewMockProxy(t) mp2, _ = wrapWithHasCollection(t, mp2, ReturnTrue, 1, nil) - mp2.EXPECT().DropCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_CollectionNotExists, // 4 - Reason: reason, - }, nil).Once() + mp2.EXPECT().DropCollection(mock.Anything, mock.Anything).Return(merr.Status(err), nil).Once() testCases = append(testCases, testCase{ name: "drop collection fail", mp: mp2, exceptCode: 200, - expectedBody: Print(int32(commonpb.ErrorCode_CollectionNotExists), reason), + expectedBody: PrintErr(err), }) mp3 := mocks.NewMockProxy(t) @@ -410,7 +405,7 @@ func TestVectorDropCollection(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) jsonBody := []byte(`{"collectionName": "` + DefaultCollectionName + `"}`) bodyReader := bytes.NewReader(jsonBody) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/collections/drop", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorCollectionsDropPath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -434,20 +429,17 @@ func TestQuery(t *testing.T) { expectedBody: PrintErr(ErrDefault), }) - reason := DefaultCollectionName + " name not found" + err := merr.WrapErrCollectionNotFound(DefaultCollectionName) mp3 := mocks.NewMockProxy(t) mp3, _ = wrapWithDescribeColl(t, mp3, ReturnSuccess, 1, nil) mp3.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_CollectionNameNotFound, // 28 - Reason: reason, - }, + Status: merr.Status(err), }, nil).Twice() testCases = append(testCases, testCase{ name: "query fail", mp: mp3, exceptCode: 200, - expectedBody: Print(int32(commonpb.ErrorCode_CollectionNameNotFound), reason), + expectedBody: PrintErr(err), }) mp4 := mocks.NewMockProxy(t) @@ -493,14 +485,14 @@ func TestQuery(t *testing.T) { func genQueryRequest() *http.Request { jsonBody := []byte(`{"collectionName": "` + DefaultCollectionName + `" , "filter": "book_id in [1,2,3]"}`) bodyReader := bytes.NewReader(jsonBody) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/query", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorQueryPath), bodyReader) return req } func genGetRequest() *http.Request { jsonBody := []byte(`{"collectionName": "` + DefaultCollectionName + `" , "id": [1,2,3]}`) bodyReader := bytes.NewReader(jsonBody) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/get", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorGetPath), bodyReader) return req } @@ -520,20 +512,17 @@ func TestDelete(t *testing.T) { expectedBody: PrintErr(ErrDefault), }) - reason := DefaultCollectionName + " name not found" + err := merr.WrapErrCollectionNotFound(DefaultCollectionName) mp3 := mocks.NewMockProxy(t) mp3, _ = wrapWithDescribeColl(t, mp3, ReturnSuccess, 1, nil) mp3.EXPECT().Delete(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_CollectionNameNotFound, // 28 - Reason: reason, - }, + Status: merr.Status(err), }, nil).Once() testCases = append(testCases, testCase{ name: "delete fail", mp: mp3, exceptCode: 200, - expectedBody: Print(int32(commonpb.ErrorCode_CollectionNameNotFound), reason), + expectedBody: PrintErr(err), }) mp4 := mocks.NewMockProxy(t) @@ -553,7 +542,7 @@ func TestDelete(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) jsonBody := []byte(`{"collectionName": "` + DefaultCollectionName + `" , "id": [1,2,3]}`) bodyReader := bytes.NewReader(jsonBody) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/delete", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorDeletePath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -566,6 +555,34 @@ func TestDelete(t *testing.T) { } } +func TestDeleteForFilter(t *testing.T) { + jsonBodyList := [][]byte{ + []byte(`{"collectionName": "` + DefaultCollectionName + `" , "id": [1,2,3]}`), + []byte(`{"collectionName": "` + DefaultCollectionName + `" , "filter": "id in [1,2,3]"}`), + []byte(`{"collectionName": "` + DefaultCollectionName + `" , "id": [1,2,3], "filter": "id in [1,2,3]"}`), + } + for _, jsonBody := range jsonBodyList { + t.Run("delete success", func(t *testing.T) { + mp := mocks.NewMockProxy(t) + mp, _ = wrapWithDescribeColl(t, mp, ReturnSuccess, 1, nil) + mp.EXPECT().Delete(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + }, nil).Once() + testEngine := initHTTPServer(mp, true) + bodyReader := bytes.NewReader(jsonBody) + req := httptest.NewRequest(http.MethodPost, versional(VectorDeletePath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, w.Code, 200) + assert.Equal(t, w.Body.String(), "{\"code\":200,\"data\":{}}") + resp := map[string]interface{}{} + err := json.Unmarshal(w.Body.Bytes(), &resp) + assert.Equal(t, err, nil) + }) + } +} + func TestInsert(t *testing.T) { paramtable.Init() testCases := []testCase{} @@ -582,20 +599,17 @@ func TestInsert(t *testing.T) { expectedBody: PrintErr(ErrDefault), }) - reason := DefaultCollectionName + " name not found" + err := merr.WrapErrCollectionNotFound(DefaultCollectionName) mp3 := mocks.NewMockProxy(t) mp3, _ = wrapWithDescribeColl(t, mp3, ReturnSuccess, 1, nil) mp3.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_CollectionNameNotFound, // 28 - Reason: reason, - }, + Status: merr.Status(err), }, nil).Once() testCases = append(testCases, testCase{ name: "insert fail", mp: mp3, exceptCode: 200, - expectedBody: Print(int32(commonpb.ErrorCode_CollectionNameNotFound), reason), + expectedBody: PrintErr(err), }) mp4 := mocks.NewMockProxy(t) @@ -638,16 +652,16 @@ func TestInsert(t *testing.T) { expectedBody: "{\"code\":200,\"data\":{\"insertCount\":3,\"insertIds\":[\"1\",\"2\",\"3\"]}}", }) + rows := generateSearchResult() + data, _ := json.Marshal(map[string]interface{}{ + HTTPCollectionName: DefaultCollectionName, + HTTPReturnData: rows[0], + }) for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { testEngine := initHTTPServer(tt.mp, true) - rows := generateSearchResult() - data, _ := json.Marshal(map[string]interface{}{ - HTTPCollectionName: DefaultCollectionName, - HTTPReturnData: rows[0], - }) bodyReader := bytes.NewReader(data) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/insert", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorInsertPath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -664,7 +678,7 @@ func TestInsert(t *testing.T) { mp, _ = wrapWithDescribeColl(t, mp, ReturnSuccess, 1, nil) testEngine := initHTTPServer(mp, true) bodyReader := bytes.NewReader([]byte(`{"collectionName": "` + DefaultCollectionName + `", "data": {}}`)) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/insert", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorInsertPath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -704,7 +718,7 @@ func TestInsertForDataType(t *testing.T) { HTTPReturnData: rows, }) bodyReader := bytes.NewReader(data) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/insert", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorInsertPath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -731,7 +745,7 @@ func TestInsertForDataType(t *testing.T) { HTTPReturnData: rows, }) bodyReader := bytes.NewReader(data) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/insert", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorInsertPath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -741,6 +755,113 @@ func TestInsertForDataType(t *testing.T) { } } +func TestUpsert(t *testing.T) { + paramtable.Init() + testCases := []testCase{} + _, testCases = wrapWithDescribeColl(t, nil, ReturnFail, 1, testCases) + _, testCases = wrapWithDescribeColl(t, nil, ReturnWrongStatus, 1, testCases) + + mp2 := mocks.NewMockProxy(t) + mp2, _ = wrapWithDescribeColl(t, mp2, ReturnSuccess, 1, nil) + mp2.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, ErrDefault).Once() + testCases = append(testCases, testCase{ + name: "insert fail", + mp: mp2, + exceptCode: 200, + expectedBody: PrintErr(ErrDefault), + }) + + err := merr.WrapErrCollectionNotFound(DefaultCollectionName) + mp3 := mocks.NewMockProxy(t) + mp3, _ = wrapWithDescribeColl(t, mp3, ReturnSuccess, 1, nil) + mp3.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: merr.Status(err), + }, nil).Once() + testCases = append(testCases, testCase{ + name: "insert fail", + mp: mp3, + exceptCode: 200, + expectedBody: PrintErr(err), + }) + + mp4 := mocks.NewMockProxy(t) + mp4, _ = wrapWithDescribeColl(t, mp4, ReturnSuccess, 1, nil) + mp4.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + }, nil).Once() + testCases = append(testCases, testCase{ + name: "id type invalid", + mp: mp4, + exceptCode: 200, + expectedBody: PrintErr(merr.ErrCheckPrimaryKey), + }) + + mp5 := mocks.NewMockProxy(t) + mp5, _ = wrapWithDescribeColl(t, mp5, ReturnSuccess, 1, nil) + mp5.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + IDs: getIntIds(), + UpsertCnt: 3, + }, nil).Once() + testCases = append(testCases, testCase{ + name: "upsert success", + mp: mp5, + exceptCode: 200, + expectedBody: "{\"code\":200,\"data\":{\"upsertCount\":3,\"upsertIds\":[1,2,3]}}", + }) + + mp6 := mocks.NewMockProxy(t) + mp6, _ = wrapWithDescribeColl(t, mp6, ReturnSuccess, 1, nil) + mp6.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + IDs: getStrIds(), + UpsertCnt: 3, + }, nil).Once() + testCases = append(testCases, testCase{ + name: "upsert success", + mp: mp6, + exceptCode: 200, + expectedBody: "{\"code\":200,\"data\":{\"upsertCount\":3,\"upsertIds\":[\"1\",\"2\",\"3\"]}}", + }) + + rows := generateSearchResult() + data, _ := json.Marshal(map[string]interface{}{ + HTTPCollectionName: DefaultCollectionName, + HTTPReturnData: rows[0], + }) + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + testEngine := initHTTPServer(tt.mp, true) + bodyReader := bytes.NewReader(data) + req := httptest.NewRequest(http.MethodPost, versional(VectorUpsertPath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, w.Code, tt.exceptCode) + assert.Equal(t, w.Body.String(), tt.expectedBody) + resp := map[string]interface{}{} + err := json.Unmarshal(w.Body.Bytes(), &resp) + assert.Equal(t, err, nil) + }) + } + + t.Run("wrong request body", func(t *testing.T) { + mp := mocks.NewMockProxy(t) + mp, _ = wrapWithDescribeColl(t, mp, ReturnSuccess, 1, nil) + testEngine := initHTTPServer(mp, true) + bodyReader := bytes.NewReader([]byte(`{"collectionName": "` + DefaultCollectionName + `", "data": {}}`)) + req := httptest.NewRequest(http.MethodPost, versional(VectorUpsertPath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, w.Code, 200) + assert.Equal(t, w.Body.String(), PrintErr(merr.ErrInvalidInsertData)) + resp := map[string]interface{}{} + err := json.Unmarshal(w.Body.Bytes(), &resp) + assert.Equal(t, err, nil) + }) +} + func getIntIds() *schemapb.IDs { ids := schemapb.IDs{ IdField: &schemapb.IDs_IntId{ @@ -776,19 +897,16 @@ func TestSearch(t *testing.T) { expectedBody: PrintErr(ErrDefault), }) - reason := DefaultCollectionName + " name not found" + err := merr.WrapErrCollectionNotFound(DefaultCollectionName) mp3 := mocks.NewMockProxy(t) mp3.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_CollectionNameNotFound, // 28 - Reason: reason, - }, + Status: merr.Status(err), }, nil).Once() testCases = append(testCases, testCase{ name: "search fail", mp: mp3, exceptCode: 200, - expectedBody: Print(int32(commonpb.ErrorCode_CollectionNameNotFound), reason), + expectedBody: PrintErr(err), }) mp4 := mocks.NewMockProxy(t) @@ -816,7 +934,7 @@ func TestSearch(t *testing.T) { "vector": rows, }) bodyReader := bytes.NewReader(data) - req := httptest.NewRequest(http.MethodPost, "/v1/vector/search", bodyReader) + req := httptest.NewRequest(http.MethodPost, versional(VectorSearchPath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -862,17 +980,15 @@ func wrapWithDescribeColl(t *testing.T, mp *mocks.MockProxy, returnType ReturnTy expectedBody: PrintErr(ErrDefault), } case ReturnWrongStatus: + err := merr.WrapErrCollectionNotFound(DefaultCollectionName) call = mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_CollectionNotExists, - Reason: "can't find collection: " + DefaultCollectionName, - }, + Status: merr.Status(err), }, nil) testcase = testCase{ name: "[share] collection not found", mp: mp, exceptCode: 200, - expectedBody: "{\"code\":4,\"message\":\"can't find collection: " + DefaultCollectionName + "\"}", + expectedBody: PrintErr(err), } } if times == 2 { @@ -918,18 +1034,15 @@ func wrapWithHasCollection(t *testing.T, mp *mocks.MockProxy, returnType ReturnT expectedBody: PrintErr(ErrDefault), } case ReturnWrongStatus: - reason := "can't find collection: " + DefaultCollectionName + err := merr.WrapErrCollectionNotFound(DefaultCollectionName) call = mp.EXPECT().HasCollection(mock.Anything, mock.Anything).Return(&milvuspb.BoolResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, // 1 - Reason: reason, - }, + Status: merr.Status(err), }, nil) testcase = testCase{ name: "[share] unexpected error", mp: mp, exceptCode: 200, - expectedBody: Print(int32(commonpb.ErrorCode_UnexpectedError), reason), + expectedBody: PrintErr(err), } } if times == 2 { @@ -959,28 +1072,31 @@ func TestHttpRequestFormat(t *testing.T) { } paths := [][]string{ { - URIPrefix + VectorCollectionsCreatePath, - URIPrefix + VectorCollectionsDropPath, - URIPrefix + VectorGetPath, - URIPrefix + VectorSearchPath, - URIPrefix + VectorQueryPath, - URIPrefix + VectorInsertPath, - URIPrefix + VectorDeletePath, + versional(VectorCollectionsCreatePath), + versional(VectorCollectionsDropPath), + versional(VectorGetPath), + versional(VectorSearchPath), + versional(VectorQueryPath), + versional(VectorInsertPath), + versional(VectorUpsertPath), + versional(VectorDeletePath), }, { - URIPrefix + VectorCollectionsDropPath, - URIPrefix + VectorGetPath, - URIPrefix + VectorSearchPath, - URIPrefix + VectorQueryPath, - URIPrefix + VectorInsertPath, - URIPrefix + VectorDeletePath, + versional(VectorCollectionsDropPath), + versional(VectorGetPath), + versional(VectorSearchPath), + versional(VectorQueryPath), + versional(VectorInsertPath), + versional(VectorUpsertPath), + versional(VectorDeletePath), }, { - URIPrefix + VectorCollectionsCreatePath, + versional(VectorCollectionsCreatePath), }, { - URIPrefix + VectorGetPath, - URIPrefix + VectorSearchPath, - URIPrefix + VectorQueryPath, - URIPrefix + VectorInsertPath, - URIPrefix + VectorDeletePath, + versional(VectorGetPath), + versional(VectorSearchPath), + versional(VectorQueryPath), + versional(VectorInsertPath), + versional(VectorUpsertPath), + versional(VectorDeletePath), }, } for i, pathArr := range paths { @@ -1002,15 +1118,16 @@ func TestHttpRequestFormat(t *testing.T) { func TestAuthorization(t *testing.T) { paramtable.Init() paramtable.Get().Save(proxy.Params.CommonCfg.AuthorizationEnabled.Key, "true") - errorStr := Print(Code(merr.ErrServiceUnavailable), "internal: Milvus Proxy is not ready yet. please wait: service unavailable") + errorStr := Print(merr.Code(merr.ErrServiceUnavailable), "internal: Milvus Proxy is not ready yet. please wait: service unavailable") jsons := map[string][]byte{ errorStr: []byte(`{"collectionName": "` + DefaultCollectionName + `", "vector": [0.1, 0.2], "filter": "id in [2]", "id": [2], "dimension": 2, "data":[{"book_id":1,"book_intro":[0.1,0.11],"distance":0.01,"word_count":1000},{"book_id":2,"book_intro":[0.2,0.22],"distance":0.04,"word_count":2000},{"book_id":3,"book_intro":[0.3,0.33],"distance":0.09,"word_count":3000}]}`), } paths := map[string][]string{ errorStr: { - URIPrefix + VectorGetPath, - URIPrefix + VectorInsertPath, - URIPrefix + VectorDeletePath, + versional(VectorGetPath), + versional(VectorInsertPath), + versional(VectorUpsertPath), + versional(VectorDeletePath), }, } for res, pathArr := range paths { @@ -1031,7 +1148,7 @@ func TestAuthorization(t *testing.T) { paths = map[string][]string{ errorStr: { - URIPrefix + VectorCollectionsCreatePath, + versional(VectorCollectionsCreatePath), }, } for res, pathArr := range paths { @@ -1052,7 +1169,7 @@ func TestAuthorization(t *testing.T) { paths = map[string][]string{ errorStr: { - URIPrefix + VectorCollectionsDropPath, + versional(VectorCollectionsDropPath), }, } for res, pathArr := range paths { @@ -1073,8 +1190,8 @@ func TestAuthorization(t *testing.T) { paths = map[string][]string{ errorStr: { - URIPrefix + VectorCollectionsPath, - URIPrefix + VectorCollectionsDescribePath + "?collectionName=" + DefaultCollectionName, + versional(VectorCollectionsPath), + versional(VectorCollectionsDescribePath) + "?collectionName=" + DefaultCollectionName, }, } for res, pathArr := range paths { @@ -1093,8 +1210,8 @@ func TestAuthorization(t *testing.T) { } paths = map[string][]string{ errorStr: { - URIPrefix + VectorQueryPath, - URIPrefix + VectorSearchPath, + versional(VectorQueryPath), + versional(VectorSearchPath), }, } for res, pathArr := range paths { @@ -1112,7 +1229,6 @@ func TestAuthorization(t *testing.T) { }) } } - } func TestDatabaseNotFound(t *testing.T) { @@ -1122,7 +1238,7 @@ func TestDatabaseNotFound(t *testing.T) { mp := mocks.NewMockProxy(t) mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(nil, ErrDefault).Once() testEngine := initHTTPServer(mp, true) - req := httptest.NewRequest(http.MethodGet, URIPrefix+VectorCollectionsPath+"?dbName=test", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath)+"?dbName=test", nil) req.Header.Set("authorization", "Bearer root:Milvus") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -1132,19 +1248,17 @@ func TestDatabaseNotFound(t *testing.T) { t.Run("list database without success code", func(t *testing.T) { mp := mocks.NewMockProxy(t) + err := errors.New("unexpected error") mp.EXPECT().ListDatabases(mock.Anything, mock.Anything).Return(&milvuspb.ListDatabasesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "", - }, + Status: merr.Status(err), }, nil).Once() testEngine := initHTTPServer(mp, true) - req := httptest.NewRequest(http.MethodGet, URIPrefix+VectorCollectionsPath+"?dbName=test", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath)+"?dbName=test", nil) req.Header.Set("authorization", "Bearer root:Milvus") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), Print(int32(commonpb.ErrorCode_UnexpectedError), "")) + assert.Equal(t, w.Body.String(), PrintErr(err)) }) t.Run("list database success", func(t *testing.T) { @@ -1160,7 +1274,7 @@ func TestDatabaseNotFound(t *testing.T) { CollectionNames: nil, }, nil).Once() testEngine := initHTTPServer(mp, true) - req := httptest.NewRequest(http.MethodGet, URIPrefix+VectorCollectionsPath+"?dbName=test", nil) + req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath)+"?dbName=test", nil) req.Header.Set("authorization", "Bearer root:Milvus") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -1168,11 +1282,11 @@ func TestDatabaseNotFound(t *testing.T) { assert.Equal(t, w.Body.String(), "{\"code\":200,\"data\":[]}") }) - errorStr := PrintErr(merr.ErrDatabaseNotfound) + errorStr := PrintErr(merr.ErrDatabaseNotFound) paths := map[string][]string{ errorStr: { - URIPrefix + VectorCollectionsPath + "?dbName=test", - URIPrefix + VectorCollectionsDescribePath + "?dbName=test&collectionName=" + DefaultCollectionName, + versional(VectorCollectionsPath) + "?dbName=test", + versional(VectorCollectionsDescribePath) + "?dbName=test&collectionName=" + DefaultCollectionName, }, } for res, pathArr := range paths { @@ -1197,13 +1311,14 @@ func TestDatabaseNotFound(t *testing.T) { requestBody := `{"dbName": "test", "collectionName": "` + DefaultCollectionName + `", "vector": [0.1, 0.2], "filter": "id in [2]", "id": [2], "dimension": 2, "data":[{"book_id":1,"book_intro":[0.1,0.11],"distance":0.01,"word_count":1000},{"book_id":2,"book_intro":[0.2,0.22],"distance":0.04,"word_count":2000},{"book_id":3,"book_intro":[0.3,0.33],"distance":0.09,"word_count":3000}]}` paths = map[string][]string{ requestBody: { - URIPrefix + VectorCollectionsCreatePath, - URIPrefix + VectorCollectionsDropPath, - URIPrefix + VectorInsertPath, - URIPrefix + VectorDeletePath, - URIPrefix + VectorQueryPath, - URIPrefix + VectorGetPath, - URIPrefix + VectorSearchPath, + versional(VectorCollectionsCreatePath), + versional(VectorCollectionsDropPath), + versional(VectorInsertPath), + versional(VectorUpsertPath), + versional(VectorDeletePath), + versional(VectorQueryPath), + versional(VectorGetPath), + versional(VectorSearchPath), }, } for request, pathArr := range paths { @@ -1347,6 +1462,7 @@ func Test_Handles_VectorCollectionsDescribe(t *testing.T) { h := NewHandlers(mp) testEngine := gin.New() app := testEngine.Group("/", func(c *gin.Context) { + c.Set(ContextUsername, "") username, _, ok := ParseUsernamePassword(c) if ok { c.Set(ContextUsername, username) @@ -1368,7 +1484,7 @@ func Test_Handles_VectorCollectionsDescribe(t *testing.T) { w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) assert.Equal(t, w.Code, http.StatusForbidden) - assert.Equal(t, w.Body.String(), Print(Code(merr.ErrServiceUnavailable), "internal: Milvus Proxy is not ready yet. please wait: service unavailable")) + assert.Equal(t, w.Body.String(), Print(merr.Code(merr.ErrServiceUnavailable), "internal: Milvus Proxy is not ready yet. please wait: service unavailable")) }) t.Run("describe collection fail with error", func(t *testing.T) { @@ -1386,18 +1502,20 @@ func Test_Handles_VectorCollectionsDescribe(t *testing.T) { }) t.Run("describe collection fail with status code", func(t *testing.T) { + err := merr.WrapErrDatabaseNotFound(DefaultDbName) paramtable.Get().Save(proxy.Params.CommonCfg.AuthorizationEnabled.Key, "false") mp.EXPECT(). DescribeCollection(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil). + Status: merr.Status(err), + }, nil). Once() req := httptest.NewRequest(http.MethodGet, "/vector/collections/describe?collectionName=book", nil) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), "{\"code\":1,\"message\":\"\"}") + assert.Equal(t, w.Body.String(), PrintErr(err)) }) t.Run("get load state and describe index fail with error", func(t *testing.T) { @@ -1405,7 +1523,8 @@ func Test_Handles_VectorCollectionsDescribe(t *testing.T) { DescribeCollection(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ Schema: getCollectionSchema("collectionName"), - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil). + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + }, nil). Once() mp.EXPECT(). GetLoadState(mock.Anything, mock.Anything). @@ -1428,17 +1547,20 @@ func Test_Handles_VectorCollectionsDescribe(t *testing.T) { DescribeCollection(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ Schema: getCollectionSchema("collectionName"), - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil). + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + }, nil). Once() mp.EXPECT(). GetLoadState(mock.Anything, mock.Anything). Return(&milvuspb.GetLoadStateResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil). + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, + }, nil). Once() mp.EXPECT(). DescribeIndex(mock.Anything, mock.Anything). Return(&milvuspb.DescribeIndexResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil). + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, + }, nil). Once() req := httptest.NewRequest(http.MethodGet, "/vector/collections/describe?collectionName=book", nil) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) @@ -1453,7 +1575,8 @@ func Test_Handles_VectorCollectionsDescribe(t *testing.T) { DescribeCollection(mock.Anything, mock.Anything). Return(&milvuspb.DescribeCollectionResponse{ Schema: getCollectionSchema("collectionName"), - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil). + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + }, nil). Once() mp.EXPECT(). GetLoadState(mock.Anything, mock.Anything). @@ -1477,7 +1600,8 @@ func Test_Handles_VectorCollectionsDescribe(t *testing.T) { }, }, }, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil). + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, + }, nil). Once() req := httptest.NewRequest(http.MethodGet, "/vector/collections/describe?collectionName=book", nil) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) diff --git a/internal/distributed/proxy/httpserver/request.go b/internal/distributed/proxy/httpserver/request.go index c14cb68343ddc..0ffded910444c 100644 --- a/internal/distributed/proxy/httpserver/request.go +++ b/internal/distributed/proxy/httpserver/request.go @@ -34,7 +34,8 @@ type GetReq struct { type DeleteReq struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" validate:"required"` - ID interface{} `json:"id" validate:"required"` + ID interface{} `json:"id"` + Filter string `json:"filter"` } type InsertReq struct { @@ -49,6 +50,18 @@ type SingleInsertReq struct { Data map[string]interface{} `json:"data" validate:"required"` } +type UpsertReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" validate:"required"` + Data []map[string]interface{} `json:"data" validate:"required"` +} + +type SingleUpsertReq struct { + DbName string `json:"dbName"` + CollectionName string `json:"collectionName" validate:"required"` + Data map[string]interface{} `json:"data" validate:"required"` +} + type SearchReq struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" validate:"required"` diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index fd29eb93893dd..0a502d3344cbd 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -10,39 +10,32 @@ import ( "strconv" "strings" - "github.com/milvus-io/milvus/pkg/util/parameterutil.go" - - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/cockroachdb/errors" - "github.com/gin-gonic/gin" "github.com/golang/protobuf/proto" "github.com/spf13/cast" "github.com/tidwall/gjson" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/parameterutil.go" ) func ParseUsernamePassword(c *gin.Context) (string, string, bool) { username, password, ok := c.Request.BasicAuth() if !ok { - auth := c.Request.Header.Get("Authorization") - if auth != "" { - token := strings.TrimPrefix(auth, "Bearer ") - if token != auth { - i := strings.IndexAny(token, ":") - if i != -1 { - username = token[:i] - password = token[i+1:] - } - } + token := GetAuthorization(c) + i := strings.IndexAny(token, util.CredentialSeperator) + if i != -1 { + username = token[:i] + password = token[i+1:] } } else { c.Header("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`) @@ -50,6 +43,11 @@ func ParseUsernamePassword(c *gin.Context) (string, string, bool) { return username, password, username != "" && password != "" } +func GetAuthorization(c *gin.Context) string { + auth := c.Request.Header.Get("Authorization") + return strings.TrimPrefix(auth, "Bearer ") +} + // find the primary field of collection func getPrimaryField(schema *schemapb.CollectionSchema) (*schemapb.FieldSchema, bool) { for _, field := range schema.Fields { @@ -174,12 +172,12 @@ func printIndexes(indexes []*milvuspb.IndexDescription) []gin.H { // --------------------- insert param --------------------- // -func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionResponse, req *InsertReq) error { +func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionResponse) (error, []map[string]interface{}) { var reallyDataArray []map[string]interface{} dataResult := gjson.Get(body, "data") dataResultArray := dataResult.Array() if len(dataResultArray) == 0 { - return errors.New("data is required") + return merr.ErrMissingRequiredParameters, reallyDataArray } var fieldNames []string @@ -200,7 +198,7 @@ func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionRespo if field.IsPrimaryKey && collDescResp.Schema.AutoID { if dataString != "" { - return fmt.Errorf("fieldName %s AutoId already open, not support insert data %s", fieldName, dataString) + return merr.WrapErrParameterInvalid("", "set primary key but autoID == true"), reallyDataArray } continue } @@ -219,31 +217,31 @@ func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionRespo case schemapb.DataType_Bool: result, err := cast.ToBoolE(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to bool error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_Int8: result, err := cast.ToInt8E(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to int8 error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_Int16: result, err := cast.ToInt16E(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to int16 error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_Int32: result, err := cast.ToInt32E(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to int32 error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_Int64: result, err := cast.ToInt64E(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to int64 error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_JSON: @@ -251,13 +249,13 @@ func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionRespo case schemapb.DataType_Float: result, err := cast.ToFloat32E(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to float32 error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_Double: result, err := cast.ToFloat64E(dataString) if err != nil { - return fmt.Errorf("dataString %s cast to float64 error: %s", dataString, err.Error()) + return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray } reallyData[fieldName] = result case schemapb.DataType_VarChar: @@ -265,7 +263,7 @@ func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionRespo case schemapb.DataType_String: reallyData[fieldName] = dataString default: - return fmt.Errorf("not support fieldName %s dataType %s", fieldName, fieldType) + return merr.WrapErrParameterInvalid("", schemapb.DataType_name[int32(fieldType)], "fieldName: "+fieldName), reallyDataArray } } @@ -274,20 +272,23 @@ func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionRespo for mapKey, mapValue := range data.Map() { if !containsString(fieldNames, mapKey) { mapValueStr := mapValue.String() - if mapValue.Type == gjson.True || mapValue.Type == gjson.False { + switch mapValue.Type { + case gjson.True, gjson.False: reallyData[mapKey] = cast.ToBool(mapValueStr) - } else if mapValue.Type == gjson.String { + case gjson.String: reallyData[mapKey] = mapValueStr - } else if mapValue.Type == gjson.Number { + case gjson.Number: if strings.Contains(mapValue.Raw, ".") { reallyData[mapKey] = cast.ToFloat64(mapValue.Raw) } else { reallyData[mapKey] = cast.ToInt64(mapValueStr) } - } else if mapValue.Type == gjson.JSON { + case gjson.JSON: reallyData[mapKey] = mapValue.Value() - } else { - + case gjson.Null: + // skip null + default: + log.Warn("unknown json type found", zap.Int("mapValue.Type", int(mapValue.Type))) } } } @@ -295,11 +296,10 @@ func checkAndSetData(body string, collDescResp *milvuspb.DescribeCollectionRespo reallyDataArray = append(reallyDataArray, reallyData) } else { - return fmt.Errorf("dataType %s not Json", data.Type) + return merr.WrapErrParameterInvalid(gjson.JSON, data.Type, "NULL:0, FALSE:1, NUMBER:2, STRING:3, TRUE:4, JSON:5"), reallyDataArray } } - req.Data = reallyDataArray - return nil + return nil, reallyDataArray } func containsString(arr []string, s string) bool { @@ -830,7 +830,6 @@ func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemap } } } - } default: row[fieldDataList[j].FieldName] = "" @@ -857,9 +856,3 @@ func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemap return queryResp, nil } - -// --------------------- error code --------------------- // - -func Code(err error) int32 { - return merr.Code(err) -} diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index cb10e4aa023ea..abf61fe170846 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -5,12 +5,13 @@ import ( "testing" "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" - "github.com/stretchr/testify/assert" - "github.com/tidwall/gjson" ) const ( @@ -115,7 +116,8 @@ func generateIndexes() []*milvuspb.IndexDescription { { Key: "index_type", Value: "IVF_FLAT", - }, { + }, + { Key: Params, Value: "{\"nlist\":1024}", }, @@ -252,25 +254,29 @@ func TestPrintCollectionDetails(t *testing.T) { HTTPReturnFieldType: "Int64", HTTPReturnFieldPrimaryKey: true, HTTPReturnFieldAutoID: false, - HTTPReturnDescription: ""}, + HTTPReturnDescription: "", + }, { HTTPReturnFieldName: FieldWordCount, HTTPReturnFieldType: "Int64", HTTPReturnFieldPrimaryKey: false, HTTPReturnFieldAutoID: false, - HTTPReturnDescription: ""}, + HTTPReturnDescription: "", + }, { HTTPReturnFieldName: FieldBookIntro, HTTPReturnFieldType: "FloatVector(2)", HTTPReturnFieldPrimaryKey: false, HTTPReturnFieldAutoID: false, - HTTPReturnDescription: ""}, + HTTPReturnDescription: "", + }, }) assert.Equal(t, printIndexes(indexes), []gin.H{ { HTTPReturnIndexName: DefaultIndexName, HTTPReturnIndexField: FieldBookIntro, - HTTPReturnIndexMetricsType: DefaultMetricType}, + HTTPReturnIndexMetricsType: DefaultMetricType, + }, }) assert.Equal(t, getMetricType(indexes[0].Params), DefaultMetricType) assert.Equal(t, getMetricType(nil), DefaultMetricType) @@ -286,7 +292,8 @@ func TestPrintCollectionDetails(t *testing.T) { HTTPReturnFieldType: "VarChar(10)", HTTPReturnFieldPrimaryKey: false, HTTPReturnFieldAutoID: false, - HTTPReturnDescription: ""}, + HTTPReturnDescription: "", + }, }) } @@ -321,13 +328,14 @@ func TestPrimaryField(t *testing.T) { } func TestInsertWithDynamicFields(t *testing.T) { - body := "{\"data\": {\"id\": 0, \"book_id\": 1, \"book_intro\": [0.1, 0.2], \"word_count\": 2}}" + body := "{\"data\": {\"id\": 0, \"book_id\": 1, \"book_intro\": [0.1, 0.2], \"word_count\": 2, \"classified\": false, \"databaseID\": null}}" req := InsertReq{} coll := generateCollectionSchema(false) - err := checkAndSetData(body, &milvuspb.DescribeCollectionResponse{ + var err error + err, req.Data = checkAndSetData(body, &milvuspb.DescribeCollectionResponse{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Schema: coll, - }, &req) + }) assert.Equal(t, err, nil) assert.Equal(t, req.Data[0]["id"], int64(0)) assert.Equal(t, req.Data[0]["book_id"], int64(1)) @@ -336,13 +344,13 @@ func TestInsertWithDynamicFields(t *testing.T) { assert.Equal(t, err, nil) assert.Equal(t, fieldsData[len(fieldsData)-1].IsDynamic, true) assert.Equal(t, fieldsData[len(fieldsData)-1].Type, schemapb.DataType_JSON) - assert.Equal(t, string(fieldsData[len(fieldsData)-1].GetScalars().GetJsonData().GetData()[0]), "{\"id\":0}") + assert.Equal(t, string(fieldsData[len(fieldsData)-1].GetScalars().GetJsonData().GetData()[0]), "{\"classified\":false,\"id\":0}") } func TestSerialize(t *testing.T) { parameters := []float32{0.11111, 0.22222} - //assert.Equal(t, string(serialize(parameters)), "\ufffd\ufffd\ufffd=\ufffd\ufffdc\u003e") - //assert.Equal(t, string(vector2PlaceholderGroupBytes(parameters)), "vector2PlaceholderGroupBytes") // todo + // assert.Equal(t, string(serialize(parameters)), "\ufffd\ufffd\ufffd=\ufffd\ufffdc\u003e") + // assert.Equal(t, string(vector2PlaceholderGroupBytes(parameters)), "vector2PlaceholderGroupBytes") // todo assert.Equal(t, string(serialize(parameters)), "\xa4\x8d\xe3=\xa4\x8dc>") assert.Equal(t, string(vector2PlaceholderGroupBytes(parameters)), "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>") // todo } @@ -373,7 +381,6 @@ func compareRow64(m1 map[string]interface{}, m2 map[string]interface{}) bool { } } return true - } func compareRow(m1 map[string]interface{}, m2 map[string]interface{}) bool { @@ -413,7 +420,6 @@ func compareRow(m1 map[string]interface{}, m2 map[string]interface{}) bool { } } return true - } type CompareFunc func(map[string]interface{}, map[string]interface{}) bool @@ -785,11 +791,13 @@ func TestBuildQueryResps(t *testing.T) { assert.Equal(t, compareRows(rows, exceptRows, compareRow), true) } - dataTypes := []schemapb.DataType{schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, + dataTypes := []schemapb.DataType{ + schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, schemapb.DataType_Bool, schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, schemapb.DataType_Float, schemapb.DataType_Double, schemapb.DataType_String, schemapb.DataType_VarChar, - schemapb.DataType_JSON, schemapb.DataType_Array} + schemapb.DataType_JSON, schemapb.DataType_Array, + } for _, dateType := range dataTypes { _, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, dateType), generateIds(3), []float32{0.01, 0.04, 0.09}) assert.Equal(t, err, nil) diff --git a/internal/distributed/proxy/httpserver/wrap_request.go b/internal/distributed/proxy/httpserver/wrap_request.go index b9d463731cdaa..a8f5eec8b98e1 100644 --- a/internal/distributed/proxy/httpserver/wrap_request.go +++ b/internal/distributed/proxy/httpserver/wrap_request.go @@ -7,8 +7,8 @@ import ( "math" "github.com/cockroachdb/errors" - "github.com/golang/protobuf/proto" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" diff --git a/internal/distributed/proxy/httpserver/wrap_request_test.go b/internal/distributed/proxy/httpserver/wrap_request_test.go index 92cde8d4d9b7c..defddf831a2c7 100644 --- a/internal/distributed/proxy/httpserver/wrap_request_test.go +++ b/internal/distributed/proxy/httpserver/wrap_request_test.go @@ -4,9 +4,10 @@ import ( "encoding/json" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/stretchr/testify/assert" ) func TestFieldData_AsSchemapb(t *testing.T) { diff --git a/internal/distributed/proxy/httpserver/wrapper.go b/internal/distributed/proxy/httpserver/wrapper.go index 267d719abc6b0..69bf66b2196e5 100644 --- a/internal/distributed/proxy/httpserver/wrapper.go +++ b/internal/distributed/proxy/httpserver/wrapper.go @@ -5,15 +5,13 @@ import ( "net/http" "github.com/cockroachdb/errors" - "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) -var ( - errBadRequest = errors.New("bad request") -) +var errBadRequest = errors.New("bad request") // handlerFunc handles http request with gin context type handlerFunc func(c *gin.Context) (interface{}, error) diff --git a/internal/distributed/proxy/httpserver/wrapper_test.go b/internal/distributed/proxy/httpserver/wrapper_test.go index f62e040ea92f0..6e591ab3d09db 100644 --- a/internal/distributed/proxy/httpserver/wrapper_test.go +++ b/internal/distributed/proxy/httpserver/wrapper_test.go @@ -6,13 +6,12 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" ) func TestWrapHandler(t *testing.T) { - var testWrapFunc = func(c *gin.Context) (interface{}, error) { + testWrapFunc := func(c *gin.Context) (interface{}, error) { Case := c.Param("case") switch Case { case "0": @@ -55,5 +54,4 @@ func TestWrapHandler(t *testing.T) { testEngine.ServeHTTP(w, req) assert.Equal(t, http.StatusInternalServerError, w.Code) }) - } diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index cab2894a0d31c..a1545fcdb707f 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -22,31 +22,29 @@ import ( "crypto/x509" "fmt" "io" - "io/ioutil" "net" "net/http" "os" "strconv" + "strings" "sync" "time" - "github.com/milvus-io/milvus/pkg/util/merr" - - "google.golang.org/grpc/credentials" - - management "github.com/milvus-io/milvus/internal/http" - "github.com/milvus-io/milvus/internal/proxy/accesslog" - "github.com/milvus-io/milvus/internal/util/componentutil" - "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/pkg/tracer" - "github.com/milvus-io/milvus/pkg/util/interceptor" - "github.com/milvus-io/milvus/pkg/util/metricsinfo" - "github.com/soheilhy/cmux" - "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" - "github.com/gin-gonic/gin" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" + "github.com/soheilhy/cmux" + clientv3 "go.etcd.io/etcd/client/v3" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "go.uber.org/atomic" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/status" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/federpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -54,23 +52,25 @@ import ( "github.com/milvus-io/milvus/internal/distributed/proxy/httpserver" qcc "github.com/milvus-io/milvus/internal/distributed/querycoord/client" rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" + "github.com/milvus-io/milvus/internal/distributed/utils" + management "github.com/milvus-io/milvus/internal/http" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proxy" + "github.com/milvus-io/milvus/internal/proxy/accesslog" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/componentutil" + "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/atomic" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/health/grpc_health_v1" - "google.golang.org/grpc/keepalive" - "google.golang.org/grpc/status" ) var ( @@ -97,14 +97,13 @@ type Server struct { serverID atomic.Int64 etcdCli *clientv3.Client - rootCoordClient types.RootCoord - dataCoordClient types.DataCoord - queryCoordClient types.QueryCoord + rootCoordClient types.RootCoordClient + dataCoordClient types.DataCoordClient + queryCoordClient types.QueryCoordClient } // NewServer create a Proxy server. func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) { - var err error server := &Server{ ctx: ctx, @@ -118,9 +117,11 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) } func authenticate(c *gin.Context) { + c.Set(httpserver.ContextUsername, "") if !proxy.Params.CommonCfg.AuthorizationEnabled.GetAsBool() { return } + // TODO fubang username, password, ok := httpserver.ParseUsernamePassword(c) if ok { if proxy.PasswordVerify(c, username, password) { @@ -129,7 +130,16 @@ func authenticate(c *gin.Context) { return } } - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{httpserver.HTTPReturnCode: httpserver.Code(merr.ErrNeedAuthenticate), httpserver.HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) + rawToken := httpserver.GetAuthorization(c) + if rawToken != "" && !strings.Contains(rawToken, util.CredentialSeperator) { + user, err := proxy.VerifyAPIKey(rawToken) + if err == nil { + c.Set(httpserver.ContextUsername, user) + return + } + log.Warn("fail to verify apikey", zap.Error(err)) + } + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{httpserver.HTTPReturnCode: merr.Code(merr.ErrNeedAuthenticate), httpserver.HTTPReturnMessage: merr.ErrNeedAuthenticate.Error()}) } // registerHTTPServer register the http server, panic when failed @@ -193,12 +203,12 @@ func (s *Server) startExternalRPCServer(grpcExternalPort int, errChan chan error func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) { defer s.wg.Done() Params := ¶mtable.Get().ProxyGrpcServerCfg - var kaep = keepalive.EnforcementPolicy{ + kaep := keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, // If a client pings more than once every 5 seconds, terminate the connection PermitWithoutStream: true, // Allow pings even when there are no active streams } - var kasp = keepalive.ServerParameters{ + kasp := keepalive.ServerParameters{ Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } @@ -248,7 +258,7 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) { } certPool := x509.NewCertPool() - rootBuf, err := ioutil.ReadFile(Params.CaPemPath.GetValue()) + rootBuf, err := os.ReadFile(Params.CaPemPath.GetValue()) if err != nil { log.Warn("failed read ca pem", zap.Error(err)) errChan <- err @@ -288,12 +298,12 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) { func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) { defer s.wg.Done() Params := ¶mtable.Get().ProxyGrpcServerCfg - var kaep = keepalive.EnforcementPolicy{ + kaep := keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, // If a client pings more than once every 5 seconds, terminate the connection PermitWithoutStream: true, // Allow pings even when there are no active streams } - var kasp = keepalive.ServerParameters{ + kasp := keepalive.ServerParameters{ Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } @@ -469,7 +479,7 @@ func (s *Server) init() error { } certPool := x509.NewCertPool() - rootBuf, err := ioutil.ReadFile(Params.CaPemPath.GetValue()) + rootBuf, err := os.ReadFile(Params.CaPemPath.GetValue()) if err != nil { log.Error("failed read ca pem", zap.Error(err)) return err @@ -492,7 +502,6 @@ func (s *Server) init() error { } } } - } { s.startExternalRPCServer(Params.Port.GetAsInt(), errChan) @@ -520,13 +529,6 @@ func (s *Server) init() error { log.Debug("create RootCoord client for Proxy done") } - log.Debug("init RootCoord client for Proxy") - if err := s.rootCoordClient.Init(); err != nil { - log.Warn("failed to init RootCoord client for Proxy", zap.Error(err)) - return err - } - log.Debug("init RootCoord client for Proxy done") - log.Debug("Proxy wait for RootCoord to be healthy") if err := componentutil.WaitForComponentHealthy(s.ctx, s.rootCoordClient, "RootCoord", 1000000, time.Millisecond*200); err != nil { log.Warn("Proxy failed to wait for RootCoord to be healthy", zap.Error(err)) @@ -549,13 +551,6 @@ func (s *Server) init() error { log.Debug("create DataCoord client for Proxy done") } - log.Debug("init DataCoord client for Proxy") - if err := s.dataCoordClient.Init(); err != nil { - log.Warn("failed to init DataCoord client for Proxy", zap.Error(err)) - return err - } - log.Debug("init DataCoord client for Proxy done") - log.Debug("Proxy wait for DataCoord to be healthy") if err := componentutil.WaitForComponentHealthy(s.ctx, s.dataCoordClient, "DataCoord", 1000000, time.Millisecond*200); err != nil { log.Warn("Proxy failed to wait for DataCoord to be healthy", zap.Error(err)) @@ -578,13 +573,6 @@ func (s *Server) init() error { log.Debug("create QueryCoord client for Proxy done") } - log.Debug("init QueryCoord client for Proxy") - if err := s.queryCoordClient.Init(); err != nil { - log.Warn("failed to init QueryCoord client for Proxy", zap.Error(err)) - return err - } - log.Debug("init QueryCoord client for Proxy done") - log.Debug("Proxy wait for QueryCoord to be healthy") if err := componentutil.WaitForComponentHealthy(s.ctx, s.queryCoordClient, "QueryCoord", 1000000, time.Millisecond*200); err != nil { log.Warn("Proxy failed to wait for QueryCoord to be healthy", zap.Error(err)) @@ -653,15 +641,13 @@ func (s *Server) Stop() error { go func() { defer gracefulWg.Done() if s.grpcInternalServer != nil { - log.Debug("Graceful stop grpc internal server...") - s.grpcInternalServer.GracefulStop() + utils.GracefulStopGRPCServer(s.grpcInternalServer) } if s.tcpServer != nil { log.Info("Graceful stop Proxy tcp server...") s.tcpServer.Close() } else if s.grpcExternalServer != nil { - log.Info("Graceful stop grpc external server...") - s.grpcExternalServer.GracefulStop() + utils.GracefulStopGRPCServer(s.grpcExternalServer) if s.httpServer != nil { log.Info("Graceful stop grpc http server...") s.httpServer.Close() @@ -682,12 +668,12 @@ func (s *Server) Stop() error { // GetComponentStates get the component states func (s *Server) GetComponentStates(ctx context.Context, request *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { - return s.proxy.GetComponentStates(ctx) + return s.proxy.GetComponentStates(ctx, request) } // GetStatisticsChannel get the statistics channel func (s *Server) GetStatisticsChannel(ctx context.Context, request *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { - return s.proxy.GetStatisticsChannel(ctx) + return s.proxy.GetStatisticsChannel(ctx, request) } // InvalidateCollectionMetaCache notifies Proxy to clear all the meta cache of specific collection. @@ -858,7 +844,6 @@ func (s *Server) GetPersistentSegmentInfo(ctx context.Context, request *milvuspb // GetQuerySegmentInfo notifies Proxy to get query segment info. func (s *Server) GetQuerySegmentInfo(ctx context.Context, request *milvuspb.GetQuerySegmentInfoRequest) (*milvuspb.GetQuerySegmentInfoResponse, error) { return s.proxy.GetQuerySegmentInfo(ctx, request) - } func (s *Server) Dummy(ctx context.Context, request *milvuspb.DummyRequest) (*milvuspb.DummyResponse, error) { @@ -895,19 +880,13 @@ func (s *Server) AlterAlias(ctx context.Context, request *milvuspb.AlterAliasReq func (s *Server) DescribeAlias(ctx context.Context, request *milvuspb.DescribeAliasRequest) (*milvuspb.DescribeAliasResponse, error) { return &milvuspb.DescribeAliasResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "TODO: implement me", - }, + Status: merr.Status(merr.WrapErrServiceUnavailable("DescribeAlias unimplemented")), }, nil } func (s *Server) ListAliases(ctx context.Context, request *milvuspb.ListAliasesRequest) (*milvuspb.ListAliasesResponse, error) { return &milvuspb.ListAliasesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "TODO: implement me", - }, + Status: merr.Status(merr.WrapErrServiceUnavailable("ListAliases unimplemented")), }, nil } @@ -925,7 +904,7 @@ func (s *Server) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb. return s.proxy.GetCompactionStateWithPlans(ctx, req) } -// GetFlushState gets the flush state of multiple segments +// GetFlushState gets the flush state of the collection based on the provided flush ts and segment IDs. func (s *Server) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { return s.proxy.GetFlushState(ctx, req) } @@ -956,11 +935,11 @@ func (s *Server) Check(ctx context.Context, req *grpc_health_v1.HealthCheckReque ret := &grpc_health_v1.HealthCheckResponse{ Status: grpc_health_v1.HealthCheckResponse_NOT_SERVING, } - state, err := s.proxy.GetComponentStates(ctx) + state, err := s.proxy.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) if err != nil { return ret, err } - if state.Status.ErrorCode != commonpb.ErrorCode_Success { + if state.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { return ret, nil } if state.State.StateCode != commonpb.StateCode_Healthy { @@ -975,11 +954,11 @@ func (s *Server) Watch(req *grpc_health_v1.HealthCheckRequest, server grpc_healt ret := &grpc_health_v1.HealthCheckResponse{ Status: grpc_health_v1.HealthCheckResponse_NOT_SERVING, } - state, err := s.proxy.GetComponentStates(s.ctx) + state, err := s.proxy.GetComponentStates(s.ctx, nil) if err != nil { return server.Send(ret) } - if state.Status.ErrorCode != commonpb.ErrorCode_Success { + if state.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { return server.Send(ret) } if state.State.StateCode != commonpb.StateCode_Healthy { @@ -1058,7 +1037,7 @@ func (s *Server) GetProxyMetrics(ctx context.Context, request *milvuspb.GetMetri func (s *Server) GetVersion(ctx context.Context, request *milvuspb.GetVersionRequest) (*milvuspb.GetVersionResponse, error) { buildTags := os.Getenv(metricsinfo.GitBuildTagsEnvKey) return &milvuspb.GetVersionResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Version: buildTags, }, nil } @@ -1097,19 +1076,13 @@ func (s *Server) ListResourceGroups(ctx context.Context, req *milvuspb.ListResou func (s *Server) ListIndexedSegment(ctx context.Context, req *federpb.ListIndexedSegmentRequest) (*federpb.ListIndexedSegmentResponse, error) { return &federpb.ListIndexedSegmentResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "not implemented", - }, + Status: merr.Status(merr.WrapErrServiceUnavailable("ListIndexedSegment unimplemented")), }, nil } func (s *Server) DescribeSegmentIndexData(ctx context.Context, req *federpb.DescribeSegmentIndexDataRequest) (*federpb.DescribeSegmentIndexDataResponse, error) { return &federpb.DescribeSegmentIndexDataResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "not implemented", - }, + Status: merr.Status(merr.WrapErrServiceUnavailable("DescribeSegmentIndexData unimplemented")), }, nil } @@ -1136,3 +1109,7 @@ func (s *Server) ListDatabases(ctx context.Context, request *milvuspb.ListDataba func (s *Server) AllocTimestamp(ctx context.Context, req *milvuspb.AllocTimestampRequest) (*milvuspb.AllocTimestampResponse, error) { return s.proxy.AllocTimestamp(ctx, req) } + +func (s *Server) ReplicateMessage(ctx context.Context, req *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) { + return s.proxy.ReplicateMessage(ctx, req) +} diff --git a/internal/distributed/proxy/service_test.go b/internal/distributed/proxy/service_test.go index 38c0d1aa3b69a..34d506d348ebf 100644 --- a/internal/distributed/proxy/service_test.go +++ b/internal/distributed/proxy/service_test.go @@ -22,16 +22,14 @@ import ( "crypto/x509" "encoding/json" "fmt" + "net/http/httptest" "os" "strconv" "testing" "time" "github.com/cockroachdb/errors" - milvusmock "github.com/milvus-io/milvus/internal/util/mock" - - "github.com/milvus-io/milvus/internal/proto/indexpb" - + "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" clientv3 "go.etcd.io/etcd/client/v3" @@ -44,13 +42,13 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/federpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/distributed/proxy/httpserver" "github.com/milvus-io/milvus/internal/mocks" - "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/proxypb" - "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/proxy" "github.com/milvus-io/milvus/internal/types" + milvusmock "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -76,7 +74,7 @@ func (m *MockBase) On(methodName string, arguments ...interface{}) *mock.Call { return m.Mock.On(methodName, arguments...) } -func (m *MockBase) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (m *MockBase) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { if m.isMockGetComponentStatesOn { ret1 := &milvuspb.ComponentStates{} var ret2 error @@ -97,424 +95,56 @@ func (m *MockBase) GetComponentStates(ctx context.Context) (*milvuspb.ComponentS }, nil } -func (m *MockBase) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (m *MockBase) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { return nil, nil } -func (m *MockBase) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (m *MockBase) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { return nil, nil } // ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -type MockRootCoord struct { - MockBase - initErr error - startErr error - regErr error - stopErr error -} - -func (m *MockRootCoord) Init() error { - return m.initErr -} - -func (m *MockRootCoord) Start() error { - return m.startErr -} - -func (m *MockRootCoord) Stop() error { - return m.stopErr -} - -func (m *MockRootCoord) Register() error { - return m.regErr -} - -func (m *MockRootCoord) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) DropCollection(ctx context.Context, req *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) HasCollection(ctx context.Context, req *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) DescribeCollectionInternal(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) ShowCollections(ctx context.Context, req *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) DropPartition(ctx context.Context, req *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) HasPartition(ctx context.Context, req *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) ShowPartitions(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) ShowPartitionsInternal(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) UpdateChannelTimeTick(ctx context.Context, req *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) ShowSegments(ctx context.Context, req *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) ListImportTasks(ctx context.Context, in *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) CreateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) UpdateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) GetCredential(ctx context.Context, req *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) DropRole(ctx context.Context, in *milvuspb.DropRoleRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) OperateUserRole(ctx context.Context, in *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) SelectRole(ctx context.Context, in *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) SelectUser(ctx context.Context, in *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) OperatePrivilege(ctx context.Context, in *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockRootCoord) SelectGrant(ctx context.Context, in *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { - return nil, nil -} - -func (m *MockRootCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { - return &milvuspb.CheckHealthResponse{ - IsHealthy: true, - }, nil -} - -func (m *MockRootCoord) RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) { - return nil, nil -} - -// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -type MockDataCoord struct { +type MockProxy struct { MockBase err error initErr error startErr error stopErr error regErr error + isMockOn bool } -func (m *MockDataCoord) Init() error { - return m.initErr -} - -func (m *MockDataCoord) Start() error { - return m.startErr -} - -func (m *MockDataCoord) Stop() error { - return m.stopErr -} - -func (m *MockDataCoord) Register() error { - return m.regErr -} - -func (m *MockDataCoord) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoRequest) (*datapb.GetSegmentInfoResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.FlushResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockDataCoord) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockDataCoord) MarkSegmentsDropped(ctx context.Context, req *datapb.MarkSegmentsDroppedRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockDataCoord) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockDataCoord) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsertBinlogPathsRequest) (*datapb.GetInsertBinlogPathsResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) GetCollectionStatistics(ctx context.Context, req *datapb.GetCollectionStatisticsRequest) (*datapb.GetCollectionStatisticsResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) GetPartitionStatistics(ctx context.Context, req *datapb.GetPartitionStatisticsRequest) (*datapb.GetPartitionStatisticsResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockDataCoord) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInfoRequest) (*datapb.GetRecoveryInfoResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryInfoRequestV2) (*datapb.GetRecoveryInfoResponseV2, error) { - return nil, nil -} - -func (m *MockDataCoord) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) GetSegmentsByStates(ctx context.Context, req *datapb.GetSegmentsByStatesRequest) (*datapb.GetSegmentsByStatesResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) CompleteCompaction(ctx context.Context, req *datapb.CompactionResult) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockDataCoord) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) GetCompactionState(ctx context.Context, req *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAllStateRequest) (*milvuspb.GetFlushAllStateResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) (*datapb.DropVirtualChannelResponse, error) { - return &datapb.DropVirtualChannelResponse{}, nil -} - -func (m *MockDataCoord) SetSegmentState(ctx context.Context, req *datapb.SetSegmentStateRequest) (*datapb.SetSegmentStateResponse, error) { - return &datapb.SetSegmentStateResponse{}, nil -} - -func (m *MockDataCoord) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockDataCoord) UpdateChannelCheckpoint(ctx context.Context, req *datapb.UpdateChannelCheckpointRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockDataCoord) BroadcastAlteredCollection(ctx context.Context, req *datapb.AlterCollectionRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockDataCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) GcConfirm(ctx context.Context, req *datapb.GcConfirmRequest) (*datapb.GcConfirmResponse, error) { - return nil, nil -} - -func (m *MockDataCoord) CreateIndex(ctx context.Context, req *indexpb.CreateIndexRequest) (*commonpb.Status, error) { - return nil, nil -} - -func (m *MockDataCoord) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error) { +func (m *MockProxy) DescribeAlias(ctx context.Context, request *milvuspb.DescribeAliasRequest) (*milvuspb.DescribeAliasResponse, error) { return nil, nil } -func (m *MockDataCoord) GetIndexState(ctx context.Context, req *indexpb.GetIndexStateRequest) (*indexpb.GetIndexStateResponse, error) { +func (m *MockProxy) ListAliases(ctx context.Context, request *milvuspb.ListAliasesRequest) (*milvuspb.ListAliasesResponse, error) { return nil, nil } -func (m *MockDataCoord) GetSegmentIndexState(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest) (*indexpb.GetSegmentIndexStateResponse, error) { +func (m *MockProxy) GetVersion(ctx context.Context, request *milvuspb.GetVersionRequest) (*milvuspb.GetVersionResponse, error) { return nil, nil } -func (m *MockDataCoord) GetIndexInfos(ctx context.Context, req *indexpb.GetIndexInfoRequest) (*indexpb.GetIndexInfoResponse, error) { +func (m *MockProxy) ListIndexedSegment(ctx context.Context, request *federpb.ListIndexedSegmentRequest) (*federpb.ListIndexedSegmentResponse, error) { return nil, nil } -func (m *MockDataCoord) DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error) { +func (m *MockProxy) DescribeSegmentIndexData(ctx context.Context, request *federpb.DescribeSegmentIndexDataRequest) (*federpb.DescribeSegmentIndexDataResponse, error) { return nil, nil } -func (m *MockDataCoord) GetIndexStatistics(ctx context.Context, req *indexpb.GetIndexStatisticsRequest) (*indexpb.GetIndexStatisticsResponse, error) { - return nil, nil +func (m *MockProxy) SetRootCoordClient(rootCoord types.RootCoordClient) { } -func (m *MockDataCoord) GetIndexBuildProgress(ctx context.Context, req *indexpb.GetIndexBuildProgressRequest) (*indexpb.GetIndexBuildProgressResponse, error) { - return nil, nil +func (m *MockProxy) SetDataCoordClient(dataCoord types.DataCoordClient) { } -func (m *MockDataCoord) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDataNodeTtMsgsRequest) (*commonpb.Status, error) { - return nil, nil +func (m *MockProxy) SetQueryCoordClient(queryCoord types.QueryCoordClient) { } -// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -type MockProxy struct { - MockBase - err error - initErr error - startErr error - stopErr error - regErr error - isMockOn bool +func (m *MockProxy) SetQueryNodeCreator(f func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error)) { + panic("error") } func (m *MockProxy) Init() error { @@ -729,28 +359,11 @@ func (m *MockProxy) GetProxyMetrics(ctx context.Context, request *milvuspb.GetMe return nil, nil } -func (m *MockProxy) SetRootCoordClient(rootCoord types.RootCoord) { - -} - -func (m *MockProxy) SetDataCoordClient(dataCoord types.DataCoord) { - -} - -func (m *MockProxy) SetQueryCoordClient(queryCoord types.QueryCoord) { - -} - -func (m *MockProxy) SetQueryNodeCreator(func(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error)) { - -} - func (m *MockProxy) GetRateLimiter() (types.Limiter, error) { return nil, nil } func (m *MockProxy) UpdateStateCode(stateCode commonpb.StateCode) { - } func (m *MockProxy) SetAddress(address string) { @@ -895,7 +508,11 @@ func (m *MockProxy) AllocTimestamp(ctx context.Context, req *milvuspb.AllocTimes return nil, nil } -/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +func (m *MockProxy) ReplicateMessage(ctx context.Context, req *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) { + return nil, nil +} + +// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// type WaitOption struct { Duration time.Duration `json:"duration"` @@ -997,9 +614,11 @@ func waitForGrpcReady(opt *WaitOption) { } // TODO: should tls-related configurations be hard code here? -var waitDuration = time.Second * 1 -var clientPemPath = "../../../configs/cert/client.pem" -var clientKeyPath = "../../../configs/cert/client.key" +var ( + waitDuration = time.Second * 1 + clientPemPath = "../../../configs/cert/client.pem" + clientKeyPath = "../../../configs/cert/client.key" +) // waitForServerReady wait for internal grpc service and external service to be ready, according to the params. func waitForServerReady() { @@ -1461,7 +1080,7 @@ func TestServer_Watch(t *testing.T) { watchServer := milvusmock.NewGrpcHealthWatchServer() resultChan := watchServer.Chan() req := &grpc_health_v1.HealthCheckRequest{Service: ""} - //var ret *grpc_health_v1.HealthCheckResponse + // var ret *grpc_health_v1.HealthCheckResponse err := server.Watch(req, watchServer) ret := <-resultChan @@ -1537,13 +1156,12 @@ func getServer(t *testing.T) *Server { assert.NoError(t, err) server.proxy = &MockProxy{} - server.rootCoordClient = &MockRootCoord{} - server.dataCoordClient = &MockDataCoord{} + server.rootCoordClient = &milvusmock.GrpcRootCoordClient{} + server.dataCoordClient = &milvusmock.GrpcDataCoordClient{} - mockQC := &mocks.MockQueryCoord{} + mockQC := &mocks.MockQueryCoordClient{} server.queryCoordClient = mockQC - mockQC.EXPECT().Init().Return(nil) - mockQC.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ + mockQC.EXPECT().GetComponentStates(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ NodeID: int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()), Role: "MockQueryCoord", @@ -1740,3 +1358,25 @@ func TestNotImplementedAPIs(t *testing.T) { }) }) } + +func TestHttpAuthenticate(t *testing.T) { + paramtable.Get().Save(proxy.Params.CommonCfg.AuthorizationEnabled.Key, "true") + defer paramtable.Get().Reset(proxy.Params.CommonCfg.AuthorizationEnabled.Key) + ctx, _ := gin.CreateTestContext(nil) + ctx.Request = httptest.NewRequest("GET", "/test", nil) + { + assert.Panics(t, func() { + ctx.Request.Header.Set("Authorization", "Bearer 123456") + authenticate(ctx) + }) + } + + { + proxy.SetMockAPIHook("foo", nil) + defer proxy.SetMockAPIHook("", nil) + ctx.Request.Header.Set("Authorization", "Bearer 123456") + authenticate(ctx) + ctxName, _ := ctx.Get(httpserver.ContextUsername) + assert.Equal(t, "foo", ctxName) + } +} diff --git a/internal/distributed/querycoord/client/client.go b/internal/distributed/querycoord/client/client.go index 655445c5ae2e8..22eedf074a068 100644 --- a/internal/distributed/querycoord/client/client.go +++ b/internal/distributed/querycoord/client/client.go @@ -20,6 +20,10 @@ import ( "context" "fmt" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" + "google.golang.org/grpc" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -31,9 +35,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/zap" - "google.golang.org/grpc" ) var Params *paramtable.ComponentParam = paramtable.Get() @@ -65,11 +66,6 @@ func NewClient(ctx context.Context, metaRoot string, etcdCli *clientv3.Client) ( return client, nil } -// Init initializes QueryCoord's grpc client. -func (c *Client) Init() error { - return nil -} - func (c *Client) getQueryCoordAddr() (string, error) { key := c.grpcClient.GetRole() msess, _, err := c.sess.GetSessions(key) @@ -90,21 +86,11 @@ func (c *Client) newGrpcClient(cc *grpc.ClientConn) querypb.QueryCoordClient { return querypb.NewQueryCoordClient(cc) } -// Start starts QueryCoordinator's client service. But it does nothing here. -func (c *Client) Start() error { - return nil -} - -// Stop stops QueryCoordinator's grpc client server. -func (c *Client) Stop() error { +// Close stops QueryCoordinator's grpc client server. +func (c *Client) Close() error { return c.grpcClient.Close() } -// Register dummy -func (c *Client) Register() error { - return nil -} - func wrapGrpcCall[T any](ctx context.Context, c *Client, call func(grpcClient querypb.QueryCoordClient) (*T, error)) (*T, error) { ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { @@ -119,28 +105,28 @@ func wrapGrpcCall[T any](ctx context.Context, c *Client, call func(grpcClient qu } // GetComponentStates gets the component states of QueryCoord. -func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (c *Client) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*milvuspb.ComponentStates, error) { return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) } // GetTimeTickChannel gets the time tick channel of QueryCoord. -func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Client) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*milvuspb.StringResponse, error) { return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) } // GetStatisticsChannel gets the statistics channel of QueryCoord. -func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Client) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*milvuspb.StringResponse, error) { return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) } // ShowCollections shows the collections in the QueryCoord. -func (c *Client) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { +func (c *Client) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest, opts ...grpc.CallOption) (*querypb.ShowCollectionsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -152,7 +138,7 @@ func (c *Client) ShowCollections(ctx context.Context, req *querypb.ShowCollectio } // LoadCollection loads the data of the specified collections in the QueryCoord. -func (c *Client) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) { +func (c *Client) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -164,7 +150,7 @@ func (c *Client) LoadCollection(ctx context.Context, req *querypb.LoadCollection } // ReleaseCollection release the data of the specified collections in the QueryCoord. -func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { +func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -176,7 +162,7 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl } // ShowPartitions shows the partitions in the QueryCoord. -func (c *Client) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { +func (c *Client) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest, opts ...grpc.CallOption) (*querypb.ShowPartitionsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -188,7 +174,7 @@ func (c *Client) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions } // LoadPartitions loads the data of the specified partitions in the QueryCoord. -func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { +func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -200,7 +186,7 @@ func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions } // ReleasePartitions release the data of the specified partitions in the QueryCoord. -func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { +func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -212,7 +198,7 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart } // SyncNewCreatedPartition notifies QueryCoord to sync new created partition if collection is loaded. -func (c *Client) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) { +func (c *Client) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -224,7 +210,7 @@ func (c *Client) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncN } // GetPartitionStates gets the states of the specified partition. -func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) { +func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest, opts ...grpc.CallOption) (*querypb.GetPartitionStatesResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -236,7 +222,7 @@ func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartiti } // GetSegmentInfo gets the information of the specified segment from QueryCoord. -func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { +func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -248,7 +234,7 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo } // LoadBalance migrate the sealed segments on the source node to the dst nodes. -func (c *Client) LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest) (*commonpb.Status, error) { +func (c *Client) LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -260,7 +246,7 @@ func (c *Client) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques } // ShowConfigurations gets specified configurations para of QueryCoord -func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { +func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -272,7 +258,7 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon } // GetMetrics gets the metrics information of QueryCoord. -func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -284,7 +270,7 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest } // GetReplicas gets the replicas of a certain collection. -func (c *Client) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) { +func (c *Client) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasRequest, opts ...grpc.CallOption) (*milvuspb.GetReplicasResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -296,7 +282,7 @@ func (c *Client) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasReque } // GetShardLeaders gets the shard leaders of a certain collection. -func (c *Client) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) { +func (c *Client) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -307,13 +293,13 @@ func (c *Client) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeade }) } -func (c *Client) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { +func (c *Client) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*milvuspb.CheckHealthResponse, error) { return client.CheckHealth(ctx, req) }) } -func (c *Client) CreateResourceGroup(ctx context.Context, req *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) { +func (c *Client) CreateResourceGroup(ctx context.Context, req *milvuspb.CreateResourceGroupRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -324,7 +310,7 @@ func (c *Client) CreateResourceGroup(ctx context.Context, req *milvuspb.CreateRe }) } -func (c *Client) DropResourceGroup(ctx context.Context, req *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) { +func (c *Client) DropResourceGroup(ctx context.Context, req *milvuspb.DropResourceGroupRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -335,7 +321,7 @@ func (c *Client) DropResourceGroup(ctx context.Context, req *milvuspb.DropResour }) } -func (c *Client) DescribeResourceGroup(ctx context.Context, req *querypb.DescribeResourceGroupRequest) (*querypb.DescribeResourceGroupResponse, error) { +func (c *Client) DescribeResourceGroup(ctx context.Context, req *querypb.DescribeResourceGroupRequest, opts ...grpc.CallOption) (*querypb.DescribeResourceGroupResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -346,7 +332,7 @@ func (c *Client) DescribeResourceGroup(ctx context.Context, req *querypb.Describ }) } -func (c *Client) TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) { +func (c *Client) TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -357,7 +343,7 @@ func (c *Client) TransferNode(ctx context.Context, req *milvuspb.TransferNodeReq }) } -func (c *Client) TransferReplica(ctx context.Context, req *querypb.TransferReplicaRequest) (*commonpb.Status, error) { +func (c *Client) TransferReplica(ctx context.Context, req *querypb.TransferReplicaRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -368,7 +354,7 @@ func (c *Client) TransferReplica(ctx context.Context, req *querypb.TransferRepli }) } -func (c *Client) ListResourceGroups(ctx context.Context, req *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) { +func (c *Client) ListResourceGroups(ctx context.Context, req *milvuspb.ListResourceGroupsRequest, opts ...grpc.CallOption) (*milvuspb.ListResourceGroupsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), diff --git a/internal/distributed/querycoord/client/client_test.go b/internal/distributed/querycoord/client/client_test.go index ead3d65eaa1d9..248d19d8c561b 100644 --- a/internal/distributed/querycoord/client/client_test.go +++ b/internal/distributed/querycoord/client/client_test.go @@ -25,17 +25,16 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/util/mock" + "github.com/stretchr/testify/assert" "go.uber.org/zap" - - "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/paramtable" "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proxy" + "github.com/milvus-io/milvus/internal/util/mock" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" - "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func TestMain(m *testing.M) { @@ -72,15 +71,6 @@ func Test_NewClient(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, client) - err = client.Init() - assert.NoError(t, err) - - err = client.Start() - assert.NoError(t, err) - - err = client.Register() - assert.NoError(t, err) - checkFunc := func(retNotNil bool) { retCheck := func(notNil bool, ret any, err error) { if notNil { @@ -92,13 +82,13 @@ func Test_NewClient(t *testing.T) { } } - r1, err := client.GetComponentStates(ctx) + r1, err := client.GetComponentStates(ctx, nil) retCheck(retNotNil, r1, err) - r2, err := client.GetTimeTickChannel(ctx) + r2, err := client.GetTimeTickChannel(ctx, nil) retCheck(retNotNil, r2, err) - r3, err := client.GetStatisticsChannel(ctx) + r3, err := client.GetStatisticsChannel(ctx, nil) retCheck(retNotNil, r3, err) r4, err := client.ShowCollections(ctx, nil) @@ -205,6 +195,6 @@ func Test_NewClient(t *testing.T) { checkFunc(true) - err = client.Stop() + err = client.Close() assert.NoError(t, err) } diff --git a/internal/distributed/querycoord/service.go b/internal/distributed/querycoord/service.go index 4f3b393caae5f..4281358f12312 100644 --- a/internal/distributed/querycoord/service.go +++ b/internal/distributed/querycoord/service.go @@ -24,10 +24,7 @@ import ( "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - "github.com/milvus-io/milvus/internal/util/componentutil" - "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/pkg/tracer" - "github.com/milvus-io/milvus/pkg/util/interceptor" + "github.com/tikv/client-go/v2/txnkv" clientv3 "go.etcd.io/etcd/client/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/atomic" @@ -39,15 +36,22 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" + "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" qc "github.com/milvus-io/milvus/internal/querycoordv2" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/componentutil" + "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tikv" ) // Server is the grpc server of QueryCoord. @@ -66,9 +70,10 @@ type Server struct { factory dependency.Factory etcdCli *clientv3.Client + tikvCli *txnkv.Client - dataCoord types.DataCoord - rootCoord types.RootCoord + dataCoord types.DataCoordClient + rootCoord types.RootCoordClient } // NewServer create a new QueryCoord grpc server. @@ -91,7 +96,6 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) // Run initializes and starts QueryCoord's grpc service. func (s *Server) Run() error { - if err := s.init(); err != nil { return err } @@ -104,10 +108,13 @@ func (s *Server) Run() error { return nil } +var getTiKVClient = tikv.GetTiKVClient + // init initializes QueryCoord's grpc service. func (s *Server) init() error { - etcdConfig := ¶mtable.Get().EtcdCfg - Params := ¶mtable.Get().QueryCoordGrpcServerCfg + params := paramtable.Get() + etcdConfig := ¶ms.EtcdCfg + rpcParams := ¶ms.QueryCoordGrpcServerCfg etcdCli, err := etcd.GetEtcdClient( etcdConfig.UseEmbedEtcd.GetAsBool(), @@ -123,10 +130,21 @@ func (s *Server) init() error { } s.etcdCli = etcdCli s.SetEtcdClient(etcdCli) - s.queryCoord.SetAddress(Params.GetAddress()) + s.queryCoord.SetAddress(rpcParams.GetAddress()) + + if params.MetaStoreCfg.MetaStoreType.GetValue() == util.MetaStoreTypeTiKV { + log.Info("Connecting to tikv metadata storage.") + s.tikvCli, err = getTiKVClient(¶mtable.Get().TiKVCfg) + if err != nil { + log.Warn("QueryCoord failed to connect to tikv", zap.Error(err)) + return err + } + s.SetTiKVClient(s.tikvCli) + log.Info("Connected to tikv. Using tikv as metadata storage.") + } s.wg.Add(1) - go s.startGrpcLoop(Params.Port.GetAsInt()) + go s.startGrpcLoop(rpcParams.Port.GetAsInt()) // wait for grpc server loop start err = <-s.grpcErrChan if err != nil { @@ -142,15 +160,6 @@ func (s *Server) init() error { } } - if err = s.rootCoord.Init(); err != nil { - log.Error("QueryCoord RootCoordClient Init failed", zap.Error(err)) - panic(err) - } - - if err = s.rootCoord.Start(); err != nil { - log.Error("QueryCoord RootCoordClient Start failed", zap.Error(err)) - panic(err) - } // wait for master init or healthy log.Debug("QueryCoord try to wait for RootCoord ready") err = componentutil.WaitForComponentHealthy(s.loopCtx, s.rootCoord, "RootCoord", 1000000, time.Millisecond*200) @@ -173,14 +182,6 @@ func (s *Server) init() error { } } - if err = s.dataCoord.Init(); err != nil { - log.Error("QueryCoord DataCoordClient Init failed", zap.Error(err)) - panic(err) - } - if err = s.dataCoord.Start(); err != nil { - log.Error("QueryCoord DataCoordClient Start failed", zap.Error(err)) - panic(err) - } log.Debug("QueryCoord try to wait for DataCoord ready") err = componentutil.WaitForComponentHealthy(s.loopCtx, s.dataCoord, "DataCoord", 1000000, time.Millisecond*200) if err != nil { @@ -201,19 +202,19 @@ func (s *Server) init() error { func (s *Server) startGrpcLoop(grpcPort int) { defer s.wg.Done() Params := ¶mtable.Get().QueryCoordGrpcServerCfg - var kaep = keepalive.EnforcementPolicy{ + kaep := keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, // If a client pings more than once every 5 seconds, terminate the connection PermitWithoutStream: true, // Allow pings even when there are no active streams } - var kasp = keepalive.ServerParameters{ + kasp := keepalive.ServerParameters{ Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } log.Debug("network", zap.String("port", strconv.Itoa(grpcPort))) lis, err := net.Listen("tcp", ":"+strconv.Itoa(grpcPort)) if err != nil { - log.Debug("GrpcServer:failed to listen:", zap.String("error", err.Error())) + log.Debug("GrpcServer:failed to listen:", zap.Error(err)) s.grpcErrChan <- err return } @@ -273,12 +274,12 @@ func (s *Server) Stop() error { if s.etcdCli != nil { defer s.etcdCli.Close() } - err := s.queryCoord.Stop() s.loopCancel() if s.grpcServer != nil { - log.Debug("Graceful stop grpc server...") - s.grpcServer.GracefulStop() + utils.GracefulStopGRPCServer(s.grpcServer) } + err := s.queryCoord.Stop() + return err } @@ -287,31 +288,35 @@ func (s *Server) SetEtcdClient(etcdClient *clientv3.Client) { s.queryCoord.SetEtcdClient(etcdClient) } +func (s *Server) SetTiKVClient(client *txnkv.Client) { + s.queryCoord.SetTiKVClient(client) +} + // SetRootCoord sets the RootCoord's client for QueryCoord component. -func (s *Server) SetRootCoord(m types.RootCoord) error { - s.queryCoord.SetRootCoord(m) +func (s *Server) SetRootCoord(m types.RootCoordClient) error { + s.queryCoord.SetRootCoordClient(m) return nil } // SetDataCoord sets the DataCoord's client for QueryCoord component. -func (s *Server) SetDataCoord(d types.DataCoord) error { - s.queryCoord.SetDataCoord(d) +func (s *Server) SetDataCoord(d types.DataCoordClient) error { + s.queryCoord.SetDataCoordClient(d) return nil } // GetComponentStates gets the component states of QueryCoord. func (s *Server) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { - return s.queryCoord.GetComponentStates(ctx) + return s.queryCoord.GetComponentStates(ctx, req) } // GetTimeTickChannel gets the time tick channel of QueryCoord. func (s *Server) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { - return s.queryCoord.GetTimeTickChannel(ctx) + return s.queryCoord.GetTimeTickChannel(ctx, req) } // GetStatisticsChannel gets the statistics channel of QueryCoord. func (s *Server) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { - return s.queryCoord.GetStatisticsChannel(ctx) + return s.queryCoord.GetStatisticsChannel(ctx, req) } // ShowCollections shows the collections in the QueryCoord. diff --git a/internal/distributed/querycoord/service_test.go b/internal/distributed/querycoord/service_test.go index 1e05264cfc8f0..359cd74fa5b54 100644 --- a/internal/distributed/querycoord/service_test.go +++ b/internal/distributed/querycoord/service_test.go @@ -22,46 +22,34 @@ import ( "testing" "github.com/cockroachdb/errors" - - "github.com/milvus-io/milvus/internal/mocks" - "github.com/milvus-io/milvus/internal/types" - "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/tikv/client-go/v2/txnkv" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tikv" ) // ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// type MockRootCoord struct { - types.RootCoord - initErr error - startErr error - regErr error + types.RootCoordClient stopErr error stateErr commonpb.ErrorCode } -func (m *MockRootCoord) Init() error { - return m.initErr -} - -func (m *MockRootCoord) Start() error { - return m.startErr -} - -func (m *MockRootCoord) Stop() error { +func (m *MockRootCoord) Close() error { return m.stopErr } -func (m *MockRootCoord) Register() error { - return m.regErr -} - -func (m *MockRootCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (m *MockRootCoord) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opt ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy}, Status: &commonpb.Status{ErrorCode: m.stateErr}, @@ -70,31 +58,16 @@ func (m *MockRootCoord) GetComponentStates(ctx context.Context) (*milvuspb.Compo // ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// type MockDataCoord struct { - types.DataCoord - initErr error - startErr error + types.DataCoordClient stopErr error - regErr error stateErr commonpb.ErrorCode } -func (m *MockDataCoord) Init() error { - return m.initErr -} - -func (m *MockDataCoord) Start() error { - return m.startErr -} - -func (m *MockDataCoord) Stop() error { +func (m *MockDataCoord) Close() error { return m.stopErr } -func (m *MockDataCoord) Register() error { - return m.regErr -} - -func (m *MockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (m *MockDataCoord) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy}, Status: &commonpb.Status{ErrorCode: m.stateErr}, @@ -109,305 +82,261 @@ func TestMain(m *testing.M) { } func Test_NewServer(t *testing.T) { - ctx := context.Background() - server, err := NewServer(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, server) - - mdc := &MockDataCoord{ - stateErr: commonpb.ErrorCode_Success, - } - - mrc := &MockRootCoord{ - stateErr: commonpb.ErrorCode_Success, - } - - mqc := getQueryCoord() - successStatus := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - - t.Run("Run", func(t *testing.T) { - server.queryCoord = mqc - server.dataCoord = mdc - server.rootCoord = mrc - - err = server.Run() - assert.NoError(t, err) - }) - - t.Run("GetComponentStates", func(t *testing.T) { - mqc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ - State: &milvuspb.ComponentInfo{ - NodeID: 0, - Role: "MockQueryCoord", - StateCode: commonpb.StateCode_Healthy, - }, - Status: successStatus, - }, nil) - - req := &milvuspb.GetComponentStatesRequest{} - states, err := server.GetComponentStates(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.StateCode_Healthy, states.State.StateCode) - }) - - t.Run("GetStatisticsChannel", func(t *testing.T) { - req := &internalpb.GetStatisticsChannelRequest{} - mqc.EXPECT().GetStatisticsChannel(mock.Anything).Return( - &milvuspb.StringResponse{ - Status: successStatus, - }, nil, - ) - resp, err := server.GetStatisticsChannel(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - }) - - t.Run("GetTimeTickChannel", func(t *testing.T) { - req := &internalpb.GetTimeTickChannelRequest{} - mqc.EXPECT().GetTimeTickChannel(mock.Anything).Return( - &milvuspb.StringResponse{ - Status: successStatus, - }, nil, - ) - resp, err := server.GetTimeTickChannel(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - }) - - t.Run("ShowCollections", func(t *testing.T) { - mqc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return( - &querypb.ShowCollectionsResponse{ - Status: successStatus, - }, nil, - ) - resp, err := server.ShowCollections(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - }) - - t.Run("LoadCollection", func(t *testing.T) { - mqc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil) - resp, err := server.LoadCollection(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("ReleaseCollection", func(t *testing.T) { - mqc.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).Return(successStatus, nil) - resp, err := server.ReleaseCollection(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("ShowPartitions", func(t *testing.T) { - mqc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{Status: successStatus}, nil) - resp, err := server.ShowPartitions(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - t.Run("GetPartitionStates", func(t *testing.T) { - mqc.EXPECT().GetPartitionStates(mock.Anything, mock.Anything).Return(&querypb.GetPartitionStatesResponse{Status: successStatus}, nil) - resp, err := server.GetPartitionStates(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("LoadPartitions", func(t *testing.T) { - mqc.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Return(successStatus, nil) - resp, err := server.LoadPartitions(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("ReleasePartitions", func(t *testing.T) { - mqc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(successStatus, nil) - resp, err := server.ReleasePartitions(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetTimeTickChannel", func(t *testing.T) { - mqc.EXPECT().GetTimeTickChannel(mock.Anything).Return(&milvuspb.StringResponse{Status: successStatus}, nil) - resp, err := server.GetTimeTickChannel(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, resp) - }) - - t.Run("GetSegmentInfo", func(t *testing.T) { - req := &querypb.GetSegmentInfoRequest{} - mqc.EXPECT().GetSegmentInfo(mock.Anything, req).Return(&querypb.GetSegmentInfoResponse{Status: successStatus}, nil) - resp, err := server.GetSegmentInfo(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - }) - - t.Run("LoadBalance", func(t *testing.T) { - req := &querypb.LoadBalanceRequest{} - mqc.EXPECT().LoadBalance(mock.Anything, req).Return(successStatus, nil) - resp, err := server.LoadBalance(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) - }) - - t.Run("GetMetrics", func(t *testing.T) { - req := &milvuspb.GetMetricsRequest{ - Request: "", + parameters := []string{"tikv", "etcd"} + for _, v := range parameters { + paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, v) + ctx := context.Background() + getTiKVClient = func(cfg *paramtable.TiKVConfig) (*txnkv.Client, error) { + return tikv.SetupLocalTxn(), nil } - mqc.EXPECT().GetMetrics(mock.Anything, req).Return(&milvuspb.GetMetricsResponse{Status: successStatus}, nil) - resp, err := server.GetMetrics(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - }) - - t.Run("CheckHealth", func(t *testing.T) { - mqc.EXPECT().CheckHealth(mock.Anything, mock.Anything).Return( - &milvuspb.CheckHealthResponse{Status: successStatus, IsHealthy: true}, nil) - ret, err := server.CheckHealth(ctx, nil) - assert.NoError(t, err) - assert.Equal(t, true, ret.IsHealthy) - }) - - t.Run("CreateResourceGroup", func(t *testing.T) { - mqc.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).Return(successStatus, nil) - resp, err := server.CreateResourceGroup(ctx, nil) + defer func() { + getTiKVClient = tikv.GetTiKVClient + }() + server, err := NewServer(ctx, nil) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) - }) + assert.NotNil(t, server) - t.Run("DropResourceGroup", func(t *testing.T) { - mqc.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(successStatus, nil) - resp, err := server.DropResourceGroup(ctx, nil) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) - }) - - t.Run("TransferNode", func(t *testing.T) { - mqc.EXPECT().TransferNode(mock.Anything, mock.Anything).Return(successStatus, nil) - resp, err := server.TransferNode(ctx, nil) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) - }) - - t.Run("TransferReplica", func(t *testing.T) { - mqc.EXPECT().TransferReplica(mock.Anything, mock.Anything).Return(successStatus, nil) - resp, err := server.TransferReplica(ctx, nil) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) - }) + mdc := &MockDataCoord{ + stateErr: commonpb.ErrorCode_Success, + } - t.Run("ListResourceGroups", func(t *testing.T) { - req := &milvuspb.ListResourceGroupsRequest{} - mqc.EXPECT().ListResourceGroups(mock.Anything, req).Return(&milvuspb.ListResourceGroupsResponse{Status: successStatus}, nil) - resp, err := server.ListResourceGroups(ctx, req) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - }) + mrc := &MockRootCoord{ + stateErr: commonpb.ErrorCode_Success, + } - t.Run("DescribeResourceGroup", func(t *testing.T) { - mqc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{Status: successStatus}, nil) - resp, err := server.DescribeResourceGroup(ctx, nil) + mqc := getQueryCoord() + successStatus := merr.Success() + + t.Run("Run", func(t *testing.T) { + server.queryCoord = mqc + server.dataCoord = mdc + server.rootCoord = mrc + + err = server.Run() + assert.NoError(t, err) + }) + + t.Run("GetComponentStates", func(t *testing.T) { + mqc.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ + State: &milvuspb.ComponentInfo{ + NodeID: 0, + Role: "MockQueryCoord", + StateCode: commonpb.StateCode_Healthy, + }, + Status: successStatus, + }, nil) + + req := &milvuspb.GetComponentStatesRequest{} + states, err := server.GetComponentStates(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.StateCode_Healthy, states.State.StateCode) + }) + + t.Run("GetStatisticsChannel", func(t *testing.T) { + req := &internalpb.GetStatisticsChannelRequest{} + mqc.EXPECT().GetStatisticsChannel(mock.Anything, mock.Anything).Return( + &milvuspb.StringResponse{ + Status: successStatus, + }, nil, + ) + resp, err := server.GetStatisticsChannel(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + + t.Run("GetTimeTickChannel", func(t *testing.T) { + req := &internalpb.GetTimeTickChannelRequest{} + mqc.EXPECT().GetTimeTickChannel(mock.Anything, mock.Anything).Return( + &milvuspb.StringResponse{ + Status: successStatus, + }, nil, + ) + resp, err := server.GetTimeTickChannel(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + + t.Run("ShowCollections", func(t *testing.T) { + mqc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return( + &querypb.ShowCollectionsResponse{ + Status: successStatus, + }, nil, + ) + resp, err := server.ShowCollections(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + + t.Run("LoadCollection", func(t *testing.T) { + mqc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil) + resp, err := server.LoadCollection(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("ReleaseCollection", func(t *testing.T) { + mqc.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).Return(successStatus, nil) + resp, err := server.ReleaseCollection(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("ShowPartitions", func(t *testing.T) { + mqc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{Status: successStatus}, nil) + resp, err := server.ShowPartitions(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + t.Run("GetPartitionStates", func(t *testing.T) { + mqc.EXPECT().GetPartitionStates(mock.Anything, mock.Anything).Return(&querypb.GetPartitionStatesResponse{Status: successStatus}, nil) + resp, err := server.GetPartitionStates(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("LoadPartitions", func(t *testing.T) { + mqc.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Return(successStatus, nil) + resp, err := server.LoadPartitions(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("ReleasePartitions", func(t *testing.T) { + mqc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(successStatus, nil) + resp, err := server.ReleasePartitions(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetTimeTickChannel", func(t *testing.T) { + mqc.EXPECT().GetTimeTickChannel(mock.Anything, mock.Anything).Return(&milvuspb.StringResponse{Status: successStatus}, nil) + resp, err := server.GetTimeTickChannel(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, resp) + }) + + t.Run("GetSegmentInfo", func(t *testing.T) { + req := &querypb.GetSegmentInfoRequest{} + mqc.EXPECT().GetSegmentInfo(mock.Anything, req).Return(&querypb.GetSegmentInfoResponse{Status: successStatus}, nil) + resp, err := server.GetSegmentInfo(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + + t.Run("LoadBalance", func(t *testing.T) { + req := &querypb.LoadBalanceRequest{} + mqc.EXPECT().LoadBalance(mock.Anything, req).Return(successStatus, nil) + resp, err := server.LoadBalance(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) + }) + + t.Run("GetMetrics", func(t *testing.T) { + req := &milvuspb.GetMetricsRequest{ + Request: "", + } + mqc.EXPECT().GetMetrics(mock.Anything, req).Return(&milvuspb.GetMetricsResponse{Status: successStatus}, nil) + resp, err := server.GetMetrics(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + + t.Run("CheckHealth", func(t *testing.T) { + mqc.EXPECT().CheckHealth(mock.Anything, mock.Anything).Return( + &milvuspb.CheckHealthResponse{Status: successStatus, IsHealthy: true}, nil) + ret, err := server.CheckHealth(ctx, nil) + assert.NoError(t, err) + assert.Equal(t, true, ret.IsHealthy) + }) + + t.Run("CreateResourceGroup", func(t *testing.T) { + mqc.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).Return(successStatus, nil) + resp, err := server.CreateResourceGroup(ctx, nil) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) + }) + + t.Run("DropResourceGroup", func(t *testing.T) { + mqc.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(successStatus, nil) + resp, err := server.DropResourceGroup(ctx, nil) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) + }) + + t.Run("TransferNode", func(t *testing.T) { + mqc.EXPECT().TransferNode(mock.Anything, mock.Anything).Return(successStatus, nil) + resp, err := server.TransferNode(ctx, nil) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) + }) + + t.Run("TransferReplica", func(t *testing.T) { + mqc.EXPECT().TransferReplica(mock.Anything, mock.Anything).Return(successStatus, nil) + resp, err := server.TransferReplica(ctx, nil) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) + }) + + t.Run("ListResourceGroups", func(t *testing.T) { + req := &milvuspb.ListResourceGroupsRequest{} + mqc.EXPECT().ListResourceGroups(mock.Anything, req).Return(&milvuspb.ListResourceGroupsResponse{Status: successStatus}, nil) + resp, err := server.ListResourceGroups(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + + t.Run("DescribeResourceGroup", func(t *testing.T) { + mqc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{Status: successStatus}, nil) + resp, err := server.DescribeResourceGroup(ctx, nil) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + + err = server.Stop() assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - }) - - err = server.Stop() - assert.NoError(t, err) + } } // This test will no longer return error immediately. func TestServer_Run1(t *testing.T) { - t.Skip() - ctx := context.Background() - server, err := NewServer(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, server) - - mqc := getQueryCoord() - mqc.EXPECT().Start().Return(errors.New("error")) - server.queryCoord = mqc - err = server.Run() - assert.Error(t, err) - - err = server.Stop() - assert.NoError(t, err) -} + parameters := []string{"tikv", "etcd"} + for _, v := range parameters { + paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, v) + t.Skip() + ctx := context.Background() + getTiKVClient = func(cfg *paramtable.TiKVConfig) (*txnkv.Client, error) { + return tikv.SetupLocalTxn(), nil + } + defer func() { + getTiKVClient = tikv.GetTiKVClient + }() + server, err := NewServer(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, server) -func TestServer_Run2(t *testing.T) { - ctx := context.Background() - server, err := NewServer(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, server) + mqc := getQueryCoord() + mqc.EXPECT().Start().Return(errors.New("error")) + server.queryCoord = mqc + err = server.Run() + assert.Error(t, err) - server.queryCoord = getQueryCoord() - server.rootCoord = &MockRootCoord{ - initErr: errors.New("error"), + err = server.Stop() + assert.NoError(t, err) } - assert.Panics(t, func() { server.Run() }) - err = server.Stop() - assert.NoError(t, err) } func getQueryCoord() *mocks.MockQueryCoord { mqc := &mocks.MockQueryCoord{} mqc.EXPECT().Init().Return(nil) mqc.EXPECT().SetEtcdClient(mock.Anything) + mqc.EXPECT().SetTiKVClient(mock.Anything) mqc.EXPECT().SetAddress(mock.Anything) - mqc.EXPECT().SetRootCoord(mock.Anything).Return(nil) - mqc.EXPECT().SetDataCoord(mock.Anything).Return(nil) + mqc.EXPECT().SetRootCoordClient(mock.Anything).Return(nil) + mqc.EXPECT().SetDataCoordClient(mock.Anything).Return(nil) mqc.EXPECT().UpdateStateCode(mock.Anything) mqc.EXPECT().Register().Return(nil) mqc.EXPECT().Start().Return(nil) mqc.EXPECT().Stop().Return(nil) - return mqc } - -func TestServer_Run3(t *testing.T) { - ctx := context.Background() - server, err := NewServer(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, server) - server.queryCoord = getQueryCoord() - server.rootCoord = &MockRootCoord{ - startErr: errors.New("error"), - } - assert.Panics(t, func() { server.Run() }) - err = server.Stop() - assert.NoError(t, err) - -} - -func TestServer_Run4(t *testing.T) { - ctx := context.Background() - server, err := NewServer(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, server) - - server.queryCoord = getQueryCoord() - server.rootCoord = &MockRootCoord{} - server.dataCoord = &MockDataCoord{ - initErr: errors.New("error"), - } - assert.Panics(t, func() { server.Run() }) - err = server.Stop() - assert.NoError(t, err) -} - -func TestServer_Run5(t *testing.T) { - ctx := context.Background() - server, err := NewServer(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, server) - - server.queryCoord = getQueryCoord() - server.rootCoord = &MockRootCoord{} - server.dataCoord = &MockDataCoord{ - startErr: errors.New("error"), - } - assert.Panics(t, func() { server.Run() }) - err = server.Stop() - assert.NoError(t, err) -} diff --git a/internal/distributed/querynode/client/client.go b/internal/distributed/querynode/client/client.go index e92e12e8a258a..a7291a9ad50f7 100644 --- a/internal/distributed/querynode/client/client.go +++ b/internal/distributed/querynode/client/client.go @@ -20,13 +20,13 @@ import ( "context" "fmt" - "github.com/milvus-io/milvus/internal/util/grpcclient" "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -49,7 +49,8 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error) addr: addr, grpcClient: grpcclient.NewClientBase[querypb.QueryNodeClient](config, "milvus.proto.query.QueryNode"), } - client.grpcClient.SetRole(typeutil.QueryNodeRole) + // node shall specify node id + client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.QueryNodeRole, nodeID)) client.grpcClient.SetGetAddrFunc(client.getAddr) client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) client.grpcClient.SetNodeID(nodeID) @@ -57,26 +58,11 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error) return client, nil } -// Init initializes QueryNode's grpc client. -func (c *Client) Init() error { - return nil -} - -// Start starts QueryNode's client service. But it does nothing here. -func (c *Client) Start() error { - return nil -} - -// Stop stops QueryNode's grpc client server. -func (c *Client) Stop() error { +// Close close QueryNode's grpc client +func (c *Client) Close() error { return c.grpcClient.Close() } -// Register dummy -func (c *Client) Register() error { - return nil -} - func (c *Client) newGrpcClient(cc *grpc.ClientConn) querypb.QueryNodeClient { return querypb.NewQueryNodeClient(cc) } @@ -99,28 +85,28 @@ func wrapGrpcCall[T any](ctx context.Context, c *Client, call func(grpcClient qu } // GetComponentStates gets the component states of QueryNode. -func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (c *Client) GetComponentStates(ctx context.Context, _ *milvuspb.GetComponentStatesRequest, _ ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*milvuspb.ComponentStates, error) { return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) } // GetTimeTickChannel gets the time tick channel of QueryNode. -func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Client) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest, _ ...grpc.CallOption) (*milvuspb.StringResponse, error) { return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*milvuspb.StringResponse, error) { return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) } // GetStatisticsChannel gets the statistics channel of QueryNode. -func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Client) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest, _ ...grpc.CallOption) (*milvuspb.StringResponse, error) { return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*milvuspb.StringResponse, error) { return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) } // WatchDmChannels watches the channels about data manipulation. -func (c *Client) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { +func (c *Client) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest, _ ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -131,7 +117,7 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChanne } // UnsubDmChannel unsubscribes the channels about data manipulation. -func (c *Client) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { +func (c *Client) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest, _ ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -142,7 +128,7 @@ func (c *Client) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannel } // LoadSegments loads the segments to search. -func (c *Client) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { +func (c *Client) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest, _ ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -153,7 +139,7 @@ func (c *Client) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequ } // ReleaseCollection releases the data of the specified collection in QueryNode. -func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { +func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest, _ ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -164,7 +150,7 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl } // LoadPartitions updates partitions meta info in QueryNode. -func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { +func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest, _ ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -175,7 +161,7 @@ func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions } // ReleasePartitions releases the data of the specified partitions in QueryNode. -func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { +func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest, _ ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -186,7 +172,7 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart } // ReleaseSegments releases the data of the specified segments in QueryNode. -func (c *Client) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { +func (c *Client) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, _ ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -197,33 +183,61 @@ func (c *Client) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmen } // Search performs replica search tasks in QueryNode. -func (c *Client) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { +func (c *Client) Search(ctx context.Context, req *querypb.SearchRequest, _ ...grpc.CallOption) (*internalpb.SearchResults, error) { return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*internalpb.SearchResults, error) { return client.Search(ctx, req) }) } -func (c *Client) SearchSegments(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { +func (c *Client) SearchSegments(ctx context.Context, req *querypb.SearchRequest, _ ...grpc.CallOption) (*internalpb.SearchResults, error) { return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*internalpb.SearchResults, error) { return client.SearchSegments(ctx, req) }) } // Query performs replica query tasks in QueryNode. -func (c *Client) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { +func (c *Client) Query(ctx context.Context, req *querypb.QueryRequest, _ ...grpc.CallOption) (*internalpb.RetrieveResults, error) { return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*internalpb.RetrieveResults, error) { return client.Query(ctx, req) }) } -func (c *Client) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { +func (c *Client) QueryStream(ctx context.Context, req *querypb.QueryRequest, _ ...grpc.CallOption) (querypb.QueryNode_QueryStreamClient, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { + if !funcutil.CheckCtxValid(ctx) { + return nil, ctx.Err() + } + + return client.QueryStream(ctx, req) + }) + if err != nil || ret == nil { + return nil, err + } + return ret.(querypb.QueryNode_QueryStreamClient), nil +} + +func (c *Client) QuerySegments(ctx context.Context, req *querypb.QueryRequest, _ ...grpc.CallOption) (*internalpb.RetrieveResults, error) { return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*internalpb.RetrieveResults, error) { return client.QuerySegments(ctx, req) }) } +func (c *Client) QueryStreamSegments(ctx context.Context, req *querypb.QueryRequest, _ ...grpc.CallOption) (querypb.QueryNode_QueryStreamSegmentsClient, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { + if !funcutil.CheckCtxValid(ctx) { + return nil, ctx.Err() + } + + return client.QueryStreamSegments(ctx, req) + }) + if err != nil || ret == nil { + return nil, err + } + return ret.(querypb.QueryNode_QueryStreamSegmentsClient), nil +} + // GetSegmentInfo gets the information of the specified segments in QueryNode. -func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { +func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest, _ ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -234,7 +248,7 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo } // SyncReplicaSegments syncs replica node segments information to shard leaders. -func (c *Client) SyncReplicaSegments(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) { +func (c *Client) SyncReplicaSegments(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest, _ ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -245,7 +259,7 @@ func (c *Client) SyncReplicaSegments(ctx context.Context, req *querypb.SyncRepli } // ShowConfigurations gets specified configurations para of QueryNode -func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { +func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest, _ ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -256,7 +270,7 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon } // GetMetrics gets the metrics information of QueryNode. -func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, _ ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -266,13 +280,13 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest }) } -func (c *Client) GetStatistics(ctx context.Context, request *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) { +func (c *Client) GetStatistics(ctx context.Context, request *querypb.GetStatisticsRequest, _ ...grpc.CallOption) (*internalpb.GetStatisticsResponse, error) { return wrapGrpcCall(ctx, c, func(client querypb.QueryNodeClient) (*internalpb.GetStatisticsResponse, error) { return client.GetStatistics(ctx, request) }) } -func (c *Client) GetDataDistribution(ctx context.Context, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) { +func (c *Client) GetDataDistribution(ctx context.Context, req *querypb.GetDataDistributionRequest, _ ...grpc.CallOption) (*querypb.GetDataDistributionResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -282,7 +296,7 @@ func (c *Client) GetDataDistribution(ctx context.Context, req *querypb.GetDataDi }) } -func (c *Client) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) { +func (c *Client) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest, _ ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -293,7 +307,7 @@ func (c *Client) SyncDistribution(ctx context.Context, req *querypb.SyncDistribu } // Delete is used to forward delete message between delegator and workers. -func (c *Client) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) { +func (c *Client) Delete(ctx context.Context, req *querypb.DeleteRequest, _ ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), diff --git a/internal/distributed/querynode/client/client_test.go b/internal/distributed/querynode/client/client_test.go index bfaa4ad242203..bd617d81fa979 100644 --- a/internal/distributed/querynode/client/client_test.go +++ b/internal/distributed/querynode/client/client_test.go @@ -21,12 +21,12 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/util/mock" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" - "google.golang.org/grpc" ) func Test_NewClient(t *testing.T) { @@ -41,12 +41,6 @@ func Test_NewClient(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, client) - err = client.Start() - assert.NoError(t, err) - - err = client.Register() - assert.NoError(t, err) - ctx, cancel := context.WithCancel(ctx) checkFunc := func(retNotNil bool) { @@ -60,13 +54,13 @@ func Test_NewClient(t *testing.T) { } } - r1, err := client.GetComponentStates(ctx) + r1, err := client.GetComponentStates(ctx, nil) retCheck(retNotNil, r1, err) - r2, err := client.GetTimeTickChannel(ctx) + r2, err := client.GetTimeTickChannel(ctx, nil) retCheck(retNotNil, r2, err) - r3, err := client.GetStatisticsChannel(ctx) + r3, err := client.GetStatisticsChannel(ctx, nil) retCheck(retNotNil, r3, err) r6, err := client.WatchDmChannels(ctx, nil) @@ -113,6 +107,10 @@ func Test_NewClient(t *testing.T) { r20, err := client.SearchSegments(ctx, nil) retCheck(retNotNil, r20, err) + + // stream rpc + client, err := client.QueryStream(ctx, nil) + retCheck(retNotNil, client, err) } client.grpcClient = &mock.GRPCClientBase[querypb.QueryNodeClient]{ @@ -157,6 +155,6 @@ func Test_NewClient(t *testing.T) { cancel() // make context canceled checkFunc(false) - err = client.Stop() + err = client.Close() assert.NoError(t, err) } diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index 914688166cd17..0233b3a98b1ab 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -25,9 +25,6 @@ import ( "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/pkg/tracer" - "github.com/milvus-io/milvus/pkg/util/interceptor" clientv3 "go.etcd.io/etcd/client/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/atomic" @@ -37,13 +34,17 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" qn "github.com/milvus-io/milvus/internal/querynodev2" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/tracer" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" @@ -148,12 +149,12 @@ func (s *Server) start() error { func (s *Server) startGrpcLoop(grpcPort int) { defer s.wg.Done() Params := ¶mtable.Get().QueryNodeGrpcServerCfg - var kaep = keepalive.EnforcementPolicy{ + kaep := keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, // If a client pings more than once every 5 seconds, terminate the connection PermitWithoutStream: true, // Allow pings even when there are no active streams } - var kasp = keepalive.ServerParameters{ + kasp := keepalive.ServerParameters{ Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } @@ -214,12 +215,10 @@ func (s *Server) startGrpcLoop(grpcPort int) { log.Debug("QueryNode Start Grpc Failed!!!!") s.grpcErrChan <- err } - } // Run initializes and starts QueryNode's grpc service. func (s *Server) Run() error { - if err := s.init(); err != nil { return err } @@ -246,8 +245,7 @@ func (s *Server) Stop() error { s.cancel() if s.grpcServer != nil { - log.Debug("Graceful stop grpc server...") - s.grpcServer.GracefulStop() + utils.GracefulStopGRPCServer(s.grpcServer) } s.wg.Wait() return nil @@ -260,18 +258,18 @@ func (s *Server) SetEtcdClient(etcdCli *clientv3.Client) { // GetTimeTickChannel gets the time tick channel of QueryNode. func (s *Server) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { - return s.querynode.GetTimeTickChannel(ctx) + return s.querynode.GetTimeTickChannel(ctx, req) } // GetStatisticsChannel gets the statistics channel of QueryNode. func (s *Server) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { - return s.querynode.GetStatisticsChannel(ctx) + return s.querynode.GetStatisticsChannel(ctx, req) } // GetComponentStates gets the component states of QueryNode. func (s *Server) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { // ignore ctx and in - return s.querynode.GetComponentStates(ctx) + return s.querynode.GetComponentStates(ctx, req) } // WatchDmChannels watches the channels about data manipulation. @@ -332,6 +330,14 @@ func (s *Server) Query(ctx context.Context, req *querypb.QueryRequest) (*interna return s.querynode.Query(ctx, req) } +func (s *Server) QueryStream(req *querypb.QueryRequest, srv querypb.QueryNode_QueryStreamServer) error { + return s.querynode.QueryStream(req, srv) +} + +func (s *Server) QueryStreamSegments(req *querypb.QueryRequest, srv querypb.QueryNode_QueryStreamSegmentsServer) error { + return s.querynode.QueryStreamSegments(req, srv) +} + func (s *Server) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { return s.querynode.QuerySegments(ctx, req) } diff --git a/internal/distributed/querynode/service_test.go b/internal/distributed/querynode/service_test.go index 7102a71b299e4..7565aa59e1f86 100644 --- a/internal/distributed/querynode/service_test.go +++ b/internal/distributed/querynode/service_test.go @@ -22,19 +22,17 @@ import ( "testing" "github.com/cockroachdb/errors" - - "github.com/milvus-io/milvus/internal/mocks" - "github.com/milvus-io/milvus/internal/types" - "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" clientv3 "go.etcd.io/etcd/client/v3" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) type MockRootCoord struct { @@ -65,7 +63,7 @@ func (m *MockRootCoord) Register() error { func (m *MockRootCoord) SetEtcdClient(client *clientv3.Client) { } -func (m *MockRootCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (m *MockRootCoord) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy}, Status: &commonpb.Status{ErrorCode: m.stateErr}, @@ -101,10 +99,11 @@ func Test_NewServer(t *testing.T) { }) t.Run("GetComponentStates", func(t *testing.T) { - mockQN.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ + mockQN.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Healthy, - }}, nil) + }, + }, nil) req := &milvuspb.GetComponentStatesRequest{} states, err := server.GetComponentStates(ctx, req) assert.NoError(t, err) @@ -112,19 +111,19 @@ func Test_NewServer(t *testing.T) { }) t.Run("GetStatisticsChannel", func(t *testing.T) { - mockQN.EXPECT().GetStatisticsChannel(mock.Anything).Return(&milvuspb.StringResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil) + mockQN.EXPECT().GetStatisticsChannel(mock.Anything, mock.Anything).Return(&milvuspb.StringResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil) req := &internalpb.GetStatisticsChannelRequest{} resp, err := server.GetStatisticsChannel(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("GetTimeTickChannel", func(t *testing.T) { - mockQN.EXPECT().GetTimeTickChannel(mock.Anything).Return(&milvuspb.StringResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil) + mockQN.EXPECT().GetTimeTickChannel(mock.Anything, mock.Anything).Return(&milvuspb.StringResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil) req := &internalpb.GetTimeTickChannelRequest{} resp, err := server.GetTimeTickChannel(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("WatchDmChannels", func(t *testing.T) { @@ -177,11 +176,12 @@ func Test_NewServer(t *testing.T) { t.Run("GetSegmentInfo", func(t *testing.T) { mockQN.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return(&querypb.GetSegmentInfoResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil) + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + }, nil) req := &querypb.GetSegmentInfoRequest{} resp, err := server.GetSegmentInfo(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("GetMetrics", func(t *testing.T) { @@ -192,12 +192,13 @@ func Test_NewServer(t *testing.T) { } resp, err := server.GetMetrics(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("Search", func(t *testing.T) { mockQN.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil) + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + }, nil) req := &querypb.SearchRequest{} resp, err := server.Search(ctx, req) assert.NoError(t, err) @@ -206,7 +207,8 @@ func Test_NewServer(t *testing.T) { t.Run("SearchSegments", func(t *testing.T) { mockQN.EXPECT().SearchSegments(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil) + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + }, nil) req := &querypb.SearchRequest{} resp, err := server.SearchSegments(ctx, req) assert.NoError(t, err) @@ -215,22 +217,36 @@ func Test_NewServer(t *testing.T) { t.Run("Query", func(t *testing.T) { mockQN.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil) + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + }, nil) req := &querypb.QueryRequest{} resp, err := server.Query(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) + t.Run("QueryStream", func(t *testing.T) { + mockQN.EXPECT().QueryStream(mock.Anything, mock.Anything).Return(nil) + ret := server.QueryStream(nil, nil) + assert.Nil(t, ret) + }) + t.Run("QuerySegments", func(t *testing.T) { mockQN.EXPECT().QuerySegments(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil) + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + }, nil) req := &querypb.QueryRequest{} resp, err := server.QuerySegments(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) + t.Run("QueryStreamSegments", func(t *testing.T) { + mockQN.EXPECT().QueryStreamSegments(mock.Anything, mock.Anything).Return(nil) + ret := server.QueryStreamSegments(nil, nil) + assert.Nil(t, ret) + }) + t.Run("SyncReplicaSegments", func(t *testing.T) { mockQN.EXPECT().SyncReplicaSegments(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) req := &querypb.SyncReplicaSegmentsRequest{} diff --git a/internal/distributed/rootcoord/client/client.go b/internal/distributed/rootcoord/client/client.go index 72b6f6173e275..1f07f904bd631 100644 --- a/internal/distributed/rootcoord/client/client.go +++ b/internal/distributed/rootcoord/client/client.go @@ -20,6 +20,12 @@ import ( "context" "fmt" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" + "google.golang.org/grpc" + grpcCodes "google.golang.org/grpc/codes" + grpcStatus "google.golang.org/grpc/status" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -32,11 +38,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/zap" - "google.golang.org/grpc" - grpcCodes "google.golang.org/grpc/codes" - grpcStatus "google.golang.org/grpc/status" ) var Params *paramtable.ComponentParam = paramtable.Get() @@ -73,10 +74,6 @@ func NewClient(ctx context.Context, metaRoot string, etcdCli *clientv3.Client) ( } // Init initialize grpc parameters -func (c *Client) Init() error { - return nil -} - func (c *Client) newGrpcClient(cc *grpc.ClientConn) rootcoordpb.RootCoordClient { return rootcoordpb.NewRootCoordClient(cc) } @@ -101,21 +98,11 @@ func (c *Client) getRootCoordAddr() (string, error) { return ms.Address, nil } -// Start dummy -func (c *Client) Start() error { - return nil -} - -// Stop terminate grpc connection -func (c *Client) Stop() error { +// Close terminate grpc connection +func (c *Client) Close() error { return c.grpcClient.Close() } -// Register dummy -func (c *Client) Register() error { - return nil -} - func wrapGrpcCall[T any](ctx context.Context, c *Client, call func(grpcClient rootcoordpb.RootCoordClient) (*T, error)) (*T, error) { ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { @@ -130,28 +117,28 @@ func wrapGrpcCall[T any](ctx context.Context, c *Client, call func(grpcClient ro } // GetComponentStates TODO: timeout need to be propagated through ctx -func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (c *Client) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return wrapGrpcCall(ctx, c, func(client rootcoordpb.RootCoordClient) (*milvuspb.ComponentStates, error) { return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) } // GetTimeTickChannel get timetick channel name -func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Client) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return wrapGrpcCall(ctx, c, func(client rootcoordpb.RootCoordClient) (*milvuspb.StringResponse, error) { return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) } // GetStatisticsChannel just define a channel, not used currently -func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Client) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return wrapGrpcCall(ctx, c, func(client rootcoordpb.RootCoordClient) (*milvuspb.StringResponse, error) { return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) } // CreateCollection create collection -func (c *Client) CreateCollection(ctx context.Context, in *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { +func (c *Client) CreateCollection(ctx context.Context, in *milvuspb.CreateCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -163,7 +150,7 @@ func (c *Client) CreateCollection(ctx context.Context, in *milvuspb.CreateCollec } // DropCollection drop collection -func (c *Client) DropCollection(ctx context.Context, in *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { +func (c *Client) DropCollection(ctx context.Context, in *milvuspb.DropCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -175,7 +162,7 @@ func (c *Client) DropCollection(ctx context.Context, in *milvuspb.DropCollection } // HasCollection check collection existence -func (c *Client) HasCollection(ctx context.Context, in *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { +func (c *Client) HasCollection(ctx context.Context, in *milvuspb.HasCollectionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -187,7 +174,7 @@ func (c *Client) HasCollection(ctx context.Context, in *milvuspb.HasCollectionRe } // DescribeCollection return collection info -func (c *Client) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (c *Client) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -199,7 +186,7 @@ func (c *Client) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCo } // describeCollectionInternal return collection info -func (c *Client) describeCollectionInternal(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (c *Client) describeCollectionInternal(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -210,7 +197,7 @@ func (c *Client) describeCollectionInternal(ctx context.Context, in *milvuspb.De }) } -func (c *Client) DescribeCollectionInternal(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (c *Client) DescribeCollectionInternal(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { resp, err := c.describeCollectionInternal(ctx, in) status, ok := grpcStatus.FromError(err) if ok && status.Code() == grpcCodes.Unimplemented { @@ -220,7 +207,7 @@ func (c *Client) DescribeCollectionInternal(ctx context.Context, in *milvuspb.De } // ShowCollections list all collection names -func (c *Client) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { +func (c *Client) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -231,7 +218,7 @@ func (c *Client) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectio }) } -func (c *Client) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { +func (c *Client) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { request = typeutil.Clone(request) commonpbutil.UpdateMsgBase( request.GetBase(), @@ -243,7 +230,7 @@ func (c *Client) AlterCollection(ctx context.Context, request *milvuspb.AlterCol } // CreatePartition create partition -func (c *Client) CreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { +func (c *Client) CreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -255,7 +242,7 @@ func (c *Client) CreatePartition(ctx context.Context, in *milvuspb.CreatePartiti } // DropPartition drop partition -func (c *Client) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { +func (c *Client) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -267,7 +254,7 @@ func (c *Client) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRe } // HasPartition check partition existence -func (c *Client) HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { +func (c *Client) HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -279,7 +266,7 @@ func (c *Client) HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequ } // ShowPartitions list all partitions in collection -func (c *Client) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { +func (c *Client) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -291,7 +278,7 @@ func (c *Client) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitions } // showPartitionsInternal list all partitions in collection -func (c *Client) showPartitionsInternal(ctx context.Context, in *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { +func (c *Client) showPartitionsInternal(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -302,7 +289,7 @@ func (c *Client) showPartitionsInternal(ctx context.Context, in *milvuspb.ShowPa }) } -func (c *Client) ShowPartitionsInternal(ctx context.Context, in *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { +func (c *Client) ShowPartitionsInternal(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { resp, err := c.showPartitionsInternal(ctx, in) status, ok := grpcStatus.FromError(err) if ok && status.Code() == grpcCodes.Unimplemented { @@ -312,7 +299,7 @@ func (c *Client) ShowPartitionsInternal(ctx context.Context, in *milvuspb.ShowPa } // AllocTimestamp global timestamp allocator -func (c *Client) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { +func (c *Client) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestampRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -324,7 +311,7 @@ func (c *Client) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimest } // AllocID global ID allocator -func (c *Client) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { +func (c *Client) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -336,7 +323,7 @@ func (c *Client) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest) (* } // UpdateChannelTimeTick used to handle ChannelTimeTickMsg -func (c *Client) UpdateChannelTimeTick(ctx context.Context, in *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) { +func (c *Client) UpdateChannelTimeTick(ctx context.Context, in *internalpb.ChannelTimeTickMsg, opts ...grpc.CallOption) (*commonpb.Status, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -348,7 +335,7 @@ func (c *Client) UpdateChannelTimeTick(ctx context.Context, in *internalpb.Chann } // ShowSegments list all segments -func (c *Client) ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { +func (c *Client) ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequest, opts ...grpc.CallOption) (*milvuspb.ShowSegmentsResponse, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -360,7 +347,7 @@ func (c *Client) ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequ } // InvalidateCollectionMetaCache notifies RootCoord to release the collection cache in Proxies. -func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { +func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -372,7 +359,7 @@ func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb. } // ShowConfigurations gets specified configurations para of RootCoord -func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { +func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -384,7 +371,7 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon } // GetMetrics get metrics -func (c *Client) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (c *Client) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -396,7 +383,7 @@ func (c *Client) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) } // CreateAlias create collection alias -func (c *Client) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { +func (c *Client) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -408,7 +395,7 @@ func (c *Client) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasReque } // DropAlias drop collection alias -func (c *Client) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest) (*commonpb.Status, error) { +func (c *Client) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -420,7 +407,7 @@ func (c *Client) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest) } // AlterAlias alter collection alias -func (c *Client) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { +func (c *Client) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -432,40 +419,40 @@ func (c *Client) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest } // Import data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments -func (c *Client) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { +func (c *Client) Import(ctx context.Context, req *milvuspb.ImportRequest, opts ...grpc.CallOption) (*milvuspb.ImportResponse, error) { return wrapGrpcCall(ctx, c, func(client rootcoordpb.RootCoordClient) (*milvuspb.ImportResponse, error) { return client.Import(ctx, req) }) } // Check import task state from datanode -func (c *Client) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) { +func (c *Client) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest, opts ...grpc.CallOption) (*milvuspb.GetImportStateResponse, error) { return wrapGrpcCall(ctx, c, func(client rootcoordpb.RootCoordClient) (*milvuspb.GetImportStateResponse, error) { return client.GetImportState(ctx, req) }) } // List id array of all import tasks -func (c *Client) ListImportTasks(ctx context.Context, req *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) { +func (c *Client) ListImportTasks(ctx context.Context, req *milvuspb.ListImportTasksRequest, opts ...grpc.CallOption) (*milvuspb.ListImportTasksResponse, error) { return wrapGrpcCall(ctx, c, func(client rootcoordpb.RootCoordClient) (*milvuspb.ListImportTasksResponse, error) { return client.ListImportTasks(ctx, req) }) } // Report impot task state to rootcoord -func (c *Client) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) (*commonpb.Status, error) { +func (c *Client) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult, opts ...grpc.CallOption) (*commonpb.Status, error) { return wrapGrpcCall(ctx, c, func(client rootcoordpb.RootCoordClient) (*commonpb.Status, error) { return client.ReportImport(ctx, req) }) } -func (c *Client) CreateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) { +func (c *Client) CreateCredential(ctx context.Context, req *internalpb.CredentialInfo, opts ...grpc.CallOption) (*commonpb.Status, error) { return wrapGrpcCall(ctx, c, func(client rootcoordpb.RootCoordClient) (*commonpb.Status, error) { return client.CreateCredential(ctx, req) }) } -func (c *Client) GetCredential(ctx context.Context, req *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) { +func (c *Client) GetCredential(ctx context.Context, req *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -476,13 +463,13 @@ func (c *Client) GetCredential(ctx context.Context, req *rootcoordpb.GetCredenti }) } -func (c *Client) UpdateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) { +func (c *Client) UpdateCredential(ctx context.Context, req *internalpb.CredentialInfo, opts ...grpc.CallOption) (*commonpb.Status, error) { return wrapGrpcCall(ctx, c, func(client rootcoordpb.RootCoordClient) (*commonpb.Status, error) { return client.UpdateCredential(ctx, req) }) } -func (c *Client) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) { +func (c *Client) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCredentialRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -493,7 +480,7 @@ func (c *Client) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCrede }) } -func (c *Client) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) { +func (c *Client) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest, opts ...grpc.CallOption) (*milvuspb.ListCredUsersResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -504,7 +491,7 @@ func (c *Client) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersR }) } -func (c *Client) CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest) (*commonpb.Status, error) { +func (c *Client) CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -515,7 +502,7 @@ func (c *Client) CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest }) } -func (c *Client) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) (*commonpb.Status, error) { +func (c *Client) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -526,7 +513,7 @@ func (c *Client) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) (* }) } -func (c *Client) OperateUserRole(ctx context.Context, req *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) { +func (c *Client) OperateUserRole(ctx context.Context, req *milvuspb.OperateUserRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -537,7 +524,7 @@ func (c *Client) OperateUserRole(ctx context.Context, req *milvuspb.OperateUserR }) } -func (c *Client) SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) { +func (c *Client) SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest, opts ...grpc.CallOption) (*milvuspb.SelectRoleResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -548,7 +535,7 @@ func (c *Client) SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest }) } -func (c *Client) SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) { +func (c *Client) SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest, opts ...grpc.CallOption) (*milvuspb.SelectUserResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -559,7 +546,7 @@ func (c *Client) SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest }) } -func (c *Client) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) { +func (c *Client) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePrivilegeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -570,7 +557,7 @@ func (c *Client) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePriv }) } -func (c *Client) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) { +func (c *Client) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantRequest, opts ...grpc.CallOption) (*milvuspb.SelectGrantResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -581,7 +568,7 @@ func (c *Client) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantReque }) } -func (c *Client) ListPolicy(ctx context.Context, req *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { +func (c *Client) ListPolicy(ctx context.Context, req *internalpb.ListPolicyRequest, opts ...grpc.CallOption) (*internalpb.ListPolicyResponse, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -592,13 +579,13 @@ func (c *Client) ListPolicy(ctx context.Context, req *internalpb.ListPolicyReque }) } -func (c *Client) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { +func (c *Client) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { return wrapGrpcCall(ctx, c, func(client rootcoordpb.RootCoordClient) (*milvuspb.CheckHealthResponse, error) { return client.CheckHealth(ctx, req) }) } -func (c *Client) RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) { +func (c *Client) RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { req = typeutil.Clone(req) commonpbutil.UpdateMsgBase( req.GetBase(), @@ -609,7 +596,7 @@ func (c *Client) RenameCollection(ctx context.Context, req *milvuspb.RenameColle }) } -func (c *Client) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { +func (c *Client) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -628,7 +615,7 @@ func (c *Client) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabase return ret.(*commonpb.Status), err } -func (c *Client) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) { +func (c *Client) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), @@ -647,7 +634,7 @@ func (c *Client) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequ return ret.(*commonpb.Status), err } -func (c *Client) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) { +func (c *Client) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error) { in = typeutil.Clone(in) commonpbutil.UpdateMsgBase( in.GetBase(), diff --git a/internal/distributed/rootcoord/client/client_test.go b/internal/distributed/rootcoord/client/client_test.go index 3d339a4356c69..ca5aa24ca4613 100644 --- a/internal/distributed/rootcoord/client/client_test.go +++ b/internal/distributed/rootcoord/client/client_test.go @@ -25,17 +25,16 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/util/mock" + "github.com/stretchr/testify/assert" "go.uber.org/zap" - - "github.com/milvus-io/milvus/internal/proto/rootcoordpb" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/paramtable" "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/proxy" + "github.com/milvus-io/milvus/internal/util/mock" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" - "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func TestMain(m *testing.M) { @@ -71,15 +70,6 @@ func Test_NewClient(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, client) - err = client.Init() - assert.NoError(t, err) - - err = client.Start() - assert.NoError(t, err) - - err = client.Register() - assert.NoError(t, err) - checkFunc := func(retNotNil bool) { retCheck := func(notNil bool, ret interface{}, err error) { if notNil { @@ -92,15 +82,15 @@ func Test_NewClient(t *testing.T) { } { - r, err := client.GetComponentStates(ctx) + r, err := client.GetComponentStates(ctx, nil) retCheck(retNotNil, r, err) } { - r, err := client.GetTimeTickChannel(ctx) + r, err := client.GetTimeTickChannel(ctx, nil) retCheck(retNotNil, r, err) } { - r, err := client.GetStatisticsChannel(ctx) + r, err := client.GetStatisticsChannel(ctx, nil) retCheck(retNotNil, r, err) } { @@ -306,15 +296,15 @@ func Test_NewClient(t *testing.T) { assert.Error(t, err) } { - rTimeout, err := client.GetComponentStates(shortCtx) + rTimeout, err := client.GetComponentStates(shortCtx, nil) retCheck(rTimeout, err) } { - rTimeout, err := client.GetTimeTickChannel(shortCtx) + rTimeout, err := client.GetTimeTickChannel(shortCtx, nil) retCheck(rTimeout, err) } { - rTimeout, err := client.GetStatisticsChannel(shortCtx) + rTimeout, err := client.GetStatisticsChannel(shortCtx, nil) retCheck(rTimeout, err) } { @@ -474,6 +464,6 @@ func Test_NewClient(t *testing.T) { retCheck(rTimeout, err) } // clean up - err = client.Stop() + err = client.Close() assert.NoError(t, err) } diff --git a/internal/distributed/rootcoord/service.go b/internal/distributed/rootcoord/service.go index f02397b1a7fef..e7be73f655c70 100644 --- a/internal/distributed/rootcoord/service.go +++ b/internal/distributed/rootcoord/service.go @@ -24,9 +24,7 @@ import ( "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/pkg/tracer" - "github.com/milvus-io/milvus/pkg/util/interceptor" + "github.com/tikv/client-go/v2/txnkv" clientv3 "go.etcd.io/etcd/client/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/atomic" @@ -36,19 +34,24 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" + qcc "github.com/milvus-io/milvus/internal/distributed/querycoord/client" + "github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/rootcoord" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/paramtable" - - dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" - qcc "github.com/milvus-io/milvus/internal/distributed/querycoord/client" + "github.com/milvus-io/milvus/pkg/util/tikv" ) // Server grpc wrapper @@ -65,11 +68,12 @@ type Server struct { serverID atomic.Int64 etcdCli *clientv3.Client - dataCoord types.DataCoord - queryCoord types.QueryCoord + tikvCli *txnkv.Client + dataCoord types.DataCoordClient + queryCoord types.QueryCoordClient - newDataCoordClient func(string, *clientv3.Client) types.DataCoord - newQueryCoordClient func(string, *clientv3.Client) types.QueryCoord + newDataCoordClient func(string, *clientv3.Client) types.DataCoordClient + newQueryCoordClient func(string, *clientv3.Client) types.QueryCoordClient } func (s *Server) CreateDatabase(ctx context.Context, request *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { @@ -121,7 +125,7 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) } func (s *Server) setClient() { - s.newDataCoordClient = func(etcdMetaRoot string, etcdCli *clientv3.Client) types.DataCoord { + s.newDataCoordClient = func(etcdMetaRoot string, etcdCli *clientv3.Client) types.DataCoordClient { dsClient, err := dcc.NewClient(s.ctx, etcdMetaRoot, etcdCli) if err != nil { panic(err) @@ -129,7 +133,7 @@ func (s *Server) setClient() { return dsClient } - s.newQueryCoordClient = func(metaRootPath string, etcdCli *clientv3.Client) types.QueryCoord { + s.newQueryCoordClient = func(metaRootPath string, etcdCli *clientv3.Client) types.QueryCoordClient { qsClient, err := qcc.NewClient(s.ctx, metaRootPath, etcdCli) if err != nil { panic(err) @@ -152,9 +156,12 @@ func (s *Server) Run() error { return nil } +var getTiKVClient = tikv.GetTiKVClient + func (s *Server) init() error { - etcdConfig := ¶mtable.Get().EtcdCfg - Params := ¶mtable.Get().RootCoordGrpcServerCfg + params := paramtable.Get() + etcdConfig := ¶ms.EtcdCfg + rpcParams := ¶ms.RootCoordGrpcServerCfg log.Debug("init params done..") etcdCli, err := etcd.GetEtcdClient( @@ -171,10 +178,21 @@ func (s *Server) init() error { } s.etcdCli = etcdCli s.rootCoord.SetEtcdClient(s.etcdCli) - s.rootCoord.SetAddress(Params.GetAddress()) + s.rootCoord.SetAddress(rpcParams.GetAddress()) log.Debug("etcd connect done ...") - err = s.startGrpc(Params.Port.GetAsInt()) + if params.MetaStoreCfg.MetaStoreType.GetValue() == util.MetaStoreTypeTiKV { + log.Info("Connecting to tikv metadata storage.") + s.tikvCli, err = getTiKVClient(¶mtable.Get().TiKVCfg) + if err != nil { + log.Debug("RootCoord failed to connect to tikv", zap.Error(err)) + return err + } + s.rootCoord.SetTiKVClient(s.tikvCli) + log.Info("Connected to tikv. Using tikv as metadata storage.") + } + + err = s.startGrpc(rpcParams.Port.GetAsInt()) if err != nil { return err } @@ -184,15 +202,7 @@ func (s *Server) init() error { log.Debug("RootCoord start to create DataCoord client") dataCoord := s.newDataCoordClient(rootcoord.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli) s.dataCoord = dataCoord - if err = s.dataCoord.Init(); err != nil { - log.Error("RootCoord DataCoordClient Init failed", zap.Error(err)) - panic(err) - } - if err = s.dataCoord.Start(); err != nil { - log.Error("RootCoord DataCoordClient Start failed", zap.Error(err)) - panic(err) - } - if err := s.rootCoord.SetDataCoord(dataCoord); err != nil { + if err := s.rootCoord.SetDataCoordClient(dataCoord); err != nil { panic(err) } } @@ -201,15 +211,7 @@ func (s *Server) init() error { log.Debug("RootCoord start to create QueryCoord client") queryCoord := s.newQueryCoordClient(rootcoord.Params.EtcdCfg.MetaRootPath.GetValue(), s.etcdCli) s.queryCoord = queryCoord - if err := s.queryCoord.Init(); err != nil { - log.Error("RootCoord QueryCoordClient Init failed", zap.Error(err)) - panic(err) - } - if err := s.queryCoord.Start(); err != nil { - log.Error("RootCoord QueryCoordClient Start failed", zap.Error(err)) - panic(err) - } - if err := s.rootCoord.SetQueryCoord(queryCoord); err != nil { + if err := s.rootCoord.SetQueryCoordClient(queryCoord); err != nil { panic(err) } } @@ -228,19 +230,19 @@ func (s *Server) startGrpc(port int) error { func (s *Server) startGrpcLoop(port int) { defer s.wg.Done() Params := ¶mtable.Get().RootCoordGrpcServerCfg - var kaep = keepalive.EnforcementPolicy{ + kaep := keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, // If a client pings more than once every 5 seconds, terminate the connection PermitWithoutStream: true, // Allow pings even when there are no active streams } - var kasp = keepalive.ServerParameters{ + kasp := keepalive.ServerParameters{ Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } log.Debug("start grpc ", zap.Int("port", port)) lis, err := net.Listen("tcp", ":"+strconv.Itoa(port)) if err != nil { - log.Error("GrpcServer:failed to listen", zap.String("error", err.Error())) + log.Error("GrpcServer:failed to listen", zap.Error(err)) s.grpcErrChan <- err return } @@ -305,13 +307,16 @@ func (s *Server) Stop() error { if s.etcdCli != nil { defer s.etcdCli.Close() } + if s.tikvCli != nil { + defer s.tikvCli.Close() + } if s.dataCoord != nil { - if err := s.dataCoord.Stop(); err != nil { + if err := s.dataCoord.Close(); err != nil { log.Error("Failed to close dataCoord client", zap.Error(err)) } } if s.queryCoord != nil { - if err := s.queryCoord.Stop(); err != nil { + if err := s.queryCoord.Close(); err != nil { log.Error("Failed to close queryCoord client", zap.Error(err)) } } @@ -323,8 +328,7 @@ func (s *Server) Stop() error { log.Debug("Rootcoord begin to stop grpc server") s.cancel() if s.grpcServer != nil { - log.Debug("Graceful stop grpc server...") - s.grpcServer.GracefulStop() + utils.GracefulStopGRPCServer(s.grpcServer) } s.wg.Wait() return nil @@ -332,17 +336,17 @@ func (s *Server) Stop() error { // GetComponentStates gets the component states of RootCoord. func (s *Server) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { - return s.rootCoord.GetComponentStates(ctx) + return s.rootCoord.GetComponentStates(ctx, req) } // GetTimeTickChannel receiver time tick from proxy service, and put it into this channel func (s *Server) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { - return s.rootCoord.GetTimeTickChannel(ctx) + return s.rootCoord.GetTimeTickChannel(ctx, req) } // GetStatisticsChannel just define a channel, not used currently func (s *Server) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { - return s.rootCoord.GetStatisticsChannel(ctx) + return s.rootCoord.GetStatisticsChannel(ctx, req) } // CreateCollection creates a collection diff --git a/internal/distributed/rootcoord/service_test.go b/internal/distributed/rootcoord/service_test.go index 78c0d187ad8b0..e8843c032de65 100644 --- a/internal/distributed/rootcoord/service_test.go +++ b/internal/distributed/rootcoord/service_test.go @@ -25,18 +25,20 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/util/sessionutil" - - clientv3 "go.etcd.io/etcd/client/v3" - "github.com/stretchr/testify/assert" + "github.com/tikv/client-go/v2/txnkv" + clientv3 "go.etcd.io/etcd/client/v3" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/rootcoord" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tikv" ) type mockCore struct { @@ -76,15 +78,18 @@ func (m *mockCore) SetAddress(address string) { func (m *mockCore) SetEtcdClient(etcdClient *clientv3.Client) { } -func (m *mockCore) SetDataCoord(types.DataCoord) error { +func (m *mockCore) SetTiKVClient(client *txnkv.Client) { +} + +func (m *mockCore) SetDataCoordClient(client types.DataCoordClient) error { return nil } -func (m *mockCore) SetQueryCoord(types.QueryCoord) error { +func (m *mockCore) SetQueryCoordClient(client types.QueryCoordClient) error { return nil } -func (m *mockCore) SetProxyCreator(func(ctx context.Context, addr string, nodeID int64) (types.Proxy, error)) { +func (m *mockCore) SetProxyCreator(func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error)) { } func (m *mockCore) Register() error { @@ -104,25 +109,19 @@ func (m *mockCore) Stop() error { } type mockDataCoord struct { - types.DataCoord - initErr error - startErr error + types.DataCoordClient } -func (m *mockDataCoord) Init() error { - return m.initErr -} -func (m *mockDataCoord) Start() error { - return m.startErr +func (m *mockDataCoord) Close() error { + return nil } -func (m *mockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { + +func (m *mockDataCoord) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Healthy, }, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), SubcomponentStates: []*milvuspb.ComponentInfo{ { StateCode: commonpb.StateCode_Healthy, @@ -130,165 +129,208 @@ func (m *mockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.Compo }, }, nil } + func (m *mockDataCoord) Stop() error { return fmt.Errorf("stop error") } type mockQueryCoord struct { - types.QueryCoord + types.QueryCoordClient initErr error startErr error } -func (m *mockQueryCoord) Init() error { - return m.initErr -} - -func (m *mockQueryCoord) Start() error { - return m.startErr -} - -func (m *mockQueryCoord) Stop() error { +func (m *mockQueryCoord) Close() error { return fmt.Errorf("stop error") } func TestRun(t *testing.T) { paramtable.Init() - ctx, cancel := context.WithCancel(context.Background()) - svr := Server{ - rootCoord: &mockCore{}, - ctx: ctx, - cancel: cancel, - grpcErrChan: make(chan error), - } - rcServerConfig := ¶mtable.Get().RootCoordGrpcServerCfg - paramtable.Get().Save(rcServerConfig.Port.Key, "1000000") - err := svr.Run() - assert.Error(t, err) - assert.EqualError(t, err, "listen tcp: address 1000000: invalid port") - - svr.newDataCoordClient = func(string, *clientv3.Client) types.DataCoord { - return &mockDataCoord{} - } - svr.newQueryCoordClient = func(string, *clientv3.Client) types.QueryCoord { - return &mockQueryCoord{} - } - - paramtable.Get().Save(rcServerConfig.Port.Key, fmt.Sprintf("%d", rand.Int()%100+10000)) - etcdConfig := ¶mtable.Get().EtcdCfg - - rand.Seed(time.Now().UnixNano()) - randVal := rand.Int() - rootPath := fmt.Sprintf("/%d/test", randVal) - rootcoord.Params.Save("etcd.rootPath", rootPath) - - etcdCli, err := etcd.GetEtcdClient( - etcdConfig.UseEmbedEtcd.GetAsBool(), - etcdConfig.EtcdUseSSL.GetAsBool(), - etcdConfig.Endpoints.GetAsStrings(), - etcdConfig.EtcdTLSCert.GetValue(), - etcdConfig.EtcdTLSKey.GetValue(), - etcdConfig.EtcdTLSCACert.GetValue(), - etcdConfig.EtcdTLSMinVersion.GetValue()) - assert.NoError(t, err) - sessKey := path.Join(rootcoord.Params.EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot) - _, err = etcdCli.Delete(ctx, sessKey, clientv3.WithPrefix()) - assert.NoError(t, err) - err = svr.Run() - assert.NoError(t, err) - - t.Run("CheckHealth", func(t *testing.T) { - ret, err := svr.CheckHealth(ctx, nil) + parameters := []string{"tikv", "etcd"} + for _, v := range parameters { + paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, v) + ctx, cancel := context.WithCancel(context.Background()) + getTiKVClient = func(cfg *paramtable.TiKVConfig) (*txnkv.Client, error) { + return tikv.SetupLocalTxn(), nil + } + defer func() { + getTiKVClient = tikv.GetTiKVClient + }() + svr := Server{ + rootCoord: &mockCore{}, + ctx: ctx, + cancel: cancel, + grpcErrChan: make(chan error), + } + rcServerConfig := ¶mtable.Get().RootCoordGrpcServerCfg + paramtable.Get().Save(rcServerConfig.Port.Key, "1000000") + err := svr.Run() + assert.Error(t, err) + assert.EqualError(t, err, "listen tcp: address 1000000: invalid port") + + svr.newDataCoordClient = func(string, *clientv3.Client) types.DataCoordClient { + return &mockDataCoord{} + } + svr.newQueryCoordClient = func(string, *clientv3.Client) types.QueryCoordClient { + return &mockQueryCoord{} + } + + paramtable.Get().Save(rcServerConfig.Port.Key, fmt.Sprintf("%d", rand.Int()%100+10000)) + etcdConfig := ¶mtable.Get().EtcdCfg + + rand.Seed(time.Now().UnixNano()) + randVal := rand.Int() + rootPath := fmt.Sprintf("/%d/test", randVal) + rootcoord.Params.Save("etcd.rootPath", rootPath) + + etcdCli, err := etcd.GetEtcdClient( + etcdConfig.UseEmbedEtcd.GetAsBool(), + etcdConfig.EtcdUseSSL.GetAsBool(), + etcdConfig.Endpoints.GetAsStrings(), + etcdConfig.EtcdTLSCert.GetValue(), + etcdConfig.EtcdTLSKey.GetValue(), + etcdConfig.EtcdTLSCACert.GetValue(), + etcdConfig.EtcdTLSMinVersion.GetValue()) + assert.NoError(t, err) + sessKey := path.Join(rootcoord.Params.EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot) + _, err = etcdCli.Delete(ctx, sessKey, clientv3.WithPrefix()) + assert.NoError(t, err) + err = svr.Run() assert.NoError(t, err) - assert.Equal(t, true, ret.IsHealthy) - }) - t.Run("RenameCollection", func(t *testing.T) { - _, err := svr.RenameCollection(ctx, nil) + t.Run("CheckHealth", func(t *testing.T) { + ret, err := svr.CheckHealth(ctx, nil) + assert.NoError(t, err) + assert.Equal(t, true, ret.IsHealthy) + }) + + t.Run("RenameCollection", func(t *testing.T) { + _, err := svr.RenameCollection(ctx, nil) + assert.NoError(t, err) + }) + + t.Run("CreateDatabase", func(t *testing.T) { + ret, err := svr.CreateDatabase(ctx, nil) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, ret.ErrorCode) + }) + + t.Run("DropDatabase", func(t *testing.T) { + ret, err := svr.DropDatabase(ctx, nil) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, ret.ErrorCode) + }) + + t.Run("ListDatabases", func(t *testing.T) { + ret, err := svr.ListDatabases(ctx, nil) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, ret.GetStatus().GetErrorCode()) + }) + err = svr.Stop() assert.NoError(t, err) - }) - - t.Run("CreateDatabase", func(t *testing.T) { - ret, err := svr.CreateDatabase(ctx, nil) - assert.Nil(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, ret.ErrorCode) - }) - - t.Run("DropDatabase", func(t *testing.T) { - ret, err := svr.DropDatabase(ctx, nil) - assert.Nil(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, ret.ErrorCode) - }) - - t.Run("ListDatabases", func(t *testing.T) { - ret, err := svr.ListDatabases(ctx, nil) - assert.Nil(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, ret.Status.ErrorCode) - }) - err = svr.Stop() - assert.NoError(t, err) + } } func TestServerRun_DataCoordClientInitErr(t *testing.T) { paramtable.Init() - ctx := context.Background() - server, err := NewServer(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, server) + parameters := []string{"tikv", "etcd"} + for _, v := range parameters { + paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, v) + ctx := context.Background() + getTiKVClient = func(cfg *paramtable.TiKVConfig) (*txnkv.Client, error) { + return tikv.SetupLocalTxn(), nil + } + defer func() { + getTiKVClient = tikv.GetTiKVClient + }() + server, err := NewServer(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, server) - server.newDataCoordClient = func(string, *clientv3.Client) types.DataCoord { - return &mockDataCoord{initErr: errors.New("mock datacoord init error")} - } - assert.Panics(t, func() { server.Run() }) + server.newDataCoordClient = func(string, *clientv3.Client) types.DataCoordClient { + return &mockDataCoord{} + } + assert.Panics(t, func() { server.Run() }) - err = server.Stop() - assert.NoError(t, err) + err = server.Stop() + assert.NoError(t, err) + } } func TestServerRun_DataCoordClientStartErr(t *testing.T) { paramtable.Init() - ctx := context.Background() - server, err := NewServer(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, server) + parameters := []string{"tikv", "etcd"} + for _, v := range parameters { + paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, v) + ctx := context.Background() + getTiKVClient = func(cfg *paramtable.TiKVConfig) (*txnkv.Client, error) { + return tikv.SetupLocalTxn(), nil + } + defer func() { + getTiKVClient = tikv.GetTiKVClient + }() + server, err := NewServer(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, server) - server.newDataCoordClient = func(string, *clientv3.Client) types.DataCoord { - return &mockDataCoord{startErr: errors.New("mock datacoord start error")} - } - assert.Panics(t, func() { server.Run() }) + server.newDataCoordClient = func(string, *clientv3.Client) types.DataCoordClient { + return &mockDataCoord{} + } + assert.Panics(t, func() { server.Run() }) - err = server.Stop() - assert.NoError(t, err) + err = server.Stop() + assert.NoError(t, err) + } } func TestServerRun_QueryCoordClientInitErr(t *testing.T) { paramtable.Init() - ctx := context.Background() - server, err := NewServer(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, server) + parameters := []string{"tikv", "etcd"} + for _, v := range parameters { + paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, v) + ctx := context.Background() + getTiKVClient = func(cfg *paramtable.TiKVConfig) (*txnkv.Client, error) { + return tikv.SetupLocalTxn(), nil + } + defer func() { + getTiKVClient = tikv.GetTiKVClient + }() + server, err := NewServer(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, server) - server.newQueryCoordClient = func(string, *clientv3.Client) types.QueryCoord { - return &mockQueryCoord{initErr: errors.New("mock querycoord init error")} - } - assert.Panics(t, func() { server.Run() }) + server.newQueryCoordClient = func(string, *clientv3.Client) types.QueryCoordClient { + return &mockQueryCoord{initErr: errors.New("mock querycoord init error")} + } + assert.Panics(t, func() { server.Run() }) - err = server.Stop() - assert.NoError(t, err) + err = server.Stop() + assert.NoError(t, err) + } } func TestServer_QueryCoordClientStartErr(t *testing.T) { paramtable.Init() - ctx := context.Background() - server, err := NewServer(ctx, nil) - assert.NoError(t, err) - assert.NotNil(t, server) + parameters := []string{"tikv", "etcd"} + for _, v := range parameters { + paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, v) + ctx := context.Background() + getTiKVClient = func(cfg *paramtable.TiKVConfig) (*txnkv.Client, error) { + return tikv.SetupLocalTxn(), nil + } + defer func() { + getTiKVClient = tikv.GetTiKVClient + }() + server, err := NewServer(ctx, nil) + assert.NoError(t, err) + assert.NotNil(t, server) - server.newQueryCoordClient = func(string, *clientv3.Client) types.QueryCoord { - return &mockQueryCoord{startErr: errors.New("mock querycoord start error")} - } - assert.Panics(t, func() { server.Run() }) + server.newQueryCoordClient = func(string, *clientv3.Client) types.QueryCoordClient { + return &mockQueryCoord{startErr: errors.New("mock querycoord start error")} + } + assert.Panics(t, func() { server.Run() }) - err = server.Stop() - assert.NoError(t, err) + err = server.Stop() + assert.NoError(t, err) + } } diff --git a/internal/distributed/utils/util.go b/internal/distributed/utils/util.go new file mode 100644 index 0000000000000..f2cc161ead0b1 --- /dev/null +++ b/internal/distributed/utils/util.go @@ -0,0 +1,32 @@ +package utils + +import ( + "time" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func GracefulStopGRPCServer(s *grpc.Server) { + if s == nil { + return + } + ch := make(chan struct{}) + go func() { + defer close(ch) + log.Debug("try to graceful stop grpc server...") + // will block until all rpc finished. + s.GracefulStop() + }() + select { + case <-ch: + case <-time.After(paramtable.Get().ProxyGrpcServerCfg.GracefulStopTimeout.GetAsDuration(time.Second)): + // took too long, manually close grpc server + log.Debug("stop grpc server...") + s.Stop() + // concurrent GracefulStop should be interrupted + <-ch + } +} diff --git a/internal/distributed/utils/util_test.go b/internal/distributed/utils/util_test.go new file mode 100644 index 0000000000000..cf65f3feb1d10 --- /dev/null +++ b/internal/distributed/utils/util_test.go @@ -0,0 +1,20 @@ +package utils + +import ( + "testing" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func TestGracefulStopGrpcServer(t *testing.T) { + paramtable.Init() + + // expected close by gracefulStop + s1 := grpc.NewServer() + GracefulStopGRPCServer(s1) + + // expected not panic + GracefulStopGRPCServer(nil) +} diff --git a/internal/http/healthz/healthz_handler.go b/internal/http/healthz/healthz_handler.go index f0de6c3d7d4e2..7509710851e3a 100644 --- a/internal/http/healthz/healthz_handler.go +++ b/internal/http/healthz/healthz_handler.go @@ -22,10 +22,10 @@ import ( "fmt" "net/http" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/log" ) diff --git a/internal/http/server_test.go b/internal/http/server_test.go index 357f1ef0b352b..7cba7885d0bdb 100644 --- a/internal/http/server_test.go +++ b/internal/http/server_test.go @@ -20,17 +20,17 @@ import ( "bytes" "context" "encoding/json" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/internal/http/healthz" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/http/healthz" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -54,7 +54,6 @@ type HTTPServerTestSuite struct { func (suite *HTTPServerTestSuite) SetupSuite() { suite.server = httptest.NewServer(nil) registerDefaults() - } func (suite *HTTPServerTestSuite) TearDownSuite() { @@ -85,7 +84,7 @@ func (suite *HTTPServerTestSuite) TestDefaultLogHandler() { suite.Require().NoError(err) defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) suite.Require().NoError(err) suite.Equal("{\"level\":\"error\"}\n", string(body)) suite.Equal(zap.ErrorLevel, log.GetLevel()) @@ -101,7 +100,7 @@ func (suite *HTTPServerTestSuite) TestHealthzHandler() { resp, err := client.Do(req) suite.Nil(err) defer resp.Body.Close() - body, _ := ioutil.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) suite.Equal("OK", string(body)) req, _ = http.NewRequest(http.MethodGet, url, nil) @@ -109,7 +108,7 @@ func (suite *HTTPServerTestSuite) TestHealthzHandler() { resp, err = client.Do(req) suite.Nil(err) defer resp.Body.Close() - body, _ = ioutil.ReadAll(resp.Body) + body, _ = io.ReadAll(resp.Body) suite.Equal("{\"state\":\"OK\",\"detail\":[{\"name\":\"m1\",\"code\":1}]}", string(body)) healthz.Register(&MockIndicator{"m2", commonpb.StateCode_Abnormal}) @@ -118,7 +117,7 @@ func (suite *HTTPServerTestSuite) TestHealthzHandler() { resp, err = client.Do(req) suite.Nil(err) defer resp.Body.Close() - body, _ = ioutil.ReadAll(resp.Body) + body, _ = io.ReadAll(resp.Body) suite.Equal("{\"state\":\"component m2 state is Abnormal\",\"detail\":[{\"name\":\"m1\",\"code\":1},{\"name\":\"m2\",\"code\":2}]}", string(body)) } diff --git a/internal/indexnode/chunk_mgr_factory.go b/internal/indexnode/chunk_mgr_factory.go index dac70ce9cf53d..4d6894da3739c 100644 --- a/internal/indexnode/chunk_mgr_factory.go +++ b/internal/indexnode/chunk_mgr_factory.go @@ -6,7 +6,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -25,19 +24,22 @@ func NewChunkMgrFactory() *chunkMgrFactory { } func (m *chunkMgrFactory) NewChunkManager(ctx context.Context, config *indexpb.StorageConfig) (storage.ChunkManager, error) { - key := m.cacheKey(config.GetStorageType(), config.GetBucketName(), config.GetAddress()) - if v, ok := m.cached.Get(key); ok { - return v, nil - } - - chunkManagerFactory := storage.NewChunkManagerFactoryWithParam(Params) - mgr, err := chunkManagerFactory.NewPersistentStorageChunkManager(ctx) - if err != nil { - return nil, err - } - v, _ := m.cached.GetOrInsert(key, mgr) - log.Ctx(ctx).Info("index node successfully init chunk manager") - return v, nil + chunkManagerFactory := storage.NewChunkManagerFactory(config.GetStorageType(), + storage.RootPath(config.GetRootPath()), + storage.Address(config.GetAddress()), + storage.AccessKeyID(config.GetAccessKeyID()), + storage.SecretAccessKeyID(config.GetSecretAccessKey()), + storage.UseSSL(config.GetUseSSL()), + storage.BucketName(config.GetBucketName()), + storage.UseIAM(config.GetUseIAM()), + storage.CloudProvider(config.GetCloudProvider()), + storage.IAMEndpoint(config.GetIAMEndpoint()), + storage.UseVirtualHost(config.GetUseVirtualHost()), + storage.RequestTimeout(config.GetRequestTimeoutMs()), + storage.Region(config.GetRegion()), + storage.CreateBucket(true), + ) + return chunkManagerFactory.NewPersistentStorageChunkManager(ctx) } func (m *chunkMgrFactory) cacheKey(storageType, bucket, address string) string { diff --git a/internal/indexnode/chunkmgr_mock.go b/internal/indexnode/chunkmgr_mock.go index ee554fe796c05..f911372b3e43b 100644 --- a/internal/indexnode/chunkmgr_mock.go +++ b/internal/indexnode/chunkmgr_mock.go @@ -72,9 +72,7 @@ var ( } ) -var ( - mockChunkMgrIns = &mockChunkmgr{} -) +var mockChunkMgrIns = &mockChunkmgr{} type mockStorageFactory struct{} diff --git a/internal/indexnode/errors_test.go b/internal/indexnode/errors_test.go deleted file mode 100644 index a58d419149458..0000000000000 --- a/internal/indexnode/errors_test.go +++ /dev/null @@ -1,40 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package indexnode - -import ( - "testing" - - "go.uber.org/zap" - - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/typeutil" -) - -func TestMsgIndexNodeIsUnhealthy(t *testing.T) { - nodeIDList := []typeutil.UniqueID{1, 2, 3} - for _, nodeID := range nodeIDList { - log.Info("TestMsgIndexNodeIsUnhealthy", zap.String("msg", msgIndexNodeIsUnhealthy(nodeID))) - } -} - -func TestErrIndexNodeIsUnhealthy(t *testing.T) { - nodeIDList := []typeutil.UniqueID{1, 2, 3} - for _, nodeID := range nodeIDList { - log.Info("TestErrIndexNodeIsUnhealthy", zap.Error(errIndexNodeIsUnhealthy(nodeID))) - } -} diff --git a/internal/indexnode/etcd_mock.go b/internal/indexnode/etcd_mock.go index 40616ec7ab729..bf7039b266e28 100644 --- a/internal/indexnode/etcd_mock.go +++ b/internal/indexnode/etcd_mock.go @@ -2,7 +2,6 @@ package indexnode import ( "fmt" - "io/ioutil" "net/url" "os" "sync" @@ -25,7 +24,7 @@ var ( func startEmbedEtcd() { startSvr.Do(func() { - dir, err := ioutil.TempDir(os.TempDir(), "milvus_ut_etcd") + dir, err := os.MkdirTemp(os.TempDir(), "milvus_ut_etcd") if err != nil { panic(err) } diff --git a/internal/indexnode/indexnode.go b/internal/indexnode/indexnode.go index 64e36772316cd..f271a23a06a34 100644 --- a/internal/indexnode/indexnode.go +++ b/internal/indexnode/indexnode.go @@ -40,7 +40,6 @@ import ( "unsafe" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/util/sessionutil" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" @@ -50,10 +49,10 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/initcore" + "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/merr" @@ -74,6 +73,14 @@ var _ types.IndexNodeComponent = (*IndexNode)(nil) // Params is a GlobalParamTable singleton of indexnode var Params *paramtable.ComponentParam = paramtable.Get() +func getCurrentIndexVersion(v int32) int32 { + cCurrent := int32(C.GetCurrentIndexVersion()) + if cCurrent < v { + return cCurrent + } + return v +} + type taskKey struct { ClusterID string BuildID UniqueID @@ -127,7 +134,7 @@ func (i *IndexNode) Register() error { i.session.Register() metrics.NumNodes.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), typeutil.IndexNodeRole).Inc() - //start liveness check + // start liveness check i.session.LivenessCheck(i.loopCtx, func() { log.Error("Index Node disconnected from etcd, process will exit", zap.Int64("Server Id", i.session.ServerID)) if err := i.Stop(); err != nil { @@ -158,7 +165,7 @@ func (i *IndexNode) initSegcore() { cIndexSliceSize := C.int64_t(Params.CommonCfg.IndexSliceSize.GetAsInt64()) C.InitIndexSliceSize(cIndexSliceSize) - //set up thread pool for different priorities + // set up thread pool for different priorities cHighPriorityThreadCoreCoefficient := C.int64_t(paramtable.Get().CommonCfg.HighPriorityThreadCoreCoefficient.GetAsInt64()) C.InitHighPriorityThreadCoreCoefficient(cHighPriorityThreadCoreCoefficient) cMiddlePriorityThreadCoreCoefficient := C.int64_t(paramtable.Get().CommonCfg.MiddlePriorityThreadCoreCoefficient.GetAsInt64()) @@ -181,11 +188,12 @@ func (i *IndexNode) CloseSegcore() { } func (i *IndexNode) initSession() error { - i.session = sessionutil.NewSession(i.loopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), i.etcdCli) + i.session = sessionutil.NewSession(i.loopCtx, Params.EtcdCfg.MetaRootPath.GetValue(), i.etcdCli, sessionutil.WithEnableDisk(Params.IndexNodeCfg.EnableDisk.GetAsBool())) if i.session == nil { return errors.New("failed to initialize session") } i.session.Init(typeutil.IndexNodeRole, i.address, false, true) + sessionutil.SaveServerInfo(typeutil.IndexNodeRole, i.session.ServerID) return nil } @@ -273,7 +281,7 @@ func (i *IndexNode) SetEtcdClient(client *clientv3.Client) { } // GetComponentStates gets the component states of IndexNode. -func (i *IndexNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (i *IndexNode) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { log.RatedInfo(10, "get IndexNode components states ...") nodeID := common.NotRegisteredID if i.session != nil && i.session.Registered() { @@ -289,7 +297,7 @@ func (i *IndexNode) GetComponentStates(ctx context.Context) (*milvuspb.Component ret := &milvuspb.ComponentStates{ State: stateInfo, SubcomponentStates: nil, // todo add subcomponents states - Status: merr.Status(nil), + Status: merr.Success(), } log.RatedInfo(10, "IndexNode Component states", @@ -300,19 +308,19 @@ func (i *IndexNode) GetComponentStates(ctx context.Context) (*milvuspb.Component } // GetTimeTickChannel gets the time tick channel of IndexNode. -func (i *IndexNode) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (i *IndexNode) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { log.RatedInfo(10, "get IndexNode time tick channel ...") return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil } // GetStatisticsChannel gets the statistics channel of IndexNode. -func (i *IndexNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (i *IndexNode) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { log.RatedInfo(10, "get IndexNode statistics channel ...") return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil } @@ -322,17 +330,14 @@ func (i *IndexNode) GetNodeID() int64 { // ShowConfigurations returns the configurations of indexNode matching req.Pattern func (i *IndexNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - if !i.lifetime.Add(commonpbutil.IsHealthyOrStopping) { + if err := i.lifetime.Add(merr.IsHealthyOrStopping); err != nil { log.Warn("IndexNode.ShowConfigurations failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.String("req", req.Pattern), - zap.Error(errIndexNodeIsUnhealthy(paramtable.GetNodeID()))) + zap.Error(err)) return &internalpb.ShowConfigurationsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: msgIndexNodeIsUnhealthy(paramtable.GetNodeID()), - }, + Status: merr.Status(err), Configuations: nil, }, nil } @@ -347,7 +352,7 @@ func (i *IndexNode) ShowConfigurations(ctx context.Context, req *internalpb.Show } return &internalpb.ShowConfigurationsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Configuations: configList, }, nil } diff --git a/internal/indexnode/indexnode_mock.go b/internal/indexnode/indexnode_mock.go index 76bd8bef8d21c..fc1b9249ccea2 100644 --- a/internal/indexnode/indexnode_mock.go +++ b/internal/indexnode/indexnode_mock.go @@ -85,16 +85,16 @@ func NewIndexNodeMock() *Mock { StateCode: commonpb.StateCode_Healthy, }, SubcomponentStates: nil, - Status: merr.Status(nil), + Status: merr.Success(), }, nil }, CallGetStatisticsChannel: func(ctx context.Context) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil }, CallCreateJob: func(ctx context.Context, req *indexpb.CreateJobRequest) (*commonpb.Status, error) { - return merr.Status(nil), nil + return merr.Success(), nil }, CallQueryJobs: func(ctx context.Context, in *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error) { indexInfos := make([]*indexpb.IndexTaskInfo, 0) @@ -106,17 +106,17 @@ func NewIndexNodeMock() *Mock { }) } return &indexpb.QueryJobsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), ClusterID: in.ClusterID, IndexInfos: indexInfos, }, nil }, CallDropJobs: func(ctx context.Context, in *indexpb.DropJobsRequest) (*commonpb.Status, error) { - return merr.Status(nil), nil + return merr.Success(), nil }, CallGetJobStats: func(ctx context.Context, in *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { return &indexpb.GetJobStatsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), TotalJobNum: 1, EnqueueJobNum: 0, InProgressJobNum: 1, @@ -137,7 +137,7 @@ func NewIndexNodeMock() *Mock { }, CallShowConfigurations: func(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { return &internalpb.ShowConfigurationsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil }, } @@ -155,11 +155,11 @@ func (m *Mock) Stop() error { return m.CallStop() } -func (m *Mock) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (m *Mock) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { return m.CallGetComponentStates(ctx) } -func (m *Mock) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (m *Mock) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { return m.CallGetStatisticsChannel(ctx) } @@ -239,7 +239,7 @@ func getMockSystemInfoMetrics( resp, _ := metricsinfo.MarshalComponentInfos(nodeInfos) return &milvuspb.GetMetricsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.IndexNodeRole, paramtable.GetNodeID()), }, nil diff --git a/internal/indexnode/indexnode_service.go b/internal/indexnode/indexnode_service.go index 5cf58a3aec979..96f7f72ae93f3 100644 --- a/internal/indexnode/indexnode_service.go +++ b/internal/indexnode/indexnode_service.go @@ -33,7 +33,6 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -42,54 +41,55 @@ import ( ) func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest) (*commonpb.Status, error) { - if !i.lifetime.Add(commonpbutil.IsHealthy) { - stateCode := i.lifetime.GetState() - log.Ctx(ctx).Warn("index node not ready", zap.String("state", stateCode.String()), zap.String("ClusterID", req.ClusterID), zap.Int64("IndexBuildID", req.BuildID)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "state code is not healthy", - }, nil + log := log.Ctx(ctx).With( + zap.String("clusterID", req.GetClusterID()), + zap.Int64("indexBuildID", req.GetBuildID()), + ) + + if err := i.lifetime.Add(merr.IsHealthy); err != nil { + log.Warn("index node not ready", + zap.Error(err), + ) + return merr.Status(err), nil } defer i.lifetime.Done() - log.Ctx(ctx).Info("IndexNode building index ...", - zap.String("ClusterID", req.ClusterID), - zap.Int64("IndexBuildID", req.BuildID), - zap.Int64("IndexID", req.IndexID), - zap.String("IndexName", req.IndexName), - zap.String("IndexFilePrefix", req.IndexFilePrefix), - zap.Int64("IndexVersion", req.IndexVersion), - zap.Strings("DataPaths", req.DataPaths), - zap.Any("TypeParams", req.TypeParams), - zap.Any("IndexParams", req.IndexParams), - zap.Int64("num_rows", req.GetNumRows())) + log.Info("IndexNode building index ...", + zap.Int64("indexID", req.GetIndexID()), + zap.String("indexName", req.GetIndexName()), + zap.String("indexFilePrefix", req.GetIndexFilePrefix()), + zap.Int64("indexVersion", req.GetIndexVersion()), + zap.Strings("dataPaths", req.GetDataPaths()), + zap.Any("typeParams", req.GetTypeParams()), + zap.Any("indexParams", req.GetIndexParams()), + zap.Int64("numRows", req.GetNumRows()), + zap.Int32("current_index_version", req.GetCurrentIndexVersion()), + ) ctx, sp := otel.Tracer(typeutil.IndexNodeRole).Start(ctx, "IndexNode-CreateIndex", trace.WithAttributes( - attribute.Int64("IndexBuildID", req.BuildID), - attribute.String("ClusterID", req.ClusterID), + attribute.Int64("indexBuildID", req.GetBuildID()), + attribute.String("clusterID", req.GetClusterID()), )) defer sp.End() metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.TotalLabel).Inc() taskCtx, taskCancel := context.WithCancel(i.loopCtx) - if oldInfo := i.loadOrStoreTask(req.ClusterID, req.BuildID, &taskInfo{ + if oldInfo := i.loadOrStoreTask(req.GetClusterID(), req.GetBuildID(), &taskInfo{ cancel: taskCancel, - state: commonpb.IndexState_InProgress}); oldInfo != nil { - log.Ctx(ctx).Warn("duplicated index build task", zap.String("ClusterID", req.ClusterID), zap.Int64("BuildID", req.BuildID)) + state: commonpb.IndexState_InProgress, + }); oldInfo != nil { + err := merr.WrapErrIndexDuplicate(req.GetIndexName(), "building index task existed") + log.Warn("duplicated index build task", zap.Error(err)) metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_BuildIndexError, - Reason: "duplicated index build task", - }, nil + return merr.Status(err), nil } - cm, err := i.storageFactory.NewChunkManager(i.loopCtx, req.StorageConfig) + cm, err := i.storageFactory.NewChunkManager(i.loopCtx, req.GetStorageConfig()) if err != nil { - log.Ctx(ctx).Error("create chunk manager failed", zap.String("Bucket", req.StorageConfig.BucketName), - zap.String("AccessKey", req.StorageConfig.AccessKeyID), - zap.String("ClusterID", req.ClusterID), zap.Int64("IndexBuildID", req.BuildID)) + log.Error("create chunk manager failed", zap.String("bucket", req.GetStorageConfig().GetBucketName()), + zap.String("accessKey", req.GetStorageConfig().GetAccessKeyID()), + zap.Error(err), + ) + i.deleteTaskInfos(ctx, []taskKey{{ClusterID: req.GetClusterID(), BuildID: req.GetBuildID()}}) metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_BuildIndexError, - Reason: "create chunk manager failed", - }, nil + return merr.Status(err), nil } var task task if Params.CommonCfg.EnableStorageV2.GetAsBool() { @@ -98,8 +98,8 @@ func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest ident: fmt.Sprintf("%s/%d", req.ClusterID, req.BuildID), ctx: taskCtx, cancel: taskCancel, - BuildID: req.BuildID, - ClusterID: req.ClusterID, + BuildID: req.GetBuildID(), + ClusterID: req.GetClusterID(), node: i, req: req, cm: cm, @@ -113,8 +113,8 @@ func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest ident: fmt.Sprintf("%s/%d", req.ClusterID, req.BuildID), ctx: taskCtx, cancel: taskCancel, - BuildID: req.BuildID, - ClusterID: req.ClusterID, + BuildID: req.GetBuildID(), + ClusterID: req.GetClusterID(), node: i, req: req, cm: cm, @@ -123,51 +123,49 @@ func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest serializedSize: 0, } } - ret := merr.Status(nil) + ret := merr.Success() if err := i.sched.IndexBuildQueue.Enqueue(task); err != nil { - log.Ctx(ctx).Warn("IndexNode failed to schedule", zap.Int64("IndexBuildID", req.BuildID), zap.String("ClusterID", req.ClusterID), zap.Error(err)) - ret.ErrorCode = commonpb.ErrorCode_UnexpectedError - ret.Reason = err.Error() + log.Warn("IndexNode failed to schedule", + zap.Error(err)) + ret = merr.Status(err) metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.FailLabel).Inc() return ret, nil } metrics.IndexNodeBuildIndexTaskCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SuccessLabel).Inc() - log.Ctx(ctx).Info("IndexNode successfully scheduled", zap.Int64("IndexBuildID", req.BuildID), zap.String("ClusterID", req.ClusterID), zap.String("indexName", req.IndexName)) + log.Info("IndexNode successfully scheduled", + zap.String("indexName", req.GetIndexName())) return ret, nil } func (i *IndexNode) QueryJobs(ctx context.Context, req *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error) { log := log.Ctx(ctx).With( - zap.String("ClusterID", req.GetClusterID()), + zap.String("clusterID", req.GetClusterID()), ).WithRateGroup("in.queryJobs", 1, 60) - if !i.lifetime.Add(commonpbutil.IsHealthyOrStopping) { - stateCode := i.lifetime.GetState() - log.Warn("index node not ready", zap.String("state", stateCode.String())) + if err := i.lifetime.Add(merr.IsHealthyOrStopping); err != nil { + log.Warn("index node not ready", zap.Error(err)) return &indexpb.QueryJobsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "state code is not healthy", - }, + Status: merr.Status(err), }, nil } defer i.lifetime.Done() infos := make(map[UniqueID]*taskInfo) i.foreachTaskInfo(func(ClusterID string, buildID UniqueID, info *taskInfo) { - if ClusterID == req.ClusterID { + if ClusterID == req.GetClusterID() { infos[buildID] = &taskInfo{ - state: info.state, - fileKeys: common.CloneStringList(info.fileKeys), - serializedSize: info.serializedSize, - failReason: info.failReason, + state: info.state, + fileKeys: common.CloneStringList(info.fileKeys), + serializedSize: info.serializedSize, + failReason: info.failReason, + currentIndexVersion: info.currentIndexVersion, } } }) ret := &indexpb.QueryJobsResponse{ - Status: merr.Status(nil), - ClusterID: req.ClusterID, - IndexInfos: make([]*indexpb.IndexTaskInfo, 0, len(req.BuildIDs)), + Status: merr.Success(), + ClusterID: req.GetClusterID(), + IndexInfos: make([]*indexpb.IndexTaskInfo, 0, len(req.GetBuildIDs())), } - for i, buildID := range req.BuildIDs { + for i, buildID := range req.GetBuildIDs() { ret.IndexInfos = append(ret.IndexInfos, &indexpb.IndexTaskInfo{ BuildID: buildID, State: commonpb.IndexState_IndexStateNone, @@ -179,49 +177,47 @@ func (i *IndexNode) QueryJobs(ctx context.Context, req *indexpb.QueryJobsRequest ret.IndexInfos[i].IndexFileKeys = info.fileKeys ret.IndexInfos[i].SerializedSize = info.serializedSize ret.IndexInfos[i].FailReason = info.failReason + ret.IndexInfos[i].CurrentIndexVersion = info.currentIndexVersion log.RatedDebug(5, "querying index build task", - zap.Int64("IndexBuildID", buildID), zap.String("state", info.state.String()), - zap.String("fail reason", info.failReason)) + zap.Int64("indexBuildID", buildID), + zap.String("state", info.state.String()), + zap.String("reason", info.failReason), + ) } } return ret, nil } func (i *IndexNode) DropJobs(ctx context.Context, req *indexpb.DropJobsRequest) (*commonpb.Status, error) { - log.Ctx(ctx).Info("drop index build jobs", zap.String("ClusterID", req.ClusterID), zap.Int64s("IndexBuildIDs", req.BuildIDs)) - if !i.lifetime.Add(commonpbutil.IsHealthyOrStopping) { - stateCode := i.lifetime.GetState() - log.Ctx(ctx).Warn("index node not ready", zap.String("state", stateCode.String()), zap.String("ClusterID", req.ClusterID)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "state code is not healthy", - }, nil + log.Ctx(ctx).Info("drop index build jobs", + zap.String("clusterID", req.ClusterID), + zap.Int64s("indexBuildIDs", req.BuildIDs), + ) + if err := i.lifetime.Add(merr.IsHealthyOrStopping); err != nil { + log.Ctx(ctx).Warn("index node not ready", zap.Error(err), zap.String("clusterID", req.ClusterID)) + return merr.Status(err), nil } defer i.lifetime.Done() - keys := make([]taskKey, 0, len(req.BuildIDs)) - for _, buildID := range req.BuildIDs { - keys = append(keys, taskKey{ClusterID: req.ClusterID, BuildID: buildID}) + keys := make([]taskKey, 0, len(req.GetBuildIDs())) + for _, buildID := range req.GetBuildIDs() { + keys = append(keys, taskKey{ClusterID: req.GetClusterID(), BuildID: buildID}) } - infos := i.deleteTaskInfos(keys) + infos := i.deleteTaskInfos(ctx, keys) for _, info := range infos { if info.cancel != nil { info.cancel() } } - log.Ctx(ctx).Info("drop index build jobs success", zap.String("ClusterID", req.ClusterID), - zap.Int64s("IndexBuildIDs", req.BuildIDs)) - return merr.Status(nil), nil + log.Ctx(ctx).Info("drop index build jobs success", zap.String("clusterID", req.GetClusterID()), + zap.Int64s("indexBuildIDs", req.GetBuildIDs())) + return merr.Success(), nil } func (i *IndexNode) GetJobStats(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - if !i.lifetime.Add(commonpbutil.IsHealthyOrStopping) { - stateCode := i.lifetime.GetState() - log.Ctx(ctx).Warn("index node not ready", zap.String("state", stateCode.String())) + if err := i.lifetime.Add(merr.IsHealthyOrStopping); err != nil { + log.Ctx(ctx).Warn("index node not ready", zap.Error(err)) return &indexpb.GetJobStatsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "state code is not healthy", - }, + Status: merr.Status(err), }, nil } defer i.lifetime.Done() @@ -236,9 +232,13 @@ func (i *IndexNode) GetJobStats(ctx context.Context, req *indexpb.GetJobStatsReq if i.sched.buildParallel > unissued+active { slots = i.sched.buildParallel - unissued - active } - log.Ctx(ctx).Info("Get Index Job Stats", zap.Int("Unissued", unissued), zap.Int("Active", active), zap.Int("Slot", slots)) + log.Ctx(ctx).Info("Get Index Job Stats", + zap.Int("unissued", unissued), + zap.Int("active", active), + zap.Int("slot", slots), + ) return &indexpb.GetJobStatsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), TotalJobNum: int64(active) + int64(unissued), InProgressJobNum: int64(active), EnqueueJobNum: int64(unissued), @@ -251,35 +251,27 @@ func (i *IndexNode) GetJobStats(ctx context.Context, req *indexpb.GetJobStatsReq // GetMetrics gets the metrics info of IndexNode. // TODO(dragondriver): cache the Metrics and set a retention to the cache func (i *IndexNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - if !i.lifetime.Add(commonpbutil.IsHealthyOrStopping) { + if err := i.lifetime.Add(merr.IsHealthyOrStopping); err != nil { log.Ctx(ctx).Warn("IndexNode.GetMetrics failed", zap.Int64("nodeID", paramtable.GetNodeID()), - zap.String("req", req.Request), - zap.Error(errIndexNodeIsUnhealthy(paramtable.GetNodeID()))) + zap.String("req", req.GetRequest()), + zap.Error(err)) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: msgIndexNodeIsUnhealthy(paramtable.GetNodeID()), - }, - Response: "", + Status: merr.Status(err), }, nil } defer i.lifetime.Done() - metricType, err := metricsinfo.ParseMetricType(req.Request) + metricType, err := metricsinfo.ParseMetricType(req.GetRequest()) if err != nil { log.Ctx(ctx).Warn("IndexNode.GetMetrics failed to parse metric type", zap.Int64("nodeID", paramtable.GetNodeID()), - zap.String("req", req.Request), + zap.String("req", req.GetRequest()), zap.Error(err)) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - Response: "", + Status: merr.Status(err), }, nil } @@ -288,8 +280,8 @@ func (i *IndexNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequ log.Ctx(ctx).RatedDebug(60, "IndexNode.GetMetrics", zap.Int64("nodeID", paramtable.GetNodeID()), - zap.String("req", req.Request), - zap.String("metric_type", metricType), + zap.String("req", req.GetRequest()), + zap.String("metricType", metricType), zap.Error(err)) return metrics, nil @@ -297,14 +289,10 @@ func (i *IndexNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequ log.Ctx(ctx).RatedWarn(60, "IndexNode.GetMetrics failed, request metric type is not implemented yet", zap.Int64("nodeID", paramtable.GetNodeID()), - zap.String("req", req.Request), - zap.String("metric_type", metricType)) + zap.String("req", req.GetRequest()), + zap.String("metricType", metricType)) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: metricsinfo.MsgUnimplementedMetric, - }, - Response: "", + Status: merr.Status(merr.WrapErrMetricNotFound(metricType)), }, nil } diff --git a/internal/indexnode/indexnode_service_test.go b/internal/indexnode/indexnode_service_test.go index 1b3bab174f6db..255551d3e2fa6 100644 --- a/internal/indexnode/indexnode_service_test.go +++ b/internal/indexnode/indexnode_service_test.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" ) @@ -36,27 +37,27 @@ func TestAbnormalIndexNode(t *testing.T) { ctx := context.TODO() status, err := in.CreateJob(ctx, &indexpb.CreateJobRequest{}) assert.NoError(t, err) - assert.Equal(t, status.ErrorCode, commonpb.ErrorCode_UnexpectedError) + assert.ErrorIs(t, merr.Error(status), merr.ErrServiceNotReady) qresp, err := in.QueryJobs(ctx, &indexpb.QueryJobsRequest{}) assert.NoError(t, err) - assert.Equal(t, qresp.Status.ErrorCode, commonpb.ErrorCode_UnexpectedError) + assert.ErrorIs(t, merr.Error(qresp.GetStatus()), merr.ErrServiceNotReady) status, err = in.DropJobs(ctx, &indexpb.DropJobsRequest{}) assert.NoError(t, err) - assert.Equal(t, status.ErrorCode, commonpb.ErrorCode_UnexpectedError) + assert.ErrorIs(t, merr.Error(status), merr.ErrServiceNotReady) jobNumRsp, err := in.GetJobStats(ctx, &indexpb.GetJobStatsRequest{}) assert.NoError(t, err) - assert.Equal(t, jobNumRsp.Status.ErrorCode, commonpb.ErrorCode_UnexpectedError) + assert.ErrorIs(t, merr.Error(jobNumRsp.GetStatus()), merr.ErrServiceNotReady) metricsResp, err := in.GetMetrics(ctx, &milvuspb.GetMetricsRequest{}) - assert.NoError(t, err) - assert.Equal(t, metricsResp.Status.ErrorCode, commonpb.ErrorCode_UnexpectedError) + err = merr.CheckRPCCall(metricsResp, err) + assert.ErrorIs(t, err, merr.ErrServiceNotReady) configurationResp, err := in.ShowConfigurations(ctx, &internalpb.ShowConfigurationsRequest{}) - assert.NoError(t, err) - assert.Equal(t, configurationResp.Status.ErrorCode, commonpb.ErrorCode_UnexpectedError) + err = merr.CheckRPCCall(configurationResp, err) + assert.ErrorIs(t, err, merr.ErrServiceNotReady) } func TestGetMetrics(t *testing.T) { @@ -69,14 +70,12 @@ func TestGetMetrics(t *testing.T) { defer in.Stop() resp, err := in.GetMetrics(ctx, metricReq) assert.NoError(t, err) - assert.Equal(t, resp.Status.ErrorCode, commonpb.ErrorCode_Success) + assert.True(t, merr.Ok(resp.GetStatus())) t.Logf("Component: %s, Metrics: %s", resp.ComponentName, resp.Response) } func TestGetMetricsError(t *testing.T) { - var ( - ctx = context.TODO() - ) + ctx := context.TODO() in, err := NewMockIndexNodeComponent(ctx) assert.NoError(t, err) @@ -86,15 +85,14 @@ func TestGetMetricsError(t *testing.T) { } resp, err := in.GetMetrics(ctx, errReq) assert.NoError(t, err) - assert.Equal(t, resp.Status.ErrorCode, commonpb.ErrorCode_UnexpectedError) + assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError) unsupportedReq := &milvuspb.GetMetricsRequest{ Request: `{"metric_type": "application_info"}`, } resp, err = in.GetMetrics(ctx, unsupportedReq) assert.NoError(t, err) - assert.Equal(t, resp.Status.ErrorCode, commonpb.ErrorCode_UnexpectedError) - assert.Equal(t, resp.Status.Reason, metricsinfo.MsgUnimplementedMetric) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrMetricNotFound) } func TestMockFieldData(t *testing.T) { diff --git a/internal/indexnode/indexnode_test.go b/internal/indexnode/indexnode_test.go index a10c3eb5870c0..5c2ff6ccebac0 100644 --- a/internal/indexnode/indexnode_test.go +++ b/internal/indexnode/indexnode_test.go @@ -22,9 +22,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" ) //func TestRegister(t *testing.T) { @@ -404,20 +405,20 @@ import ( // t.Run("GetComponentStates", func(t *testing.T) { // resp, err := in.GetComponentStates(ctx) // assert.NoError(t, err) -// assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) +// assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // assert.Equal(t, commonpb.StateCode_Healthy, resp.State.StateCode) // }) // // t.Run("GetTimeTickChannel", func(t *testing.T) { // resp, err := in.GetTimeTickChannel(ctx) // assert.NoError(t, err) -// assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) +// assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // }) // // t.Run("GetStatisticsChannel", func(t *testing.T) { // resp, err := in.GetStatisticsChannel(ctx) // assert.NoError(t, err) -// assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) +// assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // }) // // t.Run("ShowConfigurations", func(t *testing.T) { @@ -432,7 +433,7 @@ import ( // // resp, err := in.ShowConfigurations(ctx, req) // assert.NoError(t, err) -// assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) +// assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // assert.Equal(t, 1, len(resp.Configuations)) // assert.Equal(t, "indexnode.port", resp.Configuations[0].Key) // }) @@ -464,28 +465,28 @@ func TestComponentState(t *testing.T) { paramtable.Init() in := NewIndexNode(ctx, factory) in.SetEtcdClient(getEtcdClient()) - state, err := in.GetComponentStates(ctx) + state, err := in.GetComponentStates(ctx, nil) assert.NoError(t, err) - assert.Equal(t, state.Status.ErrorCode, commonpb.ErrorCode_Success) + assert.Equal(t, state.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) assert.Equal(t, state.State.StateCode, commonpb.StateCode_Abnormal) assert.Nil(t, in.Init()) - state, err = in.GetComponentStates(ctx) + state, err = in.GetComponentStates(ctx, nil) assert.NoError(t, err) - assert.Equal(t, state.Status.ErrorCode, commonpb.ErrorCode_Success) + assert.Equal(t, state.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) assert.Equal(t, state.State.StateCode, commonpb.StateCode_Initializing) assert.Nil(t, in.Start()) - state, err = in.GetComponentStates(ctx) + state, err = in.GetComponentStates(ctx, nil) assert.NoError(t, err) - assert.Equal(t, state.Status.ErrorCode, commonpb.ErrorCode_Success) + assert.Equal(t, state.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) assert.Equal(t, state.State.StateCode, commonpb.StateCode_Healthy) assert.Nil(t, in.Stop()) assert.Nil(t, in.Stop()) - state, err = in.GetComponentStates(ctx) + state, err = in.GetComponentStates(ctx, nil) assert.NoError(t, err) - assert.Equal(t, state.Status.ErrorCode, commonpb.ErrorCode_Success) + assert.Equal(t, state.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) assert.Equal(t, state.State.StateCode, commonpb.StateCode_Abnormal) } @@ -498,9 +499,9 @@ func TestGetTimeTickChannel(t *testing.T) { ) paramtable.Init() in := NewIndexNode(ctx, factory) - ret, err := in.GetTimeTickChannel(ctx) + ret, err := in.GetTimeTickChannel(ctx, nil) assert.NoError(t, err) - assert.Equal(t, ret.Status.ErrorCode, commonpb.ErrorCode_Success) + assert.Equal(t, ret.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) } func TestGetStatisticChannel(t *testing.T) { @@ -513,9 +514,9 @@ func TestGetStatisticChannel(t *testing.T) { paramtable.Init() in := NewIndexNode(ctx, factory) - ret, err := in.GetStatisticsChannel(ctx) + ret, err := in.GetStatisticsChannel(ctx, nil) assert.NoError(t, err) - assert.Equal(t, ret.Status.ErrorCode, commonpb.ErrorCode_Success) + assert.Equal(t, ret.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) } func TestIndexTaskWhenStoppingNode(t *testing.T) { diff --git a/internal/indexnode/metrics_info.go b/internal/indexnode/metrics_info.go index e91708055938e..2eb6eef0493f7 100644 --- a/internal/indexnode/metrics_info.go +++ b/internal/indexnode/metrics_info.go @@ -19,7 +19,6 @@ package indexnode import ( "context" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/merr" @@ -64,17 +63,14 @@ func getSystemInfoMetrics( resp, err := metricsinfo.MarshalComponentInfos(nodeInfos) if err != nil { return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), Response: "", ComponentName: metricsinfo.ConstructComponentName(typeutil.IndexNodeRole, paramtable.GetNodeID()), }, nil } return &milvuspb.GetMetricsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.IndexNodeRole, paramtable.GetNodeID()), }, nil diff --git a/internal/indexnode/task.go b/internal/indexnode/task.go index 32aba92623bf6..212f3f842cec3 100644 --- a/internal/indexnode/task.go +++ b/internal/indexnode/task.go @@ -40,6 +40,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/indexparams" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" ) @@ -52,11 +53,12 @@ var ( type Blob = storage.Blob type taskInfo struct { - cancel context.CancelFunc - state commonpb.IndexState - fileKeys []string - serializedSize uint64 - failReason string + cancel context.CancelFunc + state commonpb.IndexState + fileKeys []string + serializedSize uint64 + failReason string + currentIndexVersion int32 // task statistics statistic *indexpb.JobInfo @@ -206,28 +208,29 @@ type indexBuildTask struct { cancel context.CancelFunc ctx context.Context - cm storage.ChunkManager - index indexcgowrapper.CodecIndex - savePaths []string - req *indexpb.CreateJobRequest - BuildID UniqueID - nodeID UniqueID - ClusterID string - collectionID UniqueID - partitionID UniqueID - segmentID UniqueID - fieldID UniqueID - fieldName string - fieldType schemapb.DataType - fieldData storage.FieldData - indexBlobs []*storage.Blob - newTypeParams map[string]string - newIndexParams map[string]string - serializedSize uint64 - tr *timerecord.TimeRecorder - queueDur time.Duration - statistic indexpb.JobInfo - node *IndexNode + cm storage.ChunkManager + index indexcgowrapper.CodecIndex + savePaths []string + req *indexpb.CreateJobRequest + currentIndexVersion int32 + BuildID UniqueID + nodeID UniqueID + ClusterID string + collectionID UniqueID + partitionID UniqueID + segmentID UniqueID + fieldID UniqueID + fieldName string + fieldType schemapb.DataType + fieldData storage.FieldData + indexBlobs []*storage.Blob + newTypeParams map[string]string + newIndexParams map[string]string + serializedSize uint64 + tr *timerecord.TimeRecorder + queueDur time.Duration + statistic indexpb.JobInfo + node *IndexNode } func (it *indexBuildTask) Reset() { @@ -313,8 +316,8 @@ func (it *indexBuildTask) LoadData(ctx context.Context) error { getValueByPath := func(path string) ([]byte, error) { data, err := it.cm.Read(ctx, path) if err != nil { - if errors.Is(err, ErrNoSuchKey) { - return nil, ErrNoSuchKey + if errors.Is(err, merr.ErrIoKeyNotFound) { + return nil, err } return nil, err } @@ -354,7 +357,7 @@ func (it *indexBuildTask) LoadData(ctx context.Context) error { } loadFieldDataLatency := it.tr.CtxRecord(ctx, "load field data done") - metrics.IndexNodeLoadFieldLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(loadFieldDataLatency.Milliseconds())) + metrics.IndexNodeLoadFieldLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(loadFieldDataLatency.Seconds()) err = it.decodeBlobs(ctx, blobs) if err != nil { @@ -463,6 +466,12 @@ func (it *indexBuildTask) BuildIndex(ctx context.Context) error { } } + it.currentIndexVersion = getCurrentIndexVersion(it.req.GetCurrentIndexVersion()) + if err := buildIndexInfo.AppendIndexEngineVersion(it.currentIndexVersion); err != nil { + log.Ctx(ctx).Warn("append index engine version failed", zap.Error(err)) + return err + } + it.index, err = indexcgowrapper.CreateIndex(ctx, buildIndexInfo) if err != nil { if it.index != nil && it.index.CleanLocalData() != nil { @@ -475,14 +484,13 @@ func (it *indexBuildTask) BuildIndex(ctx context.Context) error { } buildIndexLatency := it.tr.RecordSpan() - metrics.IndexNodeKnowhereBuildIndexLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(buildIndexLatency.Milliseconds())) + metrics.IndexNodeKnowhereBuildIndexLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(buildIndexLatency.Seconds()) log.Ctx(ctx).Info("Successfully build index", zap.Int64("buildID", it.BuildID), zap.Int64("Collection", it.collectionID), zap.Int64("SegmentID", it.segmentID)) return nil } func (it *indexBuildTask) SaveIndexFiles(ctx context.Context) error { - gcIndex := func() { if err := it.index.Delete(); err != nil { log.Ctx(ctx).Error("IndexNode indexBuildTask Execute CIndexDelete failed", zap.Error(err)) @@ -495,7 +503,7 @@ func (it *indexBuildTask) SaveIndexFiles(ctx context.Context) error { return err } encodeIndexFileDur := it.tr.Record("index serialize and upload done") - metrics.IndexNodeEncodeIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(encodeIndexFileDur.Milliseconds())) + metrics.IndexNodeEncodeIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(encodeIndexFileDur.Seconds()) // early release index for gc, and we can ensure that Delete is idempotent. gcIndex() @@ -511,10 +519,10 @@ func (it *indexBuildTask) SaveIndexFiles(ctx context.Context) error { } it.statistic.EndTime = time.Now().UnixMicro() - it.node.storeIndexFilesAndStatistic(it.ClusterID, it.BuildID, saveFileKeys, it.serializedSize, &it.statistic) + it.node.storeIndexFilesAndStatistic(it.ClusterID, it.BuildID, saveFileKeys, it.serializedSize, &it.statistic, it.currentIndexVersion) log.Ctx(ctx).Debug("save index files done", zap.Strings("IndexFiles", saveFileKeys)) saveIndexFileDur := it.tr.RecordSpan() - metrics.IndexNodeSaveIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(saveIndexFileDur.Milliseconds())) + metrics.IndexNodeSaveIndexFileLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(saveIndexFileDur.Seconds()) it.tr.Elapse("index building all done") log.Ctx(ctx).Info("Successfully save index files", zap.Int64("buildID", it.BuildID), zap.Int64("Collection", it.collectionID), zap.Int64("partition", it.partitionID), zap.Int64("SegmentId", it.segmentID)) @@ -524,12 +532,12 @@ func (it *indexBuildTask) SaveIndexFiles(ctx context.Context) error { func (it *indexBuildTask) parseFieldMetaFromBinlog(ctx context.Context) error { toLoadDataPaths := it.req.GetDataPaths() if len(toLoadDataPaths) == 0 { - return ErrEmptyInsertPaths + return merr.WrapErrParameterInvalidMsg("data insert path must be not empty") } data, err := it.cm.Read(ctx, toLoadDataPaths[0]) if err != nil { - if errors.Is(err, ErrNoSuchKey) { - return ErrNoSuchKey + if errors.Is(err, merr.ErrIoKeyNotFound) { + return err } return err } @@ -540,7 +548,7 @@ func (it *indexBuildTask) parseFieldMetaFromBinlog(ctx context.Context) error { return err } if len(insertData.Data) != 1 { - return errors.New("we expect only one field in deserialized insert data") + return merr.WrapErrParameterInvalidMsg("we expect only one field in deserialized insert data") } it.collectionID = collectionID @@ -561,11 +569,10 @@ func (it *indexBuildTask) decodeBlobs(ctx context.Context, blobs []*storage.Blob if err2 != nil { return err2 } - decodeDuration := it.tr.RecordSpan().Milliseconds() - metrics.IndexNodeDecodeFieldLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(decodeDuration)) + metrics.IndexNodeDecodeFieldLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(it.tr.RecordSpan().Seconds()) if len(insertData.Data) != 1 { - return errors.New("we expect only one field in deserialized insert data") + return merr.WrapErrParameterInvalidMsg("we expect only one field in deserialized insert data") } it.collectionID = collectionID it.partitionID = partitionID diff --git a/internal/indexnode/task_scheduler.go b/internal/indexnode/task_scheduler.go index 8beccd7792d62..ab7b9e6ed6b02 100644 --- a/internal/indexnode/task_scheduler.go +++ b/internal/indexnode/task_scheduler.go @@ -24,12 +24,12 @@ import ( "sync" "github.com/cockroachdb/errors" - "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -221,10 +221,10 @@ func (sched *TaskScheduler) processTask(t task, q TaskQueue) { pipelines := []func(context.Context) error{t.Prepare, t.BuildIndex, t.SaveIndexFiles} for _, fn := range pipelines { if err := wrap(fn); err != nil { - if err == errCancel { - log.Ctx(t.Ctx()).Warn("index build task canceled", zap.String("task", t.Name())) - t.SetState(commonpb.IndexState_Failed, err.Error()) - } else if errors.Is(err, ErrNoSuchKey) { + if errors.Is(err, errCancel) { + log.Ctx(t.Ctx()).Warn("index build task canceled, retry it", zap.String("task", t.Name())) + t.SetState(commonpb.IndexState_Retry, err.Error()) + } else if errors.Is(err, merr.ErrIoKeyNotFound) { t.SetState(commonpb.IndexState_Failed, err.Error()) } else { t.SetState(commonpb.IndexState_Retry, err.Error()) @@ -234,7 +234,7 @@ func (sched *TaskScheduler) processTask(t task, q TaskQueue) { } t.SetState(commonpb.IndexState_Finished, "") if indexBuildTask, ok := t.(*indexBuildTask); ok { - metrics.IndexNodeBuildIndexLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(indexBuildTask.tr.ElapseSpan().Milliseconds())) + metrics.IndexNodeBuildIndexLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(indexBuildTask.tr.ElapseSpan().Seconds()) metrics.IndexNodeIndexTaskLatencyInQueue.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(indexBuildTask.queueDur.Milliseconds())) } } diff --git a/internal/indexnode/task_scheduler_test.go b/internal/indexnode/task_scheduler_test.go index 46b6b9ac5235b..2393fd2b7e1b7 100644 --- a/internal/indexnode/task_scheduler_test.go +++ b/internal/indexnode/task_scheduler_test.go @@ -7,9 +7,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" ) type fakeTaskState int @@ -164,9 +165,9 @@ func TestIndexTaskScheduler(t *testing.T) { tasks := make([]task, 0) tasks = append(tasks, - newTask(fakeTaskEnqueued, nil, commonpb.IndexState_Failed), - newTask(fakeTaskPrepared, nil, commonpb.IndexState_Failed), - newTask(fakeTaskBuiltIndex, nil, commonpb.IndexState_Failed), + newTask(fakeTaskEnqueued, nil, commonpb.IndexState_Retry), + newTask(fakeTaskPrepared, nil, commonpb.IndexState_Retry), + newTask(fakeTaskBuiltIndex, nil, commonpb.IndexState_Retry), newTask(fakeTaskSavedIndexes, nil, commonpb.IndexState_Finished), newTask(fakeTaskSavedIndexes, map[fakeTaskState]error{fakeTaskSavedIndexes: fmt.Errorf("auth failed")}, commonpb.IndexState_Retry)) diff --git a/internal/indexnode/taskinfo_ops.go b/internal/indexnode/taskinfo_ops.go index 63607e6dddcc4..7a0680efa3b4c 100644 --- a/internal/indexnode/taskinfo_ops.go +++ b/internal/indexnode/taskinfo_ops.go @@ -5,11 +5,12 @@ import ( "time" "github.com/golang/protobuf/proto" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" - "go.uber.org/zap" ) func (i *IndexNode) loadOrStoreTask(ClusterID string, buildID UniqueID, info *taskInfo) *taskInfo { @@ -55,7 +56,14 @@ func (i *IndexNode) foreachTaskInfo(fn func(ClusterID string, buildID UniqueID, } } -func (i *IndexNode) storeIndexFilesAndStatistic(ClusterID string, buildID UniqueID, fileKeys []string, serializedSize uint64, statistic *indexpb.JobInfo) { +func (i *IndexNode) storeIndexFilesAndStatistic( + ClusterID string, + buildID UniqueID, + fileKeys []string, + serializedSize uint64, + statistic *indexpb.JobInfo, + currentIndexVersion int32, +) { key := taskKey{ClusterID: ClusterID, BuildID: buildID} i.stateLock.Lock() defer i.stateLock.Unlock() @@ -63,11 +71,12 @@ func (i *IndexNode) storeIndexFilesAndStatistic(ClusterID string, buildID Unique info.fileKeys = common.CloneStringList(fileKeys) info.serializedSize = serializedSize info.statistic = proto.Clone(statistic).(*indexpb.JobInfo) + info.currentIndexVersion = currentIndexVersion return } } -func (i *IndexNode) deleteTaskInfos(keys []taskKey) []*taskInfo { +func (i *IndexNode) deleteTaskInfos(ctx context.Context, keys []taskKey) []*taskInfo { i.stateLock.Lock() defer i.stateLock.Unlock() deleted := make([]*taskInfo, 0, len(keys)) @@ -76,6 +85,8 @@ func (i *IndexNode) deleteTaskInfos(keys []taskKey) []*taskInfo { if ok { deleted = append(deleted, info) delete(i.tasks, key) + log.Ctx(ctx).Info("delete task infos", + zap.String("cluster_id", key.ClusterID), zap.Int64("build_id", key.BuildID)) } } return deleted diff --git a/internal/indexnode/util.go b/internal/indexnode/util.go index 0ddab2d80bffe..07f41f8a048ce 100644 --- a/internal/indexnode/util.go +++ b/internal/indexnode/util.go @@ -31,6 +31,8 @@ func estimateFieldDataSize(dim int64, numRows int64, dataType schemapb.DataType) if dataType == schemapb.DataType_BinaryVector { return uint64(dim) / 8 * uint64(numRows), nil } - + if dataType == schemapb.DataType_Float16Vector { + return uint64(dim) * uint64(numRows) * 2, nil + } return 0, nil } diff --git a/internal/kv/etcd/embed_etcd_config_test.go b/internal/kv/etcd/embed_etcd_config_test.go index f75aff241f12b..1a121f6a09bf9 100644 --- a/internal/kv/etcd/embed_etcd_config_test.go +++ b/internal/kv/etcd/embed_etcd_config_test.go @@ -37,7 +37,7 @@ func TestEtcdConfigLoad(te *testing.T) { te.Setenv("etcd.data.dir", "etcd.test.data.dir") param.Init(paramtable.NewBaseTable()) - //clean up data + // clean up data defer func() { os.RemoveAll("etcd.test.data.dir") }() diff --git a/internal/kv/etcd/embed_etcd_kv.go b/internal/kv/etcd/embed_etcd_kv.go index 91cf404397773..668e1ba2daba3 100644 --- a/internal/kv/etcd/embed_etcd_kv.go +++ b/internal/kv/etcd/embed_etcd_kv.go @@ -30,8 +30,9 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/internal/kv" - "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/internal/kv/predicates" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" ) // implementation assertion @@ -60,7 +61,7 @@ func NewEmbededEtcdKV(cfg *embed.Config, rootPath string) (*EmbedEtcdKV, error) etcd: e, } - //wait until embed etcd is ready + // wait until embed etcd is ready select { case <-e.Server.ReadyNotify(): log.Info("Embedded etcd is ready!") @@ -77,7 +78,6 @@ func (kv *EmbedEtcdKV) Close() { kv.client.Close() kv.etcd.Close() }) - } // GetPath returns the full path by given key @@ -130,6 +130,7 @@ func (kv *EmbedEtcdKV) LoadWithPrefix(key string) ([]string, []string, error) { if err != nil { return nil, nil, err } + keys := make([]string, 0, resp.Count) values := make([]string, 0, resp.Count) for _, kv := range resp.Kvs { @@ -218,7 +219,7 @@ func (kv *EmbedEtcdKV) Load(key string) (string, error) { return "", err } if resp.Count <= 0 { - return "", common.NewKeyNotExistError(key) + return "", merr.WrapErrIoKeyNotFound(key) } return string(resp.Kvs[0].Value), nil @@ -234,7 +235,7 @@ func (kv *EmbedEtcdKV) LoadBytes(key string) ([]byte, error) { return nil, err } if resp.Count <= 0 { - return nil, common.NewKeyNotExistError(key) + return nil, merr.WrapErrIoKeyNotFound(key) } return resp.Kvs[0].Value, nil @@ -422,7 +423,12 @@ func (kv *EmbedEtcdKV) MultiRemove(keys []string) error { } // MultiSaveAndRemove saves the key-value pairs and removes the keys in a transaction. -func (kv *EmbedEtcdKV) MultiSaveAndRemove(saves map[string]string, removals []string) error { +func (kv *EmbedEtcdKV) MultiSaveAndRemove(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + cmps, err := parsePredicates(kv.rootPath, preds...) + if err != nil { + return err + } + ops := make([]clientv3.Op, 0, len(saves)+len(removals)) for key, value := range saves { ops = append(ops, clientv3.OpPut(path.Join(kv.rootPath, key), value)) @@ -435,8 +441,15 @@ func (kv *EmbedEtcdKV) MultiSaveAndRemove(saves map[string]string, removals []st ctx, cancel := context.WithTimeout(context.TODO(), RequestTimeout) defer cancel() - _, err := kv.client.Txn(ctx).If().Then(ops...).Commit() - return err + resp, err := kv.client.Txn(ctx).If(cmps...).Then(ops...).Commit() + if err != nil { + return err + } + + if !resp.Succeeded { + return merr.WrapErrIoFailedReason("failed to execute transaction") + } + return nil } // MultiSaveBytesAndRemove saves the key-value pairs and removes the keys in a transaction. @@ -475,21 +488,13 @@ func (kv *EmbedEtcdKV) WatchWithRevision(key string, revision int64) clientv3.Wa return rch } -func (kv *EmbedEtcdKV) MultiRemoveWithPrefix(keys []string) error { - ops := make([]clientv3.Op, 0, len(keys)) - for _, k := range keys { - op := clientv3.OpDelete(path.Join(kv.rootPath, k), clientv3.WithPrefix()) - ops = append(ops, op) +// MultiSaveAndRemoveWithPrefix saves kv in @saves and removes the keys with given prefix in @removals. +func (kv *EmbedEtcdKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + cmps, err := parsePredicates(kv.rootPath, preds...) + if err != nil { + return err } - ctx, cancel := context.WithTimeout(context.TODO(), RequestTimeout) - defer cancel() - _, err := kv.client.Txn(ctx).If().Then(ops...).Commit() - return err -} - -// MultiSaveAndRemoveWithPrefix saves kv in @saves and removes the keys with given prefix in @removals. -func (kv *EmbedEtcdKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string) error { ops := make([]clientv3.Op, 0, len(saves)+len(removals)) for key, value := range saves { ops = append(ops, clientv3.OpPut(path.Join(kv.rootPath, key), value)) @@ -502,8 +507,15 @@ func (kv *EmbedEtcdKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, rem ctx, cancel := context.WithTimeout(context.TODO(), RequestTimeout) defer cancel() - _, err := kv.client.Txn(ctx).If().Then(ops...).Commit() - return err + resp, err := kv.client.Txn(ctx).If(cmps...).Then(ops...).Commit() + if err != nil { + return err + } + + if !resp.Succeeded { + return merr.WrapErrIoFailedReason("failed to execute transaction") + } + return nil } // MultiSaveBytesAndRemoveWithPrefix saves kv in @saves and removes the keys with given prefix in @removals. diff --git a/internal/kv/etcd/embed_etcd_kv_test.go b/internal/kv/etcd/embed_etcd_kv_test.go index 314762f03f071..d41684ff2de74 100644 --- a/internal/kv/etcd/embed_etcd_kv_test.go +++ b/internal/kv/etcd/embed_etcd_kv_test.go @@ -18,15 +18,20 @@ package etcdkv_test import ( "fmt" + "path" "sort" "testing" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "golang.org/x/exp/maps" + "github.com/milvus-io/milvus/internal/kv" embed_etcd_kv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/kv/predicates" + "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -98,17 +103,20 @@ func TestEmbedEtcd(te *testing.T) { metaKv.GetPath("test1"), metaKv.GetPath("test2"), metaKv.GetPath("test1/a"), - metaKv.GetPath("test1/b")}, []string{"value1", "value2", "value_a", "value_b"}, nil}, + metaKv.GetPath("test1/b"), + }, []string{"value1", "value2", "value_a", "value_b"}, nil}, {"test1", []string{ metaKv.GetPath("test1"), metaKv.GetPath("test1/a"), - metaKv.GetPath("test1/b")}, []string{"value1", "value_a", "value_b"}, nil}, + metaKv.GetPath("test1/b"), + }, []string{"value1", "value_a", "value_b"}, nil}, {"test2", []string{metaKv.GetPath("test2")}, []string{"value2"}, nil}, {"", []string{ metaKv.GetPath("test1"), metaKv.GetPath("test2"), metaKv.GetPath("test1/a"), - metaKv.GetPath("test1/b")}, []string{"value1", "value2", "value_a", "value_b"}, nil}, + metaKv.GetPath("test1/b"), + }, []string{"value1", "value2", "value_a", "value_b"}, nil}, {"test1/a", []string{metaKv.GetPath("test1/a")}, []string{"value_a"}, nil}, {"a", []string{}, []string{}, nil}, {"root", []string{}, []string{}, nil}, @@ -203,17 +211,20 @@ func TestEmbedEtcd(te *testing.T) { metaKv.GetPath("test1"), metaKv.GetPath("test2"), metaKv.GetPath("test1/a"), - metaKv.GetPath("test1/b")}, [][]byte{[]byte("value1"), []byte("value2"), []byte("value_a"), []byte("value_b")}, nil}, + metaKv.GetPath("test1/b"), + }, [][]byte{[]byte("value1"), []byte("value2"), []byte("value_a"), []byte("value_b")}, nil}, {"test1", []string{ metaKv.GetPath("test1"), metaKv.GetPath("test1/a"), - metaKv.GetPath("test1/b")}, [][]byte{[]byte("value1"), []byte("value_a"), []byte("value_b")}, nil}, + metaKv.GetPath("test1/b"), + }, [][]byte{[]byte("value1"), []byte("value_a"), []byte("value_b")}, nil}, {"test2", []string{metaKv.GetPath("test2")}, [][]byte{[]byte("value2")}, nil}, {"", []string{ metaKv.GetPath("test1"), metaKv.GetPath("test2"), metaKv.GetPath("test1/a"), - metaKv.GetPath("test1/b")}, [][]byte{[]byte("value1"), []byte("value2"), []byte("value_a"), []byte("value_b")}, nil}, + metaKv.GetPath("test1/b"), + }, [][]byte{[]byte("value1"), []byte("value2"), []byte("value_a"), []byte("value_b")}, nil}, {"test1/a", []string{metaKv.GetPath("test1/a")}, [][]byte{[]byte("value_a")}, nil}, {"a", []string{}, [][]byte{}, nil}, {"root", []string{}, [][]byte{}, nil}, @@ -300,7 +311,6 @@ func TestEmbedEtcd(te *testing.T) { assert.ElementsMatch(t, test.expectedValues, values) assert.NotZero(t, revision) } - }) te.Run("etcdKV MultiSaveAndMultiLoad", func(t *testing.T) { @@ -522,7 +532,7 @@ func TestEmbedEtcd(te *testing.T) { assert.Empty(t, vs) }) - te.Run("etcdKV MultiRemoveWithPrefix", func(t *testing.T) { + te.Run("etcdKV MultiSaveAndRemoveWithPrefix", func(t *testing.T) { rootPath := "/etcd/test/root/multi_remove_with_prefix" metaKv, err := embed_etcd_kv.NewMetaKvFactory(rootPath, ¶m.EtcdCfg) require.NoError(t, err) @@ -539,45 +549,6 @@ func TestEmbedEtcd(te *testing.T) { "x/den/2": "200", } - err = metaKv.MultiSave(prepareTests) - require.NoError(t, err) - - multiRemoveWithPrefixTests := []struct { - prefix []string - - testKey string - expectedValue string - }{ - {[]string{"x/abc"}, "x/abc/1", ""}, - {[]string{}, "x/abc/2", ""}, - {[]string{}, "x/def/1", "10"}, - {[]string{}, "x/def/2", "20"}, - {[]string{}, "x/den/1", "100"}, - {[]string{}, "x/den/2", "200"}, - {[]string{}, "not-exist", ""}, - {[]string{"x/def", "x/den"}, "x/def/1", ""}, - {[]string{}, "x/def/1", ""}, - {[]string{}, "x/def/2", ""}, - {[]string{}, "x/den/1", ""}, - {[]string{}, "x/den/2", ""}, - {[]string{}, "not-exist", ""}, - } - - for _, test := range multiRemoveWithPrefixTests { - if len(test.prefix) > 0 { - err = metaKv.MultiRemoveWithPrefix(test.prefix) - assert.NoError(t, err) - } - - v, _ := metaKv.Load(test.testKey) - assert.Equal(t, test.expectedValue, v) - } - - k, v, err := metaKv.LoadWithPrefix("/") - assert.NoError(t, err) - assert.Zero(t, len(k)) - assert.Zero(t, len(v)) - // MultiSaveAndRemoveWithPrefix err = metaKv.MultiSave(prepareTests) require.NoError(t, err) @@ -597,7 +568,7 @@ func TestEmbedEtcd(te *testing.T) { } for _, test := range multiSaveAndRemoveWithPrefixTests { - k, _, err = metaKv.LoadWithPrefix(test.loadPrefix) + k, _, err := metaKv.LoadWithPrefix(test.loadPrefix) assert.NoError(t, err) assert.Equal(t, test.lengthBeforeRemove, len(k)) @@ -628,40 +599,6 @@ func TestEmbedEtcd(te *testing.T) { "x/den/2": []byte("200"), } - err = metaKv.MultiSaveBytes(prepareTests) - require.NoError(t, err) - - multiRemoveWithPrefixTests := []struct { - prefix []string - - testKey string - expectedValue []byte - }{ - {[]string{"x/abc"}, "x/abc/1", nil}, - {[]string{}, "x/abc/2", nil}, - {[]string{}, "x/def/1", []byte("10")}, - {[]string{}, "x/def/2", []byte("20")}, - {[]string{}, "x/den/1", []byte("100")}, - {[]string{}, "x/den/2", []byte("200")}, - {[]string{}, "not-exist", nil}, - {[]string{"x/def", "x/den"}, "x/def/1", nil}, - {[]string{}, "x/def/1", nil}, - {[]string{}, "x/def/2", nil}, - {[]string{}, "x/den/1", nil}, - {[]string{}, "x/den/2", nil}, - {[]string{}, "not-exist", nil}, - } - - for _, test := range multiRemoveWithPrefixTests { - if len(test.prefix) > 0 { - err = metaKv.MultiRemoveWithPrefix(test.prefix) - assert.NoError(t, err) - } - - v, _ := metaKv.LoadBytes(test.testKey) - assert.Equal(t, test.expectedValue, v) - } - k, v, err := metaKv.LoadBytesWithPrefix("/") assert.NoError(t, err) assert.Zero(t, len(k)) @@ -893,3 +830,90 @@ func TestEmbedEtcd(te *testing.T) { assert.False(t, has) }) } + +type EmbedEtcdKVSuite struct { + suite.Suite + + param *paramtable.ComponentParam + + rootPath string + kv kv.MetaKv +} + +func (s *EmbedEtcdKVSuite) SetupSuite() { + te := s.T() + te.Setenv(metricsinfo.DeployModeEnvKey, metricsinfo.StandaloneDeployMode) + param := new(paramtable.ComponentParam) + te.Setenv("etcd.use.embed", "true") + te.Setenv("etcd.config.path", "../../../configs/advanced/etcd.yaml") + + dir := te.TempDir() + te.Setenv("etcd.data.dir", dir) + + param.Init(paramtable.NewBaseTable()) + s.param = param +} + +func (s *EmbedEtcdKVSuite) SetupTest() { + s.rootPath = path.Join("unittest/etcdkv", funcutil.RandomString(8)) + + metaKv, err := embed_etcd_kv.NewMetaKvFactory(s.rootPath, &s.param.EtcdCfg) + s.Require().NoError(err) + s.kv = metaKv +} + +func (s *EmbedEtcdKVSuite) TearDownTest() { + if s.kv != nil { + s.kv.RemoveWithPrefix("") + s.kv.Close() + s.kv = nil + } +} + +func (s *EmbedEtcdKVSuite) TestTxnWithPredicates() { + etcdKV := s.kv + + prepareKV := map[string]string{ + "lease1": "1", + "lease2": "2", + } + + err := etcdKV.MultiSave(prepareKV) + s.Require().NoError(err) + + badPredicate := predicates.NewMockPredicate(s.T()) + badPredicate.EXPECT().Type().Return(0) + badPredicate.EXPECT().Target().Return(predicates.PredTargetValue) + + multiSaveAndRemovePredTests := []struct { + tag string + multiSave map[string]string + preds []predicates.Predicate + expectSuccess bool + }{ + {"predicate_ok", map[string]string{"a": "b"}, []predicates.Predicate{predicates.ValueEqual("lease1", "1")}, true}, + {"predicate_fail", map[string]string{"a": "b"}, []predicates.Predicate{predicates.ValueEqual("lease1", "2")}, false}, + {"bad_predicate", map[string]string{"a": "b"}, []predicates.Predicate{badPredicate}, false}, + } + + for _, test := range multiSaveAndRemovePredTests { + s.Run(test.tag, func() { + err := etcdKV.MultiSaveAndRemove(test.multiSave, nil, test.preds...) + if test.expectSuccess { + s.NoError(err) + } else { + s.Error(err) + } + err = etcdKV.MultiSaveAndRemoveWithPrefix(test.multiSave, nil, test.preds...) + if test.expectSuccess { + s.NoError(err) + } else { + s.Error(err) + } + }) + } +} + +func TestEmbedEtcdKV(t *testing.T) { + suite.Run(t, new(EmbedEtcdKVSuite)) +} diff --git a/internal/kv/etcd/embed_etcd_restart_test.go b/internal/kv/etcd/embed_etcd_restart_test.go index 85d918d8e52ac..97fa8600a33d1 100644 --- a/internal/kv/etcd/embed_etcd_restart_test.go +++ b/internal/kv/etcd/embed_etcd_restart_test.go @@ -36,7 +36,7 @@ func TestEtcdRestartLoad(te *testing.T) { param.Init(paramtable.NewBaseTable()) param.Save("etcd.config.path", "../../../configs/advanced/etcd.yaml") param.Save("etcd.data.dir", etcdDataDir) - //clean up data + // clean up data defer func() { err := os.RemoveAll(etcdDataDir) assert.NoError(te, err) @@ -79,7 +79,7 @@ func TestEtcdRestartLoad(te *testing.T) { embed := metaKv.(*embed_etcd_kv.EmbedEtcdKV) embed.Close() - //restart and check test result + // restart and check test result metaKv, _ = embed_etcd_kv.NewMetaKvFactory(rootPath, ¶m.EtcdCfg) for _, test := range saveAndLoadTests { diff --git a/internal/kv/etcd/etcd_kv.go b/internal/kv/etcd/etcd_kv.go index acb211133fca0..4802ac2e98965 100644 --- a/internal/kv/etcd/etcd_kv.go +++ b/internal/kv/etcd/etcd_kv.go @@ -26,9 +26,10 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" - "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/internal/kv/predicates" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/timerecord" ) @@ -205,7 +206,7 @@ func (kv *etcdKV) Load(key string) (string, error) { return "", err } if resp.Count <= 0 { - return "", common.NewKeyNotExistError(key) + return "", merr.WrapErrIoKeyNotFound(key) } CheckElapseAndWarn(start, "Slow etcd operation load", zap.String("key", key)) return string(resp.Kvs[0].Value), nil @@ -219,10 +220,10 @@ func (kv *etcdKV) LoadBytes(key string) ([]byte, error) { defer cancel() resp, err := kv.getEtcdMeta(ctx, key) if err != nil { - return []byte{}, err + return nil, err } if resp.Count <= 0 { - return []byte{}, common.NewKeyNotExistError(key) + return nil, merr.WrapErrIoKeyNotFound(key) } CheckElapseAndWarn(start, "Slow etcd operation load", zap.String("key", key)) return resp.Kvs[0].Value, nil @@ -443,7 +444,12 @@ func (kv *etcdKV) MultiRemove(keys []string) error { } // MultiSaveAndRemove saves the key-value pairs and removes the keys in a transaction. -func (kv *etcdKV) MultiSaveAndRemove(saves map[string]string, removals []string) error { +func (kv *etcdKV) MultiSaveAndRemove(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + cmps, err := parsePredicates(kv.rootPath, preds...) + if err != nil { + return err + } + start := time.Now() ops := make([]clientv3.Op, 0, len(saves)+len(removals)) var keys []string @@ -459,7 +465,7 @@ func (kv *etcdKV) MultiSaveAndRemove(saves map[string]string, removals []string) ctx, cancel := context.WithTimeout(context.TODO(), RequestTimeout) defer cancel() - _, err := kv.executeTxn(kv.getTxnWithCmp(ctx), ops...) + resp, err := kv.executeTxn(kv.getTxnWithCmp(ctx, cmps...), ops...) if err != nil { log.Warn("Etcd MultiSaveAndRemove error", zap.Any("saves", saves), @@ -467,9 +473,14 @@ func (kv *etcdKV) MultiSaveAndRemove(saves map[string]string, removals []string) zap.Int("saveLength", len(saves)), zap.Int("removeLength", len(removals)), zap.Error(err)) + return err } CheckElapseAndWarn(start, "Slow etcd operation multi save and remove", zap.Strings("keys", keys)) - return err + if !resp.Succeeded { + log.Warn("failed to executeTxn", zap.Any("resp", resp)) + return merr.WrapErrIoFailedReason("failed to execute transaction") + } + return nil } // MultiSaveBytesAndRemove saves the key-value pairs and removes the keys in a transaction. @@ -529,27 +540,13 @@ func (kv *etcdKV) WatchWithRevision(key string, revision int64) clientv3.WatchCh return rch } -// MultiRemoveWithPrefix removes the keys with given prefix. -func (kv *etcdKV) MultiRemoveWithPrefix(keys []string) error { - start := time.Now() - ops := make([]clientv3.Op, 0, len(keys)) - for _, k := range keys { - op := clientv3.OpDelete(path.Join(kv.rootPath, k), clientv3.WithPrefix()) - ops = append(ops, op) - } - ctx, cancel := context.WithTimeout(context.TODO(), RequestTimeout) - defer cancel() - - _, err := kv.executeTxn(kv.getTxnWithCmp(ctx), ops...) +// MultiSaveAndRemoveWithPrefix saves kv in @saves and removes the keys with given prefix in @removals. +func (kv *etcdKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + cmps, err := parsePredicates(kv.rootPath, preds...) if err != nil { - log.Warn("Etcd MultiRemoveWithPrefix error", zap.Strings("keys", keys), zap.Int("len", len(keys)), zap.Error(err)) + return err } - CheckElapseAndWarn(start, "Slow etcd operation multi remove with prefix", zap.Strings("keys", keys)) - return err -} -// MultiSaveAndRemoveWithPrefix saves kv in @saves and removes the keys with given prefix in @removals. -func (kv *etcdKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string) error { start := time.Now() ops := make([]clientv3.Op, 0, len(saves)) var keys []string @@ -565,7 +562,7 @@ func (kv *etcdKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals ctx, cancel := context.WithTimeout(context.TODO(), RequestTimeout) defer cancel() - _, err := kv.executeTxn(kv.getTxnWithCmp(ctx), ops...) + resp, err := kv.executeTxn(kv.getTxnWithCmp(ctx, cmps...), ops...) if err != nil { log.Warn("Etcd MultiSaveAndRemoveWithPrefix error", zap.Any("saves", saves), @@ -573,9 +570,13 @@ func (kv *etcdKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals zap.Int("saveLength", len(saves)), zap.Int("removeLength", len(removals)), zap.Error(err)) + return err } CheckElapseAndWarn(start, "Slow etcd operation multi save and move with prefix", zap.Strings("keys", keys)) - return err + if !resp.Succeeded { + return merr.WrapErrIoFailedReason("failed to execute transaction") + } + return nil } // MultiSaveBytesAndRemoveWithPrefix saves kv in @saves and removes the keys with given prefix in @removals. diff --git a/internal/kv/etcd/etcd_kv_test.go b/internal/kv/etcd/etcd_kv_test.go index 2548c0bf45a7d..76908530fba9d 100644 --- a/internal/kv/etcd/etcd_kv_test.go +++ b/internal/kv/etcd/etcd_kv_test.go @@ -14,11 +14,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package etcdkv_test +package etcdkv import ( "fmt" "os" + "path" "sort" "testing" "time" @@ -26,9 +27,11 @@ import ( "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + clientv3 "go.etcd.io/etcd/client/v3" "golang.org/x/exp/maps" - etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/kv/predicates" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -42,7 +45,15 @@ func TestMain(m *testing.M) { os.Exit(code) } -func TestEtcdKV_Load(te *testing.T) { +type EtcdKVSuite struct { + suite.Suite + + rootPath string + etcdCli *clientv3.Client + etcdKV *etcdKV +} + +func (s *EtcdKVSuite) SetupSuite() { etcdCli, err := etcd.GetEtcdClient( Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), Params.EtcdCfg.EtcdUseSSL.GetAsBool(), @@ -51,652 +62,647 @@ func TestEtcdKV_Load(te *testing.T) { Params.EtcdCfg.EtcdTLSKey.GetValue(), Params.EtcdCfg.EtcdTLSCACert.GetValue(), Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) - defer etcdCli.Close() - assert.NoError(te, err) - te.Run("etcdKV SaveAndLoad", func(t *testing.T) { - rootPath := "/etcd/test/root/saveandload" - etcdKV := etcdkv.NewEtcdKV(etcdCli, rootPath) - err = etcdKV.RemoveWithPrefix("") - require.NoError(t, err) - - defer etcdKV.Close() - defer etcdKV.RemoveWithPrefix("") - - saveAndLoadTests := []struct { - key string - value string - }{ - {"test1", "value1"}, - {"test2", "value2"}, - {"test1/a", "value_a"}, - {"test1/b", "value_b"}, - } + s.Require().NoError(err) - for i, test := range saveAndLoadTests { - if i < 4 { - err = etcdKV.Save(test.key, test.value) - assert.NoError(t, err) - } - - val, err := etcdKV.Load(test.key) - assert.NoError(t, err) - assert.Equal(t, test.value, val) - } + s.etcdCli = etcdCli +} - invalidLoadTests := []struct { - invalidKey string - }{ - {"t"}, - {"a"}, - {"test1a"}, - } +func (s *EtcdKVSuite) TearDownSuite() { + if s.etcdCli != nil { + s.etcdCli.Close() + } +} - for _, test := range invalidLoadTests { - val, err := etcdKV.Load(test.invalidKey) - assert.Error(t, err) - assert.Zero(t, val) - } +func (s *EtcdKVSuite) SetupTest() { + s.rootPath = path.Join("unittest/etcdkv", funcutil.RandomString(8)) + s.etcdKV = NewEtcdKV(s.etcdCli, s.rootPath) +} - loadPrefixTests := []struct { - prefix string - - expectedKeys []string - expectedValues []string - expectedError error - }{ - {"test", []string{ - etcdKV.GetPath("test1"), - etcdKV.GetPath("test2"), - etcdKV.GetPath("test1/a"), - etcdKV.GetPath("test1/b")}, []string{"value1", "value2", "value_a", "value_b"}, nil}, - {"test1", []string{ - etcdKV.GetPath("test1"), - etcdKV.GetPath("test1/a"), - etcdKV.GetPath("test1/b")}, []string{"value1", "value_a", "value_b"}, nil}, - {"test2", []string{etcdKV.GetPath("test2")}, []string{"value2"}, nil}, - {"", []string{ - etcdKV.GetPath("test1"), - etcdKV.GetPath("test2"), - etcdKV.GetPath("test1/a"), - etcdKV.GetPath("test1/b")}, []string{"value1", "value2", "value_a", "value_b"}, nil}, - {"test1/a", []string{etcdKV.GetPath("test1/a")}, []string{"value_a"}, nil}, - {"a", []string{}, []string{}, nil}, - {"root", []string{}, []string{}, nil}, - {"/etcd/test/root", []string{}, []string{}, nil}, - } +func (s *EtcdKVSuite) TearDownTest() { + s.etcdKV.RemoveWithPrefix("") + s.etcdKV.Close() +} - for _, test := range loadPrefixTests { - actualKeys, actualValues, err := etcdKV.LoadWithPrefix(test.prefix) - assert.ElementsMatch(t, test.expectedKeys, actualKeys) - assert.ElementsMatch(t, test.expectedValues, actualValues) - assert.Equal(t, test.expectedError, err) - } +func (s *EtcdKVSuite) TestSaveLoad() { + etcdKV := s.etcdKV + saveAndLoadTests := []struct { + key string + value string + }{ + {"test1", "value1"}, + {"test2", "value2"}, + {"test1/a", "value_a"}, + {"test1/b", "value_b"}, + } - removeTests := []struct { - validKey string - invalidKey string - }{ - {"test1", "abc"}, - {"test1/a", "test1/lskfjal"}, - {"test1/b", "test1/b"}, - {"test2", "-"}, + for i, test := range saveAndLoadTests { + if i < 4 { + err := etcdKV.Save(test.key, test.value) + s.Require().NoError(err) } - for _, test := range removeTests { - err = etcdKV.Remove(test.validKey) - assert.NoError(t, err) + val, err := etcdKV.Load(test.key) + s.Require().NoError(err) + s.Equal(test.value, val) + } - _, err = etcdKV.Load(test.validKey) - assert.Error(t, err) + invalidLoadTests := []struct { + invalidKey string + }{ + {"t"}, + {"a"}, + {"test1a"}, + } - err = etcdKV.Remove(test.validKey) - assert.NoError(t, err) - err = etcdKV.Remove(test.invalidKey) - assert.NoError(t, err) - } - }) + for _, test := range invalidLoadTests { + val, err := etcdKV.Load(test.invalidKey) + s.Error(err) + s.Zero(val) + } - te.Run("etcdKV SaveAndLoadBytes", func(t *testing.T) { - rootPath := "/etcd/test/root/saveandloadbytes" - etcdKV := etcdkv.NewEtcdKV(etcdCli, rootPath) - err = etcdKV.RemoveWithPrefix("") - require.NoError(t, err) - - defer etcdKV.Close() - defer etcdKV.RemoveWithPrefix("") - - saveAndLoadTests := []struct { - key string - value string - }{ - {"test1", "value1"}, - {"test2", "value2"}, - {"test1/a", "value_a"}, - {"test1/b", "value_b"}, - } + loadPrefixTests := []struct { + prefix string + + expectedKeys []string + expectedValues []string + expectedError error + }{ + {"test", []string{ + etcdKV.GetPath("test1"), + etcdKV.GetPath("test2"), + etcdKV.GetPath("test1/a"), + etcdKV.GetPath("test1/b"), + }, []string{"value1", "value2", "value_a", "value_b"}, nil}, + {"test1", []string{ + etcdKV.GetPath("test1"), + etcdKV.GetPath("test1/a"), + etcdKV.GetPath("test1/b"), + }, []string{"value1", "value_a", "value_b"}, nil}, + {"test2", []string{etcdKV.GetPath("test2")}, []string{"value2"}, nil}, + {"", []string{ + etcdKV.GetPath("test1"), + etcdKV.GetPath("test2"), + etcdKV.GetPath("test1/a"), + etcdKV.GetPath("test1/b"), + }, []string{"value1", "value2", "value_a", "value_b"}, nil}, + {"test1/a", []string{etcdKV.GetPath("test1/a")}, []string{"value_a"}, nil}, + {"a", []string{}, []string{}, nil}, + {"root", []string{}, []string{}, nil}, + {"/etcd/test/root", []string{}, []string{}, nil}, + } - for i, test := range saveAndLoadTests { - if i < 4 { - err = etcdKV.SaveBytes(test.key, []byte(test.value)) - assert.NoError(t, err) - } + for _, test := range loadPrefixTests { + actualKeys, actualValues, err := etcdKV.LoadWithPrefix(test.prefix) + s.ElementsMatch(test.expectedKeys, actualKeys) + s.ElementsMatch(test.expectedValues, actualValues) + s.Equal(test.expectedError, err) + } - val, err := etcdKV.LoadBytes(test.key) - assert.NoError(t, err) - assert.Equal(t, test.value, string(val)) - } + removeTests := []struct { + validKey string + invalidKey string + }{ + {"test1", "abc"}, + {"test1/a", "test1/lskfjal"}, + {"test1/b", "test1/b"}, + {"test2", "-"}, + } - invalidLoadTests := []struct { - invalidKey string - }{ - {"t"}, - {"a"}, - {"test1a"}, - } + for _, test := range removeTests { + err := etcdKV.Remove(test.validKey) + s.NoError(err) - for _, test := range invalidLoadTests { - val, err := etcdKV.LoadBytes(test.invalidKey) - assert.Error(t, err) - assert.Zero(t, string(val)) - } + _, err = etcdKV.Load(test.validKey) + s.Error(err) - loadPrefixTests := []struct { - prefix string - - expectedKeys []string - expectedValues []string - expectedError error - }{ - {"test", []string{ - etcdKV.GetPath("test1"), - etcdKV.GetPath("test2"), - etcdKV.GetPath("test1/a"), - etcdKV.GetPath("test1/b")}, []string{"value1", "value2", "value_a", "value_b"}, nil}, - {"test1", []string{ - etcdKV.GetPath("test1"), - etcdKV.GetPath("test1/a"), - etcdKV.GetPath("test1/b")}, []string{"value1", "value_a", "value_b"}, nil}, - {"test2", []string{etcdKV.GetPath("test2")}, []string{"value2"}, nil}, - {"", []string{ - etcdKV.GetPath("test1"), - etcdKV.GetPath("test2"), - etcdKV.GetPath("test1/a"), - etcdKV.GetPath("test1/b")}, []string{"value1", "value2", "value_a", "value_b"}, nil}, - {"test1/a", []string{etcdKV.GetPath("test1/a")}, []string{"value_a"}, nil}, - {"a", []string{}, []string{}, nil}, - {"root", []string{}, []string{}, nil}, - {"/etcd/test/root", []string{}, []string{}, nil}, - } + err = etcdKV.Remove(test.validKey) + s.NoError(err) + err = etcdKV.Remove(test.invalidKey) + s.NoError(err) + } +} - for _, test := range loadPrefixTests { - actualKeys, actualValues, err := etcdKV.LoadBytesWithPrefix(test.prefix) - actualStringValues := make([]string, len(actualValues)) - for i := range actualValues { - actualStringValues[i] = string(actualValues[i]) - } - assert.ElementsMatch(t, test.expectedKeys, actualKeys) - assert.ElementsMatch(t, test.expectedValues, actualStringValues) - assert.Equal(t, test.expectedError, err) - - actualKeys, actualValues, versions, err := etcdKV.LoadBytesWithPrefix2(test.prefix) - actualStringValues = make([]string, len(actualValues)) - for i := range actualValues { - actualStringValues[i] = string(actualValues[i]) - } - assert.ElementsMatch(t, test.expectedKeys, actualKeys) - assert.ElementsMatch(t, test.expectedValues, actualStringValues) - assert.NotZero(t, versions) - assert.Equal(t, test.expectedError, err) - } +func (s *EtcdKVSuite) TestSaveAndLoadBytes() { + etcdKV := s.etcdKV + + saveAndLoadTests := []struct { + key string + value string + }{ + {"test1", "value1"}, + {"test2", "value2"}, + {"test1/a", "value_a"}, + {"test1/b", "value_b"}, + } - removeTests := []struct { - validKey string - invalidKey string - }{ - {"test1", "abc"}, - {"test1/a", "test1/lskfjal"}, - {"test1/b", "test1/b"}, - {"test2", "-"}, + for i, test := range saveAndLoadTests { + if i < 4 { + err := etcdKV.SaveBytes(test.key, []byte(test.value)) + s.Require().NoError(err) } - for _, test := range removeTests { - err = etcdKV.Remove(test.validKey) - assert.NoError(t, err) + val, err := etcdKV.LoadBytes(test.key) + s.NoError(err) + s.Equal(test.value, string(val)) + } - _, err = etcdKV.Load(test.validKey) - assert.Error(t, err) + invalidLoadTests := []struct { + invalidKey string + }{ + {"t"}, + {"a"}, + {"test1a"}, + } - err = etcdKV.Remove(test.validKey) - assert.NoError(t, err) - err = etcdKV.Remove(test.invalidKey) - assert.NoError(t, err) - } - }) + for _, test := range invalidLoadTests { + val, err := etcdKV.LoadBytes(test.invalidKey) + s.Error(err) + s.Zero(string(val)) + } - te.Run("etcdKV LoadBytesWithRevision", func(t *testing.T) { - rootPath := "/etcd/test/root/LoadBytesWithRevision" - etcdKV := etcdkv.NewEtcdKV(etcdCli, rootPath) - - defer etcdKV.Close() - defer etcdKV.RemoveWithPrefix("") - - prepareKV := []struct { - inKey string - inValue string - }{ - {"a", "a_version1"}, - {"b", "b_version2"}, - {"a", "a_version3"}, - {"c", "c_version4"}, - {"a/suba", "a_version5"}, - } + loadPrefixTests := []struct { + prefix string + + expectedKeys []string + expectedValues []string + expectedError error + }{ + {"test", []string{ + etcdKV.GetPath("test1"), + etcdKV.GetPath("test2"), + etcdKV.GetPath("test1/a"), + etcdKV.GetPath("test1/b"), + }, []string{"value1", "value2", "value_a", "value_b"}, nil}, + {"test1", []string{ + etcdKV.GetPath("test1"), + etcdKV.GetPath("test1/a"), + etcdKV.GetPath("test1/b"), + }, []string{"value1", "value_a", "value_b"}, nil}, + {"test2", []string{etcdKV.GetPath("test2")}, []string{"value2"}, nil}, + {"", []string{ + etcdKV.GetPath("test1"), + etcdKV.GetPath("test2"), + etcdKV.GetPath("test1/a"), + etcdKV.GetPath("test1/b"), + }, []string{"value1", "value2", "value_a", "value_b"}, nil}, + {"test1/a", []string{etcdKV.GetPath("test1/a")}, []string{"value_a"}, nil}, + {"a", []string{}, []string{}, nil}, + {"root", []string{}, []string{}, nil}, + {"/etcd/test/root", []string{}, []string{}, nil}, + } - for _, test := range prepareKV { - err = etcdKV.SaveBytes(test.inKey, []byte(test.inValue)) - require.NoError(t, err) - } + for _, test := range loadPrefixTests { + actualKeys, actualValues, err := etcdKV.LoadBytesWithPrefix(test.prefix) + actualStringValues := make([]string, len(actualValues)) + for i := range actualValues { + actualStringValues[i] = string(actualValues[i]) + } + s.ElementsMatch(test.expectedKeys, actualKeys) + s.ElementsMatch(test.expectedValues, actualStringValues) + s.Equal(test.expectedError, err) + + actualKeys, actualValues, versions, err := etcdKV.LoadBytesWithPrefix2(test.prefix) + actualStringValues = make([]string, len(actualValues)) + for i := range actualValues { + actualStringValues[i] = string(actualValues[i]) + } + s.ElementsMatch(test.expectedKeys, actualKeys) + s.ElementsMatch(test.expectedValues, actualStringValues) + s.NotZero(versions) + s.Equal(test.expectedError, err) + } - loadWithRevisionTests := []struct { - inKey string + removeTests := []struct { + validKey string + invalidKey string + }{ + {"test1", "abc"}, + {"test1/a", "test1/lskfjal"}, + {"test1/b", "test1/b"}, + {"test2", "-"}, + } - expectedKeyNo int - expectedValues []string - }{ - {"a", 2, []string{"a_version3", "a_version5"}}, - {"b", 1, []string{"b_version2"}}, - {"c", 1, []string{"c_version4"}}, - } + for _, test := range removeTests { + err := etcdKV.Remove(test.validKey) + s.NoError(err) - for _, test := range loadWithRevisionTests { - keys, values, revision, err := etcdKV.LoadBytesWithRevision(test.inKey) - assert.NoError(t, err) - assert.Equal(t, test.expectedKeyNo, len(keys)) - stringValues := make([]string, len(values)) - for i := range values { - stringValues[i] = string(values[i]) - } - assert.ElementsMatch(t, test.expectedValues, stringValues) - assert.NotZero(t, revision) - } + _, err = etcdKV.Load(test.validKey) + s.Error(err) - }) + err = etcdKV.Remove(test.validKey) + s.NoError(err) + err = etcdKV.Remove(test.invalidKey) + s.NoError(err) + } +} - te.Run("etcdKV MultiSaveAndMultiLoad", func(t *testing.T) { - rootPath := "/etcd/test/root/multi_save_and_multi_load" - etcdKV := etcdkv.NewEtcdKV(etcdCli, rootPath) +func (s *EtcdKVSuite) TestLoadBytesWithRevision() { + etcdKV := s.etcdKV + + prepareKV := []struct { + inKey string + inValue string + }{ + {"a", "a_version1"}, + {"b", "b_version2"}, + {"a", "a_version3"}, + {"c", "c_version4"}, + {"a/suba", "a_version5"}, + } - defer etcdKV.Close() - defer etcdKV.RemoveWithPrefix("") + for _, test := range prepareKV { + err := etcdKV.SaveBytes(test.inKey, []byte(test.inValue)) + s.NoError(err) + } - multiSaveTests := map[string]string{ - "key_1": "value_1", - "key_2": "value_2", - "key_3/a": "value_3a", - "multikey_1": "multivalue_1", - "multikey_2": "multivalue_2", - "_": "other", - } + loadWithRevisionTests := []struct { + inKey string - err = etcdKV.MultiSave(multiSaveTests) - assert.NoError(t, err) - for k, v := range multiSaveTests { - actualV, err := etcdKV.Load(k) - assert.NoError(t, err) - assert.Equal(t, v, actualV) - } + expectedKeyNo int + expectedValues []string + }{ + {"a", 2, []string{"a_version3", "a_version5"}}, + {"b", 1, []string{"b_version2"}}, + {"c", 1, []string{"c_version4"}}, + } - multiLoadTests := []struct { - inputKeys []string - expectedValues []string - }{ - {[]string{"key_1"}, []string{"value_1"}}, - {[]string{"key_1", "key_2", "key_3/a"}, []string{"value_1", "value_2", "value_3a"}}, - {[]string{"multikey_1", "multikey_2"}, []string{"multivalue_1", "multivalue_2"}}, - {[]string{"_"}, []string{"other"}}, + for _, test := range loadWithRevisionTests { + keys, values, revision, err := etcdKV.LoadBytesWithRevision(test.inKey) + s.NoError(err) + s.Equal(test.expectedKeyNo, len(keys)) + stringValues := make([]string, len(values)) + for i := range values { + stringValues[i] = string(values[i]) } + s.ElementsMatch(test.expectedValues, stringValues) + s.NotZero(revision) + } +} - for _, test := range multiLoadTests { - vs, err := etcdKV.MultiLoad(test.inputKeys) - assert.NoError(t, err) - assert.Equal(t, test.expectedValues, vs) - } +func (s *EtcdKVSuite) TestMultiSaveAndMultiLoad() { + etcdKV := s.etcdKV + multiSaveTests := map[string]string{ + "key_1": "value_1", + "key_2": "value_2", + "key_3/a": "value_3a", + "multikey_1": "multivalue_1", + "multikey_2": "multivalue_2", + "_": "other", + } - invalidMultiLoad := []struct { - invalidKeys []string - expectedValues []string - }{ - {[]string{"a", "key_1"}, []string{"", "value_1"}}, - {[]string{".....", "key_1"}, []string{"", "value_1"}}, - {[]string{"*********"}, []string{""}}, - {[]string{"key_1", "1"}, []string{"value_1", ""}}, - } + err := etcdKV.MultiSave(multiSaveTests) + s.Require().NoError(err) + for k, v := range multiSaveTests { + actualV, err := etcdKV.Load(k) + s.NoError(err) + s.Equal(v, actualV) + } - for _, test := range invalidMultiLoad { - vs, err := etcdKV.MultiLoad(test.invalidKeys) - assert.Error(t, err) - assert.Equal(t, test.expectedValues, vs) - } + multiLoadTests := []struct { + inputKeys []string + expectedValues []string + }{ + {[]string{"key_1"}, []string{"value_1"}}, + {[]string{"key_1", "key_2", "key_3/a"}, []string{"value_1", "value_2", "value_3a"}}, + {[]string{"multikey_1", "multikey_2"}, []string{"multivalue_1", "multivalue_2"}}, + {[]string{"_"}, []string{"other"}}, + } - removeWithPrefixTests := []string{ - "key_1", - "multi", - } + for _, test := range multiLoadTests { + vs, err := etcdKV.MultiLoad(test.inputKeys) + s.NoError(err) + s.Equal(test.expectedValues, vs) + } - for _, k := range removeWithPrefixTests { - err = etcdKV.RemoveWithPrefix(k) - assert.NoError(t, err) + invalidMultiLoad := []struct { + invalidKeys []string + expectedValues []string + }{ + {[]string{"a", "key_1"}, []string{"", "value_1"}}, + {[]string{".....", "key_1"}, []string{"", "value_1"}}, + {[]string{"*********"}, []string{""}}, + {[]string{"key_1", "1"}, []string{"value_1", ""}}, + } - ks, vs, err := etcdKV.LoadWithPrefix(k) - assert.Empty(t, ks) - assert.Empty(t, vs) - assert.NoError(t, err) - } + for _, test := range invalidMultiLoad { + vs, err := etcdKV.MultiLoad(test.invalidKeys) + s.Error(err) + s.Equal(test.expectedValues, vs) + } - multiRemoveTests := []string{ - "key_2", - "key_3/a", - "multikey_2", - "_", - } + removeWithPrefixTests := []string{ + "key_1", + "multi", + } - err = etcdKV.MultiRemove(multiRemoveTests) - assert.NoError(t, err) + for _, k := range removeWithPrefixTests { + err = etcdKV.RemoveWithPrefix(k) + s.NoError(err) - ks, vs, err := etcdKV.LoadWithPrefix("") - assert.NoError(t, err) - assert.Empty(t, ks) - assert.Empty(t, vs) - - multiSaveAndRemoveTests := []struct { - multiSaves map[string]string - multiRemoves []string - }{ - {map[string]string{"key_1": "value_1"}, []string{}}, - {map[string]string{"key_2": "value_2"}, []string{"key_1"}}, - {map[string]string{"key_3/a": "value_3a"}, []string{"key_2"}}, - {map[string]string{"multikey_1": "multivalue_1"}, []string{}}, - {map[string]string{"multikey_2": "multivalue_2"}, []string{"multikey_1", "key_3/a"}}, - {make(map[string]string), []string{"multikey_2"}}, - } - for _, test := range multiSaveAndRemoveTests { - err = etcdKV.MultiSaveAndRemove(test.multiSaves, test.multiRemoves) - assert.NoError(t, err) - } + ks, vs, err := etcdKV.LoadWithPrefix(k) + s.Empty(ks) + s.Empty(vs) + s.NoError(err) + } - ks, vs, err = etcdKV.LoadWithPrefix("") - assert.NoError(t, err) - assert.Empty(t, ks) - assert.Empty(t, vs) - }) + multiRemoveTests := []string{ + "key_2", + "key_3/a", + "multikey_2", + "_", + } - te.Run("etcdKV MultiSaveBytesAndMultiLoadBytes", func(t *testing.T) { - rootPath := "/etcd/test/root/multi_save_bytes_and_multi_load_bytes" - etcdKV := etcdkv.NewEtcdKV(etcdCli, rootPath) + err = etcdKV.MultiRemove(multiRemoveTests) + s.NoError(err) + + ks, vs, err := etcdKV.LoadWithPrefix("") + s.NoError(err) + s.Empty(ks) + s.Empty(vs) + + multiSaveAndRemoveTests := []struct { + multiSaves map[string]string + multiRemoves []string + }{ + {map[string]string{"key_1": "value_1"}, []string{}}, + {map[string]string{"key_2": "value_2"}, []string{"key_1"}}, + {map[string]string{"key_3/a": "value_3a"}, []string{"key_2"}}, + {map[string]string{"multikey_1": "multivalue_1"}, []string{}}, + {map[string]string{"multikey_2": "multivalue_2"}, []string{"multikey_1", "key_3/a"}}, + {make(map[string]string), []string{"multikey_2"}}, + } + for _, test := range multiSaveAndRemoveTests { + err = etcdKV.MultiSaveAndRemove(test.multiSaves, test.multiRemoves) + s.NoError(err) + } - defer etcdKV.Close() - defer etcdKV.RemoveWithPrefix("") + ks, vs, err = etcdKV.LoadWithPrefix("") + s.NoError(err) + s.Empty(ks) + s.Empty(vs) +} - multiSaveTests := map[string]string{ - "key_1": "value_1", - "key_2": "value_2", - "key_3/a": "value_3a", - "multikey_1": "multivalue_1", - "multikey_2": "multivalue_2", - "_": "other", - } +func (s *EtcdKVSuite) TestMultiSaveBytesAndMultiLoadBytes() { + etcdKV := s.etcdKV + multiSaveTests := map[string]string{ + "key_1": "value_1", + "key_2": "value_2", + "key_3/a": "value_3a", + "multikey_1": "multivalue_1", + "multikey_2": "multivalue_2", + "_": "other", + } - multiSaveBytesTests := make(map[string][]byte) - for k, v := range multiSaveTests { - multiSaveBytesTests[k] = []byte(v) - } + multiSaveBytesTests := make(map[string][]byte) + for k, v := range multiSaveTests { + multiSaveBytesTests[k] = []byte(v) + } - err = etcdKV.MultiSaveBytes(multiSaveBytesTests) - assert.NoError(t, err) - for k, v := range multiSaveTests { - actualV, err := etcdKV.LoadBytes(k) - assert.NoError(t, err) - assert.Equal(t, v, string(actualV)) - } + err := etcdKV.MultiSaveBytes(multiSaveBytesTests) + s.Require().NoError(err) + for k, v := range multiSaveTests { + actualV, err := etcdKV.LoadBytes(k) + s.NoError(err) + s.Equal(v, string(actualV)) + } - multiLoadTests := []struct { - inputKeys []string - expectedValues []string - }{ - {[]string{"key_1"}, []string{"value_1"}}, - {[]string{"key_1", "key_2", "key_3/a"}, []string{"value_1", "value_2", "value_3a"}}, - {[]string{"multikey_1", "multikey_2"}, []string{"multivalue_1", "multivalue_2"}}, - {[]string{"_"}, []string{"other"}}, - } + multiLoadTests := []struct { + inputKeys []string + expectedValues []string + }{ + {[]string{"key_1"}, []string{"value_1"}}, + {[]string{"key_1", "key_2", "key_3/a"}, []string{"value_1", "value_2", "value_3a"}}, + {[]string{"multikey_1", "multikey_2"}, []string{"multivalue_1", "multivalue_2"}}, + {[]string{"_"}, []string{"other"}}, + } - for _, test := range multiLoadTests { - vs, err := etcdKV.MultiLoadBytes(test.inputKeys) - stringVs := make([]string, len(vs)) - for i := range vs { - stringVs[i] = string(vs[i]) - } - assert.NoError(t, err) - assert.Equal(t, test.expectedValues, stringVs) + for _, test := range multiLoadTests { + vs, err := etcdKV.MultiLoadBytes(test.inputKeys) + stringVs := make([]string, len(vs)) + for i := range vs { + stringVs[i] = string(vs[i]) } + s.NoError(err) + s.Equal(test.expectedValues, stringVs) + } - invalidMultiLoad := []struct { - invalidKeys []string - expectedValues []string - }{ - {[]string{"a", "key_1"}, []string{"", "value_1"}}, - {[]string{".....", "key_1"}, []string{"", "value_1"}}, - {[]string{"*********"}, []string{""}}, - {[]string{"key_1", "1"}, []string{"value_1", ""}}, - } + invalidMultiLoad := []struct { + invalidKeys []string + expectedValues []string + }{ + {[]string{"a", "key_1"}, []string{"", "value_1"}}, + {[]string{".....", "key_1"}, []string{"", "value_1"}}, + {[]string{"*********"}, []string{""}}, + {[]string{"key_1", "1"}, []string{"value_1", ""}}, + } - for _, test := range invalidMultiLoad { - vs, err := etcdKV.MultiLoadBytes(test.invalidKeys) - stringVs := make([]string, len(vs)) - for i := range vs { - stringVs[i] = string(vs[i]) - } - assert.Error(t, err) - assert.Equal(t, test.expectedValues, stringVs) + for _, test := range invalidMultiLoad { + vs, err := etcdKV.MultiLoadBytes(test.invalidKeys) + stringVs := make([]string, len(vs)) + for i := range vs { + stringVs[i] = string(vs[i]) } + s.Error(err) + s.Equal(test.expectedValues, stringVs) + } - removeWithPrefixTests := []string{ - "key_1", - "multi", - } + removeWithPrefixTests := []string{ + "key_1", + "multi", + } - for _, k := range removeWithPrefixTests { - err = etcdKV.RemoveWithPrefix(k) - assert.NoError(t, err) + for _, k := range removeWithPrefixTests { + err = etcdKV.RemoveWithPrefix(k) + s.NoError(err) - ks, vs, err := etcdKV.LoadBytesWithPrefix(k) - assert.Empty(t, ks) - assert.Empty(t, vs) - assert.NoError(t, err) - } + ks, vs, err := etcdKV.LoadBytesWithPrefix(k) + s.Empty(ks) + s.Empty(vs) + s.NoError(err) + } - multiRemoveTests := []string{ - "key_2", - "key_3/a", - "multikey_2", - "_", - } + multiRemoveTests := []string{ + "key_2", + "key_3/a", + "multikey_2", + "_", + } - err = etcdKV.MultiRemove(multiRemoveTests) - assert.NoError(t, err) + err = etcdKV.MultiRemove(multiRemoveTests) + s.NoError(err) + + ks, vs, err := etcdKV.LoadBytesWithPrefix("") + s.NoError(err) + s.Empty(ks) + s.Empty(vs) + + multiSaveAndRemoveTests := []struct { + multiSaves map[string][]byte + multiRemoves []string + }{ + {map[string][]byte{"key_1": []byte("value_1")}, []string{}}, + {map[string][]byte{"key_2": []byte("value_2")}, []string{"key_1"}}, + {map[string][]byte{"key_3/a": []byte("value_3a")}, []string{"key_2"}}, + {map[string][]byte{"multikey_1": []byte("multivalue_1")}, []string{}}, + {map[string][]byte{"multikey_2": []byte("multivalue_2")}, []string{"multikey_1", "key_3/a"}}, + {make(map[string][]byte), []string{"multikey_2"}}, + } - ks, vs, err := etcdKV.LoadBytesWithPrefix("") - assert.NoError(t, err) - assert.Empty(t, ks) - assert.Empty(t, vs) - - multiSaveAndRemoveTests := []struct { - multiSaves map[string][]byte - multiRemoves []string - }{ - {map[string][]byte{"key_1": []byte("value_1")}, []string{}}, - {map[string][]byte{"key_2": []byte("value_2")}, []string{"key_1"}}, - {map[string][]byte{"key_3/a": []byte("value_3a")}, []string{"key_2"}}, - {map[string][]byte{"multikey_1": []byte("multivalue_1")}, []string{}}, - {map[string][]byte{"multikey_2": []byte("multivalue_2")}, []string{"multikey_1", "key_3/a"}}, - {make(map[string][]byte), []string{"multikey_2"}}, - } + for _, test := range multiSaveAndRemoveTests { + err = etcdKV.MultiSaveBytesAndRemove(test.multiSaves, test.multiRemoves) + s.NoError(err) + } - for _, test := range multiSaveAndRemoveTests { - err = etcdKV.MultiSaveBytesAndRemove(test.multiSaves, test.multiRemoves) - assert.NoError(t, err) - } + ks, vs, err = etcdKV.LoadBytesWithPrefix("") + s.NoError(err) + s.Empty(ks) + s.Empty(vs) +} - ks, vs, err = etcdKV.LoadBytesWithPrefix("") - assert.NoError(t, err) - assert.Empty(t, ks) - assert.Empty(t, vs) - }) +func (s *EtcdKVSuite) TestTxnWithPredicates() { + etcdKV := s.etcdKV - te.Run("etcdKV MultiRemoveWithPrefix", func(t *testing.T) { - rootPath := "/etcd/test/root/multi_remove_with_prefix" - etcdKV := etcdkv.NewEtcdKV(etcdCli, rootPath) - defer etcdKV.Close() - defer etcdKV.RemoveWithPrefix("") - - prepareTests := map[string]string{ - "x/abc/1": "1", - "x/abc/2": "2", - "x/def/1": "10", - "x/def/2": "20", - "x/den/1": "100", - "x/den/2": "200", - } + prepareKV := map[string]string{ + "lease1": "1", + "lease2": "2", + } - err = etcdKV.MultiSave(prepareTests) - require.NoError(t, err) - - multiRemoveWithPrefixTests := []struct { - prefix []string - - testKey string - expectedValue string - }{ - {[]string{"x/abc"}, "x/abc/1", ""}, - {[]string{}, "x/abc/2", ""}, - {[]string{}, "x/def/1", "10"}, - {[]string{}, "x/def/2", "20"}, - {[]string{}, "x/den/1", "100"}, - {[]string{}, "x/den/2", "200"}, - {[]string{}, "not-exist", ""}, - {[]string{"x/def", "x/den"}, "x/def/1", ""}, - {[]string{}, "x/def/1", ""}, - {[]string{}, "x/def/2", ""}, - {[]string{}, "x/den/1", ""}, - {[]string{}, "x/den/2", ""}, - {[]string{}, "not-exist", ""}, - } + err := etcdKV.MultiSave(prepareKV) + s.Require().NoError(err) + + badPredicate := predicates.NewMockPredicate(s.T()) + badPredicate.EXPECT().Type().Return(0) + badPredicate.EXPECT().Target().Return(predicates.PredTargetValue) + + multiSaveAndRemovePredTests := []struct { + tag string + multiSave map[string]string + preds []predicates.Predicate + expectSuccess bool + }{ + {"predicate_ok", map[string]string{"a": "b"}, []predicates.Predicate{predicates.ValueEqual("lease1", "1")}, true}, + {"predicate_fail", map[string]string{"a": "b"}, []predicates.Predicate{predicates.ValueEqual("lease1", "2")}, false}, + {"bad_predicate", map[string]string{"a": "b"}, []predicates.Predicate{badPredicate}, false}, + } - for _, test := range multiRemoveWithPrefixTests { - if len(test.prefix) > 0 { - err = etcdKV.MultiRemoveWithPrefix(test.prefix) - assert.NoError(t, err) + for _, test := range multiSaveAndRemovePredTests { + s.Run(test.tag, func() { + err := etcdKV.MultiSaveAndRemove(test.multiSave, nil, test.preds...) + if test.expectSuccess { + s.NoError(err) + } else { + s.Error(err) + } + err = etcdKV.MultiSaveAndRemoveWithPrefix(test.multiSave, nil, test.preds...) + if test.expectSuccess { + s.NoError(err) + } else { + s.Error(err) } + }) + } +} - v, _ := etcdKV.Load(test.testKey) - assert.Equal(t, test.expectedValue, v) - } +func (s *EtcdKVSuite) TestMultiSaveAndRemoveWithPrefix() { + etcdKV := s.etcdKV - k, v, err := etcdKV.LoadWithPrefix("/") - assert.NoError(t, err) - assert.Zero(t, len(k)) - assert.Zero(t, len(v)) - - // MultiSaveAndRemoveWithPrefix - err = etcdKV.MultiSave(prepareTests) - require.NoError(t, err) - multiSaveAndRemoveWithPrefixTests := []struct { - multiSave map[string]string - prefix []string - - loadPrefix string - lengthBeforeRemove int - lengthAfterRemove int - }{ - {map[string]string{}, []string{"x/abc", "x/def", "x/den"}, "x", 6, 0}, - {map[string]string{"y/a": "vvv", "y/b": "vvv"}, []string{}, "y", 0, 2}, - {map[string]string{"y/c": "vvv"}, []string{}, "y", 2, 3}, - {map[string]string{"p/a": "vvv"}, []string{"y/a", "y"}, "y", 3, 0}, - {map[string]string{}, []string{"p"}, "p", 1, 0}, - } + prepareTests := map[string]string{ + "x/abc/1": "1", + "x/abc/2": "2", + "x/def/1": "10", + "x/def/2": "20", + "x/den/1": "100", + "x/den/2": "200", + } - for _, test := range multiSaveAndRemoveWithPrefixTests { - k, _, err = etcdKV.LoadWithPrefix(test.loadPrefix) - assert.NoError(t, err) - assert.Equal(t, test.lengthBeforeRemove, len(k)) + // MultiSaveAndRemoveWithPrefix + err := etcdKV.MultiSave(prepareTests) + s.Require().NoError(err) + multiSaveAndRemoveWithPrefixTests := []struct { + multiSave map[string]string + prefix []string + + loadPrefix string + lengthBeforeRemove int + lengthAfterRemove int + }{ + {map[string]string{}, []string{"x/abc", "x/def", "x/den"}, "x", 6, 0}, + {map[string]string{"y/a": "vvv", "y/b": "vvv"}, []string{}, "y", 0, 2}, + {map[string]string{"y/c": "vvv"}, []string{}, "y", 2, 3}, + {map[string]string{"p/a": "vvv"}, []string{"y/a", "y"}, "y", 3, 0}, + {map[string]string{}, []string{"p"}, "p", 1, 0}, + } - err = etcdKV.MultiSaveAndRemoveWithPrefix(test.multiSave, test.prefix) - assert.NoError(t, err) + for _, test := range multiSaveAndRemoveWithPrefixTests { + k, _, err := etcdKV.LoadWithPrefix(test.loadPrefix) + s.NoError(err) + s.Equal(test.lengthBeforeRemove, len(k)) - k, _, err = etcdKV.LoadWithPrefix(test.loadPrefix) - assert.NoError(t, err) - assert.Equal(t, test.lengthAfterRemove, len(k)) - } - }) + err = etcdKV.MultiSaveAndRemoveWithPrefix(test.multiSave, test.prefix) + s.NoError(err) - te.Run("etcdKV Watch", func(t *testing.T) { - rootPath := "/etcd/test/root/watch" - etcdKV := etcdkv.NewEtcdKV(etcdCli, rootPath) + k, _, err = etcdKV.LoadWithPrefix(test.loadPrefix) + s.NoError(err) + s.Equal(test.lengthAfterRemove, len(k)) + } +} - defer etcdKV.Close() - defer etcdKV.RemoveWithPrefix("") +func (s *EtcdKVSuite) TestWatch() { + etcdKV := s.etcdKV - ch := etcdKV.Watch("x") - resp := <-ch - assert.True(t, resp.Created) + ch := etcdKV.Watch("x") + resp := <-ch + s.True(resp.Created) - ch = etcdKV.WatchWithPrefix("x") - resp = <-ch - assert.True(t, resp.Created) - }) + ch = etcdKV.WatchWithPrefix("x") + resp = <-ch + s.True(resp.Created) +} - te.Run("Etcd Revision Bytes", func(t *testing.T) { - rootPath := "/etcd/test/root/revision_bytes" - etcdKV := etcdkv.NewEtcdKV(etcdCli, rootPath) - defer etcdKV.Close() - defer etcdKV.RemoveWithPrefix("") - - revisionTests := []struct { - inKey string - fistValue []byte - secondValue []byte - }{ - {"a", []byte("v1"), []byte("v11")}, - {"y", []byte("v2"), []byte("v22")}, - {"z", []byte("v3"), []byte("v33")}, - } +func (s *EtcdKVSuite) TestRevisionBytes() { + etcdKV := s.etcdKV + + revisionTests := []struct { + inKey string + fistValue []byte + secondValue []byte + }{ + {"a", []byte("v1"), []byte("v11")}, + {"y", []byte("v2"), []byte("v22")}, + {"z", []byte("v3"), []byte("v33")}, + } - for _, test := range revisionTests { - err = etcdKV.SaveBytes(test.inKey, test.fistValue) - require.NoError(t, err) + for _, test := range revisionTests { + err := etcdKV.SaveBytes(test.inKey, test.fistValue) + s.Require().NoError(err) - _, _, revision, _ := etcdKV.LoadBytesWithRevision(test.inKey) - ch := etcdKV.WatchWithRevision(test.inKey, revision+1) + _, _, revision, _ := etcdKV.LoadBytesWithRevision(test.inKey) + ch := etcdKV.WatchWithRevision(test.inKey, revision+1) - err = etcdKV.SaveBytes(test.inKey, test.secondValue) - require.NoError(t, err) + err = etcdKV.SaveBytes(test.inKey, test.secondValue) + s.Require().NoError(err) - resp := <-ch - assert.Equal(t, 1, len(resp.Events)) - assert.Equal(t, string(test.secondValue), string(resp.Events[0].Kv.Value)) - assert.Equal(t, revision+1, resp.Header.Revision) - } + resp := <-ch + s.Equal(1, len(resp.Events)) + s.Equal(string(test.secondValue), string(resp.Events[0].Kv.Value)) + s.Equal(revision+1, resp.Header.Revision) + } - success, err := etcdKV.CompareVersionAndSwapBytes("a/b/c", 0, []byte("1")) - assert.NoError(t, err) - assert.True(t, success) + success, err := etcdKV.CompareVersionAndSwapBytes("a/b/c", 0, []byte("1")) + s.NoError(err) + s.True(success) - value, err := etcdKV.LoadBytes("a/b/c") - assert.NoError(t, err) - assert.Equal(t, string(value), "1") + value, err := etcdKV.LoadBytes("a/b/c") + s.NoError(err) + s.Equal(string(value), "1") - success, err = etcdKV.CompareVersionAndSwapBytes("a/b/c", 0, []byte("1")) - assert.NoError(t, err) - assert.False(t, success) - }) + success, err = etcdKV.CompareVersionAndSwapBytes("a/b/c", 0, []byte("1")) + s.NoError(err) + s.False(success) +} + +func TestEtcdKV(t *testing.T) { + suite.Run(t, new(EtcdKVSuite)) } func Test_WalkWithPagination(t *testing.T) { @@ -712,7 +718,7 @@ func Test_WalkWithPagination(t *testing.T) { assert.NoError(t, err) rootPath := "/etcd/test/root/pagination" - etcdKV := etcdkv.NewEtcdKV(etcdCli, rootPath) + etcdKV := NewEtcdKV(etcdCli, rootPath) defer etcdKV.Close() defer etcdKV.RemoveWithPrefix("") @@ -786,42 +792,42 @@ func Test_WalkWithPagination(t *testing.T) { func TestElapse(t *testing.T) { start := time.Now() - isElapse := etcdkv.CheckElapseAndWarn(start, "err message") + isElapse := CheckElapseAndWarn(start, "err message") assert.Equal(t, isElapse, false) time.Sleep(2001 * time.Millisecond) - isElapse = etcdkv.CheckElapseAndWarn(start, "err message") + isElapse = CheckElapseAndWarn(start, "err message") assert.Equal(t, isElapse, true) } func TestCheckValueSizeAndWarn(t *testing.T) { - ret := etcdkv.CheckValueSizeAndWarn("k", "v") + ret := CheckValueSizeAndWarn("k", "v") assert.False(t, ret) v := make([]byte, 1024000) - ret = etcdkv.CheckValueSizeAndWarn("k", v) + ret = CheckValueSizeAndWarn("k", v) assert.True(t, ret) } func TestCheckTnxBytesValueSizeAndWarn(t *testing.T) { kvs := make(map[string][]byte, 0) kvs["k"] = []byte("v") - ret := etcdkv.CheckTnxBytesValueSizeAndWarn(kvs) + ret := CheckTnxBytesValueSizeAndWarn(kvs) assert.False(t, ret) kvs["k"] = make([]byte, 1024000) - ret = etcdkv.CheckTnxBytesValueSizeAndWarn(kvs) + ret = CheckTnxBytesValueSizeAndWarn(kvs) assert.True(t, ret) } func TestCheckTnxStringValueSizeAndWarn(t *testing.T) { kvs := make(map[string]string, 0) kvs["k"] = "v" - ret := etcdkv.CheckTnxStringValueSizeAndWarn(kvs) + ret := CheckTnxStringValueSizeAndWarn(kvs) assert.False(t, ret) kvs["k1"] = funcutil.RandomString(1024000) - ret = etcdkv.CheckTnxStringValueSizeAndWarn(kvs) + ret = CheckTnxStringValueSizeAndWarn(kvs) assert.True(t, ret) } @@ -837,7 +843,7 @@ func TestHas(t *testing.T) { defer etcdCli.Close() assert.NoError(t, err) rootPath := "/etcd/test/root/has" - kv := etcdkv.NewEtcdKV(etcdCli, rootPath) + kv := NewEtcdKV(etcdCli, rootPath) err = kv.RemoveWithPrefix("") require.NoError(t, err) @@ -875,7 +881,7 @@ func TestHasPrefix(t *testing.T) { defer etcdCli.Close() assert.NoError(t, err) rootPath := "/etcd/test/root/hasprefix" - kv := etcdkv.NewEtcdKV(etcdCli, rootPath) + kv := NewEtcdKV(etcdCli, rootPath) err = kv.RemoveWithPrefix("") require.NoError(t, err) diff --git a/internal/kv/etcd/util.go b/internal/kv/etcd/util.go new file mode 100644 index 0000000000000..6363ddb5f9ad7 --- /dev/null +++ b/internal/kv/etcd/util.go @@ -0,0 +1,42 @@ +package etcdkv + +import ( + "fmt" + "path" + + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/milvus-io/milvus/internal/kv/predicates" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +func parsePredicates(rootPath string, preds ...predicates.Predicate) ([]clientv3.Cmp, error) { + if len(preds) == 0 { + return []clientv3.Cmp{}, nil + } + result := make([]clientv3.Cmp, 0, len(preds)) + for _, pred := range preds { + switch pred.Target() { + case predicates.PredTargetValue: + pt, err := parsePredicateType(pred.Type()) + if err != nil { + return nil, err + } + cmp := clientv3.Compare(clientv3.Value(path.Join(rootPath, pred.Key())), pt, pred.TargetValue()) + result = append(result, cmp) + default: + return nil, merr.WrapErrParameterInvalid("valid predicate target", fmt.Sprintf("%d", pred.Target())) + } + } + return result, nil +} + +// parsePredicateType parse predicates.PredicateType to clientv3.Result +func parsePredicateType(pt predicates.PredicateType) (string, error) { + switch pt { + case predicates.PredTypeEqual: + return "=", nil + default: + return "", merr.WrapErrParameterInvalid("valid predicate type", fmt.Sprintf("%d", pt)) + } +} diff --git a/internal/kv/etcd/util_test.go b/internal/kv/etcd/util_test.go new file mode 100644 index 0000000000000..331f4845ae486 --- /dev/null +++ b/internal/kv/etcd/util_test.go @@ -0,0 +1,72 @@ +package etcdkv + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/kv/predicates" +) + +type EtcdKVUtilSuite struct { + suite.Suite +} + +func (s *EtcdKVUtilSuite) TestParsePredicateType() { + type testCase struct { + tag string + pt predicates.PredicateType + expectResult string + expectSucceed bool + } + + cases := []testCase{ + {tag: "equal", pt: predicates.PredTypeEqual, expectResult: "=", expectSucceed: true}, + {tag: "zero_value", pt: 0, expectResult: "", expectSucceed: false}, + } + + for _, tc := range cases { + s.Run(tc.tag, func() { + result, err := parsePredicateType(tc.pt) + if tc.expectSucceed { + s.NoError(err) + s.Equal(tc.expectResult, result) + } else { + s.Error(err) + } + }) + } +} + +func (s *EtcdKVUtilSuite) TestParsePredicates() { + type testCase struct { + tag string + input []predicates.Predicate + expectSucceed bool + } + + badPredicate := predicates.NewMockPredicate(s.T()) + badPredicate.EXPECT().Target().Return(0) + + cases := []testCase{ + {tag: "normal_value_equal", input: []predicates.Predicate{predicates.ValueEqual("a", "b")}, expectSucceed: true}, + {tag: "empty_input", input: nil, expectSucceed: true}, + {tag: "bad_predicates", input: []predicates.Predicate{badPredicate}, expectSucceed: false}, + } + + for _, tc := range cases { + s.Run(tc.tag, func() { + result, err := parsePredicates("", tc.input...) + if tc.expectSucceed { + s.NoError(err) + s.Equal(len(tc.input), len(result)) + } else { + s.Error(err) + } + }) + } +} + +func TestEtcdKVUtil(t *testing.T) { + suite.Run(t, new(EtcdKVUtilSuite)) +} diff --git a/internal/kv/kv.go b/internal/kv/kv.go index 53ae73c37877b..14091cdc1e842 100644 --- a/internal/kv/kv.go +++ b/internal/kv/kv.go @@ -19,6 +19,7 @@ package kv import ( clientv3 "go.etcd.io/etcd/client/v3" + "github.com/milvus-io/milvus/internal/kv/predicates" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -57,9 +58,8 @@ type BaseKV interface { //go:generate mockery --name=TxnKV --with-expecter type TxnKV interface { BaseKV - MultiSaveAndRemove(saves map[string]string, removals []string) error - MultiRemoveWithPrefix(keys []string) error - MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string) error + MultiSaveAndRemove(saves map[string]string, removals []string, preds ...predicates.Predicate) error + MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string, preds ...predicates.Predicate) error } // MetaKv is TxnKV for metadata. It should save data with lease. diff --git a/internal/kv/mem/mem_kv.go b/internal/kv/mem/mem_kv.go index ad9666233660c..d4309e879aec8 100644 --- a/internal/kv/mem/mem_kv.go +++ b/internal/kv/mem/mem_kv.go @@ -20,9 +20,10 @@ import ( "strings" "sync" - "github.com/cockroachdb/errors" "github.com/google/btree" - "github.com/milvus-io/milvus/pkg/common" + + "github.com/milvus-io/milvus/internal/kv/predicates" + "github.com/milvus-io/milvus/pkg/util/merr" ) // MemoryKV implements BaseKv interface and relies on underling btree.BTree. @@ -82,7 +83,7 @@ func (kv *MemoryKV) Load(key string) (string, error) { defer kv.RUnlock() item := kv.tree.Get(memoryKVItem{key: key}) if item == nil { - return "", common.NewKeyNotExistError(key) + return "", merr.WrapErrIoKeyNotFound(key) } return item.(memoryKVItem).value.String(), nil } @@ -93,7 +94,7 @@ func (kv *MemoryKV) LoadBytes(key string) ([]byte, error) { defer kv.RUnlock() item := kv.tree.Get(memoryKVItem{key: key}) if item == nil { - return []byte{}, common.NewKeyNotExistError(key) + return nil, merr.WrapErrIoKeyNotFound(key) } return item.(memoryKVItem).value.ByteSlice(), nil } @@ -217,7 +218,10 @@ func (kv *MemoryKV) MultiRemove(keys []string) error { } // MultiSaveAndRemove saves and removes given key-value pairs in MemoryKV atomicly. -func (kv *MemoryKV) MultiSaveAndRemove(saves map[string]string, removals []string) error { +func (kv *MemoryKV) MultiSaveAndRemove(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + if len(preds) > 0 { + return merr.WrapErrServiceUnavailable("predicates not supported") + } kv.Lock() defer kv.Unlock() for key, value := range saves { @@ -282,13 +286,11 @@ func (kv *MemoryKV) LoadBytesWithPrefix(key string) ([]string, [][]byte, error) func (kv *MemoryKV) Close() { } -// MultiRemoveWithPrefix not implemented -func (kv *MemoryKV) MultiRemoveWithPrefix(keys []string) error { - return errors.New("not implement") -} - // MultiSaveAndRemoveWithPrefix saves key-value pairs in @saves, & remove key with prefix in @removals in MemoryKV atomically. -func (kv *MemoryKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string) error { +func (kv *MemoryKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + if len(preds) > 0 { + return merr.WrapErrServiceUnavailable("predicates not supported") + } kv.Lock() defer kv.Unlock() diff --git a/internal/kv/mem/mem_kv_test.go b/internal/kv/mem/mem_kv_test.go index 79f1651c68b4b..76e8896827f7a 100644 --- a/internal/kv/mem/mem_kv_test.go +++ b/internal/kv/mem/mem_kv_test.go @@ -20,6 +20,9 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/kv/predicates" + "github.com/milvus-io/milvus/pkg/util/merr" ) func TestMemoryKV_SaveAndLoadBytes(t *testing.T) { @@ -242,3 +245,16 @@ func TestHasPrefix(t *testing.T) { assert.NoError(t, err) assert.False(t, has) } + +func TestPredicates(t *testing.T) { + kv := NewMemoryKV() + + // predicates not supported for mem kv for now + err := kv.MultiSaveAndRemove(map[string]string{}, []string{}, predicates.ValueEqual("a", "b")) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrServiceUnavailable) + + err = kv.MultiSaveAndRemoveWithPrefix(map[string]string{}, []string{}, predicates.ValueEqual("a", "b")) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrServiceUnavailable) +} diff --git a/internal/kv/mocks/meta_kv.go b/internal/kv/mocks/meta_kv.go index b80537e9c7008..5a615ff5250dd 100644 --- a/internal/kv/mocks/meta_kv.go +++ b/internal/kv/mocks/meta_kv.go @@ -2,7 +2,10 @@ package mocks -import mock "github.com/stretchr/testify/mock" +import ( + predicates "github.com/milvus-io/milvus/internal/kv/predicates" + mock "github.com/stretchr/testify/mock" +) // MetaKv is an autogenerated mock type for the MetaKv type type MetaKv struct { @@ -460,48 +463,6 @@ func (_c *MetaKv_MultiRemove_Call) RunAndReturn(run func([]string) error) *MetaK return _c } -// MultiRemoveWithPrefix provides a mock function with given fields: keys -func (_m *MetaKv) MultiRemoveWithPrefix(keys []string) error { - ret := _m.Called(keys) - - var r0 error - if rf, ok := ret.Get(0).(func([]string) error); ok { - r0 = rf(keys) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MetaKv_MultiRemoveWithPrefix_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MultiRemoveWithPrefix' -type MetaKv_MultiRemoveWithPrefix_Call struct { - *mock.Call -} - -// MultiRemoveWithPrefix is a helper method to define mock.On call -// - keys []string -func (_e *MetaKv_Expecter) MultiRemoveWithPrefix(keys interface{}) *MetaKv_MultiRemoveWithPrefix_Call { - return &MetaKv_MultiRemoveWithPrefix_Call{Call: _e.mock.On("MultiRemoveWithPrefix", keys)} -} - -func (_c *MetaKv_MultiRemoveWithPrefix_Call) Run(run func(keys []string)) *MetaKv_MultiRemoveWithPrefix_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]string)) - }) - return _c -} - -func (_c *MetaKv_MultiRemoveWithPrefix_Call) Return(_a0 error) *MetaKv_MultiRemoveWithPrefix_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MetaKv_MultiRemoveWithPrefix_Call) RunAndReturn(run func([]string) error) *MetaKv_MultiRemoveWithPrefix_Call { - _c.Call.Return(run) - return _c -} - // MultiSave provides a mock function with given fields: kvs func (_m *MetaKv) MultiSave(kvs map[string]string) error { ret := _m.Called(kvs) @@ -544,13 +505,20 @@ func (_c *MetaKv_MultiSave_Call) RunAndReturn(run func(map[string]string) error) return _c } -// MultiSaveAndRemove provides a mock function with given fields: saves, removals -func (_m *MetaKv) MultiSaveAndRemove(saves map[string]string, removals []string) error { - ret := _m.Called(saves, removals) +// MultiSaveAndRemove provides a mock function with given fields: saves, removals, preds +func (_m *MetaKv) MultiSaveAndRemove(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + _va := make([]interface{}, len(preds)) + for _i := range preds { + _va[_i] = preds[_i] + } + var _ca []interface{} + _ca = append(_ca, saves, removals) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 error - if rf, ok := ret.Get(0).(func(map[string]string, []string) error); ok { - r0 = rf(saves, removals) + if rf, ok := ret.Get(0).(func(map[string]string, []string, ...predicates.Predicate) error); ok { + r0 = rf(saves, removals, preds...) } else { r0 = ret.Error(0) } @@ -566,13 +534,21 @@ type MetaKv_MultiSaveAndRemove_Call struct { // MultiSaveAndRemove is a helper method to define mock.On call // - saves map[string]string // - removals []string -func (_e *MetaKv_Expecter) MultiSaveAndRemove(saves interface{}, removals interface{}) *MetaKv_MultiSaveAndRemove_Call { - return &MetaKv_MultiSaveAndRemove_Call{Call: _e.mock.On("MultiSaveAndRemove", saves, removals)} +// - preds ...predicates.Predicate +func (_e *MetaKv_Expecter) MultiSaveAndRemove(saves interface{}, removals interface{}, preds ...interface{}) *MetaKv_MultiSaveAndRemove_Call { + return &MetaKv_MultiSaveAndRemove_Call{Call: _e.mock.On("MultiSaveAndRemove", + append([]interface{}{saves, removals}, preds...)...)} } -func (_c *MetaKv_MultiSaveAndRemove_Call) Run(run func(saves map[string]string, removals []string)) *MetaKv_MultiSaveAndRemove_Call { +func (_c *MetaKv_MultiSaveAndRemove_Call) Run(run func(saves map[string]string, removals []string, preds ...predicates.Predicate)) *MetaKv_MultiSaveAndRemove_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(map[string]string), args[1].([]string)) + variadicArgs := make([]predicates.Predicate, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(predicates.Predicate) + } + } + run(args[0].(map[string]string), args[1].([]string), variadicArgs...) }) return _c } @@ -582,18 +558,25 @@ func (_c *MetaKv_MultiSaveAndRemove_Call) Return(_a0 error) *MetaKv_MultiSaveAnd return _c } -func (_c *MetaKv_MultiSaveAndRemove_Call) RunAndReturn(run func(map[string]string, []string) error) *MetaKv_MultiSaveAndRemove_Call { +func (_c *MetaKv_MultiSaveAndRemove_Call) RunAndReturn(run func(map[string]string, []string, ...predicates.Predicate) error) *MetaKv_MultiSaveAndRemove_Call { _c.Call.Return(run) return _c } -// MultiSaveAndRemoveWithPrefix provides a mock function with given fields: saves, removals -func (_m *MetaKv) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string) error { - ret := _m.Called(saves, removals) +// MultiSaveAndRemoveWithPrefix provides a mock function with given fields: saves, removals, preds +func (_m *MetaKv) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + _va := make([]interface{}, len(preds)) + for _i := range preds { + _va[_i] = preds[_i] + } + var _ca []interface{} + _ca = append(_ca, saves, removals) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 error - if rf, ok := ret.Get(0).(func(map[string]string, []string) error); ok { - r0 = rf(saves, removals) + if rf, ok := ret.Get(0).(func(map[string]string, []string, ...predicates.Predicate) error); ok { + r0 = rf(saves, removals, preds...) } else { r0 = ret.Error(0) } @@ -609,13 +592,21 @@ type MetaKv_MultiSaveAndRemoveWithPrefix_Call struct { // MultiSaveAndRemoveWithPrefix is a helper method to define mock.On call // - saves map[string]string // - removals []string -func (_e *MetaKv_Expecter) MultiSaveAndRemoveWithPrefix(saves interface{}, removals interface{}) *MetaKv_MultiSaveAndRemoveWithPrefix_Call { - return &MetaKv_MultiSaveAndRemoveWithPrefix_Call{Call: _e.mock.On("MultiSaveAndRemoveWithPrefix", saves, removals)} +// - preds ...predicates.Predicate +func (_e *MetaKv_Expecter) MultiSaveAndRemoveWithPrefix(saves interface{}, removals interface{}, preds ...interface{}) *MetaKv_MultiSaveAndRemoveWithPrefix_Call { + return &MetaKv_MultiSaveAndRemoveWithPrefix_Call{Call: _e.mock.On("MultiSaveAndRemoveWithPrefix", + append([]interface{}{saves, removals}, preds...)...)} } -func (_c *MetaKv_MultiSaveAndRemoveWithPrefix_Call) Run(run func(saves map[string]string, removals []string)) *MetaKv_MultiSaveAndRemoveWithPrefix_Call { +func (_c *MetaKv_MultiSaveAndRemoveWithPrefix_Call) Run(run func(saves map[string]string, removals []string, preds ...predicates.Predicate)) *MetaKv_MultiSaveAndRemoveWithPrefix_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(map[string]string), args[1].([]string)) + variadicArgs := make([]predicates.Predicate, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(predicates.Predicate) + } + } + run(args[0].(map[string]string), args[1].([]string), variadicArgs...) }) return _c } @@ -625,7 +616,7 @@ func (_c *MetaKv_MultiSaveAndRemoveWithPrefix_Call) Return(_a0 error) *MetaKv_Mu return _c } -func (_c *MetaKv_MultiSaveAndRemoveWithPrefix_Call) RunAndReturn(run func(map[string]string, []string) error) *MetaKv_MultiSaveAndRemoveWithPrefix_Call { +func (_c *MetaKv_MultiSaveAndRemoveWithPrefix_Call) RunAndReturn(run func(map[string]string, []string, ...predicates.Predicate) error) *MetaKv_MultiSaveAndRemoveWithPrefix_Call { _c.Call.Return(run) return _c } diff --git a/internal/kv/mocks/txn_kv.go b/internal/kv/mocks/txn_kv.go index 6ce42c2e4fd62..25bbb438ff95b 100644 --- a/internal/kv/mocks/txn_kv.go +++ b/internal/kv/mocks/txn_kv.go @@ -2,7 +2,10 @@ package mocks -import mock "github.com/stretchr/testify/mock" +import ( + predicates "github.com/milvus-io/milvus/internal/kv/predicates" + mock "github.com/stretchr/testify/mock" +) // TxnKV is an autogenerated mock type for the TxnKV type type TxnKV struct { @@ -364,48 +367,6 @@ func (_c *TxnKV_MultiRemove_Call) RunAndReturn(run func([]string) error) *TxnKV_ return _c } -// MultiRemoveWithPrefix provides a mock function with given fields: keys -func (_m *TxnKV) MultiRemoveWithPrefix(keys []string) error { - ret := _m.Called(keys) - - var r0 error - if rf, ok := ret.Get(0).(func([]string) error); ok { - r0 = rf(keys) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// TxnKV_MultiRemoveWithPrefix_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MultiRemoveWithPrefix' -type TxnKV_MultiRemoveWithPrefix_Call struct { - *mock.Call -} - -// MultiRemoveWithPrefix is a helper method to define mock.On call -// - keys []string -func (_e *TxnKV_Expecter) MultiRemoveWithPrefix(keys interface{}) *TxnKV_MultiRemoveWithPrefix_Call { - return &TxnKV_MultiRemoveWithPrefix_Call{Call: _e.mock.On("MultiRemoveWithPrefix", keys)} -} - -func (_c *TxnKV_MultiRemoveWithPrefix_Call) Run(run func(keys []string)) *TxnKV_MultiRemoveWithPrefix_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]string)) - }) - return _c -} - -func (_c *TxnKV_MultiRemoveWithPrefix_Call) Return(_a0 error) *TxnKV_MultiRemoveWithPrefix_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *TxnKV_MultiRemoveWithPrefix_Call) RunAndReturn(run func([]string) error) *TxnKV_MultiRemoveWithPrefix_Call { - _c.Call.Return(run) - return _c -} - // MultiSave provides a mock function with given fields: kvs func (_m *TxnKV) MultiSave(kvs map[string]string) error { ret := _m.Called(kvs) @@ -448,13 +409,20 @@ func (_c *TxnKV_MultiSave_Call) RunAndReturn(run func(map[string]string) error) return _c } -// MultiSaveAndRemove provides a mock function with given fields: saves, removals -func (_m *TxnKV) MultiSaveAndRemove(saves map[string]string, removals []string) error { - ret := _m.Called(saves, removals) +// MultiSaveAndRemove provides a mock function with given fields: saves, removals, preds +func (_m *TxnKV) MultiSaveAndRemove(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + _va := make([]interface{}, len(preds)) + for _i := range preds { + _va[_i] = preds[_i] + } + var _ca []interface{} + _ca = append(_ca, saves, removals) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 error - if rf, ok := ret.Get(0).(func(map[string]string, []string) error); ok { - r0 = rf(saves, removals) + if rf, ok := ret.Get(0).(func(map[string]string, []string, ...predicates.Predicate) error); ok { + r0 = rf(saves, removals, preds...) } else { r0 = ret.Error(0) } @@ -470,13 +438,21 @@ type TxnKV_MultiSaveAndRemove_Call struct { // MultiSaveAndRemove is a helper method to define mock.On call // - saves map[string]string // - removals []string -func (_e *TxnKV_Expecter) MultiSaveAndRemove(saves interface{}, removals interface{}) *TxnKV_MultiSaveAndRemove_Call { - return &TxnKV_MultiSaveAndRemove_Call{Call: _e.mock.On("MultiSaveAndRemove", saves, removals)} +// - preds ...predicates.Predicate +func (_e *TxnKV_Expecter) MultiSaveAndRemove(saves interface{}, removals interface{}, preds ...interface{}) *TxnKV_MultiSaveAndRemove_Call { + return &TxnKV_MultiSaveAndRemove_Call{Call: _e.mock.On("MultiSaveAndRemove", + append([]interface{}{saves, removals}, preds...)...)} } -func (_c *TxnKV_MultiSaveAndRemove_Call) Run(run func(saves map[string]string, removals []string)) *TxnKV_MultiSaveAndRemove_Call { +func (_c *TxnKV_MultiSaveAndRemove_Call) Run(run func(saves map[string]string, removals []string, preds ...predicates.Predicate)) *TxnKV_MultiSaveAndRemove_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(map[string]string), args[1].([]string)) + variadicArgs := make([]predicates.Predicate, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(predicates.Predicate) + } + } + run(args[0].(map[string]string), args[1].([]string), variadicArgs...) }) return _c } @@ -486,18 +462,25 @@ func (_c *TxnKV_MultiSaveAndRemove_Call) Return(_a0 error) *TxnKV_MultiSaveAndRe return _c } -func (_c *TxnKV_MultiSaveAndRemove_Call) RunAndReturn(run func(map[string]string, []string) error) *TxnKV_MultiSaveAndRemove_Call { +func (_c *TxnKV_MultiSaveAndRemove_Call) RunAndReturn(run func(map[string]string, []string, ...predicates.Predicate) error) *TxnKV_MultiSaveAndRemove_Call { _c.Call.Return(run) return _c } -// MultiSaveAndRemoveWithPrefix provides a mock function with given fields: saves, removals -func (_m *TxnKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string) error { - ret := _m.Called(saves, removals) +// MultiSaveAndRemoveWithPrefix provides a mock function with given fields: saves, removals, preds +func (_m *TxnKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + _va := make([]interface{}, len(preds)) + for _i := range preds { + _va[_i] = preds[_i] + } + var _ca []interface{} + _ca = append(_ca, saves, removals) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 error - if rf, ok := ret.Get(0).(func(map[string]string, []string) error); ok { - r0 = rf(saves, removals) + if rf, ok := ret.Get(0).(func(map[string]string, []string, ...predicates.Predicate) error); ok { + r0 = rf(saves, removals, preds...) } else { r0 = ret.Error(0) } @@ -513,13 +496,21 @@ type TxnKV_MultiSaveAndRemoveWithPrefix_Call struct { // MultiSaveAndRemoveWithPrefix is a helper method to define mock.On call // - saves map[string]string // - removals []string -func (_e *TxnKV_Expecter) MultiSaveAndRemoveWithPrefix(saves interface{}, removals interface{}) *TxnKV_MultiSaveAndRemoveWithPrefix_Call { - return &TxnKV_MultiSaveAndRemoveWithPrefix_Call{Call: _e.mock.On("MultiSaveAndRemoveWithPrefix", saves, removals)} +// - preds ...predicates.Predicate +func (_e *TxnKV_Expecter) MultiSaveAndRemoveWithPrefix(saves interface{}, removals interface{}, preds ...interface{}) *TxnKV_MultiSaveAndRemoveWithPrefix_Call { + return &TxnKV_MultiSaveAndRemoveWithPrefix_Call{Call: _e.mock.On("MultiSaveAndRemoveWithPrefix", + append([]interface{}{saves, removals}, preds...)...)} } -func (_c *TxnKV_MultiSaveAndRemoveWithPrefix_Call) Run(run func(saves map[string]string, removals []string)) *TxnKV_MultiSaveAndRemoveWithPrefix_Call { +func (_c *TxnKV_MultiSaveAndRemoveWithPrefix_Call) Run(run func(saves map[string]string, removals []string, preds ...predicates.Predicate)) *TxnKV_MultiSaveAndRemoveWithPrefix_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(map[string]string), args[1].([]string)) + variadicArgs := make([]predicates.Predicate, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(predicates.Predicate) + } + } + run(args[0].(map[string]string), args[1].([]string), variadicArgs...) }) return _c } @@ -529,7 +520,7 @@ func (_c *TxnKV_MultiSaveAndRemoveWithPrefix_Call) Return(_a0 error) *TxnKV_Mult return _c } -func (_c *TxnKV_MultiSaveAndRemoveWithPrefix_Call) RunAndReturn(run func(map[string]string, []string) error) *TxnKV_MultiSaveAndRemoveWithPrefix_Call { +func (_c *TxnKV_MultiSaveAndRemoveWithPrefix_Call) RunAndReturn(run func(map[string]string, []string, ...predicates.Predicate) error) *TxnKV_MultiSaveAndRemoveWithPrefix_Call { _c.Call.Return(run) return _c } diff --git a/internal/kv/mocks/watch_kv.go b/internal/kv/mocks/watch_kv.go index a133029bedb93..c49ff4a924c7e 100644 --- a/internal/kv/mocks/watch_kv.go +++ b/internal/kv/mocks/watch_kv.go @@ -6,6 +6,8 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" mock "github.com/stretchr/testify/mock" + + predicates "github.com/milvus-io/milvus/internal/kv/predicates" ) // WatchKV is an autogenerated mock type for the WatchKV type @@ -464,48 +466,6 @@ func (_c *WatchKV_MultiRemove_Call) RunAndReturn(run func([]string) error) *Watc return _c } -// MultiRemoveWithPrefix provides a mock function with given fields: keys -func (_m *WatchKV) MultiRemoveWithPrefix(keys []string) error { - ret := _m.Called(keys) - - var r0 error - if rf, ok := ret.Get(0).(func([]string) error); ok { - r0 = rf(keys) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// WatchKV_MultiRemoveWithPrefix_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MultiRemoveWithPrefix' -type WatchKV_MultiRemoveWithPrefix_Call struct { - *mock.Call -} - -// MultiRemoveWithPrefix is a helper method to define mock.On call -// - keys []string -func (_e *WatchKV_Expecter) MultiRemoveWithPrefix(keys interface{}) *WatchKV_MultiRemoveWithPrefix_Call { - return &WatchKV_MultiRemoveWithPrefix_Call{Call: _e.mock.On("MultiRemoveWithPrefix", keys)} -} - -func (_c *WatchKV_MultiRemoveWithPrefix_Call) Run(run func(keys []string)) *WatchKV_MultiRemoveWithPrefix_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]string)) - }) - return _c -} - -func (_c *WatchKV_MultiRemoveWithPrefix_Call) Return(_a0 error) *WatchKV_MultiRemoveWithPrefix_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *WatchKV_MultiRemoveWithPrefix_Call) RunAndReturn(run func([]string) error) *WatchKV_MultiRemoveWithPrefix_Call { - _c.Call.Return(run) - return _c -} - // MultiSave provides a mock function with given fields: kvs func (_m *WatchKV) MultiSave(kvs map[string]string) error { ret := _m.Called(kvs) @@ -548,13 +508,20 @@ func (_c *WatchKV_MultiSave_Call) RunAndReturn(run func(map[string]string) error return _c } -// MultiSaveAndRemove provides a mock function with given fields: saves, removals -func (_m *WatchKV) MultiSaveAndRemove(saves map[string]string, removals []string) error { - ret := _m.Called(saves, removals) +// MultiSaveAndRemove provides a mock function with given fields: saves, removals, preds +func (_m *WatchKV) MultiSaveAndRemove(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + _va := make([]interface{}, len(preds)) + for _i := range preds { + _va[_i] = preds[_i] + } + var _ca []interface{} + _ca = append(_ca, saves, removals) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 error - if rf, ok := ret.Get(0).(func(map[string]string, []string) error); ok { - r0 = rf(saves, removals) + if rf, ok := ret.Get(0).(func(map[string]string, []string, ...predicates.Predicate) error); ok { + r0 = rf(saves, removals, preds...) } else { r0 = ret.Error(0) } @@ -570,13 +537,21 @@ type WatchKV_MultiSaveAndRemove_Call struct { // MultiSaveAndRemove is a helper method to define mock.On call // - saves map[string]string // - removals []string -func (_e *WatchKV_Expecter) MultiSaveAndRemove(saves interface{}, removals interface{}) *WatchKV_MultiSaveAndRemove_Call { - return &WatchKV_MultiSaveAndRemove_Call{Call: _e.mock.On("MultiSaveAndRemove", saves, removals)} +// - preds ...predicates.Predicate +func (_e *WatchKV_Expecter) MultiSaveAndRemove(saves interface{}, removals interface{}, preds ...interface{}) *WatchKV_MultiSaveAndRemove_Call { + return &WatchKV_MultiSaveAndRemove_Call{Call: _e.mock.On("MultiSaveAndRemove", + append([]interface{}{saves, removals}, preds...)...)} } -func (_c *WatchKV_MultiSaveAndRemove_Call) Run(run func(saves map[string]string, removals []string)) *WatchKV_MultiSaveAndRemove_Call { +func (_c *WatchKV_MultiSaveAndRemove_Call) Run(run func(saves map[string]string, removals []string, preds ...predicates.Predicate)) *WatchKV_MultiSaveAndRemove_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(map[string]string), args[1].([]string)) + variadicArgs := make([]predicates.Predicate, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(predicates.Predicate) + } + } + run(args[0].(map[string]string), args[1].([]string), variadicArgs...) }) return _c } @@ -586,18 +561,25 @@ func (_c *WatchKV_MultiSaveAndRemove_Call) Return(_a0 error) *WatchKV_MultiSaveA return _c } -func (_c *WatchKV_MultiSaveAndRemove_Call) RunAndReturn(run func(map[string]string, []string) error) *WatchKV_MultiSaveAndRemove_Call { +func (_c *WatchKV_MultiSaveAndRemove_Call) RunAndReturn(run func(map[string]string, []string, ...predicates.Predicate) error) *WatchKV_MultiSaveAndRemove_Call { _c.Call.Return(run) return _c } -// MultiSaveAndRemoveWithPrefix provides a mock function with given fields: saves, removals -func (_m *WatchKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string) error { - ret := _m.Called(saves, removals) +// MultiSaveAndRemoveWithPrefix provides a mock function with given fields: saves, removals, preds +func (_m *WatchKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + _va := make([]interface{}, len(preds)) + for _i := range preds { + _va[_i] = preds[_i] + } + var _ca []interface{} + _ca = append(_ca, saves, removals) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 error - if rf, ok := ret.Get(0).(func(map[string]string, []string) error); ok { - r0 = rf(saves, removals) + if rf, ok := ret.Get(0).(func(map[string]string, []string, ...predicates.Predicate) error); ok { + r0 = rf(saves, removals, preds...) } else { r0 = ret.Error(0) } @@ -613,13 +595,21 @@ type WatchKV_MultiSaveAndRemoveWithPrefix_Call struct { // MultiSaveAndRemoveWithPrefix is a helper method to define mock.On call // - saves map[string]string // - removals []string -func (_e *WatchKV_Expecter) MultiSaveAndRemoveWithPrefix(saves interface{}, removals interface{}) *WatchKV_MultiSaveAndRemoveWithPrefix_Call { - return &WatchKV_MultiSaveAndRemoveWithPrefix_Call{Call: _e.mock.On("MultiSaveAndRemoveWithPrefix", saves, removals)} +// - preds ...predicates.Predicate +func (_e *WatchKV_Expecter) MultiSaveAndRemoveWithPrefix(saves interface{}, removals interface{}, preds ...interface{}) *WatchKV_MultiSaveAndRemoveWithPrefix_Call { + return &WatchKV_MultiSaveAndRemoveWithPrefix_Call{Call: _e.mock.On("MultiSaveAndRemoveWithPrefix", + append([]interface{}{saves, removals}, preds...)...)} } -func (_c *WatchKV_MultiSaveAndRemoveWithPrefix_Call) Run(run func(saves map[string]string, removals []string)) *WatchKV_MultiSaveAndRemoveWithPrefix_Call { +func (_c *WatchKV_MultiSaveAndRemoveWithPrefix_Call) Run(run func(saves map[string]string, removals []string, preds ...predicates.Predicate)) *WatchKV_MultiSaveAndRemoveWithPrefix_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(map[string]string), args[1].([]string)) + variadicArgs := make([]predicates.Predicate, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(predicates.Predicate) + } + } + run(args[0].(map[string]string), args[1].([]string), variadicArgs...) }) return _c } @@ -629,7 +619,7 @@ func (_c *WatchKV_MultiSaveAndRemoveWithPrefix_Call) Return(_a0 error) *WatchKV_ return _c } -func (_c *WatchKV_MultiSaveAndRemoveWithPrefix_Call) RunAndReturn(run func(map[string]string, []string) error) *WatchKV_MultiSaveAndRemoveWithPrefix_Call { +func (_c *WatchKV_MultiSaveAndRemoveWithPrefix_Call) RunAndReturn(run func(map[string]string, []string, ...predicates.Predicate) error) *WatchKV_MultiSaveAndRemoveWithPrefix_Call { _c.Call.Return(run) return _c } diff --git a/internal/kv/predicates/mock_predicate.go b/internal/kv/predicates/mock_predicate.go new file mode 100644 index 0000000000000..1183a47881fc6 --- /dev/null +++ b/internal/kv/predicates/mock_predicate.go @@ -0,0 +1,240 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package predicates + +import mock "github.com/stretchr/testify/mock" + +// MockPredicate is an autogenerated mock type for the Predicate type +type MockPredicate struct { + mock.Mock +} + +type MockPredicate_Expecter struct { + mock *mock.Mock +} + +func (_m *MockPredicate) EXPECT() *MockPredicate_Expecter { + return &MockPredicate_Expecter{mock: &_m.Mock} +} + +// IsTrue provides a mock function with given fields: _a0 +func (_m *MockPredicate) IsTrue(_a0 interface{}) bool { + ret := _m.Called(_a0) + + var r0 bool + if rf, ok := ret.Get(0).(func(interface{}) bool); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockPredicate_IsTrue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsTrue' +type MockPredicate_IsTrue_Call struct { + *mock.Call +} + +// IsTrue is a helper method to define mock.On call +// - _a0 interface{} +func (_e *MockPredicate_Expecter) IsTrue(_a0 interface{}) *MockPredicate_IsTrue_Call { + return &MockPredicate_IsTrue_Call{Call: _e.mock.On("IsTrue", _a0)} +} + +func (_c *MockPredicate_IsTrue_Call) Run(run func(_a0 interface{})) *MockPredicate_IsTrue_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockPredicate_IsTrue_Call) Return(_a0 bool) *MockPredicate_IsTrue_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPredicate_IsTrue_Call) RunAndReturn(run func(interface{}) bool) *MockPredicate_IsTrue_Call { + _c.Call.Return(run) + return _c +} + +// Key provides a mock function with given fields: +func (_m *MockPredicate) Key() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockPredicate_Key_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Key' +type MockPredicate_Key_Call struct { + *mock.Call +} + +// Key is a helper method to define mock.On call +func (_e *MockPredicate_Expecter) Key() *MockPredicate_Key_Call { + return &MockPredicate_Key_Call{Call: _e.mock.On("Key")} +} + +func (_c *MockPredicate_Key_Call) Run(run func()) *MockPredicate_Key_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockPredicate_Key_Call) Return(_a0 string) *MockPredicate_Key_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPredicate_Key_Call) RunAndReturn(run func() string) *MockPredicate_Key_Call { + _c.Call.Return(run) + return _c +} + +// Target provides a mock function with given fields: +func (_m *MockPredicate) Target() PredicateTarget { + ret := _m.Called() + + var r0 PredicateTarget + if rf, ok := ret.Get(0).(func() PredicateTarget); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(PredicateTarget) + } + + return r0 +} + +// MockPredicate_Target_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Target' +type MockPredicate_Target_Call struct { + *mock.Call +} + +// Target is a helper method to define mock.On call +func (_e *MockPredicate_Expecter) Target() *MockPredicate_Target_Call { + return &MockPredicate_Target_Call{Call: _e.mock.On("Target")} +} + +func (_c *MockPredicate_Target_Call) Run(run func()) *MockPredicate_Target_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockPredicate_Target_Call) Return(_a0 PredicateTarget) *MockPredicate_Target_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPredicate_Target_Call) RunAndReturn(run func() PredicateTarget) *MockPredicate_Target_Call { + _c.Call.Return(run) + return _c +} + +// TargetValue provides a mock function with given fields: +func (_m *MockPredicate) TargetValue() interface{} { + ret := _m.Called() + + var r0 interface{} + if rf, ok := ret.Get(0).(func() interface{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } + } + + return r0 +} + +// MockPredicate_TargetValue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TargetValue' +type MockPredicate_TargetValue_Call struct { + *mock.Call +} + +// TargetValue is a helper method to define mock.On call +func (_e *MockPredicate_Expecter) TargetValue() *MockPredicate_TargetValue_Call { + return &MockPredicate_TargetValue_Call{Call: _e.mock.On("TargetValue")} +} + +func (_c *MockPredicate_TargetValue_Call) Run(run func()) *MockPredicate_TargetValue_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockPredicate_TargetValue_Call) Return(_a0 interface{}) *MockPredicate_TargetValue_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPredicate_TargetValue_Call) RunAndReturn(run func() interface{}) *MockPredicate_TargetValue_Call { + _c.Call.Return(run) + return _c +} + +// Type provides a mock function with given fields: +func (_m *MockPredicate) Type() PredicateType { + ret := _m.Called() + + var r0 PredicateType + if rf, ok := ret.Get(0).(func() PredicateType); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(PredicateType) + } + + return r0 +} + +// MockPredicate_Type_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Type' +type MockPredicate_Type_Call struct { + *mock.Call +} + +// Type is a helper method to define mock.On call +func (_e *MockPredicate_Expecter) Type() *MockPredicate_Type_Call { + return &MockPredicate_Type_Call{Call: _e.mock.On("Type")} +} + +func (_c *MockPredicate_Type_Call) Run(run func()) *MockPredicate_Type_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockPredicate_Type_Call) Return(_a0 PredicateType) *MockPredicate_Type_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockPredicate_Type_Call) RunAndReturn(run func() PredicateType) *MockPredicate_Type_Call { + _c.Call.Return(run) + return _c +} + +// NewMockPredicate creates a new instance of MockPredicate. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockPredicate(t interface { + mock.TestingT + Cleanup(func()) +}) *MockPredicate { + mock := &MockPredicate{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/kv/predicates/predicate.go b/internal/kv/predicates/predicate.go new file mode 100644 index 0000000000000..7f1a234888196 --- /dev/null +++ b/internal/kv/predicates/predicate.go @@ -0,0 +1,73 @@ +package predicates + +// PredicateTarget is enum for Predicate target type. +type PredicateTarget int32 + +const ( + // PredTargetValue is predicate target for key-value perid + PredTargetValue PredicateTarget = iota + 1 +) + +type PredicateType int32 + +const ( + PredTypeEqual PredicateType = iota + 1 +) + +// Predicate provides interface for kv predicate. +type Predicate interface { + Target() PredicateTarget + Type() PredicateType + IsTrue(any) bool + Key() string + TargetValue() any +} + +type valuePredicate struct { + k, v string + pt PredicateType +} + +func (p *valuePredicate) Target() PredicateTarget { + return PredTargetValue +} + +func (p *valuePredicate) Type() PredicateType { + return p.pt +} + +func (p *valuePredicate) IsTrue(target any) bool { + switch v := target.(type) { + case string: + return predicateValue(p.pt, v, p.v) + case []byte: + return predicateValue(p.pt, string(v), p.v) + default: + return false + } +} + +func (p *valuePredicate) Key() string { + return p.k +} + +func (p *valuePredicate) TargetValue() any { + return p.v +} + +func predicateValue[T comparable](pt PredicateType, v1, v2 T) bool { + switch pt { + case PredTypeEqual: + return v1 == v2 + default: + return false + } +} + +func ValueEqual(k, v string) Predicate { + return &valuePredicate{ + k: k, + v: v, + pt: PredTypeEqual, + } +} diff --git a/internal/kv/predicates/predicate_test.go b/internal/kv/predicates/predicate_test.go new file mode 100644 index 0000000000000..774cd56c62b8a --- /dev/null +++ b/internal/kv/predicates/predicate_test.go @@ -0,0 +1,33 @@ +package predicates + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type PredicateSuite struct { + suite.Suite +} + +func (s *PredicateSuite) TestValueEqual() { + p := ValueEqual("key", "value") + s.Equal("key", p.Key()) + s.Equal("value", p.TargetValue()) + s.Equal(PredTargetValue, p.Target()) + s.Equal(PredTypeEqual, p.Type()) + s.True(p.IsTrue("value")) + s.False(p.IsTrue("not_value")) + s.True(p.IsTrue([]byte("value"))) + s.False(p.IsTrue(1)) +} + +func (s *PredicateSuite) TestPredicateValue() { + s.True(predicateValue(PredTypeEqual, 1, 1)) + s.False(predicateValue(PredTypeEqual, 1, 2)) + s.False(predicateValue(0, 1, 1)) +} + +func TestPredicates(t *testing.T) { + suite.Run(t, new(PredicateSuite)) +} diff --git a/internal/kv/rocksdb/rocksdb_kv.go b/internal/kv/rocksdb/rocksdb_kv.go index ffeb3ac0d9f93..f8854138910ac 100644 --- a/internal/kv/rocksdb/rocksdb_kv.go +++ b/internal/kv/rocksdb/rocksdb_kv.go @@ -23,6 +23,8 @@ import ( "github.com/tecbot/gorocksdb" "github.com/milvus-io/milvus/internal/kv" + "github.com/milvus-io/milvus/internal/kv/predicates" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -389,7 +391,10 @@ func (kv *RocksdbKV) MultiRemove(keys []string) error { } // MultiSaveAndRemove provides a transaction to execute a batch of operations -func (kv *RocksdbKV) MultiSaveAndRemove(saves map[string]string, removals []string) error { +func (kv *RocksdbKV) MultiSaveAndRemove(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + if len(preds) > 0 { + return merr.WrapErrServiceUnavailable("predicates not supported") + } if kv.DB == nil { return errors.New("Rocksdb instance is nil when do MultiSaveAndRemove") } @@ -420,20 +425,11 @@ func (kv *RocksdbKV) DeleteRange(startKey, endKey string) error { return err } -// MultiRemoveWithPrefix is used to remove a batch of key-values with the same prefix -func (kv *RocksdbKV) MultiRemoveWithPrefix(prefixes []string) error { - if kv.DB == nil { - return errors.New("rocksdb instance is nil when do RemoveWithPrefix") - } - writeBatch := gorocksdb.NewWriteBatch() - defer writeBatch.Destroy() - kv.prepareRemovePrefix(prefixes, writeBatch) - err := kv.DB.Write(kv.WriteOptions, writeBatch) - return err -} - // MultiSaveAndRemoveWithPrefix is used to execute a batch operators with the same prefix -func (kv *RocksdbKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string) error { +func (kv *RocksdbKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + if len(preds) > 0 { + return merr.WrapErrServiceUnavailable("predicates not supported") + } if kv.DB == nil { return errors.New("Rocksdb instance is nil when do MultiSaveAndRemove") } diff --git a/internal/kv/rocksdb/rocksdb_kv_test.go b/internal/kv/rocksdb/rocksdb_kv_test.go index 50e15160537b6..b1d07010b8ecd 100644 --- a/internal/kv/rocksdb/rocksdb_kv_test.go +++ b/internal/kv/rocksdb/rocksdb_kv_test.go @@ -23,8 +23,11 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/milvus-io/milvus/internal/kv/predicates" rocksdbkv "github.com/milvus-io/milvus/internal/kv/rocksdb" + "github.com/milvus-io/milvus/pkg/util/merr" ) func TestRocksdbKV(t *testing.T) { @@ -188,7 +191,6 @@ func TestRocksdbKV_Prefix(t *testing.T) { val, err = rocksdbKV.Load("abcd") assert.NoError(t, err) assert.Equal(t, val, "123") - } func TestRocksdbKV_Txn(t *testing.T) { @@ -215,30 +217,6 @@ func TestRocksdbKV_Txn(t *testing.T) { assert.Equal(t, len(keys), 3) assert.Equal(t, len(vals), 3) - removePrefix := []string{"abc", "abd"} - rocksdbKV.MultiRemoveWithPrefix(removePrefix) - - keys, vals, err = rocksdbKV.LoadWithPrefix("") - assert.NoError(t, err) - assert.Equal(t, len(keys), 0) - assert.Equal(t, len(vals), 0) - - err = rocksdbKV.MultiSave(kvs) - assert.NoError(t, err) - keys, vals, err = rocksdbKV.LoadWithPrefix("") - assert.NoError(t, err) - assert.Equal(t, len(keys), 3) - assert.Equal(t, len(vals), 3) - - // test delete the whole table - removePrefix = []string{"", "hello"} - rocksdbKV.MultiRemoveWithPrefix(removePrefix) - - keys, vals, err = rocksdbKV.LoadWithPrefix("") - assert.NoError(t, err) - assert.Equal(t, len(keys), 0) - assert.Equal(t, len(vals), 0) - err = rocksdbKV.MultiSave(kvs) assert.NoError(t, err) keys, vals, err = rocksdbKV.LoadWithPrefix("") @@ -247,7 +225,7 @@ func TestRocksdbKV_Txn(t *testing.T) { assert.Equal(t, len(vals), 3) // test remove and save - removePrefix = []string{"abc", "abd"} + removePrefix := []string{"abc", "abd"} kvs2 := map[string]string{ "abfad": "12345", } @@ -389,3 +367,20 @@ func TestHasPrefix(t *testing.T) { assert.NoError(t, err) assert.False(t, has) } + +func TestPredicates(t *testing.T) { + dir := t.TempDir() + db, err := rocksdbkv.NewRocksdbKV(dir) + + require.NoError(t, err) + defer db.Close() + defer db.RemoveWithPrefix("") + + err = db.MultiSaveAndRemove(map[string]string{}, []string{}, predicates.ValueEqual("a", "b")) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrServiceUnavailable) + + err = db.MultiSaveAndRemoveWithPrefix(map[string]string{}, []string{}, predicates.ValueEqual("a", "b")) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrServiceUnavailable) +} diff --git a/internal/kv/tikv/main_test.go b/internal/kv/tikv/main_test.go new file mode 100644 index 0000000000000..f22cb1705a3b0 --- /dev/null +++ b/internal/kv/tikv/main_test.go @@ -0,0 +1,96 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tikv + +import ( + "context" + "os" + "testing" + + "github.com/tikv/client-go/v2/rawkv" + "github.com/tikv/client-go/v2/testutils" + tilib "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/txnkv" + + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +var ( + txnClient *txnkv.Client + rawClient *rawkv.Client +) + +// creates a local TiKV Store for testing purpose. +func setupLocalTiKV() { + setupLocalTxn() + setupLocalRaw() +} + +func setupLocalTxn() { + client, cluster, pdClient, err := testutils.NewMockTiKV("", nil) + if err != nil { + panic(err) + } + testutils.BootstrapWithSingleStore(cluster) + store, err := tilib.NewTestTiKVStore(client, pdClient, nil, nil, 0) + if err != nil { + panic(err) + } + txnClient = &txnkv.Client{KVStore: store} +} + +func setupLocalRaw() { + client, cluster, pdClient, err := testutils.NewMockTiKV("", nil) + if err != nil { + panic(err) + } + testutils.BootstrapWithSingleStore(cluster) + rawClient = &rawkv.Client{} + p := rawkv.ClientProbe{Client: rawClient} + p.SetPDClient(pdClient) + p.SetRegionCache(tilib.NewRegionCache(pdClient)) + p.SetRPCClient(client) +} + +// Connects to a remote TiKV service for testing purpose. By default, it assumes the TiKV is from localhost. +func setupRemoteTiKV() { + pdsn := "127.0.0.1:2379" + var err error + txnClient, err = txnkv.NewClient([]string{pdsn}) + if err != nil { + panic(err) + } + rawClient, err = rawkv.NewClientWithOpts(context.Background(), []string{pdsn}) + if err != nil { + panic(err) + } +} + +func setupTiKV(useRemote bool) { + if useRemote { + setupRemoteTiKV() + } else { + setupLocalTiKV() + } +} + +func TestMain(m *testing.M) { + paramtable.Init() + setupTiKV(false) + code := m.Run() + os.Exit(code) +} diff --git a/internal/kv/tikv/txn_tikv.go b/internal/kv/tikv/txn_tikv.go new file mode 100644 index 0000000000000..d9542478d8f04 --- /dev/null +++ b/internal/kv/tikv/txn_tikv.go @@ -0,0 +1,779 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tikv + +import ( + "bytes" + "context" + "fmt" + "math" + "path" + "time" + + "github.com/cockroachdb/errors" + tikverr "github.com/tikv/client-go/v2/error" + tikv "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/txnkv" + "github.com/tikv/client-go/v2/txnkv/transaction" + "github.com/tikv/client-go/v2/txnkv/txnsnapshot" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/kv" + "github.com/milvus-io/milvus/internal/kv/predicates" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/timerecord" +) + +// A quick note is that we are using loggingErr at our outermost scope in order to perform logging + +const ( + // We are using a Snapshot instead of transaction when doing read only operations due to the + // lower overhead (50% less overhead in small tests). In order to guarantee the latest values are + // grabbed at call, we can set the TS to be the max uint64. + MaxSnapshotTS = uint64(math.MaxUint64) + // Whether to enable transaction rollback if a transaction fails. Due to TiKV transactions being + // optimistic by default, a rollback does not need to be done and the transaction can just be + // discarded. Discarding saw a small bump in performance on small scale tests. + EnableRollback = false + // This empty value is what we are reserving within TiKv to represent an empty string value. + // TiKV does not allow storing empty values for keys which is something we do in Milvus, so + // to get over this we are using the reserved keyword as placeholder. + EmptyValueString = "__milvus_reserved_empty_tikv_value_DO_NOT_USE" +) + +var Params *paramtable.ComponentParam = paramtable.Get() + +// For reads by prefix we can customize the scan size to increase/decrease rpc calls. +var SnapshotScanSize int + +// RequestTimeout is the default timeout for tikv request. +var RequestTimeout time.Duration + +var EmptyValueByte = []byte(EmptyValueString) + +func tiTxnBegin(txn *txnkv.Client) (*transaction.KVTxn, error) { + return txn.Begin() +} + +func tiTxnCommit(ctx context.Context, txn *transaction.KVTxn) error { + return txn.Commit(ctx) +} + +func tiTxnSnapshot(txn *txnkv.Client, paginationSize int) *txnsnapshot.KVSnapshot { + ss := txn.GetSnapshot(MaxSnapshotTS) + ss.SetScanBatchSize(paginationSize) + return ss +} + +var ( + beginTxn = tiTxnBegin + commitTxn = tiTxnCommit + getSnapshot = tiTxnSnapshot +) + +// implementation assertion +var _ kv.MetaKv = (*txnTiKV)(nil) + +// txnTiKV implements MetaKv and TxnKV interface. It supports processing multiple kvs within one transaction. +type txnTiKV struct { + txn *txnkv.Client + rootPath string +} + +// NewTiKV creates a new txnTiKV client. +func NewTiKV(txn *txnkv.Client, rootPath string) *txnTiKV { + SnapshotScanSize = Params.TiKVCfg.SnapshotScanSize.GetAsInt() + RequestTimeout = Params.TiKVCfg.RequestTimeout.GetAsDuration(time.Millisecond) + kv := &txnTiKV{ + txn: txn, + rootPath: rootPath, + } + return kv +} + +// Close closes the connection to TiKV. +func (kv *txnTiKV) Close() { + log.Info("txnTiKV closed", zap.String("path", kv.rootPath)) +} + +// GetPath returns the path of the key/prefix. +func (kv *txnTiKV) GetPath(key string) string { + return path.Join(kv.rootPath, key) +} + +// Log if error is not nil. We use error pointer as in most cases this function +// is Deferred. Deferred functions evaluate their args immediately. +func logWarnOnFailure(err *error, msg string, fields ...zap.Field) { + if *err != nil { + fields = append(fields, zap.Error(*err)) + log.Warn(msg, fields...) + } +} + +// Has returns if a key exists. +func (kv *txnTiKV) Has(key string) (bool, error) { + start := time.Now() + key = path.Join(kv.rootPath, key) + ctx, cancel := context.WithTimeout(context.Background(), RequestTimeout) + defer cancel() + + var loggingErr error + defer logWarnOnFailure(&loggingErr, "txnTiKV Has() error", zap.String("key", key)) + + _, err := kv.getTiKVMeta(ctx, key) + if err != nil { + // Dont error out if not present unless failed call to tikv + if errors.Is(err, merr.ErrIoKeyNotFound) { + return false, nil + } + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to read key: %s", key)) + return false, loggingErr + } + CheckElapseAndWarn(start, "Slow txnTiKV Has() operation", zap.String("key", key)) + return true, nil +} + +func rollbackOnFailure(err *error, txn *transaction.KVTxn) { + if *err != nil && EnableRollback == true { + txn.Rollback() + } +} + +// HasPrefix returns if a key prefix exists. +func (kv *txnTiKV) HasPrefix(prefix string) (bool, error) { + start := time.Now() + prefix = path.Join(kv.rootPath, prefix) + + var loggingErr error + defer logWarnOnFailure(&loggingErr, "txnTiKV HasPrefix() error", zap.String("prefix", prefix)) + + ss := getSnapshot(kv.txn, SnapshotScanSize) + + // Retrieve bounding keys for prefix + startKey := []byte(prefix) + endKey := tikv.PrefixNextKey([]byte(prefix)) + + iter, err := ss.Iter(startKey, endKey) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to create iterator for prefix: %s", prefix)) + return false, loggingErr + } + defer iter.Close() + + r := false + // Iterater only needs to check the first key-value pair + if iter.Valid() { + r = true + } + CheckElapseAndWarn(start, "Slow txnTiKV HasPrefix() operation", zap.String("prefix", prefix)) + return r, nil +} + +// Load returns value of the key. +func (kv *txnTiKV) Load(key string) (string, error) { + start := time.Now() + key = path.Join(kv.rootPath, key) + ctx, cancel := context.WithTimeout(context.Background(), RequestTimeout) + defer cancel() + + var loggingErr error + defer logWarnOnFailure(&loggingErr, "txnTiKV Load() error", zap.String("key", key)) + + val, err := kv.getTiKVMeta(ctx, key) + if err != nil { + if errors.Is(err, merr.ErrIoKeyNotFound) { + loggingErr = err + } else { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to read key %s", key)) + } + return "", loggingErr + } + CheckElapseAndWarn(start, "Slow txnTiKV Load() operation", zap.String("key", key)) + return val, nil +} + +func batchConvertFromString(prefix string, keys []string) [][]byte { + output := make([][]byte, len(keys)) + for i := 0; i < len(keys); i++ { + keys[i] = path.Join(prefix, keys[i]) + output[i] = []byte(keys[i]) + } + return output +} + +// MultiLoad gets the values of input keys in a transaction. +func (kv *txnTiKV) MultiLoad(keys []string) ([]string, error) { + start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), RequestTimeout) + defer cancel() + + var loggingErr error + defer logWarnOnFailure(&loggingErr, "txnTiKV MultiLoad() error", zap.Strings("keys", keys)) + + // Convert from []string to [][]byte + byteKeys := batchConvertFromString(kv.rootPath, keys) + + // Since only reading, use Snapshot for less overhead + ss := getSnapshot(kv.txn, SnapshotScanSize) + + keyMap, err := ss.BatchGet(ctx, byteKeys) + if err != nil { + loggingErr = errors.Wrap(err, "Failed ss.BatchGet() for MultiLoad") + return nil, loggingErr + } + + missingValues := []string{} + validValues := []string{} + + // Loop through keys and build valid/invalid slices + for _, k := range keys { + v, ok := keyMap[k] + if !ok { + missingValues = append(missingValues, k) + } + // Check if empty value placeholder + strVal := convertEmptyByteToString(v) + validValues = append(validValues, strVal) + } + if len(missingValues) != 0 { + loggingErr = fmt.Errorf("There are invalid keys: %s", missingValues) + } + + CheckElapseAndWarn(start, "Slow txnTiKV MultiLoad() operation", zap.Any("keys", keys)) + return validValues, loggingErr +} + +// LoadWithPrefix returns all the keys and values for the given key prefix. +func (kv *txnTiKV) LoadWithPrefix(prefix string) ([]string, []string, error) { + start := time.Now() + prefix = path.Join(kv.rootPath, prefix) + + var loggingErr error + defer logWarnOnFailure(&loggingErr, "txnTiKV LoadWithPrefix() error", zap.String("prefix", prefix)) + + ss := getSnapshot(kv.txn, SnapshotScanSize) + + // Retrieve key-value pairs with the specified prefix + startKey := []byte(prefix) + endKey := tikv.PrefixNextKey([]byte(prefix)) + iter, err := ss.Iter(startKey, endKey) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to create iterater for LoadWithPrefix() for prefix: %s", prefix)) + return nil, nil, loggingErr + } + defer iter.Close() + + var keys []string + var values []string + + // Iterate over the key-value pairs + for iter.Valid() { + val := iter.Value() + // Check if empty value placeholder + strVal := convertEmptyByteToString(val) + keys = append(keys, string(iter.Key())) + values = append(values, strVal) + err = iter.Next() + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to iterate for LoadWithPrefix() for prefix: %s", prefix)) + return nil, nil, loggingErr + } + } + CheckElapseAndWarn(start, "Slow txnTiKV LoadWithPrefix() operation", zap.String("prefix", prefix)) + return keys, values, nil +} + +// Save saves the input key-value pair. +func (kv *txnTiKV) Save(key, value string) error { + key = path.Join(kv.rootPath, key) + ctx, cancel := context.WithTimeout(context.Background(), RequestTimeout) + defer cancel() + + var loggingErr error + defer logWarnOnFailure(&loggingErr, "txnTiKV Save() error", zap.String("key", key), zap.String("value", value)) + + loggingErr = kv.putTiKVMeta(ctx, key, value) + return loggingErr +} + +// MultiSave saves the input key-value pairs in transaction manner. +func (kv *txnTiKV) MultiSave(kvs map[string]string) error { + start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), RequestTimeout) + defer cancel() + + var loggingErr error + defer logWarnOnFailure(&loggingErr, "txnTiKV MultiSave() error", zap.Any("kvs", kvs), zap.Int("len", len(kvs))) + + txn, err := beginTxn(kv.txn) + if err != nil { + loggingErr = errors.Wrap(err, "Failed to create txn for MultiSave") + return loggingErr + } + + // Defer a rollback only if the transaction hasn't been committed + defer rollbackOnFailure(&loggingErr, txn) + + for key, value := range kvs { + key = path.Join(kv.rootPath, key) + // Check if value is empty or taking reserved EmptyValue + byteValue, err := convertEmptyStringToByte(value) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to cast to byte (%s:%s) for MultiSave()", key, value)) + return loggingErr + } + // Save the value within a transaction + err = txn.Set([]byte(key), byteValue) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to set (%s:%s) for MultiSave()", key, value)) + return loggingErr + } + } + err = kv.executeTxn(ctx, txn) + if err != nil { + loggingErr = errors.Wrap(err, "Failed to commit for MultiSave()") + return loggingErr + } + CheckElapseAndWarn(start, "Slow txnTiKV MultiSave() operation", zap.Any("kvs", kvs)) + return nil +} + +// Remove removes the input key. +func (kv *txnTiKV) Remove(key string) error { + key = path.Join(kv.rootPath, key) + ctx, cancel := context.WithTimeout(context.Background(), RequestTimeout) + defer cancel() + + var loggingErr error + defer logWarnOnFailure(&loggingErr, "txnTiKV Remove() error", zap.String("key", key)) + + loggingErr = kv.removeTiKVMeta(ctx, key) + return loggingErr +} + +// MultiRemove removes the input keys in transaction manner. +func (kv *txnTiKV) MultiRemove(keys []string) error { + start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), RequestTimeout) + defer cancel() + + var loggingErr error + defer logWarnOnFailure(&loggingErr, "txnTiKV MultiRemove() error", zap.Strings("keys", keys), zap.Int("len", len(keys))) + + txn, err := beginTxn(kv.txn) + if err != nil { + loggingErr = errors.Wrap(err, "Failed to create txn for MultiRemove") + return loggingErr + } + + // Defer a rollback only if the transaction hasn't been committed + defer rollbackOnFailure(&loggingErr, txn) + + for _, key := range keys { + key = path.Join(kv.rootPath, key) + loggingErr = txn.Delete([]byte(key)) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to delete %s for MultiRemove", key)) + return loggingErr + } + } + + err = kv.executeTxn(ctx, txn) + if err != nil { + loggingErr = errors.Wrap(err, "Failed to commit for MultiRemove()") + return loggingErr + } + CheckElapseAndWarn(start, "Slow txnTiKV MultiRemove() operation", zap.Strings("keys", keys)) + return nil +} + +// RemoveWithPrefix removes the keys for the given prefix. +func (kv *txnTiKV) RemoveWithPrefix(prefix string) error { + start := time.Now() + prefix = path.Join(kv.rootPath, prefix) + ctx, cancel := context.WithTimeout(context.Background(), RequestTimeout) + defer cancel() + + var loggingErr error + defer logWarnOnFailure(&loggingErr, "txnTiKV RemoveWithPrefix() error", zap.String("prefix", prefix)) + + startKey := []byte(prefix) + endKey := tikv.PrefixNextKey(startKey) + _, err := kv.txn.DeleteRange(ctx, startKey, endKey, 1) + if err != nil { + loggingErr = errors.Wrap(err, "Failed to DeleteRange for RemoveWithPrefix") + return loggingErr + } + CheckElapseAndWarn(start, "Slow txnTiKV RemoveWithPrefix() operation", zap.String("prefix", prefix)) + return nil +} + +// MultiSaveAndRemove saves the key-value pairs and removes the keys in a transaction. +func (kv *txnTiKV) MultiSaveAndRemove(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), RequestTimeout) + defer cancel() + + var loggingErr error + defer logWarnOnFailure(&loggingErr, "txnTiKV MultiSaveAndRemove error", zap.Any("saves", saves), zap.Strings("removes", removals), zap.Int("saveLength", len(saves)), zap.Int("removeLength", len(removals))) + + txn, err := beginTxn(kv.txn) + if err != nil { + loggingErr = errors.Wrap(err, "Failed to create txn for MultiSaveAndRemove") + return loggingErr + } + + // Defer a rollback only if the transaction hasn't been committed + defer rollbackOnFailure(&loggingErr, txn) + + for _, pred := range preds { + key := path.Join(kv.rootPath, pred.Key()) + val, err := txn.Get(ctx, []byte(key)) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("failed to read predicate target (%s:%v) for MultiSaveAndRemove", pred.Key(), pred.TargetValue())) + return loggingErr + } + if !pred.IsTrue(val) { + loggingErr = merr.WrapErrIoFailedReason("failed to meet predicate", fmt.Sprintf("key=%s, value=%v", pred.Key(), pred.TargetValue())) + return loggingErr + } + } + + for key, value := range saves { + key = path.Join(kv.rootPath, key) + // Check if value is empty or taking reserved EmptyValue + byteValue, err := convertEmptyStringToByte(value) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to cast to byte (%s:%s) for MultiSaveAndRemove", key, value)) + return loggingErr + } + err = txn.Set([]byte(key), byteValue) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to set (%s:%s) for MultiSaveAndRemove", key, value)) + return loggingErr + } + } + + for _, key := range removals { + key = path.Join(kv.rootPath, key) + if err = txn.Delete([]byte(key)); err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to delete %s for MultiSaveAndRemove", key)) + return loggingErr + } + } + + err = kv.executeTxn(ctx, txn) + if err != nil { + loggingErr = errors.Wrap(err, "Failed to commit for MultiSaveAndRemove") + return loggingErr + } + CheckElapseAndWarn(start, "Slow txnTiKV MultiSaveAndRemove() operation", zap.Any("saves", saves), zap.Strings("removals", removals)) + return nil +} + +// MultiSaveAndRemoveWithPrefix saves kv in @saves and removes the keys with given prefix in @removals. +func (kv *txnTiKV) MultiSaveAndRemoveWithPrefix(saves map[string]string, removals []string, preds ...predicates.Predicate) error { + start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), RequestTimeout) + defer cancel() + + var loggingErr error + defer logWarnOnFailure(&loggingErr, "txnTiKV MultiSaveAndRemoveWithPrefix() error", zap.Any("saves", saves), zap.Strings("removes", removals), zap.Int("saveLength", len(saves)), zap.Int("removeLength", len(removals))) + + txn, err := beginTxn(kv.txn) + if err != nil { + loggingErr = errors.Wrap(err, "Failed to create txn for MultiSaveAndRemoveWithPrefix") + return loggingErr + } + + // Defer a rollback only if the transaction hasn't been committed + defer rollbackOnFailure(&loggingErr, txn) + + for _, pred := range preds { + key := path.Join(kv.rootPath, pred.Key()) + val, err := txn.Get(ctx, []byte(key)) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("failed to read predicate target (%s:%v) for MultiSaveAndRemove", pred.Key(), pred.TargetValue())) + return loggingErr + } + if !pred.IsTrue(val) { + loggingErr = merr.WrapErrIoFailedReason("failed to meet predicate", fmt.Sprintf("key=%s, value=%v", pred.Key(), pred.TargetValue())) + return loggingErr + } + } + + // Save key-value pairs + for key, value := range saves { + key = path.Join(kv.rootPath, key) + // Check if value is empty or taking reserved EmptyValue + byteValue, err := convertEmptyStringToByte(value) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to cast to byte (%s:%s) for MultiSaveAndRemoveWithPrefix()", key, value)) + return loggingErr + } + err = txn.Set([]byte(key), byteValue) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to set (%s:%s) for MultiSaveAndRemoveWithPrefix()", key, value)) + return loggingErr + } + } + // Remove keys with prefix + for _, prefix := range removals { + prefix = path.Join(kv.rootPath, prefix) + // Get the start and end keys for the prefix range + startKey := []byte(prefix) + endKey := tikv.PrefixNextKey([]byte(prefix)) + + // Use Scan to iterate over keys in the prefix range + iter, err := txn.Iter(startKey, endKey) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to create iterater for %s during MultiSaveAndRemoveWithPrefix()", prefix)) + return loggingErr + } + + // Iterate over keys and delete them + for iter.Valid() { + key := iter.Key() + err = txn.Delete(key) + if loggingErr != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to delete %s for MultiSaveAndRemoveWithPrefix", string(key))) + return loggingErr + } + + // Move the iterator to the next key + err = iter.Next() + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to move Iterator after key %s for MultiSaveAndRemoveWithPrefix", string(key))) + return loggingErr + } + } + } + err = kv.executeTxn(ctx, txn) + if err != nil { + loggingErr = errors.Wrap(err, "Failed to commit for MultiSaveAndRemoveWithPrefix") + return loggingErr + } + CheckElapseAndWarn(start, "Slow txnTiKV MultiSaveAndRemoveWithPrefix() operation", zap.Any("saves", saves), zap.Strings("removals", removals)) + return nil +} + +// WalkWithPrefix visits each kv with input prefix and apply given fn to it. +func (kv *txnTiKV) WalkWithPrefix(prefix string, paginationSize int, fn func([]byte, []byte) error) error { + start := time.Now() + prefix = path.Join(kv.rootPath, prefix) + + var loggingErr error + defer logWarnOnFailure(&loggingErr, "txnTiKV WalkWithPagination error", zap.String("prefix", prefix)) + + // Since only reading, use Snapshot for less overhead + ss := getSnapshot(kv.txn, paginationSize) + + // Retrieve key-value pairs with the specified prefix + startKey := []byte(prefix) + endKey := tikv.PrefixNextKey([]byte(prefix)) + iter, err := ss.Iter(startKey, endKey) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to create iterater for %s during WalkWithPrefix", prefix)) + return loggingErr + } + defer iter.Close() + + // Iterate over the key-value pairs + for iter.Valid() { + // Grab value for empty check + byteVal := iter.Value() + // Check if empty val and replace with placeholder + if isEmptyByte(byteVal) { + byteVal = []byte{} + } + err = fn(iter.Key(), byteVal) + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to apply fn to (%s;%s)", string(iter.Key()), string(byteVal))) + return loggingErr + } + err = iter.Next() + if err != nil { + loggingErr = errors.Wrap(err, fmt.Sprintf("Failed to move Iterator after key %s for WalkWithPrefix", string(iter.Key()))) + return loggingErr + } + } + CheckElapseAndWarn(start, "Slow txnTiKV WalkWithPagination() operation", zap.String("prefix", prefix)) + return nil +} + +func (kv *txnTiKV) executeTxn(ctx context.Context, txn *transaction.KVTxn) error { + start := timerecord.NewTimeRecorder("executeTxn") + + elapsed := start.ElapseSpan() + metrics.MetaOpCounter.WithLabelValues(metrics.MetaTxnLabel, metrics.TotalLabel).Inc() + err := commitTxn(ctx, txn) + if err == nil { + metrics.MetaRequestLatency.WithLabelValues(metrics.MetaTxnLabel).Observe(float64(elapsed.Milliseconds())) + metrics.MetaOpCounter.WithLabelValues(metrics.MetaTxnLabel, metrics.SuccessLabel).Inc() + } else { + metrics.MetaOpCounter.WithLabelValues(metrics.MetaTxnLabel, metrics.FailLabel).Inc() + } + + return err +} + +func (kv *txnTiKV) getTiKVMeta(ctx context.Context, key string) (string, error) { + ctx1, cancel := context.WithTimeout(ctx, RequestTimeout) + defer cancel() + + start := timerecord.NewTimeRecorder("getTiKVMeta") + + ss := getSnapshot(kv.txn, SnapshotScanSize) + + val, err := ss.Get(ctx1, []byte(key)) + if err != nil { + // Log key read fail + metrics.MetaOpCounter.WithLabelValues(metrics.MetaGetLabel, metrics.FailLabel).Inc() + if err == tikverr.ErrNotExist { + // If key is missing + return "", merr.WrapErrIoKeyNotFound(key) + } + // If call to tikv fails + return "", errors.Wrap(err, fmt.Sprintf("Failed to get value for key %s in getTiKVMeta", key)) + } + + // Check if value is the empty placeholder + strVal := convertEmptyByteToString(val) + + elapsed := start.ElapseSpan() + + metrics.MetaOpCounter.WithLabelValues(metrics.MetaGetLabel, metrics.TotalLabel).Inc() + metrics.MetaKvSize.WithLabelValues(metrics.MetaGetLabel).Observe(float64(len(val))) + metrics.MetaRequestLatency.WithLabelValues(metrics.MetaGetLabel).Observe(float64(elapsed.Milliseconds())) + metrics.MetaOpCounter.WithLabelValues(metrics.MetaGetLabel, metrics.SuccessLabel).Inc() + + return strVal, nil +} + +func (kv *txnTiKV) putTiKVMeta(ctx context.Context, key, val string) error { + ctx1, cancel := context.WithTimeout(ctx, RequestTimeout) + defer cancel() + + start := timerecord.NewTimeRecorder("putTiKVMeta") + + txn, err := beginTxn(kv.txn) + if err != nil { + return errors.Wrap(err, "Failed to build transaction for putTiKVMeta") + } + // Defer a rollback only if the transaction hasn't been committed + defer rollbackOnFailure(&err, txn) + + // Check if the value being written needs to be empty placeholder + byteValue, err := convertEmptyStringToByte(val) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("Failed to cast to byte (%s:%s) for putTiKVMeta", key, val)) + } + err = txn.Set([]byte(key), byteValue) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("Failed to set value for key %s in putTiKVMeta", key)) + } + err = commitTxn(ctx1, txn) + + elapsed := start.ElapseSpan() + metrics.MetaOpCounter.WithLabelValues(metrics.MetaPutLabel, metrics.TotalLabel).Inc() + if err == nil { + metrics.MetaKvSize.WithLabelValues(metrics.MetaPutLabel).Observe(float64(len(byteValue))) + metrics.MetaRequestLatency.WithLabelValues(metrics.MetaPutLabel).Observe(float64(elapsed.Milliseconds())) + metrics.MetaOpCounter.WithLabelValues(metrics.MetaPutLabel, metrics.SuccessLabel).Inc() + } else { + metrics.MetaOpCounter.WithLabelValues(metrics.MetaPutLabel, metrics.FailLabel).Inc() + } + + return err +} + +func (kv *txnTiKV) removeTiKVMeta(ctx context.Context, key string) error { + ctx1, cancel := context.WithTimeout(ctx, RequestTimeout) + defer cancel() + + start := timerecord.NewTimeRecorder("removeTiKVMeta") + + txn, err := beginTxn(kv.txn) + if err != nil { + return errors.Wrap(err, "Failed to build transaction for removeTiKVMeta") + } + // Defer a rollback only if the transaction hasn't been committed + defer rollbackOnFailure(&err, txn) + + err = txn.Delete([]byte(key)) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("Failed to remove key %s in removeTiKVMeta", key)) + } + err = commitTxn(ctx1, txn) + + elapsed := start.ElapseSpan() + metrics.MetaOpCounter.WithLabelValues(metrics.MetaRemoveLabel, metrics.TotalLabel).Inc() + + if err == nil { + metrics.MetaRequestLatency.WithLabelValues(metrics.MetaRemoveLabel).Observe(float64(elapsed.Milliseconds())) + metrics.MetaOpCounter.WithLabelValues(metrics.MetaRemoveLabel, metrics.SuccessLabel).Inc() + } else { + metrics.MetaOpCounter.WithLabelValues(metrics.MetaRemoveLabel, metrics.FailLabel).Inc() + } + + return err +} + +func (kv *txnTiKV) CompareVersionAndSwap(key string, version int64, target string) (bool, error) { + err := fmt.Errorf("Unimplemented! CompareVersionAndSwap is under deprecation") + logWarnOnFailure(&err, "Unimplemented") + return false, err +} + +// CheckElapseAndWarn checks the elapsed time and warns if it is too long. +func CheckElapseAndWarn(start time.Time, message string, fields ...zap.Field) bool { + elapsed := time.Since(start) + if elapsed.Milliseconds() > 2000 { + log.Warn(message, append([]zap.Field{zap.String("time spent", elapsed.String())}, fields...)...) + return true + } + return false +} + +// Since TiKV cannot store empty key values, we assign them a placeholder held by EmptyValue. +// Upon loading, we need to check if the returned value is the placeholder. +func isEmptyByte(value []byte) bool { + return bytes.Equal(value, EmptyValueByte) || len(value) == 0 +} + +// Return an empty string if the value is the Empty placeholder, else return actual string value. +func convertEmptyByteToString(value []byte) string { + if isEmptyByte(value) { + return "" + } + return string(value) +} + +// Convert string into EmptyValue if empty else cast to []byte. Will throw error if value is equal +// to the EmptyValueString. +func convertEmptyStringToByte(value string) ([]byte, error) { + if len(value) == 0 { + return EmptyValueByte, nil + } else if value == EmptyValueString { + return nil, fmt.Errorf("Value for key is reserved by EmptyValue: %s", EmptyValueString) + } else { + return []byte(value), nil + } +} diff --git a/internal/kv/tikv/txn_tikv_test.go b/internal/kv/tikv/txn_tikv_test.go new file mode 100644 index 0000000000000..6a7dcad0f19e9 --- /dev/null +++ b/internal/kv/tikv/txn_tikv_test.go @@ -0,0 +1,642 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tikv + +import ( + "context" + "fmt" + "sort" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/txnkv" + "github.com/tikv/client-go/v2/txnkv/transaction" + "golang.org/x/exp/maps" + + "github.com/milvus-io/milvus/internal/kv/predicates" +) + +func TestTiKVLoad(te *testing.T) { + te.Run("kv SaveAndLoad", func(t *testing.T) { + rootPath := "/tikv/test/root/saveandload" + kv := NewTiKV(txnClient, rootPath) + err := kv.RemoveWithPrefix("") + require.NoError(t, err) + + defer kv.Close() + defer kv.RemoveWithPrefix("") + + saveAndLoadTests := []struct { + key string + value string + }{ + {"test1", "value1"}, + {"test2", "value2"}, + {"test1/a", "value_a"}, + {"test1/b", "value_b"}, + } + + for i, test := range saveAndLoadTests { + if i < 4 { + err = kv.Save(test.key, test.value) + assert.NoError(t, err) + } + + val, err := kv.Load(test.key) + assert.NoError(t, err) + assert.Equal(t, test.value, val) + } + + invalidLoadTests := []struct { + invalidKey string + }{ + {"t"}, + {"a"}, + {"test1a"}, + } + + for _, test := range invalidLoadTests { + val, err := kv.Load(test.invalidKey) + assert.Error(t, err) + assert.Zero(t, val) + } + + loadPrefixTests := []struct { + prefix string + + expectedKeys []string + expectedValues []string + expectedError error + }{ + {"test", []string{ + kv.GetPath("test1"), + kv.GetPath("test2"), + kv.GetPath("test1/a"), + kv.GetPath("test1/b"), + }, []string{"value1", "value2", "value_a", "value_b"}, nil}, + {"test1", []string{ + kv.GetPath("test1"), + kv.GetPath("test1/a"), + kv.GetPath("test1/b"), + }, []string{"value1", "value_a", "value_b"}, nil}, + {"test2", []string{kv.GetPath("test2")}, []string{"value2"}, nil}, + {"", []string{ + kv.GetPath("test1"), + kv.GetPath("test2"), + kv.GetPath("test1/a"), + kv.GetPath("test1/b"), + }, []string{"value1", "value2", "value_a", "value_b"}, nil}, + {"test1/a", []string{kv.GetPath("test1/a")}, []string{"value_a"}, nil}, + {"a", []string{}, []string{}, nil}, + {"root", []string{}, []string{}, nil}, + {"/tikv/test/root", []string{}, []string{}, nil}, + } + + for _, test := range loadPrefixTests { + actualKeys, actualValues, err := kv.LoadWithPrefix(test.prefix) + assert.ElementsMatch(t, test.expectedKeys, actualKeys) + assert.ElementsMatch(t, test.expectedValues, actualValues) + assert.Equal(t, test.expectedError, err) + } + + removeTests := []struct { + validKey string + invalidKey string + }{ + {"test1", "abc"}, + {"test1/a", "test1/lskfjal"}, + {"test1/b", "test1/b"}, + {"test2", "-"}, + } + + for _, test := range removeTests { + err = kv.Remove(test.validKey) + assert.NoError(t, err) + + _, err = kv.Load(test.validKey) + assert.Error(t, err) + + err = kv.Remove(test.validKey) + assert.NoError(t, err) + err = kv.Remove(test.invalidKey) + assert.NoError(t, err) + } + }) + + te.Run("kv MultiSaveAndMultiLoad", func(t *testing.T) { + rootPath := "/tikv/test/root/multi_save_and_multi_load" + kv := NewTiKV(txnClient, rootPath) + + defer kv.Close() + defer kv.RemoveWithPrefix("") + + multiSaveTests := map[string]string{ + "key_1": "value_1", + "key_2": "value_2", + "key_3/a": "value_3a", + "multikey_1": "multivalue_1", + "multikey_2": "multivalue_2", + "_": "other", + } + + err := kv.MultiSave(multiSaveTests) + assert.NoError(t, err) + for k, v := range multiSaveTests { + actualV, err := kv.Load(k) + assert.NoError(t, err) + assert.Equal(t, v, actualV) + } + + multiLoadTests := []struct { + inputKeys []string + expectedValues []string + }{ + {[]string{"key_1"}, []string{"value_1"}}, + {[]string{"key_1", "key_2", "key_3/a"}, []string{"value_1", "value_2", "value_3a"}}, + {[]string{"multikey_1", "multikey_2"}, []string{"multivalue_1", "multivalue_2"}}, + {[]string{"_"}, []string{"other"}}, + } + + for _, test := range multiLoadTests { + vs, err := kv.MultiLoad(test.inputKeys) + assert.NoError(t, err) + assert.Equal(t, test.expectedValues, vs) + } + + invalidMultiLoad := []struct { + invalidKeys []string + expectedValues []string + }{ + {[]string{"a", "key_1"}, []string{"", "value_1"}}, + {[]string{".....", "key_1"}, []string{"", "value_1"}}, + {[]string{"*********"}, []string{""}}, + {[]string{"key_1", "1"}, []string{"value_1", ""}}, + } + + for _, test := range invalidMultiLoad { + vs, err := kv.MultiLoad(test.invalidKeys) + assert.Error(t, err) + assert.Equal(t, test.expectedValues, vs) + } + + removeWithPrefixTests := []string{ + "key_1", + "multi", + } + + for _, k := range removeWithPrefixTests { + err = kv.RemoveWithPrefix(k) + assert.NoError(t, err) + + ks, vs, err := kv.LoadWithPrefix(k) + assert.Empty(t, ks) + assert.Empty(t, vs) + assert.NoError(t, err) + } + + multiRemoveTests := []string{ + "key_2", + "key_3/a", + "multikey_2", + "_", + } + + err = kv.MultiRemove(multiRemoveTests) + assert.NoError(t, err) + + ks, vs, err := kv.LoadWithPrefix("") + assert.NoError(t, err) + assert.Empty(t, ks) + assert.Empty(t, vs) + + multiSaveAndRemoveTests := []struct { + multiSaves map[string]string + multiRemoves []string + }{ + {map[string]string{"key_1": "value_1"}, []string{}}, + {map[string]string{"key_2": "value_2"}, []string{"key_1"}}, + {map[string]string{"key_3/a": "value_3a"}, []string{"key_2"}}, + {map[string]string{"multikey_1": "multivalue_1"}, []string{}}, + {map[string]string{"multikey_2": "multivalue_2"}, []string{"multikey_1", "key_3/a"}}, + {make(map[string]string), []string{"multikey_2"}}, + } + for _, test := range multiSaveAndRemoveTests { + err = kv.MultiSaveAndRemove(test.multiSaves, test.multiRemoves) + assert.NoError(t, err) + } + + ks, vs, err = kv.LoadWithPrefix("") + assert.NoError(t, err) + assert.Empty(t, ks) + assert.Empty(t, vs) + }) + + te.Run("kv MultiSaveAndRemoveWithPrefix", func(t *testing.T) { + rootPath := "/tikv/test/root/multi_remove_with_prefix" + kv := NewTiKV(txnClient, rootPath) + defer kv.Close() + defer kv.RemoveWithPrefix("") + + prepareTests := map[string]string{ + "x/abc/1": "1", + "x/abc/2": "2", + "x/def/1": "10", + "x/def/2": "20", + "x/den/1": "100", + "x/den/2": "200", + } + + // MultiSaveAndRemoveWithPrefix + err := kv.MultiSave(prepareTests) + require.NoError(t, err) + multiSaveAndRemoveWithPrefixTests := []struct { + multiSave map[string]string + prefix []string + + loadPrefix string + lengthBeforeRemove int + lengthAfterRemove int + }{ + {map[string]string{}, []string{"x/abc", "x/def", "x/den"}, "x", 6, 0}, + {map[string]string{"y/a": "vvv", "y/b": "vvv"}, []string{}, "y", 0, 2}, + {map[string]string{"y/c": "vvv"}, []string{}, "y", 2, 3}, + {map[string]string{"p/a": "vvv"}, []string{"y/a", "y"}, "y", 3, 0}, + {map[string]string{}, []string{"p"}, "p", 1, 0}, + } + + for _, test := range multiSaveAndRemoveWithPrefixTests { + k, _, err := kv.LoadWithPrefix(test.loadPrefix) + assert.NoError(t, err) + assert.Equal(t, test.lengthBeforeRemove, len(k)) + + err = kv.MultiSaveAndRemoveWithPrefix(test.multiSave, test.prefix) + assert.NoError(t, err) + + k, _, err = kv.LoadWithPrefix(test.loadPrefix) + assert.NoError(t, err) + assert.Equal(t, test.lengthAfterRemove, len(k)) + } + }) + + te.Run("kv failed to start txn", func(t *testing.T) { + rootPath := "/tikv/test/root/start_exn" + kv := NewTiKV(txnClient, rootPath) + defer kv.Close() + + beginTxn = func(txn *txnkv.Client) (*transaction.KVTxn, error) { + return nil, fmt.Errorf("bad txn!") + } + defer func() { + beginTxn = tiTxnBegin + }() + err := kv.Save("key1", "v1") + assert.Error(t, err) + err = kv.MultiSave(map[string]string{"A/100": "v1"}) + assert.Error(t, err) + err = kv.Remove("key1") + assert.Error(t, err) + err = kv.MultiRemove([]string{"key_1", "key_2"}) + assert.Error(t, err) + err = kv.MultiSaveAndRemove(map[string]string{"key_1": "value_1"}, []string{}) + assert.Error(t, err) + err = kv.MultiSaveAndRemoveWithPrefix(map[string]string{"y/c": "vvv"}, []string{"/"}) + assert.Error(t, err) + }) + + te.Run("kv failed to commit txn", func(t *testing.T) { + rootPath := "/tikv/test/root/commit_txn" + kv := NewTiKV(txnClient, rootPath) + defer kv.Close() + + commitTxn = func(ctx context.Context, txn *transaction.KVTxn) error { + return fmt.Errorf("bad txn commit!") + } + defer func() { + commitTxn = tiTxnCommit + }() + var err error + err = kv.Save("key1", "v1") + assert.Error(t, err) + err = kv.MultiSave(map[string]string{"A/100": "v1"}) + assert.Error(t, err) + err = kv.Remove("key1") + assert.Error(t, err) + err = kv.MultiRemove([]string{"key_1", "key_2"}) + assert.Error(t, err) + err = kv.MultiSaveAndRemove(map[string]string{"key_1": "value_1"}, []string{}) + assert.Error(t, err) + err = kv.MultiSaveAndRemoveWithPrefix(map[string]string{"y/c": "vvv"}, []string{"/"}) + assert.Error(t, err) + }) +} + +func TestWalkWithPagination(t *testing.T) { + rootPath := "/tikv/test/root/pagination" + kv := NewTiKV(txnClient, rootPath) + + defer kv.Close() + defer kv.RemoveWithPrefix("") + + kvs := map[string]string{ + "A/100": "v1", + "AA/100": "v2", + "AB/100": "v3", + "AB/2/100": "v4", + "B/100": "v5", + } + + err := kv.MultiSave(kvs) + assert.NoError(t, err) + for k, v := range kvs { + actualV, err := kv.Load(k) + assert.NoError(t, err) + assert.Equal(t, v, actualV) + } + + t.Run("apply function error ", func(t *testing.T) { + err = kv.WalkWithPrefix("A", 5, func(key []byte, value []byte) error { + return errors.New("error") + }) + assert.Error(t, err) + }) + + t.Run("get with non-exist prefix ", func(t *testing.T) { + err = kv.WalkWithPrefix("non-exist-prefix", 5, func(key []byte, value []byte) error { + return nil + }) + assert.NoError(t, err) + }) + + t.Run("with different pagination", func(t *testing.T) { + testFn := func(pagination int) { + expected := map[string]string{ + "A/100": "v1", + "AA/100": "v2", + "AB/100": "v3", + "AB/2/100": "v4", + } + + expectedKeys := maps.Keys(expected) + sort.Strings(expectedKeys) + + ret := make(map[string]string) + actualKeys := make([]string, 0) + + err = kv.WalkWithPrefix("A", pagination, func(key []byte, value []byte) error { + k := string(key) + k = k[len(rootPath)+1:] + ret[k] = string(value) + actualKeys = append(actualKeys, k) + return nil + }) + + assert.NoError(t, err) + assert.Equal(t, expected, ret, fmt.Errorf("pagination: %d", pagination)) + // Ignore the order. + assert.ElementsMatch(t, expectedKeys, actualKeys, fmt.Errorf("pagination: %d", pagination)) + } + + for p := -1; p < 6; p++ { + testFn(p) + } + testFn(-100) + testFn(100) + }) +} + +func TestElapse(t *testing.T) { + start := time.Now() + isElapse := CheckElapseAndWarn(start, "err message") + assert.Equal(t, isElapse, false) + + time.Sleep(2001 * time.Millisecond) + isElapse = CheckElapseAndWarn(start, "err message") + assert.Equal(t, isElapse, true) +} + +func TestHas(t *testing.T) { + rootPath := "/tikv/test/root/pagination" + kv := NewTiKV(txnClient, rootPath) + err := kv.RemoveWithPrefix("") + require.NoError(t, err) + + defer kv.Close() + defer kv.RemoveWithPrefix("") + + has, err := kv.Has("key1") + assert.NoError(t, err) + assert.False(t, has) + + err = kv.Save("key1", "value1") + assert.NoError(t, err) + + err = kv.Save("key1", EmptyValueString) + assert.Error(t, err) + + has, err = kv.Has("key1") + assert.NoError(t, err) + assert.True(t, has) + + err = kv.Remove("key1") + assert.NoError(t, err) + + has, err = kv.Has("key1") + assert.NoError(t, err) + assert.False(t, has) +} + +func TestHasPrefix(t *testing.T) { + rootPath := "/etcd/test/root/hasprefix" + kv := NewTiKV(txnClient, rootPath) + err := kv.RemoveWithPrefix("") + require.NoError(t, err) + + defer kv.Close() + defer kv.RemoveWithPrefix("") + + has, err := kv.HasPrefix("key") + assert.NoError(t, err) + assert.False(t, has) + + err = kv.Save("key1", "value1") + assert.NoError(t, err) + + has, err = kv.HasPrefix("key") + assert.NoError(t, err) + assert.True(t, has) + + err = kv.Remove("key1") + assert.NoError(t, err) + + has, err = kv.HasPrefix("key") + assert.NoError(t, err) + assert.False(t, has) +} + +func TestEmptyKey(t *testing.T) { + rootPath := "/etcd/test/root/loadempty" + kv := NewTiKV(txnClient, rootPath) + err := kv.RemoveWithPrefix("") + require.NoError(t, err) + + defer kv.Close() + defer kv.RemoveWithPrefix("") + + has, err := kv.HasPrefix("key") + assert.NoError(t, err) + assert.False(t, has) + + err = kv.Save("key", "") + assert.NoError(t, err) + + has, err = kv.HasPrefix("key") + assert.NoError(t, err) + assert.True(t, has) + + val, err := kv.Load("key") + assert.NoError(t, err) + assert.Equal(t, val, "") + + _, vals, err := kv.LoadWithPrefix("key") + assert.NoError(t, err) + assert.Equal(t, vals[0], "") + + vals, err = kv.MultiLoad([]string{"key"}) + assert.NoError(t, err) + assert.Equal(t, vals[0], "") + + var res string + nothing := func(key, val []byte) error { + res = string(val) + return nil + } + + err = kv.WalkWithPrefix("", 1, nothing) + assert.NoError(t, err) + assert.Equal(t, res, "") + + multiSaveTests := map[string]string{ + "key1": "", + } + err = kv.MultiSave(multiSaveTests) + assert.NoError(t, err) + val, err = kv.Load("key1") + assert.NoError(t, err) + assert.Equal(t, val, "") + + multiSaveTests = map[string]string{ + "key2": "", + } + err = kv.MultiSaveAndRemove(multiSaveTests, nil) + assert.NoError(t, err) + val, err = kv.Load("key2") + assert.NoError(t, err) + assert.Equal(t, val, "") + + multiSaveTests = map[string]string{ + "key3": "", + } + err = kv.MultiSaveAndRemoveWithPrefix(multiSaveTests, nil) + assert.NoError(t, err) + val, err = kv.Load("key3") + assert.NoError(t, err) + assert.Equal(t, val, "") +} + +func TestScanSize(t *testing.T) { + scanSize := SnapshotScanSize + kv := NewTiKV(txnClient, "/") + err := kv.RemoveWithPrefix("") + require.NoError(t, err) + + defer kv.Close() + defer kv.RemoveWithPrefix("") + + // Test total > scansize + keyMap := map[string]string{} + for i := 1; i <= scanSize+100; i++ { + a := fmt.Sprintf("%v", i) + keyMap[a] = a + } + + err = kv.MultiSave(keyMap) + assert.NoError(t, err) + + keys, _, err := kv.LoadWithPrefix("") + assert.NoError(t, err) + assert.Equal(t, len(keys), scanSize+100) + + err = kv.RemoveWithPrefix("") + require.NoError(t, err) +} + +func TestTiKVUnimplemented(t *testing.T) { + kv := NewTiKV(txnClient, "/") + err := kv.RemoveWithPrefix("") + require.NoError(t, err) + + defer kv.Close() + defer kv.RemoveWithPrefix("") + + _, err = kv.CompareVersionAndSwap("k", 1, "target") + assert.Error(t, err) +} + +func TestTxnWithPredicates(t *testing.T) { + kv := NewTiKV(txnClient, "/") + err := kv.RemoveWithPrefix("") + require.NoError(t, err) + + prepareKV := map[string]string{ + "lease1": "1", + "lease2": "2", + } + + err = kv.MultiSave(prepareKV) + require.NoError(t, err) + + multiSaveAndRemovePredTests := []struct { + tag string + multiSave map[string]string + preds []predicates.Predicate + expectSuccess bool + }{ + {"predicate_ok", map[string]string{"a": "b"}, []predicates.Predicate{predicates.ValueEqual("lease1", "1")}, true}, + {"predicate_fail", map[string]string{"a": "b"}, []predicates.Predicate{predicates.ValueEqual("lease1", "2")}, false}, + } + + for _, test := range multiSaveAndRemovePredTests { + t.Run(test.tag, func(t *testing.T) { + err := kv.MultiSaveAndRemove(test.multiSave, nil, test.preds...) + t.Log(err) + if test.expectSuccess { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + err = kv.MultiSaveAndRemoveWithPrefix(test.multiSave, nil, test.preds...) + if test.expectSuccess { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + }) + } +} diff --git a/internal/metastore/kv/datacoord/kv_catalog.go b/internal/metastore/kv/datacoord/kv_catalog.go index 52249431bf7c8..b1864d43dcbfe 100644 --- a/internal/metastore/kv/datacoord/kv_catalog.go +++ b/internal/metastore/kv/datacoord/kv_catalog.go @@ -78,7 +78,7 @@ func (kc *Catalog) ListSegments(ctx context.Context) ([]*datapb.SegmentInfo, err }) } - //execute list segment meta + // execute list segment meta executeFn(storage.InsertBinlog, insertLogs) executeFn(storage.DeleteBinlog, deltaLogs) executeFn(storage.StatsBinlog, statsLogs) @@ -211,7 +211,8 @@ func (kc *Catalog) listBinlogs(binlogType storage.BinlogType) (map[typeutil.Uniq } func (kc *Catalog) applyBinlogInfo(segments []*datapb.SegmentInfo, insertLogs, deltaLogs, - statsLogs map[typeutil.UniqueID][]*datapb.FieldBinlog) { + statsLogs map[typeutil.UniqueID][]*datapb.FieldBinlog, +) { for _, segmentInfo := range segments { if len(segmentInfo.Binlogs) == 0 { segmentInfo.Binlogs = insertLogs[segmentInfo.ID] @@ -529,7 +530,8 @@ func (kc *Catalog) DropChannelCheckpoint(ctx context.Context, vChannel string) e } func (kc *Catalog) getBinlogsWithPrefix(binlogType storage.BinlogType, collectionID, partitionID, - segmentID typeutil.UniqueID) ([]string, []string, error) { + segmentID typeutil.UniqueID, +) ([]string, []string, error) { var binlogPrefix string switch binlogType { case storage.InsertBinlog: @@ -727,18 +729,13 @@ func (kc *Catalog) GcConfirm(ctx context.Context, collectionID, partitionID type } func fillLogPathByLogID(chunkManagerRootPath string, binlogType storage.BinlogType, collectionID, partitionID, - segmentID typeutil.UniqueID, fieldBinlog *datapb.FieldBinlog) error { + segmentID typeutil.UniqueID, fieldBinlog *datapb.FieldBinlog, +) { for _, binlog := range fieldBinlog.Binlogs { - path, err := buildLogPath(chunkManagerRootPath, binlogType, collectionID, partitionID, + path := buildLogPath(chunkManagerRootPath, binlogType, collectionID, partitionID, segmentID, fieldBinlog.GetFieldID(), binlog.GetLogID()) - if err != nil { - return err - } - binlog.LogPath = path } - - return nil } func fillLogIDByLogPath(multiFieldBinlogs ...[]*datapb.FieldBinlog) error { @@ -765,42 +762,85 @@ func fillLogIDByLogPath(multiFieldBinlogs ...[]*datapb.FieldBinlog) error { return nil } +func CompressBinLog(fieldBinLogs []*datapb.FieldBinlog) ([]*datapb.FieldBinlog, error) { + compressedFieldBinLogs := make([]*datapb.FieldBinlog, 0) + for _, fieldBinLog := range fieldBinLogs { + compressedFieldBinLog := &datapb.FieldBinlog{} + compressedFieldBinLog.FieldID = fieldBinLog.FieldID + for _, binlog := range fieldBinLog.Binlogs { + logPath := binlog.LogPath + idx := strings.LastIndex(logPath, "/") + if idx == -1 { + return nil, fmt.Errorf("invailed binlog path: %s", logPath) + } + logPathStr := logPath[(idx + 1):] + logID, err := strconv.ParseInt(logPathStr, 10, 64) + if err != nil { + return nil, err + } + binlog := &datapb.Binlog{ + EntriesNum: binlog.EntriesNum, + // remove timestamp since it's not necessary + LogSize: binlog.LogSize, + LogID: logID, + } + compressedFieldBinLog.Binlogs = append(compressedFieldBinLog.Binlogs, binlog) + } + compressedFieldBinLogs = append(compressedFieldBinLogs, compressedFieldBinLog) + } + return compressedFieldBinLogs, nil +} + +func DecompressBinLog(path string, info *datapb.SegmentInfo) error { + for _, fieldBinLogs := range info.GetBinlogs() { + fillLogPathByLogID(path, storage.InsertBinlog, info.CollectionID, info.PartitionID, info.ID, fieldBinLogs) + } + + for _, deltaLogs := range info.GetDeltalogs() { + fillLogPathByLogID(path, storage.DeleteBinlog, info.CollectionID, info.PartitionID, info.ID, deltaLogs) + } + + for _, statsLogs := range info.GetStatslogs() { + fillLogPathByLogID(path, storage.StatsBinlog, info.CollectionID, info.PartitionID, info.ID, statsLogs) + } + return nil +} + // build a binlog path on the storage by metadata -func buildLogPath(chunkManagerRootPath string, binlogType storage.BinlogType, collectionID, partitionID, segmentID, filedID, logID typeutil.UniqueID) (string, error) { +func buildLogPath(chunkManagerRootPath string, binlogType storage.BinlogType, collectionID, partitionID, segmentID, filedID, logID typeutil.UniqueID) string { switch binlogType { case storage.InsertBinlog: - path := metautil.BuildInsertLogPath(chunkManagerRootPath, collectionID, partitionID, segmentID, filedID, logID) - return path, nil + return metautil.BuildInsertLogPath(chunkManagerRootPath, collectionID, partitionID, segmentID, filedID, logID) case storage.DeleteBinlog: - path := metautil.BuildDeltaLogPath(chunkManagerRootPath, collectionID, partitionID, segmentID, logID) - return path, nil + return metautil.BuildDeltaLogPath(chunkManagerRootPath, collectionID, partitionID, segmentID, logID) case storage.StatsBinlog: - path := metautil.BuildStatsLogPath(chunkManagerRootPath, collectionID, partitionID, segmentID, filedID, logID) - return path, nil - default: - return "", fmt.Errorf("invalid binlog type: %d", binlogType) + return metautil.BuildStatsLogPath(chunkManagerRootPath, collectionID, partitionID, segmentID, filedID, logID) } + // should not happen + log.Panic("invalid binlog type", zap.Any("type", binlogType)) + return "" } -func checkBinlogs(binlogType storage.BinlogType, segmentID typeutil.UniqueID, logs []*datapb.FieldBinlog) { - check := func(getSegmentID func(logPath string) typeutil.UniqueID) { +func checkBinlogs(binlogType storage.BinlogType, segmentID typeutil.UniqueID, logs []*datapb.FieldBinlog) error { + check := func(getSegmentID func(logPath string) typeutil.UniqueID) error { for _, fieldBinlog := range logs { for _, binlog := range fieldBinlog.Binlogs { if segmentID != getSegmentID(binlog.LogPath) { - log.Panic("the segment path doesn't match the segmentID", zap.Int64("segmentID", segmentID), zap.String("path", binlog.LogPath)) + return fmt.Errorf("the segment path doesn't match the segmentID, segmentID %d, path %s", segmentID, binlog.LogPath) } } } + return nil } switch binlogType { case storage.InsertBinlog: - check(metautil.GetSegmentIDFromInsertLogPath) + return check(metautil.GetSegmentIDFromInsertLogPath) case storage.DeleteBinlog: - check(metautil.GetSegmentIDFromDeltaLogPath) + return check(metautil.GetSegmentIDFromDeltaLogPath) case storage.StatsBinlog: - check(metautil.GetSegmentIDFromStatsLogPath) + return check(metautil.GetSegmentIDFromStatsLogPath) default: - log.Panic("invalid binlog type") + return fmt.Errorf("the segment path doesn't match the segmentID, segmentID %d, type %d", segmentID, binlogType) } } @@ -815,11 +855,20 @@ func hasSepcialStatslog(logs *datapb.FieldBinlog) bool { } func buildBinlogKvsWithLogID(collectionID, partitionID, segmentID typeutil.UniqueID, - binlogs, deltalogs, statslogs []*datapb.FieldBinlog, ignoreNumberCheck bool) (map[string]string, error) { - - checkBinlogs(storage.InsertBinlog, segmentID, binlogs) + binlogs, deltalogs, statslogs []*datapb.FieldBinlog, ignoreNumberCheck bool, +) (map[string]string, error) { + err := checkBinlogs(storage.InsertBinlog, segmentID, binlogs) + if err != nil { + return nil, err + } checkBinlogs(storage.DeleteBinlog, segmentID, deltalogs) + if err != nil { + return nil, err + } checkBinlogs(storage.StatsBinlog, segmentID, statslogs) + if err != nil { + return nil, err + } // check stats log and bin log size match // num of stats log may one more than num of binlogs if segment flushed and merged stats log if !ignoreNumberCheck && len(binlogs) != 0 && len(statslogs) != 0 && !hasSepcialStatslog(statslogs[0]) { diff --git a/internal/metastore/kv/datacoord/kv_catalog_test.go b/internal/metastore/kv/datacoord/kv_catalog_test.go index 856706de1f11f..a87c2507ee0db 100644 --- a/internal/metastore/kv/datacoord/kv_catalog_test.go +++ b/internal/metastore/kv/datacoord/kv_catalog_test.go @@ -122,7 +122,8 @@ var ( { EntriesNum: 5, LogPath: deltalogPath, - }}, + }, + }, }, } statslogs = []*datapb.FieldBinlog{ @@ -258,7 +259,6 @@ func Test_ListSegments(t *testing.T) { } if strings.HasPrefix(k3, s) { return f([]byte(k3), []byte(savedKvs[k3])) - } return errors.New("should not reach here") }) @@ -277,9 +277,9 @@ func Test_AddSegments(t *testing.T) { metakv.EXPECT().MultiSave(mock.Anything).Return(errors.New("error")).Maybe() catalog := NewCatalog(metakv, rootPath, "") - assert.Panics(t, func() { - catalog.AddSegment(context.TODO(), invalidSegment) - }) + + err := catalog.AddSegment(context.TODO(), invalidSegment) + assert.Error(t, err) }) t.Run("save error", func(t *testing.T) { @@ -327,11 +327,10 @@ func Test_AlterSegments(t *testing.T) { metakv.EXPECT().MultiSave(mock.Anything).Return(errors.New("error")).Maybe() catalog := NewCatalog(metakv, rootPath, "") - assert.Panics(t, func() { - catalog.AlterSegments(context.TODO(), []*datapb.SegmentInfo{invalidSegment}, metastore.BinlogsIncrement{ - Segment: invalidSegment, - }) + err := catalog.AlterSegments(context.TODO(), []*datapb.SegmentInfo{invalidSegment}, metastore.BinlogsIncrement{ + Segment: invalidSegment, }) + assert.Error(t, err) }) t.Run("save error", func(t *testing.T) { @@ -372,10 +371,6 @@ func Test_AlterSegments(t *testing.T) { opGroupCount := 0 metakv := mocks.NewMetaKv(t) metakv.EXPECT().MultiSave(mock.Anything).RunAndReturn(func(m map[string]string) error { - var ks []string - for k := range m { - ks = append(ks, k) - } maps.Copy(savedKvs, m) opGroupCount++ return nil @@ -1063,6 +1058,54 @@ func TestCatalog_DropSegmentIndex(t *testing.T) { }) } +func TestCatalog_Compress(t *testing.T) { + segmentInfo := getSegment(rootPath, 0, 1, 2, 3, 10000) + val, err := proto.Marshal(segmentInfo) + assert.NoError(t, err) + + compressedSegmentInfo := proto.Clone(segmentInfo).(*datapb.SegmentInfo) + compressedSegmentInfo.Binlogs, err = CompressBinLog(compressedSegmentInfo.Binlogs) + assert.NoError(t, err) + compressedSegmentInfo.Deltalogs, err = CompressBinLog(compressedSegmentInfo.Deltalogs) + assert.NoError(t, err) + compressedSegmentInfo.Statslogs, err = CompressBinLog(compressedSegmentInfo.Statslogs) + assert.NoError(t, err) + + valCompressed, err := proto.Marshal(compressedSegmentInfo) + assert.NoError(t, err) + + assert.True(t, len(valCompressed) < len(val)) + + // make sure the compact + unmarshaledSegmentInfo := &datapb.SegmentInfo{} + proto.Unmarshal(val, unmarshaledSegmentInfo) + + unmarshaledSegmentInfoCompressed := &datapb.SegmentInfo{} + proto.Unmarshal(valCompressed, unmarshaledSegmentInfoCompressed) + DecompressBinLog(rootPath, unmarshaledSegmentInfoCompressed) + + assert.Equal(t, len(unmarshaledSegmentInfo.GetBinlogs()), len(unmarshaledSegmentInfoCompressed.GetBinlogs())) + for i := 0; i < 1000; i++ { + assert.Equal(t, unmarshaledSegmentInfo.GetBinlogs()[0].Binlogs[i].LogPath, unmarshaledSegmentInfoCompressed.GetBinlogs()[0].Binlogs[i].LogPath) + } + + // test compress erorr path + fakeBinlogs := make([]*datapb.Binlog, 1) + fakeBinlogs[0] = &datapb.Binlog{ + EntriesNum: 10000, + LogPath: "test", + } + fieldBinLogs := make([]*datapb.FieldBinlog, 1) + fieldBinLogs[0] = &datapb.FieldBinlog{ + FieldID: 106, + Binlogs: fakeBinlogs, + } + compressedSegmentInfo.Binlogs, err = CompressBinLog(fieldBinLogs) + assert.Error(t, err) + + // test decompress error path +} + func BenchmarkCatalog_List1000Segments(b *testing.B) { paramtable.Init() etcdCli, err := etcd.GetEtcdClient( @@ -1140,9 +1183,62 @@ func addSegment(rootPath string, collectionID, partitionID, segmentID, fieldID i { EntriesNum: 5, LogPath: metautil.BuildDeltaLogPath(rootPath, collectionID, partitionID, segmentID, int64(rand.Int())), - }}, + }, + }, }, } + + statslogs = []*datapb.FieldBinlog{ + { + FieldID: 1, + Binlogs: []*datapb.Binlog{ + { + EntriesNum: 5, + LogPath: metautil.BuildStatsLogPath(rootPath, collectionID, partitionID, segmentID, fieldID, int64(rand.Int())), + }, + }, + }, + } + + return &datapb.SegmentInfo{ + ID: segmentID, + CollectionID: collectionID, + PartitionID: partitionID, + NumOfRows: 10000, + State: commonpb.SegmentState_Flushed, + Binlogs: binlogs, + Deltalogs: deltalogs, + Statslogs: statslogs, + } +} + +func getSegment(rootPath string, collectionID, partitionID, segmentID, fieldID int64, binlogNum int) *datapb.SegmentInfo { + binLogPaths := make([]*datapb.Binlog, binlogNum) + for i := 0; i < binlogNum; i++ { + binLogPaths[i] = &datapb.Binlog{ + EntriesNum: 10000, + LogPath: metautil.BuildInsertLogPath(rootPath, collectionID, partitionID, segmentID, fieldID, int64(i)), + } + } + binlogs = []*datapb.FieldBinlog{ + { + FieldID: fieldID, + Binlogs: binLogPaths, + }, + } + + deltalogs = []*datapb.FieldBinlog{ + { + FieldID: fieldID, + Binlogs: []*datapb.Binlog{ + { + EntriesNum: 5, + LogPath: metautil.BuildDeltaLogPath(rootPath, collectionID, partitionID, segmentID, int64(rand.Int())), + }, + }, + }, + } + statslogs = []*datapb.FieldBinlog{ { FieldID: 1, diff --git a/internal/metastore/kv/querycoord/kv_catalog.go b/internal/metastore/kv/querycoord/kv_catalog.go index 75753ec237969..eb1dce1e74349 100644 --- a/internal/metastore/kv/querycoord/kv_catalog.go +++ b/internal/metastore/kv/querycoord/kv_catalog.go @@ -12,9 +12,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" ) -var ( - ErrInvalidKey = errors.New("invalid load info key") -) +var ErrInvalidKey = errors.New("invalid load info key") const ( CollectionLoadInfoPrefix = "querycoord-collection-loadinfo" diff --git a/internal/metastore/kv/rootcoord/kv_catalog.go b/internal/metastore/kv/rootcoord/kv_catalog.go index ea2b028ec5eec..fdbd695228043 100644 --- a/internal/metastore/kv/rootcoord/kv_catalog.go +++ b/internal/metastore/kv/rootcoord/kv_catalog.go @@ -5,7 +5,10 @@ import ( "encoding/json" "fmt" + "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/kv" @@ -21,7 +24,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" ) const ( @@ -341,8 +343,8 @@ func (kc *Catalog) listFieldsAfter210(ctx context.Context, collectionID typeutil } func (kc *Catalog) appendPartitionAndFieldsInfo(ctx context.Context, collMeta *pb.CollectionInfo, - ts typeutil.Timestamp) (*model.Collection, error) { - + ts typeutil.Timestamp, +) (*model.Collection, error) { collection := model.UnmarshalCollectionModel(collMeta) if !partitionVersionAfter210(collMeta) && !fieldVersionAfter210(collMeta) { @@ -382,7 +384,7 @@ func (kc *Catalog) GetCredential(ctx context.Context, username string) (*model.C k := fmt.Sprintf("%s/%s", CredentialPrefix, username) v, err := kc.Txn.Load(k) if err != nil { - if common.IsKeyNotExistError(err) { + if errors.Is(err, merr.ErrIoKeyNotFound) { log.Debug("not found the user", zap.String("key", k)) } else { log.Warn("get credential meta fail", zap.String("key", k), zap.Error(err)) @@ -521,6 +523,11 @@ func dropPartition(collMeta *pb.CollectionInfo, partitionID typeutil.UniqueID) { func (kc *Catalog) DropPartition(ctx context.Context, dbID int64, collectionID typeutil.UniqueID, partitionID typeutil.UniqueID, ts typeutil.Timestamp) error { collMeta, err := kc.loadCollection(ctx, dbID, collectionID, ts) + if errors.Is(err, merr.ErrCollectionNotFound) { + // collection's gc happened before partition's. + return nil + } + if err != nil { return err } @@ -542,7 +549,7 @@ func (kc *Catalog) DropPartition(ctx context.Context, dbID int64, collectionID t func (kc *Catalog) DropCredential(ctx context.Context, username string) error { k := fmt.Sprintf("%s/%s", CredentialPrefix, username) userResults, err := kc.ListUser(ctx, util.DefaultTenant, &milvuspb.UserEntity{Name: username}, true) - if err != nil && !common.IsKeyNotExistError(err) { + if err != nil && !errors.Is(err, merr.ErrIoKeyNotFound) { log.Warn("fail to list user", zap.String("key", k), zap.Error(err)) return err } @@ -593,7 +600,7 @@ func (kc *Catalog) GetCollectionByName(ctx context.Context, dbID int64, collecti } } - return nil, merr.WrapErrCollectionNotFoundWithDB(dbID, collectionName, fmt.Sprintf("timestample = %d", ts)) + return nil, merr.WrapErrCollectionNotFoundWithDB(dbID, collectionName, fmt.Sprintf("timestamp = %d", ts)) } func (kc *Catalog) ListCollections(ctx context.Context, dbID int64, ts typeutil.Timestamp) ([]*model.Collection, error) { @@ -719,7 +726,7 @@ func (kc *Catalog) ListCredentials(ctx context.Context) ([]string, error) { func (kc *Catalog) save(k string) error { var err error - if _, err = kc.Txn.Load(k); err != nil && !common.IsKeyNotExistError(err) { + if _, err = kc.Txn.Load(k); err != nil && !errors.Is(err, merr.ErrIoKeyNotFound) { return err } if err == nil { @@ -731,10 +738,10 @@ func (kc *Catalog) save(k string) error { func (kc *Catalog) remove(k string) error { var err error - if _, err = kc.Txn.Load(k); err != nil && !common.IsKeyNotExistError(err) { + if _, err = kc.Txn.Load(k); err != nil && !errors.Is(err, merr.ErrIoKeyNotFound) { return err } - if err != nil && common.IsKeyNotExistError(err) { + if err != nil && errors.Is(err, merr.ErrIoKeyNotFound) { return common.NewIgnorableError(fmt.Errorf("the key[%s] isn't existed", k)) } return kc.Txn.Remove(k) @@ -752,7 +759,7 @@ func (kc *Catalog) CreateRole(ctx context.Context, tenant string, entity *milvus func (kc *Catalog) DropRole(ctx context.Context, tenant string, roleName string) error { k := funcutil.HandleTenantForEtcdKey(RolePrefix, tenant, roleName) roleResults, err := kc.ListRole(ctx, tenant, &milvuspb.RoleEntity{Name: roleName}, true) - if err != nil && !common.IsKeyNotExistError(err) { + if err != nil && !errors.Is(err, merr.ErrIoKeyNotFound) { log.Warn("fail to list role", zap.String("key", k), zap.Error(err)) return err } @@ -961,12 +968,12 @@ func (kc *Catalog) AlterGrant(ctx context.Context, tenant string, entity *milvus } else { log.Warn("fail to load grant privilege entity", zap.String("key", k), zap.Any("type", operateType), zap.Error(err)) if funcutil.IsRevoke(operateType) { - if common.IsKeyNotExistError(err) { + if errors.Is(err, merr.ErrIoKeyNotFound) { return common.NewIgnorableError(fmt.Errorf("the grant[%s] isn't existed", k)) } return err } - if !common.IsKeyNotExistError(err) { + if !errors.Is(err, merr.ErrIoKeyNotFound) { return err } @@ -982,7 +989,7 @@ func (kc *Catalog) AlterGrant(ctx context.Context, tenant string, entity *milvus _, err = kc.Txn.Load(k) if err != nil { log.Warn("fail to load the grantee id", zap.String("key", k), zap.Error(err)) - if !common.IsKeyNotExistError(err) { + if !errors.Is(err, merr.ErrIoKeyNotFound) { log.Warn("fail to load the grantee id", zap.String("key", k), zap.Error(err)) return err } diff --git a/internal/metastore/kv/rootcoord/kv_catalog_test.go b/internal/metastore/kv/rootcoord/kv_catalog_test.go index 70b622f32274f..54c22f9c08db9 100644 --- a/internal/metastore/kv/rootcoord/kv_catalog_test.go +++ b/internal/metastore/kv/rootcoord/kv_catalog_test.go @@ -10,6 +10,12 @@ import ( "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + "golang.org/x/exp/maps" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -23,12 +29,8 @@ import ( "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/crypto" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "go.uber.org/atomic" - "golang.org/x/exp/maps" ) var ( @@ -664,10 +666,9 @@ func TestCatalog_DropPartitionV2(t *testing.T) { t.Run("failed to load collection", func(t *testing.T) { ctx := context.Background() - snapshot := kv.NewMockSnapshotKV() - snapshot.LoadFunc = func(key string, ts typeutil.Timestamp) (string, error) { - return "", errors.New("mock") - } + snapshot := mocks.NewSnapShotKV(t) + snapshot.On("Load", + mock.Anything, mock.Anything).Return("not in codec format", nil) kc := Catalog{Snapshot: snapshot} @@ -675,6 +676,19 @@ func TestCatalog_DropPartitionV2(t *testing.T) { assert.Error(t, err) }) + t.Run("failed to load collection, no key found", func(t *testing.T) { + ctx := context.Background() + + snapshot := mocks.NewSnapShotKV(t) + snapshot.On("Load", + mock.Anything, mock.Anything).Return("", merr.WrapErrIoKeyNotFound("partition")) + + kc := Catalog{Snapshot: snapshot} + + err := kc.DropPartition(ctx, 0, 100, 101, 0) + assert.NoError(t, err) + }) + t.Run("partition version after 210", func(t *testing.T) { ctx := context.Background() @@ -1423,7 +1437,6 @@ func TestRBAC_Credential(t *testing.T) { fmt.Sprintf("%s/%s", CredentialPrefix, "user3"), "random", } - } return nil }, @@ -1481,7 +1494,7 @@ func TestRBAC_Role(t *testing.T) { otherError = fmt.Errorf("mock load error") ) - kvmock.EXPECT().Load(notExistKey).Return("", common.NewKeyNotExistError(notExistKey)).Once() + kvmock.EXPECT().Load(notExistKey).Return("", merr.WrapErrIoKeyNotFound(notExistKey)).Once() kvmock.EXPECT().Load(errorKey).Return("", otherError).Once() kvmock.EXPECT().Load(mock.Anything).Return("", nil).Once() kvmock.EXPECT().Remove(mock.Anything).Call.Return(nil).Once() @@ -1525,7 +1538,7 @@ func TestRBAC_Role(t *testing.T) { otherError = fmt.Errorf("mock load error") ) - kvmock.EXPECT().Load(notExistKey).Return("", common.NewKeyNotExistError(notExistKey)).Once() + kvmock.EXPECT().Load(notExistKey).Return("", merr.WrapErrIoKeyNotFound(notExistKey)).Once() kvmock.EXPECT().Load(errorKey).Return("", otherError).Once() kvmock.EXPECT().Load(mock.Anything).Return("", nil).Once() kvmock.EXPECT().Save(mock.Anything, mock.Anything).Call.Return(nil).Once() @@ -1572,7 +1585,7 @@ func TestRBAC_Role(t *testing.T) { otherError = fmt.Errorf("mock load error") ) - kvmock.EXPECT().Load(notExistPath).Return("", common.NewKeyNotExistError(notExistName)).Once() + kvmock.EXPECT().Load(notExistPath).Return("", merr.WrapErrIoKeyNotFound(notExistName)).Once() kvmock.EXPECT().Load(errorPath).Return("", otherError).Once() kvmock.EXPECT().Load(mock.Anything).Return("", nil).Once() kvmock.EXPECT().Save(mock.Anything, mock.Anything).Call.Return(nil).Once() @@ -1679,7 +1692,7 @@ func TestRBAC_Role(t *testing.T) { kvmock.EXPECT().Load(errorRoleSavepath).Return("", nil) // Catalog.save() returns nil - kvmock.EXPECT().Load(noErrorRoleSavepath).Return("", common.NewKeyNotExistError(noErrorRoleSavepath)) + kvmock.EXPECT().Load(noErrorRoleSavepath).Return("", merr.WrapErrIoKeyNotFound(noErrorRoleSavepath)) // Catalog.remove() returns error kvmock.EXPECT().Load(errorRoleRemovepath).Return("", errors.New("not exists")) @@ -1717,9 +1730,7 @@ func TestRBAC_Role(t *testing.T) { }) t.Run("test ListRole", func(t *testing.T) { - var ( - loadWithPrefixReturn atomic.Bool - ) + var loadWithPrefixReturn atomic.Bool t.Run("test entity!=nil", func(t *testing.T) { var ( @@ -1984,10 +1995,8 @@ func TestRBAC_Role(t *testing.T) { assert.Error(t, err) assert.Empty(t, res) } - }) } - }) }) t.Run("test ListUserRole", func(t *testing.T) { @@ -2100,19 +2109,19 @@ func TestRBAC_Grant(t *testing.T) { }) kvmock.EXPECT().Load(keyNotExistRoleKey).Call. Return("", func(key string) error { - return common.NewKeyNotExistError(key) + return merr.WrapErrIoKeyNotFound(key) }) kvmock.EXPECT().Load(keyNotExistRoleKeyWithDb).Call. Return("", func(key string) error { - return common.NewKeyNotExistError(key) + return merr.WrapErrIoKeyNotFound(key) }) kvmock.EXPECT().Load(errorSaveRoleKey).Call. Return("", func(key string) error { - return common.NewKeyNotExistError(key) + return merr.WrapErrIoKeyNotFound(key) }) kvmock.EXPECT().Load(errorSaveRoleKeyWithDb).Call. Return("", func(key string) error { - return common.NewKeyNotExistError(key) + return merr.WrapErrIoKeyNotFound(key) }) kvmock.EXPECT().Save(keyNotExistRoleKeyWithDb, mock.Anything).Return(nil) kvmock.EXPECT().Save(errorSaveRoleKeyWithDb, mock.Anything).Return(errors.New("mock save error role")) @@ -2130,11 +2139,11 @@ func TestRBAC_Grant(t *testing.T) { }) kvmock.EXPECT().Load(keyNotExistPrivilegeKey).Call. Return("", func(key string) error { - return common.NewKeyNotExistError(key) + return merr.WrapErrIoKeyNotFound(key) }) kvmock.EXPECT().Load(keyNotExistPrivilegeKey2WithDb).Call. Return("", func(key string) error { - return common.NewKeyNotExistError(key) + return merr.WrapErrIoKeyNotFound(key) }) kvmock.EXPECT().Load(mock.Anything).Call.Return("", nil) @@ -2176,7 +2185,8 @@ func TestRBAC_Grant(t *testing.T) { DbName: util.DefaultDBName, Grantor: &milvuspb.GrantorEntity{ User: &milvuspb.UserEntity{Name: test.userName}, - Privilege: &milvuspb.PrivilegeEntity{Name: test.privilegeName}}, + Privilege: &milvuspb.PrivilegeEntity{Name: test.privilegeName}, + }, }, milvuspb.OperatePrivilegeType_Grant) if test.isValid { @@ -2195,7 +2205,7 @@ func TestRBAC_Grant(t *testing.T) { }) t.Run("test Revoke", func(t *testing.T) { - var invalidPrivilegeRemove = "p-remove" + invalidPrivilegeRemove := "p-remove" invalidPrivilegeRemoveKey := funcutil.HandleTenantForEtcdKey(GranteeIDPrefix, tenant, fmt.Sprintf("%s/%s", validRoleValue, invalidPrivilegeRemove)) kvmock.EXPECT().Load(invalidPrivilegeRemoveKey).Call.Return("", nil) @@ -2233,7 +2243,8 @@ func TestRBAC_Grant(t *testing.T) { DbName: util.DefaultDBName, Grantor: &milvuspb.GrantorEntity{ User: &milvuspb.UserEntity{Name: test.userName}, - Privilege: &milvuspb.PrivilegeEntity{Name: test.privilegeName}}, + Privilege: &milvuspb.PrivilegeEntity{Name: test.privilegeName}, + }, }, milvuspb.OperatePrivilegeType_Revoke) if test.isValid { @@ -2306,7 +2317,6 @@ func TestRBAC_Grant(t *testing.T) { kvmock.EXPECT().LoadWithPrefix(invalidRoleKey).Call.Return(nil, nil, errors.New("mock loadWithPrefix error")) kvmock.EXPECT().LoadWithPrefix(mock.Anything).Call.Return( func(key string) []string { - // Mock kv_catalog.go:ListGrant:L871 if strings.Contains(key, GranteeIDPrefix) { return []string{ @@ -2344,21 +2354,25 @@ func TestRBAC_Grant(t *testing.T) { {false, &milvuspb.GrantEntity{ Object: &milvuspb.ObjectEntity{Name: "random"}, ObjectName: "random2", - Role: &milvuspb.RoleEntity{Name: "role1"}}, "valid role with not exist entity"}, + Role: &milvuspb.RoleEntity{Name: "role1"}, + }, "valid role with not exist entity"}, {true, &milvuspb.GrantEntity{ Object: &milvuspb.ObjectEntity{Name: "obj1"}, ObjectName: "obj_name1", - Role: &milvuspb.RoleEntity{Name: "role1"}}, "valid role with valid entity"}, + Role: &milvuspb.RoleEntity{Name: "role1"}, + }, "valid role with valid entity"}, {true, &milvuspb.GrantEntity{ Object: &milvuspb.ObjectEntity{Name: "obj1"}, ObjectName: "obj_name2", DbName: "foo", - Role: &milvuspb.RoleEntity{Name: "role1"}}, "valid role and dbName with valid entity"}, + Role: &milvuspb.RoleEntity{Name: "role1"}, + }, "valid role and dbName with valid entity"}, {false, &milvuspb.GrantEntity{ Object: &milvuspb.ObjectEntity{Name: "obj1"}, ObjectName: "obj_name2", DbName: "foo2", - Role: &milvuspb.RoleEntity{Name: "role1"}}, "valid role and invalid dbName with valid entity"}, + Role: &milvuspb.RoleEntity{Name: "role1"}, + }, "valid role and invalid dbName with valid entity"}, } for _, test := range tests { @@ -2375,7 +2389,6 @@ func TestRBAC_Grant(t *testing.T) { } else { assert.Error(t, err) } - }) } }) diff --git a/internal/metastore/kv/rootcoord/suffix_snapshot.go b/internal/metastore/kv/rootcoord/suffix_snapshot.go index 2f9cc48c46244..45171a97ae0f0 100644 --- a/internal/metastore/kv/rootcoord/suffix_snapshot.go +++ b/internal/metastore/kv/rootcoord/suffix_snapshot.go @@ -28,8 +28,6 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/util/tsoutil" - "go.uber.org/zap" "github.com/milvus-io/milvus/internal/kv" @@ -37,6 +35,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/retry" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -254,7 +253,7 @@ func binarySearchRecords(records []tsv, ts typeutil.Timestamp) (string, bool) { i, j := 0, len(records) for i+1 < j { k := (i + j) / 2 - //log.Warn("", zap.Int("i", i), zap.Int("j", j), zap.Int("k", k)) + // log.Warn("", zap.Int("i", i), zap.Int("j", j), zap.Int("k", k)) if records[k].ts == ts { return records[k].value, true } @@ -362,7 +361,7 @@ func (ss *SuffixSnapshot) Load(key string, ts typeutil.Timestamp) (string, error // 3. find i which ts[i] <= ts && ts[i+1] > ts // corner cases like len(records)==0, ts < records[0].ts is covered in binarySearch - //binary search + // binary search value, found := binarySearchRecords(records, ts) if !found { log.Warn("not found") @@ -435,8 +434,8 @@ func (ss *SuffixSnapshot) LoadWithPrefix(key string, ts typeutil.Timestamp) ([]s // ts 0 case shall be treated as fetch latest/current value if ts == 0 { keys, values, err := ss.MetaKv.LoadWithPrefix(key) - fks := keys[:0] //make([]string, 0, len(keys)) - fvs := values[:0] //make([]string, 0, len(values)) + fks := keys[:0] // make([]string, 0, len(keys)) + fvs := values[:0] // make([]string, 0, len(values)) // hide rootPrefix from return value for i, k := range keys { // filters tombstone @@ -494,7 +493,6 @@ func (ss *SuffixSnapshot) LoadWithPrefix(key string, ts typeutil.Timestamp) ([]s latestOriginalKey = curOriginalKey return nil }) - if err != nil { return nil, nil, err } @@ -656,7 +654,6 @@ func (ss *SuffixSnapshot) removeExpiredKvs(now time.Time) error { return nil }) - if err != nil { return err } diff --git a/internal/metastore/kv/rootcoord/suffix_snapshot_test.go b/internal/metastore/kv/rootcoord/suffix_snapshot_test.go index 0c3d4538b1b96..5efc00680def2 100644 --- a/internal/metastore/kv/rootcoord/suffix_snapshot_test.go +++ b/internal/metastore/kv/rootcoord/suffix_snapshot_test.go @@ -24,8 +24,6 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/util/tsoutil" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -34,12 +32,11 @@ import ( "github.com/milvus-io/milvus/internal/kv/mocks" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) -var ( - snapshotPrefix = "snapshots" -) +var snapshotPrefix = "snapshots" var Params = paramtable.Get() @@ -272,7 +269,6 @@ func Test_SuffixSnaphotIsTSOfKey(t *testing.T) { assert.EqualValues(t, c.expected, ts) assert.Equal(t, c.shouldFound, found) } - } func Test_SuffixSnapshotLoad(t *testing.T) { diff --git a/internal/metastore/model/alias_test.go b/internal/metastore/model/alias_test.go index ab9d9061ba645..172f5d6dce3ea 100644 --- a/internal/metastore/model/alias_test.go +++ b/internal/metastore/model/alias_test.go @@ -3,8 +3,9 @@ package model import ( "testing" - "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/proto/etcdpb" ) func TestAlias_Available(t *testing.T) { diff --git a/internal/metastore/model/collection.go b/internal/metastore/model/collection.go index 5b03ad1976010..66acf68cf248c 100644 --- a/internal/metastore/model/collection.go +++ b/internal/metastore/model/collection.go @@ -1,11 +1,12 @@ package model import ( + "github.com/samber/lo" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/pkg/common" - "github.com/samber/lo" ) type Collection struct { diff --git a/internal/metastore/model/collection_test.go b/internal/metastore/model/collection_test.go index 75c627e092363..7ddde61e9f495 100644 --- a/internal/metastore/model/collection_test.go +++ b/internal/metastore/model/collection_test.go @@ -3,11 +3,12 @@ package model import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/pkg/common" - "github.com/stretchr/testify/assert" ) var ( diff --git a/internal/metastore/model/field.go b/internal/metastore/model/field.go index 029bc825b784f..10d44604d2406 100644 --- a/internal/metastore/model/field.go +++ b/internal/metastore/model/field.go @@ -1,10 +1,9 @@ package model import ( - "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" ) type Field struct { @@ -20,6 +19,7 @@ type Field struct { IsDynamic bool IsPartitionKey bool // partition key mode, multi logic partitions share a physical partition DefaultValue *schemapb.ValueField + ElementType schemapb.DataType } func (f *Field) Available() bool { @@ -40,6 +40,7 @@ func (f *Field) Clone() *Field { IsDynamic: f.IsDynamic, IsPartitionKey: f.IsPartitionKey, DefaultValue: f.DefaultValue, + ElementType: f.ElementType, } } @@ -67,7 +68,8 @@ func (f *Field) Equal(other Field) bool { f.AutoID == other.AutoID && f.IsPartitionKey == other.IsPartitionKey && f.IsDynamic == other.IsDynamic && - f.DefaultValue == other.DefaultValue + f.DefaultValue == other.DefaultValue && + f.ElementType == other.ElementType } func CheckFieldsEqual(fieldsA, fieldsB []*Field) bool { @@ -100,6 +102,7 @@ func MarshalFieldModel(field *Field) *schemapb.FieldSchema { IsDynamic: field.IsDynamic, IsPartitionKey: field.IsPartitionKey, DefaultValue: field.DefaultValue, + ElementType: field.ElementType, } } @@ -132,6 +135,7 @@ func UnmarshalFieldModel(fieldSchema *schemapb.FieldSchema) *Field { IsDynamic: fieldSchema.IsDynamic, IsPartitionKey: fieldSchema.IsPartitionKey, DefaultValue: fieldSchema.DefaultValue, + ElementType: fieldSchema.ElementType, } } diff --git a/internal/metastore/model/field_test.go b/internal/metastore/model/field_test.go index c519e9915a1c6..b72c71ec08079 100644 --- a/internal/metastore/model/field_test.go +++ b/internal/metastore/model/field_test.go @@ -3,10 +3,10 @@ package model import ( "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/stretchr/testify/assert" ) var ( diff --git a/internal/metastore/model/index.go b/internal/metastore/model/index.go index 052cbb0f91352..ffbecd13a5167 100644 --- a/internal/metastore/model/index.go +++ b/internal/metastore/model/index.go @@ -2,6 +2,7 @@ package model import ( "github.com/golang/protobuf/proto" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/indexpb" ) diff --git a/internal/metastore/model/index_test.go b/internal/metastore/model/index_test.go index 11b5a026195a2..10cde6c136612 100644 --- a/internal/metastore/model/index_test.go +++ b/internal/metastore/model/index_test.go @@ -3,11 +3,10 @@ package model import ( "testing" - "github.com/milvus-io/milvus/internal/proto/indexpb" - "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/proto/indexpb" ) var ( diff --git a/internal/metastore/model/segment.go b/internal/metastore/model/segment.go index ca428a4787119..5c119ec2ad26b 100644 --- a/internal/metastore/model/segment.go +++ b/internal/metastore/model/segment.go @@ -15,7 +15,7 @@ type Segment struct { CompactionFrom []int64 CreatedByCompaction bool SegmentState commonpb.SegmentState - //IndexInfos []*SegmentIndex + // IndexInfos []*SegmentIndex ReplicaIds []int64 NodeIds []int64 } diff --git a/internal/metastore/model/segment_index.go b/internal/metastore/model/segment_index.go index fc82880a68445..3125b0106c333 100644 --- a/internal/metastore/model/segment_index.go +++ b/internal/metastore/model/segment_index.go @@ -22,7 +22,8 @@ type SegmentIndex struct { IndexFileKeys []string IndexSize uint64 // deprecated - WriteHandoff bool + WriteHandoff bool + CurrentIndexVersion int32 } func UnmarshalSegmentIndexModel(segIndex *indexpb.SegmentIndex) *SegmentIndex { @@ -31,21 +32,22 @@ func UnmarshalSegmentIndexModel(segIndex *indexpb.SegmentIndex) *SegmentIndex { } return &SegmentIndex{ - SegmentID: segIndex.SegmentID, - CollectionID: segIndex.CollectionID, - PartitionID: segIndex.PartitionID, - NumRows: segIndex.NumRows, - IndexID: segIndex.IndexID, - BuildID: segIndex.BuildID, - NodeID: segIndex.NodeID, - IndexState: segIndex.State, - FailReason: segIndex.FailReason, - IndexVersion: segIndex.IndexVersion, - IsDeleted: segIndex.Deleted, - CreateTime: segIndex.CreateTime, - IndexFileKeys: common.CloneStringList(segIndex.IndexFileKeys), - IndexSize: segIndex.SerializeSize, - WriteHandoff: segIndex.WriteHandoff, + SegmentID: segIndex.SegmentID, + CollectionID: segIndex.CollectionID, + PartitionID: segIndex.PartitionID, + NumRows: segIndex.NumRows, + IndexID: segIndex.IndexID, + BuildID: segIndex.BuildID, + NodeID: segIndex.NodeID, + IndexState: segIndex.State, + FailReason: segIndex.FailReason, + IndexVersion: segIndex.IndexVersion, + IsDeleted: segIndex.Deleted, + CreateTime: segIndex.CreateTime, + IndexFileKeys: common.CloneStringList(segIndex.IndexFileKeys), + IndexSize: segIndex.SerializeSize, + WriteHandoff: segIndex.WriteHandoff, + CurrentIndexVersion: segIndex.GetCurrentIndexVersion(), } } @@ -55,40 +57,42 @@ func MarshalSegmentIndexModel(segIdx *SegmentIndex) *indexpb.SegmentIndex { } return &indexpb.SegmentIndex{ - CollectionID: segIdx.CollectionID, - PartitionID: segIdx.PartitionID, - SegmentID: segIdx.SegmentID, - NumRows: segIdx.NumRows, - IndexID: segIdx.IndexID, - BuildID: segIdx.BuildID, - NodeID: segIdx.NodeID, - State: segIdx.IndexState, - FailReason: segIdx.FailReason, - IndexVersion: segIdx.IndexVersion, - IndexFileKeys: common.CloneStringList(segIdx.IndexFileKeys), - Deleted: segIdx.IsDeleted, - CreateTime: segIdx.CreateTime, - SerializeSize: segIdx.IndexSize, - WriteHandoff: segIdx.WriteHandoff, + CollectionID: segIdx.CollectionID, + PartitionID: segIdx.PartitionID, + SegmentID: segIdx.SegmentID, + NumRows: segIdx.NumRows, + IndexID: segIdx.IndexID, + BuildID: segIdx.BuildID, + NodeID: segIdx.NodeID, + State: segIdx.IndexState, + FailReason: segIdx.FailReason, + IndexVersion: segIdx.IndexVersion, + IndexFileKeys: common.CloneStringList(segIdx.IndexFileKeys), + Deleted: segIdx.IsDeleted, + CreateTime: segIdx.CreateTime, + SerializeSize: segIdx.IndexSize, + WriteHandoff: segIdx.WriteHandoff, + CurrentIndexVersion: segIdx.CurrentIndexVersion, } } func CloneSegmentIndex(segIndex *SegmentIndex) *SegmentIndex { return &SegmentIndex{ - SegmentID: segIndex.SegmentID, - CollectionID: segIndex.CollectionID, - PartitionID: segIndex.PartitionID, - NumRows: segIndex.NumRows, - IndexID: segIndex.IndexID, - BuildID: segIndex.BuildID, - NodeID: segIndex.NodeID, - IndexState: segIndex.IndexState, - FailReason: segIndex.FailReason, - IndexVersion: segIndex.IndexVersion, - IsDeleted: segIndex.IsDeleted, - CreateTime: segIndex.CreateTime, - IndexFileKeys: common.CloneStringList(segIndex.IndexFileKeys), - IndexSize: segIndex.IndexSize, - WriteHandoff: segIndex.WriteHandoff, + SegmentID: segIndex.SegmentID, + CollectionID: segIndex.CollectionID, + PartitionID: segIndex.PartitionID, + NumRows: segIndex.NumRows, + IndexID: segIndex.IndexID, + BuildID: segIndex.BuildID, + NodeID: segIndex.NodeID, + IndexState: segIndex.IndexState, + FailReason: segIndex.FailReason, + IndexVersion: segIndex.IndexVersion, + IsDeleted: segIndex.IsDeleted, + CreateTime: segIndex.CreateTime, + IndexFileKeys: common.CloneStringList(segIndex.IndexFileKeys), + IndexSize: segIndex.IndexSize, + WriteHandoff: segIndex.WriteHandoff, + CurrentIndexVersion: segIndex.CurrentIndexVersion, } } diff --git a/internal/metastore/model/segment_index_test.go b/internal/metastore/model/segment_index_test.go index 8a86cae6419e8..a3c056ec6e1b0 100644 --- a/internal/metastore/model/segment_index_test.go +++ b/internal/metastore/model/segment_index_test.go @@ -3,11 +3,10 @@ package model import ( "testing" - "github.com/milvus-io/milvus/internal/proto/indexpb" - "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/proto/indexpb" ) var ( diff --git a/internal/mocks/mock_datacoord.go b/internal/mocks/mock_datacoord.go index a31eb8845ad99..ebe92e4976bbd 100644 --- a/internal/mocks/mock_datacoord.go +++ b/internal/mocks/mock_datacoord.go @@ -18,6 +18,8 @@ import ( mock "github.com/stretchr/testify/mock" + txnkv "github.com/tikv/client-go/v2/txnkv" + types "github.com/milvus-io/milvus/internal/types" ) @@ -34,17 +36,17 @@ func (_m *MockDataCoord) EXPECT() *MockDataCoord_Expecter { return &MockDataCoord_Expecter{mock: &_m.Mock} } -// AssignSegmentID provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { - ret := _m.Called(ctx, req) +// AssignSegmentID provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) AssignSegmentID(_a0 context.Context, _a1 *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.AssignSegmentIDResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.AssignSegmentIDRequest) *datapb.AssignSegmentIDResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.AssignSegmentIDResponse) @@ -52,7 +54,7 @@ func (_m *MockDataCoord) AssignSegmentID(ctx context.Context, req *datapb.Assign } if rf, ok := ret.Get(1).(func(context.Context, *datapb.AssignSegmentIDRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -66,13 +68,13 @@ type MockDataCoord_AssignSegmentID_Call struct { } // AssignSegmentID is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.AssignSegmentIDRequest -func (_e *MockDataCoord_Expecter) AssignSegmentID(ctx interface{}, req interface{}) *MockDataCoord_AssignSegmentID_Call { - return &MockDataCoord_AssignSegmentID_Call{Call: _e.mock.On("AssignSegmentID", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.AssignSegmentIDRequest +func (_e *MockDataCoord_Expecter) AssignSegmentID(_a0 interface{}, _a1 interface{}) *MockDataCoord_AssignSegmentID_Call { + return &MockDataCoord_AssignSegmentID_Call{Call: _e.mock.On("AssignSegmentID", _a0, _a1)} } -func (_c *MockDataCoord_AssignSegmentID_Call) Run(run func(ctx context.Context, req *datapb.AssignSegmentIDRequest)) *MockDataCoord_AssignSegmentID_Call { +func (_c *MockDataCoord_AssignSegmentID_Call) Run(run func(_a0 context.Context, _a1 *datapb.AssignSegmentIDRequest)) *MockDataCoord_AssignSegmentID_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.AssignSegmentIDRequest)) }) @@ -89,17 +91,17 @@ func (_c *MockDataCoord_AssignSegmentID_Call) RunAndReturn(run func(context.Cont return _c } -// BroadcastAlteredCollection provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) BroadcastAlteredCollection(ctx context.Context, req *datapb.AlterCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// BroadcastAlteredCollection provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) BroadcastAlteredCollection(_a0 context.Context, _a1 *datapb.AlterCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.AlterCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.AlterCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -107,7 +109,7 @@ func (_m *MockDataCoord) BroadcastAlteredCollection(ctx context.Context, req *da } if rf, ok := ret.Get(1).(func(context.Context, *datapb.AlterCollectionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -121,13 +123,13 @@ type MockDataCoord_BroadcastAlteredCollection_Call struct { } // BroadcastAlteredCollection is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.AlterCollectionRequest -func (_e *MockDataCoord_Expecter) BroadcastAlteredCollection(ctx interface{}, req interface{}) *MockDataCoord_BroadcastAlteredCollection_Call { - return &MockDataCoord_BroadcastAlteredCollection_Call{Call: _e.mock.On("BroadcastAlteredCollection", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.AlterCollectionRequest +func (_e *MockDataCoord_Expecter) BroadcastAlteredCollection(_a0 interface{}, _a1 interface{}) *MockDataCoord_BroadcastAlteredCollection_Call { + return &MockDataCoord_BroadcastAlteredCollection_Call{Call: _e.mock.On("BroadcastAlteredCollection", _a0, _a1)} } -func (_c *MockDataCoord_BroadcastAlteredCollection_Call) Run(run func(ctx context.Context, req *datapb.AlterCollectionRequest)) *MockDataCoord_BroadcastAlteredCollection_Call { +func (_c *MockDataCoord_BroadcastAlteredCollection_Call) Run(run func(_a0 context.Context, _a1 *datapb.AlterCollectionRequest)) *MockDataCoord_BroadcastAlteredCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.AlterCollectionRequest)) }) @@ -144,17 +146,17 @@ func (_c *MockDataCoord_BroadcastAlteredCollection_Call) RunAndReturn(run func(c return _c } -// CheckHealth provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { - ret := _m.Called(ctx, req) +// CheckHealth provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) CheckHealth(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.CheckHealthResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) *milvuspb.CheckHealthResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) @@ -162,7 +164,7 @@ func (_m *MockDataCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHea } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CheckHealthRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -176,13 +178,13 @@ type MockDataCoord_CheckHealth_Call struct { } // CheckHealth is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CheckHealthRequest -func (_e *MockDataCoord_Expecter) CheckHealth(ctx interface{}, req interface{}) *MockDataCoord_CheckHealth_Call { - return &MockDataCoord_CheckHealth_Call{Call: _e.mock.On("CheckHealth", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CheckHealthRequest +func (_e *MockDataCoord_Expecter) CheckHealth(_a0 interface{}, _a1 interface{}) *MockDataCoord_CheckHealth_Call { + return &MockDataCoord_CheckHealth_Call{Call: _e.mock.On("CheckHealth", _a0, _a1)} } -func (_c *MockDataCoord_CheckHealth_Call) Run(run func(ctx context.Context, req *milvuspb.CheckHealthRequest)) *MockDataCoord_CheckHealth_Call { +func (_c *MockDataCoord_CheckHealth_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest)) *MockDataCoord_CheckHealth_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CheckHealthRequest)) }) @@ -199,17 +201,17 @@ func (_c *MockDataCoord_CheckHealth_Call) RunAndReturn(run func(context.Context, return _c } -// CreateIndex provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) CreateIndex(ctx context.Context, req *indexpb.CreateIndexRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// CreateIndex provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) CreateIndex(_a0 context.Context, _a1 *indexpb.CreateIndexRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateIndexRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateIndexRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -217,7 +219,7 @@ func (_m *MockDataCoord) CreateIndex(ctx context.Context, req *indexpb.CreateInd } if rf, ok := ret.Get(1).(func(context.Context, *indexpb.CreateIndexRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -231,13 +233,13 @@ type MockDataCoord_CreateIndex_Call struct { } // CreateIndex is a helper method to define mock.On call -// - ctx context.Context -// - req *indexpb.CreateIndexRequest -func (_e *MockDataCoord_Expecter) CreateIndex(ctx interface{}, req interface{}) *MockDataCoord_CreateIndex_Call { - return &MockDataCoord_CreateIndex_Call{Call: _e.mock.On("CreateIndex", ctx, req)} +// - _a0 context.Context +// - _a1 *indexpb.CreateIndexRequest +func (_e *MockDataCoord_Expecter) CreateIndex(_a0 interface{}, _a1 interface{}) *MockDataCoord_CreateIndex_Call { + return &MockDataCoord_CreateIndex_Call{Call: _e.mock.On("CreateIndex", _a0, _a1)} } -func (_c *MockDataCoord_CreateIndex_Call) Run(run func(ctx context.Context, req *indexpb.CreateIndexRequest)) *MockDataCoord_CreateIndex_Call { +func (_c *MockDataCoord_CreateIndex_Call) Run(run func(_a0 context.Context, _a1 *indexpb.CreateIndexRequest)) *MockDataCoord_CreateIndex_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*indexpb.CreateIndexRequest)) }) @@ -254,17 +256,17 @@ func (_c *MockDataCoord_CreateIndex_Call) RunAndReturn(run func(context.Context, return _c } -// DescribeIndex provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error) { - ret := _m.Called(ctx, req) +// DescribeIndex provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) DescribeIndex(_a0 context.Context, _a1 *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *indexpb.DescribeIndexResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DescribeIndexRequest) *indexpb.DescribeIndexResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*indexpb.DescribeIndexResponse) @@ -272,7 +274,7 @@ func (_m *MockDataCoord) DescribeIndex(ctx context.Context, req *indexpb.Describ } if rf, ok := ret.Get(1).(func(context.Context, *indexpb.DescribeIndexRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -286,13 +288,13 @@ type MockDataCoord_DescribeIndex_Call struct { } // DescribeIndex is a helper method to define mock.On call -// - ctx context.Context -// - req *indexpb.DescribeIndexRequest -func (_e *MockDataCoord_Expecter) DescribeIndex(ctx interface{}, req interface{}) *MockDataCoord_DescribeIndex_Call { - return &MockDataCoord_DescribeIndex_Call{Call: _e.mock.On("DescribeIndex", ctx, req)} +// - _a0 context.Context +// - _a1 *indexpb.DescribeIndexRequest +func (_e *MockDataCoord_Expecter) DescribeIndex(_a0 interface{}, _a1 interface{}) *MockDataCoord_DescribeIndex_Call { + return &MockDataCoord_DescribeIndex_Call{Call: _e.mock.On("DescribeIndex", _a0, _a1)} } -func (_c *MockDataCoord_DescribeIndex_Call) Run(run func(ctx context.Context, req *indexpb.DescribeIndexRequest)) *MockDataCoord_DescribeIndex_Call { +func (_c *MockDataCoord_DescribeIndex_Call) Run(run func(_a0 context.Context, _a1 *indexpb.DescribeIndexRequest)) *MockDataCoord_DescribeIndex_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*indexpb.DescribeIndexRequest)) }) @@ -309,17 +311,17 @@ func (_c *MockDataCoord_DescribeIndex_Call) RunAndReturn(run func(context.Contex return _c } -// DropIndex provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// DropIndex provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) DropIndex(_a0 context.Context, _a1 *indexpb.DropIndexRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropIndexRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropIndexRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -327,7 +329,7 @@ func (_m *MockDataCoord) DropIndex(ctx context.Context, req *indexpb.DropIndexRe } if rf, ok := ret.Get(1).(func(context.Context, *indexpb.DropIndexRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -341,13 +343,13 @@ type MockDataCoord_DropIndex_Call struct { } // DropIndex is a helper method to define mock.On call -// - ctx context.Context -// - req *indexpb.DropIndexRequest -func (_e *MockDataCoord_Expecter) DropIndex(ctx interface{}, req interface{}) *MockDataCoord_DropIndex_Call { - return &MockDataCoord_DropIndex_Call{Call: _e.mock.On("DropIndex", ctx, req)} +// - _a0 context.Context +// - _a1 *indexpb.DropIndexRequest +func (_e *MockDataCoord_Expecter) DropIndex(_a0 interface{}, _a1 interface{}) *MockDataCoord_DropIndex_Call { + return &MockDataCoord_DropIndex_Call{Call: _e.mock.On("DropIndex", _a0, _a1)} } -func (_c *MockDataCoord_DropIndex_Call) Run(run func(ctx context.Context, req *indexpb.DropIndexRequest)) *MockDataCoord_DropIndex_Call { +func (_c *MockDataCoord_DropIndex_Call) Run(run func(_a0 context.Context, _a1 *indexpb.DropIndexRequest)) *MockDataCoord_DropIndex_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*indexpb.DropIndexRequest)) }) @@ -364,17 +366,17 @@ func (_c *MockDataCoord_DropIndex_Call) RunAndReturn(run func(context.Context, * return _c } -// DropVirtualChannel provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) (*datapb.DropVirtualChannelResponse, error) { - ret := _m.Called(ctx, req) +// DropVirtualChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) DropVirtualChannel(_a0 context.Context, _a1 *datapb.DropVirtualChannelRequest) (*datapb.DropVirtualChannelResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.DropVirtualChannelResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.DropVirtualChannelRequest) (*datapb.DropVirtualChannelResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.DropVirtualChannelRequest) *datapb.DropVirtualChannelResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.DropVirtualChannelResponse) @@ -382,7 +384,7 @@ func (_m *MockDataCoord) DropVirtualChannel(ctx context.Context, req *datapb.Dro } if rf, ok := ret.Get(1).(func(context.Context, *datapb.DropVirtualChannelRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -396,13 +398,13 @@ type MockDataCoord_DropVirtualChannel_Call struct { } // DropVirtualChannel is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.DropVirtualChannelRequest -func (_e *MockDataCoord_Expecter) DropVirtualChannel(ctx interface{}, req interface{}) *MockDataCoord_DropVirtualChannel_Call { - return &MockDataCoord_DropVirtualChannel_Call{Call: _e.mock.On("DropVirtualChannel", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.DropVirtualChannelRequest +func (_e *MockDataCoord_Expecter) DropVirtualChannel(_a0 interface{}, _a1 interface{}) *MockDataCoord_DropVirtualChannel_Call { + return &MockDataCoord_DropVirtualChannel_Call{Call: _e.mock.On("DropVirtualChannel", _a0, _a1)} } -func (_c *MockDataCoord_DropVirtualChannel_Call) Run(run func(ctx context.Context, req *datapb.DropVirtualChannelRequest)) *MockDataCoord_DropVirtualChannel_Call { +func (_c *MockDataCoord_DropVirtualChannel_Call) Run(run func(_a0 context.Context, _a1 *datapb.DropVirtualChannelRequest)) *MockDataCoord_DropVirtualChannel_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.DropVirtualChannelRequest)) }) @@ -419,17 +421,17 @@ func (_c *MockDataCoord_DropVirtualChannel_Call) RunAndReturn(run func(context.C return _c } -// Flush provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.FlushResponse, error) { - ret := _m.Called(ctx, req) +// Flush provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) Flush(_a0 context.Context, _a1 *datapb.FlushRequest) (*datapb.FlushResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.FlushResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.FlushRequest) (*datapb.FlushResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.FlushRequest) *datapb.FlushResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.FlushResponse) @@ -437,7 +439,7 @@ func (_m *MockDataCoord) Flush(ctx context.Context, req *datapb.FlushRequest) (* } if rf, ok := ret.Get(1).(func(context.Context, *datapb.FlushRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -451,13 +453,13 @@ type MockDataCoord_Flush_Call struct { } // Flush is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.FlushRequest -func (_e *MockDataCoord_Expecter) Flush(ctx interface{}, req interface{}) *MockDataCoord_Flush_Call { - return &MockDataCoord_Flush_Call{Call: _e.mock.On("Flush", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.FlushRequest +func (_e *MockDataCoord_Expecter) Flush(_a0 interface{}, _a1 interface{}) *MockDataCoord_Flush_Call { + return &MockDataCoord_Flush_Call{Call: _e.mock.On("Flush", _a0, _a1)} } -func (_c *MockDataCoord_Flush_Call) Run(run func(ctx context.Context, req *datapb.FlushRequest)) *MockDataCoord_Flush_Call { +func (_c *MockDataCoord_Flush_Call) Run(run func(_a0 context.Context, _a1 *datapb.FlushRequest)) *MockDataCoord_Flush_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.FlushRequest)) }) @@ -474,17 +476,17 @@ func (_c *MockDataCoord_Flush_Call) RunAndReturn(run func(context.Context, *data return _c } -// GcConfirm provides a mock function with given fields: ctx, request -func (_m *MockDataCoord) GcConfirm(ctx context.Context, request *datapb.GcConfirmRequest) (*datapb.GcConfirmResponse, error) { - ret := _m.Called(ctx, request) +// GcConfirm provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GcConfirm(_a0 context.Context, _a1 *datapb.GcConfirmRequest) (*datapb.GcConfirmResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.GcConfirmResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.GcConfirmRequest) (*datapb.GcConfirmResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.GcConfirmRequest) *datapb.GcConfirmResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.GcConfirmResponse) @@ -492,7 +494,7 @@ func (_m *MockDataCoord) GcConfirm(ctx context.Context, request *datapb.GcConfir } if rf, ok := ret.Get(1).(func(context.Context, *datapb.GcConfirmRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -506,13 +508,13 @@ type MockDataCoord_GcConfirm_Call struct { } // GcConfirm is a helper method to define mock.On call -// - ctx context.Context -// - request *datapb.GcConfirmRequest -func (_e *MockDataCoord_Expecter) GcConfirm(ctx interface{}, request interface{}) *MockDataCoord_GcConfirm_Call { - return &MockDataCoord_GcConfirm_Call{Call: _e.mock.On("GcConfirm", ctx, request)} +// - _a0 context.Context +// - _a1 *datapb.GcConfirmRequest +func (_e *MockDataCoord_Expecter) GcConfirm(_a0 interface{}, _a1 interface{}) *MockDataCoord_GcConfirm_Call { + return &MockDataCoord_GcConfirm_Call{Call: _e.mock.On("GcConfirm", _a0, _a1)} } -func (_c *MockDataCoord_GcConfirm_Call) Run(run func(ctx context.Context, request *datapb.GcConfirmRequest)) *MockDataCoord_GcConfirm_Call { +func (_c *MockDataCoord_GcConfirm_Call) Run(run func(_a0 context.Context, _a1 *datapb.GcConfirmRequest)) *MockDataCoord_GcConfirm_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.GcConfirmRequest)) }) @@ -529,17 +531,17 @@ func (_c *MockDataCoord_GcConfirm_Call) RunAndReturn(run func(context.Context, * return _c } -// GetCollectionStatistics provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetCollectionStatistics(ctx context.Context, req *datapb.GetCollectionStatisticsRequest) (*datapb.GetCollectionStatisticsResponse, error) { - ret := _m.Called(ctx, req) +// GetCollectionStatistics provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetCollectionStatistics(_a0 context.Context, _a1 *datapb.GetCollectionStatisticsRequest) (*datapb.GetCollectionStatisticsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.GetCollectionStatisticsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetCollectionStatisticsRequest) (*datapb.GetCollectionStatisticsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetCollectionStatisticsRequest) *datapb.GetCollectionStatisticsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.GetCollectionStatisticsResponse) @@ -547,7 +549,7 @@ func (_m *MockDataCoord) GetCollectionStatistics(ctx context.Context, req *datap } if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetCollectionStatisticsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -561,13 +563,13 @@ type MockDataCoord_GetCollectionStatistics_Call struct { } // GetCollectionStatistics is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.GetCollectionStatisticsRequest -func (_e *MockDataCoord_Expecter) GetCollectionStatistics(ctx interface{}, req interface{}) *MockDataCoord_GetCollectionStatistics_Call { - return &MockDataCoord_GetCollectionStatistics_Call{Call: _e.mock.On("GetCollectionStatistics", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.GetCollectionStatisticsRequest +func (_e *MockDataCoord_Expecter) GetCollectionStatistics(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetCollectionStatistics_Call { + return &MockDataCoord_GetCollectionStatistics_Call{Call: _e.mock.On("GetCollectionStatistics", _a0, _a1)} } -func (_c *MockDataCoord_GetCollectionStatistics_Call) Run(run func(ctx context.Context, req *datapb.GetCollectionStatisticsRequest)) *MockDataCoord_GetCollectionStatistics_Call { +func (_c *MockDataCoord_GetCollectionStatistics_Call) Run(run func(_a0 context.Context, _a1 *datapb.GetCollectionStatisticsRequest)) *MockDataCoord_GetCollectionStatistics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.GetCollectionStatisticsRequest)) }) @@ -584,17 +586,17 @@ func (_c *MockDataCoord_GetCollectionStatistics_Call) RunAndReturn(run func(cont return _c } -// GetCompactionState provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetCompactionState(ctx context.Context, req *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) { - ret := _m.Called(ctx, req) +// GetCompactionState provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetCompactionState(_a0 context.Context, _a1 *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetCompactionStateResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionStateRequest) *milvuspb.GetCompactionStateResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetCompactionStateResponse) @@ -602,7 +604,7 @@ func (_m *MockDataCoord) GetCompactionState(ctx context.Context, req *milvuspb.G } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetCompactionStateRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -616,13 +618,13 @@ type MockDataCoord_GetCompactionState_Call struct { } // GetCompactionState is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetCompactionStateRequest -func (_e *MockDataCoord_Expecter) GetCompactionState(ctx interface{}, req interface{}) *MockDataCoord_GetCompactionState_Call { - return &MockDataCoord_GetCompactionState_Call{Call: _e.mock.On("GetCompactionState", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetCompactionStateRequest +func (_e *MockDataCoord_Expecter) GetCompactionState(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetCompactionState_Call { + return &MockDataCoord_GetCompactionState_Call{Call: _e.mock.On("GetCompactionState", _a0, _a1)} } -func (_c *MockDataCoord_GetCompactionState_Call) Run(run func(ctx context.Context, req *milvuspb.GetCompactionStateRequest)) *MockDataCoord_GetCompactionState_Call { +func (_c *MockDataCoord_GetCompactionState_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetCompactionStateRequest)) *MockDataCoord_GetCompactionState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetCompactionStateRequest)) }) @@ -639,17 +641,17 @@ func (_c *MockDataCoord_GetCompactionState_Call) RunAndReturn(run func(context.C return _c } -// GetCompactionStateWithPlans provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) { - ret := _m.Called(ctx, req) +// GetCompactionStateWithPlans provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetCompactionStateWithPlans(_a0 context.Context, _a1 *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetCompactionPlansResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionPlansRequest) *milvuspb.GetCompactionPlansResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetCompactionPlansResponse) @@ -657,7 +659,7 @@ func (_m *MockDataCoord) GetCompactionStateWithPlans(ctx context.Context, req *m } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetCompactionPlansRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -671,13 +673,13 @@ type MockDataCoord_GetCompactionStateWithPlans_Call struct { } // GetCompactionStateWithPlans is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetCompactionPlansRequest -func (_e *MockDataCoord_Expecter) GetCompactionStateWithPlans(ctx interface{}, req interface{}) *MockDataCoord_GetCompactionStateWithPlans_Call { - return &MockDataCoord_GetCompactionStateWithPlans_Call{Call: _e.mock.On("GetCompactionStateWithPlans", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetCompactionPlansRequest +func (_e *MockDataCoord_Expecter) GetCompactionStateWithPlans(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetCompactionStateWithPlans_Call { + return &MockDataCoord_GetCompactionStateWithPlans_Call{Call: _e.mock.On("GetCompactionStateWithPlans", _a0, _a1)} } -func (_c *MockDataCoord_GetCompactionStateWithPlans_Call) Run(run func(ctx context.Context, req *milvuspb.GetCompactionPlansRequest)) *MockDataCoord_GetCompactionStateWithPlans_Call { +func (_c *MockDataCoord_GetCompactionStateWithPlans_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetCompactionPlansRequest)) *MockDataCoord_GetCompactionStateWithPlans_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetCompactionPlansRequest)) }) @@ -694,25 +696,25 @@ func (_c *MockDataCoord_GetCompactionStateWithPlans_Call) RunAndReturn(run func( return _c } -// GetComponentStates provides a mock function with given fields: ctx -func (_m *MockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret := _m.Called(ctx) +// GetComponentStates provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetComponentStates(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ComponentStates var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.ComponentStates, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.ComponentStates); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) *milvuspb.ComponentStates); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ComponentStates) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -726,14 +728,15 @@ type MockDataCoord_GetComponentStates_Call struct { } // GetComponentStates is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockDataCoord_Expecter) GetComponentStates(ctx interface{}) *MockDataCoord_GetComponentStates_Call { - return &MockDataCoord_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx)} +// - _a0 context.Context +// - _a1 *milvuspb.GetComponentStatesRequest +func (_e *MockDataCoord_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetComponentStates_Call { + return &MockDataCoord_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)} } -func (_c *MockDataCoord_GetComponentStates_Call) Run(run func(ctx context.Context)) *MockDataCoord_GetComponentStates_Call { +func (_c *MockDataCoord_GetComponentStates_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest)) *MockDataCoord_GetComponentStates_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest)) }) return _c } @@ -743,22 +746,22 @@ func (_c *MockDataCoord_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentS return _c } -func (_c *MockDataCoord_GetComponentStates_Call) RunAndReturn(run func(context.Context) (*milvuspb.ComponentStates, error)) *MockDataCoord_GetComponentStates_Call { +func (_c *MockDataCoord_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)) *MockDataCoord_GetComponentStates_Call { _c.Call.Return(run) return _c } -// GetFlushAllState provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAllStateRequest) (*milvuspb.GetFlushAllStateResponse, error) { - ret := _m.Called(ctx, req) +// GetFlushAllState provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetFlushAllState(_a0 context.Context, _a1 *milvuspb.GetFlushAllStateRequest) (*milvuspb.GetFlushAllStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetFlushAllStateResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetFlushAllStateRequest) (*milvuspb.GetFlushAllStateResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetFlushAllStateRequest) *milvuspb.GetFlushAllStateResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetFlushAllStateResponse) @@ -766,7 +769,7 @@ func (_m *MockDataCoord) GetFlushAllState(ctx context.Context, req *milvuspb.Get } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetFlushAllStateRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -780,13 +783,13 @@ type MockDataCoord_GetFlushAllState_Call struct { } // GetFlushAllState is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetFlushAllStateRequest -func (_e *MockDataCoord_Expecter) GetFlushAllState(ctx interface{}, req interface{}) *MockDataCoord_GetFlushAllState_Call { - return &MockDataCoord_GetFlushAllState_Call{Call: _e.mock.On("GetFlushAllState", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetFlushAllStateRequest +func (_e *MockDataCoord_Expecter) GetFlushAllState(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetFlushAllState_Call { + return &MockDataCoord_GetFlushAllState_Call{Call: _e.mock.On("GetFlushAllState", _a0, _a1)} } -func (_c *MockDataCoord_GetFlushAllState_Call) Run(run func(ctx context.Context, req *milvuspb.GetFlushAllStateRequest)) *MockDataCoord_GetFlushAllState_Call { +func (_c *MockDataCoord_GetFlushAllState_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetFlushAllStateRequest)) *MockDataCoord_GetFlushAllState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetFlushAllStateRequest)) }) @@ -803,25 +806,25 @@ func (_c *MockDataCoord_GetFlushAllState_Call) RunAndReturn(run func(context.Con return _c } -// GetFlushState provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { - ret := _m.Called(ctx, req) +// GetFlushState provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetFlushState(_a0 context.Context, _a1 *datapb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetFlushStateResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error)); ok { - return rf(ctx, req) + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetFlushStateRequest) *milvuspb.GetFlushStateResponse); ok { - r0 = rf(ctx, req) + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetFlushStateRequest) *milvuspb.GetFlushStateResponse); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetFlushStateResponse) } } - if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetFlushStateRequest) error); ok { - r1 = rf(ctx, req) + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetFlushStateRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -835,15 +838,15 @@ type MockDataCoord_GetFlushState_Call struct { } // GetFlushState is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetFlushStateRequest -func (_e *MockDataCoord_Expecter) GetFlushState(ctx interface{}, req interface{}) *MockDataCoord_GetFlushState_Call { - return &MockDataCoord_GetFlushState_Call{Call: _e.mock.On("GetFlushState", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.GetFlushStateRequest +func (_e *MockDataCoord_Expecter) GetFlushState(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetFlushState_Call { + return &MockDataCoord_GetFlushState_Call{Call: _e.mock.On("GetFlushState", _a0, _a1)} } -func (_c *MockDataCoord_GetFlushState_Call) Run(run func(ctx context.Context, req *milvuspb.GetFlushStateRequest)) *MockDataCoord_GetFlushState_Call { +func (_c *MockDataCoord_GetFlushState_Call) Run(run func(_a0 context.Context, _a1 *datapb.GetFlushStateRequest)) *MockDataCoord_GetFlushState_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*milvuspb.GetFlushStateRequest)) + run(args[0].(context.Context), args[1].(*datapb.GetFlushStateRequest)) }) return _c } @@ -853,22 +856,22 @@ func (_c *MockDataCoord_GetFlushState_Call) Return(_a0 *milvuspb.GetFlushStateRe return _c } -func (_c *MockDataCoord_GetFlushState_Call) RunAndReturn(run func(context.Context, *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error)) *MockDataCoord_GetFlushState_Call { +func (_c *MockDataCoord_GetFlushState_Call) RunAndReturn(run func(context.Context, *datapb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error)) *MockDataCoord_GetFlushState_Call { _c.Call.Return(run) return _c } -// GetFlushedSegments provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) { - ret := _m.Called(ctx, req) +// GetFlushedSegments provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetFlushedSegments(_a0 context.Context, _a1 *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.GetFlushedSegmentsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetFlushedSegmentsRequest) *datapb.GetFlushedSegmentsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.GetFlushedSegmentsResponse) @@ -876,7 +879,7 @@ func (_m *MockDataCoord) GetFlushedSegments(ctx context.Context, req *datapb.Get } if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetFlushedSegmentsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -890,13 +893,13 @@ type MockDataCoord_GetFlushedSegments_Call struct { } // GetFlushedSegments is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.GetFlushedSegmentsRequest -func (_e *MockDataCoord_Expecter) GetFlushedSegments(ctx interface{}, req interface{}) *MockDataCoord_GetFlushedSegments_Call { - return &MockDataCoord_GetFlushedSegments_Call{Call: _e.mock.On("GetFlushedSegments", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.GetFlushedSegmentsRequest +func (_e *MockDataCoord_Expecter) GetFlushedSegments(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetFlushedSegments_Call { + return &MockDataCoord_GetFlushedSegments_Call{Call: _e.mock.On("GetFlushedSegments", _a0, _a1)} } -func (_c *MockDataCoord_GetFlushedSegments_Call) Run(run func(ctx context.Context, req *datapb.GetFlushedSegmentsRequest)) *MockDataCoord_GetFlushedSegments_Call { +func (_c *MockDataCoord_GetFlushedSegments_Call) Run(run func(_a0 context.Context, _a1 *datapb.GetFlushedSegmentsRequest)) *MockDataCoord_GetFlushedSegments_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.GetFlushedSegmentsRequest)) }) @@ -913,17 +916,17 @@ func (_c *MockDataCoord_GetFlushedSegments_Call) RunAndReturn(run func(context.C return _c } -// GetIndexBuildProgress provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetIndexBuildProgress(ctx context.Context, req *indexpb.GetIndexBuildProgressRequest) (*indexpb.GetIndexBuildProgressResponse, error) { - ret := _m.Called(ctx, req) +// GetIndexBuildProgress provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetIndexBuildProgress(_a0 context.Context, _a1 *indexpb.GetIndexBuildProgressRequest) (*indexpb.GetIndexBuildProgressResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *indexpb.GetIndexBuildProgressResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexBuildProgressRequest) (*indexpb.GetIndexBuildProgressResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexBuildProgressRequest) *indexpb.GetIndexBuildProgressResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*indexpb.GetIndexBuildProgressResponse) @@ -931,7 +934,7 @@ func (_m *MockDataCoord) GetIndexBuildProgress(ctx context.Context, req *indexpb } if rf, ok := ret.Get(1).(func(context.Context, *indexpb.GetIndexBuildProgressRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -945,13 +948,13 @@ type MockDataCoord_GetIndexBuildProgress_Call struct { } // GetIndexBuildProgress is a helper method to define mock.On call -// - ctx context.Context -// - req *indexpb.GetIndexBuildProgressRequest -func (_e *MockDataCoord_Expecter) GetIndexBuildProgress(ctx interface{}, req interface{}) *MockDataCoord_GetIndexBuildProgress_Call { - return &MockDataCoord_GetIndexBuildProgress_Call{Call: _e.mock.On("GetIndexBuildProgress", ctx, req)} +// - _a0 context.Context +// - _a1 *indexpb.GetIndexBuildProgressRequest +func (_e *MockDataCoord_Expecter) GetIndexBuildProgress(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetIndexBuildProgress_Call { + return &MockDataCoord_GetIndexBuildProgress_Call{Call: _e.mock.On("GetIndexBuildProgress", _a0, _a1)} } -func (_c *MockDataCoord_GetIndexBuildProgress_Call) Run(run func(ctx context.Context, req *indexpb.GetIndexBuildProgressRequest)) *MockDataCoord_GetIndexBuildProgress_Call { +func (_c *MockDataCoord_GetIndexBuildProgress_Call) Run(run func(_a0 context.Context, _a1 *indexpb.GetIndexBuildProgressRequest)) *MockDataCoord_GetIndexBuildProgress_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*indexpb.GetIndexBuildProgressRequest)) }) @@ -968,17 +971,17 @@ func (_c *MockDataCoord_GetIndexBuildProgress_Call) RunAndReturn(run func(contex return _c } -// GetIndexInfos provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetIndexInfos(ctx context.Context, req *indexpb.GetIndexInfoRequest) (*indexpb.GetIndexInfoResponse, error) { - ret := _m.Called(ctx, req) +// GetIndexInfos provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetIndexInfos(_a0 context.Context, _a1 *indexpb.GetIndexInfoRequest) (*indexpb.GetIndexInfoResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *indexpb.GetIndexInfoResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexInfoRequest) (*indexpb.GetIndexInfoResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexInfoRequest) *indexpb.GetIndexInfoResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*indexpb.GetIndexInfoResponse) @@ -986,7 +989,7 @@ func (_m *MockDataCoord) GetIndexInfos(ctx context.Context, req *indexpb.GetInde } if rf, ok := ret.Get(1).(func(context.Context, *indexpb.GetIndexInfoRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1000,13 +1003,13 @@ type MockDataCoord_GetIndexInfos_Call struct { } // GetIndexInfos is a helper method to define mock.On call -// - ctx context.Context -// - req *indexpb.GetIndexInfoRequest -func (_e *MockDataCoord_Expecter) GetIndexInfos(ctx interface{}, req interface{}) *MockDataCoord_GetIndexInfos_Call { - return &MockDataCoord_GetIndexInfos_Call{Call: _e.mock.On("GetIndexInfos", ctx, req)} +// - _a0 context.Context +// - _a1 *indexpb.GetIndexInfoRequest +func (_e *MockDataCoord_Expecter) GetIndexInfos(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetIndexInfos_Call { + return &MockDataCoord_GetIndexInfos_Call{Call: _e.mock.On("GetIndexInfos", _a0, _a1)} } -func (_c *MockDataCoord_GetIndexInfos_Call) Run(run func(ctx context.Context, req *indexpb.GetIndexInfoRequest)) *MockDataCoord_GetIndexInfos_Call { +func (_c *MockDataCoord_GetIndexInfos_Call) Run(run func(_a0 context.Context, _a1 *indexpb.GetIndexInfoRequest)) *MockDataCoord_GetIndexInfos_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*indexpb.GetIndexInfoRequest)) }) @@ -1023,17 +1026,17 @@ func (_c *MockDataCoord_GetIndexInfos_Call) RunAndReturn(run func(context.Contex return _c } -// GetIndexState provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetIndexState(ctx context.Context, req *indexpb.GetIndexStateRequest) (*indexpb.GetIndexStateResponse, error) { - ret := _m.Called(ctx, req) +// GetIndexState provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetIndexState(_a0 context.Context, _a1 *indexpb.GetIndexStateRequest) (*indexpb.GetIndexStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *indexpb.GetIndexStateResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexStateRequest) (*indexpb.GetIndexStateResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexStateRequest) *indexpb.GetIndexStateResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*indexpb.GetIndexStateResponse) @@ -1041,7 +1044,7 @@ func (_m *MockDataCoord) GetIndexState(ctx context.Context, req *indexpb.GetInde } if rf, ok := ret.Get(1).(func(context.Context, *indexpb.GetIndexStateRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1055,13 +1058,13 @@ type MockDataCoord_GetIndexState_Call struct { } // GetIndexState is a helper method to define mock.On call -// - ctx context.Context -// - req *indexpb.GetIndexStateRequest -func (_e *MockDataCoord_Expecter) GetIndexState(ctx interface{}, req interface{}) *MockDataCoord_GetIndexState_Call { - return &MockDataCoord_GetIndexState_Call{Call: _e.mock.On("GetIndexState", ctx, req)} +// - _a0 context.Context +// - _a1 *indexpb.GetIndexStateRequest +func (_e *MockDataCoord_Expecter) GetIndexState(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetIndexState_Call { + return &MockDataCoord_GetIndexState_Call{Call: _e.mock.On("GetIndexState", _a0, _a1)} } -func (_c *MockDataCoord_GetIndexState_Call) Run(run func(ctx context.Context, req *indexpb.GetIndexStateRequest)) *MockDataCoord_GetIndexState_Call { +func (_c *MockDataCoord_GetIndexState_Call) Run(run func(_a0 context.Context, _a1 *indexpb.GetIndexStateRequest)) *MockDataCoord_GetIndexState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*indexpb.GetIndexStateRequest)) }) @@ -1078,17 +1081,17 @@ func (_c *MockDataCoord_GetIndexState_Call) RunAndReturn(run func(context.Contex return _c } -// GetIndexStatistics provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetIndexStatistics(ctx context.Context, req *indexpb.GetIndexStatisticsRequest) (*indexpb.GetIndexStatisticsResponse, error) { - ret := _m.Called(ctx, req) +// GetIndexStatistics provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetIndexStatistics(_a0 context.Context, _a1 *indexpb.GetIndexStatisticsRequest) (*indexpb.GetIndexStatisticsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *indexpb.GetIndexStatisticsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexStatisticsRequest) (*indexpb.GetIndexStatisticsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexStatisticsRequest) *indexpb.GetIndexStatisticsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*indexpb.GetIndexStatisticsResponse) @@ -1096,7 +1099,7 @@ func (_m *MockDataCoord) GetIndexStatistics(ctx context.Context, req *indexpb.Ge } if rf, ok := ret.Get(1).(func(context.Context, *indexpb.GetIndexStatisticsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1110,13 +1113,13 @@ type MockDataCoord_GetIndexStatistics_Call struct { } // GetIndexStatistics is a helper method to define mock.On call -// - ctx context.Context -// - req *indexpb.GetIndexStatisticsRequest -func (_e *MockDataCoord_Expecter) GetIndexStatistics(ctx interface{}, req interface{}) *MockDataCoord_GetIndexStatistics_Call { - return &MockDataCoord_GetIndexStatistics_Call{Call: _e.mock.On("GetIndexStatistics", ctx, req)} +// - _a0 context.Context +// - _a1 *indexpb.GetIndexStatisticsRequest +func (_e *MockDataCoord_Expecter) GetIndexStatistics(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetIndexStatistics_Call { + return &MockDataCoord_GetIndexStatistics_Call{Call: _e.mock.On("GetIndexStatistics", _a0, _a1)} } -func (_c *MockDataCoord_GetIndexStatistics_Call) Run(run func(ctx context.Context, req *indexpb.GetIndexStatisticsRequest)) *MockDataCoord_GetIndexStatistics_Call { +func (_c *MockDataCoord_GetIndexStatistics_Call) Run(run func(_a0 context.Context, _a1 *indexpb.GetIndexStatisticsRequest)) *MockDataCoord_GetIndexStatistics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*indexpb.GetIndexStatisticsRequest)) }) @@ -1133,17 +1136,17 @@ func (_c *MockDataCoord_GetIndexStatistics_Call) RunAndReturn(run func(context.C return _c } -// GetInsertBinlogPaths provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsertBinlogPathsRequest) (*datapb.GetInsertBinlogPathsResponse, error) { - ret := _m.Called(ctx, req) +// GetInsertBinlogPaths provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetInsertBinlogPaths(_a0 context.Context, _a1 *datapb.GetInsertBinlogPathsRequest) (*datapb.GetInsertBinlogPathsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.GetInsertBinlogPathsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetInsertBinlogPathsRequest) (*datapb.GetInsertBinlogPathsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetInsertBinlogPathsRequest) *datapb.GetInsertBinlogPathsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.GetInsertBinlogPathsResponse) @@ -1151,7 +1154,7 @@ func (_m *MockDataCoord) GetInsertBinlogPaths(ctx context.Context, req *datapb.G } if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetInsertBinlogPathsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1165,13 +1168,13 @@ type MockDataCoord_GetInsertBinlogPaths_Call struct { } // GetInsertBinlogPaths is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.GetInsertBinlogPathsRequest -func (_e *MockDataCoord_Expecter) GetInsertBinlogPaths(ctx interface{}, req interface{}) *MockDataCoord_GetInsertBinlogPaths_Call { - return &MockDataCoord_GetInsertBinlogPaths_Call{Call: _e.mock.On("GetInsertBinlogPaths", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.GetInsertBinlogPathsRequest +func (_e *MockDataCoord_Expecter) GetInsertBinlogPaths(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetInsertBinlogPaths_Call { + return &MockDataCoord_GetInsertBinlogPaths_Call{Call: _e.mock.On("GetInsertBinlogPaths", _a0, _a1)} } -func (_c *MockDataCoord_GetInsertBinlogPaths_Call) Run(run func(ctx context.Context, req *datapb.GetInsertBinlogPathsRequest)) *MockDataCoord_GetInsertBinlogPaths_Call { +func (_c *MockDataCoord_GetInsertBinlogPaths_Call) Run(run func(_a0 context.Context, _a1 *datapb.GetInsertBinlogPathsRequest)) *MockDataCoord_GetInsertBinlogPaths_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.GetInsertBinlogPathsRequest)) }) @@ -1188,17 +1191,17 @@ func (_c *MockDataCoord_GetInsertBinlogPaths_Call) RunAndReturn(run func(context return _c } -// GetMetrics provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret := _m.Called(ctx, req) +// GetMetrics provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetMetrics(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetMetricsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) *milvuspb.GetMetricsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) @@ -1206,7 +1209,7 @@ func (_m *MockDataCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetric } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1220,13 +1223,13 @@ type MockDataCoord_GetMetrics_Call struct { } // GetMetrics is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetMetricsRequest -func (_e *MockDataCoord_Expecter) GetMetrics(ctx interface{}, req interface{}) *MockDataCoord_GetMetrics_Call { - return &MockDataCoord_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetMetricsRequest +func (_e *MockDataCoord_Expecter) GetMetrics(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetMetrics_Call { + return &MockDataCoord_GetMetrics_Call{Call: _e.mock.On("GetMetrics", _a0, _a1)} } -func (_c *MockDataCoord_GetMetrics_Call) Run(run func(ctx context.Context, req *milvuspb.GetMetricsRequest)) *MockDataCoord_GetMetrics_Call { +func (_c *MockDataCoord_GetMetrics_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest)) *MockDataCoord_GetMetrics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest)) }) @@ -1243,17 +1246,17 @@ func (_c *MockDataCoord_GetMetrics_Call) RunAndReturn(run func(context.Context, return _c } -// GetPartitionStatistics provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetPartitionStatistics(ctx context.Context, req *datapb.GetPartitionStatisticsRequest) (*datapb.GetPartitionStatisticsResponse, error) { - ret := _m.Called(ctx, req) +// GetPartitionStatistics provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetPartitionStatistics(_a0 context.Context, _a1 *datapb.GetPartitionStatisticsRequest) (*datapb.GetPartitionStatisticsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.GetPartitionStatisticsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetPartitionStatisticsRequest) (*datapb.GetPartitionStatisticsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetPartitionStatisticsRequest) *datapb.GetPartitionStatisticsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.GetPartitionStatisticsResponse) @@ -1261,7 +1264,7 @@ func (_m *MockDataCoord) GetPartitionStatistics(ctx context.Context, req *datapb } if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetPartitionStatisticsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1275,13 +1278,13 @@ type MockDataCoord_GetPartitionStatistics_Call struct { } // GetPartitionStatistics is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.GetPartitionStatisticsRequest -func (_e *MockDataCoord_Expecter) GetPartitionStatistics(ctx interface{}, req interface{}) *MockDataCoord_GetPartitionStatistics_Call { - return &MockDataCoord_GetPartitionStatistics_Call{Call: _e.mock.On("GetPartitionStatistics", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.GetPartitionStatisticsRequest +func (_e *MockDataCoord_Expecter) GetPartitionStatistics(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetPartitionStatistics_Call { + return &MockDataCoord_GetPartitionStatistics_Call{Call: _e.mock.On("GetPartitionStatistics", _a0, _a1)} } -func (_c *MockDataCoord_GetPartitionStatistics_Call) Run(run func(ctx context.Context, req *datapb.GetPartitionStatisticsRequest)) *MockDataCoord_GetPartitionStatistics_Call { +func (_c *MockDataCoord_GetPartitionStatistics_Call) Run(run func(_a0 context.Context, _a1 *datapb.GetPartitionStatisticsRequest)) *MockDataCoord_GetPartitionStatistics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.GetPartitionStatisticsRequest)) }) @@ -1298,17 +1301,17 @@ func (_c *MockDataCoord_GetPartitionStatistics_Call) RunAndReturn(run func(conte return _c } -// GetRecoveryInfo provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInfoRequest) (*datapb.GetRecoveryInfoResponse, error) { - ret := _m.Called(ctx, req) +// GetRecoveryInfo provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetRecoveryInfo(_a0 context.Context, _a1 *datapb.GetRecoveryInfoRequest) (*datapb.GetRecoveryInfoResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.GetRecoveryInfoResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetRecoveryInfoRequest) (*datapb.GetRecoveryInfoResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetRecoveryInfoRequest) *datapb.GetRecoveryInfoResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.GetRecoveryInfoResponse) @@ -1316,7 +1319,7 @@ func (_m *MockDataCoord) GetRecoveryInfo(ctx context.Context, req *datapb.GetRec } if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetRecoveryInfoRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1330,13 +1333,13 @@ type MockDataCoord_GetRecoveryInfo_Call struct { } // GetRecoveryInfo is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.GetRecoveryInfoRequest -func (_e *MockDataCoord_Expecter) GetRecoveryInfo(ctx interface{}, req interface{}) *MockDataCoord_GetRecoveryInfo_Call { - return &MockDataCoord_GetRecoveryInfo_Call{Call: _e.mock.On("GetRecoveryInfo", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.GetRecoveryInfoRequest +func (_e *MockDataCoord_Expecter) GetRecoveryInfo(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetRecoveryInfo_Call { + return &MockDataCoord_GetRecoveryInfo_Call{Call: _e.mock.On("GetRecoveryInfo", _a0, _a1)} } -func (_c *MockDataCoord_GetRecoveryInfo_Call) Run(run func(ctx context.Context, req *datapb.GetRecoveryInfoRequest)) *MockDataCoord_GetRecoveryInfo_Call { +func (_c *MockDataCoord_GetRecoveryInfo_Call) Run(run func(_a0 context.Context, _a1 *datapb.GetRecoveryInfoRequest)) *MockDataCoord_GetRecoveryInfo_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.GetRecoveryInfoRequest)) }) @@ -1353,17 +1356,17 @@ func (_c *MockDataCoord_GetRecoveryInfo_Call) RunAndReturn(run func(context.Cont return _c } -// GetRecoveryInfoV2 provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryInfoRequestV2) (*datapb.GetRecoveryInfoResponseV2, error) { - ret := _m.Called(ctx, req) +// GetRecoveryInfoV2 provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetRecoveryInfoV2(_a0 context.Context, _a1 *datapb.GetRecoveryInfoRequestV2) (*datapb.GetRecoveryInfoResponseV2, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.GetRecoveryInfoResponseV2 var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetRecoveryInfoRequestV2) (*datapb.GetRecoveryInfoResponseV2, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetRecoveryInfoRequestV2) *datapb.GetRecoveryInfoResponseV2); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.GetRecoveryInfoResponseV2) @@ -1371,7 +1374,7 @@ func (_m *MockDataCoord) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetR } if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetRecoveryInfoRequestV2) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1385,13 +1388,13 @@ type MockDataCoord_GetRecoveryInfoV2_Call struct { } // GetRecoveryInfoV2 is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.GetRecoveryInfoRequestV2 -func (_e *MockDataCoord_Expecter) GetRecoveryInfoV2(ctx interface{}, req interface{}) *MockDataCoord_GetRecoveryInfoV2_Call { - return &MockDataCoord_GetRecoveryInfoV2_Call{Call: _e.mock.On("GetRecoveryInfoV2", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.GetRecoveryInfoRequestV2 +func (_e *MockDataCoord_Expecter) GetRecoveryInfoV2(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetRecoveryInfoV2_Call { + return &MockDataCoord_GetRecoveryInfoV2_Call{Call: _e.mock.On("GetRecoveryInfoV2", _a0, _a1)} } -func (_c *MockDataCoord_GetRecoveryInfoV2_Call) Run(run func(ctx context.Context, req *datapb.GetRecoveryInfoRequestV2)) *MockDataCoord_GetRecoveryInfoV2_Call { +func (_c *MockDataCoord_GetRecoveryInfoV2_Call) Run(run func(_a0 context.Context, _a1 *datapb.GetRecoveryInfoRequestV2)) *MockDataCoord_GetRecoveryInfoV2_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.GetRecoveryInfoRequestV2)) }) @@ -1408,17 +1411,17 @@ func (_c *MockDataCoord_GetRecoveryInfoV2_Call) RunAndReturn(run func(context.Co return _c } -// GetSegmentIndexState provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetSegmentIndexState(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest) (*indexpb.GetSegmentIndexStateResponse, error) { - ret := _m.Called(ctx, req) +// GetSegmentIndexState provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetSegmentIndexState(_a0 context.Context, _a1 *indexpb.GetSegmentIndexStateRequest) (*indexpb.GetSegmentIndexStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *indexpb.GetSegmentIndexStateResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetSegmentIndexStateRequest) (*indexpb.GetSegmentIndexStateResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetSegmentIndexStateRequest) *indexpb.GetSegmentIndexStateResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*indexpb.GetSegmentIndexStateResponse) @@ -1426,7 +1429,7 @@ func (_m *MockDataCoord) GetSegmentIndexState(ctx context.Context, req *indexpb. } if rf, ok := ret.Get(1).(func(context.Context, *indexpb.GetSegmentIndexStateRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1440,13 +1443,13 @@ type MockDataCoord_GetSegmentIndexState_Call struct { } // GetSegmentIndexState is a helper method to define mock.On call -// - ctx context.Context -// - req *indexpb.GetSegmentIndexStateRequest -func (_e *MockDataCoord_Expecter) GetSegmentIndexState(ctx interface{}, req interface{}) *MockDataCoord_GetSegmentIndexState_Call { - return &MockDataCoord_GetSegmentIndexState_Call{Call: _e.mock.On("GetSegmentIndexState", ctx, req)} +// - _a0 context.Context +// - _a1 *indexpb.GetSegmentIndexStateRequest +func (_e *MockDataCoord_Expecter) GetSegmentIndexState(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetSegmentIndexState_Call { + return &MockDataCoord_GetSegmentIndexState_Call{Call: _e.mock.On("GetSegmentIndexState", _a0, _a1)} } -func (_c *MockDataCoord_GetSegmentIndexState_Call) Run(run func(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest)) *MockDataCoord_GetSegmentIndexState_Call { +func (_c *MockDataCoord_GetSegmentIndexState_Call) Run(run func(_a0 context.Context, _a1 *indexpb.GetSegmentIndexStateRequest)) *MockDataCoord_GetSegmentIndexState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*indexpb.GetSegmentIndexStateRequest)) }) @@ -1463,17 +1466,17 @@ func (_c *MockDataCoord_GetSegmentIndexState_Call) RunAndReturn(run func(context return _c } -// GetSegmentInfo provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoRequest) (*datapb.GetSegmentInfoResponse, error) { - ret := _m.Called(ctx, req) +// GetSegmentInfo provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetSegmentInfo(_a0 context.Context, _a1 *datapb.GetSegmentInfoRequest) (*datapb.GetSegmentInfoResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.GetSegmentInfoResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentInfoRequest) (*datapb.GetSegmentInfoResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentInfoRequest) *datapb.GetSegmentInfoResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.GetSegmentInfoResponse) @@ -1481,7 +1484,7 @@ func (_m *MockDataCoord) GetSegmentInfo(ctx context.Context, req *datapb.GetSegm } if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetSegmentInfoRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1495,13 +1498,13 @@ type MockDataCoord_GetSegmentInfo_Call struct { } // GetSegmentInfo is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.GetSegmentInfoRequest -func (_e *MockDataCoord_Expecter) GetSegmentInfo(ctx interface{}, req interface{}) *MockDataCoord_GetSegmentInfo_Call { - return &MockDataCoord_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.GetSegmentInfoRequest +func (_e *MockDataCoord_Expecter) GetSegmentInfo(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetSegmentInfo_Call { + return &MockDataCoord_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", _a0, _a1)} } -func (_c *MockDataCoord_GetSegmentInfo_Call) Run(run func(ctx context.Context, req *datapb.GetSegmentInfoRequest)) *MockDataCoord_GetSegmentInfo_Call { +func (_c *MockDataCoord_GetSegmentInfo_Call) Run(run func(_a0 context.Context, _a1 *datapb.GetSegmentInfoRequest)) *MockDataCoord_GetSegmentInfo_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.GetSegmentInfoRequest)) }) @@ -1518,25 +1521,25 @@ func (_c *MockDataCoord_GetSegmentInfo_Call) RunAndReturn(run func(context.Conte return _c } -// GetSegmentInfoChannel provides a mock function with given fields: ctx -func (_m *MockDataCoord) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret := _m.Called(ctx) +// GetSegmentInfoChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetSegmentInfoChannel(_a0 context.Context, _a1 *datapb.GetSegmentInfoChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.StringResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.StringResponse, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentInfoChannelRequest) (*milvuspb.StringResponse, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.StringResponse); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentInfoChannelRequest) *milvuspb.StringResponse); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.StringResponse) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetSegmentInfoChannelRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1550,14 +1553,15 @@ type MockDataCoord_GetSegmentInfoChannel_Call struct { } // GetSegmentInfoChannel is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockDataCoord_Expecter) GetSegmentInfoChannel(ctx interface{}) *MockDataCoord_GetSegmentInfoChannel_Call { - return &MockDataCoord_GetSegmentInfoChannel_Call{Call: _e.mock.On("GetSegmentInfoChannel", ctx)} +// - _a0 context.Context +// - _a1 *datapb.GetSegmentInfoChannelRequest +func (_e *MockDataCoord_Expecter) GetSegmentInfoChannel(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetSegmentInfoChannel_Call { + return &MockDataCoord_GetSegmentInfoChannel_Call{Call: _e.mock.On("GetSegmentInfoChannel", _a0, _a1)} } -func (_c *MockDataCoord_GetSegmentInfoChannel_Call) Run(run func(ctx context.Context)) *MockDataCoord_GetSegmentInfoChannel_Call { +func (_c *MockDataCoord_GetSegmentInfoChannel_Call) Run(run func(_a0 context.Context, _a1 *datapb.GetSegmentInfoChannelRequest)) *MockDataCoord_GetSegmentInfoChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*datapb.GetSegmentInfoChannelRequest)) }) return _c } @@ -1567,22 +1571,22 @@ func (_c *MockDataCoord_GetSegmentInfoChannel_Call) Return(_a0 *milvuspb.StringR return _c } -func (_c *MockDataCoord_GetSegmentInfoChannel_Call) RunAndReturn(run func(context.Context) (*milvuspb.StringResponse, error)) *MockDataCoord_GetSegmentInfoChannel_Call { +func (_c *MockDataCoord_GetSegmentInfoChannel_Call) RunAndReturn(run func(context.Context, *datapb.GetSegmentInfoChannelRequest) (*milvuspb.StringResponse, error)) *MockDataCoord_GetSegmentInfoChannel_Call { _c.Call.Return(run) return _c } -// GetSegmentStates provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { - ret := _m.Called(ctx, req) +// GetSegmentStates provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetSegmentStates(_a0 context.Context, _a1 *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.GetSegmentStatesResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentStatesRequest) *datapb.GetSegmentStatesResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.GetSegmentStatesResponse) @@ -1590,7 +1594,7 @@ func (_m *MockDataCoord) GetSegmentStates(ctx context.Context, req *datapb.GetSe } if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetSegmentStatesRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1604,13 +1608,13 @@ type MockDataCoord_GetSegmentStates_Call struct { } // GetSegmentStates is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.GetSegmentStatesRequest -func (_e *MockDataCoord_Expecter) GetSegmentStates(ctx interface{}, req interface{}) *MockDataCoord_GetSegmentStates_Call { - return &MockDataCoord_GetSegmentStates_Call{Call: _e.mock.On("GetSegmentStates", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.GetSegmentStatesRequest +func (_e *MockDataCoord_Expecter) GetSegmentStates(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetSegmentStates_Call { + return &MockDataCoord_GetSegmentStates_Call{Call: _e.mock.On("GetSegmentStates", _a0, _a1)} } -func (_c *MockDataCoord_GetSegmentStates_Call) Run(run func(ctx context.Context, req *datapb.GetSegmentStatesRequest)) *MockDataCoord_GetSegmentStates_Call { +func (_c *MockDataCoord_GetSegmentStates_Call) Run(run func(_a0 context.Context, _a1 *datapb.GetSegmentStatesRequest)) *MockDataCoord_GetSegmentStates_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.GetSegmentStatesRequest)) }) @@ -1627,17 +1631,17 @@ func (_c *MockDataCoord_GetSegmentStates_Call) RunAndReturn(run func(context.Con return _c } -// GetSegmentsByStates provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) GetSegmentsByStates(ctx context.Context, req *datapb.GetSegmentsByStatesRequest) (*datapb.GetSegmentsByStatesResponse, error) { - ret := _m.Called(ctx, req) +// GetSegmentsByStates provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetSegmentsByStates(_a0 context.Context, _a1 *datapb.GetSegmentsByStatesRequest) (*datapb.GetSegmentsByStatesResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.GetSegmentsByStatesResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentsByStatesRequest) (*datapb.GetSegmentsByStatesResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentsByStatesRequest) *datapb.GetSegmentsByStatesResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.GetSegmentsByStatesResponse) @@ -1645,7 +1649,7 @@ func (_m *MockDataCoord) GetSegmentsByStates(ctx context.Context, req *datapb.Ge } if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetSegmentsByStatesRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1659,13 +1663,13 @@ type MockDataCoord_GetSegmentsByStates_Call struct { } // GetSegmentsByStates is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.GetSegmentsByStatesRequest -func (_e *MockDataCoord_Expecter) GetSegmentsByStates(ctx interface{}, req interface{}) *MockDataCoord_GetSegmentsByStates_Call { - return &MockDataCoord_GetSegmentsByStates_Call{Call: _e.mock.On("GetSegmentsByStates", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.GetSegmentsByStatesRequest +func (_e *MockDataCoord_Expecter) GetSegmentsByStates(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetSegmentsByStates_Call { + return &MockDataCoord_GetSegmentsByStates_Call{Call: _e.mock.On("GetSegmentsByStates", _a0, _a1)} } -func (_c *MockDataCoord_GetSegmentsByStates_Call) Run(run func(ctx context.Context, req *datapb.GetSegmentsByStatesRequest)) *MockDataCoord_GetSegmentsByStates_Call { +func (_c *MockDataCoord_GetSegmentsByStates_Call) Run(run func(_a0 context.Context, _a1 *datapb.GetSegmentsByStatesRequest)) *MockDataCoord_GetSegmentsByStates_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.GetSegmentsByStatesRequest)) }) @@ -1682,25 +1686,25 @@ func (_c *MockDataCoord_GetSegmentsByStates_Call) RunAndReturn(run func(context. return _c } -// GetStatisticsChannel provides a mock function with given fields: ctx -func (_m *MockDataCoord) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret := _m.Called(ctx) +// GetStatisticsChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetStatisticsChannel(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.StringResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.StringResponse, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.StringResponse); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) *milvuspb.StringResponse); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.StringResponse) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1714,14 +1718,15 @@ type MockDataCoord_GetStatisticsChannel_Call struct { } // GetStatisticsChannel is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockDataCoord_Expecter) GetStatisticsChannel(ctx interface{}) *MockDataCoord_GetStatisticsChannel_Call { - return &MockDataCoord_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", ctx)} +// - _a0 context.Context +// - _a1 *internalpb.GetStatisticsChannelRequest +func (_e *MockDataCoord_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetStatisticsChannel_Call { + return &MockDataCoord_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)} } -func (_c *MockDataCoord_GetStatisticsChannel_Call) Run(run func(ctx context.Context)) *MockDataCoord_GetStatisticsChannel_Call { +func (_c *MockDataCoord_GetStatisticsChannel_Call) Run(run func(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest)) *MockDataCoord_GetStatisticsChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest)) }) return _c } @@ -1731,30 +1736,30 @@ func (_c *MockDataCoord_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringRe return _c } -func (_c *MockDataCoord_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context) (*milvuspb.StringResponse, error)) *MockDataCoord_GetStatisticsChannel_Call { +func (_c *MockDataCoord_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)) *MockDataCoord_GetStatisticsChannel_Call { _c.Call.Return(run) return _c } -// GetTimeTickChannel provides a mock function with given fields: ctx -func (_m *MockDataCoord) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret := _m.Called(ctx) +// GetTimeTickChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) GetTimeTickChannel(_a0 context.Context, _a1 *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.StringResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.StringResponse, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.StringResponse); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest) *milvuspb.StringResponse); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.StringResponse) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetTimeTickChannelRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1768,14 +1773,15 @@ type MockDataCoord_GetTimeTickChannel_Call struct { } // GetTimeTickChannel is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockDataCoord_Expecter) GetTimeTickChannel(ctx interface{}) *MockDataCoord_GetTimeTickChannel_Call { - return &MockDataCoord_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", ctx)} +// - _a0 context.Context +// - _a1 *internalpb.GetTimeTickChannelRequest +func (_e *MockDataCoord_Expecter) GetTimeTickChannel(_a0 interface{}, _a1 interface{}) *MockDataCoord_GetTimeTickChannel_Call { + return &MockDataCoord_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", _a0, _a1)} } -func (_c *MockDataCoord_GetTimeTickChannel_Call) Run(run func(ctx context.Context)) *MockDataCoord_GetTimeTickChannel_Call { +func (_c *MockDataCoord_GetTimeTickChannel_Call) Run(run func(_a0 context.Context, _a1 *internalpb.GetTimeTickChannelRequest)) *MockDataCoord_GetTimeTickChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*internalpb.GetTimeTickChannelRequest)) }) return _c } @@ -1785,22 +1791,22 @@ func (_c *MockDataCoord_GetTimeTickChannel_Call) Return(_a0 *milvuspb.StringResp return _c } -func (_c *MockDataCoord_GetTimeTickChannel_Call) RunAndReturn(run func(context.Context) (*milvuspb.StringResponse, error)) *MockDataCoord_GetTimeTickChannel_Call { +func (_c *MockDataCoord_GetTimeTickChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error)) *MockDataCoord_GetTimeTickChannel_Call { _c.Call.Return(run) return _c } -// Import provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { - ret := _m.Called(ctx, req) +// Import provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) Import(_a0 context.Context, _a1 *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.ImportTaskResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.ImportTaskRequest) *datapb.ImportTaskResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.ImportTaskResponse) @@ -1808,7 +1814,7 @@ func (_m *MockDataCoord) Import(ctx context.Context, req *datapb.ImportTaskReque } if rf, ok := ret.Get(1).(func(context.Context, *datapb.ImportTaskRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1822,13 +1828,13 @@ type MockDataCoord_Import_Call struct { } // Import is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.ImportTaskRequest -func (_e *MockDataCoord_Expecter) Import(ctx interface{}, req interface{}) *MockDataCoord_Import_Call { - return &MockDataCoord_Import_Call{Call: _e.mock.On("Import", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.ImportTaskRequest +func (_e *MockDataCoord_Expecter) Import(_a0 interface{}, _a1 interface{}) *MockDataCoord_Import_Call { + return &MockDataCoord_Import_Call{Call: _e.mock.On("Import", _a0, _a1)} } -func (_c *MockDataCoord_Import_Call) Run(run func(ctx context.Context, req *datapb.ImportTaskRequest)) *MockDataCoord_Import_Call { +func (_c *MockDataCoord_Import_Call) Run(run func(_a0 context.Context, _a1 *datapb.ImportTaskRequest)) *MockDataCoord_Import_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.ImportTaskRequest)) }) @@ -1886,17 +1892,17 @@ func (_c *MockDataCoord_Init_Call) RunAndReturn(run func() error) *MockDataCoord return _c } -// ManualCompaction provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) { - ret := _m.Called(ctx, req) +// ManualCompaction provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) ManualCompaction(_a0 context.Context, _a1 *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ManualCompactionResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ManualCompactionRequest) *milvuspb.ManualCompactionResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ManualCompactionResponse) @@ -1904,7 +1910,7 @@ func (_m *MockDataCoord) ManualCompaction(ctx context.Context, req *milvuspb.Man } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ManualCompactionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1918,13 +1924,13 @@ type MockDataCoord_ManualCompaction_Call struct { } // ManualCompaction is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ManualCompactionRequest -func (_e *MockDataCoord_Expecter) ManualCompaction(ctx interface{}, req interface{}) *MockDataCoord_ManualCompaction_Call { - return &MockDataCoord_ManualCompaction_Call{Call: _e.mock.On("ManualCompaction", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ManualCompactionRequest +func (_e *MockDataCoord_Expecter) ManualCompaction(_a0 interface{}, _a1 interface{}) *MockDataCoord_ManualCompaction_Call { + return &MockDataCoord_ManualCompaction_Call{Call: _e.mock.On("ManualCompaction", _a0, _a1)} } -func (_c *MockDataCoord_ManualCompaction_Call) Run(run func(ctx context.Context, req *milvuspb.ManualCompactionRequest)) *MockDataCoord_ManualCompaction_Call { +func (_c *MockDataCoord_ManualCompaction_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ManualCompactionRequest)) *MockDataCoord_ManualCompaction_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ManualCompactionRequest)) }) @@ -1941,17 +1947,17 @@ func (_c *MockDataCoord_ManualCompaction_Call) RunAndReturn(run func(context.Con return _c } -// MarkSegmentsDropped provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) MarkSegmentsDropped(ctx context.Context, req *datapb.MarkSegmentsDroppedRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// MarkSegmentsDropped provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) MarkSegmentsDropped(_a0 context.Context, _a1 *datapb.MarkSegmentsDroppedRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.MarkSegmentsDroppedRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.MarkSegmentsDroppedRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1959,7 +1965,7 @@ func (_m *MockDataCoord) MarkSegmentsDropped(ctx context.Context, req *datapb.Ma } if rf, ok := ret.Get(1).(func(context.Context, *datapb.MarkSegmentsDroppedRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1973,13 +1979,13 @@ type MockDataCoord_MarkSegmentsDropped_Call struct { } // MarkSegmentsDropped is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.MarkSegmentsDroppedRequest -func (_e *MockDataCoord_Expecter) MarkSegmentsDropped(ctx interface{}, req interface{}) *MockDataCoord_MarkSegmentsDropped_Call { - return &MockDataCoord_MarkSegmentsDropped_Call{Call: _e.mock.On("MarkSegmentsDropped", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.MarkSegmentsDroppedRequest +func (_e *MockDataCoord_Expecter) MarkSegmentsDropped(_a0 interface{}, _a1 interface{}) *MockDataCoord_MarkSegmentsDropped_Call { + return &MockDataCoord_MarkSegmentsDropped_Call{Call: _e.mock.On("MarkSegmentsDropped", _a0, _a1)} } -func (_c *MockDataCoord_MarkSegmentsDropped_Call) Run(run func(ctx context.Context, req *datapb.MarkSegmentsDroppedRequest)) *MockDataCoord_MarkSegmentsDropped_Call { +func (_c *MockDataCoord_MarkSegmentsDropped_Call) Run(run func(_a0 context.Context, _a1 *datapb.MarkSegmentsDroppedRequest)) *MockDataCoord_MarkSegmentsDropped_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.MarkSegmentsDroppedRequest)) }) @@ -2037,17 +2043,17 @@ func (_c *MockDataCoord_Register_Call) RunAndReturn(run func() error) *MockDataC return _c } -// ReportDataNodeTtMsgs provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDataNodeTtMsgsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// ReportDataNodeTtMsgs provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) ReportDataNodeTtMsgs(_a0 context.Context, _a1 *datapb.ReportDataNodeTtMsgsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.ReportDataNodeTtMsgsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.ReportDataNodeTtMsgsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -2055,7 +2061,7 @@ func (_m *MockDataCoord) ReportDataNodeTtMsgs(ctx context.Context, req *datapb.R } if rf, ok := ret.Get(1).(func(context.Context, *datapb.ReportDataNodeTtMsgsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2069,13 +2075,13 @@ type MockDataCoord_ReportDataNodeTtMsgs_Call struct { } // ReportDataNodeTtMsgs is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.ReportDataNodeTtMsgsRequest -func (_e *MockDataCoord_Expecter) ReportDataNodeTtMsgs(ctx interface{}, req interface{}) *MockDataCoord_ReportDataNodeTtMsgs_Call { - return &MockDataCoord_ReportDataNodeTtMsgs_Call{Call: _e.mock.On("ReportDataNodeTtMsgs", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.ReportDataNodeTtMsgsRequest +func (_e *MockDataCoord_Expecter) ReportDataNodeTtMsgs(_a0 interface{}, _a1 interface{}) *MockDataCoord_ReportDataNodeTtMsgs_Call { + return &MockDataCoord_ReportDataNodeTtMsgs_Call{Call: _e.mock.On("ReportDataNodeTtMsgs", _a0, _a1)} } -func (_c *MockDataCoord_ReportDataNodeTtMsgs_Call) Run(run func(ctx context.Context, req *datapb.ReportDataNodeTtMsgsRequest)) *MockDataCoord_ReportDataNodeTtMsgs_Call { +func (_c *MockDataCoord_ReportDataNodeTtMsgs_Call) Run(run func(_a0 context.Context, _a1 *datapb.ReportDataNodeTtMsgsRequest)) *MockDataCoord_ReportDataNodeTtMsgs_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.ReportDataNodeTtMsgsRequest)) }) @@ -2092,17 +2098,17 @@ func (_c *MockDataCoord_ReportDataNodeTtMsgs_Call) RunAndReturn(run func(context return _c } -// SaveBinlogPaths provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// SaveBinlogPaths provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) SaveBinlogPaths(_a0 context.Context, _a1 *datapb.SaveBinlogPathsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.SaveBinlogPathsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.SaveBinlogPathsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -2110,7 +2116,7 @@ func (_m *MockDataCoord) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBi } if rf, ok := ret.Get(1).(func(context.Context, *datapb.SaveBinlogPathsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2124,13 +2130,13 @@ type MockDataCoord_SaveBinlogPaths_Call struct { } // SaveBinlogPaths is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.SaveBinlogPathsRequest -func (_e *MockDataCoord_Expecter) SaveBinlogPaths(ctx interface{}, req interface{}) *MockDataCoord_SaveBinlogPaths_Call { - return &MockDataCoord_SaveBinlogPaths_Call{Call: _e.mock.On("SaveBinlogPaths", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.SaveBinlogPathsRequest +func (_e *MockDataCoord_Expecter) SaveBinlogPaths(_a0 interface{}, _a1 interface{}) *MockDataCoord_SaveBinlogPaths_Call { + return &MockDataCoord_SaveBinlogPaths_Call{Call: _e.mock.On("SaveBinlogPaths", _a0, _a1)} } -func (_c *MockDataCoord_SaveBinlogPaths_Call) Run(run func(ctx context.Context, req *datapb.SaveBinlogPathsRequest)) *MockDataCoord_SaveBinlogPaths_Call { +func (_c *MockDataCoord_SaveBinlogPaths_Call) Run(run func(_a0 context.Context, _a1 *datapb.SaveBinlogPathsRequest)) *MockDataCoord_SaveBinlogPaths_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.SaveBinlogPathsRequest)) }) @@ -2147,17 +2153,17 @@ func (_c *MockDataCoord_SaveBinlogPaths_Call) RunAndReturn(run func(context.Cont return _c } -// SaveImportSegment provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// SaveImportSegment provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) SaveImportSegment(_a0 context.Context, _a1 *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.SaveImportSegmentRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.SaveImportSegmentRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -2165,7 +2171,7 @@ func (_m *MockDataCoord) SaveImportSegment(ctx context.Context, req *datapb.Save } if rf, ok := ret.Get(1).(func(context.Context, *datapb.SaveImportSegmentRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2179,13 +2185,13 @@ type MockDataCoord_SaveImportSegment_Call struct { } // SaveImportSegment is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.SaveImportSegmentRequest -func (_e *MockDataCoord_Expecter) SaveImportSegment(ctx interface{}, req interface{}) *MockDataCoord_SaveImportSegment_Call { - return &MockDataCoord_SaveImportSegment_Call{Call: _e.mock.On("SaveImportSegment", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.SaveImportSegmentRequest +func (_e *MockDataCoord_Expecter) SaveImportSegment(_a0 interface{}, _a1 interface{}) *MockDataCoord_SaveImportSegment_Call { + return &MockDataCoord_SaveImportSegment_Call{Call: _e.mock.On("SaveImportSegment", _a0, _a1)} } -func (_c *MockDataCoord_SaveImportSegment_Call) Run(run func(ctx context.Context, req *datapb.SaveImportSegmentRequest)) *MockDataCoord_SaveImportSegment_Call { +func (_c *MockDataCoord_SaveImportSegment_Call) Run(run func(_a0 context.Context, _a1 *datapb.SaveImportSegmentRequest)) *MockDataCoord_SaveImportSegment_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.SaveImportSegmentRequest)) }) @@ -2236,7 +2242,7 @@ func (_c *MockDataCoord_SetAddress_Call) RunAndReturn(run func(string)) *MockDat } // SetDataNodeCreator provides a mock function with given fields: _a0 -func (_m *MockDataCoord) SetDataNodeCreator(_a0 func(context.Context, string, int64) (types.DataNode, error)) { +func (_m *MockDataCoord) SetDataNodeCreator(_a0 func(context.Context, string, int64) (types.DataNodeClient, error)) { _m.Called(_a0) } @@ -2246,14 +2252,14 @@ type MockDataCoord_SetDataNodeCreator_Call struct { } // SetDataNodeCreator is a helper method to define mock.On call -// - _a0 func(context.Context , string , int64)(types.DataNode , error) +// - _a0 func(context.Context , string , int64)(types.DataNodeClient , error) func (_e *MockDataCoord_Expecter) SetDataNodeCreator(_a0 interface{}) *MockDataCoord_SetDataNodeCreator_Call { return &MockDataCoord_SetDataNodeCreator_Call{Call: _e.mock.On("SetDataNodeCreator", _a0)} } -func (_c *MockDataCoord_SetDataNodeCreator_Call) Run(run func(_a0 func(context.Context, string, int64) (types.DataNode, error))) *MockDataCoord_SetDataNodeCreator_Call { +func (_c *MockDataCoord_SetDataNodeCreator_Call) Run(run func(_a0 func(context.Context, string, int64) (types.DataNodeClient, error))) *MockDataCoord_SetDataNodeCreator_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(func(context.Context, string, int64) (types.DataNode, error))) + run(args[0].(func(context.Context, string, int64) (types.DataNodeClient, error))) }) return _c } @@ -2263,7 +2269,7 @@ func (_c *MockDataCoord_SetDataNodeCreator_Call) Return() *MockDataCoord_SetData return _c } -func (_c *MockDataCoord_SetDataNodeCreator_Call) RunAndReturn(run func(func(context.Context, string, int64) (types.DataNode, error))) *MockDataCoord_SetDataNodeCreator_Call { +func (_c *MockDataCoord_SetDataNodeCreator_Call) RunAndReturn(run func(func(context.Context, string, int64) (types.DataNodeClient, error))) *MockDataCoord_SetDataNodeCreator_Call { _c.Call.Return(run) return _c } @@ -2302,7 +2308,7 @@ func (_c *MockDataCoord_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Clie } // SetIndexNodeCreator provides a mock function with given fields: _a0 -func (_m *MockDataCoord) SetIndexNodeCreator(_a0 func(context.Context, string, int64) (types.IndexNode, error)) { +func (_m *MockDataCoord) SetIndexNodeCreator(_a0 func(context.Context, string, int64) (types.IndexNodeClient, error)) { _m.Called(_a0) } @@ -2312,14 +2318,14 @@ type MockDataCoord_SetIndexNodeCreator_Call struct { } // SetIndexNodeCreator is a helper method to define mock.On call -// - _a0 func(context.Context , string , int64)(types.IndexNode , error) +// - _a0 func(context.Context , string , int64)(types.IndexNodeClient , error) func (_e *MockDataCoord_Expecter) SetIndexNodeCreator(_a0 interface{}) *MockDataCoord_SetIndexNodeCreator_Call { return &MockDataCoord_SetIndexNodeCreator_Call{Call: _e.mock.On("SetIndexNodeCreator", _a0)} } -func (_c *MockDataCoord_SetIndexNodeCreator_Call) Run(run func(_a0 func(context.Context, string, int64) (types.IndexNode, error))) *MockDataCoord_SetIndexNodeCreator_Call { +func (_c *MockDataCoord_SetIndexNodeCreator_Call) Run(run func(_a0 func(context.Context, string, int64) (types.IndexNodeClient, error))) *MockDataCoord_SetIndexNodeCreator_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(func(context.Context, string, int64) (types.IndexNode, error))) + run(args[0].(func(context.Context, string, int64) (types.IndexNodeClient, error))) }) return _c } @@ -2329,55 +2335,55 @@ func (_c *MockDataCoord_SetIndexNodeCreator_Call) Return() *MockDataCoord_SetInd return _c } -func (_c *MockDataCoord_SetIndexNodeCreator_Call) RunAndReturn(run func(func(context.Context, string, int64) (types.IndexNode, error))) *MockDataCoord_SetIndexNodeCreator_Call { +func (_c *MockDataCoord_SetIndexNodeCreator_Call) RunAndReturn(run func(func(context.Context, string, int64) (types.IndexNodeClient, error))) *MockDataCoord_SetIndexNodeCreator_Call { _c.Call.Return(run) return _c } -// SetRootCoord provides a mock function with given fields: rootCoord -func (_m *MockDataCoord) SetRootCoord(rootCoord types.RootCoord) { +// SetRootCoordClient provides a mock function with given fields: rootCoord +func (_m *MockDataCoord) SetRootCoordClient(rootCoord types.RootCoordClient) { _m.Called(rootCoord) } -// MockDataCoord_SetRootCoord_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRootCoord' -type MockDataCoord_SetRootCoord_Call struct { +// MockDataCoord_SetRootCoordClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRootCoordClient' +type MockDataCoord_SetRootCoordClient_Call struct { *mock.Call } -// SetRootCoord is a helper method to define mock.On call -// - rootCoord types.RootCoord -func (_e *MockDataCoord_Expecter) SetRootCoord(rootCoord interface{}) *MockDataCoord_SetRootCoord_Call { - return &MockDataCoord_SetRootCoord_Call{Call: _e.mock.On("SetRootCoord", rootCoord)} +// SetRootCoordClient is a helper method to define mock.On call +// - rootCoord types.RootCoordClient +func (_e *MockDataCoord_Expecter) SetRootCoordClient(rootCoord interface{}) *MockDataCoord_SetRootCoordClient_Call { + return &MockDataCoord_SetRootCoordClient_Call{Call: _e.mock.On("SetRootCoordClient", rootCoord)} } -func (_c *MockDataCoord_SetRootCoord_Call) Run(run func(rootCoord types.RootCoord)) *MockDataCoord_SetRootCoord_Call { +func (_c *MockDataCoord_SetRootCoordClient_Call) Run(run func(rootCoord types.RootCoordClient)) *MockDataCoord_SetRootCoordClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(types.RootCoord)) + run(args[0].(types.RootCoordClient)) }) return _c } -func (_c *MockDataCoord_SetRootCoord_Call) Return() *MockDataCoord_SetRootCoord_Call { +func (_c *MockDataCoord_SetRootCoordClient_Call) Return() *MockDataCoord_SetRootCoordClient_Call { _c.Call.Return() return _c } -func (_c *MockDataCoord_SetRootCoord_Call) RunAndReturn(run func(types.RootCoord)) *MockDataCoord_SetRootCoord_Call { +func (_c *MockDataCoord_SetRootCoordClient_Call) RunAndReturn(run func(types.RootCoordClient)) *MockDataCoord_SetRootCoordClient_Call { _c.Call.Return(run) return _c } -// SetSegmentState provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) SetSegmentState(ctx context.Context, req *datapb.SetSegmentStateRequest) (*datapb.SetSegmentStateResponse, error) { - ret := _m.Called(ctx, req) +// SetSegmentState provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) SetSegmentState(_a0 context.Context, _a1 *datapb.SetSegmentStateRequest) (*datapb.SetSegmentStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.SetSegmentStateResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.SetSegmentStateRequest) (*datapb.SetSegmentStateResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.SetSegmentStateRequest) *datapb.SetSegmentStateResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.SetSegmentStateResponse) @@ -2385,7 +2391,7 @@ func (_m *MockDataCoord) SetSegmentState(ctx context.Context, req *datapb.SetSeg } if rf, ok := ret.Get(1).(func(context.Context, *datapb.SetSegmentStateRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2399,13 +2405,13 @@ type MockDataCoord_SetSegmentState_Call struct { } // SetSegmentState is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.SetSegmentStateRequest -func (_e *MockDataCoord_Expecter) SetSegmentState(ctx interface{}, req interface{}) *MockDataCoord_SetSegmentState_Call { - return &MockDataCoord_SetSegmentState_Call{Call: _e.mock.On("SetSegmentState", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.SetSegmentStateRequest +func (_e *MockDataCoord_Expecter) SetSegmentState(_a0 interface{}, _a1 interface{}) *MockDataCoord_SetSegmentState_Call { + return &MockDataCoord_SetSegmentState_Call{Call: _e.mock.On("SetSegmentState", _a0, _a1)} } -func (_c *MockDataCoord_SetSegmentState_Call) Run(run func(ctx context.Context, req *datapb.SetSegmentStateRequest)) *MockDataCoord_SetSegmentState_Call { +func (_c *MockDataCoord_SetSegmentState_Call) Run(run func(_a0 context.Context, _a1 *datapb.SetSegmentStateRequest)) *MockDataCoord_SetSegmentState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.SetSegmentStateRequest)) }) @@ -2422,17 +2428,50 @@ func (_c *MockDataCoord_SetSegmentState_Call) RunAndReturn(run func(context.Cont return _c } -// ShowConfigurations provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - ret := _m.Called(ctx, req) +// SetTiKVClient provides a mock function with given fields: client +func (_m *MockDataCoord) SetTiKVClient(client *txnkv.Client) { + _m.Called(client) +} + +// MockDataCoord_SetTiKVClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTiKVClient' +type MockDataCoord_SetTiKVClient_Call struct { + *mock.Call +} + +// SetTiKVClient is a helper method to define mock.On call +// - client *txnkv.Client +func (_e *MockDataCoord_Expecter) SetTiKVClient(client interface{}) *MockDataCoord_SetTiKVClient_Call { + return &MockDataCoord_SetTiKVClient_Call{Call: _e.mock.On("SetTiKVClient", client)} +} + +func (_c *MockDataCoord_SetTiKVClient_Call) Run(run func(client *txnkv.Client)) *MockDataCoord_SetTiKVClient_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*txnkv.Client)) + }) + return _c +} + +func (_c *MockDataCoord_SetTiKVClient_Call) Return() *MockDataCoord_SetTiKVClient_Call { + _c.Call.Return() + return _c +} + +func (_c *MockDataCoord_SetTiKVClient_Call) RunAndReturn(run func(*txnkv.Client)) *MockDataCoord_SetTiKVClient_Call { + _c.Call.Return(run) + return _c +} + +// ShowConfigurations provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) ShowConfigurations(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *internalpb.ShowConfigurationsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) *internalpb.ShowConfigurationsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*internalpb.ShowConfigurationsResponse) @@ -2440,7 +2479,7 @@ func (_m *MockDataCoord) ShowConfigurations(ctx context.Context, req *internalpb } if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2454,13 +2493,13 @@ type MockDataCoord_ShowConfigurations_Call struct { } // ShowConfigurations is a helper method to define mock.On call -// - ctx context.Context -// - req *internalpb.ShowConfigurationsRequest -func (_e *MockDataCoord_Expecter) ShowConfigurations(ctx interface{}, req interface{}) *MockDataCoord_ShowConfigurations_Call { - return &MockDataCoord_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", ctx, req)} +// - _a0 context.Context +// - _a1 *internalpb.ShowConfigurationsRequest +func (_e *MockDataCoord_Expecter) ShowConfigurations(_a0 interface{}, _a1 interface{}) *MockDataCoord_ShowConfigurations_Call { + return &MockDataCoord_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", _a0, _a1)} } -func (_c *MockDataCoord_ShowConfigurations_Call) Run(run func(ctx context.Context, req *internalpb.ShowConfigurationsRequest)) *MockDataCoord_ShowConfigurations_Call { +func (_c *MockDataCoord_ShowConfigurations_Call) Run(run func(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest)) *MockDataCoord_ShowConfigurations_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*internalpb.ShowConfigurationsRequest)) }) @@ -2559,17 +2598,17 @@ func (_c *MockDataCoord_Stop_Call) RunAndReturn(run func() error) *MockDataCoord return _c } -// UnsetIsImportingState provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// UnsetIsImportingState provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) UnsetIsImportingState(_a0 context.Context, _a1 *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.UnsetIsImportingStateRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -2577,7 +2616,7 @@ func (_m *MockDataCoord) UnsetIsImportingState(ctx context.Context, req *datapb. } if rf, ok := ret.Get(1).(func(context.Context, *datapb.UnsetIsImportingStateRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2591,13 +2630,13 @@ type MockDataCoord_UnsetIsImportingState_Call struct { } // UnsetIsImportingState is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.UnsetIsImportingStateRequest -func (_e *MockDataCoord_Expecter) UnsetIsImportingState(ctx interface{}, req interface{}) *MockDataCoord_UnsetIsImportingState_Call { - return &MockDataCoord_UnsetIsImportingState_Call{Call: _e.mock.On("UnsetIsImportingState", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.UnsetIsImportingStateRequest +func (_e *MockDataCoord_Expecter) UnsetIsImportingState(_a0 interface{}, _a1 interface{}) *MockDataCoord_UnsetIsImportingState_Call { + return &MockDataCoord_UnsetIsImportingState_Call{Call: _e.mock.On("UnsetIsImportingState", _a0, _a1)} } -func (_c *MockDataCoord_UnsetIsImportingState_Call) Run(run func(ctx context.Context, req *datapb.UnsetIsImportingStateRequest)) *MockDataCoord_UnsetIsImportingState_Call { +func (_c *MockDataCoord_UnsetIsImportingState_Call) Run(run func(_a0 context.Context, _a1 *datapb.UnsetIsImportingStateRequest)) *MockDataCoord_UnsetIsImportingState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.UnsetIsImportingStateRequest)) }) @@ -2614,17 +2653,17 @@ func (_c *MockDataCoord_UnsetIsImportingState_Call) RunAndReturn(run func(contex return _c } -// UpdateChannelCheckpoint provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) UpdateChannelCheckpoint(ctx context.Context, req *datapb.UpdateChannelCheckpointRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// UpdateChannelCheckpoint provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) UpdateChannelCheckpoint(_a0 context.Context, _a1 *datapb.UpdateChannelCheckpointRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.UpdateChannelCheckpointRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.UpdateChannelCheckpointRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -2632,7 +2671,7 @@ func (_m *MockDataCoord) UpdateChannelCheckpoint(ctx context.Context, req *datap } if rf, ok := ret.Get(1).(func(context.Context, *datapb.UpdateChannelCheckpointRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2646,13 +2685,13 @@ type MockDataCoord_UpdateChannelCheckpoint_Call struct { } // UpdateChannelCheckpoint is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.UpdateChannelCheckpointRequest -func (_e *MockDataCoord_Expecter) UpdateChannelCheckpoint(ctx interface{}, req interface{}) *MockDataCoord_UpdateChannelCheckpoint_Call { - return &MockDataCoord_UpdateChannelCheckpoint_Call{Call: _e.mock.On("UpdateChannelCheckpoint", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.UpdateChannelCheckpointRequest +func (_e *MockDataCoord_Expecter) UpdateChannelCheckpoint(_a0 interface{}, _a1 interface{}) *MockDataCoord_UpdateChannelCheckpoint_Call { + return &MockDataCoord_UpdateChannelCheckpoint_Call{Call: _e.mock.On("UpdateChannelCheckpoint", _a0, _a1)} } -func (_c *MockDataCoord_UpdateChannelCheckpoint_Call) Run(run func(ctx context.Context, req *datapb.UpdateChannelCheckpointRequest)) *MockDataCoord_UpdateChannelCheckpoint_Call { +func (_c *MockDataCoord_UpdateChannelCheckpoint_Call) Run(run func(_a0 context.Context, _a1 *datapb.UpdateChannelCheckpointRequest)) *MockDataCoord_UpdateChannelCheckpoint_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.UpdateChannelCheckpointRequest)) }) @@ -2669,17 +2708,17 @@ func (_c *MockDataCoord_UpdateChannelCheckpoint_Call) RunAndReturn(run func(cont return _c } -// UpdateSegmentStatistics provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// UpdateSegmentStatistics provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) UpdateSegmentStatistics(_a0 context.Context, _a1 *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.UpdateSegmentStatisticsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -2687,7 +2726,7 @@ func (_m *MockDataCoord) UpdateSegmentStatistics(ctx context.Context, req *datap } if rf, ok := ret.Get(1).(func(context.Context, *datapb.UpdateSegmentStatisticsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2701,13 +2740,13 @@ type MockDataCoord_UpdateSegmentStatistics_Call struct { } // UpdateSegmentStatistics is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.UpdateSegmentStatisticsRequest -func (_e *MockDataCoord_Expecter) UpdateSegmentStatistics(ctx interface{}, req interface{}) *MockDataCoord_UpdateSegmentStatistics_Call { - return &MockDataCoord_UpdateSegmentStatistics_Call{Call: _e.mock.On("UpdateSegmentStatistics", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.UpdateSegmentStatisticsRequest +func (_e *MockDataCoord_Expecter) UpdateSegmentStatistics(_a0 interface{}, _a1 interface{}) *MockDataCoord_UpdateSegmentStatistics_Call { + return &MockDataCoord_UpdateSegmentStatistics_Call{Call: _e.mock.On("UpdateSegmentStatistics", _a0, _a1)} } -func (_c *MockDataCoord_UpdateSegmentStatistics_Call) Run(run func(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest)) *MockDataCoord_UpdateSegmentStatistics_Call { +func (_c *MockDataCoord_UpdateSegmentStatistics_Call) Run(run func(_a0 context.Context, _a1 *datapb.UpdateSegmentStatisticsRequest)) *MockDataCoord_UpdateSegmentStatistics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.UpdateSegmentStatisticsRequest)) }) @@ -2724,17 +2763,17 @@ func (_c *MockDataCoord_UpdateSegmentStatistics_Call) RunAndReturn(run func(cont return _c } -// WatchChannels provides a mock function with given fields: ctx, req -func (_m *MockDataCoord) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) { - ret := _m.Called(ctx, req) +// WatchChannels provides a mock function with given fields: _a0, _a1 +func (_m *MockDataCoord) WatchChannels(_a0 context.Context, _a1 *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.WatchChannelsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.WatchChannelsRequest) *datapb.WatchChannelsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.WatchChannelsResponse) @@ -2742,7 +2781,7 @@ func (_m *MockDataCoord) WatchChannels(ctx context.Context, req *datapb.WatchCha } if rf, ok := ret.Get(1).(func(context.Context, *datapb.WatchChannelsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2756,13 +2795,13 @@ type MockDataCoord_WatchChannels_Call struct { } // WatchChannels is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.WatchChannelsRequest -func (_e *MockDataCoord_Expecter) WatchChannels(ctx interface{}, req interface{}) *MockDataCoord_WatchChannels_Call { - return &MockDataCoord_WatchChannels_Call{Call: _e.mock.On("WatchChannels", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.WatchChannelsRequest +func (_e *MockDataCoord_Expecter) WatchChannels(_a0 interface{}, _a1 interface{}) *MockDataCoord_WatchChannels_Call { + return &MockDataCoord_WatchChannels_Call{Call: _e.mock.On("WatchChannels", _a0, _a1)} } -func (_c *MockDataCoord_WatchChannels_Call) Run(run func(ctx context.Context, req *datapb.WatchChannelsRequest)) *MockDataCoord_WatchChannels_Call { +func (_c *MockDataCoord_WatchChannels_Call) Run(run func(_a0 context.Context, _a1 *datapb.WatchChannelsRequest)) *MockDataCoord_WatchChannels_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.WatchChannelsRequest)) }) diff --git a/internal/mocks/mock_datacoord_client.go b/internal/mocks/mock_datacoord_client.go new file mode 100644 index 0000000000000..ab7b31ab64eec --- /dev/null +++ b/internal/mocks/mock_datacoord_client.go @@ -0,0 +1,3169 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + + commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + + datapb "github.com/milvus-io/milvus/internal/proto/datapb" + + grpc "google.golang.org/grpc" + + indexpb "github.com/milvus-io/milvus/internal/proto/indexpb" + + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" + + milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + + mock "github.com/stretchr/testify/mock" +) + +// MockDataCoordClient is an autogenerated mock type for the DataCoordClient type +type MockDataCoordClient struct { + mock.Mock +} + +type MockDataCoordClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockDataCoordClient) EXPECT() *MockDataCoordClient_Expecter { + return &MockDataCoordClient_Expecter{mock: &_m.Mock} +} + +// AssignSegmentID provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) AssignSegmentID(ctx context.Context, in *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.AssignSegmentIDResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.AssignSegmentIDRequest, ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.AssignSegmentIDRequest, ...grpc.CallOption) *datapb.AssignSegmentIDResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.AssignSegmentIDResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.AssignSegmentIDRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_AssignSegmentID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AssignSegmentID' +type MockDataCoordClient_AssignSegmentID_Call struct { + *mock.Call +} + +// AssignSegmentID is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.AssignSegmentIDRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) AssignSegmentID(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_AssignSegmentID_Call { + return &MockDataCoordClient_AssignSegmentID_Call{Call: _e.mock.On("AssignSegmentID", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_AssignSegmentID_Call) Run(run func(ctx context.Context, in *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption)) *MockDataCoordClient_AssignSegmentID_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.AssignSegmentIDRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_AssignSegmentID_Call) Return(_a0 *datapb.AssignSegmentIDResponse, _a1 error) *MockDataCoordClient_AssignSegmentID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_AssignSegmentID_Call) RunAndReturn(run func(context.Context, *datapb.AssignSegmentIDRequest, ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error)) *MockDataCoordClient_AssignSegmentID_Call { + _c.Call.Return(run) + return _c +} + +// BroadcastAlteredCollection provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) BroadcastAlteredCollection(ctx context.Context, in *datapb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.AlterCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.AlterCollectionRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.AlterCollectionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_BroadcastAlteredCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BroadcastAlteredCollection' +type MockDataCoordClient_BroadcastAlteredCollection_Call struct { + *mock.Call +} + +// BroadcastAlteredCollection is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.AlterCollectionRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) BroadcastAlteredCollection(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_BroadcastAlteredCollection_Call { + return &MockDataCoordClient_BroadcastAlteredCollection_Call{Call: _e.mock.On("BroadcastAlteredCollection", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_BroadcastAlteredCollection_Call) Run(run func(ctx context.Context, in *datapb.AlterCollectionRequest, opts ...grpc.CallOption)) *MockDataCoordClient_BroadcastAlteredCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.AlterCollectionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_BroadcastAlteredCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataCoordClient_BroadcastAlteredCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_BroadcastAlteredCollection_Call) RunAndReturn(run func(context.Context, *datapb.AlterCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataCoordClient_BroadcastAlteredCollection_Call { + _c.Call.Return(run) + return _c +} + +// CheckHealth provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.CheckHealthResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) *milvuspb.CheckHealthResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_CheckHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckHealth' +type MockDataCoordClient_CheckHealth_Call struct { + *mock.Call +} + +// CheckHealth is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.CheckHealthRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) CheckHealth(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_CheckHealth_Call { + return &MockDataCoordClient_CheckHealth_Call{Call: _e.mock.On("CheckHealth", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_CheckHealth_Call) Run(run func(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption)) *MockDataCoordClient_CheckHealth_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.CheckHealthRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_CheckHealth_Call) Return(_a0 *milvuspb.CheckHealthResponse, _a1 error) *MockDataCoordClient_CheckHealth_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_CheckHealth_Call) RunAndReturn(run func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error)) *MockDataCoordClient_CheckHealth_Call { + _c.Call.Return(run) + return _c +} + +// Close provides a mock function with given fields: +func (_m *MockDataCoordClient) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockDataCoordClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockDataCoordClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockDataCoordClient_Expecter) Close() *MockDataCoordClient_Close_Call { + return &MockDataCoordClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockDataCoordClient_Close_Call) Run(run func()) *MockDataCoordClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockDataCoordClient_Close_Call) Return(_a0 error) *MockDataCoordClient_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockDataCoordClient_Close_Call) RunAndReturn(run func() error) *MockDataCoordClient_Close_Call { + _c.Call.Return(run) + return _c +} + +// CreateIndex provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) CreateIndex(ctx context.Context, in *indexpb.CreateIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateIndexRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateIndexRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.CreateIndexRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_CreateIndex_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateIndex' +type MockDataCoordClient_CreateIndex_Call struct { + *mock.Call +} + +// CreateIndex is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.CreateIndexRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) CreateIndex(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_CreateIndex_Call { + return &MockDataCoordClient_CreateIndex_Call{Call: _e.mock.On("CreateIndex", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_CreateIndex_Call) Run(run func(ctx context.Context, in *indexpb.CreateIndexRequest, opts ...grpc.CallOption)) *MockDataCoordClient_CreateIndex_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.CreateIndexRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_CreateIndex_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataCoordClient_CreateIndex_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_CreateIndex_Call) RunAndReturn(run func(context.Context, *indexpb.CreateIndexRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataCoordClient_CreateIndex_Call { + _c.Call.Return(run) + return _c +} + +// DescribeIndex provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) DescribeIndex(ctx context.Context, in *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *indexpb.DescribeIndexResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DescribeIndexRequest, ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DescribeIndexRequest, ...grpc.CallOption) *indexpb.DescribeIndexResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*indexpb.DescribeIndexResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.DescribeIndexRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_DescribeIndex_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeIndex' +type MockDataCoordClient_DescribeIndex_Call struct { + *mock.Call +} + +// DescribeIndex is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.DescribeIndexRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) DescribeIndex(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_DescribeIndex_Call { + return &MockDataCoordClient_DescribeIndex_Call{Call: _e.mock.On("DescribeIndex", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_DescribeIndex_Call) Run(run func(ctx context.Context, in *indexpb.DescribeIndexRequest, opts ...grpc.CallOption)) *MockDataCoordClient_DescribeIndex_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.DescribeIndexRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_DescribeIndex_Call) Return(_a0 *indexpb.DescribeIndexResponse, _a1 error) *MockDataCoordClient_DescribeIndex_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_DescribeIndex_Call) RunAndReturn(run func(context.Context, *indexpb.DescribeIndexRequest, ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error)) *MockDataCoordClient_DescribeIndex_Call { + _c.Call.Return(run) + return _c +} + +// DropIndex provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) DropIndex(ctx context.Context, in *indexpb.DropIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropIndexRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropIndexRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.DropIndexRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_DropIndex_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropIndex' +type MockDataCoordClient_DropIndex_Call struct { + *mock.Call +} + +// DropIndex is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.DropIndexRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) DropIndex(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_DropIndex_Call { + return &MockDataCoordClient_DropIndex_Call{Call: _e.mock.On("DropIndex", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_DropIndex_Call) Run(run func(ctx context.Context, in *indexpb.DropIndexRequest, opts ...grpc.CallOption)) *MockDataCoordClient_DropIndex_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.DropIndexRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_DropIndex_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataCoordClient_DropIndex_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_DropIndex_Call) RunAndReturn(run func(context.Context, *indexpb.DropIndexRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataCoordClient_DropIndex_Call { + _c.Call.Return(run) + return _c +} + +// DropVirtualChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) DropVirtualChannel(ctx context.Context, in *datapb.DropVirtualChannelRequest, opts ...grpc.CallOption) (*datapb.DropVirtualChannelResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.DropVirtualChannelResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.DropVirtualChannelRequest, ...grpc.CallOption) (*datapb.DropVirtualChannelResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.DropVirtualChannelRequest, ...grpc.CallOption) *datapb.DropVirtualChannelResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.DropVirtualChannelResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.DropVirtualChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_DropVirtualChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropVirtualChannel' +type MockDataCoordClient_DropVirtualChannel_Call struct { + *mock.Call +} + +// DropVirtualChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.DropVirtualChannelRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) DropVirtualChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_DropVirtualChannel_Call { + return &MockDataCoordClient_DropVirtualChannel_Call{Call: _e.mock.On("DropVirtualChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_DropVirtualChannel_Call) Run(run func(ctx context.Context, in *datapb.DropVirtualChannelRequest, opts ...grpc.CallOption)) *MockDataCoordClient_DropVirtualChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.DropVirtualChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_DropVirtualChannel_Call) Return(_a0 *datapb.DropVirtualChannelResponse, _a1 error) *MockDataCoordClient_DropVirtualChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_DropVirtualChannel_Call) RunAndReturn(run func(context.Context, *datapb.DropVirtualChannelRequest, ...grpc.CallOption) (*datapb.DropVirtualChannelResponse, error)) *MockDataCoordClient_DropVirtualChannel_Call { + _c.Call.Return(run) + return _c +} + +// Flush provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) Flush(ctx context.Context, in *datapb.FlushRequest, opts ...grpc.CallOption) (*datapb.FlushResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.FlushResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.FlushRequest, ...grpc.CallOption) (*datapb.FlushResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.FlushRequest, ...grpc.CallOption) *datapb.FlushResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.FlushResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.FlushRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_Flush_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Flush' +type MockDataCoordClient_Flush_Call struct { + *mock.Call +} + +// Flush is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.FlushRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) Flush(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_Flush_Call { + return &MockDataCoordClient_Flush_Call{Call: _e.mock.On("Flush", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_Flush_Call) Run(run func(ctx context.Context, in *datapb.FlushRequest, opts ...grpc.CallOption)) *MockDataCoordClient_Flush_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.FlushRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_Flush_Call) Return(_a0 *datapb.FlushResponse, _a1 error) *MockDataCoordClient_Flush_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_Flush_Call) RunAndReturn(run func(context.Context, *datapb.FlushRequest, ...grpc.CallOption) (*datapb.FlushResponse, error)) *MockDataCoordClient_Flush_Call { + _c.Call.Return(run) + return _c +} + +// GcConfirm provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GcConfirm(ctx context.Context, in *datapb.GcConfirmRequest, opts ...grpc.CallOption) (*datapb.GcConfirmResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.GcConfirmResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GcConfirmRequest, ...grpc.CallOption) (*datapb.GcConfirmResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GcConfirmRequest, ...grpc.CallOption) *datapb.GcConfirmResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.GcConfirmResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GcConfirmRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GcConfirm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GcConfirm' +type MockDataCoordClient_GcConfirm_Call struct { + *mock.Call +} + +// GcConfirm is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.GcConfirmRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GcConfirm(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GcConfirm_Call { + return &MockDataCoordClient_GcConfirm_Call{Call: _e.mock.On("GcConfirm", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GcConfirm_Call) Run(run func(ctx context.Context, in *datapb.GcConfirmRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GcConfirm_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.GcConfirmRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GcConfirm_Call) Return(_a0 *datapb.GcConfirmResponse, _a1 error) *MockDataCoordClient_GcConfirm_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GcConfirm_Call) RunAndReturn(run func(context.Context, *datapb.GcConfirmRequest, ...grpc.CallOption) (*datapb.GcConfirmResponse, error)) *MockDataCoordClient_GcConfirm_Call { + _c.Call.Return(run) + return _c +} + +// GetCollectionStatistics provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetCollectionStatistics(ctx context.Context, in *datapb.GetCollectionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetCollectionStatisticsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.GetCollectionStatisticsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetCollectionStatisticsRequest, ...grpc.CallOption) (*datapb.GetCollectionStatisticsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetCollectionStatisticsRequest, ...grpc.CallOption) *datapb.GetCollectionStatisticsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.GetCollectionStatisticsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetCollectionStatisticsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetCollectionStatistics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionStatistics' +type MockDataCoordClient_GetCollectionStatistics_Call struct { + *mock.Call +} + +// GetCollectionStatistics is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.GetCollectionStatisticsRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetCollectionStatistics(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetCollectionStatistics_Call { + return &MockDataCoordClient_GetCollectionStatistics_Call{Call: _e.mock.On("GetCollectionStatistics", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetCollectionStatistics_Call) Run(run func(ctx context.Context, in *datapb.GetCollectionStatisticsRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetCollectionStatistics_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.GetCollectionStatisticsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetCollectionStatistics_Call) Return(_a0 *datapb.GetCollectionStatisticsResponse, _a1 error) *MockDataCoordClient_GetCollectionStatistics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetCollectionStatistics_Call) RunAndReturn(run func(context.Context, *datapb.GetCollectionStatisticsRequest, ...grpc.CallOption) (*datapb.GetCollectionStatisticsResponse, error)) *MockDataCoordClient_GetCollectionStatistics_Call { + _c.Call.Return(run) + return _c +} + +// GetCompactionState provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetCompactionState(ctx context.Context, in *milvuspb.GetCompactionStateRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionStateResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.GetCompactionStateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionStateRequest, ...grpc.CallOption) (*milvuspb.GetCompactionStateResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionStateRequest, ...grpc.CallOption) *milvuspb.GetCompactionStateResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetCompactionStateResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetCompactionStateRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetCompactionState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCompactionState' +type MockDataCoordClient_GetCompactionState_Call struct { + *mock.Call +} + +// GetCompactionState is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetCompactionStateRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetCompactionState(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetCompactionState_Call { + return &MockDataCoordClient_GetCompactionState_Call{Call: _e.mock.On("GetCompactionState", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetCompactionState_Call) Run(run func(ctx context.Context, in *milvuspb.GetCompactionStateRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetCompactionState_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetCompactionStateRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetCompactionState_Call) Return(_a0 *milvuspb.GetCompactionStateResponse, _a1 error) *MockDataCoordClient_GetCompactionState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetCompactionState_Call) RunAndReturn(run func(context.Context, *milvuspb.GetCompactionStateRequest, ...grpc.CallOption) (*milvuspb.GetCompactionStateResponse, error)) *MockDataCoordClient_GetCompactionState_Call { + _c.Call.Return(run) + return _c +} + +// GetCompactionStateWithPlans provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetCompactionStateWithPlans(ctx context.Context, in *milvuspb.GetCompactionPlansRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionPlansResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.GetCompactionPlansResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionPlansRequest, ...grpc.CallOption) (*milvuspb.GetCompactionPlansResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionPlansRequest, ...grpc.CallOption) *milvuspb.GetCompactionPlansResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetCompactionPlansResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetCompactionPlansRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetCompactionStateWithPlans_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCompactionStateWithPlans' +type MockDataCoordClient_GetCompactionStateWithPlans_Call struct { + *mock.Call +} + +// GetCompactionStateWithPlans is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetCompactionPlansRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetCompactionStateWithPlans(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetCompactionStateWithPlans_Call { + return &MockDataCoordClient_GetCompactionStateWithPlans_Call{Call: _e.mock.On("GetCompactionStateWithPlans", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetCompactionStateWithPlans_Call) Run(run func(ctx context.Context, in *milvuspb.GetCompactionPlansRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetCompactionStateWithPlans_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetCompactionPlansRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetCompactionStateWithPlans_Call) Return(_a0 *milvuspb.GetCompactionPlansResponse, _a1 error) *MockDataCoordClient_GetCompactionStateWithPlans_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetCompactionStateWithPlans_Call) RunAndReturn(run func(context.Context, *milvuspb.GetCompactionPlansRequest, ...grpc.CallOption) (*milvuspb.GetCompactionPlansResponse, error)) *MockDataCoordClient_GetCompactionStateWithPlans_Call { + _c.Call.Return(run) + return _c +} + +// GetComponentStates provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ComponentStates + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) *milvuspb.ComponentStates); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ComponentStates) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetComponentStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetComponentStates' +type MockDataCoordClient_GetComponentStates_Call struct { + *mock.Call +} + +// GetComponentStates is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetComponentStatesRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetComponentStates(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetComponentStates_Call { + return &MockDataCoordClient_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetComponentStates_Call) Run(run func(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetComponentStates_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentStates, _a1 error) *MockDataCoordClient_GetComponentStates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)) *MockDataCoordClient_GetComponentStates_Call { + _c.Call.Return(run) + return _c +} + +// GetFlushAllState provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetFlushAllState(ctx context.Context, in *milvuspb.GetFlushAllStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushAllStateResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.GetFlushAllStateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetFlushAllStateRequest, ...grpc.CallOption) (*milvuspb.GetFlushAllStateResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetFlushAllStateRequest, ...grpc.CallOption) *milvuspb.GetFlushAllStateResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetFlushAllStateResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetFlushAllStateRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetFlushAllState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetFlushAllState' +type MockDataCoordClient_GetFlushAllState_Call struct { + *mock.Call +} + +// GetFlushAllState is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetFlushAllStateRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetFlushAllState(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetFlushAllState_Call { + return &MockDataCoordClient_GetFlushAllState_Call{Call: _e.mock.On("GetFlushAllState", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetFlushAllState_Call) Run(run func(ctx context.Context, in *milvuspb.GetFlushAllStateRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetFlushAllState_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetFlushAllStateRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetFlushAllState_Call) Return(_a0 *milvuspb.GetFlushAllStateResponse, _a1 error) *MockDataCoordClient_GetFlushAllState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetFlushAllState_Call) RunAndReturn(run func(context.Context, *milvuspb.GetFlushAllStateRequest, ...grpc.CallOption) (*milvuspb.GetFlushAllStateResponse, error)) *MockDataCoordClient_GetFlushAllState_Call { + _c.Call.Return(run) + return _c +} + +// GetFlushState provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetFlushState(ctx context.Context, in *datapb.GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.GetFlushStateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetFlushStateRequest, ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetFlushStateRequest, ...grpc.CallOption) *milvuspb.GetFlushStateResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetFlushStateResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetFlushStateRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetFlushState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetFlushState' +type MockDataCoordClient_GetFlushState_Call struct { + *mock.Call +} + +// GetFlushState is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.GetFlushStateRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetFlushState(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetFlushState_Call { + return &MockDataCoordClient_GetFlushState_Call{Call: _e.mock.On("GetFlushState", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetFlushState_Call) Run(run func(ctx context.Context, in *datapb.GetFlushStateRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetFlushState_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.GetFlushStateRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetFlushState_Call) Return(_a0 *milvuspb.GetFlushStateResponse, _a1 error) *MockDataCoordClient_GetFlushState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetFlushState_Call) RunAndReturn(run func(context.Context, *datapb.GetFlushStateRequest, ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error)) *MockDataCoordClient_GetFlushState_Call { + _c.Call.Return(run) + return _c +} + +// GetFlushedSegments provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetFlushedSegments(ctx context.Context, in *datapb.GetFlushedSegmentsRequest, opts ...grpc.CallOption) (*datapb.GetFlushedSegmentsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.GetFlushedSegmentsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetFlushedSegmentsRequest, ...grpc.CallOption) (*datapb.GetFlushedSegmentsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetFlushedSegmentsRequest, ...grpc.CallOption) *datapb.GetFlushedSegmentsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.GetFlushedSegmentsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetFlushedSegmentsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetFlushedSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetFlushedSegments' +type MockDataCoordClient_GetFlushedSegments_Call struct { + *mock.Call +} + +// GetFlushedSegments is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.GetFlushedSegmentsRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetFlushedSegments(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetFlushedSegments_Call { + return &MockDataCoordClient_GetFlushedSegments_Call{Call: _e.mock.On("GetFlushedSegments", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetFlushedSegments_Call) Run(run func(ctx context.Context, in *datapb.GetFlushedSegmentsRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetFlushedSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.GetFlushedSegmentsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetFlushedSegments_Call) Return(_a0 *datapb.GetFlushedSegmentsResponse, _a1 error) *MockDataCoordClient_GetFlushedSegments_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetFlushedSegments_Call) RunAndReturn(run func(context.Context, *datapb.GetFlushedSegmentsRequest, ...grpc.CallOption) (*datapb.GetFlushedSegmentsResponse, error)) *MockDataCoordClient_GetFlushedSegments_Call { + _c.Call.Return(run) + return _c +} + +// GetIndexBuildProgress provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetIndexBuildProgress(ctx context.Context, in *indexpb.GetIndexBuildProgressRequest, opts ...grpc.CallOption) (*indexpb.GetIndexBuildProgressResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *indexpb.GetIndexBuildProgressResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexBuildProgressRequest, ...grpc.CallOption) (*indexpb.GetIndexBuildProgressResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexBuildProgressRequest, ...grpc.CallOption) *indexpb.GetIndexBuildProgressResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*indexpb.GetIndexBuildProgressResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.GetIndexBuildProgressRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetIndexBuildProgress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetIndexBuildProgress' +type MockDataCoordClient_GetIndexBuildProgress_Call struct { + *mock.Call +} + +// GetIndexBuildProgress is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.GetIndexBuildProgressRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetIndexBuildProgress(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetIndexBuildProgress_Call { + return &MockDataCoordClient_GetIndexBuildProgress_Call{Call: _e.mock.On("GetIndexBuildProgress", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetIndexBuildProgress_Call) Run(run func(ctx context.Context, in *indexpb.GetIndexBuildProgressRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetIndexBuildProgress_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.GetIndexBuildProgressRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetIndexBuildProgress_Call) Return(_a0 *indexpb.GetIndexBuildProgressResponse, _a1 error) *MockDataCoordClient_GetIndexBuildProgress_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetIndexBuildProgress_Call) RunAndReturn(run func(context.Context, *indexpb.GetIndexBuildProgressRequest, ...grpc.CallOption) (*indexpb.GetIndexBuildProgressResponse, error)) *MockDataCoordClient_GetIndexBuildProgress_Call { + _c.Call.Return(run) + return _c +} + +// GetIndexInfos provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetIndexInfos(ctx context.Context, in *indexpb.GetIndexInfoRequest, opts ...grpc.CallOption) (*indexpb.GetIndexInfoResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *indexpb.GetIndexInfoResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexInfoRequest, ...grpc.CallOption) (*indexpb.GetIndexInfoResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexInfoRequest, ...grpc.CallOption) *indexpb.GetIndexInfoResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*indexpb.GetIndexInfoResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.GetIndexInfoRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetIndexInfos_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetIndexInfos' +type MockDataCoordClient_GetIndexInfos_Call struct { + *mock.Call +} + +// GetIndexInfos is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.GetIndexInfoRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetIndexInfos(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetIndexInfos_Call { + return &MockDataCoordClient_GetIndexInfos_Call{Call: _e.mock.On("GetIndexInfos", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetIndexInfos_Call) Run(run func(ctx context.Context, in *indexpb.GetIndexInfoRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetIndexInfos_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.GetIndexInfoRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetIndexInfos_Call) Return(_a0 *indexpb.GetIndexInfoResponse, _a1 error) *MockDataCoordClient_GetIndexInfos_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetIndexInfos_Call) RunAndReturn(run func(context.Context, *indexpb.GetIndexInfoRequest, ...grpc.CallOption) (*indexpb.GetIndexInfoResponse, error)) *MockDataCoordClient_GetIndexInfos_Call { + _c.Call.Return(run) + return _c +} + +// GetIndexState provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetIndexState(ctx context.Context, in *indexpb.GetIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetIndexStateResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *indexpb.GetIndexStateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexStateRequest, ...grpc.CallOption) (*indexpb.GetIndexStateResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexStateRequest, ...grpc.CallOption) *indexpb.GetIndexStateResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*indexpb.GetIndexStateResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.GetIndexStateRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetIndexState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetIndexState' +type MockDataCoordClient_GetIndexState_Call struct { + *mock.Call +} + +// GetIndexState is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.GetIndexStateRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetIndexState(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetIndexState_Call { + return &MockDataCoordClient_GetIndexState_Call{Call: _e.mock.On("GetIndexState", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetIndexState_Call) Run(run func(ctx context.Context, in *indexpb.GetIndexStateRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetIndexState_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.GetIndexStateRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetIndexState_Call) Return(_a0 *indexpb.GetIndexStateResponse, _a1 error) *MockDataCoordClient_GetIndexState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetIndexState_Call) RunAndReturn(run func(context.Context, *indexpb.GetIndexStateRequest, ...grpc.CallOption) (*indexpb.GetIndexStateResponse, error)) *MockDataCoordClient_GetIndexState_Call { + _c.Call.Return(run) + return _c +} + +// GetIndexStatistics provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetIndexStatistics(ctx context.Context, in *indexpb.GetIndexStatisticsRequest, opts ...grpc.CallOption) (*indexpb.GetIndexStatisticsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *indexpb.GetIndexStatisticsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexStatisticsRequest, ...grpc.CallOption) (*indexpb.GetIndexStatisticsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetIndexStatisticsRequest, ...grpc.CallOption) *indexpb.GetIndexStatisticsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*indexpb.GetIndexStatisticsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.GetIndexStatisticsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetIndexStatistics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetIndexStatistics' +type MockDataCoordClient_GetIndexStatistics_Call struct { + *mock.Call +} + +// GetIndexStatistics is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.GetIndexStatisticsRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetIndexStatistics(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetIndexStatistics_Call { + return &MockDataCoordClient_GetIndexStatistics_Call{Call: _e.mock.On("GetIndexStatistics", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetIndexStatistics_Call) Run(run func(ctx context.Context, in *indexpb.GetIndexStatisticsRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetIndexStatistics_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.GetIndexStatisticsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetIndexStatistics_Call) Return(_a0 *indexpb.GetIndexStatisticsResponse, _a1 error) *MockDataCoordClient_GetIndexStatistics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetIndexStatistics_Call) RunAndReturn(run func(context.Context, *indexpb.GetIndexStatisticsRequest, ...grpc.CallOption) (*indexpb.GetIndexStatisticsResponse, error)) *MockDataCoordClient_GetIndexStatistics_Call { + _c.Call.Return(run) + return _c +} + +// GetInsertBinlogPaths provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetInsertBinlogPaths(ctx context.Context, in *datapb.GetInsertBinlogPathsRequest, opts ...grpc.CallOption) (*datapb.GetInsertBinlogPathsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.GetInsertBinlogPathsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetInsertBinlogPathsRequest, ...grpc.CallOption) (*datapb.GetInsertBinlogPathsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetInsertBinlogPathsRequest, ...grpc.CallOption) *datapb.GetInsertBinlogPathsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.GetInsertBinlogPathsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetInsertBinlogPathsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetInsertBinlogPaths_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetInsertBinlogPaths' +type MockDataCoordClient_GetInsertBinlogPaths_Call struct { + *mock.Call +} + +// GetInsertBinlogPaths is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.GetInsertBinlogPathsRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetInsertBinlogPaths(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetInsertBinlogPaths_Call { + return &MockDataCoordClient_GetInsertBinlogPaths_Call{Call: _e.mock.On("GetInsertBinlogPaths", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetInsertBinlogPaths_Call) Run(run func(ctx context.Context, in *datapb.GetInsertBinlogPathsRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetInsertBinlogPaths_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.GetInsertBinlogPathsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetInsertBinlogPaths_Call) Return(_a0 *datapb.GetInsertBinlogPathsResponse, _a1 error) *MockDataCoordClient_GetInsertBinlogPaths_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetInsertBinlogPaths_Call) RunAndReturn(run func(context.Context, *datapb.GetInsertBinlogPathsRequest, ...grpc.CallOption) (*datapb.GetInsertBinlogPathsResponse, error)) *MockDataCoordClient_GetInsertBinlogPaths_Call { + _c.Call.Return(run) + return _c +} + +// GetMetrics provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.GetMetricsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) *milvuspb.GetMetricsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetMetrics' +type MockDataCoordClient_GetMetrics_Call struct { + *mock.Call +} + +// GetMetrics is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetMetricsRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetMetrics(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetMetrics_Call { + return &MockDataCoordClient_GetMetrics_Call{Call: _e.mock.On("GetMetrics", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetMetrics_Call) Run(run func(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetMetrics_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetMetrics_Call) Return(_a0 *milvuspb.GetMetricsResponse, _a1 error) *MockDataCoordClient_GetMetrics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetMetrics_Call) RunAndReturn(run func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)) *MockDataCoordClient_GetMetrics_Call { + _c.Call.Return(run) + return _c +} + +// GetPartitionStatistics provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetPartitionStatistics(ctx context.Context, in *datapb.GetPartitionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetPartitionStatisticsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.GetPartitionStatisticsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetPartitionStatisticsRequest, ...grpc.CallOption) (*datapb.GetPartitionStatisticsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetPartitionStatisticsRequest, ...grpc.CallOption) *datapb.GetPartitionStatisticsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.GetPartitionStatisticsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetPartitionStatisticsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetPartitionStatistics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPartitionStatistics' +type MockDataCoordClient_GetPartitionStatistics_Call struct { + *mock.Call +} + +// GetPartitionStatistics is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.GetPartitionStatisticsRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetPartitionStatistics(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetPartitionStatistics_Call { + return &MockDataCoordClient_GetPartitionStatistics_Call{Call: _e.mock.On("GetPartitionStatistics", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetPartitionStatistics_Call) Run(run func(ctx context.Context, in *datapb.GetPartitionStatisticsRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetPartitionStatistics_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.GetPartitionStatisticsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetPartitionStatistics_Call) Return(_a0 *datapb.GetPartitionStatisticsResponse, _a1 error) *MockDataCoordClient_GetPartitionStatistics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetPartitionStatistics_Call) RunAndReturn(run func(context.Context, *datapb.GetPartitionStatisticsRequest, ...grpc.CallOption) (*datapb.GetPartitionStatisticsResponse, error)) *MockDataCoordClient_GetPartitionStatistics_Call { + _c.Call.Return(run) + return _c +} + +// GetRecoveryInfo provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetRecoveryInfo(ctx context.Context, in *datapb.GetRecoveryInfoRequest, opts ...grpc.CallOption) (*datapb.GetRecoveryInfoResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.GetRecoveryInfoResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetRecoveryInfoRequest, ...grpc.CallOption) (*datapb.GetRecoveryInfoResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetRecoveryInfoRequest, ...grpc.CallOption) *datapb.GetRecoveryInfoResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.GetRecoveryInfoResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetRecoveryInfoRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetRecoveryInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRecoveryInfo' +type MockDataCoordClient_GetRecoveryInfo_Call struct { + *mock.Call +} + +// GetRecoveryInfo is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.GetRecoveryInfoRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetRecoveryInfo(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetRecoveryInfo_Call { + return &MockDataCoordClient_GetRecoveryInfo_Call{Call: _e.mock.On("GetRecoveryInfo", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetRecoveryInfo_Call) Run(run func(ctx context.Context, in *datapb.GetRecoveryInfoRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetRecoveryInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.GetRecoveryInfoRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetRecoveryInfo_Call) Return(_a0 *datapb.GetRecoveryInfoResponse, _a1 error) *MockDataCoordClient_GetRecoveryInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetRecoveryInfo_Call) RunAndReturn(run func(context.Context, *datapb.GetRecoveryInfoRequest, ...grpc.CallOption) (*datapb.GetRecoveryInfoResponse, error)) *MockDataCoordClient_GetRecoveryInfo_Call { + _c.Call.Return(run) + return _c +} + +// GetRecoveryInfoV2 provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetRecoveryInfoV2(ctx context.Context, in *datapb.GetRecoveryInfoRequestV2, opts ...grpc.CallOption) (*datapb.GetRecoveryInfoResponseV2, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.GetRecoveryInfoResponseV2 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetRecoveryInfoRequestV2, ...grpc.CallOption) (*datapb.GetRecoveryInfoResponseV2, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetRecoveryInfoRequestV2, ...grpc.CallOption) *datapb.GetRecoveryInfoResponseV2); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.GetRecoveryInfoResponseV2) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetRecoveryInfoRequestV2, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetRecoveryInfoV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRecoveryInfoV2' +type MockDataCoordClient_GetRecoveryInfoV2_Call struct { + *mock.Call +} + +// GetRecoveryInfoV2 is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.GetRecoveryInfoRequestV2 +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetRecoveryInfoV2(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetRecoveryInfoV2_Call { + return &MockDataCoordClient_GetRecoveryInfoV2_Call{Call: _e.mock.On("GetRecoveryInfoV2", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetRecoveryInfoV2_Call) Run(run func(ctx context.Context, in *datapb.GetRecoveryInfoRequestV2, opts ...grpc.CallOption)) *MockDataCoordClient_GetRecoveryInfoV2_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.GetRecoveryInfoRequestV2), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetRecoveryInfoV2_Call) Return(_a0 *datapb.GetRecoveryInfoResponseV2, _a1 error) *MockDataCoordClient_GetRecoveryInfoV2_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetRecoveryInfoV2_Call) RunAndReturn(run func(context.Context, *datapb.GetRecoveryInfoRequestV2, ...grpc.CallOption) (*datapb.GetRecoveryInfoResponseV2, error)) *MockDataCoordClient_GetRecoveryInfoV2_Call { + _c.Call.Return(run) + return _c +} + +// GetSegmentIndexState provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetSegmentIndexState(ctx context.Context, in *indexpb.GetSegmentIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetSegmentIndexStateResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *indexpb.GetSegmentIndexStateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetSegmentIndexStateRequest, ...grpc.CallOption) (*indexpb.GetSegmentIndexStateResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetSegmentIndexStateRequest, ...grpc.CallOption) *indexpb.GetSegmentIndexStateResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*indexpb.GetSegmentIndexStateResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.GetSegmentIndexStateRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetSegmentIndexState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSegmentIndexState' +type MockDataCoordClient_GetSegmentIndexState_Call struct { + *mock.Call +} + +// GetSegmentIndexState is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.GetSegmentIndexStateRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetSegmentIndexState(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetSegmentIndexState_Call { + return &MockDataCoordClient_GetSegmentIndexState_Call{Call: _e.mock.On("GetSegmentIndexState", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetSegmentIndexState_Call) Run(run func(ctx context.Context, in *indexpb.GetSegmentIndexStateRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetSegmentIndexState_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.GetSegmentIndexStateRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetSegmentIndexState_Call) Return(_a0 *indexpb.GetSegmentIndexStateResponse, _a1 error) *MockDataCoordClient_GetSegmentIndexState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetSegmentIndexState_Call) RunAndReturn(run func(context.Context, *indexpb.GetSegmentIndexStateRequest, ...grpc.CallOption) (*indexpb.GetSegmentIndexStateResponse, error)) *MockDataCoordClient_GetSegmentIndexState_Call { + _c.Call.Return(run) + return _c +} + +// GetSegmentInfo provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetSegmentInfo(ctx context.Context, in *datapb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*datapb.GetSegmentInfoResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.GetSegmentInfoResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentInfoRequest, ...grpc.CallOption) (*datapb.GetSegmentInfoResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentInfoRequest, ...grpc.CallOption) *datapb.GetSegmentInfoResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.GetSegmentInfoResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetSegmentInfoRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetSegmentInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSegmentInfo' +type MockDataCoordClient_GetSegmentInfo_Call struct { + *mock.Call +} + +// GetSegmentInfo is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.GetSegmentInfoRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetSegmentInfo(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetSegmentInfo_Call { + return &MockDataCoordClient_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetSegmentInfo_Call) Run(run func(ctx context.Context, in *datapb.GetSegmentInfoRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetSegmentInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.GetSegmentInfoRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetSegmentInfo_Call) Return(_a0 *datapb.GetSegmentInfoResponse, _a1 error) *MockDataCoordClient_GetSegmentInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetSegmentInfo_Call) RunAndReturn(run func(context.Context, *datapb.GetSegmentInfoRequest, ...grpc.CallOption) (*datapb.GetSegmentInfoResponse, error)) *MockDataCoordClient_GetSegmentInfo_Call { + _c.Call.Return(run) + return _c +} + +// GetSegmentInfoChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetSegmentInfoChannel(ctx context.Context, in *datapb.GetSegmentInfoChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.StringResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentInfoChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentInfoChannelRequest, ...grpc.CallOption) *milvuspb.StringResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetSegmentInfoChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetSegmentInfoChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSegmentInfoChannel' +type MockDataCoordClient_GetSegmentInfoChannel_Call struct { + *mock.Call +} + +// GetSegmentInfoChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.GetSegmentInfoChannelRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetSegmentInfoChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetSegmentInfoChannel_Call { + return &MockDataCoordClient_GetSegmentInfoChannel_Call{Call: _e.mock.On("GetSegmentInfoChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetSegmentInfoChannel_Call) Run(run func(ctx context.Context, in *datapb.GetSegmentInfoChannelRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetSegmentInfoChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.GetSegmentInfoChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetSegmentInfoChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *MockDataCoordClient_GetSegmentInfoChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetSegmentInfoChannel_Call) RunAndReturn(run func(context.Context, *datapb.GetSegmentInfoChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)) *MockDataCoordClient_GetSegmentInfoChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetSegmentStates provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetSegmentStates(ctx context.Context, in *datapb.GetSegmentStatesRequest, opts ...grpc.CallOption) (*datapb.GetSegmentStatesResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.GetSegmentStatesResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentStatesRequest, ...grpc.CallOption) (*datapb.GetSegmentStatesResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentStatesRequest, ...grpc.CallOption) *datapb.GetSegmentStatesResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.GetSegmentStatesResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetSegmentStatesRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetSegmentStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSegmentStates' +type MockDataCoordClient_GetSegmentStates_Call struct { + *mock.Call +} + +// GetSegmentStates is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.GetSegmentStatesRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetSegmentStates(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetSegmentStates_Call { + return &MockDataCoordClient_GetSegmentStates_Call{Call: _e.mock.On("GetSegmentStates", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetSegmentStates_Call) Run(run func(ctx context.Context, in *datapb.GetSegmentStatesRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetSegmentStates_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.GetSegmentStatesRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetSegmentStates_Call) Return(_a0 *datapb.GetSegmentStatesResponse, _a1 error) *MockDataCoordClient_GetSegmentStates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetSegmentStates_Call) RunAndReturn(run func(context.Context, *datapb.GetSegmentStatesRequest, ...grpc.CallOption) (*datapb.GetSegmentStatesResponse, error)) *MockDataCoordClient_GetSegmentStates_Call { + _c.Call.Return(run) + return _c +} + +// GetSegmentsByStates provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetSegmentsByStates(ctx context.Context, in *datapb.GetSegmentsByStatesRequest, opts ...grpc.CallOption) (*datapb.GetSegmentsByStatesResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.GetSegmentsByStatesResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentsByStatesRequest, ...grpc.CallOption) (*datapb.GetSegmentsByStatesResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.GetSegmentsByStatesRequest, ...grpc.CallOption) *datapb.GetSegmentsByStatesResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.GetSegmentsByStatesResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.GetSegmentsByStatesRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetSegmentsByStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSegmentsByStates' +type MockDataCoordClient_GetSegmentsByStates_Call struct { + *mock.Call +} + +// GetSegmentsByStates is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.GetSegmentsByStatesRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetSegmentsByStates(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetSegmentsByStates_Call { + return &MockDataCoordClient_GetSegmentsByStates_Call{Call: _e.mock.On("GetSegmentsByStates", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetSegmentsByStates_Call) Run(run func(ctx context.Context, in *datapb.GetSegmentsByStatesRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetSegmentsByStates_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.GetSegmentsByStatesRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetSegmentsByStates_Call) Return(_a0 *datapb.GetSegmentsByStatesResponse, _a1 error) *MockDataCoordClient_GetSegmentsByStates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetSegmentsByStates_Call) RunAndReturn(run func(context.Context, *datapb.GetSegmentsByStatesRequest, ...grpc.CallOption) (*datapb.GetSegmentsByStatesResponse, error)) *MockDataCoordClient_GetSegmentsByStates_Call { + _c.Call.Return(run) + return _c +} + +// GetStatisticsChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.StringResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) *milvuspb.StringResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetStatisticsChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatisticsChannel' +type MockDataCoordClient_GetStatisticsChannel_Call struct { + *mock.Call +} + +// GetStatisticsChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.GetStatisticsChannelRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetStatisticsChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetStatisticsChannel_Call { + return &MockDataCoordClient_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetStatisticsChannel_Call) Run(run func(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetStatisticsChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *MockDataCoordClient_GetStatisticsChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)) *MockDataCoordClient_GetStatisticsChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetTimeTickChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.StringResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) *milvuspb.StringResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_GetTimeTickChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTimeTickChannel' +type MockDataCoordClient_GetTimeTickChannel_Call struct { + *mock.Call +} + +// GetTimeTickChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.GetTimeTickChannelRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) GetTimeTickChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_GetTimeTickChannel_Call { + return &MockDataCoordClient_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_GetTimeTickChannel_Call) Run(run func(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption)) *MockDataCoordClient_GetTimeTickChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.GetTimeTickChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_GetTimeTickChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *MockDataCoordClient_GetTimeTickChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_GetTimeTickChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)) *MockDataCoordClient_GetTimeTickChannel_Call { + _c.Call.Return(run) + return _c +} + +// Import provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) Import(ctx context.Context, in *datapb.ImportTaskRequest, opts ...grpc.CallOption) (*datapb.ImportTaskResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.ImportTaskResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ImportTaskRequest, ...grpc.CallOption) (*datapb.ImportTaskResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ImportTaskRequest, ...grpc.CallOption) *datapb.ImportTaskResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.ImportTaskResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.ImportTaskRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_Import_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Import' +type MockDataCoordClient_Import_Call struct { + *mock.Call +} + +// Import is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.ImportTaskRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) Import(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_Import_Call { + return &MockDataCoordClient_Import_Call{Call: _e.mock.On("Import", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_Import_Call) Run(run func(ctx context.Context, in *datapb.ImportTaskRequest, opts ...grpc.CallOption)) *MockDataCoordClient_Import_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.ImportTaskRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_Import_Call) Return(_a0 *datapb.ImportTaskResponse, _a1 error) *MockDataCoordClient_Import_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_Import_Call) RunAndReturn(run func(context.Context, *datapb.ImportTaskRequest, ...grpc.CallOption) (*datapb.ImportTaskResponse, error)) *MockDataCoordClient_Import_Call { + _c.Call.Return(run) + return _c +} + +// ManualCompaction provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) ManualCompaction(ctx context.Context, in *milvuspb.ManualCompactionRequest, opts ...grpc.CallOption) (*milvuspb.ManualCompactionResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ManualCompactionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ManualCompactionRequest, ...grpc.CallOption) (*milvuspb.ManualCompactionResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ManualCompactionRequest, ...grpc.CallOption) *milvuspb.ManualCompactionResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ManualCompactionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ManualCompactionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_ManualCompaction_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ManualCompaction' +type MockDataCoordClient_ManualCompaction_Call struct { + *mock.Call +} + +// ManualCompaction is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.ManualCompactionRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) ManualCompaction(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_ManualCompaction_Call { + return &MockDataCoordClient_ManualCompaction_Call{Call: _e.mock.On("ManualCompaction", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_ManualCompaction_Call) Run(run func(ctx context.Context, in *milvuspb.ManualCompactionRequest, opts ...grpc.CallOption)) *MockDataCoordClient_ManualCompaction_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.ManualCompactionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_ManualCompaction_Call) Return(_a0 *milvuspb.ManualCompactionResponse, _a1 error) *MockDataCoordClient_ManualCompaction_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_ManualCompaction_Call) RunAndReturn(run func(context.Context, *milvuspb.ManualCompactionRequest, ...grpc.CallOption) (*milvuspb.ManualCompactionResponse, error)) *MockDataCoordClient_ManualCompaction_Call { + _c.Call.Return(run) + return _c +} + +// MarkSegmentsDropped provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) MarkSegmentsDropped(ctx context.Context, in *datapb.MarkSegmentsDroppedRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.MarkSegmentsDroppedRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.MarkSegmentsDroppedRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.MarkSegmentsDroppedRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_MarkSegmentsDropped_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MarkSegmentsDropped' +type MockDataCoordClient_MarkSegmentsDropped_Call struct { + *mock.Call +} + +// MarkSegmentsDropped is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.MarkSegmentsDroppedRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) MarkSegmentsDropped(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_MarkSegmentsDropped_Call { + return &MockDataCoordClient_MarkSegmentsDropped_Call{Call: _e.mock.On("MarkSegmentsDropped", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_MarkSegmentsDropped_Call) Run(run func(ctx context.Context, in *datapb.MarkSegmentsDroppedRequest, opts ...grpc.CallOption)) *MockDataCoordClient_MarkSegmentsDropped_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.MarkSegmentsDroppedRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_MarkSegmentsDropped_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataCoordClient_MarkSegmentsDropped_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_MarkSegmentsDropped_Call) RunAndReturn(run func(context.Context, *datapb.MarkSegmentsDroppedRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataCoordClient_MarkSegmentsDropped_Call { + _c.Call.Return(run) + return _c +} + +// ReportDataNodeTtMsgs provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) ReportDataNodeTtMsgs(ctx context.Context, in *datapb.ReportDataNodeTtMsgsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ReportDataNodeTtMsgsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ReportDataNodeTtMsgsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.ReportDataNodeTtMsgsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_ReportDataNodeTtMsgs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportDataNodeTtMsgs' +type MockDataCoordClient_ReportDataNodeTtMsgs_Call struct { + *mock.Call +} + +// ReportDataNodeTtMsgs is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.ReportDataNodeTtMsgsRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) ReportDataNodeTtMsgs(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_ReportDataNodeTtMsgs_Call { + return &MockDataCoordClient_ReportDataNodeTtMsgs_Call{Call: _e.mock.On("ReportDataNodeTtMsgs", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_ReportDataNodeTtMsgs_Call) Run(run func(ctx context.Context, in *datapb.ReportDataNodeTtMsgsRequest, opts ...grpc.CallOption)) *MockDataCoordClient_ReportDataNodeTtMsgs_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.ReportDataNodeTtMsgsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_ReportDataNodeTtMsgs_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataCoordClient_ReportDataNodeTtMsgs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_ReportDataNodeTtMsgs_Call) RunAndReturn(run func(context.Context, *datapb.ReportDataNodeTtMsgsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataCoordClient_ReportDataNodeTtMsgs_Call { + _c.Call.Return(run) + return _c +} + +// SaveBinlogPaths provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) SaveBinlogPaths(ctx context.Context, in *datapb.SaveBinlogPathsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.SaveBinlogPathsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.SaveBinlogPathsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.SaveBinlogPathsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_SaveBinlogPaths_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveBinlogPaths' +type MockDataCoordClient_SaveBinlogPaths_Call struct { + *mock.Call +} + +// SaveBinlogPaths is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.SaveBinlogPathsRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) SaveBinlogPaths(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_SaveBinlogPaths_Call { + return &MockDataCoordClient_SaveBinlogPaths_Call{Call: _e.mock.On("SaveBinlogPaths", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_SaveBinlogPaths_Call) Run(run func(ctx context.Context, in *datapb.SaveBinlogPathsRequest, opts ...grpc.CallOption)) *MockDataCoordClient_SaveBinlogPaths_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.SaveBinlogPathsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_SaveBinlogPaths_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataCoordClient_SaveBinlogPaths_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_SaveBinlogPaths_Call) RunAndReturn(run func(context.Context, *datapb.SaveBinlogPathsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataCoordClient_SaveBinlogPaths_Call { + _c.Call.Return(run) + return _c +} + +// SaveImportSegment provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) SaveImportSegment(ctx context.Context, in *datapb.SaveImportSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.SaveImportSegmentRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.SaveImportSegmentRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.SaveImportSegmentRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_SaveImportSegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveImportSegment' +type MockDataCoordClient_SaveImportSegment_Call struct { + *mock.Call +} + +// SaveImportSegment is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.SaveImportSegmentRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) SaveImportSegment(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_SaveImportSegment_Call { + return &MockDataCoordClient_SaveImportSegment_Call{Call: _e.mock.On("SaveImportSegment", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_SaveImportSegment_Call) Run(run func(ctx context.Context, in *datapb.SaveImportSegmentRequest, opts ...grpc.CallOption)) *MockDataCoordClient_SaveImportSegment_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.SaveImportSegmentRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_SaveImportSegment_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataCoordClient_SaveImportSegment_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_SaveImportSegment_Call) RunAndReturn(run func(context.Context, *datapb.SaveImportSegmentRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataCoordClient_SaveImportSegment_Call { + _c.Call.Return(run) + return _c +} + +// SetSegmentState provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) SetSegmentState(ctx context.Context, in *datapb.SetSegmentStateRequest, opts ...grpc.CallOption) (*datapb.SetSegmentStateResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.SetSegmentStateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.SetSegmentStateRequest, ...grpc.CallOption) (*datapb.SetSegmentStateResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.SetSegmentStateRequest, ...grpc.CallOption) *datapb.SetSegmentStateResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.SetSegmentStateResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.SetSegmentStateRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_SetSegmentState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetSegmentState' +type MockDataCoordClient_SetSegmentState_Call struct { + *mock.Call +} + +// SetSegmentState is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.SetSegmentStateRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) SetSegmentState(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_SetSegmentState_Call { + return &MockDataCoordClient_SetSegmentState_Call{Call: _e.mock.On("SetSegmentState", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_SetSegmentState_Call) Run(run func(ctx context.Context, in *datapb.SetSegmentStateRequest, opts ...grpc.CallOption)) *MockDataCoordClient_SetSegmentState_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.SetSegmentStateRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_SetSegmentState_Call) Return(_a0 *datapb.SetSegmentStateResponse, _a1 error) *MockDataCoordClient_SetSegmentState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_SetSegmentState_Call) RunAndReturn(run func(context.Context, *datapb.SetSegmentStateRequest, ...grpc.CallOption) (*datapb.SetSegmentStateResponse, error)) *MockDataCoordClient_SetSegmentState_Call { + _c.Call.Return(run) + return _c +} + +// ShowConfigurations provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *internalpb.ShowConfigurationsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) *internalpb.ShowConfigurationsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internalpb.ShowConfigurationsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_ShowConfigurations_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowConfigurations' +type MockDataCoordClient_ShowConfigurations_Call struct { + *mock.Call +} + +// ShowConfigurations is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.ShowConfigurationsRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) ShowConfigurations(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_ShowConfigurations_Call { + return &MockDataCoordClient_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_ShowConfigurations_Call) Run(run func(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption)) *MockDataCoordClient_ShowConfigurations_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.ShowConfigurationsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_ShowConfigurations_Call) Return(_a0 *internalpb.ShowConfigurationsResponse, _a1 error) *MockDataCoordClient_ShowConfigurations_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_ShowConfigurations_Call) RunAndReturn(run func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error)) *MockDataCoordClient_ShowConfigurations_Call { + _c.Call.Return(run) + return _c +} + +// UnsetIsImportingState provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) UnsetIsImportingState(ctx context.Context, in *datapb.UnsetIsImportingStateRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.UnsetIsImportingStateRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.UnsetIsImportingStateRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.UnsetIsImportingStateRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_UnsetIsImportingState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UnsetIsImportingState' +type MockDataCoordClient_UnsetIsImportingState_Call struct { + *mock.Call +} + +// UnsetIsImportingState is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.UnsetIsImportingStateRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) UnsetIsImportingState(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_UnsetIsImportingState_Call { + return &MockDataCoordClient_UnsetIsImportingState_Call{Call: _e.mock.On("UnsetIsImportingState", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_UnsetIsImportingState_Call) Run(run func(ctx context.Context, in *datapb.UnsetIsImportingStateRequest, opts ...grpc.CallOption)) *MockDataCoordClient_UnsetIsImportingState_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.UnsetIsImportingStateRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_UnsetIsImportingState_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataCoordClient_UnsetIsImportingState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_UnsetIsImportingState_Call) RunAndReturn(run func(context.Context, *datapb.UnsetIsImportingStateRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataCoordClient_UnsetIsImportingState_Call { + _c.Call.Return(run) + return _c +} + +// UpdateChannelCheckpoint provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) UpdateChannelCheckpoint(ctx context.Context, in *datapb.UpdateChannelCheckpointRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.UpdateChannelCheckpointRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.UpdateChannelCheckpointRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.UpdateChannelCheckpointRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_UpdateChannelCheckpoint_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateChannelCheckpoint' +type MockDataCoordClient_UpdateChannelCheckpoint_Call struct { + *mock.Call +} + +// UpdateChannelCheckpoint is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.UpdateChannelCheckpointRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) UpdateChannelCheckpoint(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_UpdateChannelCheckpoint_Call { + return &MockDataCoordClient_UpdateChannelCheckpoint_Call{Call: _e.mock.On("UpdateChannelCheckpoint", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_UpdateChannelCheckpoint_Call) Run(run func(ctx context.Context, in *datapb.UpdateChannelCheckpointRequest, opts ...grpc.CallOption)) *MockDataCoordClient_UpdateChannelCheckpoint_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.UpdateChannelCheckpointRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_UpdateChannelCheckpoint_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataCoordClient_UpdateChannelCheckpoint_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_UpdateChannelCheckpoint_Call) RunAndReturn(run func(context.Context, *datapb.UpdateChannelCheckpointRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataCoordClient_UpdateChannelCheckpoint_Call { + _c.Call.Return(run) + return _c +} + +// UpdateSegmentStatistics provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) UpdateSegmentStatistics(ctx context.Context, in *datapb.UpdateSegmentStatisticsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.UpdateSegmentStatisticsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.UpdateSegmentStatisticsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.UpdateSegmentStatisticsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_UpdateSegmentStatistics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateSegmentStatistics' +type MockDataCoordClient_UpdateSegmentStatistics_Call struct { + *mock.Call +} + +// UpdateSegmentStatistics is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.UpdateSegmentStatisticsRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) UpdateSegmentStatistics(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_UpdateSegmentStatistics_Call { + return &MockDataCoordClient_UpdateSegmentStatistics_Call{Call: _e.mock.On("UpdateSegmentStatistics", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_UpdateSegmentStatistics_Call) Run(run func(ctx context.Context, in *datapb.UpdateSegmentStatisticsRequest, opts ...grpc.CallOption)) *MockDataCoordClient_UpdateSegmentStatistics_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.UpdateSegmentStatisticsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_UpdateSegmentStatistics_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataCoordClient_UpdateSegmentStatistics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_UpdateSegmentStatistics_Call) RunAndReturn(run func(context.Context, *datapb.UpdateSegmentStatisticsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataCoordClient_UpdateSegmentStatistics_Call { + _c.Call.Return(run) + return _c +} + +// WatchChannels provides a mock function with given fields: ctx, in, opts +func (_m *MockDataCoordClient) WatchChannels(ctx context.Context, in *datapb.WatchChannelsRequest, opts ...grpc.CallOption) (*datapb.WatchChannelsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.WatchChannelsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.WatchChannelsRequest, ...grpc.CallOption) (*datapb.WatchChannelsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.WatchChannelsRequest, ...grpc.CallOption) *datapb.WatchChannelsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.WatchChannelsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.WatchChannelsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataCoordClient_WatchChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchChannels' +type MockDataCoordClient_WatchChannels_Call struct { + *mock.Call +} + +// WatchChannels is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.WatchChannelsRequest +// - opts ...grpc.CallOption +func (_e *MockDataCoordClient_Expecter) WatchChannels(ctx interface{}, in interface{}, opts ...interface{}) *MockDataCoordClient_WatchChannels_Call { + return &MockDataCoordClient_WatchChannels_Call{Call: _e.mock.On("WatchChannels", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataCoordClient_WatchChannels_Call) Run(run func(ctx context.Context, in *datapb.WatchChannelsRequest, opts ...grpc.CallOption)) *MockDataCoordClient_WatchChannels_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.WatchChannelsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataCoordClient_WatchChannels_Call) Return(_a0 *datapb.WatchChannelsResponse, _a1 error) *MockDataCoordClient_WatchChannels_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataCoordClient_WatchChannels_Call) RunAndReturn(run func(context.Context, *datapb.WatchChannelsRequest, ...grpc.CallOption) (*datapb.WatchChannelsResponse, error)) *MockDataCoordClient_WatchChannels_Call { + _c.Call.Return(run) + return _c +} + +// NewMockDataCoordClient creates a new instance of MockDataCoordClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockDataCoordClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockDataCoordClient { + mock := &MockDataCoordClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/mock_datanode.go b/internal/mocks/mock_datanode.go index 2a02b4e48a5b7..64d298a1378da 100644 --- a/internal/mocks/mock_datanode.go +++ b/internal/mocks/mock_datanode.go @@ -32,17 +32,17 @@ func (_m *MockDataNode) EXPECT() *MockDataNode_Expecter { return &MockDataNode_Expecter{mock: &_m.Mock} } -// AddImportSegment provides a mock function with given fields: ctx, req -func (_m *MockDataNode) AddImportSegment(ctx context.Context, req *datapb.AddImportSegmentRequest) (*datapb.AddImportSegmentResponse, error) { - ret := _m.Called(ctx, req) +// AddImportSegment provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) AddImportSegment(_a0 context.Context, _a1 *datapb.AddImportSegmentRequest) (*datapb.AddImportSegmentResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.AddImportSegmentResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.AddImportSegmentRequest) (*datapb.AddImportSegmentResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.AddImportSegmentRequest) *datapb.AddImportSegmentResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.AddImportSegmentResponse) @@ -50,7 +50,7 @@ func (_m *MockDataNode) AddImportSegment(ctx context.Context, req *datapb.AddImp } if rf, ok := ret.Get(1).(func(context.Context, *datapb.AddImportSegmentRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -64,13 +64,13 @@ type MockDataNode_AddImportSegment_Call struct { } // AddImportSegment is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.AddImportSegmentRequest -func (_e *MockDataNode_Expecter) AddImportSegment(ctx interface{}, req interface{}) *MockDataNode_AddImportSegment_Call { - return &MockDataNode_AddImportSegment_Call{Call: _e.mock.On("AddImportSegment", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.AddImportSegmentRequest +func (_e *MockDataNode_Expecter) AddImportSegment(_a0 interface{}, _a1 interface{}) *MockDataNode_AddImportSegment_Call { + return &MockDataNode_AddImportSegment_Call{Call: _e.mock.On("AddImportSegment", _a0, _a1)} } -func (_c *MockDataNode_AddImportSegment_Call) Run(run func(ctx context.Context, req *datapb.AddImportSegmentRequest)) *MockDataNode_AddImportSegment_Call { +func (_c *MockDataNode_AddImportSegment_Call) Run(run func(_a0 context.Context, _a1 *datapb.AddImportSegmentRequest)) *MockDataNode_AddImportSegment_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.AddImportSegmentRequest)) }) @@ -87,17 +87,72 @@ func (_c *MockDataNode_AddImportSegment_Call) RunAndReturn(run func(context.Cont return _c } -// Compaction provides a mock function with given fields: ctx, req -func (_m *MockDataNode) Compaction(ctx context.Context, req *datapb.CompactionPlan) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// CheckChannelOperationProgress provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) CheckChannelOperationProgress(_a0 context.Context, _a1 *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *datapb.ChannelOperationProgressResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ChannelWatchInfo) *datapb.ChannelOperationProgressResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.ChannelOperationProgressResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.ChannelWatchInfo) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNode_CheckChannelOperationProgress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckChannelOperationProgress' +type MockDataNode_CheckChannelOperationProgress_Call struct { + *mock.Call +} + +// CheckChannelOperationProgress is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *datapb.ChannelWatchInfo +func (_e *MockDataNode_Expecter) CheckChannelOperationProgress(_a0 interface{}, _a1 interface{}) *MockDataNode_CheckChannelOperationProgress_Call { + return &MockDataNode_CheckChannelOperationProgress_Call{Call: _e.mock.On("CheckChannelOperationProgress", _a0, _a1)} +} + +func (_c *MockDataNode_CheckChannelOperationProgress_Call) Run(run func(_a0 context.Context, _a1 *datapb.ChannelWatchInfo)) *MockDataNode_CheckChannelOperationProgress_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*datapb.ChannelWatchInfo)) + }) + return _c +} + +func (_c *MockDataNode_CheckChannelOperationProgress_Call) Return(_a0 *datapb.ChannelOperationProgressResponse, _a1 error) *MockDataNode_CheckChannelOperationProgress_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNode_CheckChannelOperationProgress_Call) RunAndReturn(run func(context.Context, *datapb.ChannelWatchInfo) (*datapb.ChannelOperationProgressResponse, error)) *MockDataNode_CheckChannelOperationProgress_Call { + _c.Call.Return(run) + return _c +} + +// Compaction provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) Compaction(_a0 context.Context, _a1 *datapb.CompactionPlan) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.CompactionPlan) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.CompactionPlan) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -105,7 +160,7 @@ func (_m *MockDataNode) Compaction(ctx context.Context, req *datapb.CompactionPl } if rf, ok := ret.Get(1).(func(context.Context, *datapb.CompactionPlan) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -119,13 +174,13 @@ type MockDataNode_Compaction_Call struct { } // Compaction is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.CompactionPlan -func (_e *MockDataNode_Expecter) Compaction(ctx interface{}, req interface{}) *MockDataNode_Compaction_Call { - return &MockDataNode_Compaction_Call{Call: _e.mock.On("Compaction", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.CompactionPlan +func (_e *MockDataNode_Expecter) Compaction(_a0 interface{}, _a1 interface{}) *MockDataNode_Compaction_Call { + return &MockDataNode_Compaction_Call{Call: _e.mock.On("Compaction", _a0, _a1)} } -func (_c *MockDataNode_Compaction_Call) Run(run func(ctx context.Context, req *datapb.CompactionPlan)) *MockDataNode_Compaction_Call { +func (_c *MockDataNode_Compaction_Call) Run(run func(_a0 context.Context, _a1 *datapb.CompactionPlan)) *MockDataNode_Compaction_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.CompactionPlan)) }) @@ -142,17 +197,72 @@ func (_c *MockDataNode_Compaction_Call) RunAndReturn(run func(context.Context, * return _c } -// FlushSegments provides a mock function with given fields: ctx, req -func (_m *MockDataNode) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// FlushChannels provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) FlushChannels(_a0 context.Context, _a1 *datapb.FlushChannelsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.FlushChannelsRequest) (*commonpb.Status, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.FlushChannelsRequest) *commonpb.Status); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.FlushChannelsRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNode_FlushChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FlushChannels' +type MockDataNode_FlushChannels_Call struct { + *mock.Call +} + +// FlushChannels is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *datapb.FlushChannelsRequest +func (_e *MockDataNode_Expecter) FlushChannels(_a0 interface{}, _a1 interface{}) *MockDataNode_FlushChannels_Call { + return &MockDataNode_FlushChannels_Call{Call: _e.mock.On("FlushChannels", _a0, _a1)} +} + +func (_c *MockDataNode_FlushChannels_Call) Run(run func(_a0 context.Context, _a1 *datapb.FlushChannelsRequest)) *MockDataNode_FlushChannels_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*datapb.FlushChannelsRequest)) + }) + return _c +} + +func (_c *MockDataNode_FlushChannels_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataNode_FlushChannels_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNode_FlushChannels_Call) RunAndReturn(run func(context.Context, *datapb.FlushChannelsRequest) (*commonpb.Status, error)) *MockDataNode_FlushChannels_Call { + _c.Call.Return(run) + return _c +} + +// FlushSegments provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) FlushSegments(_a0 context.Context, _a1 *datapb.FlushSegmentsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.FlushSegmentsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.FlushSegmentsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -160,7 +270,7 @@ func (_m *MockDataNode) FlushSegments(ctx context.Context, req *datapb.FlushSegm } if rf, ok := ret.Get(1).(func(context.Context, *datapb.FlushSegmentsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -174,13 +284,13 @@ type MockDataNode_FlushSegments_Call struct { } // FlushSegments is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.FlushSegmentsRequest -func (_e *MockDataNode_Expecter) FlushSegments(ctx interface{}, req interface{}) *MockDataNode_FlushSegments_Call { - return &MockDataNode_FlushSegments_Call{Call: _e.mock.On("FlushSegments", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.FlushSegmentsRequest +func (_e *MockDataNode_Expecter) FlushSegments(_a0 interface{}, _a1 interface{}) *MockDataNode_FlushSegments_Call { + return &MockDataNode_FlushSegments_Call{Call: _e.mock.On("FlushSegments", _a0, _a1)} } -func (_c *MockDataNode_FlushSegments_Call) Run(run func(ctx context.Context, req *datapb.FlushSegmentsRequest)) *MockDataNode_FlushSegments_Call { +func (_c *MockDataNode_FlushSegments_Call) Run(run func(_a0 context.Context, _a1 *datapb.FlushSegmentsRequest)) *MockDataNode_FlushSegments_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.FlushSegmentsRequest)) }) @@ -238,17 +348,17 @@ func (_c *MockDataNode_GetAddress_Call) RunAndReturn(run func() string) *MockDat return _c } -// GetCompactionState provides a mock function with given fields: ctx, req -func (_m *MockDataNode) GetCompactionState(ctx context.Context, req *datapb.CompactionStateRequest) (*datapb.CompactionStateResponse, error) { - ret := _m.Called(ctx, req) +// GetCompactionState provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) GetCompactionState(_a0 context.Context, _a1 *datapb.CompactionStateRequest) (*datapb.CompactionStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.CompactionStateResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.CompactionStateRequest) (*datapb.CompactionStateResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.CompactionStateRequest) *datapb.CompactionStateResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.CompactionStateResponse) @@ -256,7 +366,7 @@ func (_m *MockDataNode) GetCompactionState(ctx context.Context, req *datapb.Comp } if rf, ok := ret.Get(1).(func(context.Context, *datapb.CompactionStateRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -270,13 +380,13 @@ type MockDataNode_GetCompactionState_Call struct { } // GetCompactionState is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.CompactionStateRequest -func (_e *MockDataNode_Expecter) GetCompactionState(ctx interface{}, req interface{}) *MockDataNode_GetCompactionState_Call { - return &MockDataNode_GetCompactionState_Call{Call: _e.mock.On("GetCompactionState", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.CompactionStateRequest +func (_e *MockDataNode_Expecter) GetCompactionState(_a0 interface{}, _a1 interface{}) *MockDataNode_GetCompactionState_Call { + return &MockDataNode_GetCompactionState_Call{Call: _e.mock.On("GetCompactionState", _a0, _a1)} } -func (_c *MockDataNode_GetCompactionState_Call) Run(run func(ctx context.Context, req *datapb.CompactionStateRequest)) *MockDataNode_GetCompactionState_Call { +func (_c *MockDataNode_GetCompactionState_Call) Run(run func(_a0 context.Context, _a1 *datapb.CompactionStateRequest)) *MockDataNode_GetCompactionState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.CompactionStateRequest)) }) @@ -293,25 +403,25 @@ func (_c *MockDataNode_GetCompactionState_Call) RunAndReturn(run func(context.Co return _c } -// GetComponentStates provides a mock function with given fields: ctx -func (_m *MockDataNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret := _m.Called(ctx) +// GetComponentStates provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) GetComponentStates(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ComponentStates var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.ComponentStates, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.ComponentStates); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) *milvuspb.ComponentStates); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ComponentStates) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -325,14 +435,15 @@ type MockDataNode_GetComponentStates_Call struct { } // GetComponentStates is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockDataNode_Expecter) GetComponentStates(ctx interface{}) *MockDataNode_GetComponentStates_Call { - return &MockDataNode_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx)} +// - _a0 context.Context +// - _a1 *milvuspb.GetComponentStatesRequest +func (_e *MockDataNode_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *MockDataNode_GetComponentStates_Call { + return &MockDataNode_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)} } -func (_c *MockDataNode_GetComponentStates_Call) Run(run func(ctx context.Context)) *MockDataNode_GetComponentStates_Call { +func (_c *MockDataNode_GetComponentStates_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest)) *MockDataNode_GetComponentStates_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest)) }) return _c } @@ -342,22 +453,22 @@ func (_c *MockDataNode_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentSt return _c } -func (_c *MockDataNode_GetComponentStates_Call) RunAndReturn(run func(context.Context) (*milvuspb.ComponentStates, error)) *MockDataNode_GetComponentStates_Call { +func (_c *MockDataNode_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)) *MockDataNode_GetComponentStates_Call { _c.Call.Return(run) return _c } -// GetMetrics provides a mock function with given fields: ctx, req -func (_m *MockDataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret := _m.Called(ctx, req) +// GetMetrics provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) GetMetrics(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetMetricsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) *milvuspb.GetMetricsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) @@ -365,7 +476,7 @@ func (_m *MockDataNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetrics } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -379,13 +490,13 @@ type MockDataNode_GetMetrics_Call struct { } // GetMetrics is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetMetricsRequest -func (_e *MockDataNode_Expecter) GetMetrics(ctx interface{}, req interface{}) *MockDataNode_GetMetrics_Call { - return &MockDataNode_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetMetricsRequest +func (_e *MockDataNode_Expecter) GetMetrics(_a0 interface{}, _a1 interface{}) *MockDataNode_GetMetrics_Call { + return &MockDataNode_GetMetrics_Call{Call: _e.mock.On("GetMetrics", _a0, _a1)} } -func (_c *MockDataNode_GetMetrics_Call) Run(run func(ctx context.Context, req *milvuspb.GetMetricsRequest)) *MockDataNode_GetMetrics_Call { +func (_c *MockDataNode_GetMetrics_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest)) *MockDataNode_GetMetrics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest)) }) @@ -443,25 +554,25 @@ func (_c *MockDataNode_GetStateCode_Call) RunAndReturn(run func() commonpb.State return _c } -// GetStatisticsChannel provides a mock function with given fields: ctx -func (_m *MockDataNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret := _m.Called(ctx) +// GetStatisticsChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) GetStatisticsChannel(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.StringResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.StringResponse, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.StringResponse); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) *milvuspb.StringResponse); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.StringResponse) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -475,14 +586,15 @@ type MockDataNode_GetStatisticsChannel_Call struct { } // GetStatisticsChannel is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockDataNode_Expecter) GetStatisticsChannel(ctx interface{}) *MockDataNode_GetStatisticsChannel_Call { - return &MockDataNode_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", ctx)} +// - _a0 context.Context +// - _a1 *internalpb.GetStatisticsChannelRequest +func (_e *MockDataNode_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *MockDataNode_GetStatisticsChannel_Call { + return &MockDataNode_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)} } -func (_c *MockDataNode_GetStatisticsChannel_Call) Run(run func(ctx context.Context)) *MockDataNode_GetStatisticsChannel_Call { +func (_c *MockDataNode_GetStatisticsChannel_Call) Run(run func(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest)) *MockDataNode_GetStatisticsChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest)) }) return _c } @@ -492,22 +604,22 @@ func (_c *MockDataNode_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringRes return _c } -func (_c *MockDataNode_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context) (*milvuspb.StringResponse, error)) *MockDataNode_GetStatisticsChannel_Call { +func (_c *MockDataNode_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)) *MockDataNode_GetStatisticsChannel_Call { _c.Call.Return(run) return _c } -// Import provides a mock function with given fields: ctx, req -func (_m *MockDataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// Import provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) Import(_a0 context.Context, _a1 *datapb.ImportTaskRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.ImportTaskRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.ImportTaskRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -515,7 +627,7 @@ func (_m *MockDataNode) Import(ctx context.Context, req *datapb.ImportTaskReques } if rf, ok := ret.Get(1).(func(context.Context, *datapb.ImportTaskRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -529,13 +641,13 @@ type MockDataNode_Import_Call struct { } // Import is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.ImportTaskRequest -func (_e *MockDataNode_Expecter) Import(ctx interface{}, req interface{}) *MockDataNode_Import_Call { - return &MockDataNode_Import_Call{Call: _e.mock.On("Import", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.ImportTaskRequest +func (_e *MockDataNode_Expecter) Import(_a0 interface{}, _a1 interface{}) *MockDataNode_Import_Call { + return &MockDataNode_Import_Call{Call: _e.mock.On("Import", _a0, _a1)} } -func (_c *MockDataNode_Import_Call) Run(run func(ctx context.Context, req *datapb.ImportTaskRequest)) *MockDataNode_Import_Call { +func (_c *MockDataNode_Import_Call) Run(run func(_a0 context.Context, _a1 *datapb.ImportTaskRequest)) *MockDataNode_Import_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.ImportTaskRequest)) }) @@ -593,6 +705,61 @@ func (_c *MockDataNode_Init_Call) RunAndReturn(run func() error) *MockDataNode_I return _c } +// NotifyChannelOperation provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) NotifyChannelOperation(_a0 context.Context, _a1 *datapb.ChannelOperationsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ChannelOperationsRequest) (*commonpb.Status, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ChannelOperationsRequest) *commonpb.Status); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.ChannelOperationsRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNode_NotifyChannelOperation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NotifyChannelOperation' +type MockDataNode_NotifyChannelOperation_Call struct { + *mock.Call +} + +// NotifyChannelOperation is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *datapb.ChannelOperationsRequest +func (_e *MockDataNode_Expecter) NotifyChannelOperation(_a0 interface{}, _a1 interface{}) *MockDataNode_NotifyChannelOperation_Call { + return &MockDataNode_NotifyChannelOperation_Call{Call: _e.mock.On("NotifyChannelOperation", _a0, _a1)} +} + +func (_c *MockDataNode_NotifyChannelOperation_Call) Run(run func(_a0 context.Context, _a1 *datapb.ChannelOperationsRequest)) *MockDataNode_NotifyChannelOperation_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*datapb.ChannelOperationsRequest)) + }) + return _c +} + +func (_c *MockDataNode_NotifyChannelOperation_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataNode_NotifyChannelOperation_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNode_NotifyChannelOperation_Call) RunAndReturn(run func(context.Context, *datapb.ChannelOperationsRequest) (*commonpb.Status, error)) *MockDataNode_NotifyChannelOperation_Call { + _c.Call.Return(run) + return _c +} + // Register provides a mock function with given fields: func (_m *MockDataNode) Register() error { ret := _m.Called() @@ -634,17 +801,17 @@ func (_c *MockDataNode_Register_Call) RunAndReturn(run func() error) *MockDataNo return _c } -// ResendSegmentStats provides a mock function with given fields: ctx, req -func (_m *MockDataNode) ResendSegmentStats(ctx context.Context, req *datapb.ResendSegmentStatsRequest) (*datapb.ResendSegmentStatsResponse, error) { - ret := _m.Called(ctx, req) +// ResendSegmentStats provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) ResendSegmentStats(_a0 context.Context, _a1 *datapb.ResendSegmentStatsRequest) (*datapb.ResendSegmentStatsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *datapb.ResendSegmentStatsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.ResendSegmentStatsRequest) (*datapb.ResendSegmentStatsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.ResendSegmentStatsRequest) *datapb.ResendSegmentStatsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*datapb.ResendSegmentStatsResponse) @@ -652,7 +819,7 @@ func (_m *MockDataNode) ResendSegmentStats(ctx context.Context, req *datapb.Rese } if rf, ok := ret.Get(1).(func(context.Context, *datapb.ResendSegmentStatsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -666,13 +833,13 @@ type MockDataNode_ResendSegmentStats_Call struct { } // ResendSegmentStats is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.ResendSegmentStatsRequest -func (_e *MockDataNode_Expecter) ResendSegmentStats(ctx interface{}, req interface{}) *MockDataNode_ResendSegmentStats_Call { - return &MockDataNode_ResendSegmentStats_Call{Call: _e.mock.On("ResendSegmentStats", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.ResendSegmentStatsRequest +func (_e *MockDataNode_Expecter) ResendSegmentStats(_a0 interface{}, _a1 interface{}) *MockDataNode_ResendSegmentStats_Call { + return &MockDataNode_ResendSegmentStats_Call{Call: _e.mock.On("ResendSegmentStats", _a0, _a1)} } -func (_c *MockDataNode_ResendSegmentStats_Call) Run(run func(ctx context.Context, req *datapb.ResendSegmentStatsRequest)) *MockDataNode_ResendSegmentStats_Call { +func (_c *MockDataNode_ResendSegmentStats_Call) Run(run func(_a0 context.Context, _a1 *datapb.ResendSegmentStatsRequest)) *MockDataNode_ResendSegmentStats_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.ResendSegmentStatsRequest)) }) @@ -722,12 +889,12 @@ func (_c *MockDataNode_SetAddress_Call) RunAndReturn(run func(string)) *MockData return _c } -// SetDataCoord provides a mock function with given fields: dataCoord -func (_m *MockDataNode) SetDataCoord(dataCoord types.DataCoord) error { +// SetDataCoordClient provides a mock function with given fields: dataCoord +func (_m *MockDataNode) SetDataCoordClient(dataCoord types.DataCoordClient) error { ret := _m.Called(dataCoord) var r0 error - if rf, ok := ret.Get(0).(func(types.DataCoord) error); ok { + if rf, ok := ret.Get(0).(func(types.DataCoordClient) error); ok { r0 = rf(dataCoord) } else { r0 = ret.Error(0) @@ -736,30 +903,30 @@ func (_m *MockDataNode) SetDataCoord(dataCoord types.DataCoord) error { return r0 } -// MockDataNode_SetDataCoord_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetDataCoord' -type MockDataNode_SetDataCoord_Call struct { +// MockDataNode_SetDataCoordClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetDataCoordClient' +type MockDataNode_SetDataCoordClient_Call struct { *mock.Call } -// SetDataCoord is a helper method to define mock.On call -// - dataCoord types.DataCoord -func (_e *MockDataNode_Expecter) SetDataCoord(dataCoord interface{}) *MockDataNode_SetDataCoord_Call { - return &MockDataNode_SetDataCoord_Call{Call: _e.mock.On("SetDataCoord", dataCoord)} +// SetDataCoordClient is a helper method to define mock.On call +// - dataCoord types.DataCoordClient +func (_e *MockDataNode_Expecter) SetDataCoordClient(dataCoord interface{}) *MockDataNode_SetDataCoordClient_Call { + return &MockDataNode_SetDataCoordClient_Call{Call: _e.mock.On("SetDataCoordClient", dataCoord)} } -func (_c *MockDataNode_SetDataCoord_Call) Run(run func(dataCoord types.DataCoord)) *MockDataNode_SetDataCoord_Call { +func (_c *MockDataNode_SetDataCoordClient_Call) Run(run func(dataCoord types.DataCoordClient)) *MockDataNode_SetDataCoordClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(types.DataCoord)) + run(args[0].(types.DataCoordClient)) }) return _c } -func (_c *MockDataNode_SetDataCoord_Call) Return(_a0 error) *MockDataNode_SetDataCoord_Call { +func (_c *MockDataNode_SetDataCoordClient_Call) Return(_a0 error) *MockDataNode_SetDataCoordClient_Call { _c.Call.Return(_a0) return _c } -func (_c *MockDataNode_SetDataCoord_Call) RunAndReturn(run func(types.DataCoord) error) *MockDataNode_SetDataCoord_Call { +func (_c *MockDataNode_SetDataCoordClient_Call) RunAndReturn(run func(types.DataCoordClient) error) *MockDataNode_SetDataCoordClient_Call { _c.Call.Return(run) return _c } @@ -797,12 +964,12 @@ func (_c *MockDataNode_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Clien return _c } -// SetRootCoord provides a mock function with given fields: rootCoord -func (_m *MockDataNode) SetRootCoord(rootCoord types.RootCoord) error { +// SetRootCoordClient provides a mock function with given fields: rootCoord +func (_m *MockDataNode) SetRootCoordClient(rootCoord types.RootCoordClient) error { ret := _m.Called(rootCoord) var r0 error - if rf, ok := ret.Get(0).(func(types.RootCoord) error); ok { + if rf, ok := ret.Get(0).(func(types.RootCoordClient) error); ok { r0 = rf(rootCoord) } else { r0 = ret.Error(0) @@ -811,45 +978,45 @@ func (_m *MockDataNode) SetRootCoord(rootCoord types.RootCoord) error { return r0 } -// MockDataNode_SetRootCoord_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRootCoord' -type MockDataNode_SetRootCoord_Call struct { +// MockDataNode_SetRootCoordClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRootCoordClient' +type MockDataNode_SetRootCoordClient_Call struct { *mock.Call } -// SetRootCoord is a helper method to define mock.On call -// - rootCoord types.RootCoord -func (_e *MockDataNode_Expecter) SetRootCoord(rootCoord interface{}) *MockDataNode_SetRootCoord_Call { - return &MockDataNode_SetRootCoord_Call{Call: _e.mock.On("SetRootCoord", rootCoord)} +// SetRootCoordClient is a helper method to define mock.On call +// - rootCoord types.RootCoordClient +func (_e *MockDataNode_Expecter) SetRootCoordClient(rootCoord interface{}) *MockDataNode_SetRootCoordClient_Call { + return &MockDataNode_SetRootCoordClient_Call{Call: _e.mock.On("SetRootCoordClient", rootCoord)} } -func (_c *MockDataNode_SetRootCoord_Call) Run(run func(rootCoord types.RootCoord)) *MockDataNode_SetRootCoord_Call { +func (_c *MockDataNode_SetRootCoordClient_Call) Run(run func(rootCoord types.RootCoordClient)) *MockDataNode_SetRootCoordClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(types.RootCoord)) + run(args[0].(types.RootCoordClient)) }) return _c } -func (_c *MockDataNode_SetRootCoord_Call) Return(_a0 error) *MockDataNode_SetRootCoord_Call { +func (_c *MockDataNode_SetRootCoordClient_Call) Return(_a0 error) *MockDataNode_SetRootCoordClient_Call { _c.Call.Return(_a0) return _c } -func (_c *MockDataNode_SetRootCoord_Call) RunAndReturn(run func(types.RootCoord) error) *MockDataNode_SetRootCoord_Call { +func (_c *MockDataNode_SetRootCoordClient_Call) RunAndReturn(run func(types.RootCoordClient) error) *MockDataNode_SetRootCoordClient_Call { _c.Call.Return(run) return _c } -// ShowConfigurations provides a mock function with given fields: ctx, req -func (_m *MockDataNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - ret := _m.Called(ctx, req) +// ShowConfigurations provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) ShowConfigurations(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *internalpb.ShowConfigurationsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) *internalpb.ShowConfigurationsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*internalpb.ShowConfigurationsResponse) @@ -857,7 +1024,7 @@ func (_m *MockDataNode) ShowConfigurations(ctx context.Context, req *internalpb. } if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -871,13 +1038,13 @@ type MockDataNode_ShowConfigurations_Call struct { } // ShowConfigurations is a helper method to define mock.On call -// - ctx context.Context -// - req *internalpb.ShowConfigurationsRequest -func (_e *MockDataNode_Expecter) ShowConfigurations(ctx interface{}, req interface{}) *MockDataNode_ShowConfigurations_Call { - return &MockDataNode_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", ctx, req)} +// - _a0 context.Context +// - _a1 *internalpb.ShowConfigurationsRequest +func (_e *MockDataNode_Expecter) ShowConfigurations(_a0 interface{}, _a1 interface{}) *MockDataNode_ShowConfigurations_Call { + return &MockDataNode_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", _a0, _a1)} } -func (_c *MockDataNode_ShowConfigurations_Call) Run(run func(ctx context.Context, req *internalpb.ShowConfigurationsRequest)) *MockDataNode_ShowConfigurations_Call { +func (_c *MockDataNode_ShowConfigurations_Call) Run(run func(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest)) *MockDataNode_ShowConfigurations_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*internalpb.ShowConfigurationsRequest)) }) @@ -976,17 +1143,17 @@ func (_c *MockDataNode_Stop_Call) RunAndReturn(run func() error) *MockDataNode_S return _c } -// SyncSegments provides a mock function with given fields: ctx, req -func (_m *MockDataNode) SyncSegments(ctx context.Context, req *datapb.SyncSegmentsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// SyncSegments provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) SyncSegments(_a0 context.Context, _a1 *datapb.SyncSegmentsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.SyncSegmentsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.SyncSegmentsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -994,7 +1161,7 @@ func (_m *MockDataNode) SyncSegments(ctx context.Context, req *datapb.SyncSegmen } if rf, ok := ret.Get(1).(func(context.Context, *datapb.SyncSegmentsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1008,13 +1175,13 @@ type MockDataNode_SyncSegments_Call struct { } // SyncSegments is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.SyncSegmentsRequest -func (_e *MockDataNode_Expecter) SyncSegments(ctx interface{}, req interface{}) *MockDataNode_SyncSegments_Call { - return &MockDataNode_SyncSegments_Call{Call: _e.mock.On("SyncSegments", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.SyncSegmentsRequest +func (_e *MockDataNode_Expecter) SyncSegments(_a0 interface{}, _a1 interface{}) *MockDataNode_SyncSegments_Call { + return &MockDataNode_SyncSegments_Call{Call: _e.mock.On("SyncSegments", _a0, _a1)} } -func (_c *MockDataNode_SyncSegments_Call) Run(run func(ctx context.Context, req *datapb.SyncSegmentsRequest)) *MockDataNode_SyncSegments_Call { +func (_c *MockDataNode_SyncSegments_Call) Run(run func(_a0 context.Context, _a1 *datapb.SyncSegmentsRequest)) *MockDataNode_SyncSegments_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.SyncSegmentsRequest)) }) @@ -1064,17 +1231,17 @@ func (_c *MockDataNode_UpdateStateCode_Call) RunAndReturn(run func(commonpb.Stat return _c } -// WatchDmChannels provides a mock function with given fields: ctx, req -func (_m *MockDataNode) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannelsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// WatchDmChannels provides a mock function with given fields: _a0, _a1 +func (_m *MockDataNode) WatchDmChannels(_a0 context.Context, _a1 *datapb.WatchDmChannelsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *datapb.WatchDmChannelsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *datapb.WatchDmChannelsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1082,7 +1249,7 @@ func (_m *MockDataNode) WatchDmChannels(ctx context.Context, req *datapb.WatchDm } if rf, ok := ret.Get(1).(func(context.Context, *datapb.WatchDmChannelsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1096,13 +1263,13 @@ type MockDataNode_WatchDmChannels_Call struct { } // WatchDmChannels is a helper method to define mock.On call -// - ctx context.Context -// - req *datapb.WatchDmChannelsRequest -func (_e *MockDataNode_Expecter) WatchDmChannels(ctx interface{}, req interface{}) *MockDataNode_WatchDmChannels_Call { - return &MockDataNode_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", ctx, req)} +// - _a0 context.Context +// - _a1 *datapb.WatchDmChannelsRequest +func (_e *MockDataNode_Expecter) WatchDmChannels(_a0 interface{}, _a1 interface{}) *MockDataNode_WatchDmChannels_Call { + return &MockDataNode_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", _a0, _a1)} } -func (_c *MockDataNode_WatchDmChannels_Call) Run(run func(ctx context.Context, req *datapb.WatchDmChannelsRequest)) *MockDataNode_WatchDmChannels_Call { +func (_c *MockDataNode_WatchDmChannels_Call) Run(run func(_a0 context.Context, _a1 *datapb.WatchDmChannelsRequest)) *MockDataNode_WatchDmChannels_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*datapb.WatchDmChannelsRequest)) }) diff --git a/internal/mocks/mock_datanode_client.go b/internal/mocks/mock_datanode_client.go new file mode 100644 index 0000000000000..76131bacf781b --- /dev/null +++ b/internal/mocks/mock_datanode_client.go @@ -0,0 +1,1137 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + + commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + + datapb "github.com/milvus-io/milvus/internal/proto/datapb" + + grpc "google.golang.org/grpc" + + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" + + milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + + mock "github.com/stretchr/testify/mock" +) + +// MockDataNodeClient is an autogenerated mock type for the DataNodeClient type +type MockDataNodeClient struct { + mock.Mock +} + +type MockDataNodeClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockDataNodeClient) EXPECT() *MockDataNodeClient_Expecter { + return &MockDataNodeClient_Expecter{mock: &_m.Mock} +} + +// AddImportSegment provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) AddImportSegment(ctx context.Context, in *datapb.AddImportSegmentRequest, opts ...grpc.CallOption) (*datapb.AddImportSegmentResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.AddImportSegmentResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.AddImportSegmentRequest, ...grpc.CallOption) (*datapb.AddImportSegmentResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.AddImportSegmentRequest, ...grpc.CallOption) *datapb.AddImportSegmentResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.AddImportSegmentResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.AddImportSegmentRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_AddImportSegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddImportSegment' +type MockDataNodeClient_AddImportSegment_Call struct { + *mock.Call +} + +// AddImportSegment is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.AddImportSegmentRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) AddImportSegment(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_AddImportSegment_Call { + return &MockDataNodeClient_AddImportSegment_Call{Call: _e.mock.On("AddImportSegment", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_AddImportSegment_Call) Run(run func(ctx context.Context, in *datapb.AddImportSegmentRequest, opts ...grpc.CallOption)) *MockDataNodeClient_AddImportSegment_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.AddImportSegmentRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_AddImportSegment_Call) Return(_a0 *datapb.AddImportSegmentResponse, _a1 error) *MockDataNodeClient_AddImportSegment_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_AddImportSegment_Call) RunAndReturn(run func(context.Context, *datapb.AddImportSegmentRequest, ...grpc.CallOption) (*datapb.AddImportSegmentResponse, error)) *MockDataNodeClient_AddImportSegment_Call { + _c.Call.Return(run) + return _c +} + +// CheckChannelOperationProgress provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) CheckChannelOperationProgress(ctx context.Context, in *datapb.ChannelWatchInfo, opts ...grpc.CallOption) (*datapb.ChannelOperationProgressResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.ChannelOperationProgressResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ChannelWatchInfo, ...grpc.CallOption) (*datapb.ChannelOperationProgressResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ChannelWatchInfo, ...grpc.CallOption) *datapb.ChannelOperationProgressResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.ChannelOperationProgressResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.ChannelWatchInfo, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_CheckChannelOperationProgress_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckChannelOperationProgress' +type MockDataNodeClient_CheckChannelOperationProgress_Call struct { + *mock.Call +} + +// CheckChannelOperationProgress is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.ChannelWatchInfo +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) CheckChannelOperationProgress(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_CheckChannelOperationProgress_Call { + return &MockDataNodeClient_CheckChannelOperationProgress_Call{Call: _e.mock.On("CheckChannelOperationProgress", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_CheckChannelOperationProgress_Call) Run(run func(ctx context.Context, in *datapb.ChannelWatchInfo, opts ...grpc.CallOption)) *MockDataNodeClient_CheckChannelOperationProgress_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.ChannelWatchInfo), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_CheckChannelOperationProgress_Call) Return(_a0 *datapb.ChannelOperationProgressResponse, _a1 error) *MockDataNodeClient_CheckChannelOperationProgress_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_CheckChannelOperationProgress_Call) RunAndReturn(run func(context.Context, *datapb.ChannelWatchInfo, ...grpc.CallOption) (*datapb.ChannelOperationProgressResponse, error)) *MockDataNodeClient_CheckChannelOperationProgress_Call { + _c.Call.Return(run) + return _c +} + +// Close provides a mock function with given fields: +func (_m *MockDataNodeClient) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockDataNodeClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockDataNodeClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockDataNodeClient_Expecter) Close() *MockDataNodeClient_Close_Call { + return &MockDataNodeClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockDataNodeClient_Close_Call) Run(run func()) *MockDataNodeClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockDataNodeClient_Close_Call) Return(_a0 error) *MockDataNodeClient_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockDataNodeClient_Close_Call) RunAndReturn(run func() error) *MockDataNodeClient_Close_Call { + _c.Call.Return(run) + return _c +} + +// Compaction provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) Compaction(ctx context.Context, in *datapb.CompactionPlan, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.CompactionPlan, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.CompactionPlan, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.CompactionPlan, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_Compaction_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Compaction' +type MockDataNodeClient_Compaction_Call struct { + *mock.Call +} + +// Compaction is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.CompactionPlan +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) Compaction(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_Compaction_Call { + return &MockDataNodeClient_Compaction_Call{Call: _e.mock.On("Compaction", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_Compaction_Call) Run(run func(ctx context.Context, in *datapb.CompactionPlan, opts ...grpc.CallOption)) *MockDataNodeClient_Compaction_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.CompactionPlan), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_Compaction_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataNodeClient_Compaction_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_Compaction_Call) RunAndReturn(run func(context.Context, *datapb.CompactionPlan, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataNodeClient_Compaction_Call { + _c.Call.Return(run) + return _c +} + +// FlushChannels provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) FlushChannels(ctx context.Context, in *datapb.FlushChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.FlushChannelsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.FlushChannelsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.FlushChannelsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_FlushChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FlushChannels' +type MockDataNodeClient_FlushChannels_Call struct { + *mock.Call +} + +// FlushChannels is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.FlushChannelsRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) FlushChannels(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_FlushChannels_Call { + return &MockDataNodeClient_FlushChannels_Call{Call: _e.mock.On("FlushChannels", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_FlushChannels_Call) Run(run func(ctx context.Context, in *datapb.FlushChannelsRequest, opts ...grpc.CallOption)) *MockDataNodeClient_FlushChannels_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.FlushChannelsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_FlushChannels_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataNodeClient_FlushChannels_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_FlushChannels_Call) RunAndReturn(run func(context.Context, *datapb.FlushChannelsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataNodeClient_FlushChannels_Call { + _c.Call.Return(run) + return _c +} + +// FlushSegments provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) FlushSegments(ctx context.Context, in *datapb.FlushSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.FlushSegmentsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.FlushSegmentsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.FlushSegmentsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_FlushSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FlushSegments' +type MockDataNodeClient_FlushSegments_Call struct { + *mock.Call +} + +// FlushSegments is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.FlushSegmentsRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) FlushSegments(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_FlushSegments_Call { + return &MockDataNodeClient_FlushSegments_Call{Call: _e.mock.On("FlushSegments", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_FlushSegments_Call) Run(run func(ctx context.Context, in *datapb.FlushSegmentsRequest, opts ...grpc.CallOption)) *MockDataNodeClient_FlushSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.FlushSegmentsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_FlushSegments_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataNodeClient_FlushSegments_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_FlushSegments_Call) RunAndReturn(run func(context.Context, *datapb.FlushSegmentsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataNodeClient_FlushSegments_Call { + _c.Call.Return(run) + return _c +} + +// GetCompactionState provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) GetCompactionState(ctx context.Context, in *datapb.CompactionStateRequest, opts ...grpc.CallOption) (*datapb.CompactionStateResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.CompactionStateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.CompactionStateRequest, ...grpc.CallOption) (*datapb.CompactionStateResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.CompactionStateRequest, ...grpc.CallOption) *datapb.CompactionStateResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.CompactionStateResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.CompactionStateRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_GetCompactionState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCompactionState' +type MockDataNodeClient_GetCompactionState_Call struct { + *mock.Call +} + +// GetCompactionState is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.CompactionStateRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) GetCompactionState(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_GetCompactionState_Call { + return &MockDataNodeClient_GetCompactionState_Call{Call: _e.mock.On("GetCompactionState", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_GetCompactionState_Call) Run(run func(ctx context.Context, in *datapb.CompactionStateRequest, opts ...grpc.CallOption)) *MockDataNodeClient_GetCompactionState_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.CompactionStateRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_GetCompactionState_Call) Return(_a0 *datapb.CompactionStateResponse, _a1 error) *MockDataNodeClient_GetCompactionState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_GetCompactionState_Call) RunAndReturn(run func(context.Context, *datapb.CompactionStateRequest, ...grpc.CallOption) (*datapb.CompactionStateResponse, error)) *MockDataNodeClient_GetCompactionState_Call { + _c.Call.Return(run) + return _c +} + +// GetComponentStates provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ComponentStates + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) *milvuspb.ComponentStates); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ComponentStates) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_GetComponentStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetComponentStates' +type MockDataNodeClient_GetComponentStates_Call struct { + *mock.Call +} + +// GetComponentStates is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetComponentStatesRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) GetComponentStates(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_GetComponentStates_Call { + return &MockDataNodeClient_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_GetComponentStates_Call) Run(run func(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption)) *MockDataNodeClient_GetComponentStates_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentStates, _a1 error) *MockDataNodeClient_GetComponentStates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)) *MockDataNodeClient_GetComponentStates_Call { + _c.Call.Return(run) + return _c +} + +// GetMetrics provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.GetMetricsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) *milvuspb.GetMetricsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_GetMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetMetrics' +type MockDataNodeClient_GetMetrics_Call struct { + *mock.Call +} + +// GetMetrics is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetMetricsRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) GetMetrics(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_GetMetrics_Call { + return &MockDataNodeClient_GetMetrics_Call{Call: _e.mock.On("GetMetrics", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_GetMetrics_Call) Run(run func(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption)) *MockDataNodeClient_GetMetrics_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_GetMetrics_Call) Return(_a0 *milvuspb.GetMetricsResponse, _a1 error) *MockDataNodeClient_GetMetrics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_GetMetrics_Call) RunAndReturn(run func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)) *MockDataNodeClient_GetMetrics_Call { + _c.Call.Return(run) + return _c +} + +// GetStatisticsChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.StringResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) *milvuspb.StringResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_GetStatisticsChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatisticsChannel' +type MockDataNodeClient_GetStatisticsChannel_Call struct { + *mock.Call +} + +// GetStatisticsChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.GetStatisticsChannelRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) GetStatisticsChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_GetStatisticsChannel_Call { + return &MockDataNodeClient_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_GetStatisticsChannel_Call) Run(run func(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption)) *MockDataNodeClient_GetStatisticsChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *MockDataNodeClient_GetStatisticsChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)) *MockDataNodeClient_GetStatisticsChannel_Call { + _c.Call.Return(run) + return _c +} + +// Import provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) Import(ctx context.Context, in *datapb.ImportTaskRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ImportTaskRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ImportTaskRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.ImportTaskRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_Import_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Import' +type MockDataNodeClient_Import_Call struct { + *mock.Call +} + +// Import is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.ImportTaskRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) Import(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_Import_Call { + return &MockDataNodeClient_Import_Call{Call: _e.mock.On("Import", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_Import_Call) Run(run func(ctx context.Context, in *datapb.ImportTaskRequest, opts ...grpc.CallOption)) *MockDataNodeClient_Import_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.ImportTaskRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_Import_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataNodeClient_Import_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_Import_Call) RunAndReturn(run func(context.Context, *datapb.ImportTaskRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataNodeClient_Import_Call { + _c.Call.Return(run) + return _c +} + +// NotifyChannelOperation provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) NotifyChannelOperation(ctx context.Context, in *datapb.ChannelOperationsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ChannelOperationsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ChannelOperationsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.ChannelOperationsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_NotifyChannelOperation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NotifyChannelOperation' +type MockDataNodeClient_NotifyChannelOperation_Call struct { + *mock.Call +} + +// NotifyChannelOperation is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.ChannelOperationsRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) NotifyChannelOperation(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_NotifyChannelOperation_Call { + return &MockDataNodeClient_NotifyChannelOperation_Call{Call: _e.mock.On("NotifyChannelOperation", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_NotifyChannelOperation_Call) Run(run func(ctx context.Context, in *datapb.ChannelOperationsRequest, opts ...grpc.CallOption)) *MockDataNodeClient_NotifyChannelOperation_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.ChannelOperationsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_NotifyChannelOperation_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataNodeClient_NotifyChannelOperation_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_NotifyChannelOperation_Call) RunAndReturn(run func(context.Context, *datapb.ChannelOperationsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataNodeClient_NotifyChannelOperation_Call { + _c.Call.Return(run) + return _c +} + +// ResendSegmentStats provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) ResendSegmentStats(ctx context.Context, in *datapb.ResendSegmentStatsRequest, opts ...grpc.CallOption) (*datapb.ResendSegmentStatsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *datapb.ResendSegmentStatsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ResendSegmentStatsRequest, ...grpc.CallOption) (*datapb.ResendSegmentStatsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.ResendSegmentStatsRequest, ...grpc.CallOption) *datapb.ResendSegmentStatsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.ResendSegmentStatsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.ResendSegmentStatsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_ResendSegmentStats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResendSegmentStats' +type MockDataNodeClient_ResendSegmentStats_Call struct { + *mock.Call +} + +// ResendSegmentStats is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.ResendSegmentStatsRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) ResendSegmentStats(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_ResendSegmentStats_Call { + return &MockDataNodeClient_ResendSegmentStats_Call{Call: _e.mock.On("ResendSegmentStats", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_ResendSegmentStats_Call) Run(run func(ctx context.Context, in *datapb.ResendSegmentStatsRequest, opts ...grpc.CallOption)) *MockDataNodeClient_ResendSegmentStats_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.ResendSegmentStatsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_ResendSegmentStats_Call) Return(_a0 *datapb.ResendSegmentStatsResponse, _a1 error) *MockDataNodeClient_ResendSegmentStats_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_ResendSegmentStats_Call) RunAndReturn(run func(context.Context, *datapb.ResendSegmentStatsRequest, ...grpc.CallOption) (*datapb.ResendSegmentStatsResponse, error)) *MockDataNodeClient_ResendSegmentStats_Call { + _c.Call.Return(run) + return _c +} + +// ShowConfigurations provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *internalpb.ShowConfigurationsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) *internalpb.ShowConfigurationsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internalpb.ShowConfigurationsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_ShowConfigurations_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowConfigurations' +type MockDataNodeClient_ShowConfigurations_Call struct { + *mock.Call +} + +// ShowConfigurations is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.ShowConfigurationsRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) ShowConfigurations(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_ShowConfigurations_Call { + return &MockDataNodeClient_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_ShowConfigurations_Call) Run(run func(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption)) *MockDataNodeClient_ShowConfigurations_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.ShowConfigurationsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_ShowConfigurations_Call) Return(_a0 *internalpb.ShowConfigurationsResponse, _a1 error) *MockDataNodeClient_ShowConfigurations_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_ShowConfigurations_Call) RunAndReturn(run func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error)) *MockDataNodeClient_ShowConfigurations_Call { + _c.Call.Return(run) + return _c +} + +// SyncSegments provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) SyncSegments(ctx context.Context, in *datapb.SyncSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.SyncSegmentsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.SyncSegmentsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.SyncSegmentsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_SyncSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SyncSegments' +type MockDataNodeClient_SyncSegments_Call struct { + *mock.Call +} + +// SyncSegments is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.SyncSegmentsRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) SyncSegments(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_SyncSegments_Call { + return &MockDataNodeClient_SyncSegments_Call{Call: _e.mock.On("SyncSegments", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_SyncSegments_Call) Run(run func(ctx context.Context, in *datapb.SyncSegmentsRequest, opts ...grpc.CallOption)) *MockDataNodeClient_SyncSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.SyncSegmentsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_SyncSegments_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataNodeClient_SyncSegments_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_SyncSegments_Call) RunAndReturn(run func(context.Context, *datapb.SyncSegmentsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataNodeClient_SyncSegments_Call { + _c.Call.Return(run) + return _c +} + +// WatchDmChannels provides a mock function with given fields: ctx, in, opts +func (_m *MockDataNodeClient) WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.WatchDmChannelsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *datapb.WatchDmChannelsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *datapb.WatchDmChannelsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDataNodeClient_WatchDmChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchDmChannels' +type MockDataNodeClient_WatchDmChannels_Call struct { + *mock.Call +} + +// WatchDmChannels is a helper method to define mock.On call +// - ctx context.Context +// - in *datapb.WatchDmChannelsRequest +// - opts ...grpc.CallOption +func (_e *MockDataNodeClient_Expecter) WatchDmChannels(ctx interface{}, in interface{}, opts ...interface{}) *MockDataNodeClient_WatchDmChannels_Call { + return &MockDataNodeClient_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockDataNodeClient_WatchDmChannels_Call) Run(run func(ctx context.Context, in *datapb.WatchDmChannelsRequest, opts ...grpc.CallOption)) *MockDataNodeClient_WatchDmChannels_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*datapb.WatchDmChannelsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockDataNodeClient_WatchDmChannels_Call) Return(_a0 *commonpb.Status, _a1 error) *MockDataNodeClient_WatchDmChannels_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDataNodeClient_WatchDmChannels_Call) RunAndReturn(run func(context.Context, *datapb.WatchDmChannelsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockDataNodeClient_WatchDmChannels_Call { + _c.Call.Return(run) + return _c +} + +// NewMockDataNodeClient creates a new instance of MockDataNodeClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockDataNodeClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockDataNodeClient { + mock := &MockDataNodeClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/mock_indexnode.go b/internal/mocks/mock_indexnode.go index 3354e4b560d6a..bcd158dc7b34e 100644 --- a/internal/mocks/mock_indexnode.go +++ b/internal/mocks/mock_indexnode.go @@ -181,25 +181,25 @@ func (_c *MockIndexNode_GetAddress_Call) RunAndReturn(run func() string) *MockIn return _c } -// GetComponentStates provides a mock function with given fields: ctx -func (_m *MockIndexNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret := _m.Called(ctx) +// GetComponentStates provides a mock function with given fields: _a0, _a1 +func (_m *MockIndexNode) GetComponentStates(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ComponentStates var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.ComponentStates, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.ComponentStates); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) *milvuspb.ComponentStates); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ComponentStates) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -213,14 +213,15 @@ type MockIndexNode_GetComponentStates_Call struct { } // GetComponentStates is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockIndexNode_Expecter) GetComponentStates(ctx interface{}) *MockIndexNode_GetComponentStates_Call { - return &MockIndexNode_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx)} +// - _a0 context.Context +// - _a1 *milvuspb.GetComponentStatesRequest +func (_e *MockIndexNode_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *MockIndexNode_GetComponentStates_Call { + return &MockIndexNode_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)} } -func (_c *MockIndexNode_GetComponentStates_Call) Run(run func(ctx context.Context)) *MockIndexNode_GetComponentStates_Call { +func (_c *MockIndexNode_GetComponentStates_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest)) *MockIndexNode_GetComponentStates_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest)) }) return _c } @@ -230,7 +231,7 @@ func (_c *MockIndexNode_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentS return _c } -func (_c *MockIndexNode_GetComponentStates_Call) RunAndReturn(run func(context.Context) (*milvuspb.ComponentStates, error)) *MockIndexNode_GetComponentStates_Call { +func (_c *MockIndexNode_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)) *MockIndexNode_GetComponentStates_Call { _c.Call.Return(run) return _c } @@ -290,17 +291,17 @@ func (_c *MockIndexNode_GetJobStats_Call) RunAndReturn(run func(context.Context, return _c } -// GetMetrics provides a mock function with given fields: ctx, req -func (_m *MockIndexNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret := _m.Called(ctx, req) +// GetMetrics provides a mock function with given fields: _a0, _a1 +func (_m *MockIndexNode) GetMetrics(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetMetricsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) *milvuspb.GetMetricsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) @@ -308,7 +309,7 @@ func (_m *MockIndexNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetric } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -322,13 +323,13 @@ type MockIndexNode_GetMetrics_Call struct { } // GetMetrics is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetMetricsRequest -func (_e *MockIndexNode_Expecter) GetMetrics(ctx interface{}, req interface{}) *MockIndexNode_GetMetrics_Call { - return &MockIndexNode_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetMetricsRequest +func (_e *MockIndexNode_Expecter) GetMetrics(_a0 interface{}, _a1 interface{}) *MockIndexNode_GetMetrics_Call { + return &MockIndexNode_GetMetrics_Call{Call: _e.mock.On("GetMetrics", _a0, _a1)} } -func (_c *MockIndexNode_GetMetrics_Call) Run(run func(ctx context.Context, req *milvuspb.GetMetricsRequest)) *MockIndexNode_GetMetrics_Call { +func (_c *MockIndexNode_GetMetrics_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest)) *MockIndexNode_GetMetrics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest)) }) @@ -345,25 +346,25 @@ func (_c *MockIndexNode_GetMetrics_Call) RunAndReturn(run func(context.Context, return _c } -// GetStatisticsChannel provides a mock function with given fields: ctx -func (_m *MockIndexNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret := _m.Called(ctx) +// GetStatisticsChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockIndexNode) GetStatisticsChannel(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.StringResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.StringResponse, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.StringResponse); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) *milvuspb.StringResponse); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.StringResponse) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -377,14 +378,15 @@ type MockIndexNode_GetStatisticsChannel_Call struct { } // GetStatisticsChannel is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockIndexNode_Expecter) GetStatisticsChannel(ctx interface{}) *MockIndexNode_GetStatisticsChannel_Call { - return &MockIndexNode_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", ctx)} +// - _a0 context.Context +// - _a1 *internalpb.GetStatisticsChannelRequest +func (_e *MockIndexNode_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *MockIndexNode_GetStatisticsChannel_Call { + return &MockIndexNode_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)} } -func (_c *MockIndexNode_GetStatisticsChannel_Call) Run(run func(ctx context.Context)) *MockIndexNode_GetStatisticsChannel_Call { +func (_c *MockIndexNode_GetStatisticsChannel_Call) Run(run func(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest)) *MockIndexNode_GetStatisticsChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest)) }) return _c } @@ -394,7 +396,7 @@ func (_c *MockIndexNode_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringRe return _c } -func (_c *MockIndexNode_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context) (*milvuspb.StringResponse, error)) *MockIndexNode_GetStatisticsChannel_Call { +func (_c *MockIndexNode_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)) *MockIndexNode_GetStatisticsChannel_Call { _c.Call.Return(run) return _c } @@ -602,17 +604,17 @@ func (_c *MockIndexNode_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Clie return _c } -// ShowConfigurations provides a mock function with given fields: ctx, req -func (_m *MockIndexNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - ret := _m.Called(ctx, req) +// ShowConfigurations provides a mock function with given fields: _a0, _a1 +func (_m *MockIndexNode) ShowConfigurations(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *internalpb.ShowConfigurationsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) *internalpb.ShowConfigurationsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*internalpb.ShowConfigurationsResponse) @@ -620,7 +622,7 @@ func (_m *MockIndexNode) ShowConfigurations(ctx context.Context, req *internalpb } if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -634,13 +636,13 @@ type MockIndexNode_ShowConfigurations_Call struct { } // ShowConfigurations is a helper method to define mock.On call -// - ctx context.Context -// - req *internalpb.ShowConfigurationsRequest -func (_e *MockIndexNode_Expecter) ShowConfigurations(ctx interface{}, req interface{}) *MockIndexNode_ShowConfigurations_Call { - return &MockIndexNode_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", ctx, req)} +// - _a0 context.Context +// - _a1 *internalpb.ShowConfigurationsRequest +func (_e *MockIndexNode_Expecter) ShowConfigurations(_a0 interface{}, _a1 interface{}) *MockIndexNode_ShowConfigurations_Call { + return &MockIndexNode_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", _a0, _a1)} } -func (_c *MockIndexNode_ShowConfigurations_Call) Run(run func(ctx context.Context, req *internalpb.ShowConfigurationsRequest)) *MockIndexNode_ShowConfigurations_Call { +func (_c *MockIndexNode_ShowConfigurations_Call) Run(run func(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest)) *MockIndexNode_ShowConfigurations_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*internalpb.ShowConfigurationsRequest)) }) diff --git a/internal/mocks/mock_indexnode_client.go b/internal/mocks/mock_indexnode_client.go new file mode 100644 index 0000000000000..1e30de98ac1bd --- /dev/null +++ b/internal/mocks/mock_indexnode_client.go @@ -0,0 +1,647 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + + commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + + grpc "google.golang.org/grpc" + + indexpb "github.com/milvus-io/milvus/internal/proto/indexpb" + + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" + + milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + + mock "github.com/stretchr/testify/mock" +) + +// MockIndexNodeClient is an autogenerated mock type for the IndexNodeClient type +type MockIndexNodeClient struct { + mock.Mock +} + +type MockIndexNodeClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockIndexNodeClient) EXPECT() *MockIndexNodeClient_Expecter { + return &MockIndexNodeClient_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with given fields: +func (_m *MockIndexNodeClient) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockIndexNodeClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockIndexNodeClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockIndexNodeClient_Expecter) Close() *MockIndexNodeClient_Close_Call { + return &MockIndexNodeClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockIndexNodeClient_Close_Call) Run(run func()) *MockIndexNodeClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockIndexNodeClient_Close_Call) Return(_a0 error) *MockIndexNodeClient_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockIndexNodeClient_Close_Call) RunAndReturn(run func() error) *MockIndexNodeClient_Close_Call { + _c.Call.Return(run) + return _c +} + +// CreateJob provides a mock function with given fields: ctx, in, opts +func (_m *MockIndexNodeClient) CreateJob(ctx context.Context, in *indexpb.CreateJobRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateJobRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.CreateJobRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.CreateJobRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNodeClient_CreateJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateJob' +type MockIndexNodeClient_CreateJob_Call struct { + *mock.Call +} + +// CreateJob is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.CreateJobRequest +// - opts ...grpc.CallOption +func (_e *MockIndexNodeClient_Expecter) CreateJob(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_CreateJob_Call { + return &MockIndexNodeClient_CreateJob_Call{Call: _e.mock.On("CreateJob", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockIndexNodeClient_CreateJob_Call) Run(run func(ctx context.Context, in *indexpb.CreateJobRequest, opts ...grpc.CallOption)) *MockIndexNodeClient_CreateJob_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.CreateJobRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockIndexNodeClient_CreateJob_Call) Return(_a0 *commonpb.Status, _a1 error) *MockIndexNodeClient_CreateJob_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNodeClient_CreateJob_Call) RunAndReturn(run func(context.Context, *indexpb.CreateJobRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockIndexNodeClient_CreateJob_Call { + _c.Call.Return(run) + return _c +} + +// DropJobs provides a mock function with given fields: ctx, in, opts +func (_m *MockIndexNodeClient) DropJobs(ctx context.Context, in *indexpb.DropJobsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropJobsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.DropJobsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.DropJobsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNodeClient_DropJobs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropJobs' +type MockIndexNodeClient_DropJobs_Call struct { + *mock.Call +} + +// DropJobs is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.DropJobsRequest +// - opts ...grpc.CallOption +func (_e *MockIndexNodeClient_Expecter) DropJobs(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_DropJobs_Call { + return &MockIndexNodeClient_DropJobs_Call{Call: _e.mock.On("DropJobs", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockIndexNodeClient_DropJobs_Call) Run(run func(ctx context.Context, in *indexpb.DropJobsRequest, opts ...grpc.CallOption)) *MockIndexNodeClient_DropJobs_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.DropJobsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockIndexNodeClient_DropJobs_Call) Return(_a0 *commonpb.Status, _a1 error) *MockIndexNodeClient_DropJobs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNodeClient_DropJobs_Call) RunAndReturn(run func(context.Context, *indexpb.DropJobsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockIndexNodeClient_DropJobs_Call { + _c.Call.Return(run) + return _c +} + +// GetComponentStates provides a mock function with given fields: ctx, in, opts +func (_m *MockIndexNodeClient) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ComponentStates + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) *milvuspb.ComponentStates); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ComponentStates) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNodeClient_GetComponentStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetComponentStates' +type MockIndexNodeClient_GetComponentStates_Call struct { + *mock.Call +} + +// GetComponentStates is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetComponentStatesRequest +// - opts ...grpc.CallOption +func (_e *MockIndexNodeClient_Expecter) GetComponentStates(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_GetComponentStates_Call { + return &MockIndexNodeClient_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockIndexNodeClient_GetComponentStates_Call) Run(run func(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption)) *MockIndexNodeClient_GetComponentStates_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockIndexNodeClient_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentStates, _a1 error) *MockIndexNodeClient_GetComponentStates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNodeClient_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)) *MockIndexNodeClient_GetComponentStates_Call { + _c.Call.Return(run) + return _c +} + +// GetJobStats provides a mock function with given fields: ctx, in, opts +func (_m *MockIndexNodeClient) GetJobStats(ctx context.Context, in *indexpb.GetJobStatsRequest, opts ...grpc.CallOption) (*indexpb.GetJobStatsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *indexpb.GetJobStatsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetJobStatsRequest, ...grpc.CallOption) (*indexpb.GetJobStatsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.GetJobStatsRequest, ...grpc.CallOption) *indexpb.GetJobStatsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*indexpb.GetJobStatsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.GetJobStatsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNodeClient_GetJobStats_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetJobStats' +type MockIndexNodeClient_GetJobStats_Call struct { + *mock.Call +} + +// GetJobStats is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.GetJobStatsRequest +// - opts ...grpc.CallOption +func (_e *MockIndexNodeClient_Expecter) GetJobStats(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_GetJobStats_Call { + return &MockIndexNodeClient_GetJobStats_Call{Call: _e.mock.On("GetJobStats", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockIndexNodeClient_GetJobStats_Call) Run(run func(ctx context.Context, in *indexpb.GetJobStatsRequest, opts ...grpc.CallOption)) *MockIndexNodeClient_GetJobStats_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.GetJobStatsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockIndexNodeClient_GetJobStats_Call) Return(_a0 *indexpb.GetJobStatsResponse, _a1 error) *MockIndexNodeClient_GetJobStats_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNodeClient_GetJobStats_Call) RunAndReturn(run func(context.Context, *indexpb.GetJobStatsRequest, ...grpc.CallOption) (*indexpb.GetJobStatsResponse, error)) *MockIndexNodeClient_GetJobStats_Call { + _c.Call.Return(run) + return _c +} + +// GetMetrics provides a mock function with given fields: ctx, in, opts +func (_m *MockIndexNodeClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.GetMetricsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) *milvuspb.GetMetricsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNodeClient_GetMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetMetrics' +type MockIndexNodeClient_GetMetrics_Call struct { + *mock.Call +} + +// GetMetrics is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetMetricsRequest +// - opts ...grpc.CallOption +func (_e *MockIndexNodeClient_Expecter) GetMetrics(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_GetMetrics_Call { + return &MockIndexNodeClient_GetMetrics_Call{Call: _e.mock.On("GetMetrics", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockIndexNodeClient_GetMetrics_Call) Run(run func(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption)) *MockIndexNodeClient_GetMetrics_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockIndexNodeClient_GetMetrics_Call) Return(_a0 *milvuspb.GetMetricsResponse, _a1 error) *MockIndexNodeClient_GetMetrics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNodeClient_GetMetrics_Call) RunAndReturn(run func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)) *MockIndexNodeClient_GetMetrics_Call { + _c.Call.Return(run) + return _c +} + +// GetStatisticsChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockIndexNodeClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.StringResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) *milvuspb.StringResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNodeClient_GetStatisticsChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatisticsChannel' +type MockIndexNodeClient_GetStatisticsChannel_Call struct { + *mock.Call +} + +// GetStatisticsChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.GetStatisticsChannelRequest +// - opts ...grpc.CallOption +func (_e *MockIndexNodeClient_Expecter) GetStatisticsChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_GetStatisticsChannel_Call { + return &MockIndexNodeClient_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockIndexNodeClient_GetStatisticsChannel_Call) Run(run func(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption)) *MockIndexNodeClient_GetStatisticsChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockIndexNodeClient_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *MockIndexNodeClient_GetStatisticsChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNodeClient_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)) *MockIndexNodeClient_GetStatisticsChannel_Call { + _c.Call.Return(run) + return _c +} + +// QueryJobs provides a mock function with given fields: ctx, in, opts +func (_m *MockIndexNodeClient) QueryJobs(ctx context.Context, in *indexpb.QueryJobsRequest, opts ...grpc.CallOption) (*indexpb.QueryJobsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *indexpb.QueryJobsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.QueryJobsRequest, ...grpc.CallOption) (*indexpb.QueryJobsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *indexpb.QueryJobsRequest, ...grpc.CallOption) *indexpb.QueryJobsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*indexpb.QueryJobsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *indexpb.QueryJobsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNodeClient_QueryJobs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryJobs' +type MockIndexNodeClient_QueryJobs_Call struct { + *mock.Call +} + +// QueryJobs is a helper method to define mock.On call +// - ctx context.Context +// - in *indexpb.QueryJobsRequest +// - opts ...grpc.CallOption +func (_e *MockIndexNodeClient_Expecter) QueryJobs(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_QueryJobs_Call { + return &MockIndexNodeClient_QueryJobs_Call{Call: _e.mock.On("QueryJobs", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockIndexNodeClient_QueryJobs_Call) Run(run func(ctx context.Context, in *indexpb.QueryJobsRequest, opts ...grpc.CallOption)) *MockIndexNodeClient_QueryJobs_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*indexpb.QueryJobsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockIndexNodeClient_QueryJobs_Call) Return(_a0 *indexpb.QueryJobsResponse, _a1 error) *MockIndexNodeClient_QueryJobs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNodeClient_QueryJobs_Call) RunAndReturn(run func(context.Context, *indexpb.QueryJobsRequest, ...grpc.CallOption) (*indexpb.QueryJobsResponse, error)) *MockIndexNodeClient_QueryJobs_Call { + _c.Call.Return(run) + return _c +} + +// ShowConfigurations provides a mock function with given fields: ctx, in, opts +func (_m *MockIndexNodeClient) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *internalpb.ShowConfigurationsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) *internalpb.ShowConfigurationsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internalpb.ShowConfigurationsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockIndexNodeClient_ShowConfigurations_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowConfigurations' +type MockIndexNodeClient_ShowConfigurations_Call struct { + *mock.Call +} + +// ShowConfigurations is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.ShowConfigurationsRequest +// - opts ...grpc.CallOption +func (_e *MockIndexNodeClient_Expecter) ShowConfigurations(ctx interface{}, in interface{}, opts ...interface{}) *MockIndexNodeClient_ShowConfigurations_Call { + return &MockIndexNodeClient_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockIndexNodeClient_ShowConfigurations_Call) Run(run func(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption)) *MockIndexNodeClient_ShowConfigurations_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.ShowConfigurationsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockIndexNodeClient_ShowConfigurations_Call) Return(_a0 *internalpb.ShowConfigurationsResponse, _a1 error) *MockIndexNodeClient_ShowConfigurations_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockIndexNodeClient_ShowConfigurations_Call) RunAndReturn(run func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error)) *MockIndexNodeClient_ShowConfigurations_Call { + _c.Call.Return(run) + return _c +} + +// NewMockIndexNodeClient creates a new instance of MockIndexNodeClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockIndexNodeClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockIndexNodeClient { + mock := &MockIndexNodeClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/mock_proxy.go b/internal/mocks/mock_proxy.go index 0db3a60152ef7..706b2fe25c7c7 100644 --- a/internal/mocks/mock_proxy.go +++ b/internal/mocks/mock_proxy.go @@ -8,6 +8,8 @@ import ( commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" clientv3 "go.etcd.io/etcd/client/v3" + federpb "github.com/milvus-io/milvus-proto/go-api/v2/federpb" + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -32,17 +34,17 @@ func (_m *MockProxy) EXPECT() *MockProxy_Expecter { return &MockProxy_Expecter{mock: &_m.Mock} } -// AllocTimestamp provides a mock function with given fields: ctx, req -func (_m *MockProxy) AllocTimestamp(ctx context.Context, req *milvuspb.AllocTimestampRequest) (*milvuspb.AllocTimestampResponse, error) { - ret := _m.Called(ctx, req) +// AllocTimestamp provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) AllocTimestamp(_a0 context.Context, _a1 *milvuspb.AllocTimestampRequest) (*milvuspb.AllocTimestampResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.AllocTimestampResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AllocTimestampRequest) (*milvuspb.AllocTimestampResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AllocTimestampRequest) *milvuspb.AllocTimestampResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.AllocTimestampResponse) @@ -50,7 +52,7 @@ func (_m *MockProxy) AllocTimestamp(ctx context.Context, req *milvuspb.AllocTime } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.AllocTimestampRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -64,13 +66,13 @@ type MockProxy_AllocTimestamp_Call struct { } // AllocTimestamp is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.AllocTimestampRequest -func (_e *MockProxy_Expecter) AllocTimestamp(ctx interface{}, req interface{}) *MockProxy_AllocTimestamp_Call { - return &MockProxy_AllocTimestamp_Call{Call: _e.mock.On("AllocTimestamp", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.AllocTimestampRequest +func (_e *MockProxy_Expecter) AllocTimestamp(_a0 interface{}, _a1 interface{}) *MockProxy_AllocTimestamp_Call { + return &MockProxy_AllocTimestamp_Call{Call: _e.mock.On("AllocTimestamp", _a0, _a1)} } -func (_c *MockProxy_AllocTimestamp_Call) Run(run func(ctx context.Context, req *milvuspb.AllocTimestampRequest)) *MockProxy_AllocTimestamp_Call { +func (_c *MockProxy_AllocTimestamp_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.AllocTimestampRequest)) *MockProxy_AllocTimestamp_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.AllocTimestampRequest)) }) @@ -87,17 +89,17 @@ func (_c *MockProxy_AllocTimestamp_Call) RunAndReturn(run func(context.Context, return _c } -// AlterAlias provides a mock function with given fields: ctx, request -func (_m *MockProxy) AlterAlias(ctx context.Context, request *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// AlterAlias provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) AlterAlias(_a0 context.Context, _a1 *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterAliasRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterAliasRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -105,7 +107,7 @@ func (_m *MockProxy) AlterAlias(ctx context.Context, request *milvuspb.AlterAlia } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.AlterAliasRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -119,13 +121,13 @@ type MockProxy_AlterAlias_Call struct { } // AlterAlias is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.AlterAliasRequest -func (_e *MockProxy_Expecter) AlterAlias(ctx interface{}, request interface{}) *MockProxy_AlterAlias_Call { - return &MockProxy_AlterAlias_Call{Call: _e.mock.On("AlterAlias", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.AlterAliasRequest +func (_e *MockProxy_Expecter) AlterAlias(_a0 interface{}, _a1 interface{}) *MockProxy_AlterAlias_Call { + return &MockProxy_AlterAlias_Call{Call: _e.mock.On("AlterAlias", _a0, _a1)} } -func (_c *MockProxy_AlterAlias_Call) Run(run func(ctx context.Context, request *milvuspb.AlterAliasRequest)) *MockProxy_AlterAlias_Call { +func (_c *MockProxy_AlterAlias_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.AlterAliasRequest)) *MockProxy_AlterAlias_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.AlterAliasRequest)) }) @@ -142,17 +144,17 @@ func (_c *MockProxy_AlterAlias_Call) RunAndReturn(run func(context.Context, *mil return _c } -// AlterCollection provides a mock function with given fields: ctx, request -func (_m *MockProxy) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// AlterCollection provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) AlterCollection(_a0 context.Context, _a1 *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -160,7 +162,7 @@ func (_m *MockProxy) AlterCollection(ctx context.Context, request *milvuspb.Alte } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.AlterCollectionRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -174,13 +176,13 @@ type MockProxy_AlterCollection_Call struct { } // AlterCollection is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.AlterCollectionRequest -func (_e *MockProxy_Expecter) AlterCollection(ctx interface{}, request interface{}) *MockProxy_AlterCollection_Call { - return &MockProxy_AlterCollection_Call{Call: _e.mock.On("AlterCollection", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.AlterCollectionRequest +func (_e *MockProxy_Expecter) AlterCollection(_a0 interface{}, _a1 interface{}) *MockProxy_AlterCollection_Call { + return &MockProxy_AlterCollection_Call{Call: _e.mock.On("AlterCollection", _a0, _a1)} } -func (_c *MockProxy_AlterCollection_Call) Run(run func(ctx context.Context, request *milvuspb.AlterCollectionRequest)) *MockProxy_AlterCollection_Call { +func (_c *MockProxy_AlterCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.AlterCollectionRequest)) *MockProxy_AlterCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.AlterCollectionRequest)) }) @@ -197,17 +199,17 @@ func (_c *MockProxy_AlterCollection_Call) RunAndReturn(run func(context.Context, return _c } -// CalcDistance provides a mock function with given fields: ctx, request -func (_m *MockProxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDistanceRequest) (*milvuspb.CalcDistanceResults, error) { - ret := _m.Called(ctx, request) +// CalcDistance provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) CalcDistance(_a0 context.Context, _a1 *milvuspb.CalcDistanceRequest) (*milvuspb.CalcDistanceResults, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.CalcDistanceResults var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CalcDistanceRequest) (*milvuspb.CalcDistanceResults, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CalcDistanceRequest) *milvuspb.CalcDistanceResults); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.CalcDistanceResults) @@ -215,7 +217,7 @@ func (_m *MockProxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDis } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CalcDistanceRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -229,13 +231,13 @@ type MockProxy_CalcDistance_Call struct { } // CalcDistance is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.CalcDistanceRequest -func (_e *MockProxy_Expecter) CalcDistance(ctx interface{}, request interface{}) *MockProxy_CalcDistance_Call { - return &MockProxy_CalcDistance_Call{Call: _e.mock.On("CalcDistance", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.CalcDistanceRequest +func (_e *MockProxy_Expecter) CalcDistance(_a0 interface{}, _a1 interface{}) *MockProxy_CalcDistance_Call { + return &MockProxy_CalcDistance_Call{Call: _e.mock.On("CalcDistance", _a0, _a1)} } -func (_c *MockProxy_CalcDistance_Call) Run(run func(ctx context.Context, request *milvuspb.CalcDistanceRequest)) *MockProxy_CalcDistance_Call { +func (_c *MockProxy_CalcDistance_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CalcDistanceRequest)) *MockProxy_CalcDistance_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CalcDistanceRequest)) }) @@ -252,17 +254,17 @@ func (_c *MockProxy_CalcDistance_Call) RunAndReturn(run func(context.Context, *m return _c } -// CheckHealth provides a mock function with given fields: ctx, req -func (_m *MockProxy) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { - ret := _m.Called(ctx, req) +// CheckHealth provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) CheckHealth(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.CheckHealthResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) *milvuspb.CheckHealthResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) @@ -270,7 +272,7 @@ func (_m *MockProxy) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthR } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CheckHealthRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -284,13 +286,13 @@ type MockProxy_CheckHealth_Call struct { } // CheckHealth is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CheckHealthRequest -func (_e *MockProxy_Expecter) CheckHealth(ctx interface{}, req interface{}) *MockProxy_CheckHealth_Call { - return &MockProxy_CheckHealth_Call{Call: _e.mock.On("CheckHealth", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CheckHealthRequest +func (_e *MockProxy_Expecter) CheckHealth(_a0 interface{}, _a1 interface{}) *MockProxy_CheckHealth_Call { + return &MockProxy_CheckHealth_Call{Call: _e.mock.On("CheckHealth", _a0, _a1)} } -func (_c *MockProxy_CheckHealth_Call) Run(run func(ctx context.Context, req *milvuspb.CheckHealthRequest)) *MockProxy_CheckHealth_Call { +func (_c *MockProxy_CheckHealth_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest)) *MockProxy_CheckHealth_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CheckHealthRequest)) }) @@ -307,17 +309,17 @@ func (_c *MockProxy_CheckHealth_Call) RunAndReturn(run func(context.Context, *mi return _c } -// Connect provides a mock function with given fields: ctx, req -func (_m *MockProxy) Connect(ctx context.Context, req *milvuspb.ConnectRequest) (*milvuspb.ConnectResponse, error) { - ret := _m.Called(ctx, req) +// Connect provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) Connect(_a0 context.Context, _a1 *milvuspb.ConnectRequest) (*milvuspb.ConnectResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ConnectResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ConnectRequest) (*milvuspb.ConnectResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ConnectRequest) *milvuspb.ConnectResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ConnectResponse) @@ -325,7 +327,7 @@ func (_m *MockProxy) Connect(ctx context.Context, req *milvuspb.ConnectRequest) } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ConnectRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -339,13 +341,13 @@ type MockProxy_Connect_Call struct { } // Connect is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ConnectRequest -func (_e *MockProxy_Expecter) Connect(ctx interface{}, req interface{}) *MockProxy_Connect_Call { - return &MockProxy_Connect_Call{Call: _e.mock.On("Connect", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ConnectRequest +func (_e *MockProxy_Expecter) Connect(_a0 interface{}, _a1 interface{}) *MockProxy_Connect_Call { + return &MockProxy_Connect_Call{Call: _e.mock.On("Connect", _a0, _a1)} } -func (_c *MockProxy_Connect_Call) Run(run func(ctx context.Context, req *milvuspb.ConnectRequest)) *MockProxy_Connect_Call { +func (_c *MockProxy_Connect_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ConnectRequest)) *MockProxy_Connect_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ConnectRequest)) }) @@ -362,17 +364,17 @@ func (_c *MockProxy_Connect_Call) RunAndReturn(run func(context.Context, *milvus return _c } -// CreateAlias provides a mock function with given fields: ctx, request -func (_m *MockProxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// CreateAlias provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) CreateAlias(_a0 context.Context, _a1 *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateAliasRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateAliasRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -380,7 +382,7 @@ func (_m *MockProxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAl } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateAliasRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -394,13 +396,13 @@ type MockProxy_CreateAlias_Call struct { } // CreateAlias is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.CreateAliasRequest -func (_e *MockProxy_Expecter) CreateAlias(ctx interface{}, request interface{}) *MockProxy_CreateAlias_Call { - return &MockProxy_CreateAlias_Call{Call: _e.mock.On("CreateAlias", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.CreateAliasRequest +func (_e *MockProxy_Expecter) CreateAlias(_a0 interface{}, _a1 interface{}) *MockProxy_CreateAlias_Call { + return &MockProxy_CreateAlias_Call{Call: _e.mock.On("CreateAlias", _a0, _a1)} } -func (_c *MockProxy_CreateAlias_Call) Run(run func(ctx context.Context, request *milvuspb.CreateAliasRequest)) *MockProxy_CreateAlias_Call { +func (_c *MockProxy_CreateAlias_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreateAliasRequest)) *MockProxy_CreateAlias_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreateAliasRequest)) }) @@ -417,17 +419,17 @@ func (_c *MockProxy_CreateAlias_Call) RunAndReturn(run func(context.Context, *mi return _c } -// CreateCollection provides a mock function with given fields: ctx, request -func (_m *MockProxy) CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// CreateCollection provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) CreateCollection(_a0 context.Context, _a1 *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -435,7 +437,7 @@ func (_m *MockProxy) CreateCollection(ctx context.Context, request *milvuspb.Cre } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateCollectionRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -449,13 +451,13 @@ type MockProxy_CreateCollection_Call struct { } // CreateCollection is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.CreateCollectionRequest -func (_e *MockProxy_Expecter) CreateCollection(ctx interface{}, request interface{}) *MockProxy_CreateCollection_Call { - return &MockProxy_CreateCollection_Call{Call: _e.mock.On("CreateCollection", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.CreateCollectionRequest +func (_e *MockProxy_Expecter) CreateCollection(_a0 interface{}, _a1 interface{}) *MockProxy_CreateCollection_Call { + return &MockProxy_CreateCollection_Call{Call: _e.mock.On("CreateCollection", _a0, _a1)} } -func (_c *MockProxy_CreateCollection_Call) Run(run func(ctx context.Context, request *milvuspb.CreateCollectionRequest)) *MockProxy_CreateCollection_Call { +func (_c *MockProxy_CreateCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreateCollectionRequest)) *MockProxy_CreateCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreateCollectionRequest)) }) @@ -472,17 +474,17 @@ func (_c *MockProxy_CreateCollection_Call) RunAndReturn(run func(context.Context return _c } -// CreateCredential provides a mock function with given fields: ctx, req -func (_m *MockProxy) CreateCredential(ctx context.Context, req *milvuspb.CreateCredentialRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// CreateCredential provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) CreateCredential(_a0 context.Context, _a1 *milvuspb.CreateCredentialRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateCredentialRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateCredentialRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -490,7 +492,7 @@ func (_m *MockProxy) CreateCredential(ctx context.Context, req *milvuspb.CreateC } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateCredentialRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -504,13 +506,13 @@ type MockProxy_CreateCredential_Call struct { } // CreateCredential is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CreateCredentialRequest -func (_e *MockProxy_Expecter) CreateCredential(ctx interface{}, req interface{}) *MockProxy_CreateCredential_Call { - return &MockProxy_CreateCredential_Call{Call: _e.mock.On("CreateCredential", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CreateCredentialRequest +func (_e *MockProxy_Expecter) CreateCredential(_a0 interface{}, _a1 interface{}) *MockProxy_CreateCredential_Call { + return &MockProxy_CreateCredential_Call{Call: _e.mock.On("CreateCredential", _a0, _a1)} } -func (_c *MockProxy_CreateCredential_Call) Run(run func(ctx context.Context, req *milvuspb.CreateCredentialRequest)) *MockProxy_CreateCredential_Call { +func (_c *MockProxy_CreateCredential_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreateCredentialRequest)) *MockProxy_CreateCredential_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreateCredentialRequest)) }) @@ -527,17 +529,17 @@ func (_c *MockProxy_CreateCredential_Call) RunAndReturn(run func(context.Context return _c } -// CreateDatabase provides a mock function with given fields: ctx, req -func (_m *MockProxy) CreateDatabase(ctx context.Context, req *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// CreateDatabase provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) CreateDatabase(_a0 context.Context, _a1 *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateDatabaseRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -545,7 +547,7 @@ func (_m *MockProxy) CreateDatabase(ctx context.Context, req *milvuspb.CreateDat } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateDatabaseRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -559,13 +561,13 @@ type MockProxy_CreateDatabase_Call struct { } // CreateDatabase is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CreateDatabaseRequest -func (_e *MockProxy_Expecter) CreateDatabase(ctx interface{}, req interface{}) *MockProxy_CreateDatabase_Call { - return &MockProxy_CreateDatabase_Call{Call: _e.mock.On("CreateDatabase", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CreateDatabaseRequest +func (_e *MockProxy_Expecter) CreateDatabase(_a0 interface{}, _a1 interface{}) *MockProxy_CreateDatabase_Call { + return &MockProxy_CreateDatabase_Call{Call: _e.mock.On("CreateDatabase", _a0, _a1)} } -func (_c *MockProxy_CreateDatabase_Call) Run(run func(ctx context.Context, req *milvuspb.CreateDatabaseRequest)) *MockProxy_CreateDatabase_Call { +func (_c *MockProxy_CreateDatabase_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreateDatabaseRequest)) *MockProxy_CreateDatabase_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreateDatabaseRequest)) }) @@ -582,17 +584,17 @@ func (_c *MockProxy_CreateDatabase_Call) RunAndReturn(run func(context.Context, return _c } -// CreateIndex provides a mock function with given fields: ctx, request -func (_m *MockProxy) CreateIndex(ctx context.Context, request *milvuspb.CreateIndexRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// CreateIndex provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) CreateIndex(_a0 context.Context, _a1 *milvuspb.CreateIndexRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateIndexRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateIndexRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -600,7 +602,7 @@ func (_m *MockProxy) CreateIndex(ctx context.Context, request *milvuspb.CreateIn } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateIndexRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -614,13 +616,13 @@ type MockProxy_CreateIndex_Call struct { } // CreateIndex is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.CreateIndexRequest -func (_e *MockProxy_Expecter) CreateIndex(ctx interface{}, request interface{}) *MockProxy_CreateIndex_Call { - return &MockProxy_CreateIndex_Call{Call: _e.mock.On("CreateIndex", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.CreateIndexRequest +func (_e *MockProxy_Expecter) CreateIndex(_a0 interface{}, _a1 interface{}) *MockProxy_CreateIndex_Call { + return &MockProxy_CreateIndex_Call{Call: _e.mock.On("CreateIndex", _a0, _a1)} } -func (_c *MockProxy_CreateIndex_Call) Run(run func(ctx context.Context, request *milvuspb.CreateIndexRequest)) *MockProxy_CreateIndex_Call { +func (_c *MockProxy_CreateIndex_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreateIndexRequest)) *MockProxy_CreateIndex_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreateIndexRequest)) }) @@ -637,17 +639,17 @@ func (_c *MockProxy_CreateIndex_Call) RunAndReturn(run func(context.Context, *mi return _c } -// CreatePartition provides a mock function with given fields: ctx, request -func (_m *MockProxy) CreatePartition(ctx context.Context, request *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// CreatePartition provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) CreatePartition(_a0 context.Context, _a1 *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreatePartitionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreatePartitionRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -655,7 +657,7 @@ func (_m *MockProxy) CreatePartition(ctx context.Context, request *milvuspb.Crea } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreatePartitionRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -669,13 +671,13 @@ type MockProxy_CreatePartition_Call struct { } // CreatePartition is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.CreatePartitionRequest -func (_e *MockProxy_Expecter) CreatePartition(ctx interface{}, request interface{}) *MockProxy_CreatePartition_Call { - return &MockProxy_CreatePartition_Call{Call: _e.mock.On("CreatePartition", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.CreatePartitionRequest +func (_e *MockProxy_Expecter) CreatePartition(_a0 interface{}, _a1 interface{}) *MockProxy_CreatePartition_Call { + return &MockProxy_CreatePartition_Call{Call: _e.mock.On("CreatePartition", _a0, _a1)} } -func (_c *MockProxy_CreatePartition_Call) Run(run func(ctx context.Context, request *milvuspb.CreatePartitionRequest)) *MockProxy_CreatePartition_Call { +func (_c *MockProxy_CreatePartition_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreatePartitionRequest)) *MockProxy_CreatePartition_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreatePartitionRequest)) }) @@ -692,17 +694,17 @@ func (_c *MockProxy_CreatePartition_Call) RunAndReturn(run func(context.Context, return _c } -// CreateResourceGroup provides a mock function with given fields: ctx, req -func (_m *MockProxy) CreateResourceGroup(ctx context.Context, req *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// CreateResourceGroup provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) CreateResourceGroup(_a0 context.Context, _a1 *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateResourceGroupRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -710,7 +712,7 @@ func (_m *MockProxy) CreateResourceGroup(ctx context.Context, req *milvuspb.Crea } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateResourceGroupRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -724,13 +726,13 @@ type MockProxy_CreateResourceGroup_Call struct { } // CreateResourceGroup is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CreateResourceGroupRequest -func (_e *MockProxy_Expecter) CreateResourceGroup(ctx interface{}, req interface{}) *MockProxy_CreateResourceGroup_Call { - return &MockProxy_CreateResourceGroup_Call{Call: _e.mock.On("CreateResourceGroup", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CreateResourceGroupRequest +func (_e *MockProxy_Expecter) CreateResourceGroup(_a0 interface{}, _a1 interface{}) *MockProxy_CreateResourceGroup_Call { + return &MockProxy_CreateResourceGroup_Call{Call: _e.mock.On("CreateResourceGroup", _a0, _a1)} } -func (_c *MockProxy_CreateResourceGroup_Call) Run(run func(ctx context.Context, req *milvuspb.CreateResourceGroupRequest)) *MockProxy_CreateResourceGroup_Call { +func (_c *MockProxy_CreateResourceGroup_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreateResourceGroupRequest)) *MockProxy_CreateResourceGroup_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreateResourceGroupRequest)) }) @@ -747,17 +749,17 @@ func (_c *MockProxy_CreateResourceGroup_Call) RunAndReturn(run func(context.Cont return _c } -// CreateRole provides a mock function with given fields: ctx, req -func (_m *MockProxy) CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// CreateRole provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) CreateRole(_a0 context.Context, _a1 *milvuspb.CreateRoleRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateRoleRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateRoleRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -765,7 +767,7 @@ func (_m *MockProxy) CreateRole(ctx context.Context, req *milvuspb.CreateRoleReq } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateRoleRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -779,13 +781,13 @@ type MockProxy_CreateRole_Call struct { } // CreateRole is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CreateRoleRequest -func (_e *MockProxy_Expecter) CreateRole(ctx interface{}, req interface{}) *MockProxy_CreateRole_Call { - return &MockProxy_CreateRole_Call{Call: _e.mock.On("CreateRole", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CreateRoleRequest +func (_e *MockProxy_Expecter) CreateRole(_a0 interface{}, _a1 interface{}) *MockProxy_CreateRole_Call { + return &MockProxy_CreateRole_Call{Call: _e.mock.On("CreateRole", _a0, _a1)} } -func (_c *MockProxy_CreateRole_Call) Run(run func(ctx context.Context, req *milvuspb.CreateRoleRequest)) *MockProxy_CreateRole_Call { +func (_c *MockProxy_CreateRole_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreateRoleRequest)) *MockProxy_CreateRole_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreateRoleRequest)) }) @@ -802,17 +804,17 @@ func (_c *MockProxy_CreateRole_Call) RunAndReturn(run func(context.Context, *mil return _c } -// Delete provides a mock function with given fields: ctx, request -func (_m *MockProxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error) { - ret := _m.Called(ctx, request) +// Delete provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) Delete(_a0 context.Context, _a1 *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.MutationResult var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DeleteRequest) *milvuspb.MutationResult); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.MutationResult) @@ -820,7 +822,7 @@ func (_m *MockProxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DeleteRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -834,13 +836,13 @@ type MockProxy_Delete_Call struct { } // Delete is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.DeleteRequest -func (_e *MockProxy_Expecter) Delete(ctx interface{}, request interface{}) *MockProxy_Delete_Call { - return &MockProxy_Delete_Call{Call: _e.mock.On("Delete", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.DeleteRequest +func (_e *MockProxy_Expecter) Delete(_a0 interface{}, _a1 interface{}) *MockProxy_Delete_Call { + return &MockProxy_Delete_Call{Call: _e.mock.On("Delete", _a0, _a1)} } -func (_c *MockProxy_Delete_Call) Run(run func(ctx context.Context, request *milvuspb.DeleteRequest)) *MockProxy_Delete_Call { +func (_c *MockProxy_Delete_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DeleteRequest)) *MockProxy_Delete_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DeleteRequest)) }) @@ -857,17 +859,17 @@ func (_c *MockProxy_Delete_Call) RunAndReturn(run func(context.Context, *milvusp return _c } -// DeleteCredential provides a mock function with given fields: ctx, req -func (_m *MockProxy) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// DeleteCredential provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) DeleteCredential(_a0 context.Context, _a1 *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DeleteCredentialRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -875,7 +877,7 @@ func (_m *MockProxy) DeleteCredential(ctx context.Context, req *milvuspb.DeleteC } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DeleteCredentialRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -889,13 +891,13 @@ type MockProxy_DeleteCredential_Call struct { } // DeleteCredential is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DeleteCredentialRequest -func (_e *MockProxy_Expecter) DeleteCredential(ctx interface{}, req interface{}) *MockProxy_DeleteCredential_Call { - return &MockProxy_DeleteCredential_Call{Call: _e.mock.On("DeleteCredential", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DeleteCredentialRequest +func (_e *MockProxy_Expecter) DeleteCredential(_a0 interface{}, _a1 interface{}) *MockProxy_DeleteCredential_Call { + return &MockProxy_DeleteCredential_Call{Call: _e.mock.On("DeleteCredential", _a0, _a1)} } -func (_c *MockProxy_DeleteCredential_Call) Run(run func(ctx context.Context, req *milvuspb.DeleteCredentialRequest)) *MockProxy_DeleteCredential_Call { +func (_c *MockProxy_DeleteCredential_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DeleteCredentialRequest)) *MockProxy_DeleteCredential_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DeleteCredentialRequest)) }) @@ -912,17 +914,72 @@ func (_c *MockProxy_DeleteCredential_Call) RunAndReturn(run func(context.Context return _c } -// DescribeCollection provides a mock function with given fields: ctx, request -func (_m *MockProxy) DescribeCollection(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { - ret := _m.Called(ctx, request) +// DescribeAlias provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) DescribeAlias(_a0 context.Context, _a1 *milvuspb.DescribeAliasRequest) (*milvuspb.DescribeAliasResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *milvuspb.DescribeAliasResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeAliasRequest) (*milvuspb.DescribeAliasResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeAliasRequest) *milvuspb.DescribeAliasResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.DescribeAliasResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DescribeAliasRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxy_DescribeAlias_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeAlias' +type MockProxy_DescribeAlias_Call struct { + *mock.Call +} + +// DescribeAlias is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *milvuspb.DescribeAliasRequest +func (_e *MockProxy_Expecter) DescribeAlias(_a0 interface{}, _a1 interface{}) *MockProxy_DescribeAlias_Call { + return &MockProxy_DescribeAlias_Call{Call: _e.mock.On("DescribeAlias", _a0, _a1)} +} + +func (_c *MockProxy_DescribeAlias_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DescribeAliasRequest)) *MockProxy_DescribeAlias_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.DescribeAliasRequest)) + }) + return _c +} + +func (_c *MockProxy_DescribeAlias_Call) Return(_a0 *milvuspb.DescribeAliasResponse, _a1 error) *MockProxy_DescribeAlias_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxy_DescribeAlias_Call) RunAndReturn(run func(context.Context, *milvuspb.DescribeAliasRequest) (*milvuspb.DescribeAliasResponse, error)) *MockProxy_DescribeAlias_Call { + _c.Call.Return(run) + return _c +} + +// DescribeCollection provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) DescribeCollection(_a0 context.Context, _a1 *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.DescribeCollectionResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeCollectionRequest) *milvuspb.DescribeCollectionResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.DescribeCollectionResponse) @@ -930,7 +987,7 @@ func (_m *MockProxy) DescribeCollection(ctx context.Context, request *milvuspb.D } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DescribeCollectionRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -944,13 +1001,13 @@ type MockProxy_DescribeCollection_Call struct { } // DescribeCollection is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.DescribeCollectionRequest -func (_e *MockProxy_Expecter) DescribeCollection(ctx interface{}, request interface{}) *MockProxy_DescribeCollection_Call { - return &MockProxy_DescribeCollection_Call{Call: _e.mock.On("DescribeCollection", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.DescribeCollectionRequest +func (_e *MockProxy_Expecter) DescribeCollection(_a0 interface{}, _a1 interface{}) *MockProxy_DescribeCollection_Call { + return &MockProxy_DescribeCollection_Call{Call: _e.mock.On("DescribeCollection", _a0, _a1)} } -func (_c *MockProxy_DescribeCollection_Call) Run(run func(ctx context.Context, request *milvuspb.DescribeCollectionRequest)) *MockProxy_DescribeCollection_Call { +func (_c *MockProxy_DescribeCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DescribeCollectionRequest)) *MockProxy_DescribeCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DescribeCollectionRequest)) }) @@ -967,17 +1024,17 @@ func (_c *MockProxy_DescribeCollection_Call) RunAndReturn(run func(context.Conte return _c } -// DescribeIndex provides a mock function with given fields: ctx, request -func (_m *MockProxy) DescribeIndex(ctx context.Context, request *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) { - ret := _m.Called(ctx, request) +// DescribeIndex provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) DescribeIndex(_a0 context.Context, _a1 *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.DescribeIndexResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeIndexRequest) *milvuspb.DescribeIndexResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.DescribeIndexResponse) @@ -985,7 +1042,7 @@ func (_m *MockProxy) DescribeIndex(ctx context.Context, request *milvuspb.Descri } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DescribeIndexRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -999,13 +1056,13 @@ type MockProxy_DescribeIndex_Call struct { } // DescribeIndex is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.DescribeIndexRequest -func (_e *MockProxy_Expecter) DescribeIndex(ctx interface{}, request interface{}) *MockProxy_DescribeIndex_Call { - return &MockProxy_DescribeIndex_Call{Call: _e.mock.On("DescribeIndex", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.DescribeIndexRequest +func (_e *MockProxy_Expecter) DescribeIndex(_a0 interface{}, _a1 interface{}) *MockProxy_DescribeIndex_Call { + return &MockProxy_DescribeIndex_Call{Call: _e.mock.On("DescribeIndex", _a0, _a1)} } -func (_c *MockProxy_DescribeIndex_Call) Run(run func(ctx context.Context, request *milvuspb.DescribeIndexRequest)) *MockProxy_DescribeIndex_Call { +func (_c *MockProxy_DescribeIndex_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DescribeIndexRequest)) *MockProxy_DescribeIndex_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DescribeIndexRequest)) }) @@ -1022,17 +1079,17 @@ func (_c *MockProxy_DescribeIndex_Call) RunAndReturn(run func(context.Context, * return _c } -// DescribeResourceGroup provides a mock function with given fields: ctx, req -func (_m *MockProxy) DescribeResourceGroup(ctx context.Context, req *milvuspb.DescribeResourceGroupRequest) (*milvuspb.DescribeResourceGroupResponse, error) { - ret := _m.Called(ctx, req) +// DescribeResourceGroup provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) DescribeResourceGroup(_a0 context.Context, _a1 *milvuspb.DescribeResourceGroupRequest) (*milvuspb.DescribeResourceGroupResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.DescribeResourceGroupResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeResourceGroupRequest) (*milvuspb.DescribeResourceGroupResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeResourceGroupRequest) *milvuspb.DescribeResourceGroupResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.DescribeResourceGroupResponse) @@ -1040,7 +1097,7 @@ func (_m *MockProxy) DescribeResourceGroup(ctx context.Context, req *milvuspb.De } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DescribeResourceGroupRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1054,13 +1111,13 @@ type MockProxy_DescribeResourceGroup_Call struct { } // DescribeResourceGroup is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DescribeResourceGroupRequest -func (_e *MockProxy_Expecter) DescribeResourceGroup(ctx interface{}, req interface{}) *MockProxy_DescribeResourceGroup_Call { - return &MockProxy_DescribeResourceGroup_Call{Call: _e.mock.On("DescribeResourceGroup", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DescribeResourceGroupRequest +func (_e *MockProxy_Expecter) DescribeResourceGroup(_a0 interface{}, _a1 interface{}) *MockProxy_DescribeResourceGroup_Call { + return &MockProxy_DescribeResourceGroup_Call{Call: _e.mock.On("DescribeResourceGroup", _a0, _a1)} } -func (_c *MockProxy_DescribeResourceGroup_Call) Run(run func(ctx context.Context, req *milvuspb.DescribeResourceGroupRequest)) *MockProxy_DescribeResourceGroup_Call { +func (_c *MockProxy_DescribeResourceGroup_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DescribeResourceGroupRequest)) *MockProxy_DescribeResourceGroup_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DescribeResourceGroupRequest)) }) @@ -1077,17 +1134,72 @@ func (_c *MockProxy_DescribeResourceGroup_Call) RunAndReturn(run func(context.Co return _c } -// DropAlias provides a mock function with given fields: ctx, request -func (_m *MockProxy) DropAlias(ctx context.Context, request *milvuspb.DropAliasRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// DescribeSegmentIndexData provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) DescribeSegmentIndexData(_a0 context.Context, _a1 *federpb.DescribeSegmentIndexDataRequest) (*federpb.DescribeSegmentIndexDataResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *federpb.DescribeSegmentIndexDataResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *federpb.DescribeSegmentIndexDataRequest) (*federpb.DescribeSegmentIndexDataResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *federpb.DescribeSegmentIndexDataRequest) *federpb.DescribeSegmentIndexDataResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*federpb.DescribeSegmentIndexDataResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *federpb.DescribeSegmentIndexDataRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxy_DescribeSegmentIndexData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeSegmentIndexData' +type MockProxy_DescribeSegmentIndexData_Call struct { + *mock.Call +} + +// DescribeSegmentIndexData is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *federpb.DescribeSegmentIndexDataRequest +func (_e *MockProxy_Expecter) DescribeSegmentIndexData(_a0 interface{}, _a1 interface{}) *MockProxy_DescribeSegmentIndexData_Call { + return &MockProxy_DescribeSegmentIndexData_Call{Call: _e.mock.On("DescribeSegmentIndexData", _a0, _a1)} +} + +func (_c *MockProxy_DescribeSegmentIndexData_Call) Run(run func(_a0 context.Context, _a1 *federpb.DescribeSegmentIndexDataRequest)) *MockProxy_DescribeSegmentIndexData_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*federpb.DescribeSegmentIndexDataRequest)) + }) + return _c +} + +func (_c *MockProxy_DescribeSegmentIndexData_Call) Return(_a0 *federpb.DescribeSegmentIndexDataResponse, _a1 error) *MockProxy_DescribeSegmentIndexData_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxy_DescribeSegmentIndexData_Call) RunAndReturn(run func(context.Context, *federpb.DescribeSegmentIndexDataRequest) (*federpb.DescribeSegmentIndexDataResponse, error)) *MockProxy_DescribeSegmentIndexData_Call { + _c.Call.Return(run) + return _c +} + +// DropAlias provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) DropAlias(_a0 context.Context, _a1 *milvuspb.DropAliasRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropAliasRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropAliasRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1095,7 +1207,7 @@ func (_m *MockProxy) DropAlias(ctx context.Context, request *milvuspb.DropAliasR } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropAliasRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1109,13 +1221,13 @@ type MockProxy_DropAlias_Call struct { } // DropAlias is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.DropAliasRequest -func (_e *MockProxy_Expecter) DropAlias(ctx interface{}, request interface{}) *MockProxy_DropAlias_Call { - return &MockProxy_DropAlias_Call{Call: _e.mock.On("DropAlias", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.DropAliasRequest +func (_e *MockProxy_Expecter) DropAlias(_a0 interface{}, _a1 interface{}) *MockProxy_DropAlias_Call { + return &MockProxy_DropAlias_Call{Call: _e.mock.On("DropAlias", _a0, _a1)} } -func (_c *MockProxy_DropAlias_Call) Run(run func(ctx context.Context, request *milvuspb.DropAliasRequest)) *MockProxy_DropAlias_Call { +func (_c *MockProxy_DropAlias_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DropAliasRequest)) *MockProxy_DropAlias_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DropAliasRequest)) }) @@ -1132,17 +1244,17 @@ func (_c *MockProxy_DropAlias_Call) RunAndReturn(run func(context.Context, *milv return _c } -// DropCollection provides a mock function with given fields: ctx, request -func (_m *MockProxy) DropCollection(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// DropCollection provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) DropCollection(_a0 context.Context, _a1 *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1150,7 +1262,7 @@ func (_m *MockProxy) DropCollection(ctx context.Context, request *milvuspb.DropC } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropCollectionRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1164,13 +1276,13 @@ type MockProxy_DropCollection_Call struct { } // DropCollection is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.DropCollectionRequest -func (_e *MockProxy_Expecter) DropCollection(ctx interface{}, request interface{}) *MockProxy_DropCollection_Call { - return &MockProxy_DropCollection_Call{Call: _e.mock.On("DropCollection", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.DropCollectionRequest +func (_e *MockProxy_Expecter) DropCollection(_a0 interface{}, _a1 interface{}) *MockProxy_DropCollection_Call { + return &MockProxy_DropCollection_Call{Call: _e.mock.On("DropCollection", _a0, _a1)} } -func (_c *MockProxy_DropCollection_Call) Run(run func(ctx context.Context, request *milvuspb.DropCollectionRequest)) *MockProxy_DropCollection_Call { +func (_c *MockProxy_DropCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DropCollectionRequest)) *MockProxy_DropCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DropCollectionRequest)) }) @@ -1187,17 +1299,17 @@ func (_c *MockProxy_DropCollection_Call) RunAndReturn(run func(context.Context, return _c } -// DropDatabase provides a mock function with given fields: ctx, req -func (_m *MockProxy) DropDatabase(ctx context.Context, req *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// DropDatabase provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) DropDatabase(_a0 context.Context, _a1 *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropDatabaseRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropDatabaseRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1205,7 +1317,7 @@ func (_m *MockProxy) DropDatabase(ctx context.Context, req *milvuspb.DropDatabas } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropDatabaseRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1219,13 +1331,13 @@ type MockProxy_DropDatabase_Call struct { } // DropDatabase is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DropDatabaseRequest -func (_e *MockProxy_Expecter) DropDatabase(ctx interface{}, req interface{}) *MockProxy_DropDatabase_Call { - return &MockProxy_DropDatabase_Call{Call: _e.mock.On("DropDatabase", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DropDatabaseRequest +func (_e *MockProxy_Expecter) DropDatabase(_a0 interface{}, _a1 interface{}) *MockProxy_DropDatabase_Call { + return &MockProxy_DropDatabase_Call{Call: _e.mock.On("DropDatabase", _a0, _a1)} } -func (_c *MockProxy_DropDatabase_Call) Run(run func(ctx context.Context, req *milvuspb.DropDatabaseRequest)) *MockProxy_DropDatabase_Call { +func (_c *MockProxy_DropDatabase_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DropDatabaseRequest)) *MockProxy_DropDatabase_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DropDatabaseRequest)) }) @@ -1242,17 +1354,17 @@ func (_c *MockProxy_DropDatabase_Call) RunAndReturn(run func(context.Context, *m return _c } -// DropIndex provides a mock function with given fields: ctx, request -func (_m *MockProxy) DropIndex(ctx context.Context, request *milvuspb.DropIndexRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// DropIndex provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) DropIndex(_a0 context.Context, _a1 *milvuspb.DropIndexRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropIndexRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropIndexRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1260,7 +1372,7 @@ func (_m *MockProxy) DropIndex(ctx context.Context, request *milvuspb.DropIndexR } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropIndexRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1274,13 +1386,13 @@ type MockProxy_DropIndex_Call struct { } // DropIndex is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.DropIndexRequest -func (_e *MockProxy_Expecter) DropIndex(ctx interface{}, request interface{}) *MockProxy_DropIndex_Call { - return &MockProxy_DropIndex_Call{Call: _e.mock.On("DropIndex", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.DropIndexRequest +func (_e *MockProxy_Expecter) DropIndex(_a0 interface{}, _a1 interface{}) *MockProxy_DropIndex_Call { + return &MockProxy_DropIndex_Call{Call: _e.mock.On("DropIndex", _a0, _a1)} } -func (_c *MockProxy_DropIndex_Call) Run(run func(ctx context.Context, request *milvuspb.DropIndexRequest)) *MockProxy_DropIndex_Call { +func (_c *MockProxy_DropIndex_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DropIndexRequest)) *MockProxy_DropIndex_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DropIndexRequest)) }) @@ -1297,17 +1409,17 @@ func (_c *MockProxy_DropIndex_Call) RunAndReturn(run func(context.Context, *milv return _c } -// DropPartition provides a mock function with given fields: ctx, request -func (_m *MockProxy) DropPartition(ctx context.Context, request *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// DropPartition provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) DropPartition(_a0 context.Context, _a1 *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropPartitionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropPartitionRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1315,7 +1427,7 @@ func (_m *MockProxy) DropPartition(ctx context.Context, request *milvuspb.DropPa } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropPartitionRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1329,13 +1441,13 @@ type MockProxy_DropPartition_Call struct { } // DropPartition is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.DropPartitionRequest -func (_e *MockProxy_Expecter) DropPartition(ctx interface{}, request interface{}) *MockProxy_DropPartition_Call { - return &MockProxy_DropPartition_Call{Call: _e.mock.On("DropPartition", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.DropPartitionRequest +func (_e *MockProxy_Expecter) DropPartition(_a0 interface{}, _a1 interface{}) *MockProxy_DropPartition_Call { + return &MockProxy_DropPartition_Call{Call: _e.mock.On("DropPartition", _a0, _a1)} } -func (_c *MockProxy_DropPartition_Call) Run(run func(ctx context.Context, request *milvuspb.DropPartitionRequest)) *MockProxy_DropPartition_Call { +func (_c *MockProxy_DropPartition_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DropPartitionRequest)) *MockProxy_DropPartition_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DropPartitionRequest)) }) @@ -1352,17 +1464,17 @@ func (_c *MockProxy_DropPartition_Call) RunAndReturn(run func(context.Context, * return _c } -// DropResourceGroup provides a mock function with given fields: ctx, req -func (_m *MockProxy) DropResourceGroup(ctx context.Context, req *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// DropResourceGroup provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) DropResourceGroup(_a0 context.Context, _a1 *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropResourceGroupRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1370,7 +1482,7 @@ func (_m *MockProxy) DropResourceGroup(ctx context.Context, req *milvuspb.DropRe } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropResourceGroupRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1384,13 +1496,13 @@ type MockProxy_DropResourceGroup_Call struct { } // DropResourceGroup is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DropResourceGroupRequest -func (_e *MockProxy_Expecter) DropResourceGroup(ctx interface{}, req interface{}) *MockProxy_DropResourceGroup_Call { - return &MockProxy_DropResourceGroup_Call{Call: _e.mock.On("DropResourceGroup", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DropResourceGroupRequest +func (_e *MockProxy_Expecter) DropResourceGroup(_a0 interface{}, _a1 interface{}) *MockProxy_DropResourceGroup_Call { + return &MockProxy_DropResourceGroup_Call{Call: _e.mock.On("DropResourceGroup", _a0, _a1)} } -func (_c *MockProxy_DropResourceGroup_Call) Run(run func(ctx context.Context, req *milvuspb.DropResourceGroupRequest)) *MockProxy_DropResourceGroup_Call { +func (_c *MockProxy_DropResourceGroup_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DropResourceGroupRequest)) *MockProxy_DropResourceGroup_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DropResourceGroupRequest)) }) @@ -1407,17 +1519,17 @@ func (_c *MockProxy_DropResourceGroup_Call) RunAndReturn(run func(context.Contex return _c } -// DropRole provides a mock function with given fields: ctx, req -func (_m *MockProxy) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// DropRole provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) DropRole(_a0 context.Context, _a1 *milvuspb.DropRoleRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropRoleRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropRoleRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1425,7 +1537,7 @@ func (_m *MockProxy) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropRoleRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1439,13 +1551,13 @@ type MockProxy_DropRole_Call struct { } // DropRole is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DropRoleRequest -func (_e *MockProxy_Expecter) DropRole(ctx interface{}, req interface{}) *MockProxy_DropRole_Call { - return &MockProxy_DropRole_Call{Call: _e.mock.On("DropRole", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DropRoleRequest +func (_e *MockProxy_Expecter) DropRole(_a0 interface{}, _a1 interface{}) *MockProxy_DropRole_Call { + return &MockProxy_DropRole_Call{Call: _e.mock.On("DropRole", _a0, _a1)} } -func (_c *MockProxy_DropRole_Call) Run(run func(ctx context.Context, req *milvuspb.DropRoleRequest)) *MockProxy_DropRole_Call { +func (_c *MockProxy_DropRole_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DropRoleRequest)) *MockProxy_DropRole_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DropRoleRequest)) }) @@ -1462,17 +1574,17 @@ func (_c *MockProxy_DropRole_Call) RunAndReturn(run func(context.Context, *milvu return _c } -// Dummy provides a mock function with given fields: ctx, request -func (_m *MockProxy) Dummy(ctx context.Context, request *milvuspb.DummyRequest) (*milvuspb.DummyResponse, error) { - ret := _m.Called(ctx, request) +// Dummy provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) Dummy(_a0 context.Context, _a1 *milvuspb.DummyRequest) (*milvuspb.DummyResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.DummyResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DummyRequest) (*milvuspb.DummyResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DummyRequest) *milvuspb.DummyResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.DummyResponse) @@ -1480,7 +1592,7 @@ func (_m *MockProxy) Dummy(ctx context.Context, request *milvuspb.DummyRequest) } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DummyRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1494,13 +1606,13 @@ type MockProxy_Dummy_Call struct { } // Dummy is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.DummyRequest -func (_e *MockProxy_Expecter) Dummy(ctx interface{}, request interface{}) *MockProxy_Dummy_Call { - return &MockProxy_Dummy_Call{Call: _e.mock.On("Dummy", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.DummyRequest +func (_e *MockProxy_Expecter) Dummy(_a0 interface{}, _a1 interface{}) *MockProxy_Dummy_Call { + return &MockProxy_Dummy_Call{Call: _e.mock.On("Dummy", _a0, _a1)} } -func (_c *MockProxy_Dummy_Call) Run(run func(ctx context.Context, request *milvuspb.DummyRequest)) *MockProxy_Dummy_Call { +func (_c *MockProxy_Dummy_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DummyRequest)) *MockProxy_Dummy_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DummyRequest)) }) @@ -1517,17 +1629,17 @@ func (_c *MockProxy_Dummy_Call) RunAndReturn(run func(context.Context, *milvuspb return _c } -// Flush provides a mock function with given fields: ctx, request -func (_m *MockProxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*milvuspb.FlushResponse, error) { - ret := _m.Called(ctx, request) +// Flush provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) Flush(_a0 context.Context, _a1 *milvuspb.FlushRequest) (*milvuspb.FlushResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.FlushResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.FlushRequest) (*milvuspb.FlushResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.FlushRequest) *milvuspb.FlushResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.FlushResponse) @@ -1535,7 +1647,7 @@ func (_m *MockProxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.FlushRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1549,13 +1661,13 @@ type MockProxy_Flush_Call struct { } // Flush is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.FlushRequest -func (_e *MockProxy_Expecter) Flush(ctx interface{}, request interface{}) *MockProxy_Flush_Call { - return &MockProxy_Flush_Call{Call: _e.mock.On("Flush", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.FlushRequest +func (_e *MockProxy_Expecter) Flush(_a0 interface{}, _a1 interface{}) *MockProxy_Flush_Call { + return &MockProxy_Flush_Call{Call: _e.mock.On("Flush", _a0, _a1)} } -func (_c *MockProxy_Flush_Call) Run(run func(ctx context.Context, request *milvuspb.FlushRequest)) *MockProxy_Flush_Call { +func (_c *MockProxy_Flush_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.FlushRequest)) *MockProxy_Flush_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.FlushRequest)) }) @@ -1572,17 +1684,17 @@ func (_c *MockProxy_Flush_Call) RunAndReturn(run func(context.Context, *milvuspb return _c } -// FlushAll provides a mock function with given fields: ctx, request -func (_m *MockProxy) FlushAll(ctx context.Context, request *milvuspb.FlushAllRequest) (*milvuspb.FlushAllResponse, error) { - ret := _m.Called(ctx, request) +// FlushAll provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) FlushAll(_a0 context.Context, _a1 *milvuspb.FlushAllRequest) (*milvuspb.FlushAllResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.FlushAllResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.FlushAllRequest) (*milvuspb.FlushAllResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.FlushAllRequest) *milvuspb.FlushAllResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.FlushAllResponse) @@ -1590,7 +1702,7 @@ func (_m *MockProxy) FlushAll(ctx context.Context, request *milvuspb.FlushAllReq } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.FlushAllRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1604,13 +1716,13 @@ type MockProxy_FlushAll_Call struct { } // FlushAll is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.FlushAllRequest -func (_e *MockProxy_Expecter) FlushAll(ctx interface{}, request interface{}) *MockProxy_FlushAll_Call { - return &MockProxy_FlushAll_Call{Call: _e.mock.On("FlushAll", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.FlushAllRequest +func (_e *MockProxy_Expecter) FlushAll(_a0 interface{}, _a1 interface{}) *MockProxy_FlushAll_Call { + return &MockProxy_FlushAll_Call{Call: _e.mock.On("FlushAll", _a0, _a1)} } -func (_c *MockProxy_FlushAll_Call) Run(run func(ctx context.Context, request *milvuspb.FlushAllRequest)) *MockProxy_FlushAll_Call { +func (_c *MockProxy_FlushAll_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.FlushAllRequest)) *MockProxy_FlushAll_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.FlushAllRequest)) }) @@ -1668,17 +1780,17 @@ func (_c *MockProxy_GetAddress_Call) RunAndReturn(run func() string) *MockProxy_ return _c } -// GetCollectionStatistics provides a mock function with given fields: ctx, request -func (_m *MockProxy) GetCollectionStatistics(ctx context.Context, request *milvuspb.GetCollectionStatisticsRequest) (*milvuspb.GetCollectionStatisticsResponse, error) { - ret := _m.Called(ctx, request) +// GetCollectionStatistics provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetCollectionStatistics(_a0 context.Context, _a1 *milvuspb.GetCollectionStatisticsRequest) (*milvuspb.GetCollectionStatisticsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetCollectionStatisticsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCollectionStatisticsRequest) (*milvuspb.GetCollectionStatisticsResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCollectionStatisticsRequest) *milvuspb.GetCollectionStatisticsResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetCollectionStatisticsResponse) @@ -1686,7 +1798,7 @@ func (_m *MockProxy) GetCollectionStatistics(ctx context.Context, request *milvu } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetCollectionStatisticsRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1700,13 +1812,13 @@ type MockProxy_GetCollectionStatistics_Call struct { } // GetCollectionStatistics is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.GetCollectionStatisticsRequest -func (_e *MockProxy_Expecter) GetCollectionStatistics(ctx interface{}, request interface{}) *MockProxy_GetCollectionStatistics_Call { - return &MockProxy_GetCollectionStatistics_Call{Call: _e.mock.On("GetCollectionStatistics", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.GetCollectionStatisticsRequest +func (_e *MockProxy_Expecter) GetCollectionStatistics(_a0 interface{}, _a1 interface{}) *MockProxy_GetCollectionStatistics_Call { + return &MockProxy_GetCollectionStatistics_Call{Call: _e.mock.On("GetCollectionStatistics", _a0, _a1)} } -func (_c *MockProxy_GetCollectionStatistics_Call) Run(run func(ctx context.Context, request *milvuspb.GetCollectionStatisticsRequest)) *MockProxy_GetCollectionStatistics_Call { +func (_c *MockProxy_GetCollectionStatistics_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetCollectionStatisticsRequest)) *MockProxy_GetCollectionStatistics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetCollectionStatisticsRequest)) }) @@ -1723,17 +1835,17 @@ func (_c *MockProxy_GetCollectionStatistics_Call) RunAndReturn(run func(context. return _c } -// GetCompactionState provides a mock function with given fields: ctx, req -func (_m *MockProxy) GetCompactionState(ctx context.Context, req *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) { - ret := _m.Called(ctx, req) +// GetCompactionState provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetCompactionState(_a0 context.Context, _a1 *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetCompactionStateResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionStateRequest) *milvuspb.GetCompactionStateResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetCompactionStateResponse) @@ -1741,7 +1853,7 @@ func (_m *MockProxy) GetCompactionState(ctx context.Context, req *milvuspb.GetCo } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetCompactionStateRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1755,13 +1867,13 @@ type MockProxy_GetCompactionState_Call struct { } // GetCompactionState is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetCompactionStateRequest -func (_e *MockProxy_Expecter) GetCompactionState(ctx interface{}, req interface{}) *MockProxy_GetCompactionState_Call { - return &MockProxy_GetCompactionState_Call{Call: _e.mock.On("GetCompactionState", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetCompactionStateRequest +func (_e *MockProxy_Expecter) GetCompactionState(_a0 interface{}, _a1 interface{}) *MockProxy_GetCompactionState_Call { + return &MockProxy_GetCompactionState_Call{Call: _e.mock.On("GetCompactionState", _a0, _a1)} } -func (_c *MockProxy_GetCompactionState_Call) Run(run func(ctx context.Context, req *milvuspb.GetCompactionStateRequest)) *MockProxy_GetCompactionState_Call { +func (_c *MockProxy_GetCompactionState_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetCompactionStateRequest)) *MockProxy_GetCompactionState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetCompactionStateRequest)) }) @@ -1778,17 +1890,17 @@ func (_c *MockProxy_GetCompactionState_Call) RunAndReturn(run func(context.Conte return _c } -// GetCompactionStateWithPlans provides a mock function with given fields: ctx, req -func (_m *MockProxy) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) { - ret := _m.Called(ctx, req) +// GetCompactionStateWithPlans provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetCompactionStateWithPlans(_a0 context.Context, _a1 *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetCompactionPlansResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetCompactionPlansRequest) *milvuspb.GetCompactionPlansResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetCompactionPlansResponse) @@ -1796,7 +1908,7 @@ func (_m *MockProxy) GetCompactionStateWithPlans(ctx context.Context, req *milvu } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetCompactionPlansRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1810,13 +1922,13 @@ type MockProxy_GetCompactionStateWithPlans_Call struct { } // GetCompactionStateWithPlans is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetCompactionPlansRequest -func (_e *MockProxy_Expecter) GetCompactionStateWithPlans(ctx interface{}, req interface{}) *MockProxy_GetCompactionStateWithPlans_Call { - return &MockProxy_GetCompactionStateWithPlans_Call{Call: _e.mock.On("GetCompactionStateWithPlans", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetCompactionPlansRequest +func (_e *MockProxy_Expecter) GetCompactionStateWithPlans(_a0 interface{}, _a1 interface{}) *MockProxy_GetCompactionStateWithPlans_Call { + return &MockProxy_GetCompactionStateWithPlans_Call{Call: _e.mock.On("GetCompactionStateWithPlans", _a0, _a1)} } -func (_c *MockProxy_GetCompactionStateWithPlans_Call) Run(run func(ctx context.Context, req *milvuspb.GetCompactionPlansRequest)) *MockProxy_GetCompactionStateWithPlans_Call { +func (_c *MockProxy_GetCompactionStateWithPlans_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetCompactionPlansRequest)) *MockProxy_GetCompactionStateWithPlans_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetCompactionPlansRequest)) }) @@ -1833,25 +1945,25 @@ func (_c *MockProxy_GetCompactionStateWithPlans_Call) RunAndReturn(run func(cont return _c } -// GetComponentStates provides a mock function with given fields: ctx -func (_m *MockProxy) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret := _m.Called(ctx) +// GetComponentStates provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetComponentStates(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ComponentStates var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.ComponentStates, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.ComponentStates); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) *milvuspb.ComponentStates); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ComponentStates) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1865,14 +1977,15 @@ type MockProxy_GetComponentStates_Call struct { } // GetComponentStates is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockProxy_Expecter) GetComponentStates(ctx interface{}) *MockProxy_GetComponentStates_Call { - return &MockProxy_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx)} +// - _a0 context.Context +// - _a1 *milvuspb.GetComponentStatesRequest +func (_e *MockProxy_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *MockProxy_GetComponentStates_Call { + return &MockProxy_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)} } -func (_c *MockProxy_GetComponentStates_Call) Run(run func(ctx context.Context)) *MockProxy_GetComponentStates_Call { +func (_c *MockProxy_GetComponentStates_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest)) *MockProxy_GetComponentStates_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest)) }) return _c } @@ -1882,22 +1995,22 @@ func (_c *MockProxy_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentState return _c } -func (_c *MockProxy_GetComponentStates_Call) RunAndReturn(run func(context.Context) (*milvuspb.ComponentStates, error)) *MockProxy_GetComponentStates_Call { +func (_c *MockProxy_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)) *MockProxy_GetComponentStates_Call { _c.Call.Return(run) return _c } -// GetDdChannel provides a mock function with given fields: ctx, request -func (_m *MockProxy) GetDdChannel(ctx context.Context, request *internalpb.GetDdChannelRequest) (*milvuspb.StringResponse, error) { - ret := _m.Called(ctx, request) +// GetDdChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetDdChannel(_a0 context.Context, _a1 *internalpb.GetDdChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.StringResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetDdChannelRequest) (*milvuspb.StringResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetDdChannelRequest) *milvuspb.StringResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.StringResponse) @@ -1905,7 +2018,7 @@ func (_m *MockProxy) GetDdChannel(ctx context.Context, request *internalpb.GetDd } if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetDdChannelRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1919,13 +2032,13 @@ type MockProxy_GetDdChannel_Call struct { } // GetDdChannel is a helper method to define mock.On call -// - ctx context.Context -// - request *internalpb.GetDdChannelRequest -func (_e *MockProxy_Expecter) GetDdChannel(ctx interface{}, request interface{}) *MockProxy_GetDdChannel_Call { - return &MockProxy_GetDdChannel_Call{Call: _e.mock.On("GetDdChannel", ctx, request)} +// - _a0 context.Context +// - _a1 *internalpb.GetDdChannelRequest +func (_e *MockProxy_Expecter) GetDdChannel(_a0 interface{}, _a1 interface{}) *MockProxy_GetDdChannel_Call { + return &MockProxy_GetDdChannel_Call{Call: _e.mock.On("GetDdChannel", _a0, _a1)} } -func (_c *MockProxy_GetDdChannel_Call) Run(run func(ctx context.Context, request *internalpb.GetDdChannelRequest)) *MockProxy_GetDdChannel_Call { +func (_c *MockProxy_GetDdChannel_Call) Run(run func(_a0 context.Context, _a1 *internalpb.GetDdChannelRequest)) *MockProxy_GetDdChannel_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*internalpb.GetDdChannelRequest)) }) @@ -1942,17 +2055,17 @@ func (_c *MockProxy_GetDdChannel_Call) RunAndReturn(run func(context.Context, *i return _c } -// GetFlushAllState provides a mock function with given fields: ctx, req -func (_m *MockProxy) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAllStateRequest) (*milvuspb.GetFlushAllStateResponse, error) { - ret := _m.Called(ctx, req) +// GetFlushAllState provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetFlushAllState(_a0 context.Context, _a1 *milvuspb.GetFlushAllStateRequest) (*milvuspb.GetFlushAllStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetFlushAllStateResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetFlushAllStateRequest) (*milvuspb.GetFlushAllStateResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetFlushAllStateRequest) *milvuspb.GetFlushAllStateResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetFlushAllStateResponse) @@ -1960,7 +2073,7 @@ func (_m *MockProxy) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlus } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetFlushAllStateRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1974,13 +2087,13 @@ type MockProxy_GetFlushAllState_Call struct { } // GetFlushAllState is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetFlushAllStateRequest -func (_e *MockProxy_Expecter) GetFlushAllState(ctx interface{}, req interface{}) *MockProxy_GetFlushAllState_Call { - return &MockProxy_GetFlushAllState_Call{Call: _e.mock.On("GetFlushAllState", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetFlushAllStateRequest +func (_e *MockProxy_Expecter) GetFlushAllState(_a0 interface{}, _a1 interface{}) *MockProxy_GetFlushAllState_Call { + return &MockProxy_GetFlushAllState_Call{Call: _e.mock.On("GetFlushAllState", _a0, _a1)} } -func (_c *MockProxy_GetFlushAllState_Call) Run(run func(ctx context.Context, req *milvuspb.GetFlushAllStateRequest)) *MockProxy_GetFlushAllState_Call { +func (_c *MockProxy_GetFlushAllState_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetFlushAllStateRequest)) *MockProxy_GetFlushAllState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetFlushAllStateRequest)) }) @@ -1997,17 +2110,17 @@ func (_c *MockProxy_GetFlushAllState_Call) RunAndReturn(run func(context.Context return _c } -// GetFlushState provides a mock function with given fields: ctx, req -func (_m *MockProxy) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { - ret := _m.Called(ctx, req) +// GetFlushState provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetFlushState(_a0 context.Context, _a1 *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetFlushStateResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetFlushStateRequest) *milvuspb.GetFlushStateResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetFlushStateResponse) @@ -2015,7 +2128,7 @@ func (_m *MockProxy) GetFlushState(ctx context.Context, req *milvuspb.GetFlushSt } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetFlushStateRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2029,13 +2142,13 @@ type MockProxy_GetFlushState_Call struct { } // GetFlushState is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetFlushStateRequest -func (_e *MockProxy_Expecter) GetFlushState(ctx interface{}, req interface{}) *MockProxy_GetFlushState_Call { - return &MockProxy_GetFlushState_Call{Call: _e.mock.On("GetFlushState", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetFlushStateRequest +func (_e *MockProxy_Expecter) GetFlushState(_a0 interface{}, _a1 interface{}) *MockProxy_GetFlushState_Call { + return &MockProxy_GetFlushState_Call{Call: _e.mock.On("GetFlushState", _a0, _a1)} } -func (_c *MockProxy_GetFlushState_Call) Run(run func(ctx context.Context, req *milvuspb.GetFlushStateRequest)) *MockProxy_GetFlushState_Call { +func (_c *MockProxy_GetFlushState_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetFlushStateRequest)) *MockProxy_GetFlushState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetFlushStateRequest)) }) @@ -2052,17 +2165,17 @@ func (_c *MockProxy_GetFlushState_Call) RunAndReturn(run func(context.Context, * return _c } -// GetImportState provides a mock function with given fields: ctx, req -func (_m *MockProxy) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) { - ret := _m.Called(ctx, req) +// GetImportState provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetImportState(_a0 context.Context, _a1 *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetImportStateResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetImportStateRequest) *milvuspb.GetImportStateResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetImportStateResponse) @@ -2070,7 +2183,7 @@ func (_m *MockProxy) GetImportState(ctx context.Context, req *milvuspb.GetImport } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetImportStateRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2084,13 +2197,13 @@ type MockProxy_GetImportState_Call struct { } // GetImportState is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetImportStateRequest -func (_e *MockProxy_Expecter) GetImportState(ctx interface{}, req interface{}) *MockProxy_GetImportState_Call { - return &MockProxy_GetImportState_Call{Call: _e.mock.On("GetImportState", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetImportStateRequest +func (_e *MockProxy_Expecter) GetImportState(_a0 interface{}, _a1 interface{}) *MockProxy_GetImportState_Call { + return &MockProxy_GetImportState_Call{Call: _e.mock.On("GetImportState", _a0, _a1)} } -func (_c *MockProxy_GetImportState_Call) Run(run func(ctx context.Context, req *milvuspb.GetImportStateRequest)) *MockProxy_GetImportState_Call { +func (_c *MockProxy_GetImportState_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetImportStateRequest)) *MockProxy_GetImportState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetImportStateRequest)) }) @@ -2107,17 +2220,17 @@ func (_c *MockProxy_GetImportState_Call) RunAndReturn(run func(context.Context, return _c } -// GetIndexBuildProgress provides a mock function with given fields: ctx, request -func (_m *MockProxy) GetIndexBuildProgress(ctx context.Context, request *milvuspb.GetIndexBuildProgressRequest) (*milvuspb.GetIndexBuildProgressResponse, error) { - ret := _m.Called(ctx, request) +// GetIndexBuildProgress provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetIndexBuildProgress(_a0 context.Context, _a1 *milvuspb.GetIndexBuildProgressRequest) (*milvuspb.GetIndexBuildProgressResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetIndexBuildProgressResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetIndexBuildProgressRequest) (*milvuspb.GetIndexBuildProgressResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetIndexBuildProgressRequest) *milvuspb.GetIndexBuildProgressResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetIndexBuildProgressResponse) @@ -2125,7 +2238,7 @@ func (_m *MockProxy) GetIndexBuildProgress(ctx context.Context, request *milvusp } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetIndexBuildProgressRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2139,13 +2252,13 @@ type MockProxy_GetIndexBuildProgress_Call struct { } // GetIndexBuildProgress is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.GetIndexBuildProgressRequest -func (_e *MockProxy_Expecter) GetIndexBuildProgress(ctx interface{}, request interface{}) *MockProxy_GetIndexBuildProgress_Call { - return &MockProxy_GetIndexBuildProgress_Call{Call: _e.mock.On("GetIndexBuildProgress", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.GetIndexBuildProgressRequest +func (_e *MockProxy_Expecter) GetIndexBuildProgress(_a0 interface{}, _a1 interface{}) *MockProxy_GetIndexBuildProgress_Call { + return &MockProxy_GetIndexBuildProgress_Call{Call: _e.mock.On("GetIndexBuildProgress", _a0, _a1)} } -func (_c *MockProxy_GetIndexBuildProgress_Call) Run(run func(ctx context.Context, request *milvuspb.GetIndexBuildProgressRequest)) *MockProxy_GetIndexBuildProgress_Call { +func (_c *MockProxy_GetIndexBuildProgress_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetIndexBuildProgressRequest)) *MockProxy_GetIndexBuildProgress_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetIndexBuildProgressRequest)) }) @@ -2162,17 +2275,17 @@ func (_c *MockProxy_GetIndexBuildProgress_Call) RunAndReturn(run func(context.Co return _c } -// GetIndexState provides a mock function with given fields: ctx, request -func (_m *MockProxy) GetIndexState(ctx context.Context, request *milvuspb.GetIndexStateRequest) (*milvuspb.GetIndexStateResponse, error) { - ret := _m.Called(ctx, request) +// GetIndexState provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetIndexState(_a0 context.Context, _a1 *milvuspb.GetIndexStateRequest) (*milvuspb.GetIndexStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetIndexStateResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetIndexStateRequest) (*milvuspb.GetIndexStateResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetIndexStateRequest) *milvuspb.GetIndexStateResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetIndexStateResponse) @@ -2180,7 +2293,7 @@ func (_m *MockProxy) GetIndexState(ctx context.Context, request *milvuspb.GetInd } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetIndexStateRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2194,13 +2307,13 @@ type MockProxy_GetIndexState_Call struct { } // GetIndexState is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.GetIndexStateRequest -func (_e *MockProxy_Expecter) GetIndexState(ctx interface{}, request interface{}) *MockProxy_GetIndexState_Call { - return &MockProxy_GetIndexState_Call{Call: _e.mock.On("GetIndexState", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.GetIndexStateRequest +func (_e *MockProxy_Expecter) GetIndexState(_a0 interface{}, _a1 interface{}) *MockProxy_GetIndexState_Call { + return &MockProxy_GetIndexState_Call{Call: _e.mock.On("GetIndexState", _a0, _a1)} } -func (_c *MockProxy_GetIndexState_Call) Run(run func(ctx context.Context, request *milvuspb.GetIndexStateRequest)) *MockProxy_GetIndexState_Call { +func (_c *MockProxy_GetIndexState_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetIndexStateRequest)) *MockProxy_GetIndexState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetIndexStateRequest)) }) @@ -2217,17 +2330,17 @@ func (_c *MockProxy_GetIndexState_Call) RunAndReturn(run func(context.Context, * return _c } -// GetIndexStatistics provides a mock function with given fields: ctx, request -func (_m *MockProxy) GetIndexStatistics(ctx context.Context, request *milvuspb.GetIndexStatisticsRequest) (*milvuspb.GetIndexStatisticsResponse, error) { - ret := _m.Called(ctx, request) +// GetIndexStatistics provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetIndexStatistics(_a0 context.Context, _a1 *milvuspb.GetIndexStatisticsRequest) (*milvuspb.GetIndexStatisticsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetIndexStatisticsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetIndexStatisticsRequest) (*milvuspb.GetIndexStatisticsResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetIndexStatisticsRequest) *milvuspb.GetIndexStatisticsResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetIndexStatisticsResponse) @@ -2235,7 +2348,7 @@ func (_m *MockProxy) GetIndexStatistics(ctx context.Context, request *milvuspb.G } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetIndexStatisticsRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2249,13 +2362,13 @@ type MockProxy_GetIndexStatistics_Call struct { } // GetIndexStatistics is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.GetIndexStatisticsRequest -func (_e *MockProxy_Expecter) GetIndexStatistics(ctx interface{}, request interface{}) *MockProxy_GetIndexStatistics_Call { - return &MockProxy_GetIndexStatistics_Call{Call: _e.mock.On("GetIndexStatistics", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.GetIndexStatisticsRequest +func (_e *MockProxy_Expecter) GetIndexStatistics(_a0 interface{}, _a1 interface{}) *MockProxy_GetIndexStatistics_Call { + return &MockProxy_GetIndexStatistics_Call{Call: _e.mock.On("GetIndexStatistics", _a0, _a1)} } -func (_c *MockProxy_GetIndexStatistics_Call) Run(run func(ctx context.Context, request *milvuspb.GetIndexStatisticsRequest)) *MockProxy_GetIndexStatistics_Call { +func (_c *MockProxy_GetIndexStatistics_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetIndexStatisticsRequest)) *MockProxy_GetIndexStatistics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetIndexStatisticsRequest)) }) @@ -2272,17 +2385,17 @@ func (_c *MockProxy_GetIndexStatistics_Call) RunAndReturn(run func(context.Conte return _c } -// GetLoadState provides a mock function with given fields: ctx, request -func (_m *MockProxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadStateRequest) (*milvuspb.GetLoadStateResponse, error) { - ret := _m.Called(ctx, request) +// GetLoadState provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetLoadState(_a0 context.Context, _a1 *milvuspb.GetLoadStateRequest) (*milvuspb.GetLoadStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetLoadStateResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetLoadStateRequest) (*milvuspb.GetLoadStateResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetLoadStateRequest) *milvuspb.GetLoadStateResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetLoadStateResponse) @@ -2290,7 +2403,7 @@ func (_m *MockProxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoad } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetLoadStateRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2304,13 +2417,13 @@ type MockProxy_GetLoadState_Call struct { } // GetLoadState is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.GetLoadStateRequest -func (_e *MockProxy_Expecter) GetLoadState(ctx interface{}, request interface{}) *MockProxy_GetLoadState_Call { - return &MockProxy_GetLoadState_Call{Call: _e.mock.On("GetLoadState", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.GetLoadStateRequest +func (_e *MockProxy_Expecter) GetLoadState(_a0 interface{}, _a1 interface{}) *MockProxy_GetLoadState_Call { + return &MockProxy_GetLoadState_Call{Call: _e.mock.On("GetLoadState", _a0, _a1)} } -func (_c *MockProxy_GetLoadState_Call) Run(run func(ctx context.Context, request *milvuspb.GetLoadStateRequest)) *MockProxy_GetLoadState_Call { +func (_c *MockProxy_GetLoadState_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetLoadStateRequest)) *MockProxy_GetLoadState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetLoadStateRequest)) }) @@ -2327,17 +2440,17 @@ func (_c *MockProxy_GetLoadState_Call) RunAndReturn(run func(context.Context, *m return _c } -// GetLoadingProgress provides a mock function with given fields: ctx, request -func (_m *MockProxy) GetLoadingProgress(ctx context.Context, request *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) { - ret := _m.Called(ctx, request) +// GetLoadingProgress provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetLoadingProgress(_a0 context.Context, _a1 *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetLoadingProgressResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetLoadingProgressRequest) *milvuspb.GetLoadingProgressResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetLoadingProgressResponse) @@ -2345,7 +2458,7 @@ func (_m *MockProxy) GetLoadingProgress(ctx context.Context, request *milvuspb.G } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetLoadingProgressRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2359,13 +2472,13 @@ type MockProxy_GetLoadingProgress_Call struct { } // GetLoadingProgress is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.GetLoadingProgressRequest -func (_e *MockProxy_Expecter) GetLoadingProgress(ctx interface{}, request interface{}) *MockProxy_GetLoadingProgress_Call { - return &MockProxy_GetLoadingProgress_Call{Call: _e.mock.On("GetLoadingProgress", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.GetLoadingProgressRequest +func (_e *MockProxy_Expecter) GetLoadingProgress(_a0 interface{}, _a1 interface{}) *MockProxy_GetLoadingProgress_Call { + return &MockProxy_GetLoadingProgress_Call{Call: _e.mock.On("GetLoadingProgress", _a0, _a1)} } -func (_c *MockProxy_GetLoadingProgress_Call) Run(run func(ctx context.Context, request *milvuspb.GetLoadingProgressRequest)) *MockProxy_GetLoadingProgress_Call { +func (_c *MockProxy_GetLoadingProgress_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetLoadingProgressRequest)) *MockProxy_GetLoadingProgress_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetLoadingProgressRequest)) }) @@ -2382,17 +2495,17 @@ func (_c *MockProxy_GetLoadingProgress_Call) RunAndReturn(run func(context.Conte return _c } -// GetMetrics provides a mock function with given fields: ctx, request -func (_m *MockProxy) GetMetrics(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret := _m.Called(ctx, request) +// GetMetrics provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetMetrics(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetMetricsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) *milvuspb.GetMetricsResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) @@ -2400,7 +2513,7 @@ func (_m *MockProxy) GetMetrics(ctx context.Context, request *milvuspb.GetMetric } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2414,13 +2527,13 @@ type MockProxy_GetMetrics_Call struct { } // GetMetrics is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.GetMetricsRequest -func (_e *MockProxy_Expecter) GetMetrics(ctx interface{}, request interface{}) *MockProxy_GetMetrics_Call { - return &MockProxy_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.GetMetricsRequest +func (_e *MockProxy_Expecter) GetMetrics(_a0 interface{}, _a1 interface{}) *MockProxy_GetMetrics_Call { + return &MockProxy_GetMetrics_Call{Call: _e.mock.On("GetMetrics", _a0, _a1)} } -func (_c *MockProxy_GetMetrics_Call) Run(run func(ctx context.Context, request *milvuspb.GetMetricsRequest)) *MockProxy_GetMetrics_Call { +func (_c *MockProxy_GetMetrics_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest)) *MockProxy_GetMetrics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest)) }) @@ -2437,17 +2550,17 @@ func (_c *MockProxy_GetMetrics_Call) RunAndReturn(run func(context.Context, *mil return _c } -// GetPartitionStatistics provides a mock function with given fields: ctx, request -func (_m *MockProxy) GetPartitionStatistics(ctx context.Context, request *milvuspb.GetPartitionStatisticsRequest) (*milvuspb.GetPartitionStatisticsResponse, error) { - ret := _m.Called(ctx, request) +// GetPartitionStatistics provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetPartitionStatistics(_a0 context.Context, _a1 *milvuspb.GetPartitionStatisticsRequest) (*milvuspb.GetPartitionStatisticsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetPartitionStatisticsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetPartitionStatisticsRequest) (*milvuspb.GetPartitionStatisticsResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetPartitionStatisticsRequest) *milvuspb.GetPartitionStatisticsResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetPartitionStatisticsResponse) @@ -2455,7 +2568,7 @@ func (_m *MockProxy) GetPartitionStatistics(ctx context.Context, request *milvus } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetPartitionStatisticsRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2469,13 +2582,13 @@ type MockProxy_GetPartitionStatistics_Call struct { } // GetPartitionStatistics is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.GetPartitionStatisticsRequest -func (_e *MockProxy_Expecter) GetPartitionStatistics(ctx interface{}, request interface{}) *MockProxy_GetPartitionStatistics_Call { - return &MockProxy_GetPartitionStatistics_Call{Call: _e.mock.On("GetPartitionStatistics", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.GetPartitionStatisticsRequest +func (_e *MockProxy_Expecter) GetPartitionStatistics(_a0 interface{}, _a1 interface{}) *MockProxy_GetPartitionStatistics_Call { + return &MockProxy_GetPartitionStatistics_Call{Call: _e.mock.On("GetPartitionStatistics", _a0, _a1)} } -func (_c *MockProxy_GetPartitionStatistics_Call) Run(run func(ctx context.Context, request *milvuspb.GetPartitionStatisticsRequest)) *MockProxy_GetPartitionStatistics_Call { +func (_c *MockProxy_GetPartitionStatistics_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetPartitionStatisticsRequest)) *MockProxy_GetPartitionStatistics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetPartitionStatisticsRequest)) }) @@ -2492,17 +2605,17 @@ func (_c *MockProxy_GetPartitionStatistics_Call) RunAndReturn(run func(context.C return _c } -// GetPersistentSegmentInfo provides a mock function with given fields: ctx, request -func (_m *MockProxy) GetPersistentSegmentInfo(ctx context.Context, request *milvuspb.GetPersistentSegmentInfoRequest) (*milvuspb.GetPersistentSegmentInfoResponse, error) { - ret := _m.Called(ctx, request) +// GetPersistentSegmentInfo provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetPersistentSegmentInfo(_a0 context.Context, _a1 *milvuspb.GetPersistentSegmentInfoRequest) (*milvuspb.GetPersistentSegmentInfoResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetPersistentSegmentInfoResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetPersistentSegmentInfoRequest) (*milvuspb.GetPersistentSegmentInfoResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetPersistentSegmentInfoRequest) *milvuspb.GetPersistentSegmentInfoResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetPersistentSegmentInfoResponse) @@ -2510,7 +2623,7 @@ func (_m *MockProxy) GetPersistentSegmentInfo(ctx context.Context, request *milv } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetPersistentSegmentInfoRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2524,13 +2637,13 @@ type MockProxy_GetPersistentSegmentInfo_Call struct { } // GetPersistentSegmentInfo is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.GetPersistentSegmentInfoRequest -func (_e *MockProxy_Expecter) GetPersistentSegmentInfo(ctx interface{}, request interface{}) *MockProxy_GetPersistentSegmentInfo_Call { - return &MockProxy_GetPersistentSegmentInfo_Call{Call: _e.mock.On("GetPersistentSegmentInfo", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.GetPersistentSegmentInfoRequest +func (_e *MockProxy_Expecter) GetPersistentSegmentInfo(_a0 interface{}, _a1 interface{}) *MockProxy_GetPersistentSegmentInfo_Call { + return &MockProxy_GetPersistentSegmentInfo_Call{Call: _e.mock.On("GetPersistentSegmentInfo", _a0, _a1)} } -func (_c *MockProxy_GetPersistentSegmentInfo_Call) Run(run func(ctx context.Context, request *milvuspb.GetPersistentSegmentInfoRequest)) *MockProxy_GetPersistentSegmentInfo_Call { +func (_c *MockProxy_GetPersistentSegmentInfo_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetPersistentSegmentInfoRequest)) *MockProxy_GetPersistentSegmentInfo_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetPersistentSegmentInfoRequest)) }) @@ -2547,17 +2660,17 @@ func (_c *MockProxy_GetPersistentSegmentInfo_Call) RunAndReturn(run func(context return _c } -// GetProxyMetrics provides a mock function with given fields: ctx, request -func (_m *MockProxy) GetProxyMetrics(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret := _m.Called(ctx, request) +// GetProxyMetrics provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetProxyMetrics(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetMetricsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) *milvuspb.GetMetricsResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) @@ -2565,7 +2678,7 @@ func (_m *MockProxy) GetProxyMetrics(ctx context.Context, request *milvuspb.GetM } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2579,13 +2692,13 @@ type MockProxy_GetProxyMetrics_Call struct { } // GetProxyMetrics is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.GetMetricsRequest -func (_e *MockProxy_Expecter) GetProxyMetrics(ctx interface{}, request interface{}) *MockProxy_GetProxyMetrics_Call { - return &MockProxy_GetProxyMetrics_Call{Call: _e.mock.On("GetProxyMetrics", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.GetMetricsRequest +func (_e *MockProxy_Expecter) GetProxyMetrics(_a0 interface{}, _a1 interface{}) *MockProxy_GetProxyMetrics_Call { + return &MockProxy_GetProxyMetrics_Call{Call: _e.mock.On("GetProxyMetrics", _a0, _a1)} } -func (_c *MockProxy_GetProxyMetrics_Call) Run(run func(ctx context.Context, request *milvuspb.GetMetricsRequest)) *MockProxy_GetProxyMetrics_Call { +func (_c *MockProxy_GetProxyMetrics_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest)) *MockProxy_GetProxyMetrics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest)) }) @@ -2602,17 +2715,17 @@ func (_c *MockProxy_GetProxyMetrics_Call) RunAndReturn(run func(context.Context, return _c } -// GetQuerySegmentInfo provides a mock function with given fields: ctx, request -func (_m *MockProxy) GetQuerySegmentInfo(ctx context.Context, request *milvuspb.GetQuerySegmentInfoRequest) (*milvuspb.GetQuerySegmentInfoResponse, error) { - ret := _m.Called(ctx, request) +// GetQuerySegmentInfo provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetQuerySegmentInfo(_a0 context.Context, _a1 *milvuspb.GetQuerySegmentInfoRequest) (*milvuspb.GetQuerySegmentInfoResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetQuerySegmentInfoResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetQuerySegmentInfoRequest) (*milvuspb.GetQuerySegmentInfoResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetQuerySegmentInfoRequest) *milvuspb.GetQuerySegmentInfoResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetQuerySegmentInfoResponse) @@ -2620,7 +2733,7 @@ func (_m *MockProxy) GetQuerySegmentInfo(ctx context.Context, request *milvuspb. } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetQuerySegmentInfoRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2634,13 +2747,13 @@ type MockProxy_GetQuerySegmentInfo_Call struct { } // GetQuerySegmentInfo is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.GetQuerySegmentInfoRequest -func (_e *MockProxy_Expecter) GetQuerySegmentInfo(ctx interface{}, request interface{}) *MockProxy_GetQuerySegmentInfo_Call { - return &MockProxy_GetQuerySegmentInfo_Call{Call: _e.mock.On("GetQuerySegmentInfo", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.GetQuerySegmentInfoRequest +func (_e *MockProxy_Expecter) GetQuerySegmentInfo(_a0 interface{}, _a1 interface{}) *MockProxy_GetQuerySegmentInfo_Call { + return &MockProxy_GetQuerySegmentInfo_Call{Call: _e.mock.On("GetQuerySegmentInfo", _a0, _a1)} } -func (_c *MockProxy_GetQuerySegmentInfo_Call) Run(run func(ctx context.Context, request *milvuspb.GetQuerySegmentInfoRequest)) *MockProxy_GetQuerySegmentInfo_Call { +func (_c *MockProxy_GetQuerySegmentInfo_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetQuerySegmentInfoRequest)) *MockProxy_GetQuerySegmentInfo_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetQuerySegmentInfoRequest)) }) @@ -2710,17 +2823,17 @@ func (_c *MockProxy_GetRateLimiter_Call) RunAndReturn(run func() (types.Limiter, return _c } -// GetReplicas provides a mock function with given fields: ctx, req -func (_m *MockProxy) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) { - ret := _m.Called(ctx, req) +// GetReplicas provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetReplicas(_a0 context.Context, _a1 *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetReplicasResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetReplicasRequest) *milvuspb.GetReplicasResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetReplicasResponse) @@ -2728,7 +2841,7 @@ func (_m *MockProxy) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasR } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetReplicasRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2742,13 +2855,13 @@ type MockProxy_GetReplicas_Call struct { } // GetReplicas is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetReplicasRequest -func (_e *MockProxy_Expecter) GetReplicas(ctx interface{}, req interface{}) *MockProxy_GetReplicas_Call { - return &MockProxy_GetReplicas_Call{Call: _e.mock.On("GetReplicas", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetReplicasRequest +func (_e *MockProxy_Expecter) GetReplicas(_a0 interface{}, _a1 interface{}) *MockProxy_GetReplicas_Call { + return &MockProxy_GetReplicas_Call{Call: _e.mock.On("GetReplicas", _a0, _a1)} } -func (_c *MockProxy_GetReplicas_Call) Run(run func(ctx context.Context, req *milvuspb.GetReplicasRequest)) *MockProxy_GetReplicas_Call { +func (_c *MockProxy_GetReplicas_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetReplicasRequest)) *MockProxy_GetReplicas_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetReplicasRequest)) }) @@ -2765,25 +2878,25 @@ func (_c *MockProxy_GetReplicas_Call) RunAndReturn(run func(context.Context, *mi return _c } -// GetStatisticsChannel provides a mock function with given fields: ctx -func (_m *MockProxy) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret := _m.Called(ctx) +// GetStatisticsChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetStatisticsChannel(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.StringResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.StringResponse, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.StringResponse); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) *milvuspb.StringResponse); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.StringResponse) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2797,14 +2910,15 @@ type MockProxy_GetStatisticsChannel_Call struct { } // GetStatisticsChannel is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockProxy_Expecter) GetStatisticsChannel(ctx interface{}) *MockProxy_GetStatisticsChannel_Call { - return &MockProxy_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", ctx)} +// - _a0 context.Context +// - _a1 *internalpb.GetStatisticsChannelRequest +func (_e *MockProxy_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *MockProxy_GetStatisticsChannel_Call { + return &MockProxy_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)} } -func (_c *MockProxy_GetStatisticsChannel_Call) Run(run func(ctx context.Context)) *MockProxy_GetStatisticsChannel_Call { +func (_c *MockProxy_GetStatisticsChannel_Call) Run(run func(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest)) *MockProxy_GetStatisticsChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest)) }) return _c } @@ -2814,22 +2928,77 @@ func (_c *MockProxy_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringRespon return _c } -func (_c *MockProxy_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context) (*milvuspb.StringResponse, error)) *MockProxy_GetStatisticsChannel_Call { +func (_c *MockProxy_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)) *MockProxy_GetStatisticsChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetVersion provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) GetVersion(_a0 context.Context, _a1 *milvuspb.GetVersionRequest) (*milvuspb.GetVersionResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *milvuspb.GetVersionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetVersionRequest) (*milvuspb.GetVersionResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetVersionRequest) *milvuspb.GetVersionResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetVersionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetVersionRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxy_GetVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetVersion' +type MockProxy_GetVersion_Call struct { + *mock.Call +} + +// GetVersion is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *milvuspb.GetVersionRequest +func (_e *MockProxy_Expecter) GetVersion(_a0 interface{}, _a1 interface{}) *MockProxy_GetVersion_Call { + return &MockProxy_GetVersion_Call{Call: _e.mock.On("GetVersion", _a0, _a1)} +} + +func (_c *MockProxy_GetVersion_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetVersionRequest)) *MockProxy_GetVersion_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.GetVersionRequest)) + }) + return _c +} + +func (_c *MockProxy_GetVersion_Call) Return(_a0 *milvuspb.GetVersionResponse, _a1 error) *MockProxy_GetVersion_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxy_GetVersion_Call) RunAndReturn(run func(context.Context, *milvuspb.GetVersionRequest) (*milvuspb.GetVersionResponse, error)) *MockProxy_GetVersion_Call { _c.Call.Return(run) return _c } -// HasCollection provides a mock function with given fields: ctx, request -func (_m *MockProxy) HasCollection(ctx context.Context, request *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { - ret := _m.Called(ctx, request) +// HasCollection provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) HasCollection(_a0 context.Context, _a1 *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.BoolResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasCollectionRequest) *milvuspb.BoolResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.BoolResponse) @@ -2837,7 +3006,7 @@ func (_m *MockProxy) HasCollection(ctx context.Context, request *milvuspb.HasCol } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.HasCollectionRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2851,13 +3020,13 @@ type MockProxy_HasCollection_Call struct { } // HasCollection is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.HasCollectionRequest -func (_e *MockProxy_Expecter) HasCollection(ctx interface{}, request interface{}) *MockProxy_HasCollection_Call { - return &MockProxy_HasCollection_Call{Call: _e.mock.On("HasCollection", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.HasCollectionRequest +func (_e *MockProxy_Expecter) HasCollection(_a0 interface{}, _a1 interface{}) *MockProxy_HasCollection_Call { + return &MockProxy_HasCollection_Call{Call: _e.mock.On("HasCollection", _a0, _a1)} } -func (_c *MockProxy_HasCollection_Call) Run(run func(ctx context.Context, request *milvuspb.HasCollectionRequest)) *MockProxy_HasCollection_Call { +func (_c *MockProxy_HasCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.HasCollectionRequest)) *MockProxy_HasCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.HasCollectionRequest)) }) @@ -2874,17 +3043,17 @@ func (_c *MockProxy_HasCollection_Call) RunAndReturn(run func(context.Context, * return _c } -// HasPartition provides a mock function with given fields: ctx, request -func (_m *MockProxy) HasPartition(ctx context.Context, request *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { - ret := _m.Called(ctx, request) +// HasPartition provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) HasPartition(_a0 context.Context, _a1 *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.BoolResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasPartitionRequest) *milvuspb.BoolResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.BoolResponse) @@ -2892,7 +3061,7 @@ func (_m *MockProxy) HasPartition(ctx context.Context, request *milvuspb.HasPart } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.HasPartitionRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2906,13 +3075,13 @@ type MockProxy_HasPartition_Call struct { } // HasPartition is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.HasPartitionRequest -func (_e *MockProxy_Expecter) HasPartition(ctx interface{}, request interface{}) *MockProxy_HasPartition_Call { - return &MockProxy_HasPartition_Call{Call: _e.mock.On("HasPartition", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.HasPartitionRequest +func (_e *MockProxy_Expecter) HasPartition(_a0 interface{}, _a1 interface{}) *MockProxy_HasPartition_Call { + return &MockProxy_HasPartition_Call{Call: _e.mock.On("HasPartition", _a0, _a1)} } -func (_c *MockProxy_HasPartition_Call) Run(run func(ctx context.Context, request *milvuspb.HasPartitionRequest)) *MockProxy_HasPartition_Call { +func (_c *MockProxy_HasPartition_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.HasPartitionRequest)) *MockProxy_HasPartition_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.HasPartitionRequest)) }) @@ -2929,17 +3098,17 @@ func (_c *MockProxy_HasPartition_Call) RunAndReturn(run func(context.Context, *m return _c } -// Import provides a mock function with given fields: ctx, req -func (_m *MockProxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { - ret := _m.Called(ctx, req) +// Import provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) Import(_a0 context.Context, _a1 *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ImportResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ImportRequest) *milvuspb.ImportResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ImportResponse) @@ -2947,7 +3116,7 @@ func (_m *MockProxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (* } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ImportRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2961,13 +3130,13 @@ type MockProxy_Import_Call struct { } // Import is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ImportRequest -func (_e *MockProxy_Expecter) Import(ctx interface{}, req interface{}) *MockProxy_Import_Call { - return &MockProxy_Import_Call{Call: _e.mock.On("Import", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ImportRequest +func (_e *MockProxy_Expecter) Import(_a0 interface{}, _a1 interface{}) *MockProxy_Import_Call { + return &MockProxy_Import_Call{Call: _e.mock.On("Import", _a0, _a1)} } -func (_c *MockProxy_Import_Call) Run(run func(ctx context.Context, req *milvuspb.ImportRequest)) *MockProxy_Import_Call { +func (_c *MockProxy_Import_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ImportRequest)) *MockProxy_Import_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ImportRequest)) }) @@ -3025,17 +3194,17 @@ func (_c *MockProxy_Init_Call) RunAndReturn(run func() error) *MockProxy_Init_Ca return _c } -// Insert provides a mock function with given fields: ctx, request -func (_m *MockProxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) { - ret := _m.Called(ctx, request) +// Insert provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) Insert(_a0 context.Context, _a1 *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.MutationResult var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.InsertRequest) (*milvuspb.MutationResult, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.InsertRequest) *milvuspb.MutationResult); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.MutationResult) @@ -3043,7 +3212,7 @@ func (_m *MockProxy) Insert(ctx context.Context, request *milvuspb.InsertRequest } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.InsertRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3057,13 +3226,13 @@ type MockProxy_Insert_Call struct { } // Insert is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.InsertRequest -func (_e *MockProxy_Expecter) Insert(ctx interface{}, request interface{}) *MockProxy_Insert_Call { - return &MockProxy_Insert_Call{Call: _e.mock.On("Insert", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.InsertRequest +func (_e *MockProxy_Expecter) Insert(_a0 interface{}, _a1 interface{}) *MockProxy_Insert_Call { + return &MockProxy_Insert_Call{Call: _e.mock.On("Insert", _a0, _a1)} } -func (_c *MockProxy_Insert_Call) Run(run func(ctx context.Context, request *milvuspb.InsertRequest)) *MockProxy_Insert_Call { +func (_c *MockProxy_Insert_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.InsertRequest)) *MockProxy_Insert_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.InsertRequest)) }) @@ -3080,17 +3249,17 @@ func (_c *MockProxy_Insert_Call) RunAndReturn(run func(context.Context, *milvusp return _c } -// InvalidateCollectionMetaCache provides a mock function with given fields: ctx, request -func (_m *MockProxy) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// InvalidateCollectionMetaCache provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) InvalidateCollectionMetaCache(_a0 context.Context, _a1 *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -3098,7 +3267,7 @@ func (_m *MockProxy) InvalidateCollectionMetaCache(ctx context.Context, request } if rf, ok := ret.Get(1).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3112,13 +3281,13 @@ type MockProxy_InvalidateCollectionMetaCache_Call struct { } // InvalidateCollectionMetaCache is a helper method to define mock.On call -// - ctx context.Context -// - request *proxypb.InvalidateCollMetaCacheRequest -func (_e *MockProxy_Expecter) InvalidateCollectionMetaCache(ctx interface{}, request interface{}) *MockProxy_InvalidateCollectionMetaCache_Call { - return &MockProxy_InvalidateCollectionMetaCache_Call{Call: _e.mock.On("InvalidateCollectionMetaCache", ctx, request)} +// - _a0 context.Context +// - _a1 *proxypb.InvalidateCollMetaCacheRequest +func (_e *MockProxy_Expecter) InvalidateCollectionMetaCache(_a0 interface{}, _a1 interface{}) *MockProxy_InvalidateCollectionMetaCache_Call { + return &MockProxy_InvalidateCollectionMetaCache_Call{Call: _e.mock.On("InvalidateCollectionMetaCache", _a0, _a1)} } -func (_c *MockProxy_InvalidateCollectionMetaCache_Call) Run(run func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest)) *MockProxy_InvalidateCollectionMetaCache_Call { +func (_c *MockProxy_InvalidateCollectionMetaCache_Call) Run(run func(_a0 context.Context, _a1 *proxypb.InvalidateCollMetaCacheRequest)) *MockProxy_InvalidateCollectionMetaCache_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*proxypb.InvalidateCollMetaCacheRequest)) }) @@ -3135,17 +3304,17 @@ func (_c *MockProxy_InvalidateCollectionMetaCache_Call) RunAndReturn(run func(co return _c } -// InvalidateCredentialCache provides a mock function with given fields: ctx, request -func (_m *MockProxy) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// InvalidateCredentialCache provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) InvalidateCredentialCache(_a0 context.Context, _a1 *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCredCacheRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -3153,7 +3322,7 @@ func (_m *MockProxy) InvalidateCredentialCache(ctx context.Context, request *pro } if rf, ok := ret.Get(1).(func(context.Context, *proxypb.InvalidateCredCacheRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3167,13 +3336,13 @@ type MockProxy_InvalidateCredentialCache_Call struct { } // InvalidateCredentialCache is a helper method to define mock.On call -// - ctx context.Context -// - request *proxypb.InvalidateCredCacheRequest -func (_e *MockProxy_Expecter) InvalidateCredentialCache(ctx interface{}, request interface{}) *MockProxy_InvalidateCredentialCache_Call { - return &MockProxy_InvalidateCredentialCache_Call{Call: _e.mock.On("InvalidateCredentialCache", ctx, request)} +// - _a0 context.Context +// - _a1 *proxypb.InvalidateCredCacheRequest +func (_e *MockProxy_Expecter) InvalidateCredentialCache(_a0 interface{}, _a1 interface{}) *MockProxy_InvalidateCredentialCache_Call { + return &MockProxy_InvalidateCredentialCache_Call{Call: _e.mock.On("InvalidateCredentialCache", _a0, _a1)} } -func (_c *MockProxy_InvalidateCredentialCache_Call) Run(run func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest)) *MockProxy_InvalidateCredentialCache_Call { +func (_c *MockProxy_InvalidateCredentialCache_Call) Run(run func(_a0 context.Context, _a1 *proxypb.InvalidateCredCacheRequest)) *MockProxy_InvalidateCredentialCache_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*proxypb.InvalidateCredCacheRequest)) }) @@ -3190,17 +3359,72 @@ func (_c *MockProxy_InvalidateCredentialCache_Call) RunAndReturn(run func(contex return _c } -// ListClientInfos provides a mock function with given fields: ctx, req -func (_m *MockProxy) ListClientInfos(ctx context.Context, req *proxypb.ListClientInfosRequest) (*proxypb.ListClientInfosResponse, error) { - ret := _m.Called(ctx, req) +// ListAliases provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) ListAliases(_a0 context.Context, _a1 *milvuspb.ListAliasesRequest) (*milvuspb.ListAliasesResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *milvuspb.ListAliasesResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListAliasesRequest) (*milvuspb.ListAliasesResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListAliasesRequest) *milvuspb.ListAliasesResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ListAliasesResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListAliasesRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxy_ListAliases_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAliases' +type MockProxy_ListAliases_Call struct { + *mock.Call +} + +// ListAliases is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *milvuspb.ListAliasesRequest +func (_e *MockProxy_Expecter) ListAliases(_a0 interface{}, _a1 interface{}) *MockProxy_ListAliases_Call { + return &MockProxy_ListAliases_Call{Call: _e.mock.On("ListAliases", _a0, _a1)} +} + +func (_c *MockProxy_ListAliases_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ListAliasesRequest)) *MockProxy_ListAliases_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.ListAliasesRequest)) + }) + return _c +} + +func (_c *MockProxy_ListAliases_Call) Return(_a0 *milvuspb.ListAliasesResponse, _a1 error) *MockProxy_ListAliases_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxy_ListAliases_Call) RunAndReturn(run func(context.Context, *milvuspb.ListAliasesRequest) (*milvuspb.ListAliasesResponse, error)) *MockProxy_ListAliases_Call { + _c.Call.Return(run) + return _c +} + +// ListClientInfos provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) ListClientInfos(_a0 context.Context, _a1 *proxypb.ListClientInfosRequest) (*proxypb.ListClientInfosResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *proxypb.ListClientInfosResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.ListClientInfosRequest) (*proxypb.ListClientInfosResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *proxypb.ListClientInfosRequest) *proxypb.ListClientInfosResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*proxypb.ListClientInfosResponse) @@ -3208,7 +3432,7 @@ func (_m *MockProxy) ListClientInfos(ctx context.Context, req *proxypb.ListClien } if rf, ok := ret.Get(1).(func(context.Context, *proxypb.ListClientInfosRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3222,13 +3446,13 @@ type MockProxy_ListClientInfos_Call struct { } // ListClientInfos is a helper method to define mock.On call -// - ctx context.Context -// - req *proxypb.ListClientInfosRequest -func (_e *MockProxy_Expecter) ListClientInfos(ctx interface{}, req interface{}) *MockProxy_ListClientInfos_Call { - return &MockProxy_ListClientInfos_Call{Call: _e.mock.On("ListClientInfos", ctx, req)} +// - _a0 context.Context +// - _a1 *proxypb.ListClientInfosRequest +func (_e *MockProxy_Expecter) ListClientInfos(_a0 interface{}, _a1 interface{}) *MockProxy_ListClientInfos_Call { + return &MockProxy_ListClientInfos_Call{Call: _e.mock.On("ListClientInfos", _a0, _a1)} } -func (_c *MockProxy_ListClientInfos_Call) Run(run func(ctx context.Context, req *proxypb.ListClientInfosRequest)) *MockProxy_ListClientInfos_Call { +func (_c *MockProxy_ListClientInfos_Call) Run(run func(_a0 context.Context, _a1 *proxypb.ListClientInfosRequest)) *MockProxy_ListClientInfos_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*proxypb.ListClientInfosRequest)) }) @@ -3245,17 +3469,17 @@ func (_c *MockProxy_ListClientInfos_Call) RunAndReturn(run func(context.Context, return _c } -// ListCredUsers provides a mock function with given fields: ctx, req -func (_m *MockProxy) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) { - ret := _m.Called(ctx, req) +// ListCredUsers provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) ListCredUsers(_a0 context.Context, _a1 *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ListCredUsersResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListCredUsersRequest) *milvuspb.ListCredUsersResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ListCredUsersResponse) @@ -3263,7 +3487,7 @@ func (_m *MockProxy) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUs } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListCredUsersRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3277,13 +3501,13 @@ type MockProxy_ListCredUsers_Call struct { } // ListCredUsers is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ListCredUsersRequest -func (_e *MockProxy_Expecter) ListCredUsers(ctx interface{}, req interface{}) *MockProxy_ListCredUsers_Call { - return &MockProxy_ListCredUsers_Call{Call: _e.mock.On("ListCredUsers", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ListCredUsersRequest +func (_e *MockProxy_Expecter) ListCredUsers(_a0 interface{}, _a1 interface{}) *MockProxy_ListCredUsers_Call { + return &MockProxy_ListCredUsers_Call{Call: _e.mock.On("ListCredUsers", _a0, _a1)} } -func (_c *MockProxy_ListCredUsers_Call) Run(run func(ctx context.Context, req *milvuspb.ListCredUsersRequest)) *MockProxy_ListCredUsers_Call { +func (_c *MockProxy_ListCredUsers_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ListCredUsersRequest)) *MockProxy_ListCredUsers_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ListCredUsersRequest)) }) @@ -3300,17 +3524,17 @@ func (_c *MockProxy_ListCredUsers_Call) RunAndReturn(run func(context.Context, * return _c } -// ListDatabases provides a mock function with given fields: ctx, req -func (_m *MockProxy) ListDatabases(ctx context.Context, req *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) { - ret := _m.Called(ctx, req) +// ListDatabases provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) ListDatabases(_a0 context.Context, _a1 *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ListDatabasesResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListDatabasesRequest) *milvuspb.ListDatabasesResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ListDatabasesResponse) @@ -3318,7 +3542,7 @@ func (_m *MockProxy) ListDatabases(ctx context.Context, req *milvuspb.ListDataba } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListDatabasesRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3332,13 +3556,13 @@ type MockProxy_ListDatabases_Call struct { } // ListDatabases is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ListDatabasesRequest -func (_e *MockProxy_Expecter) ListDatabases(ctx interface{}, req interface{}) *MockProxy_ListDatabases_Call { - return &MockProxy_ListDatabases_Call{Call: _e.mock.On("ListDatabases", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ListDatabasesRequest +func (_e *MockProxy_Expecter) ListDatabases(_a0 interface{}, _a1 interface{}) *MockProxy_ListDatabases_Call { + return &MockProxy_ListDatabases_Call{Call: _e.mock.On("ListDatabases", _a0, _a1)} } -func (_c *MockProxy_ListDatabases_Call) Run(run func(ctx context.Context, req *milvuspb.ListDatabasesRequest)) *MockProxy_ListDatabases_Call { +func (_c *MockProxy_ListDatabases_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ListDatabasesRequest)) *MockProxy_ListDatabases_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ListDatabasesRequest)) }) @@ -3355,17 +3579,17 @@ func (_c *MockProxy_ListDatabases_Call) RunAndReturn(run func(context.Context, * return _c } -// ListImportTasks provides a mock function with given fields: ctx, req -func (_m *MockProxy) ListImportTasks(ctx context.Context, req *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) { - ret := _m.Called(ctx, req) +// ListImportTasks provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) ListImportTasks(_a0 context.Context, _a1 *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ListImportTasksResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListImportTasksRequest) *milvuspb.ListImportTasksResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ListImportTasksResponse) @@ -3373,7 +3597,7 @@ func (_m *MockProxy) ListImportTasks(ctx context.Context, req *milvuspb.ListImpo } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListImportTasksRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3387,13 +3611,13 @@ type MockProxy_ListImportTasks_Call struct { } // ListImportTasks is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ListImportTasksRequest -func (_e *MockProxy_Expecter) ListImportTasks(ctx interface{}, req interface{}) *MockProxy_ListImportTasks_Call { - return &MockProxy_ListImportTasks_Call{Call: _e.mock.On("ListImportTasks", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ListImportTasksRequest +func (_e *MockProxy_Expecter) ListImportTasks(_a0 interface{}, _a1 interface{}) *MockProxy_ListImportTasks_Call { + return &MockProxy_ListImportTasks_Call{Call: _e.mock.On("ListImportTasks", _a0, _a1)} } -func (_c *MockProxy_ListImportTasks_Call) Run(run func(ctx context.Context, req *milvuspb.ListImportTasksRequest)) *MockProxy_ListImportTasks_Call { +func (_c *MockProxy_ListImportTasks_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ListImportTasksRequest)) *MockProxy_ListImportTasks_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ListImportTasksRequest)) }) @@ -3410,17 +3634,72 @@ func (_c *MockProxy_ListImportTasks_Call) RunAndReturn(run func(context.Context, return _c } -// ListResourceGroups provides a mock function with given fields: ctx, req -func (_m *MockProxy) ListResourceGroups(ctx context.Context, req *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) { - ret := _m.Called(ctx, req) +// ListIndexedSegment provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) ListIndexedSegment(_a0 context.Context, _a1 *federpb.ListIndexedSegmentRequest) (*federpb.ListIndexedSegmentResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *federpb.ListIndexedSegmentResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *federpb.ListIndexedSegmentRequest) (*federpb.ListIndexedSegmentResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *federpb.ListIndexedSegmentRequest) *federpb.ListIndexedSegmentResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*federpb.ListIndexedSegmentResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *federpb.ListIndexedSegmentRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxy_ListIndexedSegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListIndexedSegment' +type MockProxy_ListIndexedSegment_Call struct { + *mock.Call +} + +// ListIndexedSegment is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *federpb.ListIndexedSegmentRequest +func (_e *MockProxy_Expecter) ListIndexedSegment(_a0 interface{}, _a1 interface{}) *MockProxy_ListIndexedSegment_Call { + return &MockProxy_ListIndexedSegment_Call{Call: _e.mock.On("ListIndexedSegment", _a0, _a1)} +} + +func (_c *MockProxy_ListIndexedSegment_Call) Run(run func(_a0 context.Context, _a1 *federpb.ListIndexedSegmentRequest)) *MockProxy_ListIndexedSegment_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*federpb.ListIndexedSegmentRequest)) + }) + return _c +} + +func (_c *MockProxy_ListIndexedSegment_Call) Return(_a0 *federpb.ListIndexedSegmentResponse, _a1 error) *MockProxy_ListIndexedSegment_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxy_ListIndexedSegment_Call) RunAndReturn(run func(context.Context, *federpb.ListIndexedSegmentRequest) (*federpb.ListIndexedSegmentResponse, error)) *MockProxy_ListIndexedSegment_Call { + _c.Call.Return(run) + return _c +} + +// ListResourceGroups provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) ListResourceGroups(_a0 context.Context, _a1 *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ListResourceGroupsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListResourceGroupsRequest) *milvuspb.ListResourceGroupsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ListResourceGroupsResponse) @@ -3428,7 +3707,7 @@ func (_m *MockProxy) ListResourceGroups(ctx context.Context, req *milvuspb.ListR } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListResourceGroupsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3442,13 +3721,13 @@ type MockProxy_ListResourceGroups_Call struct { } // ListResourceGroups is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ListResourceGroupsRequest -func (_e *MockProxy_Expecter) ListResourceGroups(ctx interface{}, req interface{}) *MockProxy_ListResourceGroups_Call { - return &MockProxy_ListResourceGroups_Call{Call: _e.mock.On("ListResourceGroups", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ListResourceGroupsRequest +func (_e *MockProxy_Expecter) ListResourceGroups(_a0 interface{}, _a1 interface{}) *MockProxy_ListResourceGroups_Call { + return &MockProxy_ListResourceGroups_Call{Call: _e.mock.On("ListResourceGroups", _a0, _a1)} } -func (_c *MockProxy_ListResourceGroups_Call) Run(run func(ctx context.Context, req *milvuspb.ListResourceGroupsRequest)) *MockProxy_ListResourceGroups_Call { +func (_c *MockProxy_ListResourceGroups_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ListResourceGroupsRequest)) *MockProxy_ListResourceGroups_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ListResourceGroupsRequest)) }) @@ -3465,17 +3744,17 @@ func (_c *MockProxy_ListResourceGroups_Call) RunAndReturn(run func(context.Conte return _c } -// LoadBalance provides a mock function with given fields: ctx, request -func (_m *MockProxy) LoadBalance(ctx context.Context, request *milvuspb.LoadBalanceRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// LoadBalance provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) LoadBalance(_a0 context.Context, _a1 *milvuspb.LoadBalanceRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.LoadBalanceRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.LoadBalanceRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -3483,7 +3762,7 @@ func (_m *MockProxy) LoadBalance(ctx context.Context, request *milvuspb.LoadBala } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.LoadBalanceRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3497,13 +3776,13 @@ type MockProxy_LoadBalance_Call struct { } // LoadBalance is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.LoadBalanceRequest -func (_e *MockProxy_Expecter) LoadBalance(ctx interface{}, request interface{}) *MockProxy_LoadBalance_Call { - return &MockProxy_LoadBalance_Call{Call: _e.mock.On("LoadBalance", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.LoadBalanceRequest +func (_e *MockProxy_Expecter) LoadBalance(_a0 interface{}, _a1 interface{}) *MockProxy_LoadBalance_Call { + return &MockProxy_LoadBalance_Call{Call: _e.mock.On("LoadBalance", _a0, _a1)} } -func (_c *MockProxy_LoadBalance_Call) Run(run func(ctx context.Context, request *milvuspb.LoadBalanceRequest)) *MockProxy_LoadBalance_Call { +func (_c *MockProxy_LoadBalance_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.LoadBalanceRequest)) *MockProxy_LoadBalance_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.LoadBalanceRequest)) }) @@ -3520,17 +3799,17 @@ func (_c *MockProxy_LoadBalance_Call) RunAndReturn(run func(context.Context, *mi return _c } -// LoadCollection provides a mock function with given fields: ctx, request -func (_m *MockProxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// LoadCollection provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) LoadCollection(_a0 context.Context, _a1 *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.LoadCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.LoadCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -3538,7 +3817,7 @@ func (_m *MockProxy) LoadCollection(ctx context.Context, request *milvuspb.LoadC } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.LoadCollectionRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3552,13 +3831,13 @@ type MockProxy_LoadCollection_Call struct { } // LoadCollection is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.LoadCollectionRequest -func (_e *MockProxy_Expecter) LoadCollection(ctx interface{}, request interface{}) *MockProxy_LoadCollection_Call { - return &MockProxy_LoadCollection_Call{Call: _e.mock.On("LoadCollection", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.LoadCollectionRequest +func (_e *MockProxy_Expecter) LoadCollection(_a0 interface{}, _a1 interface{}) *MockProxy_LoadCollection_Call { + return &MockProxy_LoadCollection_Call{Call: _e.mock.On("LoadCollection", _a0, _a1)} } -func (_c *MockProxy_LoadCollection_Call) Run(run func(ctx context.Context, request *milvuspb.LoadCollectionRequest)) *MockProxy_LoadCollection_Call { +func (_c *MockProxy_LoadCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.LoadCollectionRequest)) *MockProxy_LoadCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.LoadCollectionRequest)) }) @@ -3575,17 +3854,17 @@ func (_c *MockProxy_LoadCollection_Call) RunAndReturn(run func(context.Context, return _c } -// LoadPartitions provides a mock function with given fields: ctx, request -func (_m *MockProxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadPartitionsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// LoadPartitions provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) LoadPartitions(_a0 context.Context, _a1 *milvuspb.LoadPartitionsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.LoadPartitionsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.LoadPartitionsRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -3593,7 +3872,7 @@ func (_m *MockProxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadP } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.LoadPartitionsRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3607,13 +3886,13 @@ type MockProxy_LoadPartitions_Call struct { } // LoadPartitions is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.LoadPartitionsRequest -func (_e *MockProxy_Expecter) LoadPartitions(ctx interface{}, request interface{}) *MockProxy_LoadPartitions_Call { - return &MockProxy_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.LoadPartitionsRequest +func (_e *MockProxy_Expecter) LoadPartitions(_a0 interface{}, _a1 interface{}) *MockProxy_LoadPartitions_Call { + return &MockProxy_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", _a0, _a1)} } -func (_c *MockProxy_LoadPartitions_Call) Run(run func(ctx context.Context, request *milvuspb.LoadPartitionsRequest)) *MockProxy_LoadPartitions_Call { +func (_c *MockProxy_LoadPartitions_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.LoadPartitionsRequest)) *MockProxy_LoadPartitions_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.LoadPartitionsRequest)) }) @@ -3630,17 +3909,17 @@ func (_c *MockProxy_LoadPartitions_Call) RunAndReturn(run func(context.Context, return _c } -// ManualCompaction provides a mock function with given fields: ctx, req -func (_m *MockProxy) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) { - ret := _m.Called(ctx, req) +// ManualCompaction provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) ManualCompaction(_a0 context.Context, _a1 *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ManualCompactionResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ManualCompactionRequest) *milvuspb.ManualCompactionResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ManualCompactionResponse) @@ -3648,7 +3927,7 @@ func (_m *MockProxy) ManualCompaction(ctx context.Context, req *milvuspb.ManualC } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ManualCompactionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3662,13 +3941,13 @@ type MockProxy_ManualCompaction_Call struct { } // ManualCompaction is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ManualCompactionRequest -func (_e *MockProxy_Expecter) ManualCompaction(ctx interface{}, req interface{}) *MockProxy_ManualCompaction_Call { - return &MockProxy_ManualCompaction_Call{Call: _e.mock.On("ManualCompaction", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ManualCompactionRequest +func (_e *MockProxy_Expecter) ManualCompaction(_a0 interface{}, _a1 interface{}) *MockProxy_ManualCompaction_Call { + return &MockProxy_ManualCompaction_Call{Call: _e.mock.On("ManualCompaction", _a0, _a1)} } -func (_c *MockProxy_ManualCompaction_Call) Run(run func(ctx context.Context, req *milvuspb.ManualCompactionRequest)) *MockProxy_ManualCompaction_Call { +func (_c *MockProxy_ManualCompaction_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ManualCompactionRequest)) *MockProxy_ManualCompaction_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ManualCompactionRequest)) }) @@ -3685,17 +3964,17 @@ func (_c *MockProxy_ManualCompaction_Call) RunAndReturn(run func(context.Context return _c } -// OperatePrivilege provides a mock function with given fields: ctx, req -func (_m *MockProxy) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// OperatePrivilege provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) OperatePrivilege(_a0 context.Context, _a1 *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperatePrivilegeRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -3703,7 +3982,7 @@ func (_m *MockProxy) OperatePrivilege(ctx context.Context, req *milvuspb.Operate } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.OperatePrivilegeRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3717,13 +3996,13 @@ type MockProxy_OperatePrivilege_Call struct { } // OperatePrivilege is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.OperatePrivilegeRequest -func (_e *MockProxy_Expecter) OperatePrivilege(ctx interface{}, req interface{}) *MockProxy_OperatePrivilege_Call { - return &MockProxy_OperatePrivilege_Call{Call: _e.mock.On("OperatePrivilege", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.OperatePrivilegeRequest +func (_e *MockProxy_Expecter) OperatePrivilege(_a0 interface{}, _a1 interface{}) *MockProxy_OperatePrivilege_Call { + return &MockProxy_OperatePrivilege_Call{Call: _e.mock.On("OperatePrivilege", _a0, _a1)} } -func (_c *MockProxy_OperatePrivilege_Call) Run(run func(ctx context.Context, req *milvuspb.OperatePrivilegeRequest)) *MockProxy_OperatePrivilege_Call { +func (_c *MockProxy_OperatePrivilege_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.OperatePrivilegeRequest)) *MockProxy_OperatePrivilege_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.OperatePrivilegeRequest)) }) @@ -3740,17 +4019,17 @@ func (_c *MockProxy_OperatePrivilege_Call) RunAndReturn(run func(context.Context return _c } -// OperateUserRole provides a mock function with given fields: ctx, req -func (_m *MockProxy) OperateUserRole(ctx context.Context, req *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// OperateUserRole provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) OperateUserRole(_a0 context.Context, _a1 *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperateUserRoleRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -3758,7 +4037,7 @@ func (_m *MockProxy) OperateUserRole(ctx context.Context, req *milvuspb.OperateU } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.OperateUserRoleRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3772,13 +4051,13 @@ type MockProxy_OperateUserRole_Call struct { } // OperateUserRole is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.OperateUserRoleRequest -func (_e *MockProxy_Expecter) OperateUserRole(ctx interface{}, req interface{}) *MockProxy_OperateUserRole_Call { - return &MockProxy_OperateUserRole_Call{Call: _e.mock.On("OperateUserRole", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.OperateUserRoleRequest +func (_e *MockProxy_Expecter) OperateUserRole(_a0 interface{}, _a1 interface{}) *MockProxy_OperateUserRole_Call { + return &MockProxy_OperateUserRole_Call{Call: _e.mock.On("OperateUserRole", _a0, _a1)} } -func (_c *MockProxy_OperateUserRole_Call) Run(run func(ctx context.Context, req *milvuspb.OperateUserRoleRequest)) *MockProxy_OperateUserRole_Call { +func (_c *MockProxy_OperateUserRole_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.OperateUserRoleRequest)) *MockProxy_OperateUserRole_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.OperateUserRoleRequest)) }) @@ -3795,17 +4074,17 @@ func (_c *MockProxy_OperateUserRole_Call) RunAndReturn(run func(context.Context, return _c } -// Query provides a mock function with given fields: ctx, request -func (_m *MockProxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { - ret := _m.Called(ctx, request) +// Query provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) Query(_a0 context.Context, _a1 *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.QueryResults var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.QueryRequest) (*milvuspb.QueryResults, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.QueryRequest) *milvuspb.QueryResults); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.QueryResults) @@ -3813,7 +4092,7 @@ func (_m *MockProxy) Query(ctx context.Context, request *milvuspb.QueryRequest) } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.QueryRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3827,13 +4106,13 @@ type MockProxy_Query_Call struct { } // Query is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.QueryRequest -func (_e *MockProxy_Expecter) Query(ctx interface{}, request interface{}) *MockProxy_Query_Call { - return &MockProxy_Query_Call{Call: _e.mock.On("Query", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.QueryRequest +func (_e *MockProxy_Expecter) Query(_a0 interface{}, _a1 interface{}) *MockProxy_Query_Call { + return &MockProxy_Query_Call{Call: _e.mock.On("Query", _a0, _a1)} } -func (_c *MockProxy_Query_Call) Run(run func(ctx context.Context, request *milvuspb.QueryRequest)) *MockProxy_Query_Call { +func (_c *MockProxy_Query_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.QueryRequest)) *MockProxy_Query_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.QueryRequest)) }) @@ -3850,17 +4129,17 @@ func (_c *MockProxy_Query_Call) RunAndReturn(run func(context.Context, *milvuspb return _c } -// RefreshPolicyInfoCache provides a mock function with given fields: ctx, req -func (_m *MockProxy) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// RefreshPolicyInfoCache provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) RefreshPolicyInfoCache(_a0 context.Context, _a1 *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -3868,7 +4147,7 @@ func (_m *MockProxy) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.Re } if rf, ok := ret.Get(1).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3882,13 +4161,13 @@ type MockProxy_RefreshPolicyInfoCache_Call struct { } // RefreshPolicyInfoCache is a helper method to define mock.On call -// - ctx context.Context -// - req *proxypb.RefreshPolicyInfoCacheRequest -func (_e *MockProxy_Expecter) RefreshPolicyInfoCache(ctx interface{}, req interface{}) *MockProxy_RefreshPolicyInfoCache_Call { - return &MockProxy_RefreshPolicyInfoCache_Call{Call: _e.mock.On("RefreshPolicyInfoCache", ctx, req)} +// - _a0 context.Context +// - _a1 *proxypb.RefreshPolicyInfoCacheRequest +func (_e *MockProxy_Expecter) RefreshPolicyInfoCache(_a0 interface{}, _a1 interface{}) *MockProxy_RefreshPolicyInfoCache_Call { + return &MockProxy_RefreshPolicyInfoCache_Call{Call: _e.mock.On("RefreshPolicyInfoCache", _a0, _a1)} } -func (_c *MockProxy_RefreshPolicyInfoCache_Call) Run(run func(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest)) *MockProxy_RefreshPolicyInfoCache_Call { +func (_c *MockProxy_RefreshPolicyInfoCache_Call) Run(run func(_a0 context.Context, _a1 *proxypb.RefreshPolicyInfoCacheRequest)) *MockProxy_RefreshPolicyInfoCache_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*proxypb.RefreshPolicyInfoCacheRequest)) }) @@ -3946,17 +4225,17 @@ func (_c *MockProxy_Register_Call) RunAndReturn(run func() error) *MockProxy_Reg return _c } -// RegisterLink provides a mock function with given fields: ctx, request -func (_m *MockProxy) RegisterLink(ctx context.Context, request *milvuspb.RegisterLinkRequest) (*milvuspb.RegisterLinkResponse, error) { - ret := _m.Called(ctx, request) +// RegisterLink provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) RegisterLink(_a0 context.Context, _a1 *milvuspb.RegisterLinkRequest) (*milvuspb.RegisterLinkResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.RegisterLinkResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.RegisterLinkRequest) (*milvuspb.RegisterLinkResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.RegisterLinkRequest) *milvuspb.RegisterLinkResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.RegisterLinkResponse) @@ -3964,7 +4243,7 @@ func (_m *MockProxy) RegisterLink(ctx context.Context, request *milvuspb.Registe } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.RegisterLinkRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -3978,13 +4257,13 @@ type MockProxy_RegisterLink_Call struct { } // RegisterLink is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.RegisterLinkRequest -func (_e *MockProxy_Expecter) RegisterLink(ctx interface{}, request interface{}) *MockProxy_RegisterLink_Call { - return &MockProxy_RegisterLink_Call{Call: _e.mock.On("RegisterLink", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.RegisterLinkRequest +func (_e *MockProxy_Expecter) RegisterLink(_a0 interface{}, _a1 interface{}) *MockProxy_RegisterLink_Call { + return &MockProxy_RegisterLink_Call{Call: _e.mock.On("RegisterLink", _a0, _a1)} } -func (_c *MockProxy_RegisterLink_Call) Run(run func(ctx context.Context, request *milvuspb.RegisterLinkRequest)) *MockProxy_RegisterLink_Call { +func (_c *MockProxy_RegisterLink_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.RegisterLinkRequest)) *MockProxy_RegisterLink_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.RegisterLinkRequest)) }) @@ -4001,17 +4280,17 @@ func (_c *MockProxy_RegisterLink_Call) RunAndReturn(run func(context.Context, *m return _c } -// ReleaseCollection provides a mock function with given fields: ctx, request -func (_m *MockProxy) ReleaseCollection(ctx context.Context, request *milvuspb.ReleaseCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// ReleaseCollection provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) ReleaseCollection(_a0 context.Context, _a1 *milvuspb.ReleaseCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ReleaseCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ReleaseCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -4019,7 +4298,7 @@ func (_m *MockProxy) ReleaseCollection(ctx context.Context, request *milvuspb.Re } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ReleaseCollectionRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -4033,13 +4312,13 @@ type MockProxy_ReleaseCollection_Call struct { } // ReleaseCollection is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.ReleaseCollectionRequest -func (_e *MockProxy_Expecter) ReleaseCollection(ctx interface{}, request interface{}) *MockProxy_ReleaseCollection_Call { - return &MockProxy_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.ReleaseCollectionRequest +func (_e *MockProxy_Expecter) ReleaseCollection(_a0 interface{}, _a1 interface{}) *MockProxy_ReleaseCollection_Call { + return &MockProxy_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", _a0, _a1)} } -func (_c *MockProxy_ReleaseCollection_Call) Run(run func(ctx context.Context, request *milvuspb.ReleaseCollectionRequest)) *MockProxy_ReleaseCollection_Call { +func (_c *MockProxy_ReleaseCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ReleaseCollectionRequest)) *MockProxy_ReleaseCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ReleaseCollectionRequest)) }) @@ -4056,17 +4335,17 @@ func (_c *MockProxy_ReleaseCollection_Call) RunAndReturn(run func(context.Contex return _c } -// ReleasePartitions provides a mock function with given fields: ctx, request -func (_m *MockProxy) ReleasePartitions(ctx context.Context, request *milvuspb.ReleasePartitionsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// ReleasePartitions provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) ReleasePartitions(_a0 context.Context, _a1 *milvuspb.ReleasePartitionsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ReleasePartitionsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ReleasePartitionsRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -4074,7 +4353,7 @@ func (_m *MockProxy) ReleasePartitions(ctx context.Context, request *milvuspb.Re } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ReleasePartitionsRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -4088,13 +4367,13 @@ type MockProxy_ReleasePartitions_Call struct { } // ReleasePartitions is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.ReleasePartitionsRequest -func (_e *MockProxy_Expecter) ReleasePartitions(ctx interface{}, request interface{}) *MockProxy_ReleasePartitions_Call { - return &MockProxy_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.ReleasePartitionsRequest +func (_e *MockProxy_Expecter) ReleasePartitions(_a0 interface{}, _a1 interface{}) *MockProxy_ReleasePartitions_Call { + return &MockProxy_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", _a0, _a1)} } -func (_c *MockProxy_ReleasePartitions_Call) Run(run func(ctx context.Context, request *milvuspb.ReleasePartitionsRequest)) *MockProxy_ReleasePartitions_Call { +func (_c *MockProxy_ReleasePartitions_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ReleasePartitionsRequest)) *MockProxy_ReleasePartitions_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ReleasePartitionsRequest)) }) @@ -4111,17 +4390,17 @@ func (_c *MockProxy_ReleasePartitions_Call) RunAndReturn(run func(context.Contex return _c } -// RenameCollection provides a mock function with given fields: ctx, req -func (_m *MockProxy) RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// RenameCollection provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) RenameCollection(_a0 context.Context, _a1 *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.RenameCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.RenameCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -4129,7 +4408,7 @@ func (_m *MockProxy) RenameCollection(ctx context.Context, req *milvuspb.RenameC } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.RenameCollectionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -4143,13 +4422,13 @@ type MockProxy_RenameCollection_Call struct { } // RenameCollection is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.RenameCollectionRequest -func (_e *MockProxy_Expecter) RenameCollection(ctx interface{}, req interface{}) *MockProxy_RenameCollection_Call { - return &MockProxy_RenameCollection_Call{Call: _e.mock.On("RenameCollection", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.RenameCollectionRequest +func (_e *MockProxy_Expecter) RenameCollection(_a0 interface{}, _a1 interface{}) *MockProxy_RenameCollection_Call { + return &MockProxy_RenameCollection_Call{Call: _e.mock.On("RenameCollection", _a0, _a1)} } -func (_c *MockProxy_RenameCollection_Call) Run(run func(ctx context.Context, req *milvuspb.RenameCollectionRequest)) *MockProxy_RenameCollection_Call { +func (_c *MockProxy_RenameCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.RenameCollectionRequest)) *MockProxy_RenameCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.RenameCollectionRequest)) }) @@ -4166,17 +4445,72 @@ func (_c *MockProxy_RenameCollection_Call) RunAndReturn(run func(context.Context return _c } -// Search provides a mock function with given fields: ctx, request -func (_m *MockProxy) Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { - ret := _m.Called(ctx, request) +// ReplicateMessage provides a mock function with given fields: ctx, req +func (_m *MockProxy) ReplicateMessage(ctx context.Context, req *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) { + ret := _m.Called(ctx, req) + + var r0 *milvuspb.ReplicateMessageResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ReplicateMessageRequest) *milvuspb.ReplicateMessageResponse); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ReplicateMessageResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ReplicateMessageRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxy_ReplicateMessage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReplicateMessage' +type MockProxy_ReplicateMessage_Call struct { + *mock.Call +} + +// ReplicateMessage is a helper method to define mock.On call +// - ctx context.Context +// - req *milvuspb.ReplicateMessageRequest +func (_e *MockProxy_Expecter) ReplicateMessage(ctx interface{}, req interface{}) *MockProxy_ReplicateMessage_Call { + return &MockProxy_ReplicateMessage_Call{Call: _e.mock.On("ReplicateMessage", ctx, req)} +} + +func (_c *MockProxy_ReplicateMessage_Call) Run(run func(ctx context.Context, req *milvuspb.ReplicateMessageRequest)) *MockProxy_ReplicateMessage_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.ReplicateMessageRequest)) + }) + return _c +} + +func (_c *MockProxy_ReplicateMessage_Call) Return(_a0 *milvuspb.ReplicateMessageResponse, _a1 error) *MockProxy_ReplicateMessage_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxy_ReplicateMessage_Call) RunAndReturn(run func(context.Context, *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error)) *MockProxy_ReplicateMessage_Call { + _c.Call.Return(run) + return _c +} + +// Search provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) Search(_a0 context.Context, _a1 *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.SearchResults var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SearchRequest) (*milvuspb.SearchResults, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SearchRequest) *milvuspb.SearchResults); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.SearchResults) @@ -4184,7 +4518,7 @@ func (_m *MockProxy) Search(ctx context.Context, request *milvuspb.SearchRequest } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SearchRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -4198,13 +4532,13 @@ type MockProxy_Search_Call struct { } // Search is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.SearchRequest -func (_e *MockProxy_Expecter) Search(ctx interface{}, request interface{}) *MockProxy_Search_Call { - return &MockProxy_Search_Call{Call: _e.mock.On("Search", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.SearchRequest +func (_e *MockProxy_Expecter) Search(_a0 interface{}, _a1 interface{}) *MockProxy_Search_Call { + return &MockProxy_Search_Call{Call: _e.mock.On("Search", _a0, _a1)} } -func (_c *MockProxy_Search_Call) Run(run func(ctx context.Context, request *milvuspb.SearchRequest)) *MockProxy_Search_Call { +func (_c *MockProxy_Search_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.SearchRequest)) *MockProxy_Search_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.SearchRequest)) }) @@ -4221,17 +4555,17 @@ func (_c *MockProxy_Search_Call) RunAndReturn(run func(context.Context, *milvusp return _c } -// SelectGrant provides a mock function with given fields: ctx, req -func (_m *MockProxy) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) { - ret := _m.Called(ctx, req) +// SelectGrant provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) SelectGrant(_a0 context.Context, _a1 *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.SelectGrantResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectGrantRequest) *milvuspb.SelectGrantResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.SelectGrantResponse) @@ -4239,7 +4573,7 @@ func (_m *MockProxy) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantR } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SelectGrantRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -4253,13 +4587,13 @@ type MockProxy_SelectGrant_Call struct { } // SelectGrant is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.SelectGrantRequest -func (_e *MockProxy_Expecter) SelectGrant(ctx interface{}, req interface{}) *MockProxy_SelectGrant_Call { - return &MockProxy_SelectGrant_Call{Call: _e.mock.On("SelectGrant", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.SelectGrantRequest +func (_e *MockProxy_Expecter) SelectGrant(_a0 interface{}, _a1 interface{}) *MockProxy_SelectGrant_Call { + return &MockProxy_SelectGrant_Call{Call: _e.mock.On("SelectGrant", _a0, _a1)} } -func (_c *MockProxy_SelectGrant_Call) Run(run func(ctx context.Context, req *milvuspb.SelectGrantRequest)) *MockProxy_SelectGrant_Call { +func (_c *MockProxy_SelectGrant_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.SelectGrantRequest)) *MockProxy_SelectGrant_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.SelectGrantRequest)) }) @@ -4276,17 +4610,17 @@ func (_c *MockProxy_SelectGrant_Call) RunAndReturn(run func(context.Context, *mi return _c } -// SelectRole provides a mock function with given fields: ctx, req -func (_m *MockProxy) SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) { - ret := _m.Called(ctx, req) +// SelectRole provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) SelectRole(_a0 context.Context, _a1 *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.SelectRoleResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectRoleRequest) *milvuspb.SelectRoleResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.SelectRoleResponse) @@ -4294,7 +4628,7 @@ func (_m *MockProxy) SelectRole(ctx context.Context, req *milvuspb.SelectRoleReq } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SelectRoleRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -4308,13 +4642,13 @@ type MockProxy_SelectRole_Call struct { } // SelectRole is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.SelectRoleRequest -func (_e *MockProxy_Expecter) SelectRole(ctx interface{}, req interface{}) *MockProxy_SelectRole_Call { - return &MockProxy_SelectRole_Call{Call: _e.mock.On("SelectRole", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.SelectRoleRequest +func (_e *MockProxy_Expecter) SelectRole(_a0 interface{}, _a1 interface{}) *MockProxy_SelectRole_Call { + return &MockProxy_SelectRole_Call{Call: _e.mock.On("SelectRole", _a0, _a1)} } -func (_c *MockProxy_SelectRole_Call) Run(run func(ctx context.Context, req *milvuspb.SelectRoleRequest)) *MockProxy_SelectRole_Call { +func (_c *MockProxy_SelectRole_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.SelectRoleRequest)) *MockProxy_SelectRole_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.SelectRoleRequest)) }) @@ -4331,17 +4665,17 @@ func (_c *MockProxy_SelectRole_Call) RunAndReturn(run func(context.Context, *mil return _c } -// SelectUser provides a mock function with given fields: ctx, req -func (_m *MockProxy) SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) { - ret := _m.Called(ctx, req) +// SelectUser provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) SelectUser(_a0 context.Context, _a1 *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.SelectUserResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectUserRequest) *milvuspb.SelectUserResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.SelectUserResponse) @@ -4349,7 +4683,7 @@ func (_m *MockProxy) SelectUser(ctx context.Context, req *milvuspb.SelectUserReq } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SelectUserRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -4363,13 +4697,13 @@ type MockProxy_SelectUser_Call struct { } // SelectUser is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.SelectUserRequest -func (_e *MockProxy_Expecter) SelectUser(ctx interface{}, req interface{}) *MockProxy_SelectUser_Call { - return &MockProxy_SelectUser_Call{Call: _e.mock.On("SelectUser", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.SelectUserRequest +func (_e *MockProxy_Expecter) SelectUser(_a0 interface{}, _a1 interface{}) *MockProxy_SelectUser_Call { + return &MockProxy_SelectUser_Call{Call: _e.mock.On("SelectUser", _a0, _a1)} } -func (_c *MockProxy_SelectUser_Call) Run(run func(ctx context.Context, req *milvuspb.SelectUserRequest)) *MockProxy_SelectUser_Call { +func (_c *MockProxy_SelectUser_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.SelectUserRequest)) *MockProxy_SelectUser_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.SelectUserRequest)) }) @@ -4420,7 +4754,7 @@ func (_c *MockProxy_SetAddress_Call) RunAndReturn(run func(string)) *MockProxy_S } // SetDataCoordClient provides a mock function with given fields: dataCoord -func (_m *MockProxy) SetDataCoordClient(dataCoord types.DataCoord) { +func (_m *MockProxy) SetDataCoordClient(dataCoord types.DataCoordClient) { _m.Called(dataCoord) } @@ -4430,14 +4764,14 @@ type MockProxy_SetDataCoordClient_Call struct { } // SetDataCoordClient is a helper method to define mock.On call -// - dataCoord types.DataCoord +// - dataCoord types.DataCoordClient func (_e *MockProxy_Expecter) SetDataCoordClient(dataCoord interface{}) *MockProxy_SetDataCoordClient_Call { return &MockProxy_SetDataCoordClient_Call{Call: _e.mock.On("SetDataCoordClient", dataCoord)} } -func (_c *MockProxy_SetDataCoordClient_Call) Run(run func(dataCoord types.DataCoord)) *MockProxy_SetDataCoordClient_Call { +func (_c *MockProxy_SetDataCoordClient_Call) Run(run func(dataCoord types.DataCoordClient)) *MockProxy_SetDataCoordClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(types.DataCoord)) + run(args[0].(types.DataCoordClient)) }) return _c } @@ -4447,7 +4781,7 @@ func (_c *MockProxy_SetDataCoordClient_Call) Return() *MockProxy_SetDataCoordCli return _c } -func (_c *MockProxy_SetDataCoordClient_Call) RunAndReturn(run func(types.DataCoord)) *MockProxy_SetDataCoordClient_Call { +func (_c *MockProxy_SetDataCoordClient_Call) RunAndReturn(run func(types.DataCoordClient)) *MockProxy_SetDataCoordClient_Call { _c.Call.Return(run) return _c } @@ -4486,7 +4820,7 @@ func (_c *MockProxy_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Client)) } // SetQueryCoordClient provides a mock function with given fields: queryCoord -func (_m *MockProxy) SetQueryCoordClient(queryCoord types.QueryCoord) { +func (_m *MockProxy) SetQueryCoordClient(queryCoord types.QueryCoordClient) { _m.Called(queryCoord) } @@ -4496,14 +4830,14 @@ type MockProxy_SetQueryCoordClient_Call struct { } // SetQueryCoordClient is a helper method to define mock.On call -// - queryCoord types.QueryCoord +// - queryCoord types.QueryCoordClient func (_e *MockProxy_Expecter) SetQueryCoordClient(queryCoord interface{}) *MockProxy_SetQueryCoordClient_Call { return &MockProxy_SetQueryCoordClient_Call{Call: _e.mock.On("SetQueryCoordClient", queryCoord)} } -func (_c *MockProxy_SetQueryCoordClient_Call) Run(run func(queryCoord types.QueryCoord)) *MockProxy_SetQueryCoordClient_Call { +func (_c *MockProxy_SetQueryCoordClient_Call) Run(run func(queryCoord types.QueryCoordClient)) *MockProxy_SetQueryCoordClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(types.QueryCoord)) + run(args[0].(types.QueryCoordClient)) }) return _c } @@ -4513,13 +4847,13 @@ func (_c *MockProxy_SetQueryCoordClient_Call) Return() *MockProxy_SetQueryCoordC return _c } -func (_c *MockProxy_SetQueryCoordClient_Call) RunAndReturn(run func(types.QueryCoord)) *MockProxy_SetQueryCoordClient_Call { +func (_c *MockProxy_SetQueryCoordClient_Call) RunAndReturn(run func(types.QueryCoordClient)) *MockProxy_SetQueryCoordClient_Call { _c.Call.Return(run) return _c } // SetQueryNodeCreator provides a mock function with given fields: _a0 -func (_m *MockProxy) SetQueryNodeCreator(_a0 func(context.Context, string, int64) (types.QueryNode, error)) { +func (_m *MockProxy) SetQueryNodeCreator(_a0 func(context.Context, string, int64) (types.QueryNodeClient, error)) { _m.Called(_a0) } @@ -4529,14 +4863,14 @@ type MockProxy_SetQueryNodeCreator_Call struct { } // SetQueryNodeCreator is a helper method to define mock.On call -// - _a0 func(context.Context , string , int64)(types.QueryNode , error) +// - _a0 func(context.Context , string , int64)(types.QueryNodeClient , error) func (_e *MockProxy_Expecter) SetQueryNodeCreator(_a0 interface{}) *MockProxy_SetQueryNodeCreator_Call { return &MockProxy_SetQueryNodeCreator_Call{Call: _e.mock.On("SetQueryNodeCreator", _a0)} } -func (_c *MockProxy_SetQueryNodeCreator_Call) Run(run func(_a0 func(context.Context, string, int64) (types.QueryNode, error))) *MockProxy_SetQueryNodeCreator_Call { +func (_c *MockProxy_SetQueryNodeCreator_Call) Run(run func(_a0 func(context.Context, string, int64) (types.QueryNodeClient, error))) *MockProxy_SetQueryNodeCreator_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(func(context.Context, string, int64) (types.QueryNode, error))) + run(args[0].(func(context.Context, string, int64) (types.QueryNodeClient, error))) }) return _c } @@ -4546,22 +4880,22 @@ func (_c *MockProxy_SetQueryNodeCreator_Call) Return() *MockProxy_SetQueryNodeCr return _c } -func (_c *MockProxy_SetQueryNodeCreator_Call) RunAndReturn(run func(func(context.Context, string, int64) (types.QueryNode, error))) *MockProxy_SetQueryNodeCreator_Call { +func (_c *MockProxy_SetQueryNodeCreator_Call) RunAndReturn(run func(func(context.Context, string, int64) (types.QueryNodeClient, error))) *MockProxy_SetQueryNodeCreator_Call { _c.Call.Return(run) return _c } -// SetRates provides a mock function with given fields: ctx, req -func (_m *MockProxy) SetRates(ctx context.Context, req *proxypb.SetRatesRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// SetRates provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) SetRates(_a0 context.Context, _a1 *proxypb.SetRatesRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.SetRatesRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *proxypb.SetRatesRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -4569,7 +4903,7 @@ func (_m *MockProxy) SetRates(ctx context.Context, req *proxypb.SetRatesRequest) } if rf, ok := ret.Get(1).(func(context.Context, *proxypb.SetRatesRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -4583,13 +4917,13 @@ type MockProxy_SetRates_Call struct { } // SetRates is a helper method to define mock.On call -// - ctx context.Context -// - req *proxypb.SetRatesRequest -func (_e *MockProxy_Expecter) SetRates(ctx interface{}, req interface{}) *MockProxy_SetRates_Call { - return &MockProxy_SetRates_Call{Call: _e.mock.On("SetRates", ctx, req)} +// - _a0 context.Context +// - _a1 *proxypb.SetRatesRequest +func (_e *MockProxy_Expecter) SetRates(_a0 interface{}, _a1 interface{}) *MockProxy_SetRates_Call { + return &MockProxy_SetRates_Call{Call: _e.mock.On("SetRates", _a0, _a1)} } -func (_c *MockProxy_SetRates_Call) Run(run func(ctx context.Context, req *proxypb.SetRatesRequest)) *MockProxy_SetRates_Call { +func (_c *MockProxy_SetRates_Call) Run(run func(_a0 context.Context, _a1 *proxypb.SetRatesRequest)) *MockProxy_SetRates_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*proxypb.SetRatesRequest)) }) @@ -4607,7 +4941,7 @@ func (_c *MockProxy_SetRates_Call) RunAndReturn(run func(context.Context, *proxy } // SetRootCoordClient provides a mock function with given fields: rootCoord -func (_m *MockProxy) SetRootCoordClient(rootCoord types.RootCoord) { +func (_m *MockProxy) SetRootCoordClient(rootCoord types.RootCoordClient) { _m.Called(rootCoord) } @@ -4617,14 +4951,14 @@ type MockProxy_SetRootCoordClient_Call struct { } // SetRootCoordClient is a helper method to define mock.On call -// - rootCoord types.RootCoord +// - rootCoord types.RootCoordClient func (_e *MockProxy_Expecter) SetRootCoordClient(rootCoord interface{}) *MockProxy_SetRootCoordClient_Call { return &MockProxy_SetRootCoordClient_Call{Call: _e.mock.On("SetRootCoordClient", rootCoord)} } -func (_c *MockProxy_SetRootCoordClient_Call) Run(run func(rootCoord types.RootCoord)) *MockProxy_SetRootCoordClient_Call { +func (_c *MockProxy_SetRootCoordClient_Call) Run(run func(rootCoord types.RootCoordClient)) *MockProxy_SetRootCoordClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(types.RootCoord)) + run(args[0].(types.RootCoordClient)) }) return _c } @@ -4634,22 +4968,22 @@ func (_c *MockProxy_SetRootCoordClient_Call) Return() *MockProxy_SetRootCoordCli return _c } -func (_c *MockProxy_SetRootCoordClient_Call) RunAndReturn(run func(types.RootCoord)) *MockProxy_SetRootCoordClient_Call { +func (_c *MockProxy_SetRootCoordClient_Call) RunAndReturn(run func(types.RootCoordClient)) *MockProxy_SetRootCoordClient_Call { _c.Call.Return(run) return _c } -// ShowCollections provides a mock function with given fields: ctx, request -func (_m *MockProxy) ShowCollections(ctx context.Context, request *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { - ret := _m.Called(ctx, request) +// ShowCollections provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) ShowCollections(_a0 context.Context, _a1 *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ShowCollectionsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowCollectionsRequest) *milvuspb.ShowCollectionsResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ShowCollectionsResponse) @@ -4657,7 +4991,7 @@ func (_m *MockProxy) ShowCollections(ctx context.Context, request *milvuspb.Show } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ShowCollectionsRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -4671,13 +5005,13 @@ type MockProxy_ShowCollections_Call struct { } // ShowCollections is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.ShowCollectionsRequest -func (_e *MockProxy_Expecter) ShowCollections(ctx interface{}, request interface{}) *MockProxy_ShowCollections_Call { - return &MockProxy_ShowCollections_Call{Call: _e.mock.On("ShowCollections", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.ShowCollectionsRequest +func (_e *MockProxy_Expecter) ShowCollections(_a0 interface{}, _a1 interface{}) *MockProxy_ShowCollections_Call { + return &MockProxy_ShowCollections_Call{Call: _e.mock.On("ShowCollections", _a0, _a1)} } -func (_c *MockProxy_ShowCollections_Call) Run(run func(ctx context.Context, request *milvuspb.ShowCollectionsRequest)) *MockProxy_ShowCollections_Call { +func (_c *MockProxy_ShowCollections_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ShowCollectionsRequest)) *MockProxy_ShowCollections_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ShowCollectionsRequest)) }) @@ -4694,17 +5028,17 @@ func (_c *MockProxy_ShowCollections_Call) RunAndReturn(run func(context.Context, return _c } -// ShowPartitions provides a mock function with given fields: ctx, request -func (_m *MockProxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { - ret := _m.Called(ctx, request) +// ShowPartitions provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) ShowPartitions(_a0 context.Context, _a1 *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ShowPartitionsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowPartitionsRequest) *milvuspb.ShowPartitionsResponse); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ShowPartitionsResponse) @@ -4712,7 +5046,7 @@ func (_m *MockProxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowP } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ShowPartitionsRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -4726,13 +5060,13 @@ type MockProxy_ShowPartitions_Call struct { } // ShowPartitions is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.ShowPartitionsRequest -func (_e *MockProxy_Expecter) ShowPartitions(ctx interface{}, request interface{}) *MockProxy_ShowPartitions_Call { - return &MockProxy_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.ShowPartitionsRequest +func (_e *MockProxy_Expecter) ShowPartitions(_a0 interface{}, _a1 interface{}) *MockProxy_ShowPartitions_Call { + return &MockProxy_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", _a0, _a1)} } -func (_c *MockProxy_ShowPartitions_Call) Run(run func(ctx context.Context, request *milvuspb.ShowPartitionsRequest)) *MockProxy_ShowPartitions_Call { +func (_c *MockProxy_ShowPartitions_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ShowPartitionsRequest)) *MockProxy_ShowPartitions_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ShowPartitionsRequest)) }) @@ -4831,17 +5165,17 @@ func (_c *MockProxy_Stop_Call) RunAndReturn(run func() error) *MockProxy_Stop_Ca return _c } -// TransferNode provides a mock function with given fields: ctx, req -func (_m *MockProxy) TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// TransferNode provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) TransferNode(_a0 context.Context, _a1 *milvuspb.TransferNodeRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.TransferNodeRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.TransferNodeRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -4849,7 +5183,7 @@ func (_m *MockProxy) TransferNode(ctx context.Context, req *milvuspb.TransferNod } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.TransferNodeRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -4863,13 +5197,13 @@ type MockProxy_TransferNode_Call struct { } // TransferNode is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.TransferNodeRequest -func (_e *MockProxy_Expecter) TransferNode(ctx interface{}, req interface{}) *MockProxy_TransferNode_Call { - return &MockProxy_TransferNode_Call{Call: _e.mock.On("TransferNode", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.TransferNodeRequest +func (_e *MockProxy_Expecter) TransferNode(_a0 interface{}, _a1 interface{}) *MockProxy_TransferNode_Call { + return &MockProxy_TransferNode_Call{Call: _e.mock.On("TransferNode", _a0, _a1)} } -func (_c *MockProxy_TransferNode_Call) Run(run func(ctx context.Context, req *milvuspb.TransferNodeRequest)) *MockProxy_TransferNode_Call { +func (_c *MockProxy_TransferNode_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.TransferNodeRequest)) *MockProxy_TransferNode_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.TransferNodeRequest)) }) @@ -4886,17 +5220,17 @@ func (_c *MockProxy_TransferNode_Call) RunAndReturn(run func(context.Context, *m return _c } -// TransferReplica provides a mock function with given fields: ctx, req -func (_m *MockProxy) TransferReplica(ctx context.Context, req *milvuspb.TransferReplicaRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// TransferReplica provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) TransferReplica(_a0 context.Context, _a1 *milvuspb.TransferReplicaRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.TransferReplicaRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.TransferReplicaRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -4904,7 +5238,7 @@ func (_m *MockProxy) TransferReplica(ctx context.Context, req *milvuspb.Transfer } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.TransferReplicaRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -4918,13 +5252,13 @@ type MockProxy_TransferReplica_Call struct { } // TransferReplica is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.TransferReplicaRequest -func (_e *MockProxy_Expecter) TransferReplica(ctx interface{}, req interface{}) *MockProxy_TransferReplica_Call { - return &MockProxy_TransferReplica_Call{Call: _e.mock.On("TransferReplica", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.TransferReplicaRequest +func (_e *MockProxy_Expecter) TransferReplica(_a0 interface{}, _a1 interface{}) *MockProxy_TransferReplica_Call { + return &MockProxy_TransferReplica_Call{Call: _e.mock.On("TransferReplica", _a0, _a1)} } -func (_c *MockProxy_TransferReplica_Call) Run(run func(ctx context.Context, req *milvuspb.TransferReplicaRequest)) *MockProxy_TransferReplica_Call { +func (_c *MockProxy_TransferReplica_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.TransferReplicaRequest)) *MockProxy_TransferReplica_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.TransferReplicaRequest)) }) @@ -4941,17 +5275,17 @@ func (_c *MockProxy_TransferReplica_Call) RunAndReturn(run func(context.Context, return _c } -// UpdateCredential provides a mock function with given fields: ctx, req -func (_m *MockProxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCredentialRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// UpdateCredential provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) UpdateCredential(_a0 context.Context, _a1 *milvuspb.UpdateCredentialRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.UpdateCredentialRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.UpdateCredentialRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -4959,7 +5293,7 @@ func (_m *MockProxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateC } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.UpdateCredentialRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -4973,13 +5307,13 @@ type MockProxy_UpdateCredential_Call struct { } // UpdateCredential is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.UpdateCredentialRequest -func (_e *MockProxy_Expecter) UpdateCredential(ctx interface{}, req interface{}) *MockProxy_UpdateCredential_Call { - return &MockProxy_UpdateCredential_Call{Call: _e.mock.On("UpdateCredential", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.UpdateCredentialRequest +func (_e *MockProxy_Expecter) UpdateCredential(_a0 interface{}, _a1 interface{}) *MockProxy_UpdateCredential_Call { + return &MockProxy_UpdateCredential_Call{Call: _e.mock.On("UpdateCredential", _a0, _a1)} } -func (_c *MockProxy_UpdateCredential_Call) Run(run func(ctx context.Context, req *milvuspb.UpdateCredentialRequest)) *MockProxy_UpdateCredential_Call { +func (_c *MockProxy_UpdateCredential_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.UpdateCredentialRequest)) *MockProxy_UpdateCredential_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.UpdateCredentialRequest)) }) @@ -4996,17 +5330,17 @@ func (_c *MockProxy_UpdateCredential_Call) RunAndReturn(run func(context.Context return _c } -// UpdateCredentialCache provides a mock function with given fields: ctx, request -func (_m *MockProxy) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// UpdateCredentialCache provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) UpdateCredentialCache(_a0 context.Context, _a1 *proxypb.UpdateCredCacheRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.UpdateCredCacheRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *proxypb.UpdateCredCacheRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -5014,7 +5348,7 @@ func (_m *MockProxy) UpdateCredentialCache(ctx context.Context, request *proxypb } if rf, ok := ret.Get(1).(func(context.Context, *proxypb.UpdateCredCacheRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -5028,13 +5362,13 @@ type MockProxy_UpdateCredentialCache_Call struct { } // UpdateCredentialCache is a helper method to define mock.On call -// - ctx context.Context -// - request *proxypb.UpdateCredCacheRequest -func (_e *MockProxy_Expecter) UpdateCredentialCache(ctx interface{}, request interface{}) *MockProxy_UpdateCredentialCache_Call { - return &MockProxy_UpdateCredentialCache_Call{Call: _e.mock.On("UpdateCredentialCache", ctx, request)} +// - _a0 context.Context +// - _a1 *proxypb.UpdateCredCacheRequest +func (_e *MockProxy_Expecter) UpdateCredentialCache(_a0 interface{}, _a1 interface{}) *MockProxy_UpdateCredentialCache_Call { + return &MockProxy_UpdateCredentialCache_Call{Call: _e.mock.On("UpdateCredentialCache", _a0, _a1)} } -func (_c *MockProxy_UpdateCredentialCache_Call) Run(run func(ctx context.Context, request *proxypb.UpdateCredCacheRequest)) *MockProxy_UpdateCredentialCache_Call { +func (_c *MockProxy_UpdateCredentialCache_Call) Run(run func(_a0 context.Context, _a1 *proxypb.UpdateCredCacheRequest)) *MockProxy_UpdateCredentialCache_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*proxypb.UpdateCredCacheRequest)) }) @@ -5084,17 +5418,17 @@ func (_c *MockProxy_UpdateStateCode_Call) RunAndReturn(run func(commonpb.StateCo return _c } -// Upsert provides a mock function with given fields: ctx, request -func (_m *MockProxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) { - ret := _m.Called(ctx, request) +// Upsert provides a mock function with given fields: _a0, _a1 +func (_m *MockProxy) Upsert(_a0 context.Context, _a1 *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.MutationResult var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.UpsertRequest) *milvuspb.MutationResult); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.MutationResult) @@ -5102,7 +5436,7 @@ func (_m *MockProxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.UpsertRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -5116,13 +5450,13 @@ type MockProxy_Upsert_Call struct { } // Upsert is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.UpsertRequest -func (_e *MockProxy_Expecter) Upsert(ctx interface{}, request interface{}) *MockProxy_Upsert_Call { - return &MockProxy_Upsert_Call{Call: _e.mock.On("Upsert", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.UpsertRequest +func (_e *MockProxy_Expecter) Upsert(_a0 interface{}, _a1 interface{}) *MockProxy_Upsert_Call { + return &MockProxy_Upsert_Call{Call: _e.mock.On("Upsert", _a0, _a1)} } -func (_c *MockProxy_Upsert_Call) Run(run func(ctx context.Context, request *milvuspb.UpsertRequest)) *MockProxy_Upsert_Call { +func (_c *MockProxy_Upsert_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.UpsertRequest)) *MockProxy_Upsert_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.UpsertRequest)) }) diff --git a/internal/mocks/mock_proxy_client.go b/internal/mocks/mock_proxy_client.go new file mode 100644 index 0000000000000..0d74d3cd46e30 --- /dev/null +++ b/internal/mocks/mock_proxy_client.go @@ -0,0 +1,787 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + + commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + + grpc "google.golang.org/grpc" + + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" + + milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + + mock "github.com/stretchr/testify/mock" + + proxypb "github.com/milvus-io/milvus/internal/proto/proxypb" +) + +// MockProxyClient is an autogenerated mock type for the ProxyClient type +type MockProxyClient struct { + mock.Mock +} + +type MockProxyClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockProxyClient) EXPECT() *MockProxyClient_Expecter { + return &MockProxyClient_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with given fields: +func (_m *MockProxyClient) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockProxyClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockProxyClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockProxyClient_Expecter) Close() *MockProxyClient_Close_Call { + return &MockProxyClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockProxyClient_Close_Call) Run(run func()) *MockProxyClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockProxyClient_Close_Call) Return(_a0 error) *MockProxyClient_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProxyClient_Close_Call) RunAndReturn(run func() error) *MockProxyClient_Close_Call { + _c.Call.Return(run) + return _c +} + +// GetComponentStates provides a mock function with given fields: ctx, in, opts +func (_m *MockProxyClient) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ComponentStates + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) *milvuspb.ComponentStates); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ComponentStates) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxyClient_GetComponentStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetComponentStates' +type MockProxyClient_GetComponentStates_Call struct { + *mock.Call +} + +// GetComponentStates is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetComponentStatesRequest +// - opts ...grpc.CallOption +func (_e *MockProxyClient_Expecter) GetComponentStates(ctx interface{}, in interface{}, opts ...interface{}) *MockProxyClient_GetComponentStates_Call { + return &MockProxyClient_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockProxyClient_GetComponentStates_Call) Run(run func(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption)) *MockProxyClient_GetComponentStates_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockProxyClient_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentStates, _a1 error) *MockProxyClient_GetComponentStates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxyClient_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)) *MockProxyClient_GetComponentStates_Call { + _c.Call.Return(run) + return _c +} + +// GetDdChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockProxyClient) GetDdChannel(ctx context.Context, in *internalpb.GetDdChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.StringResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetDdChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetDdChannelRequest, ...grpc.CallOption) *milvuspb.StringResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetDdChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxyClient_GetDdChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDdChannel' +type MockProxyClient_GetDdChannel_Call struct { + *mock.Call +} + +// GetDdChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.GetDdChannelRequest +// - opts ...grpc.CallOption +func (_e *MockProxyClient_Expecter) GetDdChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockProxyClient_GetDdChannel_Call { + return &MockProxyClient_GetDdChannel_Call{Call: _e.mock.On("GetDdChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockProxyClient_GetDdChannel_Call) Run(run func(ctx context.Context, in *internalpb.GetDdChannelRequest, opts ...grpc.CallOption)) *MockProxyClient_GetDdChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.GetDdChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockProxyClient_GetDdChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *MockProxyClient_GetDdChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxyClient_GetDdChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetDdChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)) *MockProxyClient_GetDdChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetProxyMetrics provides a mock function with given fields: ctx, in, opts +func (_m *MockProxyClient) GetProxyMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.GetMetricsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) *milvuspb.GetMetricsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxyClient_GetProxyMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetProxyMetrics' +type MockProxyClient_GetProxyMetrics_Call struct { + *mock.Call +} + +// GetProxyMetrics is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetMetricsRequest +// - opts ...grpc.CallOption +func (_e *MockProxyClient_Expecter) GetProxyMetrics(ctx interface{}, in interface{}, opts ...interface{}) *MockProxyClient_GetProxyMetrics_Call { + return &MockProxyClient_GetProxyMetrics_Call{Call: _e.mock.On("GetProxyMetrics", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockProxyClient_GetProxyMetrics_Call) Run(run func(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption)) *MockProxyClient_GetProxyMetrics_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockProxyClient_GetProxyMetrics_Call) Return(_a0 *milvuspb.GetMetricsResponse, _a1 error) *MockProxyClient_GetProxyMetrics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxyClient_GetProxyMetrics_Call) RunAndReturn(run func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)) *MockProxyClient_GetProxyMetrics_Call { + _c.Call.Return(run) + return _c +} + +// GetStatisticsChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockProxyClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.StringResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) *milvuspb.StringResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxyClient_GetStatisticsChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatisticsChannel' +type MockProxyClient_GetStatisticsChannel_Call struct { + *mock.Call +} + +// GetStatisticsChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.GetStatisticsChannelRequest +// - opts ...grpc.CallOption +func (_e *MockProxyClient_Expecter) GetStatisticsChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockProxyClient_GetStatisticsChannel_Call { + return &MockProxyClient_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockProxyClient_GetStatisticsChannel_Call) Run(run func(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption)) *MockProxyClient_GetStatisticsChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockProxyClient_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *MockProxyClient_GetStatisticsChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxyClient_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)) *MockProxyClient_GetStatisticsChannel_Call { + _c.Call.Return(run) + return _c +} + +// InvalidateCollectionMetaCache provides a mock function with given fields: ctx, in, opts +func (_m *MockProxyClient) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxyClient_InvalidateCollectionMetaCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateCollectionMetaCache' +type MockProxyClient_InvalidateCollectionMetaCache_Call struct { + *mock.Call +} + +// InvalidateCollectionMetaCache is a helper method to define mock.On call +// - ctx context.Context +// - in *proxypb.InvalidateCollMetaCacheRequest +// - opts ...grpc.CallOption +func (_e *MockProxyClient_Expecter) InvalidateCollectionMetaCache(ctx interface{}, in interface{}, opts ...interface{}) *MockProxyClient_InvalidateCollectionMetaCache_Call { + return &MockProxyClient_InvalidateCollectionMetaCache_Call{Call: _e.mock.On("InvalidateCollectionMetaCache", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockProxyClient_InvalidateCollectionMetaCache_Call) Run(run func(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest, opts ...grpc.CallOption)) *MockProxyClient_InvalidateCollectionMetaCache_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*proxypb.InvalidateCollMetaCacheRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockProxyClient_InvalidateCollectionMetaCache_Call) Return(_a0 *commonpb.Status, _a1 error) *MockProxyClient_InvalidateCollectionMetaCache_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxyClient_InvalidateCollectionMetaCache_Call) RunAndReturn(run func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockProxyClient_InvalidateCollectionMetaCache_Call { + _c.Call.Return(run) + return _c +} + +// InvalidateCredentialCache provides a mock function with given fields: ctx, in, opts +func (_m *MockProxyClient) InvalidateCredentialCache(ctx context.Context, in *proxypb.InvalidateCredCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCredCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCredCacheRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *proxypb.InvalidateCredCacheRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxyClient_InvalidateCredentialCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateCredentialCache' +type MockProxyClient_InvalidateCredentialCache_Call struct { + *mock.Call +} + +// InvalidateCredentialCache is a helper method to define mock.On call +// - ctx context.Context +// - in *proxypb.InvalidateCredCacheRequest +// - opts ...grpc.CallOption +func (_e *MockProxyClient_Expecter) InvalidateCredentialCache(ctx interface{}, in interface{}, opts ...interface{}) *MockProxyClient_InvalidateCredentialCache_Call { + return &MockProxyClient_InvalidateCredentialCache_Call{Call: _e.mock.On("InvalidateCredentialCache", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockProxyClient_InvalidateCredentialCache_Call) Run(run func(ctx context.Context, in *proxypb.InvalidateCredCacheRequest, opts ...grpc.CallOption)) *MockProxyClient_InvalidateCredentialCache_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*proxypb.InvalidateCredCacheRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockProxyClient_InvalidateCredentialCache_Call) Return(_a0 *commonpb.Status, _a1 error) *MockProxyClient_InvalidateCredentialCache_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxyClient_InvalidateCredentialCache_Call) RunAndReturn(run func(context.Context, *proxypb.InvalidateCredCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockProxyClient_InvalidateCredentialCache_Call { + _c.Call.Return(run) + return _c +} + +// ListClientInfos provides a mock function with given fields: ctx, in, opts +func (_m *MockProxyClient) ListClientInfos(ctx context.Context, in *proxypb.ListClientInfosRequest, opts ...grpc.CallOption) (*proxypb.ListClientInfosResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *proxypb.ListClientInfosResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.ListClientInfosRequest, ...grpc.CallOption) (*proxypb.ListClientInfosResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.ListClientInfosRequest, ...grpc.CallOption) *proxypb.ListClientInfosResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*proxypb.ListClientInfosResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *proxypb.ListClientInfosRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxyClient_ListClientInfos_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListClientInfos' +type MockProxyClient_ListClientInfos_Call struct { + *mock.Call +} + +// ListClientInfos is a helper method to define mock.On call +// - ctx context.Context +// - in *proxypb.ListClientInfosRequest +// - opts ...grpc.CallOption +func (_e *MockProxyClient_Expecter) ListClientInfos(ctx interface{}, in interface{}, opts ...interface{}) *MockProxyClient_ListClientInfos_Call { + return &MockProxyClient_ListClientInfos_Call{Call: _e.mock.On("ListClientInfos", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockProxyClient_ListClientInfos_Call) Run(run func(ctx context.Context, in *proxypb.ListClientInfosRequest, opts ...grpc.CallOption)) *MockProxyClient_ListClientInfos_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*proxypb.ListClientInfosRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockProxyClient_ListClientInfos_Call) Return(_a0 *proxypb.ListClientInfosResponse, _a1 error) *MockProxyClient_ListClientInfos_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxyClient_ListClientInfos_Call) RunAndReturn(run func(context.Context, *proxypb.ListClientInfosRequest, ...grpc.CallOption) (*proxypb.ListClientInfosResponse, error)) *MockProxyClient_ListClientInfos_Call { + _c.Call.Return(run) + return _c +} + +// RefreshPolicyInfoCache provides a mock function with given fields: ctx, in, opts +func (_m *MockProxyClient) RefreshPolicyInfoCache(ctx context.Context, in *proxypb.RefreshPolicyInfoCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxyClient_RefreshPolicyInfoCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RefreshPolicyInfoCache' +type MockProxyClient_RefreshPolicyInfoCache_Call struct { + *mock.Call +} + +// RefreshPolicyInfoCache is a helper method to define mock.On call +// - ctx context.Context +// - in *proxypb.RefreshPolicyInfoCacheRequest +// - opts ...grpc.CallOption +func (_e *MockProxyClient_Expecter) RefreshPolicyInfoCache(ctx interface{}, in interface{}, opts ...interface{}) *MockProxyClient_RefreshPolicyInfoCache_Call { + return &MockProxyClient_RefreshPolicyInfoCache_Call{Call: _e.mock.On("RefreshPolicyInfoCache", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockProxyClient_RefreshPolicyInfoCache_Call) Run(run func(ctx context.Context, in *proxypb.RefreshPolicyInfoCacheRequest, opts ...grpc.CallOption)) *MockProxyClient_RefreshPolicyInfoCache_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*proxypb.RefreshPolicyInfoCacheRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockProxyClient_RefreshPolicyInfoCache_Call) Return(_a0 *commonpb.Status, _a1 error) *MockProxyClient_RefreshPolicyInfoCache_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxyClient_RefreshPolicyInfoCache_Call) RunAndReturn(run func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockProxyClient_RefreshPolicyInfoCache_Call { + _c.Call.Return(run) + return _c +} + +// SetRates provides a mock function with given fields: ctx, in, opts +func (_m *MockProxyClient) SetRates(ctx context.Context, in *proxypb.SetRatesRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.SetRatesRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.SetRatesRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *proxypb.SetRatesRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxyClient_SetRates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRates' +type MockProxyClient_SetRates_Call struct { + *mock.Call +} + +// SetRates is a helper method to define mock.On call +// - ctx context.Context +// - in *proxypb.SetRatesRequest +// - opts ...grpc.CallOption +func (_e *MockProxyClient_Expecter) SetRates(ctx interface{}, in interface{}, opts ...interface{}) *MockProxyClient_SetRates_Call { + return &MockProxyClient_SetRates_Call{Call: _e.mock.On("SetRates", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockProxyClient_SetRates_Call) Run(run func(ctx context.Context, in *proxypb.SetRatesRequest, opts ...grpc.CallOption)) *MockProxyClient_SetRates_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*proxypb.SetRatesRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockProxyClient_SetRates_Call) Return(_a0 *commonpb.Status, _a1 error) *MockProxyClient_SetRates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxyClient_SetRates_Call) RunAndReturn(run func(context.Context, *proxypb.SetRatesRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockProxyClient_SetRates_Call { + _c.Call.Return(run) + return _c +} + +// UpdateCredentialCache provides a mock function with given fields: ctx, in, opts +func (_m *MockProxyClient) UpdateCredentialCache(ctx context.Context, in *proxypb.UpdateCredCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.UpdateCredCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.UpdateCredCacheRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *proxypb.UpdateCredCacheRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxyClient_UpdateCredentialCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCredentialCache' +type MockProxyClient_UpdateCredentialCache_Call struct { + *mock.Call +} + +// UpdateCredentialCache is a helper method to define mock.On call +// - ctx context.Context +// - in *proxypb.UpdateCredCacheRequest +// - opts ...grpc.CallOption +func (_e *MockProxyClient_Expecter) UpdateCredentialCache(ctx interface{}, in interface{}, opts ...interface{}) *MockProxyClient_UpdateCredentialCache_Call { + return &MockProxyClient_UpdateCredentialCache_Call{Call: _e.mock.On("UpdateCredentialCache", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockProxyClient_UpdateCredentialCache_Call) Run(run func(ctx context.Context, in *proxypb.UpdateCredCacheRequest, opts ...grpc.CallOption)) *MockProxyClient_UpdateCredentialCache_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*proxypb.UpdateCredCacheRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockProxyClient_UpdateCredentialCache_Call) Return(_a0 *commonpb.Status, _a1 error) *MockProxyClient_UpdateCredentialCache_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxyClient_UpdateCredentialCache_Call) RunAndReturn(run func(context.Context, *proxypb.UpdateCredCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockProxyClient_UpdateCredentialCache_Call { + _c.Call.Return(run) + return _c +} + +// NewMockProxyClient creates a new instance of MockProxyClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockProxyClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockProxyClient { + mock := &MockProxyClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/mock_querycoord.go b/internal/mocks/mock_querycoord.go index 78d66aab99e61..8b632e1b87d9d 100644 --- a/internal/mocks/mock_querycoord.go +++ b/internal/mocks/mock_querycoord.go @@ -16,6 +16,8 @@ import ( querypb "github.com/milvus-io/milvus/internal/proto/querypb" + txnkv "github.com/tikv/client-go/v2/txnkv" + types "github.com/milvus-io/milvus/internal/types" ) @@ -32,17 +34,17 @@ func (_m *MockQueryCoord) EXPECT() *MockQueryCoord_Expecter { return &MockQueryCoord_Expecter{mock: &_m.Mock} } -// CheckHealth provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { - ret := _m.Called(ctx, req) +// CheckHealth provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) CheckHealth(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.CheckHealthResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) *milvuspb.CheckHealthResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) @@ -50,7 +52,7 @@ func (_m *MockQueryCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHe } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CheckHealthRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -64,13 +66,13 @@ type MockQueryCoord_CheckHealth_Call struct { } // CheckHealth is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CheckHealthRequest -func (_e *MockQueryCoord_Expecter) CheckHealth(ctx interface{}, req interface{}) *MockQueryCoord_CheckHealth_Call { - return &MockQueryCoord_CheckHealth_Call{Call: _e.mock.On("CheckHealth", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CheckHealthRequest +func (_e *MockQueryCoord_Expecter) CheckHealth(_a0 interface{}, _a1 interface{}) *MockQueryCoord_CheckHealth_Call { + return &MockQueryCoord_CheckHealth_Call{Call: _e.mock.On("CheckHealth", _a0, _a1)} } -func (_c *MockQueryCoord_CheckHealth_Call) Run(run func(ctx context.Context, req *milvuspb.CheckHealthRequest)) *MockQueryCoord_CheckHealth_Call { +func (_c *MockQueryCoord_CheckHealth_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest)) *MockQueryCoord_CheckHealth_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CheckHealthRequest)) }) @@ -87,17 +89,17 @@ func (_c *MockQueryCoord_CheckHealth_Call) RunAndReturn(run func(context.Context return _c } -// CreateResourceGroup provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) CreateResourceGroup(ctx context.Context, req *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// CreateResourceGroup provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) CreateResourceGroup(_a0 context.Context, _a1 *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateResourceGroupRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -105,7 +107,7 @@ func (_m *MockQueryCoord) CreateResourceGroup(ctx context.Context, req *milvuspb } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateResourceGroupRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -119,13 +121,13 @@ type MockQueryCoord_CreateResourceGroup_Call struct { } // CreateResourceGroup is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CreateResourceGroupRequest -func (_e *MockQueryCoord_Expecter) CreateResourceGroup(ctx interface{}, req interface{}) *MockQueryCoord_CreateResourceGroup_Call { - return &MockQueryCoord_CreateResourceGroup_Call{Call: _e.mock.On("CreateResourceGroup", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CreateResourceGroupRequest +func (_e *MockQueryCoord_Expecter) CreateResourceGroup(_a0 interface{}, _a1 interface{}) *MockQueryCoord_CreateResourceGroup_Call { + return &MockQueryCoord_CreateResourceGroup_Call{Call: _e.mock.On("CreateResourceGroup", _a0, _a1)} } -func (_c *MockQueryCoord_CreateResourceGroup_Call) Run(run func(ctx context.Context, req *milvuspb.CreateResourceGroupRequest)) *MockQueryCoord_CreateResourceGroup_Call { +func (_c *MockQueryCoord_CreateResourceGroup_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreateResourceGroupRequest)) *MockQueryCoord_CreateResourceGroup_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreateResourceGroupRequest)) }) @@ -142,17 +144,17 @@ func (_c *MockQueryCoord_CreateResourceGroup_Call) RunAndReturn(run func(context return _c } -// DescribeResourceGroup provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) DescribeResourceGroup(ctx context.Context, req *querypb.DescribeResourceGroupRequest) (*querypb.DescribeResourceGroupResponse, error) { - ret := _m.Called(ctx, req) +// DescribeResourceGroup provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) DescribeResourceGroup(_a0 context.Context, _a1 *querypb.DescribeResourceGroupRequest) (*querypb.DescribeResourceGroupResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *querypb.DescribeResourceGroupResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.DescribeResourceGroupRequest) (*querypb.DescribeResourceGroupResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.DescribeResourceGroupRequest) *querypb.DescribeResourceGroupResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*querypb.DescribeResourceGroupResponse) @@ -160,7 +162,7 @@ func (_m *MockQueryCoord) DescribeResourceGroup(ctx context.Context, req *queryp } if rf, ok := ret.Get(1).(func(context.Context, *querypb.DescribeResourceGroupRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -174,13 +176,13 @@ type MockQueryCoord_DescribeResourceGroup_Call struct { } // DescribeResourceGroup is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.DescribeResourceGroupRequest -func (_e *MockQueryCoord_Expecter) DescribeResourceGroup(ctx interface{}, req interface{}) *MockQueryCoord_DescribeResourceGroup_Call { - return &MockQueryCoord_DescribeResourceGroup_Call{Call: _e.mock.On("DescribeResourceGroup", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.DescribeResourceGroupRequest +func (_e *MockQueryCoord_Expecter) DescribeResourceGroup(_a0 interface{}, _a1 interface{}) *MockQueryCoord_DescribeResourceGroup_Call { + return &MockQueryCoord_DescribeResourceGroup_Call{Call: _e.mock.On("DescribeResourceGroup", _a0, _a1)} } -func (_c *MockQueryCoord_DescribeResourceGroup_Call) Run(run func(ctx context.Context, req *querypb.DescribeResourceGroupRequest)) *MockQueryCoord_DescribeResourceGroup_Call { +func (_c *MockQueryCoord_DescribeResourceGroup_Call) Run(run func(_a0 context.Context, _a1 *querypb.DescribeResourceGroupRequest)) *MockQueryCoord_DescribeResourceGroup_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.DescribeResourceGroupRequest)) }) @@ -197,17 +199,17 @@ func (_c *MockQueryCoord_DescribeResourceGroup_Call) RunAndReturn(run func(conte return _c } -// DropResourceGroup provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) DropResourceGroup(ctx context.Context, req *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// DropResourceGroup provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) DropResourceGroup(_a0 context.Context, _a1 *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropResourceGroupRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -215,7 +217,7 @@ func (_m *MockQueryCoord) DropResourceGroup(ctx context.Context, req *milvuspb.D } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropResourceGroupRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -229,13 +231,13 @@ type MockQueryCoord_DropResourceGroup_Call struct { } // DropResourceGroup is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DropResourceGroupRequest -func (_e *MockQueryCoord_Expecter) DropResourceGroup(ctx interface{}, req interface{}) *MockQueryCoord_DropResourceGroup_Call { - return &MockQueryCoord_DropResourceGroup_Call{Call: _e.mock.On("DropResourceGroup", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DropResourceGroupRequest +func (_e *MockQueryCoord_Expecter) DropResourceGroup(_a0 interface{}, _a1 interface{}) *MockQueryCoord_DropResourceGroup_Call { + return &MockQueryCoord_DropResourceGroup_Call{Call: _e.mock.On("DropResourceGroup", _a0, _a1)} } -func (_c *MockQueryCoord_DropResourceGroup_Call) Run(run func(ctx context.Context, req *milvuspb.DropResourceGroupRequest)) *MockQueryCoord_DropResourceGroup_Call { +func (_c *MockQueryCoord_DropResourceGroup_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DropResourceGroupRequest)) *MockQueryCoord_DropResourceGroup_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DropResourceGroupRequest)) }) @@ -252,25 +254,25 @@ func (_c *MockQueryCoord_DropResourceGroup_Call) RunAndReturn(run func(context.C return _c } -// GetComponentStates provides a mock function with given fields: ctx -func (_m *MockQueryCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret := _m.Called(ctx) +// GetComponentStates provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) GetComponentStates(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ComponentStates var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.ComponentStates, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.ComponentStates); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) *milvuspb.ComponentStates); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ComponentStates) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -284,14 +286,15 @@ type MockQueryCoord_GetComponentStates_Call struct { } // GetComponentStates is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockQueryCoord_Expecter) GetComponentStates(ctx interface{}) *MockQueryCoord_GetComponentStates_Call { - return &MockQueryCoord_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx)} +// - _a0 context.Context +// - _a1 *milvuspb.GetComponentStatesRequest +func (_e *MockQueryCoord_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *MockQueryCoord_GetComponentStates_Call { + return &MockQueryCoord_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)} } -func (_c *MockQueryCoord_GetComponentStates_Call) Run(run func(ctx context.Context)) *MockQueryCoord_GetComponentStates_Call { +func (_c *MockQueryCoord_GetComponentStates_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest)) *MockQueryCoord_GetComponentStates_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest)) }) return _c } @@ -301,22 +304,22 @@ func (_c *MockQueryCoord_GetComponentStates_Call) Return(_a0 *milvuspb.Component return _c } -func (_c *MockQueryCoord_GetComponentStates_Call) RunAndReturn(run func(context.Context) (*milvuspb.ComponentStates, error)) *MockQueryCoord_GetComponentStates_Call { +func (_c *MockQueryCoord_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)) *MockQueryCoord_GetComponentStates_Call { _c.Call.Return(run) return _c } -// GetMetrics provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret := _m.Called(ctx, req) +// GetMetrics provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) GetMetrics(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetMetricsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) *milvuspb.GetMetricsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) @@ -324,7 +327,7 @@ func (_m *MockQueryCoord) GetMetrics(ctx context.Context, req *milvuspb.GetMetri } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -338,13 +341,13 @@ type MockQueryCoord_GetMetrics_Call struct { } // GetMetrics is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetMetricsRequest -func (_e *MockQueryCoord_Expecter) GetMetrics(ctx interface{}, req interface{}) *MockQueryCoord_GetMetrics_Call { - return &MockQueryCoord_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetMetricsRequest +func (_e *MockQueryCoord_Expecter) GetMetrics(_a0 interface{}, _a1 interface{}) *MockQueryCoord_GetMetrics_Call { + return &MockQueryCoord_GetMetrics_Call{Call: _e.mock.On("GetMetrics", _a0, _a1)} } -func (_c *MockQueryCoord_GetMetrics_Call) Run(run func(ctx context.Context, req *milvuspb.GetMetricsRequest)) *MockQueryCoord_GetMetrics_Call { +func (_c *MockQueryCoord_GetMetrics_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest)) *MockQueryCoord_GetMetrics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest)) }) @@ -361,17 +364,17 @@ func (_c *MockQueryCoord_GetMetrics_Call) RunAndReturn(run func(context.Context, return _c } -// GetPartitionStates provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) { - ret := _m.Called(ctx, req) +// GetPartitionStates provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) GetPartitionStates(_a0 context.Context, _a1 *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *querypb.GetPartitionStatesResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetPartitionStatesRequest) *querypb.GetPartitionStatesResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*querypb.GetPartitionStatesResponse) @@ -379,7 +382,7 @@ func (_m *MockQueryCoord) GetPartitionStates(ctx context.Context, req *querypb.G } if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetPartitionStatesRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -393,13 +396,13 @@ type MockQueryCoord_GetPartitionStates_Call struct { } // GetPartitionStates is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.GetPartitionStatesRequest -func (_e *MockQueryCoord_Expecter) GetPartitionStates(ctx interface{}, req interface{}) *MockQueryCoord_GetPartitionStates_Call { - return &MockQueryCoord_GetPartitionStates_Call{Call: _e.mock.On("GetPartitionStates", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.GetPartitionStatesRequest +func (_e *MockQueryCoord_Expecter) GetPartitionStates(_a0 interface{}, _a1 interface{}) *MockQueryCoord_GetPartitionStates_Call { + return &MockQueryCoord_GetPartitionStates_Call{Call: _e.mock.On("GetPartitionStates", _a0, _a1)} } -func (_c *MockQueryCoord_GetPartitionStates_Call) Run(run func(ctx context.Context, req *querypb.GetPartitionStatesRequest)) *MockQueryCoord_GetPartitionStates_Call { +func (_c *MockQueryCoord_GetPartitionStates_Call) Run(run func(_a0 context.Context, _a1 *querypb.GetPartitionStatesRequest)) *MockQueryCoord_GetPartitionStates_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.GetPartitionStatesRequest)) }) @@ -416,17 +419,17 @@ func (_c *MockQueryCoord_GetPartitionStates_Call) RunAndReturn(run func(context. return _c } -// GetReplicas provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) { - ret := _m.Called(ctx, req) +// GetReplicas provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) GetReplicas(_a0 context.Context, _a1 *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetReplicasResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetReplicasRequest) *milvuspb.GetReplicasResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetReplicasResponse) @@ -434,7 +437,7 @@ func (_m *MockQueryCoord) GetReplicas(ctx context.Context, req *milvuspb.GetRepl } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetReplicasRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -448,13 +451,13 @@ type MockQueryCoord_GetReplicas_Call struct { } // GetReplicas is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetReplicasRequest -func (_e *MockQueryCoord_Expecter) GetReplicas(ctx interface{}, req interface{}) *MockQueryCoord_GetReplicas_Call { - return &MockQueryCoord_GetReplicas_Call{Call: _e.mock.On("GetReplicas", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetReplicasRequest +func (_e *MockQueryCoord_Expecter) GetReplicas(_a0 interface{}, _a1 interface{}) *MockQueryCoord_GetReplicas_Call { + return &MockQueryCoord_GetReplicas_Call{Call: _e.mock.On("GetReplicas", _a0, _a1)} } -func (_c *MockQueryCoord_GetReplicas_Call) Run(run func(ctx context.Context, req *milvuspb.GetReplicasRequest)) *MockQueryCoord_GetReplicas_Call { +func (_c *MockQueryCoord_GetReplicas_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetReplicasRequest)) *MockQueryCoord_GetReplicas_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetReplicasRequest)) }) @@ -471,17 +474,17 @@ func (_c *MockQueryCoord_GetReplicas_Call) RunAndReturn(run func(context.Context return _c } -// GetSegmentInfo provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { - ret := _m.Called(ctx, req) +// GetSegmentInfo provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) GetSegmentInfo(_a0 context.Context, _a1 *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *querypb.GetSegmentInfoResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetSegmentInfoRequest) *querypb.GetSegmentInfoResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*querypb.GetSegmentInfoResponse) @@ -489,7 +492,7 @@ func (_m *MockQueryCoord) GetSegmentInfo(ctx context.Context, req *querypb.GetSe } if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetSegmentInfoRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -503,13 +506,13 @@ type MockQueryCoord_GetSegmentInfo_Call struct { } // GetSegmentInfo is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.GetSegmentInfoRequest -func (_e *MockQueryCoord_Expecter) GetSegmentInfo(ctx interface{}, req interface{}) *MockQueryCoord_GetSegmentInfo_Call { - return &MockQueryCoord_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.GetSegmentInfoRequest +func (_e *MockQueryCoord_Expecter) GetSegmentInfo(_a0 interface{}, _a1 interface{}) *MockQueryCoord_GetSegmentInfo_Call { + return &MockQueryCoord_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", _a0, _a1)} } -func (_c *MockQueryCoord_GetSegmentInfo_Call) Run(run func(ctx context.Context, req *querypb.GetSegmentInfoRequest)) *MockQueryCoord_GetSegmentInfo_Call { +func (_c *MockQueryCoord_GetSegmentInfo_Call) Run(run func(_a0 context.Context, _a1 *querypb.GetSegmentInfoRequest)) *MockQueryCoord_GetSegmentInfo_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.GetSegmentInfoRequest)) }) @@ -526,17 +529,17 @@ func (_c *MockQueryCoord_GetSegmentInfo_Call) RunAndReturn(run func(context.Cont return _c } -// GetShardLeaders provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) { - ret := _m.Called(ctx, req) +// GetShardLeaders provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) GetShardLeaders(_a0 context.Context, _a1 *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *querypb.GetShardLeadersResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetShardLeadersRequest) *querypb.GetShardLeadersResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*querypb.GetShardLeadersResponse) @@ -544,7 +547,7 @@ func (_m *MockQueryCoord) GetShardLeaders(ctx context.Context, req *querypb.GetS } if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetShardLeadersRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -558,13 +561,13 @@ type MockQueryCoord_GetShardLeaders_Call struct { } // GetShardLeaders is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.GetShardLeadersRequest -func (_e *MockQueryCoord_Expecter) GetShardLeaders(ctx interface{}, req interface{}) *MockQueryCoord_GetShardLeaders_Call { - return &MockQueryCoord_GetShardLeaders_Call{Call: _e.mock.On("GetShardLeaders", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.GetShardLeadersRequest +func (_e *MockQueryCoord_Expecter) GetShardLeaders(_a0 interface{}, _a1 interface{}) *MockQueryCoord_GetShardLeaders_Call { + return &MockQueryCoord_GetShardLeaders_Call{Call: _e.mock.On("GetShardLeaders", _a0, _a1)} } -func (_c *MockQueryCoord_GetShardLeaders_Call) Run(run func(ctx context.Context, req *querypb.GetShardLeadersRequest)) *MockQueryCoord_GetShardLeaders_Call { +func (_c *MockQueryCoord_GetShardLeaders_Call) Run(run func(_a0 context.Context, _a1 *querypb.GetShardLeadersRequest)) *MockQueryCoord_GetShardLeaders_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.GetShardLeadersRequest)) }) @@ -581,25 +584,25 @@ func (_c *MockQueryCoord_GetShardLeaders_Call) RunAndReturn(run func(context.Con return _c } -// GetStatisticsChannel provides a mock function with given fields: ctx -func (_m *MockQueryCoord) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret := _m.Called(ctx) +// GetStatisticsChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) GetStatisticsChannel(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.StringResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.StringResponse, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.StringResponse); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) *milvuspb.StringResponse); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.StringResponse) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -613,14 +616,15 @@ type MockQueryCoord_GetStatisticsChannel_Call struct { } // GetStatisticsChannel is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockQueryCoord_Expecter) GetStatisticsChannel(ctx interface{}) *MockQueryCoord_GetStatisticsChannel_Call { - return &MockQueryCoord_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", ctx)} +// - _a0 context.Context +// - _a1 *internalpb.GetStatisticsChannelRequest +func (_e *MockQueryCoord_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *MockQueryCoord_GetStatisticsChannel_Call { + return &MockQueryCoord_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)} } -func (_c *MockQueryCoord_GetStatisticsChannel_Call) Run(run func(ctx context.Context)) *MockQueryCoord_GetStatisticsChannel_Call { +func (_c *MockQueryCoord_GetStatisticsChannel_Call) Run(run func(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest)) *MockQueryCoord_GetStatisticsChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest)) }) return _c } @@ -630,30 +634,30 @@ func (_c *MockQueryCoord_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringR return _c } -func (_c *MockQueryCoord_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context) (*milvuspb.StringResponse, error)) *MockQueryCoord_GetStatisticsChannel_Call { +func (_c *MockQueryCoord_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)) *MockQueryCoord_GetStatisticsChannel_Call { _c.Call.Return(run) return _c } -// GetTimeTickChannel provides a mock function with given fields: ctx -func (_m *MockQueryCoord) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret := _m.Called(ctx) +// GetTimeTickChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) GetTimeTickChannel(_a0 context.Context, _a1 *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.StringResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.StringResponse, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.StringResponse); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest) *milvuspb.StringResponse); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.StringResponse) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetTimeTickChannelRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -667,14 +671,15 @@ type MockQueryCoord_GetTimeTickChannel_Call struct { } // GetTimeTickChannel is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockQueryCoord_Expecter) GetTimeTickChannel(ctx interface{}) *MockQueryCoord_GetTimeTickChannel_Call { - return &MockQueryCoord_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", ctx)} +// - _a0 context.Context +// - _a1 *internalpb.GetTimeTickChannelRequest +func (_e *MockQueryCoord_Expecter) GetTimeTickChannel(_a0 interface{}, _a1 interface{}) *MockQueryCoord_GetTimeTickChannel_Call { + return &MockQueryCoord_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", _a0, _a1)} } -func (_c *MockQueryCoord_GetTimeTickChannel_Call) Run(run func(ctx context.Context)) *MockQueryCoord_GetTimeTickChannel_Call { +func (_c *MockQueryCoord_GetTimeTickChannel_Call) Run(run func(_a0 context.Context, _a1 *internalpb.GetTimeTickChannelRequest)) *MockQueryCoord_GetTimeTickChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*internalpb.GetTimeTickChannelRequest)) }) return _c } @@ -684,7 +689,7 @@ func (_c *MockQueryCoord_GetTimeTickChannel_Call) Return(_a0 *milvuspb.StringRes return _c } -func (_c *MockQueryCoord_GetTimeTickChannel_Call) RunAndReturn(run func(context.Context) (*milvuspb.StringResponse, error)) *MockQueryCoord_GetTimeTickChannel_Call { +func (_c *MockQueryCoord_GetTimeTickChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error)) *MockQueryCoord_GetTimeTickChannel_Call { _c.Call.Return(run) return _c } @@ -730,17 +735,17 @@ func (_c *MockQueryCoord_Init_Call) RunAndReturn(run func() error) *MockQueryCoo return _c } -// ListResourceGroups provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) ListResourceGroups(ctx context.Context, req *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) { - ret := _m.Called(ctx, req) +// ListResourceGroups provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) ListResourceGroups(_a0 context.Context, _a1 *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ListResourceGroupsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListResourceGroupsRequest) *milvuspb.ListResourceGroupsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ListResourceGroupsResponse) @@ -748,7 +753,7 @@ func (_m *MockQueryCoord) ListResourceGroups(ctx context.Context, req *milvuspb. } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListResourceGroupsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -762,13 +767,13 @@ type MockQueryCoord_ListResourceGroups_Call struct { } // ListResourceGroups is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ListResourceGroupsRequest -func (_e *MockQueryCoord_Expecter) ListResourceGroups(ctx interface{}, req interface{}) *MockQueryCoord_ListResourceGroups_Call { - return &MockQueryCoord_ListResourceGroups_Call{Call: _e.mock.On("ListResourceGroups", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ListResourceGroupsRequest +func (_e *MockQueryCoord_Expecter) ListResourceGroups(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ListResourceGroups_Call { + return &MockQueryCoord_ListResourceGroups_Call{Call: _e.mock.On("ListResourceGroups", _a0, _a1)} } -func (_c *MockQueryCoord_ListResourceGroups_Call) Run(run func(ctx context.Context, req *milvuspb.ListResourceGroupsRequest)) *MockQueryCoord_ListResourceGroups_Call { +func (_c *MockQueryCoord_ListResourceGroups_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ListResourceGroupsRequest)) *MockQueryCoord_ListResourceGroups_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ListResourceGroupsRequest)) }) @@ -785,17 +790,17 @@ func (_c *MockQueryCoord_ListResourceGroups_Call) RunAndReturn(run func(context. return _c } -// LoadBalance provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// LoadBalance provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) LoadBalance(_a0 context.Context, _a1 *querypb.LoadBalanceRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadBalanceRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadBalanceRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -803,7 +808,7 @@ func (_m *MockQueryCoord) LoadBalance(ctx context.Context, req *querypb.LoadBala } if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadBalanceRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -817,13 +822,13 @@ type MockQueryCoord_LoadBalance_Call struct { } // LoadBalance is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.LoadBalanceRequest -func (_e *MockQueryCoord_Expecter) LoadBalance(ctx interface{}, req interface{}) *MockQueryCoord_LoadBalance_Call { - return &MockQueryCoord_LoadBalance_Call{Call: _e.mock.On("LoadBalance", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.LoadBalanceRequest +func (_e *MockQueryCoord_Expecter) LoadBalance(_a0 interface{}, _a1 interface{}) *MockQueryCoord_LoadBalance_Call { + return &MockQueryCoord_LoadBalance_Call{Call: _e.mock.On("LoadBalance", _a0, _a1)} } -func (_c *MockQueryCoord_LoadBalance_Call) Run(run func(ctx context.Context, req *querypb.LoadBalanceRequest)) *MockQueryCoord_LoadBalance_Call { +func (_c *MockQueryCoord_LoadBalance_Call) Run(run func(_a0 context.Context, _a1 *querypb.LoadBalanceRequest)) *MockQueryCoord_LoadBalance_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.LoadBalanceRequest)) }) @@ -840,17 +845,17 @@ func (_c *MockQueryCoord_LoadBalance_Call) RunAndReturn(run func(context.Context return _c } -// LoadCollection provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// LoadCollection provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) LoadCollection(_a0 context.Context, _a1 *querypb.LoadCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -858,7 +863,7 @@ func (_m *MockQueryCoord) LoadCollection(ctx context.Context, req *querypb.LoadC } if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadCollectionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -872,13 +877,13 @@ type MockQueryCoord_LoadCollection_Call struct { } // LoadCollection is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.LoadCollectionRequest -func (_e *MockQueryCoord_Expecter) LoadCollection(ctx interface{}, req interface{}) *MockQueryCoord_LoadCollection_Call { - return &MockQueryCoord_LoadCollection_Call{Call: _e.mock.On("LoadCollection", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.LoadCollectionRequest +func (_e *MockQueryCoord_Expecter) LoadCollection(_a0 interface{}, _a1 interface{}) *MockQueryCoord_LoadCollection_Call { + return &MockQueryCoord_LoadCollection_Call{Call: _e.mock.On("LoadCollection", _a0, _a1)} } -func (_c *MockQueryCoord_LoadCollection_Call) Run(run func(ctx context.Context, req *querypb.LoadCollectionRequest)) *MockQueryCoord_LoadCollection_Call { +func (_c *MockQueryCoord_LoadCollection_Call) Run(run func(_a0 context.Context, _a1 *querypb.LoadCollectionRequest)) *MockQueryCoord_LoadCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.LoadCollectionRequest)) }) @@ -895,17 +900,17 @@ func (_c *MockQueryCoord_LoadCollection_Call) RunAndReturn(run func(context.Cont return _c } -// LoadPartitions provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// LoadPartitions provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) LoadPartitions(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadPartitionsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadPartitionsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -913,7 +918,7 @@ func (_m *MockQueryCoord) LoadPartitions(ctx context.Context, req *querypb.LoadP } if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadPartitionsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -927,13 +932,13 @@ type MockQueryCoord_LoadPartitions_Call struct { } // LoadPartitions is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.LoadPartitionsRequest -func (_e *MockQueryCoord_Expecter) LoadPartitions(ctx interface{}, req interface{}) *MockQueryCoord_LoadPartitions_Call { - return &MockQueryCoord_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.LoadPartitionsRequest +func (_e *MockQueryCoord_Expecter) LoadPartitions(_a0 interface{}, _a1 interface{}) *MockQueryCoord_LoadPartitions_Call { + return &MockQueryCoord_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", _a0, _a1)} } -func (_c *MockQueryCoord_LoadPartitions_Call) Run(run func(ctx context.Context, req *querypb.LoadPartitionsRequest)) *MockQueryCoord_LoadPartitions_Call { +func (_c *MockQueryCoord_LoadPartitions_Call) Run(run func(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest)) *MockQueryCoord_LoadPartitions_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.LoadPartitionsRequest)) }) @@ -991,17 +996,17 @@ func (_c *MockQueryCoord_Register_Call) RunAndReturn(run func() error) *MockQuer return _c } -// ReleaseCollection provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// ReleaseCollection provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) ReleaseCollection(_a0 context.Context, _a1 *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1009,7 +1014,7 @@ func (_m *MockQueryCoord) ReleaseCollection(ctx context.Context, req *querypb.Re } if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleaseCollectionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1023,13 +1028,13 @@ type MockQueryCoord_ReleaseCollection_Call struct { } // ReleaseCollection is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.ReleaseCollectionRequest -func (_e *MockQueryCoord_Expecter) ReleaseCollection(ctx interface{}, req interface{}) *MockQueryCoord_ReleaseCollection_Call { - return &MockQueryCoord_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.ReleaseCollectionRequest +func (_e *MockQueryCoord_Expecter) ReleaseCollection(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ReleaseCollection_Call { + return &MockQueryCoord_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", _a0, _a1)} } -func (_c *MockQueryCoord_ReleaseCollection_Call) Run(run func(ctx context.Context, req *querypb.ReleaseCollectionRequest)) *MockQueryCoord_ReleaseCollection_Call { +func (_c *MockQueryCoord_ReleaseCollection_Call) Run(run func(_a0 context.Context, _a1 *querypb.ReleaseCollectionRequest)) *MockQueryCoord_ReleaseCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.ReleaseCollectionRequest)) }) @@ -1046,17 +1051,17 @@ func (_c *MockQueryCoord_ReleaseCollection_Call) RunAndReturn(run func(context.C return _c } -// ReleasePartitions provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// ReleasePartitions provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) ReleasePartitions(_a0 context.Context, _a1 *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleasePartitionsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleasePartitionsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1064,7 +1069,7 @@ func (_m *MockQueryCoord) ReleasePartitions(ctx context.Context, req *querypb.Re } if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleasePartitionsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1078,13 +1083,13 @@ type MockQueryCoord_ReleasePartitions_Call struct { } // ReleasePartitions is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.ReleasePartitionsRequest -func (_e *MockQueryCoord_Expecter) ReleasePartitions(ctx interface{}, req interface{}) *MockQueryCoord_ReleasePartitions_Call { - return &MockQueryCoord_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.ReleasePartitionsRequest +func (_e *MockQueryCoord_Expecter) ReleasePartitions(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ReleasePartitions_Call { + return &MockQueryCoord_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", _a0, _a1)} } -func (_c *MockQueryCoord_ReleasePartitions_Call) Run(run func(ctx context.Context, req *querypb.ReleasePartitionsRequest)) *MockQueryCoord_ReleasePartitions_Call { +func (_c *MockQueryCoord_ReleasePartitions_Call) Run(run func(_a0 context.Context, _a1 *querypb.ReleasePartitionsRequest)) *MockQueryCoord_ReleasePartitions_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.ReleasePartitionsRequest)) }) @@ -1134,12 +1139,12 @@ func (_c *MockQueryCoord_SetAddress_Call) RunAndReturn(run func(string)) *MockQu return _c } -// SetDataCoord provides a mock function with given fields: dataCoord -func (_m *MockQueryCoord) SetDataCoord(dataCoord types.DataCoord) error { +// SetDataCoordClient provides a mock function with given fields: dataCoord +func (_m *MockQueryCoord) SetDataCoordClient(dataCoord types.DataCoordClient) error { ret := _m.Called(dataCoord) var r0 error - if rf, ok := ret.Get(0).(func(types.DataCoord) error); ok { + if rf, ok := ret.Get(0).(func(types.DataCoordClient) error); ok { r0 = rf(dataCoord) } else { r0 = ret.Error(0) @@ -1148,30 +1153,30 @@ func (_m *MockQueryCoord) SetDataCoord(dataCoord types.DataCoord) error { return r0 } -// MockQueryCoord_SetDataCoord_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetDataCoord' -type MockQueryCoord_SetDataCoord_Call struct { +// MockQueryCoord_SetDataCoordClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetDataCoordClient' +type MockQueryCoord_SetDataCoordClient_Call struct { *mock.Call } -// SetDataCoord is a helper method to define mock.On call -// - dataCoord types.DataCoord -func (_e *MockQueryCoord_Expecter) SetDataCoord(dataCoord interface{}) *MockQueryCoord_SetDataCoord_Call { - return &MockQueryCoord_SetDataCoord_Call{Call: _e.mock.On("SetDataCoord", dataCoord)} +// SetDataCoordClient is a helper method to define mock.On call +// - dataCoord types.DataCoordClient +func (_e *MockQueryCoord_Expecter) SetDataCoordClient(dataCoord interface{}) *MockQueryCoord_SetDataCoordClient_Call { + return &MockQueryCoord_SetDataCoordClient_Call{Call: _e.mock.On("SetDataCoordClient", dataCoord)} } -func (_c *MockQueryCoord_SetDataCoord_Call) Run(run func(dataCoord types.DataCoord)) *MockQueryCoord_SetDataCoord_Call { +func (_c *MockQueryCoord_SetDataCoordClient_Call) Run(run func(dataCoord types.DataCoordClient)) *MockQueryCoord_SetDataCoordClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(types.DataCoord)) + run(args[0].(types.DataCoordClient)) }) return _c } -func (_c *MockQueryCoord_SetDataCoord_Call) Return(_a0 error) *MockQueryCoord_SetDataCoord_Call { +func (_c *MockQueryCoord_SetDataCoordClient_Call) Return(_a0 error) *MockQueryCoord_SetDataCoordClient_Call { _c.Call.Return(_a0) return _c } -func (_c *MockQueryCoord_SetDataCoord_Call) RunAndReturn(run func(types.DataCoord) error) *MockQueryCoord_SetDataCoord_Call { +func (_c *MockQueryCoord_SetDataCoordClient_Call) RunAndReturn(run func(types.DataCoordClient) error) *MockQueryCoord_SetDataCoordClient_Call { _c.Call.Return(run) return _c } @@ -1210,7 +1215,7 @@ func (_c *MockQueryCoord_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Cli } // SetQueryNodeCreator provides a mock function with given fields: _a0 -func (_m *MockQueryCoord) SetQueryNodeCreator(_a0 func(context.Context, string, int64) (types.QueryNode, error)) { +func (_m *MockQueryCoord) SetQueryNodeCreator(_a0 func(context.Context, string, int64) (types.QueryNodeClient, error)) { _m.Called(_a0) } @@ -1220,14 +1225,14 @@ type MockQueryCoord_SetQueryNodeCreator_Call struct { } // SetQueryNodeCreator is a helper method to define mock.On call -// - _a0 func(context.Context , string , int64)(types.QueryNode , error) +// - _a0 func(context.Context , string , int64)(types.QueryNodeClient , error) func (_e *MockQueryCoord_Expecter) SetQueryNodeCreator(_a0 interface{}) *MockQueryCoord_SetQueryNodeCreator_Call { return &MockQueryCoord_SetQueryNodeCreator_Call{Call: _e.mock.On("SetQueryNodeCreator", _a0)} } -func (_c *MockQueryCoord_SetQueryNodeCreator_Call) Run(run func(_a0 func(context.Context, string, int64) (types.QueryNode, error))) *MockQueryCoord_SetQueryNodeCreator_Call { +func (_c *MockQueryCoord_SetQueryNodeCreator_Call) Run(run func(_a0 func(context.Context, string, int64) (types.QueryNodeClient, error))) *MockQueryCoord_SetQueryNodeCreator_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(func(context.Context, string, int64) (types.QueryNode, error))) + run(args[0].(func(context.Context, string, int64) (types.QueryNodeClient, error))) }) return _c } @@ -1237,17 +1242,17 @@ func (_c *MockQueryCoord_SetQueryNodeCreator_Call) Return() *MockQueryCoord_SetQ return _c } -func (_c *MockQueryCoord_SetQueryNodeCreator_Call) RunAndReturn(run func(func(context.Context, string, int64) (types.QueryNode, error))) *MockQueryCoord_SetQueryNodeCreator_Call { +func (_c *MockQueryCoord_SetQueryNodeCreator_Call) RunAndReturn(run func(func(context.Context, string, int64) (types.QueryNodeClient, error))) *MockQueryCoord_SetQueryNodeCreator_Call { _c.Call.Return(run) return _c } -// SetRootCoord provides a mock function with given fields: rootCoord -func (_m *MockQueryCoord) SetRootCoord(rootCoord types.RootCoord) error { +// SetRootCoordClient provides a mock function with given fields: rootCoord +func (_m *MockQueryCoord) SetRootCoordClient(rootCoord types.RootCoordClient) error { ret := _m.Called(rootCoord) var r0 error - if rf, ok := ret.Get(0).(func(types.RootCoord) error); ok { + if rf, ok := ret.Get(0).(func(types.RootCoordClient) error); ok { r0 = rf(rootCoord) } else { r0 = ret.Error(0) @@ -1256,45 +1261,78 @@ func (_m *MockQueryCoord) SetRootCoord(rootCoord types.RootCoord) error { return r0 } -// MockQueryCoord_SetRootCoord_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRootCoord' -type MockQueryCoord_SetRootCoord_Call struct { +// MockQueryCoord_SetRootCoordClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRootCoordClient' +type MockQueryCoord_SetRootCoordClient_Call struct { *mock.Call } -// SetRootCoord is a helper method to define mock.On call -// - rootCoord types.RootCoord -func (_e *MockQueryCoord_Expecter) SetRootCoord(rootCoord interface{}) *MockQueryCoord_SetRootCoord_Call { - return &MockQueryCoord_SetRootCoord_Call{Call: _e.mock.On("SetRootCoord", rootCoord)} +// SetRootCoordClient is a helper method to define mock.On call +// - rootCoord types.RootCoordClient +func (_e *MockQueryCoord_Expecter) SetRootCoordClient(rootCoord interface{}) *MockQueryCoord_SetRootCoordClient_Call { + return &MockQueryCoord_SetRootCoordClient_Call{Call: _e.mock.On("SetRootCoordClient", rootCoord)} } -func (_c *MockQueryCoord_SetRootCoord_Call) Run(run func(rootCoord types.RootCoord)) *MockQueryCoord_SetRootCoord_Call { +func (_c *MockQueryCoord_SetRootCoordClient_Call) Run(run func(rootCoord types.RootCoordClient)) *MockQueryCoord_SetRootCoordClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(types.RootCoord)) + run(args[0].(types.RootCoordClient)) }) return _c } -func (_c *MockQueryCoord_SetRootCoord_Call) Return(_a0 error) *MockQueryCoord_SetRootCoord_Call { +func (_c *MockQueryCoord_SetRootCoordClient_Call) Return(_a0 error) *MockQueryCoord_SetRootCoordClient_Call { _c.Call.Return(_a0) return _c } -func (_c *MockQueryCoord_SetRootCoord_Call) RunAndReturn(run func(types.RootCoord) error) *MockQueryCoord_SetRootCoord_Call { +func (_c *MockQueryCoord_SetRootCoordClient_Call) RunAndReturn(run func(types.RootCoordClient) error) *MockQueryCoord_SetRootCoordClient_Call { + _c.Call.Return(run) + return _c +} + +// SetTiKVClient provides a mock function with given fields: client +func (_m *MockQueryCoord) SetTiKVClient(client *txnkv.Client) { + _m.Called(client) +} + +// MockQueryCoord_SetTiKVClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTiKVClient' +type MockQueryCoord_SetTiKVClient_Call struct { + *mock.Call +} + +// SetTiKVClient is a helper method to define mock.On call +// - client *txnkv.Client +func (_e *MockQueryCoord_Expecter) SetTiKVClient(client interface{}) *MockQueryCoord_SetTiKVClient_Call { + return &MockQueryCoord_SetTiKVClient_Call{Call: _e.mock.On("SetTiKVClient", client)} +} + +func (_c *MockQueryCoord_SetTiKVClient_Call) Run(run func(client *txnkv.Client)) *MockQueryCoord_SetTiKVClient_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*txnkv.Client)) + }) + return _c +} + +func (_c *MockQueryCoord_SetTiKVClient_Call) Return() *MockQueryCoord_SetTiKVClient_Call { + _c.Call.Return() + return _c +} + +func (_c *MockQueryCoord_SetTiKVClient_Call) RunAndReturn(run func(*txnkv.Client)) *MockQueryCoord_SetTiKVClient_Call { _c.Call.Return(run) return _c } -// ShowCollections provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { - ret := _m.Called(ctx, req) +// ShowCollections provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) ShowCollections(_a0 context.Context, _a1 *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *querypb.ShowCollectionsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.ShowCollectionsRequest) *querypb.ShowCollectionsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*querypb.ShowCollectionsResponse) @@ -1302,7 +1340,7 @@ func (_m *MockQueryCoord) ShowCollections(ctx context.Context, req *querypb.Show } if rf, ok := ret.Get(1).(func(context.Context, *querypb.ShowCollectionsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1316,13 +1354,13 @@ type MockQueryCoord_ShowCollections_Call struct { } // ShowCollections is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.ShowCollectionsRequest -func (_e *MockQueryCoord_Expecter) ShowCollections(ctx interface{}, req interface{}) *MockQueryCoord_ShowCollections_Call { - return &MockQueryCoord_ShowCollections_Call{Call: _e.mock.On("ShowCollections", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.ShowCollectionsRequest +func (_e *MockQueryCoord_Expecter) ShowCollections(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ShowCollections_Call { + return &MockQueryCoord_ShowCollections_Call{Call: _e.mock.On("ShowCollections", _a0, _a1)} } -func (_c *MockQueryCoord_ShowCollections_Call) Run(run func(ctx context.Context, req *querypb.ShowCollectionsRequest)) *MockQueryCoord_ShowCollections_Call { +func (_c *MockQueryCoord_ShowCollections_Call) Run(run func(_a0 context.Context, _a1 *querypb.ShowCollectionsRequest)) *MockQueryCoord_ShowCollections_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.ShowCollectionsRequest)) }) @@ -1339,17 +1377,17 @@ func (_c *MockQueryCoord_ShowCollections_Call) RunAndReturn(run func(context.Con return _c } -// ShowConfigurations provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - ret := _m.Called(ctx, req) +// ShowConfigurations provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) ShowConfigurations(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *internalpb.ShowConfigurationsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) *internalpb.ShowConfigurationsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*internalpb.ShowConfigurationsResponse) @@ -1357,7 +1395,7 @@ func (_m *MockQueryCoord) ShowConfigurations(ctx context.Context, req *internalp } if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1371,13 +1409,13 @@ type MockQueryCoord_ShowConfigurations_Call struct { } // ShowConfigurations is a helper method to define mock.On call -// - ctx context.Context -// - req *internalpb.ShowConfigurationsRequest -func (_e *MockQueryCoord_Expecter) ShowConfigurations(ctx interface{}, req interface{}) *MockQueryCoord_ShowConfigurations_Call { - return &MockQueryCoord_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", ctx, req)} +// - _a0 context.Context +// - _a1 *internalpb.ShowConfigurationsRequest +func (_e *MockQueryCoord_Expecter) ShowConfigurations(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ShowConfigurations_Call { + return &MockQueryCoord_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", _a0, _a1)} } -func (_c *MockQueryCoord_ShowConfigurations_Call) Run(run func(ctx context.Context, req *internalpb.ShowConfigurationsRequest)) *MockQueryCoord_ShowConfigurations_Call { +func (_c *MockQueryCoord_ShowConfigurations_Call) Run(run func(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest)) *MockQueryCoord_ShowConfigurations_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*internalpb.ShowConfigurationsRequest)) }) @@ -1394,17 +1432,17 @@ func (_c *MockQueryCoord_ShowConfigurations_Call) RunAndReturn(run func(context. return _c } -// ShowPartitions provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - ret := _m.Called(ctx, req) +// ShowPartitions provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) ShowPartitions(_a0 context.Context, _a1 *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *querypb.ShowPartitionsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.ShowPartitionsRequest) *querypb.ShowPartitionsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*querypb.ShowPartitionsResponse) @@ -1412,7 +1450,7 @@ func (_m *MockQueryCoord) ShowPartitions(ctx context.Context, req *querypb.ShowP } if rf, ok := ret.Get(1).(func(context.Context, *querypb.ShowPartitionsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1426,13 +1464,13 @@ type MockQueryCoord_ShowPartitions_Call struct { } // ShowPartitions is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.ShowPartitionsRequest -func (_e *MockQueryCoord_Expecter) ShowPartitions(ctx interface{}, req interface{}) *MockQueryCoord_ShowPartitions_Call { - return &MockQueryCoord_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.ShowPartitionsRequest +func (_e *MockQueryCoord_Expecter) ShowPartitions(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ShowPartitions_Call { + return &MockQueryCoord_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", _a0, _a1)} } -func (_c *MockQueryCoord_ShowPartitions_Call) Run(run func(ctx context.Context, req *querypb.ShowPartitionsRequest)) *MockQueryCoord_ShowPartitions_Call { +func (_c *MockQueryCoord_ShowPartitions_Call) Run(run func(_a0 context.Context, _a1 *querypb.ShowPartitionsRequest)) *MockQueryCoord_ShowPartitions_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.ShowPartitionsRequest)) }) @@ -1531,17 +1569,17 @@ func (_c *MockQueryCoord_Stop_Call) RunAndReturn(run func() error) *MockQueryCoo return _c } -// SyncNewCreatedPartition provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// SyncNewCreatedPartition provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) SyncNewCreatedPartition(_a0 context.Context, _a1 *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncNewCreatedPartitionRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1549,7 +1587,7 @@ func (_m *MockQueryCoord) SyncNewCreatedPartition(ctx context.Context, req *quer } if rf, ok := ret.Get(1).(func(context.Context, *querypb.SyncNewCreatedPartitionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1563,13 +1601,13 @@ type MockQueryCoord_SyncNewCreatedPartition_Call struct { } // SyncNewCreatedPartition is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.SyncNewCreatedPartitionRequest -func (_e *MockQueryCoord_Expecter) SyncNewCreatedPartition(ctx interface{}, req interface{}) *MockQueryCoord_SyncNewCreatedPartition_Call { - return &MockQueryCoord_SyncNewCreatedPartition_Call{Call: _e.mock.On("SyncNewCreatedPartition", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.SyncNewCreatedPartitionRequest +func (_e *MockQueryCoord_Expecter) SyncNewCreatedPartition(_a0 interface{}, _a1 interface{}) *MockQueryCoord_SyncNewCreatedPartition_Call { + return &MockQueryCoord_SyncNewCreatedPartition_Call{Call: _e.mock.On("SyncNewCreatedPartition", _a0, _a1)} } -func (_c *MockQueryCoord_SyncNewCreatedPartition_Call) Run(run func(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest)) *MockQueryCoord_SyncNewCreatedPartition_Call { +func (_c *MockQueryCoord_SyncNewCreatedPartition_Call) Run(run func(_a0 context.Context, _a1 *querypb.SyncNewCreatedPartitionRequest)) *MockQueryCoord_SyncNewCreatedPartition_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.SyncNewCreatedPartitionRequest)) }) @@ -1586,17 +1624,17 @@ func (_c *MockQueryCoord_SyncNewCreatedPartition_Call) RunAndReturn(run func(con return _c } -// TransferNode provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// TransferNode provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) TransferNode(_a0 context.Context, _a1 *milvuspb.TransferNodeRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.TransferNodeRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.TransferNodeRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1604,7 +1642,7 @@ func (_m *MockQueryCoord) TransferNode(ctx context.Context, req *milvuspb.Transf } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.TransferNodeRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1618,13 +1656,13 @@ type MockQueryCoord_TransferNode_Call struct { } // TransferNode is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.TransferNodeRequest -func (_e *MockQueryCoord_Expecter) TransferNode(ctx interface{}, req interface{}) *MockQueryCoord_TransferNode_Call { - return &MockQueryCoord_TransferNode_Call{Call: _e.mock.On("TransferNode", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.TransferNodeRequest +func (_e *MockQueryCoord_Expecter) TransferNode(_a0 interface{}, _a1 interface{}) *MockQueryCoord_TransferNode_Call { + return &MockQueryCoord_TransferNode_Call{Call: _e.mock.On("TransferNode", _a0, _a1)} } -func (_c *MockQueryCoord_TransferNode_Call) Run(run func(ctx context.Context, req *milvuspb.TransferNodeRequest)) *MockQueryCoord_TransferNode_Call { +func (_c *MockQueryCoord_TransferNode_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.TransferNodeRequest)) *MockQueryCoord_TransferNode_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.TransferNodeRequest)) }) @@ -1641,17 +1679,17 @@ func (_c *MockQueryCoord_TransferNode_Call) RunAndReturn(run func(context.Contex return _c } -// TransferReplica provides a mock function with given fields: ctx, req -func (_m *MockQueryCoord) TransferReplica(ctx context.Context, req *querypb.TransferReplicaRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// TransferReplica provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) TransferReplica(_a0 context.Context, _a1 *querypb.TransferReplicaRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferReplicaRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferReplicaRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1659,7 +1697,7 @@ func (_m *MockQueryCoord) TransferReplica(ctx context.Context, req *querypb.Tran } if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferReplicaRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1673,13 +1711,13 @@ type MockQueryCoord_TransferReplica_Call struct { } // TransferReplica is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.TransferReplicaRequest -func (_e *MockQueryCoord_Expecter) TransferReplica(ctx interface{}, req interface{}) *MockQueryCoord_TransferReplica_Call { - return &MockQueryCoord_TransferReplica_Call{Call: _e.mock.On("TransferReplica", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.TransferReplicaRequest +func (_e *MockQueryCoord_Expecter) TransferReplica(_a0 interface{}, _a1 interface{}) *MockQueryCoord_TransferReplica_Call { + return &MockQueryCoord_TransferReplica_Call{Call: _e.mock.On("TransferReplica", _a0, _a1)} } -func (_c *MockQueryCoord_TransferReplica_Call) Run(run func(ctx context.Context, req *querypb.TransferReplicaRequest)) *MockQueryCoord_TransferReplica_Call { +func (_c *MockQueryCoord_TransferReplica_Call) Run(run func(_a0 context.Context, _a1 *querypb.TransferReplicaRequest)) *MockQueryCoord_TransferReplica_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.TransferReplicaRequest)) }) diff --git a/internal/mocks/mock_querycoord_client.go b/internal/mocks/mock_querycoord_client.go new file mode 100644 index 0000000000000..95bfe6ad9e254 --- /dev/null +++ b/internal/mocks/mock_querycoord_client.go @@ -0,0 +1,1767 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + + commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + + grpc "google.golang.org/grpc" + + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" + + milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + + mock "github.com/stretchr/testify/mock" + + querypb "github.com/milvus-io/milvus/internal/proto/querypb" +) + +// MockQueryCoordClient is an autogenerated mock type for the QueryCoordClient type +type MockQueryCoordClient struct { + mock.Mock +} + +type MockQueryCoordClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockQueryCoordClient) EXPECT() *MockQueryCoordClient_Expecter { + return &MockQueryCoordClient_Expecter{mock: &_m.Mock} +} + +// CheckHealth provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.CheckHealthResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) *milvuspb.CheckHealthResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_CheckHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckHealth' +type MockQueryCoordClient_CheckHealth_Call struct { + *mock.Call +} + +// CheckHealth is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.CheckHealthRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) CheckHealth(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_CheckHealth_Call { + return &MockQueryCoordClient_CheckHealth_Call{Call: _e.mock.On("CheckHealth", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_CheckHealth_Call) Run(run func(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_CheckHealth_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.CheckHealthRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_CheckHealth_Call) Return(_a0 *milvuspb.CheckHealthResponse, _a1 error) *MockQueryCoordClient_CheckHealth_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_CheckHealth_Call) RunAndReturn(run func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error)) *MockQueryCoordClient_CheckHealth_Call { + _c.Call.Return(run) + return _c +} + +// Close provides a mock function with given fields: +func (_m *MockQueryCoordClient) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryCoordClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockQueryCoordClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockQueryCoordClient_Expecter) Close() *MockQueryCoordClient_Close_Call { + return &MockQueryCoordClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockQueryCoordClient_Close_Call) Run(run func()) *MockQueryCoordClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockQueryCoordClient_Close_Call) Return(_a0 error) *MockQueryCoordClient_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryCoordClient_Close_Call) RunAndReturn(run func() error) *MockQueryCoordClient_Close_Call { + _c.Call.Return(run) + return _c +} + +// CreateResourceGroup provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) CreateResourceGroup(ctx context.Context, in *milvuspb.CreateResourceGroupRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateResourceGroupRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateResourceGroupRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateResourceGroupRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_CreateResourceGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateResourceGroup' +type MockQueryCoordClient_CreateResourceGroup_Call struct { + *mock.Call +} + +// CreateResourceGroup is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.CreateResourceGroupRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) CreateResourceGroup(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_CreateResourceGroup_Call { + return &MockQueryCoordClient_CreateResourceGroup_Call{Call: _e.mock.On("CreateResourceGroup", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_CreateResourceGroup_Call) Run(run func(ctx context.Context, in *milvuspb.CreateResourceGroupRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_CreateResourceGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.CreateResourceGroupRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_CreateResourceGroup_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_CreateResourceGroup_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_CreateResourceGroup_Call) RunAndReturn(run func(context.Context, *milvuspb.CreateResourceGroupRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_CreateResourceGroup_Call { + _c.Call.Return(run) + return _c +} + +// DescribeResourceGroup provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) DescribeResourceGroup(ctx context.Context, in *querypb.DescribeResourceGroupRequest, opts ...grpc.CallOption) (*querypb.DescribeResourceGroupResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *querypb.DescribeResourceGroupResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.DescribeResourceGroupRequest, ...grpc.CallOption) (*querypb.DescribeResourceGroupResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.DescribeResourceGroupRequest, ...grpc.CallOption) *querypb.DescribeResourceGroupResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.DescribeResourceGroupResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.DescribeResourceGroupRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_DescribeResourceGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeResourceGroup' +type MockQueryCoordClient_DescribeResourceGroup_Call struct { + *mock.Call +} + +// DescribeResourceGroup is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.DescribeResourceGroupRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) DescribeResourceGroup(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_DescribeResourceGroup_Call { + return &MockQueryCoordClient_DescribeResourceGroup_Call{Call: _e.mock.On("DescribeResourceGroup", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_DescribeResourceGroup_Call) Run(run func(ctx context.Context, in *querypb.DescribeResourceGroupRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_DescribeResourceGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.DescribeResourceGroupRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_DescribeResourceGroup_Call) Return(_a0 *querypb.DescribeResourceGroupResponse, _a1 error) *MockQueryCoordClient_DescribeResourceGroup_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_DescribeResourceGroup_Call) RunAndReturn(run func(context.Context, *querypb.DescribeResourceGroupRequest, ...grpc.CallOption) (*querypb.DescribeResourceGroupResponse, error)) *MockQueryCoordClient_DescribeResourceGroup_Call { + _c.Call.Return(run) + return _c +} + +// DropResourceGroup provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) DropResourceGroup(ctx context.Context, in *milvuspb.DropResourceGroupRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropResourceGroupRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropResourceGroupRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropResourceGroupRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_DropResourceGroup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropResourceGroup' +type MockQueryCoordClient_DropResourceGroup_Call struct { + *mock.Call +} + +// DropResourceGroup is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.DropResourceGroupRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) DropResourceGroup(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_DropResourceGroup_Call { + return &MockQueryCoordClient_DropResourceGroup_Call{Call: _e.mock.On("DropResourceGroup", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_DropResourceGroup_Call) Run(run func(ctx context.Context, in *milvuspb.DropResourceGroupRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_DropResourceGroup_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.DropResourceGroupRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_DropResourceGroup_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_DropResourceGroup_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_DropResourceGroup_Call) RunAndReturn(run func(context.Context, *milvuspb.DropResourceGroupRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_DropResourceGroup_Call { + _c.Call.Return(run) + return _c +} + +// GetComponentStates provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ComponentStates + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) *milvuspb.ComponentStates); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ComponentStates) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_GetComponentStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetComponentStates' +type MockQueryCoordClient_GetComponentStates_Call struct { + *mock.Call +} + +// GetComponentStates is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetComponentStatesRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) GetComponentStates(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_GetComponentStates_Call { + return &MockQueryCoordClient_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_GetComponentStates_Call) Run(run func(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_GetComponentStates_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentStates, _a1 error) *MockQueryCoordClient_GetComponentStates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)) *MockQueryCoordClient_GetComponentStates_Call { + _c.Call.Return(run) + return _c +} + +// GetMetrics provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.GetMetricsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) *milvuspb.GetMetricsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_GetMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetMetrics' +type MockQueryCoordClient_GetMetrics_Call struct { + *mock.Call +} + +// GetMetrics is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetMetricsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) GetMetrics(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_GetMetrics_Call { + return &MockQueryCoordClient_GetMetrics_Call{Call: _e.mock.On("GetMetrics", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_GetMetrics_Call) Run(run func(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_GetMetrics_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_GetMetrics_Call) Return(_a0 *milvuspb.GetMetricsResponse, _a1 error) *MockQueryCoordClient_GetMetrics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_GetMetrics_Call) RunAndReturn(run func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)) *MockQueryCoordClient_GetMetrics_Call { + _c.Call.Return(run) + return _c +} + +// GetPartitionStates provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) GetPartitionStates(ctx context.Context, in *querypb.GetPartitionStatesRequest, opts ...grpc.CallOption) (*querypb.GetPartitionStatesResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *querypb.GetPartitionStatesResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetPartitionStatesRequest, ...grpc.CallOption) (*querypb.GetPartitionStatesResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetPartitionStatesRequest, ...grpc.CallOption) *querypb.GetPartitionStatesResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.GetPartitionStatesResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetPartitionStatesRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_GetPartitionStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPartitionStates' +type MockQueryCoordClient_GetPartitionStates_Call struct { + *mock.Call +} + +// GetPartitionStates is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.GetPartitionStatesRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) GetPartitionStates(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_GetPartitionStates_Call { + return &MockQueryCoordClient_GetPartitionStates_Call{Call: _e.mock.On("GetPartitionStates", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_GetPartitionStates_Call) Run(run func(ctx context.Context, in *querypb.GetPartitionStatesRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_GetPartitionStates_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.GetPartitionStatesRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_GetPartitionStates_Call) Return(_a0 *querypb.GetPartitionStatesResponse, _a1 error) *MockQueryCoordClient_GetPartitionStates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_GetPartitionStates_Call) RunAndReturn(run func(context.Context, *querypb.GetPartitionStatesRequest, ...grpc.CallOption) (*querypb.GetPartitionStatesResponse, error)) *MockQueryCoordClient_GetPartitionStates_Call { + _c.Call.Return(run) + return _c +} + +// GetReplicas provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) GetReplicas(ctx context.Context, in *milvuspb.GetReplicasRequest, opts ...grpc.CallOption) (*milvuspb.GetReplicasResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.GetReplicasResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetReplicasRequest, ...grpc.CallOption) (*milvuspb.GetReplicasResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetReplicasRequest, ...grpc.CallOption) *milvuspb.GetReplicasResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetReplicasResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetReplicasRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_GetReplicas_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetReplicas' +type MockQueryCoordClient_GetReplicas_Call struct { + *mock.Call +} + +// GetReplicas is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetReplicasRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) GetReplicas(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_GetReplicas_Call { + return &MockQueryCoordClient_GetReplicas_Call{Call: _e.mock.On("GetReplicas", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_GetReplicas_Call) Run(run func(ctx context.Context, in *milvuspb.GetReplicasRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_GetReplicas_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetReplicasRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_GetReplicas_Call) Return(_a0 *milvuspb.GetReplicasResponse, _a1 error) *MockQueryCoordClient_GetReplicas_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_GetReplicas_Call) RunAndReturn(run func(context.Context, *milvuspb.GetReplicasRequest, ...grpc.CallOption) (*milvuspb.GetReplicasResponse, error)) *MockQueryCoordClient_GetReplicas_Call { + _c.Call.Return(run) + return _c +} + +// GetSegmentInfo provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *querypb.GetSegmentInfoResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetSegmentInfoRequest, ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetSegmentInfoRequest, ...grpc.CallOption) *querypb.GetSegmentInfoResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.GetSegmentInfoResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetSegmentInfoRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_GetSegmentInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSegmentInfo' +type MockQueryCoordClient_GetSegmentInfo_Call struct { + *mock.Call +} + +// GetSegmentInfo is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.GetSegmentInfoRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) GetSegmentInfo(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_GetSegmentInfo_Call { + return &MockQueryCoordClient_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_GetSegmentInfo_Call) Run(run func(ctx context.Context, in *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_GetSegmentInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.GetSegmentInfoRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_GetSegmentInfo_Call) Return(_a0 *querypb.GetSegmentInfoResponse, _a1 error) *MockQueryCoordClient_GetSegmentInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_GetSegmentInfo_Call) RunAndReturn(run func(context.Context, *querypb.GetSegmentInfoRequest, ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error)) *MockQueryCoordClient_GetSegmentInfo_Call { + _c.Call.Return(run) + return _c +} + +// GetShardLeaders provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) GetShardLeaders(ctx context.Context, in *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *querypb.GetShardLeadersResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetShardLeadersRequest, ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetShardLeadersRequest, ...grpc.CallOption) *querypb.GetShardLeadersResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.GetShardLeadersResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetShardLeadersRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_GetShardLeaders_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShardLeaders' +type MockQueryCoordClient_GetShardLeaders_Call struct { + *mock.Call +} + +// GetShardLeaders is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.GetShardLeadersRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) GetShardLeaders(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_GetShardLeaders_Call { + return &MockQueryCoordClient_GetShardLeaders_Call{Call: _e.mock.On("GetShardLeaders", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_GetShardLeaders_Call) Run(run func(ctx context.Context, in *querypb.GetShardLeadersRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_GetShardLeaders_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.GetShardLeadersRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_GetShardLeaders_Call) Return(_a0 *querypb.GetShardLeadersResponse, _a1 error) *MockQueryCoordClient_GetShardLeaders_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_GetShardLeaders_Call) RunAndReturn(run func(context.Context, *querypb.GetShardLeadersRequest, ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error)) *MockQueryCoordClient_GetShardLeaders_Call { + _c.Call.Return(run) + return _c +} + +// GetStatisticsChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.StringResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) *milvuspb.StringResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_GetStatisticsChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatisticsChannel' +type MockQueryCoordClient_GetStatisticsChannel_Call struct { + *mock.Call +} + +// GetStatisticsChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.GetStatisticsChannelRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) GetStatisticsChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_GetStatisticsChannel_Call { + return &MockQueryCoordClient_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_GetStatisticsChannel_Call) Run(run func(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_GetStatisticsChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *MockQueryCoordClient_GetStatisticsChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)) *MockQueryCoordClient_GetStatisticsChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetTimeTickChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.StringResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) *milvuspb.StringResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_GetTimeTickChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTimeTickChannel' +type MockQueryCoordClient_GetTimeTickChannel_Call struct { + *mock.Call +} + +// GetTimeTickChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.GetTimeTickChannelRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) GetTimeTickChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_GetTimeTickChannel_Call { + return &MockQueryCoordClient_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_GetTimeTickChannel_Call) Run(run func(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_GetTimeTickChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.GetTimeTickChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_GetTimeTickChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *MockQueryCoordClient_GetTimeTickChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_GetTimeTickChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)) *MockQueryCoordClient_GetTimeTickChannel_Call { + _c.Call.Return(run) + return _c +} + +// ListResourceGroups provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) ListResourceGroups(ctx context.Context, in *milvuspb.ListResourceGroupsRequest, opts ...grpc.CallOption) (*milvuspb.ListResourceGroupsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ListResourceGroupsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListResourceGroupsRequest, ...grpc.CallOption) (*milvuspb.ListResourceGroupsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListResourceGroupsRequest, ...grpc.CallOption) *milvuspb.ListResourceGroupsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ListResourceGroupsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListResourceGroupsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_ListResourceGroups_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListResourceGroups' +type MockQueryCoordClient_ListResourceGroups_Call struct { + *mock.Call +} + +// ListResourceGroups is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.ListResourceGroupsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) ListResourceGroups(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ListResourceGroups_Call { + return &MockQueryCoordClient_ListResourceGroups_Call{Call: _e.mock.On("ListResourceGroups", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_ListResourceGroups_Call) Run(run func(ctx context.Context, in *milvuspb.ListResourceGroupsRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ListResourceGroups_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.ListResourceGroupsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_ListResourceGroups_Call) Return(_a0 *milvuspb.ListResourceGroupsResponse, _a1 error) *MockQueryCoordClient_ListResourceGroups_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_ListResourceGroups_Call) RunAndReturn(run func(context.Context, *milvuspb.ListResourceGroupsRequest, ...grpc.CallOption) (*milvuspb.ListResourceGroupsResponse, error)) *MockQueryCoordClient_ListResourceGroups_Call { + _c.Call.Return(run) + return _c +} + +// LoadBalance provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) LoadBalance(ctx context.Context, in *querypb.LoadBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadBalanceRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadBalanceRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_LoadBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadBalance' +type MockQueryCoordClient_LoadBalance_Call struct { + *mock.Call +} + +// LoadBalance is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.LoadBalanceRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) LoadBalance(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_LoadBalance_Call { + return &MockQueryCoordClient_LoadBalance_Call{Call: _e.mock.On("LoadBalance", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_LoadBalance_Call) Run(run func(ctx context.Context, in *querypb.LoadBalanceRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_LoadBalance_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.LoadBalanceRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_LoadBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_LoadBalance_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_LoadBalance_Call) RunAndReturn(run func(context.Context, *querypb.LoadBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_LoadBalance_Call { + _c.Call.Return(run) + return _c +} + +// LoadCollection provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) LoadCollection(ctx context.Context, in *querypb.LoadCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadCollectionRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadCollectionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_LoadCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadCollection' +type MockQueryCoordClient_LoadCollection_Call struct { + *mock.Call +} + +// LoadCollection is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.LoadCollectionRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) LoadCollection(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_LoadCollection_Call { + return &MockQueryCoordClient_LoadCollection_Call{Call: _e.mock.On("LoadCollection", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_LoadCollection_Call) Run(run func(ctx context.Context, in *querypb.LoadCollectionRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_LoadCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.LoadCollectionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_LoadCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_LoadCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_LoadCollection_Call) RunAndReturn(run func(context.Context, *querypb.LoadCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_LoadCollection_Call { + _c.Call.Return(run) + return _c +} + +// LoadPartitions provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadPartitionsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadPartitionsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadPartitionsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_LoadPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadPartitions' +type MockQueryCoordClient_LoadPartitions_Call struct { + *mock.Call +} + +// LoadPartitions is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.LoadPartitionsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) LoadPartitions(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_LoadPartitions_Call { + return &MockQueryCoordClient_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_LoadPartitions_Call) Run(run func(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_LoadPartitions_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.LoadPartitionsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_LoadPartitions_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_LoadPartitions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_LoadPartitions_Call) RunAndReturn(run func(context.Context, *querypb.LoadPartitionsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_LoadPartitions_Call { + _c.Call.Return(run) + return _c +} + +// ReleaseCollection provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseCollectionRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleaseCollectionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_ReleaseCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseCollection' +type MockQueryCoordClient_ReleaseCollection_Call struct { + *mock.Call +} + +// ReleaseCollection is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.ReleaseCollectionRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) ReleaseCollection(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ReleaseCollection_Call { + return &MockQueryCoordClient_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_ReleaseCollection_Call) Run(run func(ctx context.Context, in *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ReleaseCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.ReleaseCollectionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_ReleaseCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_ReleaseCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_ReleaseCollection_Call) RunAndReturn(run func(context.Context, *querypb.ReleaseCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_ReleaseCollection_Call { + _c.Call.Return(run) + return _c +} + +// ReleasePartitions provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleasePartitionsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleasePartitionsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleasePartitionsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_ReleasePartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleasePartitions' +type MockQueryCoordClient_ReleasePartitions_Call struct { + *mock.Call +} + +// ReleasePartitions is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.ReleasePartitionsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) ReleasePartitions(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ReleasePartitions_Call { + return &MockQueryCoordClient_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_ReleasePartitions_Call) Run(run func(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ReleasePartitions_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.ReleasePartitionsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_ReleasePartitions_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_ReleasePartitions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_ReleasePartitions_Call) RunAndReturn(run func(context.Context, *querypb.ReleasePartitionsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_ReleasePartitions_Call { + _c.Call.Return(run) + return _c +} + +// ShowCollections provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) ShowCollections(ctx context.Context, in *querypb.ShowCollectionsRequest, opts ...grpc.CallOption) (*querypb.ShowCollectionsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *querypb.ShowCollectionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ShowCollectionsRequest, ...grpc.CallOption) (*querypb.ShowCollectionsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ShowCollectionsRequest, ...grpc.CallOption) *querypb.ShowCollectionsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.ShowCollectionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.ShowCollectionsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_ShowCollections_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowCollections' +type MockQueryCoordClient_ShowCollections_Call struct { + *mock.Call +} + +// ShowCollections is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.ShowCollectionsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) ShowCollections(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ShowCollections_Call { + return &MockQueryCoordClient_ShowCollections_Call{Call: _e.mock.On("ShowCollections", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_ShowCollections_Call) Run(run func(ctx context.Context, in *querypb.ShowCollectionsRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ShowCollections_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.ShowCollectionsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_ShowCollections_Call) Return(_a0 *querypb.ShowCollectionsResponse, _a1 error) *MockQueryCoordClient_ShowCollections_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_ShowCollections_Call) RunAndReturn(run func(context.Context, *querypb.ShowCollectionsRequest, ...grpc.CallOption) (*querypb.ShowCollectionsResponse, error)) *MockQueryCoordClient_ShowCollections_Call { + _c.Call.Return(run) + return _c +} + +// ShowConfigurations provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *internalpb.ShowConfigurationsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) *internalpb.ShowConfigurationsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internalpb.ShowConfigurationsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_ShowConfigurations_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowConfigurations' +type MockQueryCoordClient_ShowConfigurations_Call struct { + *mock.Call +} + +// ShowConfigurations is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.ShowConfigurationsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) ShowConfigurations(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ShowConfigurations_Call { + return &MockQueryCoordClient_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_ShowConfigurations_Call) Run(run func(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ShowConfigurations_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.ShowConfigurationsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_ShowConfigurations_Call) Return(_a0 *internalpb.ShowConfigurationsResponse, _a1 error) *MockQueryCoordClient_ShowConfigurations_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_ShowConfigurations_Call) RunAndReturn(run func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error)) *MockQueryCoordClient_ShowConfigurations_Call { + _c.Call.Return(run) + return _c +} + +// ShowPartitions provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) ShowPartitions(ctx context.Context, in *querypb.ShowPartitionsRequest, opts ...grpc.CallOption) (*querypb.ShowPartitionsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *querypb.ShowPartitionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ShowPartitionsRequest, ...grpc.CallOption) (*querypb.ShowPartitionsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ShowPartitionsRequest, ...grpc.CallOption) *querypb.ShowPartitionsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.ShowPartitionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.ShowPartitionsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_ShowPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowPartitions' +type MockQueryCoordClient_ShowPartitions_Call struct { + *mock.Call +} + +// ShowPartitions is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.ShowPartitionsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) ShowPartitions(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ShowPartitions_Call { + return &MockQueryCoordClient_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_ShowPartitions_Call) Run(run func(ctx context.Context, in *querypb.ShowPartitionsRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ShowPartitions_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.ShowPartitionsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_ShowPartitions_Call) Return(_a0 *querypb.ShowPartitionsResponse, _a1 error) *MockQueryCoordClient_ShowPartitions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_ShowPartitions_Call) RunAndReturn(run func(context.Context, *querypb.ShowPartitionsRequest, ...grpc.CallOption) (*querypb.ShowPartitionsResponse, error)) *MockQueryCoordClient_ShowPartitions_Call { + _c.Call.Return(run) + return _c +} + +// SyncNewCreatedPartition provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) SyncNewCreatedPartition(ctx context.Context, in *querypb.SyncNewCreatedPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncNewCreatedPartitionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncNewCreatedPartitionRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.SyncNewCreatedPartitionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_SyncNewCreatedPartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SyncNewCreatedPartition' +type MockQueryCoordClient_SyncNewCreatedPartition_Call struct { + *mock.Call +} + +// SyncNewCreatedPartition is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.SyncNewCreatedPartitionRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) SyncNewCreatedPartition(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_SyncNewCreatedPartition_Call { + return &MockQueryCoordClient_SyncNewCreatedPartition_Call{Call: _e.mock.On("SyncNewCreatedPartition", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_SyncNewCreatedPartition_Call) Run(run func(ctx context.Context, in *querypb.SyncNewCreatedPartitionRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_SyncNewCreatedPartition_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.SyncNewCreatedPartitionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_SyncNewCreatedPartition_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_SyncNewCreatedPartition_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_SyncNewCreatedPartition_Call) RunAndReturn(run func(context.Context, *querypb.SyncNewCreatedPartitionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_SyncNewCreatedPartition_Call { + _c.Call.Return(run) + return _c +} + +// TransferNode provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) TransferNode(ctx context.Context, in *milvuspb.TransferNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.TransferNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.TransferNodeRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.TransferNodeRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_TransferNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferNode' +type MockQueryCoordClient_TransferNode_Call struct { + *mock.Call +} + +// TransferNode is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.TransferNodeRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) TransferNode(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_TransferNode_Call { + return &MockQueryCoordClient_TransferNode_Call{Call: _e.mock.On("TransferNode", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_TransferNode_Call) Run(run func(ctx context.Context, in *milvuspb.TransferNodeRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_TransferNode_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.TransferNodeRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_TransferNode_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_TransferNode_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_TransferNode_Call) RunAndReturn(run func(context.Context, *milvuspb.TransferNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_TransferNode_Call { + _c.Call.Return(run) + return _c +} + +// TransferReplica provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) TransferReplica(ctx context.Context, in *querypb.TransferReplicaRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferReplicaRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferReplicaRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferReplicaRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_TransferReplica_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferReplica' +type MockQueryCoordClient_TransferReplica_Call struct { + *mock.Call +} + +// TransferReplica is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.TransferReplicaRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) TransferReplica(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_TransferReplica_Call { + return &MockQueryCoordClient_TransferReplica_Call{Call: _e.mock.On("TransferReplica", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_TransferReplica_Call) Run(run func(ctx context.Context, in *querypb.TransferReplicaRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_TransferReplica_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.TransferReplicaRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_TransferReplica_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_TransferReplica_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_TransferReplica_Call) RunAndReturn(run func(context.Context, *querypb.TransferReplicaRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_TransferReplica_Call { + _c.Call.Return(run) + return _c +} + +// NewMockQueryCoordClient creates a new instance of MockQueryCoordClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockQueryCoordClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockQueryCoordClient { + mock := &MockQueryCoordClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/mock_querynode.go b/internal/mocks/mock_querynode.go index 0314a728c4af2..9723cf21f152f 100644 --- a/internal/mocks/mock_querynode.go +++ b/internal/mocks/mock_querynode.go @@ -126,25 +126,25 @@ func (_c *MockQueryNode_GetAddress_Call) RunAndReturn(run func() string) *MockQu return _c } -// GetComponentStates provides a mock function with given fields: ctx -func (_m *MockQueryNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret := _m.Called(ctx) +// GetComponentStates provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) GetComponentStates(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ComponentStates var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.ComponentStates, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.ComponentStates); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) *milvuspb.ComponentStates); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ComponentStates) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -158,14 +158,15 @@ type MockQueryNode_GetComponentStates_Call struct { } // GetComponentStates is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockQueryNode_Expecter) GetComponentStates(ctx interface{}) *MockQueryNode_GetComponentStates_Call { - return &MockQueryNode_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx)} +// - _a0 context.Context +// - _a1 *milvuspb.GetComponentStatesRequest +func (_e *MockQueryNode_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *MockQueryNode_GetComponentStates_Call { + return &MockQueryNode_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)} } -func (_c *MockQueryNode_GetComponentStates_Call) Run(run func(ctx context.Context)) *MockQueryNode_GetComponentStates_Call { +func (_c *MockQueryNode_GetComponentStates_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest)) *MockQueryNode_GetComponentStates_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest)) }) return _c } @@ -175,7 +176,7 @@ func (_c *MockQueryNode_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentS return _c } -func (_c *MockQueryNode_GetComponentStates_Call) RunAndReturn(run func(context.Context) (*milvuspb.ComponentStates, error)) *MockQueryNode_GetComponentStates_Call { +func (_c *MockQueryNode_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)) *MockQueryNode_GetComponentStates_Call { _c.Call.Return(run) return _c } @@ -235,17 +236,17 @@ func (_c *MockQueryNode_GetDataDistribution_Call) RunAndReturn(run func(context. return _c } -// GetMetrics provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret := _m.Called(ctx, req) +// GetMetrics provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) GetMetrics(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetMetricsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest) *milvuspb.GetMetricsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) @@ -253,7 +254,7 @@ func (_m *MockQueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetric } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -267,13 +268,13 @@ type MockQueryNode_GetMetrics_Call struct { } // GetMetrics is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetMetricsRequest -func (_e *MockQueryNode_Expecter) GetMetrics(ctx interface{}, req interface{}) *MockQueryNode_GetMetrics_Call { - return &MockQueryNode_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetMetricsRequest +func (_e *MockQueryNode_Expecter) GetMetrics(_a0 interface{}, _a1 interface{}) *MockQueryNode_GetMetrics_Call { + return &MockQueryNode_GetMetrics_Call{Call: _e.mock.On("GetMetrics", _a0, _a1)} } -func (_c *MockQueryNode_GetMetrics_Call) Run(run func(ctx context.Context, req *milvuspb.GetMetricsRequest)) *MockQueryNode_GetMetrics_Call { +func (_c *MockQueryNode_GetMetrics_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetMetricsRequest)) *MockQueryNode_GetMetrics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest)) }) @@ -290,17 +291,17 @@ func (_c *MockQueryNode_GetMetrics_Call) RunAndReturn(run func(context.Context, return _c } -// GetSegmentInfo provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { - ret := _m.Called(ctx, req) +// GetSegmentInfo provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) GetSegmentInfo(_a0 context.Context, _a1 *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *querypb.GetSegmentInfoResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetSegmentInfoRequest) *querypb.GetSegmentInfoResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*querypb.GetSegmentInfoResponse) @@ -308,7 +309,7 @@ func (_m *MockQueryNode) GetSegmentInfo(ctx context.Context, req *querypb.GetSeg } if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetSegmentInfoRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -322,13 +323,13 @@ type MockQueryNode_GetSegmentInfo_Call struct { } // GetSegmentInfo is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.GetSegmentInfoRequest -func (_e *MockQueryNode_Expecter) GetSegmentInfo(ctx interface{}, req interface{}) *MockQueryNode_GetSegmentInfo_Call { - return &MockQueryNode_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.GetSegmentInfoRequest +func (_e *MockQueryNode_Expecter) GetSegmentInfo(_a0 interface{}, _a1 interface{}) *MockQueryNode_GetSegmentInfo_Call { + return &MockQueryNode_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", _a0, _a1)} } -func (_c *MockQueryNode_GetSegmentInfo_Call) Run(run func(ctx context.Context, req *querypb.GetSegmentInfoRequest)) *MockQueryNode_GetSegmentInfo_Call { +func (_c *MockQueryNode_GetSegmentInfo_Call) Run(run func(_a0 context.Context, _a1 *querypb.GetSegmentInfoRequest)) *MockQueryNode_GetSegmentInfo_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.GetSegmentInfoRequest)) }) @@ -345,17 +346,17 @@ func (_c *MockQueryNode_GetSegmentInfo_Call) RunAndReturn(run func(context.Conte return _c } -// GetStatistics provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) { - ret := _m.Called(ctx, req) +// GetStatistics provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) GetStatistics(_a0 context.Context, _a1 *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *internalpb.GetStatisticsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) *internalpb.GetStatisticsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*internalpb.GetStatisticsResponse) @@ -363,7 +364,7 @@ func (_m *MockQueryNode) GetStatistics(ctx context.Context, req *querypb.GetStat } if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetStatisticsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -377,13 +378,13 @@ type MockQueryNode_GetStatistics_Call struct { } // GetStatistics is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.GetStatisticsRequest -func (_e *MockQueryNode_Expecter) GetStatistics(ctx interface{}, req interface{}) *MockQueryNode_GetStatistics_Call { - return &MockQueryNode_GetStatistics_Call{Call: _e.mock.On("GetStatistics", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.GetStatisticsRequest +func (_e *MockQueryNode_Expecter) GetStatistics(_a0 interface{}, _a1 interface{}) *MockQueryNode_GetStatistics_Call { + return &MockQueryNode_GetStatistics_Call{Call: _e.mock.On("GetStatistics", _a0, _a1)} } -func (_c *MockQueryNode_GetStatistics_Call) Run(run func(ctx context.Context, req *querypb.GetStatisticsRequest)) *MockQueryNode_GetStatistics_Call { +func (_c *MockQueryNode_GetStatistics_Call) Run(run func(_a0 context.Context, _a1 *querypb.GetStatisticsRequest)) *MockQueryNode_GetStatistics_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.GetStatisticsRequest)) }) @@ -400,25 +401,25 @@ func (_c *MockQueryNode_GetStatistics_Call) RunAndReturn(run func(context.Contex return _c } -// GetStatisticsChannel provides a mock function with given fields: ctx -func (_m *MockQueryNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret := _m.Called(ctx) +// GetStatisticsChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) GetStatisticsChannel(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.StringResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.StringResponse, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.StringResponse); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) *milvuspb.StringResponse); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.StringResponse) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -432,14 +433,15 @@ type MockQueryNode_GetStatisticsChannel_Call struct { } // GetStatisticsChannel is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockQueryNode_Expecter) GetStatisticsChannel(ctx interface{}) *MockQueryNode_GetStatisticsChannel_Call { - return &MockQueryNode_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", ctx)} +// - _a0 context.Context +// - _a1 *internalpb.GetStatisticsChannelRequest +func (_e *MockQueryNode_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *MockQueryNode_GetStatisticsChannel_Call { + return &MockQueryNode_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)} } -func (_c *MockQueryNode_GetStatisticsChannel_Call) Run(run func(ctx context.Context)) *MockQueryNode_GetStatisticsChannel_Call { +func (_c *MockQueryNode_GetStatisticsChannel_Call) Run(run func(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest)) *MockQueryNode_GetStatisticsChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest)) }) return _c } @@ -449,30 +451,30 @@ func (_c *MockQueryNode_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringRe return _c } -func (_c *MockQueryNode_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context) (*milvuspb.StringResponse, error)) *MockQueryNode_GetStatisticsChannel_Call { +func (_c *MockQueryNode_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)) *MockQueryNode_GetStatisticsChannel_Call { _c.Call.Return(run) return _c } -// GetTimeTickChannel provides a mock function with given fields: ctx -func (_m *MockQueryNode) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret := _m.Called(ctx) +// GetTimeTickChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) GetTimeTickChannel(_a0 context.Context, _a1 *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.StringResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.StringResponse, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.StringResponse); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest) *milvuspb.StringResponse); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.StringResponse) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetTimeTickChannelRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -486,14 +488,15 @@ type MockQueryNode_GetTimeTickChannel_Call struct { } // GetTimeTickChannel is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockQueryNode_Expecter) GetTimeTickChannel(ctx interface{}) *MockQueryNode_GetTimeTickChannel_Call { - return &MockQueryNode_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", ctx)} +// - _a0 context.Context +// - _a1 *internalpb.GetTimeTickChannelRequest +func (_e *MockQueryNode_Expecter) GetTimeTickChannel(_a0 interface{}, _a1 interface{}) *MockQueryNode_GetTimeTickChannel_Call { + return &MockQueryNode_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", _a0, _a1)} } -func (_c *MockQueryNode_GetTimeTickChannel_Call) Run(run func(ctx context.Context)) *MockQueryNode_GetTimeTickChannel_Call { +func (_c *MockQueryNode_GetTimeTickChannel_Call) Run(run func(_a0 context.Context, _a1 *internalpb.GetTimeTickChannelRequest)) *MockQueryNode_GetTimeTickChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*internalpb.GetTimeTickChannelRequest)) }) return _c } @@ -503,7 +506,7 @@ func (_c *MockQueryNode_GetTimeTickChannel_Call) Return(_a0 *milvuspb.StringResp return _c } -func (_c *MockQueryNode_GetTimeTickChannel_Call) RunAndReturn(run func(context.Context) (*milvuspb.StringResponse, error)) *MockQueryNode_GetTimeTickChannel_Call { +func (_c *MockQueryNode_GetTimeTickChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error)) *MockQueryNode_GetTimeTickChannel_Call { _c.Call.Return(run) return _c } @@ -549,17 +552,17 @@ func (_c *MockQueryNode_Init_Call) RunAndReturn(run func() error) *MockQueryNode return _c } -// LoadPartitions provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// LoadPartitions provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) LoadPartitions(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadPartitionsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadPartitionsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -567,7 +570,7 @@ func (_m *MockQueryNode) LoadPartitions(ctx context.Context, req *querypb.LoadPa } if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadPartitionsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -581,13 +584,13 @@ type MockQueryNode_LoadPartitions_Call struct { } // LoadPartitions is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.LoadPartitionsRequest -func (_e *MockQueryNode_Expecter) LoadPartitions(ctx interface{}, req interface{}) *MockQueryNode_LoadPartitions_Call { - return &MockQueryNode_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.LoadPartitionsRequest +func (_e *MockQueryNode_Expecter) LoadPartitions(_a0 interface{}, _a1 interface{}) *MockQueryNode_LoadPartitions_Call { + return &MockQueryNode_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", _a0, _a1)} } -func (_c *MockQueryNode_LoadPartitions_Call) Run(run func(ctx context.Context, req *querypb.LoadPartitionsRequest)) *MockQueryNode_LoadPartitions_Call { +func (_c *MockQueryNode_LoadPartitions_Call) Run(run func(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest)) *MockQueryNode_LoadPartitions_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.LoadPartitionsRequest)) }) @@ -604,17 +607,17 @@ func (_c *MockQueryNode_LoadPartitions_Call) RunAndReturn(run func(context.Conte return _c } -// LoadSegments provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// LoadSegments provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) LoadSegments(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadSegmentsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadSegmentsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -622,7 +625,7 @@ func (_m *MockQueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegm } if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadSegmentsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -636,13 +639,13 @@ type MockQueryNode_LoadSegments_Call struct { } // LoadSegments is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.LoadSegmentsRequest -func (_e *MockQueryNode_Expecter) LoadSegments(ctx interface{}, req interface{}) *MockQueryNode_LoadSegments_Call { - return &MockQueryNode_LoadSegments_Call{Call: _e.mock.On("LoadSegments", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.LoadSegmentsRequest +func (_e *MockQueryNode_Expecter) LoadSegments(_a0 interface{}, _a1 interface{}) *MockQueryNode_LoadSegments_Call { + return &MockQueryNode_LoadSegments_Call{Call: _e.mock.On("LoadSegments", _a0, _a1)} } -func (_c *MockQueryNode_LoadSegments_Call) Run(run func(ctx context.Context, req *querypb.LoadSegmentsRequest)) *MockQueryNode_LoadSegments_Call { +func (_c *MockQueryNode_LoadSegments_Call) Run(run func(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest)) *MockQueryNode_LoadSegments_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.LoadSegmentsRequest)) }) @@ -659,17 +662,17 @@ func (_c *MockQueryNode_LoadSegments_Call) RunAndReturn(run func(context.Context return _c } -// Query provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { - ret := _m.Called(ctx, req) +// Query provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) Query(_a0 context.Context, _a1 *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { + ret := _m.Called(_a0, _a1) var r0 *internalpb.RetrieveResults var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) *internalpb.RetrieveResults); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*internalpb.RetrieveResults) @@ -677,7 +680,7 @@ func (_m *MockQueryNode) Query(ctx context.Context, req *querypb.QueryRequest) ( } if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -691,13 +694,13 @@ type MockQueryNode_Query_Call struct { } // Query is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.QueryRequest -func (_e *MockQueryNode_Expecter) Query(ctx interface{}, req interface{}) *MockQueryNode_Query_Call { - return &MockQueryNode_Query_Call{Call: _e.mock.On("Query", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.QueryRequest +func (_e *MockQueryNode_Expecter) Query(_a0 interface{}, _a1 interface{}) *MockQueryNode_Query_Call { + return &MockQueryNode_Query_Call{Call: _e.mock.On("Query", _a0, _a1)} } -func (_c *MockQueryNode_Query_Call) Run(run func(ctx context.Context, req *querypb.QueryRequest)) *MockQueryNode_Query_Call { +func (_c *MockQueryNode_Query_Call) Run(run func(_a0 context.Context, _a1 *querypb.QueryRequest)) *MockQueryNode_Query_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.QueryRequest)) }) @@ -714,17 +717,17 @@ func (_c *MockQueryNode_Query_Call) RunAndReturn(run func(context.Context, *quer return _c } -// QuerySegments provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { - ret := _m.Called(ctx, req) +// QuerySegments provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) QuerySegments(_a0 context.Context, _a1 *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { + ret := _m.Called(_a0, _a1) var r0 *internalpb.RetrieveResults var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) (*internalpb.RetrieveResults, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) *internalpb.RetrieveResults); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*internalpb.RetrieveResults) @@ -732,7 +735,7 @@ func (_m *MockQueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRe } if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -746,13 +749,13 @@ type MockQueryNode_QuerySegments_Call struct { } // QuerySegments is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.QueryRequest -func (_e *MockQueryNode_Expecter) QuerySegments(ctx interface{}, req interface{}) *MockQueryNode_QuerySegments_Call { - return &MockQueryNode_QuerySegments_Call{Call: _e.mock.On("QuerySegments", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.QueryRequest +func (_e *MockQueryNode_Expecter) QuerySegments(_a0 interface{}, _a1 interface{}) *MockQueryNode_QuerySegments_Call { + return &MockQueryNode_QuerySegments_Call{Call: _e.mock.On("QuerySegments", _a0, _a1)} } -func (_c *MockQueryNode_QuerySegments_Call) Run(run func(ctx context.Context, req *querypb.QueryRequest)) *MockQueryNode_QuerySegments_Call { +func (_c *MockQueryNode_QuerySegments_Call) Run(run func(_a0 context.Context, _a1 *querypb.QueryRequest)) *MockQueryNode_QuerySegments_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.QueryRequest)) }) @@ -769,6 +772,92 @@ func (_c *MockQueryNode_QuerySegments_Call) RunAndReturn(run func(context.Contex return _c } +// QueryStream provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) QueryStream(_a0 *querypb.QueryRequest, _a1 querypb.QueryNode_QueryStreamServer) error { + ret := _m.Called(_a0, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(*querypb.QueryRequest, querypb.QueryNode_QueryStreamServer) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryNode_QueryStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryStream' +type MockQueryNode_QueryStream_Call struct { + *mock.Call +} + +// QueryStream is a helper method to define mock.On call +// - _a0 *querypb.QueryRequest +// - _a1 querypb.QueryNode_QueryStreamServer +func (_e *MockQueryNode_Expecter) QueryStream(_a0 interface{}, _a1 interface{}) *MockQueryNode_QueryStream_Call { + return &MockQueryNode_QueryStream_Call{Call: _e.mock.On("QueryStream", _a0, _a1)} +} + +func (_c *MockQueryNode_QueryStream_Call) Run(run func(_a0 *querypb.QueryRequest, _a1 querypb.QueryNode_QueryStreamServer)) *MockQueryNode_QueryStream_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*querypb.QueryRequest), args[1].(querypb.QueryNode_QueryStreamServer)) + }) + return _c +} + +func (_c *MockQueryNode_QueryStream_Call) Return(_a0 error) *MockQueryNode_QueryStream_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryNode_QueryStream_Call) RunAndReturn(run func(*querypb.QueryRequest, querypb.QueryNode_QueryStreamServer) error) *MockQueryNode_QueryStream_Call { + _c.Call.Return(run) + return _c +} + +// QueryStreamSegments provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) QueryStreamSegments(_a0 *querypb.QueryRequest, _a1 querypb.QueryNode_QueryStreamSegmentsServer) error { + ret := _m.Called(_a0, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(*querypb.QueryRequest, querypb.QueryNode_QueryStreamSegmentsServer) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryNode_QueryStreamSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryStreamSegments' +type MockQueryNode_QueryStreamSegments_Call struct { + *mock.Call +} + +// QueryStreamSegments is a helper method to define mock.On call +// - _a0 *querypb.QueryRequest +// - _a1 querypb.QueryNode_QueryStreamSegmentsServer +func (_e *MockQueryNode_Expecter) QueryStreamSegments(_a0 interface{}, _a1 interface{}) *MockQueryNode_QueryStreamSegments_Call { + return &MockQueryNode_QueryStreamSegments_Call{Call: _e.mock.On("QueryStreamSegments", _a0, _a1)} +} + +func (_c *MockQueryNode_QueryStreamSegments_Call) Run(run func(_a0 *querypb.QueryRequest, _a1 querypb.QueryNode_QueryStreamSegmentsServer)) *MockQueryNode_QueryStreamSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*querypb.QueryRequest), args[1].(querypb.QueryNode_QueryStreamSegmentsServer)) + }) + return _c +} + +func (_c *MockQueryNode_QueryStreamSegments_Call) Return(_a0 error) *MockQueryNode_QueryStreamSegments_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryNode_QueryStreamSegments_Call) RunAndReturn(run func(*querypb.QueryRequest, querypb.QueryNode_QueryStreamSegmentsServer) error) *MockQueryNode_QueryStreamSegments_Call { + _c.Call.Return(run) + return _c +} + // Register provides a mock function with given fields: func (_m *MockQueryNode) Register() error { ret := _m.Called() @@ -810,17 +899,17 @@ func (_c *MockQueryNode_Register_Call) RunAndReturn(run func() error) *MockQuery return _c } -// ReleaseCollection provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// ReleaseCollection provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) ReleaseCollection(_a0 context.Context, _a1 *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -828,7 +917,7 @@ func (_m *MockQueryNode) ReleaseCollection(ctx context.Context, req *querypb.Rel } if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleaseCollectionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -842,13 +931,13 @@ type MockQueryNode_ReleaseCollection_Call struct { } // ReleaseCollection is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.ReleaseCollectionRequest -func (_e *MockQueryNode_Expecter) ReleaseCollection(ctx interface{}, req interface{}) *MockQueryNode_ReleaseCollection_Call { - return &MockQueryNode_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.ReleaseCollectionRequest +func (_e *MockQueryNode_Expecter) ReleaseCollection(_a0 interface{}, _a1 interface{}) *MockQueryNode_ReleaseCollection_Call { + return &MockQueryNode_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", _a0, _a1)} } -func (_c *MockQueryNode_ReleaseCollection_Call) Run(run func(ctx context.Context, req *querypb.ReleaseCollectionRequest)) *MockQueryNode_ReleaseCollection_Call { +func (_c *MockQueryNode_ReleaseCollection_Call) Run(run func(_a0 context.Context, _a1 *querypb.ReleaseCollectionRequest)) *MockQueryNode_ReleaseCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.ReleaseCollectionRequest)) }) @@ -865,17 +954,17 @@ func (_c *MockQueryNode_ReleaseCollection_Call) RunAndReturn(run func(context.Co return _c } -// ReleasePartitions provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// ReleasePartitions provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) ReleasePartitions(_a0 context.Context, _a1 *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleasePartitionsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleasePartitionsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -883,7 +972,7 @@ func (_m *MockQueryNode) ReleasePartitions(ctx context.Context, req *querypb.Rel } if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleasePartitionsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -897,13 +986,13 @@ type MockQueryNode_ReleasePartitions_Call struct { } // ReleasePartitions is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.ReleasePartitionsRequest -func (_e *MockQueryNode_Expecter) ReleasePartitions(ctx interface{}, req interface{}) *MockQueryNode_ReleasePartitions_Call { - return &MockQueryNode_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.ReleasePartitionsRequest +func (_e *MockQueryNode_Expecter) ReleasePartitions(_a0 interface{}, _a1 interface{}) *MockQueryNode_ReleasePartitions_Call { + return &MockQueryNode_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", _a0, _a1)} } -func (_c *MockQueryNode_ReleasePartitions_Call) Run(run func(ctx context.Context, req *querypb.ReleasePartitionsRequest)) *MockQueryNode_ReleasePartitions_Call { +func (_c *MockQueryNode_ReleasePartitions_Call) Run(run func(_a0 context.Context, _a1 *querypb.ReleasePartitionsRequest)) *MockQueryNode_ReleasePartitions_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.ReleasePartitionsRequest)) }) @@ -920,17 +1009,17 @@ func (_c *MockQueryNode_ReleasePartitions_Call) RunAndReturn(run func(context.Co return _c } -// ReleaseSegments provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// ReleaseSegments provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) ReleaseSegments(_a0 context.Context, _a1 *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseSegmentsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -938,7 +1027,7 @@ func (_m *MockQueryNode) ReleaseSegments(ctx context.Context, req *querypb.Relea } if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleaseSegmentsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -952,13 +1041,13 @@ type MockQueryNode_ReleaseSegments_Call struct { } // ReleaseSegments is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.ReleaseSegmentsRequest -func (_e *MockQueryNode_Expecter) ReleaseSegments(ctx interface{}, req interface{}) *MockQueryNode_ReleaseSegments_Call { - return &MockQueryNode_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.ReleaseSegmentsRequest +func (_e *MockQueryNode_Expecter) ReleaseSegments(_a0 interface{}, _a1 interface{}) *MockQueryNode_ReleaseSegments_Call { + return &MockQueryNode_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", _a0, _a1)} } -func (_c *MockQueryNode_ReleaseSegments_Call) Run(run func(ctx context.Context, req *querypb.ReleaseSegmentsRequest)) *MockQueryNode_ReleaseSegments_Call { +func (_c *MockQueryNode_ReleaseSegments_Call) Run(run func(_a0 context.Context, _a1 *querypb.ReleaseSegmentsRequest)) *MockQueryNode_ReleaseSegments_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.ReleaseSegmentsRequest)) }) @@ -975,17 +1064,17 @@ func (_c *MockQueryNode_ReleaseSegments_Call) RunAndReturn(run func(context.Cont return _c } -// Search provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { - ret := _m.Called(ctx, req) +// Search provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) Search(_a0 context.Context, _a1 *querypb.SearchRequest) (*internalpb.SearchResults, error) { + ret := _m.Called(_a0, _a1) var r0 *internalpb.SearchResults var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) *internalpb.SearchResults); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*internalpb.SearchResults) @@ -993,7 +1082,7 @@ func (_m *MockQueryNode) Search(ctx context.Context, req *querypb.SearchRequest) } if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1007,13 +1096,13 @@ type MockQueryNode_Search_Call struct { } // Search is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.SearchRequest -func (_e *MockQueryNode_Expecter) Search(ctx interface{}, req interface{}) *MockQueryNode_Search_Call { - return &MockQueryNode_Search_Call{Call: _e.mock.On("Search", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.SearchRequest +func (_e *MockQueryNode_Expecter) Search(_a0 interface{}, _a1 interface{}) *MockQueryNode_Search_Call { + return &MockQueryNode_Search_Call{Call: _e.mock.On("Search", _a0, _a1)} } -func (_c *MockQueryNode_Search_Call) Run(run func(ctx context.Context, req *querypb.SearchRequest)) *MockQueryNode_Search_Call { +func (_c *MockQueryNode_Search_Call) Run(run func(_a0 context.Context, _a1 *querypb.SearchRequest)) *MockQueryNode_Search_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.SearchRequest)) }) @@ -1030,17 +1119,17 @@ func (_c *MockQueryNode_Search_Call) RunAndReturn(run func(context.Context, *que return _c } -// SearchSegments provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { - ret := _m.Called(ctx, req) +// SearchSegments provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) SearchSegments(_a0 context.Context, _a1 *querypb.SearchRequest) (*internalpb.SearchResults, error) { + ret := _m.Called(_a0, _a1) var r0 *internalpb.SearchResults var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) (*internalpb.SearchResults, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) *internalpb.SearchResults); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*internalpb.SearchResults) @@ -1048,7 +1137,7 @@ func (_m *MockQueryNode) SearchSegments(ctx context.Context, req *querypb.Search } if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1062,13 +1151,13 @@ type MockQueryNode_SearchSegments_Call struct { } // SearchSegments is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.SearchRequest -func (_e *MockQueryNode_Expecter) SearchSegments(ctx interface{}, req interface{}) *MockQueryNode_SearchSegments_Call { - return &MockQueryNode_SearchSegments_Call{Call: _e.mock.On("SearchSegments", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.SearchRequest +func (_e *MockQueryNode_Expecter) SearchSegments(_a0 interface{}, _a1 interface{}) *MockQueryNode_SearchSegments_Call { + return &MockQueryNode_SearchSegments_Call{Call: _e.mock.On("SearchSegments", _a0, _a1)} } -func (_c *MockQueryNode_SearchSegments_Call) Run(run func(ctx context.Context, req *querypb.SearchRequest)) *MockQueryNode_SearchSegments_Call { +func (_c *MockQueryNode_SearchSegments_Call) Run(run func(_a0 context.Context, _a1 *querypb.SearchRequest)) *MockQueryNode_SearchSegments_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.SearchRequest)) }) @@ -1151,17 +1240,17 @@ func (_c *MockQueryNode_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Clie return _c } -// ShowConfigurations provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - ret := _m.Called(ctx, req) +// ShowConfigurations provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) ShowConfigurations(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *internalpb.ShowConfigurationsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) *internalpb.ShowConfigurationsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*internalpb.ShowConfigurationsResponse) @@ -1169,7 +1258,7 @@ func (_m *MockQueryNode) ShowConfigurations(ctx context.Context, req *internalpb } if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1183,13 +1272,13 @@ type MockQueryNode_ShowConfigurations_Call struct { } // ShowConfigurations is a helper method to define mock.On call -// - ctx context.Context -// - req *internalpb.ShowConfigurationsRequest -func (_e *MockQueryNode_Expecter) ShowConfigurations(ctx interface{}, req interface{}) *MockQueryNode_ShowConfigurations_Call { - return &MockQueryNode_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", ctx, req)} +// - _a0 context.Context +// - _a1 *internalpb.ShowConfigurationsRequest +func (_e *MockQueryNode_Expecter) ShowConfigurations(_a0 interface{}, _a1 interface{}) *MockQueryNode_ShowConfigurations_Call { + return &MockQueryNode_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", _a0, _a1)} } -func (_c *MockQueryNode_ShowConfigurations_Call) Run(run func(ctx context.Context, req *internalpb.ShowConfigurationsRequest)) *MockQueryNode_ShowConfigurations_Call { +func (_c *MockQueryNode_ShowConfigurations_Call) Run(run func(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest)) *MockQueryNode_ShowConfigurations_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*internalpb.ShowConfigurationsRequest)) }) @@ -1343,17 +1432,17 @@ func (_c *MockQueryNode_SyncDistribution_Call) RunAndReturn(run func(context.Con return _c } -// SyncReplicaSegments provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// SyncReplicaSegments provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) SyncReplicaSegments(_a0 context.Context, _a1 *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncReplicaSegmentsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1361,7 +1450,7 @@ func (_m *MockQueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.S } if rf, ok := ret.Get(1).(func(context.Context, *querypb.SyncReplicaSegmentsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1375,13 +1464,13 @@ type MockQueryNode_SyncReplicaSegments_Call struct { } // SyncReplicaSegments is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.SyncReplicaSegmentsRequest -func (_e *MockQueryNode_Expecter) SyncReplicaSegments(ctx interface{}, req interface{}) *MockQueryNode_SyncReplicaSegments_Call { - return &MockQueryNode_SyncReplicaSegments_Call{Call: _e.mock.On("SyncReplicaSegments", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.SyncReplicaSegmentsRequest +func (_e *MockQueryNode_Expecter) SyncReplicaSegments(_a0 interface{}, _a1 interface{}) *MockQueryNode_SyncReplicaSegments_Call { + return &MockQueryNode_SyncReplicaSegments_Call{Call: _e.mock.On("SyncReplicaSegments", _a0, _a1)} } -func (_c *MockQueryNode_SyncReplicaSegments_Call) Run(run func(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest)) *MockQueryNode_SyncReplicaSegments_Call { +func (_c *MockQueryNode_SyncReplicaSegments_Call) Run(run func(_a0 context.Context, _a1 *querypb.SyncReplicaSegmentsRequest)) *MockQueryNode_SyncReplicaSegments_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.SyncReplicaSegmentsRequest)) }) @@ -1398,17 +1487,17 @@ func (_c *MockQueryNode_SyncReplicaSegments_Call) RunAndReturn(run func(context. return _c } -// UnsubDmChannel provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// UnsubDmChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) UnsubDmChannel(_a0 context.Context, _a1 *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.UnsubDmChannelRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.UnsubDmChannelRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1416,7 +1505,7 @@ func (_m *MockQueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubD } if rf, ok := ret.Get(1).(func(context.Context, *querypb.UnsubDmChannelRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1430,13 +1519,13 @@ type MockQueryNode_UnsubDmChannel_Call struct { } // UnsubDmChannel is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.UnsubDmChannelRequest -func (_e *MockQueryNode_Expecter) UnsubDmChannel(ctx interface{}, req interface{}) *MockQueryNode_UnsubDmChannel_Call { - return &MockQueryNode_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.UnsubDmChannelRequest +func (_e *MockQueryNode_Expecter) UnsubDmChannel(_a0 interface{}, _a1 interface{}) *MockQueryNode_UnsubDmChannel_Call { + return &MockQueryNode_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", _a0, _a1)} } -func (_c *MockQueryNode_UnsubDmChannel_Call) Run(run func(ctx context.Context, req *querypb.UnsubDmChannelRequest)) *MockQueryNode_UnsubDmChannel_Call { +func (_c *MockQueryNode_UnsubDmChannel_Call) Run(run func(_a0 context.Context, _a1 *querypb.UnsubDmChannelRequest)) *MockQueryNode_UnsubDmChannel_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.UnsubDmChannelRequest)) }) @@ -1486,17 +1575,17 @@ func (_c *MockQueryNode_UpdateStateCode_Call) RunAndReturn(run func(commonpb.Sta return _c } -// WatchDmChannels provides a mock function with given fields: ctx, req -func (_m *MockQueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// WatchDmChannels provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNode) WatchDmChannels(_a0 context.Context, _a1 *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *querypb.WatchDmChannelsRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *querypb.WatchDmChannelsRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1504,7 +1593,7 @@ func (_m *MockQueryNode) WatchDmChannels(ctx context.Context, req *querypb.Watch } if rf, ok := ret.Get(1).(func(context.Context, *querypb.WatchDmChannelsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1518,13 +1607,13 @@ type MockQueryNode_WatchDmChannels_Call struct { } // WatchDmChannels is a helper method to define mock.On call -// - ctx context.Context -// - req *querypb.WatchDmChannelsRequest -func (_e *MockQueryNode_Expecter) WatchDmChannels(ctx interface{}, req interface{}) *MockQueryNode_WatchDmChannels_Call { - return &MockQueryNode_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", ctx, req)} +// - _a0 context.Context +// - _a1 *querypb.WatchDmChannelsRequest +func (_e *MockQueryNode_Expecter) WatchDmChannels(_a0 interface{}, _a1 interface{}) *MockQueryNode_WatchDmChannels_Call { + return &MockQueryNode_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", _a0, _a1)} } -func (_c *MockQueryNode_WatchDmChannels_Call) Run(run func(ctx context.Context, req *querypb.WatchDmChannelsRequest)) *MockQueryNode_WatchDmChannels_Call { +func (_c *MockQueryNode_WatchDmChannels_Call) Run(run func(_a0 context.Context, _a1 *querypb.WatchDmChannelsRequest)) *MockQueryNode_WatchDmChannels_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*querypb.WatchDmChannelsRequest)) }) diff --git a/internal/mocks/mock_querynode_client.go b/internal/mocks/mock_querynode_client.go new file mode 100644 index 0000000000000..3621a87884222 --- /dev/null +++ b/internal/mocks/mock_querynode_client.go @@ -0,0 +1,1767 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + + commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + + grpc "google.golang.org/grpc" + + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" + + milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + + mock "github.com/stretchr/testify/mock" + + querypb "github.com/milvus-io/milvus/internal/proto/querypb" +) + +// MockQueryNodeClient is an autogenerated mock type for the QueryNodeClient type +type MockQueryNodeClient struct { + mock.Mock +} + +type MockQueryNodeClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockQueryNodeClient) EXPECT() *MockQueryNodeClient_Expecter { + return &MockQueryNodeClient_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with given fields: +func (_m *MockQueryNodeClient) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryNodeClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockQueryNodeClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockQueryNodeClient_Expecter) Close() *MockQueryNodeClient_Close_Call { + return &MockQueryNodeClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockQueryNodeClient_Close_Call) Run(run func()) *MockQueryNodeClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockQueryNodeClient_Close_Call) Return(_a0 error) *MockQueryNodeClient_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryNodeClient_Close_Call) RunAndReturn(run func() error) *MockQueryNodeClient_Close_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) Delete(ctx context.Context, in *querypb.DeleteRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.DeleteRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.DeleteRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.DeleteRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type MockQueryNodeClient_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.DeleteRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) Delete(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_Delete_Call { + return &MockQueryNodeClient_Delete_Call{Call: _e.mock.On("Delete", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_Delete_Call) Run(run func(ctx context.Context, in *querypb.DeleteRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.DeleteRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_Delete_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeClient_Delete_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_Delete_Call) RunAndReturn(run func(context.Context, *querypb.DeleteRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryNodeClient_Delete_Call { + _c.Call.Return(run) + return _c +} + +// GetComponentStates provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ComponentStates + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) *milvuspb.ComponentStates); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ComponentStates) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_GetComponentStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetComponentStates' +type MockQueryNodeClient_GetComponentStates_Call struct { + *mock.Call +} + +// GetComponentStates is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetComponentStatesRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) GetComponentStates(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_GetComponentStates_Call { + return &MockQueryNodeClient_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_GetComponentStates_Call) Run(run func(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_GetComponentStates_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentStates, _a1 error) *MockQueryNodeClient_GetComponentStates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)) *MockQueryNodeClient_GetComponentStates_Call { + _c.Call.Return(run) + return _c +} + +// GetDataDistribution provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) GetDataDistribution(ctx context.Context, in *querypb.GetDataDistributionRequest, opts ...grpc.CallOption) (*querypb.GetDataDistributionResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *querypb.GetDataDistributionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetDataDistributionRequest, ...grpc.CallOption) (*querypb.GetDataDistributionResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetDataDistributionRequest, ...grpc.CallOption) *querypb.GetDataDistributionResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.GetDataDistributionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetDataDistributionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_GetDataDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDataDistribution' +type MockQueryNodeClient_GetDataDistribution_Call struct { + *mock.Call +} + +// GetDataDistribution is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.GetDataDistributionRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) GetDataDistribution(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_GetDataDistribution_Call { + return &MockQueryNodeClient_GetDataDistribution_Call{Call: _e.mock.On("GetDataDistribution", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_GetDataDistribution_Call) Run(run func(ctx context.Context, in *querypb.GetDataDistributionRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_GetDataDistribution_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.GetDataDistributionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_GetDataDistribution_Call) Return(_a0 *querypb.GetDataDistributionResponse, _a1 error) *MockQueryNodeClient_GetDataDistribution_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_GetDataDistribution_Call) RunAndReturn(run func(context.Context, *querypb.GetDataDistributionRequest, ...grpc.CallOption) (*querypb.GetDataDistributionResponse, error)) *MockQueryNodeClient_GetDataDistribution_Call { + _c.Call.Return(run) + return _c +} + +// GetMetrics provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.GetMetricsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) *milvuspb.GetMetricsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_GetMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetMetrics' +type MockQueryNodeClient_GetMetrics_Call struct { + *mock.Call +} + +// GetMetrics is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetMetricsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) GetMetrics(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_GetMetrics_Call { + return &MockQueryNodeClient_GetMetrics_Call{Call: _e.mock.On("GetMetrics", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_GetMetrics_Call) Run(run func(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_GetMetrics_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_GetMetrics_Call) Return(_a0 *milvuspb.GetMetricsResponse, _a1 error) *MockQueryNodeClient_GetMetrics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_GetMetrics_Call) RunAndReturn(run func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)) *MockQueryNodeClient_GetMetrics_Call { + _c.Call.Return(run) + return _c +} + +// GetSegmentInfo provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *querypb.GetSegmentInfoResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetSegmentInfoRequest, ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetSegmentInfoRequest, ...grpc.CallOption) *querypb.GetSegmentInfoResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.GetSegmentInfoResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetSegmentInfoRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_GetSegmentInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSegmentInfo' +type MockQueryNodeClient_GetSegmentInfo_Call struct { + *mock.Call +} + +// GetSegmentInfo is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.GetSegmentInfoRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) GetSegmentInfo(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_GetSegmentInfo_Call { + return &MockQueryNodeClient_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_GetSegmentInfo_Call) Run(run func(ctx context.Context, in *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_GetSegmentInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.GetSegmentInfoRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_GetSegmentInfo_Call) Return(_a0 *querypb.GetSegmentInfoResponse, _a1 error) *MockQueryNodeClient_GetSegmentInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_GetSegmentInfo_Call) RunAndReturn(run func(context.Context, *querypb.GetSegmentInfoRequest, ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error)) *MockQueryNodeClient_GetSegmentInfo_Call { + _c.Call.Return(run) + return _c +} + +// GetStatistics provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) GetStatistics(ctx context.Context, in *querypb.GetStatisticsRequest, opts ...grpc.CallOption) (*internalpb.GetStatisticsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *internalpb.GetStatisticsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest, ...grpc.CallOption) (*internalpb.GetStatisticsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest, ...grpc.CallOption) *internalpb.GetStatisticsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internalpb.GetStatisticsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetStatisticsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_GetStatistics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatistics' +type MockQueryNodeClient_GetStatistics_Call struct { + *mock.Call +} + +// GetStatistics is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.GetStatisticsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) GetStatistics(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_GetStatistics_Call { + return &MockQueryNodeClient_GetStatistics_Call{Call: _e.mock.On("GetStatistics", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_GetStatistics_Call) Run(run func(ctx context.Context, in *querypb.GetStatisticsRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_GetStatistics_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.GetStatisticsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_GetStatistics_Call) Return(_a0 *internalpb.GetStatisticsResponse, _a1 error) *MockQueryNodeClient_GetStatistics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_GetStatistics_Call) RunAndReturn(run func(context.Context, *querypb.GetStatisticsRequest, ...grpc.CallOption) (*internalpb.GetStatisticsResponse, error)) *MockQueryNodeClient_GetStatistics_Call { + _c.Call.Return(run) + return _c +} + +// GetStatisticsChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.StringResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) *milvuspb.StringResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_GetStatisticsChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatisticsChannel' +type MockQueryNodeClient_GetStatisticsChannel_Call struct { + *mock.Call +} + +// GetStatisticsChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.GetStatisticsChannelRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) GetStatisticsChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_GetStatisticsChannel_Call { + return &MockQueryNodeClient_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_GetStatisticsChannel_Call) Run(run func(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_GetStatisticsChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *MockQueryNodeClient_GetStatisticsChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)) *MockQueryNodeClient_GetStatisticsChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetTimeTickChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.StringResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) *milvuspb.StringResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_GetTimeTickChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTimeTickChannel' +type MockQueryNodeClient_GetTimeTickChannel_Call struct { + *mock.Call +} + +// GetTimeTickChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.GetTimeTickChannelRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) GetTimeTickChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_GetTimeTickChannel_Call { + return &MockQueryNodeClient_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_GetTimeTickChannel_Call) Run(run func(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_GetTimeTickChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.GetTimeTickChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_GetTimeTickChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *MockQueryNodeClient_GetTimeTickChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_GetTimeTickChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)) *MockQueryNodeClient_GetTimeTickChannel_Call { + _c.Call.Return(run) + return _c +} + +// LoadPartitions provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadPartitionsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadPartitionsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadPartitionsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_LoadPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadPartitions' +type MockQueryNodeClient_LoadPartitions_Call struct { + *mock.Call +} + +// LoadPartitions is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.LoadPartitionsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) LoadPartitions(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_LoadPartitions_Call { + return &MockQueryNodeClient_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_LoadPartitions_Call) Run(run func(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_LoadPartitions_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.LoadPartitionsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_LoadPartitions_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeClient_LoadPartitions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_LoadPartitions_Call) RunAndReturn(run func(context.Context, *querypb.LoadPartitionsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryNodeClient_LoadPartitions_Call { + _c.Call.Return(run) + return _c +} + +// LoadSegments provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadSegmentsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadSegmentsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadSegmentsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_LoadSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadSegments' +type MockQueryNodeClient_LoadSegments_Call struct { + *mock.Call +} + +// LoadSegments is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.LoadSegmentsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) LoadSegments(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_LoadSegments_Call { + return &MockQueryNodeClient_LoadSegments_Call{Call: _e.mock.On("LoadSegments", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_LoadSegments_Call) Run(run func(ctx context.Context, in *querypb.LoadSegmentsRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_LoadSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.LoadSegmentsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_LoadSegments_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeClient_LoadSegments_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_LoadSegments_Call) RunAndReturn(run func(context.Context, *querypb.LoadSegmentsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryNodeClient_LoadSegments_Call { + _c.Call.Return(run) + return _c +} + +// Query provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) Query(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (*internalpb.RetrieveResults, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *internalpb.RetrieveResults + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) (*internalpb.RetrieveResults, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) *internalpb.RetrieveResults); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internalpb.RetrieveResults) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_Query_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Query' +type MockQueryNodeClient_Query_Call struct { + *mock.Call +} + +// Query is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.QueryRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) Query(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_Query_Call { + return &MockQueryNodeClient_Query_Call{Call: _e.mock.On("Query", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_Query_Call) Run(run func(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_Query_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.QueryRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_Query_Call) Return(_a0 *internalpb.RetrieveResults, _a1 error) *MockQueryNodeClient_Query_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_Query_Call) RunAndReturn(run func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) (*internalpb.RetrieveResults, error)) *MockQueryNodeClient_Query_Call { + _c.Call.Return(run) + return _c +} + +// QuerySegments provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) QuerySegments(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (*internalpb.RetrieveResults, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *internalpb.RetrieveResults + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) (*internalpb.RetrieveResults, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) *internalpb.RetrieveResults); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internalpb.RetrieveResults) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_QuerySegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QuerySegments' +type MockQueryNodeClient_QuerySegments_Call struct { + *mock.Call +} + +// QuerySegments is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.QueryRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) QuerySegments(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_QuerySegments_Call { + return &MockQueryNodeClient_QuerySegments_Call{Call: _e.mock.On("QuerySegments", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_QuerySegments_Call) Run(run func(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_QuerySegments_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.QueryRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_QuerySegments_Call) Return(_a0 *internalpb.RetrieveResults, _a1 error) *MockQueryNodeClient_QuerySegments_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_QuerySegments_Call) RunAndReturn(run func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) (*internalpb.RetrieveResults, error)) *MockQueryNodeClient_QuerySegments_Call { + _c.Call.Return(run) + return _c +} + +// QueryStream provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) QueryStream(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (querypb.QueryNode_QueryStreamClient, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 querypb.QueryNode_QueryStreamClient + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) (querypb.QueryNode_QueryStreamClient, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) querypb.QueryNode_QueryStreamClient); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(querypb.QueryNode_QueryStreamClient) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_QueryStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryStream' +type MockQueryNodeClient_QueryStream_Call struct { + *mock.Call +} + +// QueryStream is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.QueryRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) QueryStream(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_QueryStream_Call { + return &MockQueryNodeClient_QueryStream_Call{Call: _e.mock.On("QueryStream", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_QueryStream_Call) Run(run func(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_QueryStream_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.QueryRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_QueryStream_Call) Return(_a0 querypb.QueryNode_QueryStreamClient, _a1 error) *MockQueryNodeClient_QueryStream_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_QueryStream_Call) RunAndReturn(run func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) (querypb.QueryNode_QueryStreamClient, error)) *MockQueryNodeClient_QueryStream_Call { + _c.Call.Return(run) + return _c +} + +// QueryStreamSegments provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) QueryStreamSegments(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (querypb.QueryNode_QueryStreamSegmentsClient, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 querypb.QueryNode_QueryStreamSegmentsClient + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) (querypb.QueryNode_QueryStreamSegmentsClient, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) querypb.QueryNode_QueryStreamSegmentsClient); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(querypb.QueryNode_QueryStreamSegmentsClient) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_QueryStreamSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryStreamSegments' +type MockQueryNodeClient_QueryStreamSegments_Call struct { + *mock.Call +} + +// QueryStreamSegments is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.QueryRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) QueryStreamSegments(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_QueryStreamSegments_Call { + return &MockQueryNodeClient_QueryStreamSegments_Call{Call: _e.mock.On("QueryStreamSegments", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_QueryStreamSegments_Call) Run(run func(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_QueryStreamSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.QueryRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_QueryStreamSegments_Call) Return(_a0 querypb.QueryNode_QueryStreamSegmentsClient, _a1 error) *MockQueryNodeClient_QueryStreamSegments_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_QueryStreamSegments_Call) RunAndReturn(run func(context.Context, *querypb.QueryRequest, ...grpc.CallOption) (querypb.QueryNode_QueryStreamSegmentsClient, error)) *MockQueryNodeClient_QueryStreamSegments_Call { + _c.Call.Return(run) + return _c +} + +// ReleaseCollection provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseCollectionRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleaseCollectionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_ReleaseCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseCollection' +type MockQueryNodeClient_ReleaseCollection_Call struct { + *mock.Call +} + +// ReleaseCollection is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.ReleaseCollectionRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) ReleaseCollection(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_ReleaseCollection_Call { + return &MockQueryNodeClient_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_ReleaseCollection_Call) Run(run func(ctx context.Context, in *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_ReleaseCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.ReleaseCollectionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_ReleaseCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeClient_ReleaseCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_ReleaseCollection_Call) RunAndReturn(run func(context.Context, *querypb.ReleaseCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryNodeClient_ReleaseCollection_Call { + _c.Call.Return(run) + return _c +} + +// ReleasePartitions provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleasePartitionsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleasePartitionsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleasePartitionsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_ReleasePartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleasePartitions' +type MockQueryNodeClient_ReleasePartitions_Call struct { + *mock.Call +} + +// ReleasePartitions is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.ReleasePartitionsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) ReleasePartitions(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_ReleasePartitions_Call { + return &MockQueryNodeClient_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_ReleasePartitions_Call) Run(run func(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_ReleasePartitions_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.ReleasePartitionsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_ReleasePartitions_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeClient_ReleasePartitions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_ReleasePartitions_Call) RunAndReturn(run func(context.Context, *querypb.ReleasePartitionsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryNodeClient_ReleasePartitions_Call { + _c.Call.Return(run) + return _c +} + +// ReleaseSegments provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseSegmentsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ReleaseSegmentsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.ReleaseSegmentsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_ReleaseSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleaseSegments' +type MockQueryNodeClient_ReleaseSegments_Call struct { + *mock.Call +} + +// ReleaseSegments is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.ReleaseSegmentsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) ReleaseSegments(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_ReleaseSegments_Call { + return &MockQueryNodeClient_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_ReleaseSegments_Call) Run(run func(ctx context.Context, in *querypb.ReleaseSegmentsRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_ReleaseSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.ReleaseSegmentsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_ReleaseSegments_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeClient_ReleaseSegments_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_ReleaseSegments_Call) RunAndReturn(run func(context.Context, *querypb.ReleaseSegmentsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryNodeClient_ReleaseSegments_Call { + _c.Call.Return(run) + return _c +} + +// Search provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) Search(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *internalpb.SearchResults + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest, ...grpc.CallOption) (*internalpb.SearchResults, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest, ...grpc.CallOption) *internalpb.SearchResults); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internalpb.SearchResults) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_Search_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Search' +type MockQueryNodeClient_Search_Call struct { + *mock.Call +} + +// Search is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.SearchRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) Search(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_Search_Call { + return &MockQueryNodeClient_Search_Call{Call: _e.mock.On("Search", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_Search_Call) Run(run func(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_Search_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.SearchRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_Search_Call) Return(_a0 *internalpb.SearchResults, _a1 error) *MockQueryNodeClient_Search_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_Search_Call) RunAndReturn(run func(context.Context, *querypb.SearchRequest, ...grpc.CallOption) (*internalpb.SearchResults, error)) *MockQueryNodeClient_Search_Call { + _c.Call.Return(run) + return _c +} + +// SearchSegments provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *internalpb.SearchResults + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest, ...grpc.CallOption) (*internalpb.SearchResults, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest, ...grpc.CallOption) *internalpb.SearchResults); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internalpb.SearchResults) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_SearchSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SearchSegments' +type MockQueryNodeClient_SearchSegments_Call struct { + *mock.Call +} + +// SearchSegments is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.SearchRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) SearchSegments(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_SearchSegments_Call { + return &MockQueryNodeClient_SearchSegments_Call{Call: _e.mock.On("SearchSegments", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_SearchSegments_Call) Run(run func(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_SearchSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.SearchRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_SearchSegments_Call) Return(_a0 *internalpb.SearchResults, _a1 error) *MockQueryNodeClient_SearchSegments_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_SearchSegments_Call) RunAndReturn(run func(context.Context, *querypb.SearchRequest, ...grpc.CallOption) (*internalpb.SearchResults, error)) *MockQueryNodeClient_SearchSegments_Call { + _c.Call.Return(run) + return _c +} + +// ShowConfigurations provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *internalpb.ShowConfigurationsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) *internalpb.ShowConfigurationsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internalpb.ShowConfigurationsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_ShowConfigurations_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowConfigurations' +type MockQueryNodeClient_ShowConfigurations_Call struct { + *mock.Call +} + +// ShowConfigurations is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.ShowConfigurationsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) ShowConfigurations(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_ShowConfigurations_Call { + return &MockQueryNodeClient_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_ShowConfigurations_Call) Run(run func(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_ShowConfigurations_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.ShowConfigurationsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_ShowConfigurations_Call) Return(_a0 *internalpb.ShowConfigurationsResponse, _a1 error) *MockQueryNodeClient_ShowConfigurations_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_ShowConfigurations_Call) RunAndReturn(run func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error)) *MockQueryNodeClient_ShowConfigurations_Call { + _c.Call.Return(run) + return _c +} + +// SyncDistribution provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) SyncDistribution(ctx context.Context, in *querypb.SyncDistributionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncDistributionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncDistributionRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.SyncDistributionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_SyncDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SyncDistribution' +type MockQueryNodeClient_SyncDistribution_Call struct { + *mock.Call +} + +// SyncDistribution is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.SyncDistributionRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) SyncDistribution(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_SyncDistribution_Call { + return &MockQueryNodeClient_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_SyncDistribution_Call) Run(run func(ctx context.Context, in *querypb.SyncDistributionRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_SyncDistribution_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.SyncDistributionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_SyncDistribution_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeClient_SyncDistribution_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_SyncDistribution_Call) RunAndReturn(run func(context.Context, *querypb.SyncDistributionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryNodeClient_SyncDistribution_Call { + _c.Call.Return(run) + return _c +} + +// SyncReplicaSegments provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) SyncReplicaSegments(ctx context.Context, in *querypb.SyncReplicaSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncReplicaSegmentsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncReplicaSegmentsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.SyncReplicaSegmentsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_SyncReplicaSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SyncReplicaSegments' +type MockQueryNodeClient_SyncReplicaSegments_Call struct { + *mock.Call +} + +// SyncReplicaSegments is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.SyncReplicaSegmentsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) SyncReplicaSegments(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_SyncReplicaSegments_Call { + return &MockQueryNodeClient_SyncReplicaSegments_Call{Call: _e.mock.On("SyncReplicaSegments", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_SyncReplicaSegments_Call) Run(run func(ctx context.Context, in *querypb.SyncReplicaSegmentsRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_SyncReplicaSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.SyncReplicaSegmentsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_SyncReplicaSegments_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeClient_SyncReplicaSegments_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_SyncReplicaSegments_Call) RunAndReturn(run func(context.Context, *querypb.SyncReplicaSegmentsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryNodeClient_SyncReplicaSegments_Call { + _c.Call.Return(run) + return _c +} + +// UnsubDmChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) UnsubDmChannel(ctx context.Context, in *querypb.UnsubDmChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.UnsubDmChannelRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.UnsubDmChannelRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.UnsubDmChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_UnsubDmChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UnsubDmChannel' +type MockQueryNodeClient_UnsubDmChannel_Call struct { + *mock.Call +} + +// UnsubDmChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.UnsubDmChannelRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) UnsubDmChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_UnsubDmChannel_Call { + return &MockQueryNodeClient_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_UnsubDmChannel_Call) Run(run func(ctx context.Context, in *querypb.UnsubDmChannelRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_UnsubDmChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.UnsubDmChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_UnsubDmChannel_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeClient_UnsubDmChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_UnsubDmChannel_Call) RunAndReturn(run func(context.Context, *querypb.UnsubDmChannelRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryNodeClient_UnsubDmChannel_Call { + _c.Call.Return(run) + return _c +} + +// WatchDmChannels provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryNodeClient) WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.WatchDmChannelsRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.WatchDmChannelsRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.WatchDmChannelsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryNodeClient_WatchDmChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchDmChannels' +type MockQueryNodeClient_WatchDmChannels_Call struct { + *mock.Call +} + +// WatchDmChannels is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.WatchDmChannelsRequest +// - opts ...grpc.CallOption +func (_e *MockQueryNodeClient_Expecter) WatchDmChannels(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryNodeClient_WatchDmChannels_Call { + return &MockQueryNodeClient_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryNodeClient_WatchDmChannels_Call) Run(run func(ctx context.Context, in *querypb.WatchDmChannelsRequest, opts ...grpc.CallOption)) *MockQueryNodeClient_WatchDmChannels_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.WatchDmChannelsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryNodeClient_WatchDmChannels_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeClient_WatchDmChannels_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryNodeClient_WatchDmChannels_Call) RunAndReturn(run func(context.Context, *querypb.WatchDmChannelsRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryNodeClient_WatchDmChannels_Call { + _c.Call.Return(run) + return _c +} + +// NewMockQueryNodeClient creates a new instance of MockQueryNodeClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockQueryNodeClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockQueryNodeClient { + mock := &MockQueryNodeClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/mock_rootcoord.go b/internal/mocks/mock_rootcoord.go index 8f3e56785bb14..56dcb5490a22e 100644 --- a/internal/mocks/mock_rootcoord.go +++ b/internal/mocks/mock_rootcoord.go @@ -18,6 +18,8 @@ import ( rootcoordpb "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + txnkv "github.com/tikv/client-go/v2/txnkv" + types "github.com/milvus-io/milvus/internal/types" ) @@ -34,17 +36,17 @@ func (_m *RootCoord) EXPECT() *RootCoord_Expecter { return &RootCoord_Expecter{mock: &_m.Mock} } -// AllocID provides a mock function with given fields: ctx, req -func (_m *RootCoord) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { - ret := _m.Called(ctx, req) +// AllocID provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) AllocID(_a0 context.Context, _a1 *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *rootcoordpb.AllocIDResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.AllocIDRequest) *rootcoordpb.AllocIDResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*rootcoordpb.AllocIDResponse) @@ -52,7 +54,7 @@ func (_m *RootCoord) AllocID(ctx context.Context, req *rootcoordpb.AllocIDReques } if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.AllocIDRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -66,13 +68,13 @@ type RootCoord_AllocID_Call struct { } // AllocID is a helper method to define mock.On call -// - ctx context.Context -// - req *rootcoordpb.AllocIDRequest -func (_e *RootCoord_Expecter) AllocID(ctx interface{}, req interface{}) *RootCoord_AllocID_Call { - return &RootCoord_AllocID_Call{Call: _e.mock.On("AllocID", ctx, req)} +// - _a0 context.Context +// - _a1 *rootcoordpb.AllocIDRequest +func (_e *RootCoord_Expecter) AllocID(_a0 interface{}, _a1 interface{}) *RootCoord_AllocID_Call { + return &RootCoord_AllocID_Call{Call: _e.mock.On("AllocID", _a0, _a1)} } -func (_c *RootCoord_AllocID_Call) Run(run func(ctx context.Context, req *rootcoordpb.AllocIDRequest)) *RootCoord_AllocID_Call { +func (_c *RootCoord_AllocID_Call) Run(run func(_a0 context.Context, _a1 *rootcoordpb.AllocIDRequest)) *RootCoord_AllocID_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*rootcoordpb.AllocIDRequest)) }) @@ -89,17 +91,17 @@ func (_c *RootCoord_AllocID_Call) RunAndReturn(run func(context.Context, *rootco return _c } -// AllocTimestamp provides a mock function with given fields: ctx, req -func (_m *RootCoord) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { - ret := _m.Called(ctx, req) +// AllocTimestamp provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) AllocTimestamp(_a0 context.Context, _a1 *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *rootcoordpb.AllocTimestampResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.AllocTimestampRequest) *rootcoordpb.AllocTimestampResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*rootcoordpb.AllocTimestampResponse) @@ -107,7 +109,7 @@ func (_m *RootCoord) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocT } if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.AllocTimestampRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -121,13 +123,13 @@ type RootCoord_AllocTimestamp_Call struct { } // AllocTimestamp is a helper method to define mock.On call -// - ctx context.Context -// - req *rootcoordpb.AllocTimestampRequest -func (_e *RootCoord_Expecter) AllocTimestamp(ctx interface{}, req interface{}) *RootCoord_AllocTimestamp_Call { - return &RootCoord_AllocTimestamp_Call{Call: _e.mock.On("AllocTimestamp", ctx, req)} +// - _a0 context.Context +// - _a1 *rootcoordpb.AllocTimestampRequest +func (_e *RootCoord_Expecter) AllocTimestamp(_a0 interface{}, _a1 interface{}) *RootCoord_AllocTimestamp_Call { + return &RootCoord_AllocTimestamp_Call{Call: _e.mock.On("AllocTimestamp", _a0, _a1)} } -func (_c *RootCoord_AllocTimestamp_Call) Run(run func(ctx context.Context, req *rootcoordpb.AllocTimestampRequest)) *RootCoord_AllocTimestamp_Call { +func (_c *RootCoord_AllocTimestamp_Call) Run(run func(_a0 context.Context, _a1 *rootcoordpb.AllocTimestampRequest)) *RootCoord_AllocTimestamp_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*rootcoordpb.AllocTimestampRequest)) }) @@ -144,17 +146,17 @@ func (_c *RootCoord_AllocTimestamp_Call) RunAndReturn(run func(context.Context, return _c } -// AlterAlias provides a mock function with given fields: ctx, req -func (_m *RootCoord) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// AlterAlias provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) AlterAlias(_a0 context.Context, _a1 *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterAliasRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterAliasRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -162,7 +164,7 @@ func (_m *RootCoord) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasReq } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.AlterAliasRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -176,13 +178,13 @@ type RootCoord_AlterAlias_Call struct { } // AlterAlias is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.AlterAliasRequest -func (_e *RootCoord_Expecter) AlterAlias(ctx interface{}, req interface{}) *RootCoord_AlterAlias_Call { - return &RootCoord_AlterAlias_Call{Call: _e.mock.On("AlterAlias", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.AlterAliasRequest +func (_e *RootCoord_Expecter) AlterAlias(_a0 interface{}, _a1 interface{}) *RootCoord_AlterAlias_Call { + return &RootCoord_AlterAlias_Call{Call: _e.mock.On("AlterAlias", _a0, _a1)} } -func (_c *RootCoord_AlterAlias_Call) Run(run func(ctx context.Context, req *milvuspb.AlterAliasRequest)) *RootCoord_AlterAlias_Call { +func (_c *RootCoord_AlterAlias_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.AlterAliasRequest)) *RootCoord_AlterAlias_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.AlterAliasRequest)) }) @@ -199,17 +201,17 @@ func (_c *RootCoord_AlterAlias_Call) RunAndReturn(run func(context.Context, *mil return _c } -// AlterCollection provides a mock function with given fields: ctx, request -func (_m *RootCoord) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// AlterCollection provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) AlterCollection(_a0 context.Context, _a1 *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -217,7 +219,7 @@ func (_m *RootCoord) AlterCollection(ctx context.Context, request *milvuspb.Alte } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.AlterCollectionRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -231,13 +233,13 @@ type RootCoord_AlterCollection_Call struct { } // AlterCollection is a helper method to define mock.On call -// - ctx context.Context -// - request *milvuspb.AlterCollectionRequest -func (_e *RootCoord_Expecter) AlterCollection(ctx interface{}, request interface{}) *RootCoord_AlterCollection_Call { - return &RootCoord_AlterCollection_Call{Call: _e.mock.On("AlterCollection", ctx, request)} +// - _a0 context.Context +// - _a1 *milvuspb.AlterCollectionRequest +func (_e *RootCoord_Expecter) AlterCollection(_a0 interface{}, _a1 interface{}) *RootCoord_AlterCollection_Call { + return &RootCoord_AlterCollection_Call{Call: _e.mock.On("AlterCollection", _a0, _a1)} } -func (_c *RootCoord_AlterCollection_Call) Run(run func(ctx context.Context, request *milvuspb.AlterCollectionRequest)) *RootCoord_AlterCollection_Call { +func (_c *RootCoord_AlterCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.AlterCollectionRequest)) *RootCoord_AlterCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.AlterCollectionRequest)) }) @@ -254,17 +256,17 @@ func (_c *RootCoord_AlterCollection_Call) RunAndReturn(run func(context.Context, return _c } -// CheckHealth provides a mock function with given fields: ctx, req -func (_m *RootCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { - ret := _m.Called(ctx, req) +// CheckHealth provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) CheckHealth(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.CheckHealthResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest) *milvuspb.CheckHealthResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) @@ -272,7 +274,7 @@ func (_m *RootCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthR } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CheckHealthRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -286,13 +288,13 @@ type RootCoord_CheckHealth_Call struct { } // CheckHealth is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CheckHealthRequest -func (_e *RootCoord_Expecter) CheckHealth(ctx interface{}, req interface{}) *RootCoord_CheckHealth_Call { - return &RootCoord_CheckHealth_Call{Call: _e.mock.On("CheckHealth", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CheckHealthRequest +func (_e *RootCoord_Expecter) CheckHealth(_a0 interface{}, _a1 interface{}) *RootCoord_CheckHealth_Call { + return &RootCoord_CheckHealth_Call{Call: _e.mock.On("CheckHealth", _a0, _a1)} } -func (_c *RootCoord_CheckHealth_Call) Run(run func(ctx context.Context, req *milvuspb.CheckHealthRequest)) *RootCoord_CheckHealth_Call { +func (_c *RootCoord_CheckHealth_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CheckHealthRequest)) *RootCoord_CheckHealth_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CheckHealthRequest)) }) @@ -309,17 +311,17 @@ func (_c *RootCoord_CheckHealth_Call) RunAndReturn(run func(context.Context, *mi return _c } -// CreateAlias provides a mock function with given fields: ctx, req -func (_m *RootCoord) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// CreateAlias provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) CreateAlias(_a0 context.Context, _a1 *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateAliasRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateAliasRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -327,7 +329,7 @@ func (_m *RootCoord) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasR } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateAliasRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -341,13 +343,13 @@ type RootCoord_CreateAlias_Call struct { } // CreateAlias is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CreateAliasRequest -func (_e *RootCoord_Expecter) CreateAlias(ctx interface{}, req interface{}) *RootCoord_CreateAlias_Call { - return &RootCoord_CreateAlias_Call{Call: _e.mock.On("CreateAlias", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CreateAliasRequest +func (_e *RootCoord_Expecter) CreateAlias(_a0 interface{}, _a1 interface{}) *RootCoord_CreateAlias_Call { + return &RootCoord_CreateAlias_Call{Call: _e.mock.On("CreateAlias", _a0, _a1)} } -func (_c *RootCoord_CreateAlias_Call) Run(run func(ctx context.Context, req *milvuspb.CreateAliasRequest)) *RootCoord_CreateAlias_Call { +func (_c *RootCoord_CreateAlias_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreateAliasRequest)) *RootCoord_CreateAlias_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreateAliasRequest)) }) @@ -364,17 +366,17 @@ func (_c *RootCoord_CreateAlias_Call) RunAndReturn(run func(context.Context, *mi return _c } -// CreateCollection provides a mock function with given fields: ctx, req -func (_m *RootCoord) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// CreateCollection provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) CreateCollection(_a0 context.Context, _a1 *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -382,7 +384,7 @@ func (_m *RootCoord) CreateCollection(ctx context.Context, req *milvuspb.CreateC } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateCollectionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -396,13 +398,13 @@ type RootCoord_CreateCollection_Call struct { } // CreateCollection is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CreateCollectionRequest -func (_e *RootCoord_Expecter) CreateCollection(ctx interface{}, req interface{}) *RootCoord_CreateCollection_Call { - return &RootCoord_CreateCollection_Call{Call: _e.mock.On("CreateCollection", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CreateCollectionRequest +func (_e *RootCoord_Expecter) CreateCollection(_a0 interface{}, _a1 interface{}) *RootCoord_CreateCollection_Call { + return &RootCoord_CreateCollection_Call{Call: _e.mock.On("CreateCollection", _a0, _a1)} } -func (_c *RootCoord_CreateCollection_Call) Run(run func(ctx context.Context, req *milvuspb.CreateCollectionRequest)) *RootCoord_CreateCollection_Call { +func (_c *RootCoord_CreateCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreateCollectionRequest)) *RootCoord_CreateCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreateCollectionRequest)) }) @@ -419,17 +421,17 @@ func (_c *RootCoord_CreateCollection_Call) RunAndReturn(run func(context.Context return _c } -// CreateCredential provides a mock function with given fields: ctx, req -func (_m *RootCoord) CreateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// CreateCredential provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) CreateCredential(_a0 context.Context, _a1 *internalpb.CredentialInfo) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *internalpb.CredentialInfo) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *internalpb.CredentialInfo) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -437,7 +439,7 @@ func (_m *RootCoord) CreateCredential(ctx context.Context, req *internalpb.Crede } if rf, ok := ret.Get(1).(func(context.Context, *internalpb.CredentialInfo) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -451,13 +453,13 @@ type RootCoord_CreateCredential_Call struct { } // CreateCredential is a helper method to define mock.On call -// - ctx context.Context -// - req *internalpb.CredentialInfo -func (_e *RootCoord_Expecter) CreateCredential(ctx interface{}, req interface{}) *RootCoord_CreateCredential_Call { - return &RootCoord_CreateCredential_Call{Call: _e.mock.On("CreateCredential", ctx, req)} +// - _a0 context.Context +// - _a1 *internalpb.CredentialInfo +func (_e *RootCoord_Expecter) CreateCredential(_a0 interface{}, _a1 interface{}) *RootCoord_CreateCredential_Call { + return &RootCoord_CreateCredential_Call{Call: _e.mock.On("CreateCredential", _a0, _a1)} } -func (_c *RootCoord_CreateCredential_Call) Run(run func(ctx context.Context, req *internalpb.CredentialInfo)) *RootCoord_CreateCredential_Call { +func (_c *RootCoord_CreateCredential_Call) Run(run func(_a0 context.Context, _a1 *internalpb.CredentialInfo)) *RootCoord_CreateCredential_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*internalpb.CredentialInfo)) }) @@ -474,17 +476,17 @@ func (_c *RootCoord_CreateCredential_Call) RunAndReturn(run func(context.Context return _c } -// CreateDatabase provides a mock function with given fields: ctx, req -func (_m *RootCoord) CreateDatabase(ctx context.Context, req *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// CreateDatabase provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) CreateDatabase(_a0 context.Context, _a1 *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateDatabaseRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -492,7 +494,7 @@ func (_m *RootCoord) CreateDatabase(ctx context.Context, req *milvuspb.CreateDat } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateDatabaseRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -506,13 +508,13 @@ type RootCoord_CreateDatabase_Call struct { } // CreateDatabase is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CreateDatabaseRequest -func (_e *RootCoord_Expecter) CreateDatabase(ctx interface{}, req interface{}) *RootCoord_CreateDatabase_Call { - return &RootCoord_CreateDatabase_Call{Call: _e.mock.On("CreateDatabase", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CreateDatabaseRequest +func (_e *RootCoord_Expecter) CreateDatabase(_a0 interface{}, _a1 interface{}) *RootCoord_CreateDatabase_Call { + return &RootCoord_CreateDatabase_Call{Call: _e.mock.On("CreateDatabase", _a0, _a1)} } -func (_c *RootCoord_CreateDatabase_Call) Run(run func(ctx context.Context, req *milvuspb.CreateDatabaseRequest)) *RootCoord_CreateDatabase_Call { +func (_c *RootCoord_CreateDatabase_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreateDatabaseRequest)) *RootCoord_CreateDatabase_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreateDatabaseRequest)) }) @@ -529,17 +531,17 @@ func (_c *RootCoord_CreateDatabase_Call) RunAndReturn(run func(context.Context, return _c } -// CreatePartition provides a mock function with given fields: ctx, req -func (_m *RootCoord) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// CreatePartition provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) CreatePartition(_a0 context.Context, _a1 *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreatePartitionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreatePartitionRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -547,7 +549,7 @@ func (_m *RootCoord) CreatePartition(ctx context.Context, req *milvuspb.CreatePa } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreatePartitionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -561,13 +563,13 @@ type RootCoord_CreatePartition_Call struct { } // CreatePartition is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CreatePartitionRequest -func (_e *RootCoord_Expecter) CreatePartition(ctx interface{}, req interface{}) *RootCoord_CreatePartition_Call { - return &RootCoord_CreatePartition_Call{Call: _e.mock.On("CreatePartition", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CreatePartitionRequest +func (_e *RootCoord_Expecter) CreatePartition(_a0 interface{}, _a1 interface{}) *RootCoord_CreatePartition_Call { + return &RootCoord_CreatePartition_Call{Call: _e.mock.On("CreatePartition", _a0, _a1)} } -func (_c *RootCoord_CreatePartition_Call) Run(run func(ctx context.Context, req *milvuspb.CreatePartitionRequest)) *RootCoord_CreatePartition_Call { +func (_c *RootCoord_CreatePartition_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreatePartitionRequest)) *RootCoord_CreatePartition_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreatePartitionRequest)) }) @@ -584,17 +586,17 @@ func (_c *RootCoord_CreatePartition_Call) RunAndReturn(run func(context.Context, return _c } -// CreateRole provides a mock function with given fields: ctx, req -func (_m *RootCoord) CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// CreateRole provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) CreateRole(_a0 context.Context, _a1 *milvuspb.CreateRoleRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateRoleRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateRoleRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -602,7 +604,7 @@ func (_m *RootCoord) CreateRole(ctx context.Context, req *milvuspb.CreateRoleReq } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateRoleRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -616,13 +618,13 @@ type RootCoord_CreateRole_Call struct { } // CreateRole is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.CreateRoleRequest -func (_e *RootCoord_Expecter) CreateRole(ctx interface{}, req interface{}) *RootCoord_CreateRole_Call { - return &RootCoord_CreateRole_Call{Call: _e.mock.On("CreateRole", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.CreateRoleRequest +func (_e *RootCoord_Expecter) CreateRole(_a0 interface{}, _a1 interface{}) *RootCoord_CreateRole_Call { + return &RootCoord_CreateRole_Call{Call: _e.mock.On("CreateRole", _a0, _a1)} } -func (_c *RootCoord_CreateRole_Call) Run(run func(ctx context.Context, req *milvuspb.CreateRoleRequest)) *RootCoord_CreateRole_Call { +func (_c *RootCoord_CreateRole_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.CreateRoleRequest)) *RootCoord_CreateRole_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.CreateRoleRequest)) }) @@ -639,17 +641,17 @@ func (_c *RootCoord_CreateRole_Call) RunAndReturn(run func(context.Context, *mil return _c } -// DeleteCredential provides a mock function with given fields: ctx, req -func (_m *RootCoord) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// DeleteCredential provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) DeleteCredential(_a0 context.Context, _a1 *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DeleteCredentialRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -657,7 +659,7 @@ func (_m *RootCoord) DeleteCredential(ctx context.Context, req *milvuspb.DeleteC } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DeleteCredentialRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -671,13 +673,13 @@ type RootCoord_DeleteCredential_Call struct { } // DeleteCredential is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DeleteCredentialRequest -func (_e *RootCoord_Expecter) DeleteCredential(ctx interface{}, req interface{}) *RootCoord_DeleteCredential_Call { - return &RootCoord_DeleteCredential_Call{Call: _e.mock.On("DeleteCredential", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DeleteCredentialRequest +func (_e *RootCoord_Expecter) DeleteCredential(_a0 interface{}, _a1 interface{}) *RootCoord_DeleteCredential_Call { + return &RootCoord_DeleteCredential_Call{Call: _e.mock.On("DeleteCredential", _a0, _a1)} } -func (_c *RootCoord_DeleteCredential_Call) Run(run func(ctx context.Context, req *milvuspb.DeleteCredentialRequest)) *RootCoord_DeleteCredential_Call { +func (_c *RootCoord_DeleteCredential_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DeleteCredentialRequest)) *RootCoord_DeleteCredential_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DeleteCredentialRequest)) }) @@ -694,17 +696,17 @@ func (_c *RootCoord_DeleteCredential_Call) RunAndReturn(run func(context.Context return _c } -// DescribeCollection provides a mock function with given fields: ctx, req -func (_m *RootCoord) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { - ret := _m.Called(ctx, req) +// DescribeCollection provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) DescribeCollection(_a0 context.Context, _a1 *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.DescribeCollectionResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeCollectionRequest) *milvuspb.DescribeCollectionResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.DescribeCollectionResponse) @@ -712,7 +714,7 @@ func (_m *RootCoord) DescribeCollection(ctx context.Context, req *milvuspb.Descr } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DescribeCollectionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -726,13 +728,13 @@ type RootCoord_DescribeCollection_Call struct { } // DescribeCollection is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DescribeCollectionRequest -func (_e *RootCoord_Expecter) DescribeCollection(ctx interface{}, req interface{}) *RootCoord_DescribeCollection_Call { - return &RootCoord_DescribeCollection_Call{Call: _e.mock.On("DescribeCollection", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DescribeCollectionRequest +func (_e *RootCoord_Expecter) DescribeCollection(_a0 interface{}, _a1 interface{}) *RootCoord_DescribeCollection_Call { + return &RootCoord_DescribeCollection_Call{Call: _e.mock.On("DescribeCollection", _a0, _a1)} } -func (_c *RootCoord_DescribeCollection_Call) Run(run func(ctx context.Context, req *milvuspb.DescribeCollectionRequest)) *RootCoord_DescribeCollection_Call { +func (_c *RootCoord_DescribeCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DescribeCollectionRequest)) *RootCoord_DescribeCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DescribeCollectionRequest)) }) @@ -749,17 +751,17 @@ func (_c *RootCoord_DescribeCollection_Call) RunAndReturn(run func(context.Conte return _c } -// DescribeCollectionInternal provides a mock function with given fields: ctx, req -func (_m *RootCoord) DescribeCollectionInternal(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { - ret := _m.Called(ctx, req) +// DescribeCollectionInternal provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) DescribeCollectionInternal(_a0 context.Context, _a1 *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.DescribeCollectionResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeCollectionRequest) *milvuspb.DescribeCollectionResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.DescribeCollectionResponse) @@ -767,7 +769,7 @@ func (_m *RootCoord) DescribeCollectionInternal(ctx context.Context, req *milvus } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DescribeCollectionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -781,13 +783,13 @@ type RootCoord_DescribeCollectionInternal_Call struct { } // DescribeCollectionInternal is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DescribeCollectionRequest -func (_e *RootCoord_Expecter) DescribeCollectionInternal(ctx interface{}, req interface{}) *RootCoord_DescribeCollectionInternal_Call { - return &RootCoord_DescribeCollectionInternal_Call{Call: _e.mock.On("DescribeCollectionInternal", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DescribeCollectionRequest +func (_e *RootCoord_Expecter) DescribeCollectionInternal(_a0 interface{}, _a1 interface{}) *RootCoord_DescribeCollectionInternal_Call { + return &RootCoord_DescribeCollectionInternal_Call{Call: _e.mock.On("DescribeCollectionInternal", _a0, _a1)} } -func (_c *RootCoord_DescribeCollectionInternal_Call) Run(run func(ctx context.Context, req *milvuspb.DescribeCollectionRequest)) *RootCoord_DescribeCollectionInternal_Call { +func (_c *RootCoord_DescribeCollectionInternal_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DescribeCollectionRequest)) *RootCoord_DescribeCollectionInternal_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DescribeCollectionRequest)) }) @@ -804,17 +806,17 @@ func (_c *RootCoord_DescribeCollectionInternal_Call) RunAndReturn(run func(conte return _c } -// DropAlias provides a mock function with given fields: ctx, req -func (_m *RootCoord) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// DropAlias provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) DropAlias(_a0 context.Context, _a1 *milvuspb.DropAliasRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropAliasRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropAliasRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -822,7 +824,7 @@ func (_m *RootCoord) DropAlias(ctx context.Context, req *milvuspb.DropAliasReque } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropAliasRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -836,13 +838,13 @@ type RootCoord_DropAlias_Call struct { } // DropAlias is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DropAliasRequest -func (_e *RootCoord_Expecter) DropAlias(ctx interface{}, req interface{}) *RootCoord_DropAlias_Call { - return &RootCoord_DropAlias_Call{Call: _e.mock.On("DropAlias", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DropAliasRequest +func (_e *RootCoord_Expecter) DropAlias(_a0 interface{}, _a1 interface{}) *RootCoord_DropAlias_Call { + return &RootCoord_DropAlias_Call{Call: _e.mock.On("DropAlias", _a0, _a1)} } -func (_c *RootCoord_DropAlias_Call) Run(run func(ctx context.Context, req *milvuspb.DropAliasRequest)) *RootCoord_DropAlias_Call { +func (_c *RootCoord_DropAlias_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DropAliasRequest)) *RootCoord_DropAlias_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DropAliasRequest)) }) @@ -859,17 +861,17 @@ func (_c *RootCoord_DropAlias_Call) RunAndReturn(run func(context.Context, *milv return _c } -// DropCollection provides a mock function with given fields: ctx, req -func (_m *RootCoord) DropCollection(ctx context.Context, req *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// DropCollection provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) DropCollection(_a0 context.Context, _a1 *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -877,7 +879,7 @@ func (_m *RootCoord) DropCollection(ctx context.Context, req *milvuspb.DropColle } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropCollectionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -891,13 +893,13 @@ type RootCoord_DropCollection_Call struct { } // DropCollection is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DropCollectionRequest -func (_e *RootCoord_Expecter) DropCollection(ctx interface{}, req interface{}) *RootCoord_DropCollection_Call { - return &RootCoord_DropCollection_Call{Call: _e.mock.On("DropCollection", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DropCollectionRequest +func (_e *RootCoord_Expecter) DropCollection(_a0 interface{}, _a1 interface{}) *RootCoord_DropCollection_Call { + return &RootCoord_DropCollection_Call{Call: _e.mock.On("DropCollection", _a0, _a1)} } -func (_c *RootCoord_DropCollection_Call) Run(run func(ctx context.Context, req *milvuspb.DropCollectionRequest)) *RootCoord_DropCollection_Call { +func (_c *RootCoord_DropCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DropCollectionRequest)) *RootCoord_DropCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DropCollectionRequest)) }) @@ -914,17 +916,17 @@ func (_c *RootCoord_DropCollection_Call) RunAndReturn(run func(context.Context, return _c } -// DropDatabase provides a mock function with given fields: ctx, req -func (_m *RootCoord) DropDatabase(ctx context.Context, req *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// DropDatabase provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) DropDatabase(_a0 context.Context, _a1 *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropDatabaseRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropDatabaseRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -932,7 +934,7 @@ func (_m *RootCoord) DropDatabase(ctx context.Context, req *milvuspb.DropDatabas } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropDatabaseRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -946,13 +948,13 @@ type RootCoord_DropDatabase_Call struct { } // DropDatabase is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DropDatabaseRequest -func (_e *RootCoord_Expecter) DropDatabase(ctx interface{}, req interface{}) *RootCoord_DropDatabase_Call { - return &RootCoord_DropDatabase_Call{Call: _e.mock.On("DropDatabase", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DropDatabaseRequest +func (_e *RootCoord_Expecter) DropDatabase(_a0 interface{}, _a1 interface{}) *RootCoord_DropDatabase_Call { + return &RootCoord_DropDatabase_Call{Call: _e.mock.On("DropDatabase", _a0, _a1)} } -func (_c *RootCoord_DropDatabase_Call) Run(run func(ctx context.Context, req *milvuspb.DropDatabaseRequest)) *RootCoord_DropDatabase_Call { +func (_c *RootCoord_DropDatabase_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DropDatabaseRequest)) *RootCoord_DropDatabase_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DropDatabaseRequest)) }) @@ -969,17 +971,17 @@ func (_c *RootCoord_DropDatabase_Call) RunAndReturn(run func(context.Context, *m return _c } -// DropPartition provides a mock function with given fields: ctx, req -func (_m *RootCoord) DropPartition(ctx context.Context, req *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// DropPartition provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) DropPartition(_a0 context.Context, _a1 *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropPartitionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropPartitionRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -987,7 +989,7 @@ func (_m *RootCoord) DropPartition(ctx context.Context, req *milvuspb.DropPartit } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropPartitionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1001,13 +1003,13 @@ type RootCoord_DropPartition_Call struct { } // DropPartition is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DropPartitionRequest -func (_e *RootCoord_Expecter) DropPartition(ctx interface{}, req interface{}) *RootCoord_DropPartition_Call { - return &RootCoord_DropPartition_Call{Call: _e.mock.On("DropPartition", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DropPartitionRequest +func (_e *RootCoord_Expecter) DropPartition(_a0 interface{}, _a1 interface{}) *RootCoord_DropPartition_Call { + return &RootCoord_DropPartition_Call{Call: _e.mock.On("DropPartition", _a0, _a1)} } -func (_c *RootCoord_DropPartition_Call) Run(run func(ctx context.Context, req *milvuspb.DropPartitionRequest)) *RootCoord_DropPartition_Call { +func (_c *RootCoord_DropPartition_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DropPartitionRequest)) *RootCoord_DropPartition_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DropPartitionRequest)) }) @@ -1024,17 +1026,17 @@ func (_c *RootCoord_DropPartition_Call) RunAndReturn(run func(context.Context, * return _c } -// DropRole provides a mock function with given fields: ctx, req -func (_m *RootCoord) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// DropRole provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) DropRole(_a0 context.Context, _a1 *milvuspb.DropRoleRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropRoleRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropRoleRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1042,7 +1044,7 @@ func (_m *RootCoord) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropRoleRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1056,13 +1058,13 @@ type RootCoord_DropRole_Call struct { } // DropRole is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.DropRoleRequest -func (_e *RootCoord_Expecter) DropRole(ctx interface{}, req interface{}) *RootCoord_DropRole_Call { - return &RootCoord_DropRole_Call{Call: _e.mock.On("DropRole", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.DropRoleRequest +func (_e *RootCoord_Expecter) DropRole(_a0 interface{}, _a1 interface{}) *RootCoord_DropRole_Call { + return &RootCoord_DropRole_Call{Call: _e.mock.On("DropRole", _a0, _a1)} } -func (_c *RootCoord_DropRole_Call) Run(run func(ctx context.Context, req *milvuspb.DropRoleRequest)) *RootCoord_DropRole_Call { +func (_c *RootCoord_DropRole_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.DropRoleRequest)) *RootCoord_DropRole_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.DropRoleRequest)) }) @@ -1079,25 +1081,25 @@ func (_c *RootCoord_DropRole_Call) RunAndReturn(run func(context.Context, *milvu return _c } -// GetComponentStates provides a mock function with given fields: ctx -func (_m *RootCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret := _m.Called(ctx) +// GetComponentStates provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) GetComponentStates(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ComponentStates var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.ComponentStates, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.ComponentStates); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest) *milvuspb.ComponentStates); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ComponentStates) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1111,14 +1113,15 @@ type RootCoord_GetComponentStates_Call struct { } // GetComponentStates is a helper method to define mock.On call -// - ctx context.Context -func (_e *RootCoord_Expecter) GetComponentStates(ctx interface{}) *RootCoord_GetComponentStates_Call { - return &RootCoord_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx)} +// - _a0 context.Context +// - _a1 *milvuspb.GetComponentStatesRequest +func (_e *RootCoord_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *RootCoord_GetComponentStates_Call { + return &RootCoord_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)} } -func (_c *RootCoord_GetComponentStates_Call) Run(run func(ctx context.Context)) *RootCoord_GetComponentStates_Call { +func (_c *RootCoord_GetComponentStates_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetComponentStatesRequest)) *RootCoord_GetComponentStates_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest)) }) return _c } @@ -1128,22 +1131,22 @@ func (_c *RootCoord_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentState return _c } -func (_c *RootCoord_GetComponentStates_Call) RunAndReturn(run func(context.Context) (*milvuspb.ComponentStates, error)) *RootCoord_GetComponentStates_Call { +func (_c *RootCoord_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error)) *RootCoord_GetComponentStates_Call { _c.Call.Return(run) return _c } -// GetCredential provides a mock function with given fields: ctx, req -func (_m *RootCoord) GetCredential(ctx context.Context, req *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) { - ret := _m.Called(ctx, req) +// GetCredential provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) GetCredential(_a0 context.Context, _a1 *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *rootcoordpb.GetCredentialResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.GetCredentialRequest) *rootcoordpb.GetCredentialResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*rootcoordpb.GetCredentialResponse) @@ -1151,7 +1154,7 @@ func (_m *RootCoord) GetCredential(ctx context.Context, req *rootcoordpb.GetCred } if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.GetCredentialRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1165,13 +1168,13 @@ type RootCoord_GetCredential_Call struct { } // GetCredential is a helper method to define mock.On call -// - ctx context.Context -// - req *rootcoordpb.GetCredentialRequest -func (_e *RootCoord_Expecter) GetCredential(ctx interface{}, req interface{}) *RootCoord_GetCredential_Call { - return &RootCoord_GetCredential_Call{Call: _e.mock.On("GetCredential", ctx, req)} +// - _a0 context.Context +// - _a1 *rootcoordpb.GetCredentialRequest +func (_e *RootCoord_Expecter) GetCredential(_a0 interface{}, _a1 interface{}) *RootCoord_GetCredential_Call { + return &RootCoord_GetCredential_Call{Call: _e.mock.On("GetCredential", _a0, _a1)} } -func (_c *RootCoord_GetCredential_Call) Run(run func(ctx context.Context, req *rootcoordpb.GetCredentialRequest)) *RootCoord_GetCredential_Call { +func (_c *RootCoord_GetCredential_Call) Run(run func(_a0 context.Context, _a1 *rootcoordpb.GetCredentialRequest)) *RootCoord_GetCredential_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*rootcoordpb.GetCredentialRequest)) }) @@ -1188,17 +1191,17 @@ func (_c *RootCoord_GetCredential_Call) RunAndReturn(run func(context.Context, * return _c } -// GetImportState provides a mock function with given fields: ctx, req -func (_m *RootCoord) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) { - ret := _m.Called(ctx, req) +// GetImportState provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) GetImportState(_a0 context.Context, _a1 *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.GetImportStateResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetImportStateRequest) *milvuspb.GetImportStateResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.GetImportStateResponse) @@ -1206,7 +1209,7 @@ func (_m *RootCoord) GetImportState(ctx context.Context, req *milvuspb.GetImport } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetImportStateRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1220,13 +1223,13 @@ type RootCoord_GetImportState_Call struct { } // GetImportState is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.GetImportStateRequest -func (_e *RootCoord_Expecter) GetImportState(ctx interface{}, req interface{}) *RootCoord_GetImportState_Call { - return &RootCoord_GetImportState_Call{Call: _e.mock.On("GetImportState", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.GetImportStateRequest +func (_e *RootCoord_Expecter) GetImportState(_a0 interface{}, _a1 interface{}) *RootCoord_GetImportState_Call { + return &RootCoord_GetImportState_Call{Call: _e.mock.On("GetImportState", _a0, _a1)} } -func (_c *RootCoord_GetImportState_Call) Run(run func(ctx context.Context, req *milvuspb.GetImportStateRequest)) *RootCoord_GetImportState_Call { +func (_c *RootCoord_GetImportState_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.GetImportStateRequest)) *RootCoord_GetImportState_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.GetImportStateRequest)) }) @@ -1298,25 +1301,25 @@ func (_c *RootCoord_GetMetrics_Call) RunAndReturn(run func(context.Context, *mil return _c } -// GetStatisticsChannel provides a mock function with given fields: ctx -func (_m *RootCoord) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret := _m.Called(ctx) +// GetStatisticsChannel provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) GetStatisticsChannel(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.StringResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.StringResponse, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.StringResponse); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest) *milvuspb.StringResponse); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.StringResponse) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1330,14 +1333,15 @@ type RootCoord_GetStatisticsChannel_Call struct { } // GetStatisticsChannel is a helper method to define mock.On call -// - ctx context.Context -func (_e *RootCoord_Expecter) GetStatisticsChannel(ctx interface{}) *RootCoord_GetStatisticsChannel_Call { - return &RootCoord_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", ctx)} +// - _a0 context.Context +// - _a1 *internalpb.GetStatisticsChannelRequest +func (_e *RootCoord_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *RootCoord_GetStatisticsChannel_Call { + return &RootCoord_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)} } -func (_c *RootCoord_GetStatisticsChannel_Call) Run(run func(ctx context.Context)) *RootCoord_GetStatisticsChannel_Call { +func (_c *RootCoord_GetStatisticsChannel_Call) Run(run func(_a0 context.Context, _a1 *internalpb.GetStatisticsChannelRequest)) *RootCoord_GetStatisticsChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest)) }) return _c } @@ -1347,30 +1351,30 @@ func (_c *RootCoord_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringRespon return _c } -func (_c *RootCoord_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context) (*milvuspb.StringResponse, error)) *RootCoord_GetStatisticsChannel_Call { +func (_c *RootCoord_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error)) *RootCoord_GetStatisticsChannel_Call { _c.Call.Return(run) return _c } -// GetTimeTickChannel provides a mock function with given fields: ctx -func (_m *RootCoord) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret := _m.Called(ctx) +// GetTimeTickChannel provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) GetTimeTickChannel(_a0 context.Context, _a1 *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.StringResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*milvuspb.StringResponse, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error)); ok { + return rf(_a0, _a1) } - if rf, ok := ret.Get(0).(func(context.Context) *milvuspb.StringResponse); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest) *milvuspb.StringResponse); ok { + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.StringResponse) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetTimeTickChannelRequest) error); ok { + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1384,14 +1388,15 @@ type RootCoord_GetTimeTickChannel_Call struct { } // GetTimeTickChannel is a helper method to define mock.On call -// - ctx context.Context -func (_e *RootCoord_Expecter) GetTimeTickChannel(ctx interface{}) *RootCoord_GetTimeTickChannel_Call { - return &RootCoord_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", ctx)} +// - _a0 context.Context +// - _a1 *internalpb.GetTimeTickChannelRequest +func (_e *RootCoord_Expecter) GetTimeTickChannel(_a0 interface{}, _a1 interface{}) *RootCoord_GetTimeTickChannel_Call { + return &RootCoord_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", _a0, _a1)} } -func (_c *RootCoord_GetTimeTickChannel_Call) Run(run func(ctx context.Context)) *RootCoord_GetTimeTickChannel_Call { +func (_c *RootCoord_GetTimeTickChannel_Call) Run(run func(_a0 context.Context, _a1 *internalpb.GetTimeTickChannelRequest)) *RootCoord_GetTimeTickChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run(args[0].(context.Context), args[1].(*internalpb.GetTimeTickChannelRequest)) }) return _c } @@ -1401,22 +1406,22 @@ func (_c *RootCoord_GetTimeTickChannel_Call) Return(_a0 *milvuspb.StringResponse return _c } -func (_c *RootCoord_GetTimeTickChannel_Call) RunAndReturn(run func(context.Context) (*milvuspb.StringResponse, error)) *RootCoord_GetTimeTickChannel_Call { +func (_c *RootCoord_GetTimeTickChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error)) *RootCoord_GetTimeTickChannel_Call { _c.Call.Return(run) return _c } -// HasCollection provides a mock function with given fields: ctx, req -func (_m *RootCoord) HasCollection(ctx context.Context, req *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { - ret := _m.Called(ctx, req) +// HasCollection provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) HasCollection(_a0 context.Context, _a1 *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.BoolResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasCollectionRequest) *milvuspb.BoolResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.BoolResponse) @@ -1424,7 +1429,7 @@ func (_m *RootCoord) HasCollection(ctx context.Context, req *milvuspb.HasCollect } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.HasCollectionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1438,13 +1443,13 @@ type RootCoord_HasCollection_Call struct { } // HasCollection is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.HasCollectionRequest -func (_e *RootCoord_Expecter) HasCollection(ctx interface{}, req interface{}) *RootCoord_HasCollection_Call { - return &RootCoord_HasCollection_Call{Call: _e.mock.On("HasCollection", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.HasCollectionRequest +func (_e *RootCoord_Expecter) HasCollection(_a0 interface{}, _a1 interface{}) *RootCoord_HasCollection_Call { + return &RootCoord_HasCollection_Call{Call: _e.mock.On("HasCollection", _a0, _a1)} } -func (_c *RootCoord_HasCollection_Call) Run(run func(ctx context.Context, req *milvuspb.HasCollectionRequest)) *RootCoord_HasCollection_Call { +func (_c *RootCoord_HasCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.HasCollectionRequest)) *RootCoord_HasCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.HasCollectionRequest)) }) @@ -1461,17 +1466,17 @@ func (_c *RootCoord_HasCollection_Call) RunAndReturn(run func(context.Context, * return _c } -// HasPartition provides a mock function with given fields: ctx, req -func (_m *RootCoord) HasPartition(ctx context.Context, req *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { - ret := _m.Called(ctx, req) +// HasPartition provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) HasPartition(_a0 context.Context, _a1 *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.BoolResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasPartitionRequest) *milvuspb.BoolResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.BoolResponse) @@ -1479,7 +1484,7 @@ func (_m *RootCoord) HasPartition(ctx context.Context, req *milvuspb.HasPartitio } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.HasPartitionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1493,13 +1498,13 @@ type RootCoord_HasPartition_Call struct { } // HasPartition is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.HasPartitionRequest -func (_e *RootCoord_Expecter) HasPartition(ctx interface{}, req interface{}) *RootCoord_HasPartition_Call { - return &RootCoord_HasPartition_Call{Call: _e.mock.On("HasPartition", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.HasPartitionRequest +func (_e *RootCoord_Expecter) HasPartition(_a0 interface{}, _a1 interface{}) *RootCoord_HasPartition_Call { + return &RootCoord_HasPartition_Call{Call: _e.mock.On("HasPartition", _a0, _a1)} } -func (_c *RootCoord_HasPartition_Call) Run(run func(ctx context.Context, req *milvuspb.HasPartitionRequest)) *RootCoord_HasPartition_Call { +func (_c *RootCoord_HasPartition_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.HasPartitionRequest)) *RootCoord_HasPartition_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.HasPartitionRequest)) }) @@ -1516,17 +1521,17 @@ func (_c *RootCoord_HasPartition_Call) RunAndReturn(run func(context.Context, *m return _c } -// Import provides a mock function with given fields: ctx, req -func (_m *RootCoord) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { - ret := _m.Called(ctx, req) +// Import provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) Import(_a0 context.Context, _a1 *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ImportResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ImportRequest) *milvuspb.ImportResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ImportResponse) @@ -1534,7 +1539,7 @@ func (_m *RootCoord) Import(ctx context.Context, req *milvuspb.ImportRequest) (* } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ImportRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1548,13 +1553,13 @@ type RootCoord_Import_Call struct { } // Import is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ImportRequest -func (_e *RootCoord_Expecter) Import(ctx interface{}, req interface{}) *RootCoord_Import_Call { - return &RootCoord_Import_Call{Call: _e.mock.On("Import", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ImportRequest +func (_e *RootCoord_Expecter) Import(_a0 interface{}, _a1 interface{}) *RootCoord_Import_Call { + return &RootCoord_Import_Call{Call: _e.mock.On("Import", _a0, _a1)} } -func (_c *RootCoord_Import_Call) Run(run func(ctx context.Context, req *milvuspb.ImportRequest)) *RootCoord_Import_Call { +func (_c *RootCoord_Import_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ImportRequest)) *RootCoord_Import_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ImportRequest)) }) @@ -1612,17 +1617,17 @@ func (_c *RootCoord_Init_Call) RunAndReturn(run func() error) *RootCoord_Init_Ca return _c } -// InvalidateCollectionMetaCache provides a mock function with given fields: ctx, request -func (_m *RootCoord) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, request) +// InvalidateCollectionMetaCache provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) InvalidateCollectionMetaCache(_a0 context.Context, _a1 *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error)); ok { - return rf(ctx, request) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest) *commonpb.Status); ok { - r0 = rf(ctx, request) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1630,7 +1635,7 @@ func (_m *RootCoord) InvalidateCollectionMetaCache(ctx context.Context, request } if rf, ok := ret.Get(1).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest) error); ok { - r1 = rf(ctx, request) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1644,13 +1649,13 @@ type RootCoord_InvalidateCollectionMetaCache_Call struct { } // InvalidateCollectionMetaCache is a helper method to define mock.On call -// - ctx context.Context -// - request *proxypb.InvalidateCollMetaCacheRequest -func (_e *RootCoord_Expecter) InvalidateCollectionMetaCache(ctx interface{}, request interface{}) *RootCoord_InvalidateCollectionMetaCache_Call { - return &RootCoord_InvalidateCollectionMetaCache_Call{Call: _e.mock.On("InvalidateCollectionMetaCache", ctx, request)} +// - _a0 context.Context +// - _a1 *proxypb.InvalidateCollMetaCacheRequest +func (_e *RootCoord_Expecter) InvalidateCollectionMetaCache(_a0 interface{}, _a1 interface{}) *RootCoord_InvalidateCollectionMetaCache_Call { + return &RootCoord_InvalidateCollectionMetaCache_Call{Call: _e.mock.On("InvalidateCollectionMetaCache", _a0, _a1)} } -func (_c *RootCoord_InvalidateCollectionMetaCache_Call) Run(run func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest)) *RootCoord_InvalidateCollectionMetaCache_Call { +func (_c *RootCoord_InvalidateCollectionMetaCache_Call) Run(run func(_a0 context.Context, _a1 *proxypb.InvalidateCollMetaCacheRequest)) *RootCoord_InvalidateCollectionMetaCache_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*proxypb.InvalidateCollMetaCacheRequest)) }) @@ -1667,17 +1672,17 @@ func (_c *RootCoord_InvalidateCollectionMetaCache_Call) RunAndReturn(run func(co return _c } -// ListCredUsers provides a mock function with given fields: ctx, req -func (_m *RootCoord) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) { - ret := _m.Called(ctx, req) +// ListCredUsers provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) ListCredUsers(_a0 context.Context, _a1 *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ListCredUsersResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListCredUsersRequest) *milvuspb.ListCredUsersResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ListCredUsersResponse) @@ -1685,7 +1690,7 @@ func (_m *RootCoord) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUs } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListCredUsersRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1699,13 +1704,13 @@ type RootCoord_ListCredUsers_Call struct { } // ListCredUsers is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ListCredUsersRequest -func (_e *RootCoord_Expecter) ListCredUsers(ctx interface{}, req interface{}) *RootCoord_ListCredUsers_Call { - return &RootCoord_ListCredUsers_Call{Call: _e.mock.On("ListCredUsers", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ListCredUsersRequest +func (_e *RootCoord_Expecter) ListCredUsers(_a0 interface{}, _a1 interface{}) *RootCoord_ListCredUsers_Call { + return &RootCoord_ListCredUsers_Call{Call: _e.mock.On("ListCredUsers", _a0, _a1)} } -func (_c *RootCoord_ListCredUsers_Call) Run(run func(ctx context.Context, req *milvuspb.ListCredUsersRequest)) *RootCoord_ListCredUsers_Call { +func (_c *RootCoord_ListCredUsers_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ListCredUsersRequest)) *RootCoord_ListCredUsers_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ListCredUsersRequest)) }) @@ -1722,17 +1727,17 @@ func (_c *RootCoord_ListCredUsers_Call) RunAndReturn(run func(context.Context, * return _c } -// ListDatabases provides a mock function with given fields: ctx, req -func (_m *RootCoord) ListDatabases(ctx context.Context, req *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) { - ret := _m.Called(ctx, req) +// ListDatabases provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) ListDatabases(_a0 context.Context, _a1 *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ListDatabasesResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListDatabasesRequest) *milvuspb.ListDatabasesResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ListDatabasesResponse) @@ -1740,7 +1745,7 @@ func (_m *RootCoord) ListDatabases(ctx context.Context, req *milvuspb.ListDataba } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListDatabasesRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1754,13 +1759,13 @@ type RootCoord_ListDatabases_Call struct { } // ListDatabases is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ListDatabasesRequest -func (_e *RootCoord_Expecter) ListDatabases(ctx interface{}, req interface{}) *RootCoord_ListDatabases_Call { - return &RootCoord_ListDatabases_Call{Call: _e.mock.On("ListDatabases", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ListDatabasesRequest +func (_e *RootCoord_Expecter) ListDatabases(_a0 interface{}, _a1 interface{}) *RootCoord_ListDatabases_Call { + return &RootCoord_ListDatabases_Call{Call: _e.mock.On("ListDatabases", _a0, _a1)} } -func (_c *RootCoord_ListDatabases_Call) Run(run func(ctx context.Context, req *milvuspb.ListDatabasesRequest)) *RootCoord_ListDatabases_Call { +func (_c *RootCoord_ListDatabases_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ListDatabasesRequest)) *RootCoord_ListDatabases_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ListDatabasesRequest)) }) @@ -1777,17 +1782,17 @@ func (_c *RootCoord_ListDatabases_Call) RunAndReturn(run func(context.Context, * return _c } -// ListImportTasks provides a mock function with given fields: ctx, req -func (_m *RootCoord) ListImportTasks(ctx context.Context, req *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) { - ret := _m.Called(ctx, req) +// ListImportTasks provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) ListImportTasks(_a0 context.Context, _a1 *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ListImportTasksResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListImportTasksRequest) *milvuspb.ListImportTasksResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ListImportTasksResponse) @@ -1795,7 +1800,7 @@ func (_m *RootCoord) ListImportTasks(ctx context.Context, req *milvuspb.ListImpo } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListImportTasksRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1809,13 +1814,13 @@ type RootCoord_ListImportTasks_Call struct { } // ListImportTasks is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ListImportTasksRequest -func (_e *RootCoord_Expecter) ListImportTasks(ctx interface{}, req interface{}) *RootCoord_ListImportTasks_Call { - return &RootCoord_ListImportTasks_Call{Call: _e.mock.On("ListImportTasks", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ListImportTasksRequest +func (_e *RootCoord_Expecter) ListImportTasks(_a0 interface{}, _a1 interface{}) *RootCoord_ListImportTasks_Call { + return &RootCoord_ListImportTasks_Call{Call: _e.mock.On("ListImportTasks", _a0, _a1)} } -func (_c *RootCoord_ListImportTasks_Call) Run(run func(ctx context.Context, req *milvuspb.ListImportTasksRequest)) *RootCoord_ListImportTasks_Call { +func (_c *RootCoord_ListImportTasks_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ListImportTasksRequest)) *RootCoord_ListImportTasks_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ListImportTasksRequest)) }) @@ -1832,17 +1837,17 @@ func (_c *RootCoord_ListImportTasks_Call) RunAndReturn(run func(context.Context, return _c } -// ListPolicy provides a mock function with given fields: ctx, in -func (_m *RootCoord) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { - ret := _m.Called(ctx, in) +// ListPolicy provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) ListPolicy(_a0 context.Context, _a1 *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *internalpb.ListPolicyResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error)); ok { - return rf(ctx, in) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ListPolicyRequest) *internalpb.ListPolicyResponse); ok { - r0 = rf(ctx, in) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*internalpb.ListPolicyResponse) @@ -1850,7 +1855,7 @@ func (_m *RootCoord) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRe } if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ListPolicyRequest) error); ok { - r1 = rf(ctx, in) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1864,13 +1869,13 @@ type RootCoord_ListPolicy_Call struct { } // ListPolicy is a helper method to define mock.On call -// - ctx context.Context -// - in *internalpb.ListPolicyRequest -func (_e *RootCoord_Expecter) ListPolicy(ctx interface{}, in interface{}) *RootCoord_ListPolicy_Call { - return &RootCoord_ListPolicy_Call{Call: _e.mock.On("ListPolicy", ctx, in)} +// - _a0 context.Context +// - _a1 *internalpb.ListPolicyRequest +func (_e *RootCoord_Expecter) ListPolicy(_a0 interface{}, _a1 interface{}) *RootCoord_ListPolicy_Call { + return &RootCoord_ListPolicy_Call{Call: _e.mock.On("ListPolicy", _a0, _a1)} } -func (_c *RootCoord_ListPolicy_Call) Run(run func(ctx context.Context, in *internalpb.ListPolicyRequest)) *RootCoord_ListPolicy_Call { +func (_c *RootCoord_ListPolicy_Call) Run(run func(_a0 context.Context, _a1 *internalpb.ListPolicyRequest)) *RootCoord_ListPolicy_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*internalpb.ListPolicyRequest)) }) @@ -1887,17 +1892,17 @@ func (_c *RootCoord_ListPolicy_Call) RunAndReturn(run func(context.Context, *int return _c } -// OperatePrivilege provides a mock function with given fields: ctx, req -func (_m *RootCoord) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// OperatePrivilege provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) OperatePrivilege(_a0 context.Context, _a1 *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperatePrivilegeRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1905,7 +1910,7 @@ func (_m *RootCoord) OperatePrivilege(ctx context.Context, req *milvuspb.Operate } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.OperatePrivilegeRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1919,13 +1924,13 @@ type RootCoord_OperatePrivilege_Call struct { } // OperatePrivilege is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.OperatePrivilegeRequest -func (_e *RootCoord_Expecter) OperatePrivilege(ctx interface{}, req interface{}) *RootCoord_OperatePrivilege_Call { - return &RootCoord_OperatePrivilege_Call{Call: _e.mock.On("OperatePrivilege", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.OperatePrivilegeRequest +func (_e *RootCoord_Expecter) OperatePrivilege(_a0 interface{}, _a1 interface{}) *RootCoord_OperatePrivilege_Call { + return &RootCoord_OperatePrivilege_Call{Call: _e.mock.On("OperatePrivilege", _a0, _a1)} } -func (_c *RootCoord_OperatePrivilege_Call) Run(run func(ctx context.Context, req *milvuspb.OperatePrivilegeRequest)) *RootCoord_OperatePrivilege_Call { +func (_c *RootCoord_OperatePrivilege_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.OperatePrivilegeRequest)) *RootCoord_OperatePrivilege_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.OperatePrivilegeRequest)) }) @@ -1942,17 +1947,17 @@ func (_c *RootCoord_OperatePrivilege_Call) RunAndReturn(run func(context.Context return _c } -// OperateUserRole provides a mock function with given fields: ctx, req -func (_m *RootCoord) OperateUserRole(ctx context.Context, req *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// OperateUserRole provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) OperateUserRole(_a0 context.Context, _a1 *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperateUserRoleRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -1960,7 +1965,7 @@ func (_m *RootCoord) OperateUserRole(ctx context.Context, req *milvuspb.OperateU } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.OperateUserRoleRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -1974,13 +1979,13 @@ type RootCoord_OperateUserRole_Call struct { } // OperateUserRole is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.OperateUserRoleRequest -func (_e *RootCoord_Expecter) OperateUserRole(ctx interface{}, req interface{}) *RootCoord_OperateUserRole_Call { - return &RootCoord_OperateUserRole_Call{Call: _e.mock.On("OperateUserRole", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.OperateUserRoleRequest +func (_e *RootCoord_Expecter) OperateUserRole(_a0 interface{}, _a1 interface{}) *RootCoord_OperateUserRole_Call { + return &RootCoord_OperateUserRole_Call{Call: _e.mock.On("OperateUserRole", _a0, _a1)} } -func (_c *RootCoord_OperateUserRole_Call) Run(run func(ctx context.Context, req *milvuspb.OperateUserRoleRequest)) *RootCoord_OperateUserRole_Call { +func (_c *RootCoord_OperateUserRole_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.OperateUserRoleRequest)) *RootCoord_OperateUserRole_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.OperateUserRoleRequest)) }) @@ -2038,17 +2043,17 @@ func (_c *RootCoord_Register_Call) RunAndReturn(run func() error) *RootCoord_Reg return _c } -// RenameCollection provides a mock function with given fields: ctx, req -func (_m *RootCoord) RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// RenameCollection provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) RenameCollection(_a0 context.Context, _a1 *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.RenameCollectionRequest) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.RenameCollectionRequest) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -2056,7 +2061,7 @@ func (_m *RootCoord) RenameCollection(ctx context.Context, req *milvuspb.RenameC } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.RenameCollectionRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2070,13 +2075,13 @@ type RootCoord_RenameCollection_Call struct { } // RenameCollection is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.RenameCollectionRequest -func (_e *RootCoord_Expecter) RenameCollection(ctx interface{}, req interface{}) *RootCoord_RenameCollection_Call { - return &RootCoord_RenameCollection_Call{Call: _e.mock.On("RenameCollection", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.RenameCollectionRequest +func (_e *RootCoord_Expecter) RenameCollection(_a0 interface{}, _a1 interface{}) *RootCoord_RenameCollection_Call { + return &RootCoord_RenameCollection_Call{Call: _e.mock.On("RenameCollection", _a0, _a1)} } -func (_c *RootCoord_RenameCollection_Call) Run(run func(ctx context.Context, req *milvuspb.RenameCollectionRequest)) *RootCoord_RenameCollection_Call { +func (_c *RootCoord_RenameCollection_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.RenameCollectionRequest)) *RootCoord_RenameCollection_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.RenameCollectionRequest)) }) @@ -2093,17 +2098,17 @@ func (_c *RootCoord_RenameCollection_Call) RunAndReturn(run func(context.Context return _c } -// ReportImport provides a mock function with given fields: ctx, req -func (_m *RootCoord) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// ReportImport provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) ReportImport(_a0 context.Context, _a1 *rootcoordpb.ImportResult) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.ImportResult) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.ImportResult) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -2111,7 +2116,7 @@ func (_m *RootCoord) ReportImport(ctx context.Context, req *rootcoordpb.ImportRe } if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.ImportResult) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2125,13 +2130,13 @@ type RootCoord_ReportImport_Call struct { } // ReportImport is a helper method to define mock.On call -// - ctx context.Context -// - req *rootcoordpb.ImportResult -func (_e *RootCoord_Expecter) ReportImport(ctx interface{}, req interface{}) *RootCoord_ReportImport_Call { - return &RootCoord_ReportImport_Call{Call: _e.mock.On("ReportImport", ctx, req)} +// - _a0 context.Context +// - _a1 *rootcoordpb.ImportResult +func (_e *RootCoord_Expecter) ReportImport(_a0 interface{}, _a1 interface{}) *RootCoord_ReportImport_Call { + return &RootCoord_ReportImport_Call{Call: _e.mock.On("ReportImport", _a0, _a1)} } -func (_c *RootCoord_ReportImport_Call) Run(run func(ctx context.Context, req *rootcoordpb.ImportResult)) *RootCoord_ReportImport_Call { +func (_c *RootCoord_ReportImport_Call) Run(run func(_a0 context.Context, _a1 *rootcoordpb.ImportResult)) *RootCoord_ReportImport_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*rootcoordpb.ImportResult)) }) @@ -2148,17 +2153,17 @@ func (_c *RootCoord_ReportImport_Call) RunAndReturn(run func(context.Context, *r return _c } -// SelectGrant provides a mock function with given fields: ctx, req -func (_m *RootCoord) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) { - ret := _m.Called(ctx, req) +// SelectGrant provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) SelectGrant(_a0 context.Context, _a1 *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.SelectGrantResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectGrantRequest) *milvuspb.SelectGrantResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.SelectGrantResponse) @@ -2166,7 +2171,7 @@ func (_m *RootCoord) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantR } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SelectGrantRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2180,13 +2185,13 @@ type RootCoord_SelectGrant_Call struct { } // SelectGrant is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.SelectGrantRequest -func (_e *RootCoord_Expecter) SelectGrant(ctx interface{}, req interface{}) *RootCoord_SelectGrant_Call { - return &RootCoord_SelectGrant_Call{Call: _e.mock.On("SelectGrant", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.SelectGrantRequest +func (_e *RootCoord_Expecter) SelectGrant(_a0 interface{}, _a1 interface{}) *RootCoord_SelectGrant_Call { + return &RootCoord_SelectGrant_Call{Call: _e.mock.On("SelectGrant", _a0, _a1)} } -func (_c *RootCoord_SelectGrant_Call) Run(run func(ctx context.Context, req *milvuspb.SelectGrantRequest)) *RootCoord_SelectGrant_Call { +func (_c *RootCoord_SelectGrant_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.SelectGrantRequest)) *RootCoord_SelectGrant_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.SelectGrantRequest)) }) @@ -2203,17 +2208,17 @@ func (_c *RootCoord_SelectGrant_Call) RunAndReturn(run func(context.Context, *mi return _c } -// SelectRole provides a mock function with given fields: ctx, req -func (_m *RootCoord) SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) { - ret := _m.Called(ctx, req) +// SelectRole provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) SelectRole(_a0 context.Context, _a1 *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.SelectRoleResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectRoleRequest) *milvuspb.SelectRoleResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.SelectRoleResponse) @@ -2221,7 +2226,7 @@ func (_m *RootCoord) SelectRole(ctx context.Context, req *milvuspb.SelectRoleReq } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SelectRoleRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2235,13 +2240,13 @@ type RootCoord_SelectRole_Call struct { } // SelectRole is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.SelectRoleRequest -func (_e *RootCoord_Expecter) SelectRole(ctx interface{}, req interface{}) *RootCoord_SelectRole_Call { - return &RootCoord_SelectRole_Call{Call: _e.mock.On("SelectRole", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.SelectRoleRequest +func (_e *RootCoord_Expecter) SelectRole(_a0 interface{}, _a1 interface{}) *RootCoord_SelectRole_Call { + return &RootCoord_SelectRole_Call{Call: _e.mock.On("SelectRole", _a0, _a1)} } -func (_c *RootCoord_SelectRole_Call) Run(run func(ctx context.Context, req *milvuspb.SelectRoleRequest)) *RootCoord_SelectRole_Call { +func (_c *RootCoord_SelectRole_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.SelectRoleRequest)) *RootCoord_SelectRole_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.SelectRoleRequest)) }) @@ -2258,17 +2263,17 @@ func (_c *RootCoord_SelectRole_Call) RunAndReturn(run func(context.Context, *mil return _c } -// SelectUser provides a mock function with given fields: ctx, req -func (_m *RootCoord) SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) { - ret := _m.Called(ctx, req) +// SelectUser provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) SelectUser(_a0 context.Context, _a1 *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.SelectUserResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectUserRequest) *milvuspb.SelectUserResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.SelectUserResponse) @@ -2276,7 +2281,7 @@ func (_m *RootCoord) SelectUser(ctx context.Context, req *milvuspb.SelectUserReq } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SelectUserRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2290,13 +2295,13 @@ type RootCoord_SelectUser_Call struct { } // SelectUser is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.SelectUserRequest -func (_e *RootCoord_Expecter) SelectUser(ctx interface{}, req interface{}) *RootCoord_SelectUser_Call { - return &RootCoord_SelectUser_Call{Call: _e.mock.On("SelectUser", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.SelectUserRequest +func (_e *RootCoord_Expecter) SelectUser(_a0 interface{}, _a1 interface{}) *RootCoord_SelectUser_Call { + return &RootCoord_SelectUser_Call{Call: _e.mock.On("SelectUser", _a0, _a1)} } -func (_c *RootCoord_SelectUser_Call) Run(run func(ctx context.Context, req *milvuspb.SelectUserRequest)) *RootCoord_SelectUser_Call { +func (_c *RootCoord_SelectUser_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.SelectUserRequest)) *RootCoord_SelectUser_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.SelectUserRequest)) }) @@ -2346,12 +2351,12 @@ func (_c *RootCoord_SetAddress_Call) RunAndReturn(run func(string)) *RootCoord_S return _c } -// SetDataCoord provides a mock function with given fields: dataCoord -func (_m *RootCoord) SetDataCoord(dataCoord types.DataCoord) error { +// SetDataCoordClient provides a mock function with given fields: dataCoord +func (_m *RootCoord) SetDataCoordClient(dataCoord types.DataCoordClient) error { ret := _m.Called(dataCoord) var r0 error - if rf, ok := ret.Get(0).(func(types.DataCoord) error); ok { + if rf, ok := ret.Get(0).(func(types.DataCoordClient) error); ok { r0 = rf(dataCoord) } else { r0 = ret.Error(0) @@ -2360,30 +2365,30 @@ func (_m *RootCoord) SetDataCoord(dataCoord types.DataCoord) error { return r0 } -// RootCoord_SetDataCoord_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetDataCoord' -type RootCoord_SetDataCoord_Call struct { +// RootCoord_SetDataCoordClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetDataCoordClient' +type RootCoord_SetDataCoordClient_Call struct { *mock.Call } -// SetDataCoord is a helper method to define mock.On call -// - dataCoord types.DataCoord -func (_e *RootCoord_Expecter) SetDataCoord(dataCoord interface{}) *RootCoord_SetDataCoord_Call { - return &RootCoord_SetDataCoord_Call{Call: _e.mock.On("SetDataCoord", dataCoord)} +// SetDataCoordClient is a helper method to define mock.On call +// - dataCoord types.DataCoordClient +func (_e *RootCoord_Expecter) SetDataCoordClient(dataCoord interface{}) *RootCoord_SetDataCoordClient_Call { + return &RootCoord_SetDataCoordClient_Call{Call: _e.mock.On("SetDataCoordClient", dataCoord)} } -func (_c *RootCoord_SetDataCoord_Call) Run(run func(dataCoord types.DataCoord)) *RootCoord_SetDataCoord_Call { +func (_c *RootCoord_SetDataCoordClient_Call) Run(run func(dataCoord types.DataCoordClient)) *RootCoord_SetDataCoordClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(types.DataCoord)) + run(args[0].(types.DataCoordClient)) }) return _c } -func (_c *RootCoord_SetDataCoord_Call) Return(_a0 error) *RootCoord_SetDataCoord_Call { +func (_c *RootCoord_SetDataCoordClient_Call) Return(_a0 error) *RootCoord_SetDataCoordClient_Call { _c.Call.Return(_a0) return _c } -func (_c *RootCoord_SetDataCoord_Call) RunAndReturn(run func(types.DataCoord) error) *RootCoord_SetDataCoord_Call { +func (_c *RootCoord_SetDataCoordClient_Call) RunAndReturn(run func(types.DataCoordClient) error) *RootCoord_SetDataCoordClient_Call { _c.Call.Return(run) return _c } @@ -2422,7 +2427,7 @@ func (_c *RootCoord_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Client)) } // SetProxyCreator provides a mock function with given fields: _a0 -func (_m *RootCoord) SetProxyCreator(_a0 func(context.Context, string, int64) (types.Proxy, error)) { +func (_m *RootCoord) SetProxyCreator(_a0 func(context.Context, string, int64) (types.ProxyClient, error)) { _m.Called(_a0) } @@ -2432,14 +2437,14 @@ type RootCoord_SetProxyCreator_Call struct { } // SetProxyCreator is a helper method to define mock.On call -// - _a0 func(context.Context , string , int64)(types.Proxy , error) +// - _a0 func(context.Context , string , int64)(types.ProxyClient , error) func (_e *RootCoord_Expecter) SetProxyCreator(_a0 interface{}) *RootCoord_SetProxyCreator_Call { return &RootCoord_SetProxyCreator_Call{Call: _e.mock.On("SetProxyCreator", _a0)} } -func (_c *RootCoord_SetProxyCreator_Call) Run(run func(_a0 func(context.Context, string, int64) (types.Proxy, error))) *RootCoord_SetProxyCreator_Call { +func (_c *RootCoord_SetProxyCreator_Call) Run(run func(_a0 func(context.Context, string, int64) (types.ProxyClient, error))) *RootCoord_SetProxyCreator_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(func(context.Context, string, int64) (types.Proxy, error))) + run(args[0].(func(context.Context, string, int64) (types.ProxyClient, error))) }) return _c } @@ -2449,17 +2454,17 @@ func (_c *RootCoord_SetProxyCreator_Call) Return() *RootCoord_SetProxyCreator_Ca return _c } -func (_c *RootCoord_SetProxyCreator_Call) RunAndReturn(run func(func(context.Context, string, int64) (types.Proxy, error))) *RootCoord_SetProxyCreator_Call { +func (_c *RootCoord_SetProxyCreator_Call) RunAndReturn(run func(func(context.Context, string, int64) (types.ProxyClient, error))) *RootCoord_SetProxyCreator_Call { _c.Call.Return(run) return _c } -// SetQueryCoord provides a mock function with given fields: queryCoord -func (_m *RootCoord) SetQueryCoord(queryCoord types.QueryCoord) error { +// SetQueryCoordClient provides a mock function with given fields: queryCoord +func (_m *RootCoord) SetQueryCoordClient(queryCoord types.QueryCoordClient) error { ret := _m.Called(queryCoord) var r0 error - if rf, ok := ret.Get(0).(func(types.QueryCoord) error); ok { + if rf, ok := ret.Get(0).(func(types.QueryCoordClient) error); ok { r0 = rf(queryCoord) } else { r0 = ret.Error(0) @@ -2468,45 +2473,78 @@ func (_m *RootCoord) SetQueryCoord(queryCoord types.QueryCoord) error { return r0 } -// RootCoord_SetQueryCoord_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetQueryCoord' -type RootCoord_SetQueryCoord_Call struct { +// RootCoord_SetQueryCoordClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetQueryCoordClient' +type RootCoord_SetQueryCoordClient_Call struct { *mock.Call } -// SetQueryCoord is a helper method to define mock.On call -// - queryCoord types.QueryCoord -func (_e *RootCoord_Expecter) SetQueryCoord(queryCoord interface{}) *RootCoord_SetQueryCoord_Call { - return &RootCoord_SetQueryCoord_Call{Call: _e.mock.On("SetQueryCoord", queryCoord)} +// SetQueryCoordClient is a helper method to define mock.On call +// - queryCoord types.QueryCoordClient +func (_e *RootCoord_Expecter) SetQueryCoordClient(queryCoord interface{}) *RootCoord_SetQueryCoordClient_Call { + return &RootCoord_SetQueryCoordClient_Call{Call: _e.mock.On("SetQueryCoordClient", queryCoord)} } -func (_c *RootCoord_SetQueryCoord_Call) Run(run func(queryCoord types.QueryCoord)) *RootCoord_SetQueryCoord_Call { +func (_c *RootCoord_SetQueryCoordClient_Call) Run(run func(queryCoord types.QueryCoordClient)) *RootCoord_SetQueryCoordClient_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(types.QueryCoord)) + run(args[0].(types.QueryCoordClient)) }) return _c } -func (_c *RootCoord_SetQueryCoord_Call) Return(_a0 error) *RootCoord_SetQueryCoord_Call { +func (_c *RootCoord_SetQueryCoordClient_Call) Return(_a0 error) *RootCoord_SetQueryCoordClient_Call { _c.Call.Return(_a0) return _c } -func (_c *RootCoord_SetQueryCoord_Call) RunAndReturn(run func(types.QueryCoord) error) *RootCoord_SetQueryCoord_Call { +func (_c *RootCoord_SetQueryCoordClient_Call) RunAndReturn(run func(types.QueryCoordClient) error) *RootCoord_SetQueryCoordClient_Call { _c.Call.Return(run) return _c } -// ShowCollections provides a mock function with given fields: ctx, req -func (_m *RootCoord) ShowCollections(ctx context.Context, req *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { - ret := _m.Called(ctx, req) +// SetTiKVClient provides a mock function with given fields: client +func (_m *RootCoord) SetTiKVClient(client *txnkv.Client) { + _m.Called(client) +} + +// RootCoord_SetTiKVClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTiKVClient' +type RootCoord_SetTiKVClient_Call struct { + *mock.Call +} + +// SetTiKVClient is a helper method to define mock.On call +// - client *txnkv.Client +func (_e *RootCoord_Expecter) SetTiKVClient(client interface{}) *RootCoord_SetTiKVClient_Call { + return &RootCoord_SetTiKVClient_Call{Call: _e.mock.On("SetTiKVClient", client)} +} + +func (_c *RootCoord_SetTiKVClient_Call) Run(run func(client *txnkv.Client)) *RootCoord_SetTiKVClient_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*txnkv.Client)) + }) + return _c +} + +func (_c *RootCoord_SetTiKVClient_Call) Return() *RootCoord_SetTiKVClient_Call { + _c.Call.Return() + return _c +} + +func (_c *RootCoord_SetTiKVClient_Call) RunAndReturn(run func(*txnkv.Client)) *RootCoord_SetTiKVClient_Call { + _c.Call.Return(run) + return _c +} + +// ShowCollections provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) ShowCollections(_a0 context.Context, _a1 *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ShowCollectionsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowCollectionsRequest) *milvuspb.ShowCollectionsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ShowCollectionsResponse) @@ -2514,7 +2552,7 @@ func (_m *RootCoord) ShowCollections(ctx context.Context, req *milvuspb.ShowColl } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ShowCollectionsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2528,13 +2566,13 @@ type RootCoord_ShowCollections_Call struct { } // ShowCollections is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ShowCollectionsRequest -func (_e *RootCoord_Expecter) ShowCollections(ctx interface{}, req interface{}) *RootCoord_ShowCollections_Call { - return &RootCoord_ShowCollections_Call{Call: _e.mock.On("ShowCollections", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ShowCollectionsRequest +func (_e *RootCoord_Expecter) ShowCollections(_a0 interface{}, _a1 interface{}) *RootCoord_ShowCollections_Call { + return &RootCoord_ShowCollections_Call{Call: _e.mock.On("ShowCollections", _a0, _a1)} } -func (_c *RootCoord_ShowCollections_Call) Run(run func(ctx context.Context, req *milvuspb.ShowCollectionsRequest)) *RootCoord_ShowCollections_Call { +func (_c *RootCoord_ShowCollections_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ShowCollectionsRequest)) *RootCoord_ShowCollections_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ShowCollectionsRequest)) }) @@ -2551,17 +2589,17 @@ func (_c *RootCoord_ShowCollections_Call) RunAndReturn(run func(context.Context, return _c } -// ShowConfigurations provides a mock function with given fields: ctx, req -func (_m *RootCoord) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - ret := _m.Called(ctx, req) +// ShowConfigurations provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) ShowConfigurations(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *internalpb.ShowConfigurationsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest) *internalpb.ShowConfigurationsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*internalpb.ShowConfigurationsResponse) @@ -2569,7 +2607,7 @@ func (_m *RootCoord) ShowConfigurations(ctx context.Context, req *internalpb.Sho } if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2583,13 +2621,13 @@ type RootCoord_ShowConfigurations_Call struct { } // ShowConfigurations is a helper method to define mock.On call -// - ctx context.Context -// - req *internalpb.ShowConfigurationsRequest -func (_e *RootCoord_Expecter) ShowConfigurations(ctx interface{}, req interface{}) *RootCoord_ShowConfigurations_Call { - return &RootCoord_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", ctx, req)} +// - _a0 context.Context +// - _a1 *internalpb.ShowConfigurationsRequest +func (_e *RootCoord_Expecter) ShowConfigurations(_a0 interface{}, _a1 interface{}) *RootCoord_ShowConfigurations_Call { + return &RootCoord_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", _a0, _a1)} } -func (_c *RootCoord_ShowConfigurations_Call) Run(run func(ctx context.Context, req *internalpb.ShowConfigurationsRequest)) *RootCoord_ShowConfigurations_Call { +func (_c *RootCoord_ShowConfigurations_Call) Run(run func(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest)) *RootCoord_ShowConfigurations_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*internalpb.ShowConfigurationsRequest)) }) @@ -2606,17 +2644,17 @@ func (_c *RootCoord_ShowConfigurations_Call) RunAndReturn(run func(context.Conte return _c } -// ShowPartitions provides a mock function with given fields: ctx, req -func (_m *RootCoord) ShowPartitions(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { - ret := _m.Called(ctx, req) +// ShowPartitions provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) ShowPartitions(_a0 context.Context, _a1 *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ShowPartitionsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowPartitionsRequest) *milvuspb.ShowPartitionsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ShowPartitionsResponse) @@ -2624,7 +2662,7 @@ func (_m *RootCoord) ShowPartitions(ctx context.Context, req *milvuspb.ShowParti } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ShowPartitionsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2638,13 +2676,13 @@ type RootCoord_ShowPartitions_Call struct { } // ShowPartitions is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ShowPartitionsRequest -func (_e *RootCoord_Expecter) ShowPartitions(ctx interface{}, req interface{}) *RootCoord_ShowPartitions_Call { - return &RootCoord_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ShowPartitionsRequest +func (_e *RootCoord_Expecter) ShowPartitions(_a0 interface{}, _a1 interface{}) *RootCoord_ShowPartitions_Call { + return &RootCoord_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", _a0, _a1)} } -func (_c *RootCoord_ShowPartitions_Call) Run(run func(ctx context.Context, req *milvuspb.ShowPartitionsRequest)) *RootCoord_ShowPartitions_Call { +func (_c *RootCoord_ShowPartitions_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ShowPartitionsRequest)) *RootCoord_ShowPartitions_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ShowPartitionsRequest)) }) @@ -2661,17 +2699,17 @@ func (_c *RootCoord_ShowPartitions_Call) RunAndReturn(run func(context.Context, return _c } -// ShowPartitionsInternal provides a mock function with given fields: ctx, req -func (_m *RootCoord) ShowPartitionsInternal(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { - ret := _m.Called(ctx, req) +// ShowPartitionsInternal provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) ShowPartitionsInternal(_a0 context.Context, _a1 *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ShowPartitionsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowPartitionsRequest) *milvuspb.ShowPartitionsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ShowPartitionsResponse) @@ -2679,7 +2717,7 @@ func (_m *RootCoord) ShowPartitionsInternal(ctx context.Context, req *milvuspb.S } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ShowPartitionsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2693,13 +2731,13 @@ type RootCoord_ShowPartitionsInternal_Call struct { } // ShowPartitionsInternal is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ShowPartitionsRequest -func (_e *RootCoord_Expecter) ShowPartitionsInternal(ctx interface{}, req interface{}) *RootCoord_ShowPartitionsInternal_Call { - return &RootCoord_ShowPartitionsInternal_Call{Call: _e.mock.On("ShowPartitionsInternal", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ShowPartitionsRequest +func (_e *RootCoord_Expecter) ShowPartitionsInternal(_a0 interface{}, _a1 interface{}) *RootCoord_ShowPartitionsInternal_Call { + return &RootCoord_ShowPartitionsInternal_Call{Call: _e.mock.On("ShowPartitionsInternal", _a0, _a1)} } -func (_c *RootCoord_ShowPartitionsInternal_Call) Run(run func(ctx context.Context, req *milvuspb.ShowPartitionsRequest)) *RootCoord_ShowPartitionsInternal_Call { +func (_c *RootCoord_ShowPartitionsInternal_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ShowPartitionsRequest)) *RootCoord_ShowPartitionsInternal_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ShowPartitionsRequest)) }) @@ -2716,17 +2754,17 @@ func (_c *RootCoord_ShowPartitionsInternal_Call) RunAndReturn(run func(context.C return _c } -// ShowSegments provides a mock function with given fields: ctx, req -func (_m *RootCoord) ShowSegments(ctx context.Context, req *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { - ret := _m.Called(ctx, req) +// ShowSegments provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) ShowSegments(_a0 context.Context, _a1 *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { + ret := _m.Called(_a0, _a1) var r0 *milvuspb.ShowSegmentsResponse var r1 error if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowSegmentsRequest) *milvuspb.ShowSegmentsResponse); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*milvuspb.ShowSegmentsResponse) @@ -2734,7 +2772,7 @@ func (_m *RootCoord) ShowSegments(ctx context.Context, req *milvuspb.ShowSegment } if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ShowSegmentsRequest) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2748,13 +2786,13 @@ type RootCoord_ShowSegments_Call struct { } // ShowSegments is a helper method to define mock.On call -// - ctx context.Context -// - req *milvuspb.ShowSegmentsRequest -func (_e *RootCoord_Expecter) ShowSegments(ctx interface{}, req interface{}) *RootCoord_ShowSegments_Call { - return &RootCoord_ShowSegments_Call{Call: _e.mock.On("ShowSegments", ctx, req)} +// - _a0 context.Context +// - _a1 *milvuspb.ShowSegmentsRequest +func (_e *RootCoord_Expecter) ShowSegments(_a0 interface{}, _a1 interface{}) *RootCoord_ShowSegments_Call { + return &RootCoord_ShowSegments_Call{Call: _e.mock.On("ShowSegments", _a0, _a1)} } -func (_c *RootCoord_ShowSegments_Call) Run(run func(ctx context.Context, req *milvuspb.ShowSegmentsRequest)) *RootCoord_ShowSegments_Call { +func (_c *RootCoord_ShowSegments_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.ShowSegmentsRequest)) *RootCoord_ShowSegments_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*milvuspb.ShowSegmentsRequest)) }) @@ -2853,17 +2891,17 @@ func (_c *RootCoord_Stop_Call) RunAndReturn(run func() error) *RootCoord_Stop_Ca return _c } -// UpdateChannelTimeTick provides a mock function with given fields: ctx, req -func (_m *RootCoord) UpdateChannelTimeTick(ctx context.Context, req *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// UpdateChannelTimeTick provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) UpdateChannelTimeTick(_a0 context.Context, _a1 *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ChannelTimeTickMsg) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -2871,7 +2909,7 @@ func (_m *RootCoord) UpdateChannelTimeTick(ctx context.Context, req *internalpb. } if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ChannelTimeTickMsg) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2885,13 +2923,13 @@ type RootCoord_UpdateChannelTimeTick_Call struct { } // UpdateChannelTimeTick is a helper method to define mock.On call -// - ctx context.Context -// - req *internalpb.ChannelTimeTickMsg -func (_e *RootCoord_Expecter) UpdateChannelTimeTick(ctx interface{}, req interface{}) *RootCoord_UpdateChannelTimeTick_Call { - return &RootCoord_UpdateChannelTimeTick_Call{Call: _e.mock.On("UpdateChannelTimeTick", ctx, req)} +// - _a0 context.Context +// - _a1 *internalpb.ChannelTimeTickMsg +func (_e *RootCoord_Expecter) UpdateChannelTimeTick(_a0 interface{}, _a1 interface{}) *RootCoord_UpdateChannelTimeTick_Call { + return &RootCoord_UpdateChannelTimeTick_Call{Call: _e.mock.On("UpdateChannelTimeTick", _a0, _a1)} } -func (_c *RootCoord_UpdateChannelTimeTick_Call) Run(run func(ctx context.Context, req *internalpb.ChannelTimeTickMsg)) *RootCoord_UpdateChannelTimeTick_Call { +func (_c *RootCoord_UpdateChannelTimeTick_Call) Run(run func(_a0 context.Context, _a1 *internalpb.ChannelTimeTickMsg)) *RootCoord_UpdateChannelTimeTick_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*internalpb.ChannelTimeTickMsg)) }) @@ -2908,17 +2946,17 @@ func (_c *RootCoord_UpdateChannelTimeTick_Call) RunAndReturn(run func(context.Co return _c } -// UpdateCredential provides a mock function with given fields: ctx, req -func (_m *RootCoord) UpdateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) { - ret := _m.Called(ctx, req) +// UpdateCredential provides a mock function with given fields: _a0, _a1 +func (_m *RootCoord) UpdateCredential(_a0 context.Context, _a1 *internalpb.CredentialInfo) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) var r0 *commonpb.Status var r1 error if rf, ok := ret.Get(0).(func(context.Context, *internalpb.CredentialInfo) (*commonpb.Status, error)); ok { - return rf(ctx, req) + return rf(_a0, _a1) } if rf, ok := ret.Get(0).(func(context.Context, *internalpb.CredentialInfo) *commonpb.Status); ok { - r0 = rf(ctx, req) + r0 = rf(_a0, _a1) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*commonpb.Status) @@ -2926,7 +2964,7 @@ func (_m *RootCoord) UpdateCredential(ctx context.Context, req *internalpb.Crede } if rf, ok := ret.Get(1).(func(context.Context, *internalpb.CredentialInfo) error); ok { - r1 = rf(ctx, req) + r1 = rf(_a0, _a1) } else { r1 = ret.Error(1) } @@ -2940,13 +2978,13 @@ type RootCoord_UpdateCredential_Call struct { } // UpdateCredential is a helper method to define mock.On call -// - ctx context.Context -// - req *internalpb.CredentialInfo -func (_e *RootCoord_Expecter) UpdateCredential(ctx interface{}, req interface{}) *RootCoord_UpdateCredential_Call { - return &RootCoord_UpdateCredential_Call{Call: _e.mock.On("UpdateCredential", ctx, req)} +// - _a0 context.Context +// - _a1 *internalpb.CredentialInfo +func (_e *RootCoord_Expecter) UpdateCredential(_a0 interface{}, _a1 interface{}) *RootCoord_UpdateCredential_Call { + return &RootCoord_UpdateCredential_Call{Call: _e.mock.On("UpdateCredential", _a0, _a1)} } -func (_c *RootCoord_UpdateCredential_Call) Run(run func(ctx context.Context, req *internalpb.CredentialInfo)) *RootCoord_UpdateCredential_Call { +func (_c *RootCoord_UpdateCredential_Call) Run(run func(_a0 context.Context, _a1 *internalpb.CredentialInfo)) *RootCoord_UpdateCredential_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(*internalpb.CredentialInfo)) }) diff --git a/internal/mocks/mock_rootcoord_client.go b/internal/mocks/mock_rootcoord_client.go new file mode 100644 index 0000000000000..b1deb3977c285 --- /dev/null +++ b/internal/mocks/mock_rootcoord_client.go @@ -0,0 +1,3379 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + + commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + + grpc "google.golang.org/grpc" + + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" + + milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + + mock "github.com/stretchr/testify/mock" + + proxypb "github.com/milvus-io/milvus/internal/proto/proxypb" + + rootcoordpb "github.com/milvus-io/milvus/internal/proto/rootcoordpb" +) + +// MockRootCoordClient is an autogenerated mock type for the RootCoordClient type +type MockRootCoordClient struct { + mock.Mock +} + +type MockRootCoordClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockRootCoordClient) EXPECT() *MockRootCoordClient_Expecter { + return &MockRootCoordClient_Expecter{mock: &_m.Mock} +} + +// AllocID provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *rootcoordpb.AllocIDResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.AllocIDRequest, ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.AllocIDRequest, ...grpc.CallOption) *rootcoordpb.AllocIDResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*rootcoordpb.AllocIDResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.AllocIDRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_AllocID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllocID' +type MockRootCoordClient_AllocID_Call struct { + *mock.Call +} + +// AllocID is a helper method to define mock.On call +// - ctx context.Context +// - in *rootcoordpb.AllocIDRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) AllocID(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_AllocID_Call { + return &MockRootCoordClient_AllocID_Call{Call: _e.mock.On("AllocID", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_AllocID_Call) Run(run func(ctx context.Context, in *rootcoordpb.AllocIDRequest, opts ...grpc.CallOption)) *MockRootCoordClient_AllocID_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*rootcoordpb.AllocIDRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_AllocID_Call) Return(_a0 *rootcoordpb.AllocIDResponse, _a1 error) *MockRootCoordClient_AllocID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_AllocID_Call) RunAndReturn(run func(context.Context, *rootcoordpb.AllocIDRequest, ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error)) *MockRootCoordClient_AllocID_Call { + _c.Call.Return(run) + return _c +} + +// AllocTimestamp provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestampRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *rootcoordpb.AllocTimestampResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.AllocTimestampRequest, ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.AllocTimestampRequest, ...grpc.CallOption) *rootcoordpb.AllocTimestampResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*rootcoordpb.AllocTimestampResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.AllocTimestampRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_AllocTimestamp_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllocTimestamp' +type MockRootCoordClient_AllocTimestamp_Call struct { + *mock.Call +} + +// AllocTimestamp is a helper method to define mock.On call +// - ctx context.Context +// - in *rootcoordpb.AllocTimestampRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) AllocTimestamp(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_AllocTimestamp_Call { + return &MockRootCoordClient_AllocTimestamp_Call{Call: _e.mock.On("AllocTimestamp", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_AllocTimestamp_Call) Run(run func(ctx context.Context, in *rootcoordpb.AllocTimestampRequest, opts ...grpc.CallOption)) *MockRootCoordClient_AllocTimestamp_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*rootcoordpb.AllocTimestampRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_AllocTimestamp_Call) Return(_a0 *rootcoordpb.AllocTimestampResponse, _a1 error) *MockRootCoordClient_AllocTimestamp_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_AllocTimestamp_Call) RunAndReturn(run func(context.Context, *rootcoordpb.AllocTimestampRequest, ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error)) *MockRootCoordClient_AllocTimestamp_Call { + _c.Call.Return(run) + return _c +} + +// AlterAlias provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) AlterAlias(ctx context.Context, in *milvuspb.AlterAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterAliasRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterAliasRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.AlterAliasRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_AlterAlias_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AlterAlias' +type MockRootCoordClient_AlterAlias_Call struct { + *mock.Call +} + +// AlterAlias is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.AlterAliasRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) AlterAlias(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_AlterAlias_Call { + return &MockRootCoordClient_AlterAlias_Call{Call: _e.mock.On("AlterAlias", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_AlterAlias_Call) Run(run func(ctx context.Context, in *milvuspb.AlterAliasRequest, opts ...grpc.CallOption)) *MockRootCoordClient_AlterAlias_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.AlterAliasRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_AlterAlias_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_AlterAlias_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_AlterAlias_Call) RunAndReturn(run func(context.Context, *milvuspb.AlterAliasRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_AlterAlias_Call { + _c.Call.Return(run) + return _c +} + +// AlterCollection provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) AlterCollection(ctx context.Context, in *milvuspb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.AlterCollectionRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.AlterCollectionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_AlterCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AlterCollection' +type MockRootCoordClient_AlterCollection_Call struct { + *mock.Call +} + +// AlterCollection is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.AlterCollectionRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) AlterCollection(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_AlterCollection_Call { + return &MockRootCoordClient_AlterCollection_Call{Call: _e.mock.On("AlterCollection", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_AlterCollection_Call) Run(run func(ctx context.Context, in *milvuspb.AlterCollectionRequest, opts ...grpc.CallOption)) *MockRootCoordClient_AlterCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.AlterCollectionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_AlterCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_AlterCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_AlterCollection_Call) RunAndReturn(run func(context.Context, *milvuspb.AlterCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_AlterCollection_Call { + _c.Call.Return(run) + return _c +} + +// CheckHealth provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.CheckHealthResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) *milvuspb.CheckHealthResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.CheckHealthResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_CheckHealth_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckHealth' +type MockRootCoordClient_CheckHealth_Call struct { + *mock.Call +} + +// CheckHealth is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.CheckHealthRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) CheckHealth(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_CheckHealth_Call { + return &MockRootCoordClient_CheckHealth_Call{Call: _e.mock.On("CheckHealth", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_CheckHealth_Call) Run(run func(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption)) *MockRootCoordClient_CheckHealth_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.CheckHealthRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_CheckHealth_Call) Return(_a0 *milvuspb.CheckHealthResponse, _a1 error) *MockRootCoordClient_CheckHealth_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_CheckHealth_Call) RunAndReturn(run func(context.Context, *milvuspb.CheckHealthRequest, ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error)) *MockRootCoordClient_CheckHealth_Call { + _c.Call.Return(run) + return _c +} + +// Close provides a mock function with given fields: +func (_m *MockRootCoordClient) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockRootCoordClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockRootCoordClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockRootCoordClient_Expecter) Close() *MockRootCoordClient_Close_Call { + return &MockRootCoordClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockRootCoordClient_Close_Call) Run(run func()) *MockRootCoordClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockRootCoordClient_Close_Call) Return(_a0 error) *MockRootCoordClient_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRootCoordClient_Close_Call) RunAndReturn(run func() error) *MockRootCoordClient_Close_Call { + _c.Call.Return(run) + return _c +} + +// CreateAlias provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) CreateAlias(ctx context.Context, in *milvuspb.CreateAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateAliasRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateAliasRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateAliasRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_CreateAlias_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateAlias' +type MockRootCoordClient_CreateAlias_Call struct { + *mock.Call +} + +// CreateAlias is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.CreateAliasRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) CreateAlias(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_CreateAlias_Call { + return &MockRootCoordClient_CreateAlias_Call{Call: _e.mock.On("CreateAlias", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_CreateAlias_Call) Run(run func(ctx context.Context, in *milvuspb.CreateAliasRequest, opts ...grpc.CallOption)) *MockRootCoordClient_CreateAlias_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.CreateAliasRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_CreateAlias_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_CreateAlias_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_CreateAlias_Call) RunAndReturn(run func(context.Context, *milvuspb.CreateAliasRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_CreateAlias_Call { + _c.Call.Return(run) + return _c +} + +// CreateCollection provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) CreateCollection(ctx context.Context, in *milvuspb.CreateCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateCollectionRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateCollectionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_CreateCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCollection' +type MockRootCoordClient_CreateCollection_Call struct { + *mock.Call +} + +// CreateCollection is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.CreateCollectionRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) CreateCollection(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_CreateCollection_Call { + return &MockRootCoordClient_CreateCollection_Call{Call: _e.mock.On("CreateCollection", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_CreateCollection_Call) Run(run func(ctx context.Context, in *milvuspb.CreateCollectionRequest, opts ...grpc.CallOption)) *MockRootCoordClient_CreateCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.CreateCollectionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_CreateCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_CreateCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_CreateCollection_Call) RunAndReturn(run func(context.Context, *milvuspb.CreateCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_CreateCollection_Call { + _c.Call.Return(run) + return _c +} + +// CreateCredential provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) CreateCredential(ctx context.Context, in *internalpb.CredentialInfo, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.CredentialInfo, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.CredentialInfo, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.CredentialInfo, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_CreateCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCredential' +type MockRootCoordClient_CreateCredential_Call struct { + *mock.Call +} + +// CreateCredential is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.CredentialInfo +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) CreateCredential(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_CreateCredential_Call { + return &MockRootCoordClient_CreateCredential_Call{Call: _e.mock.On("CreateCredential", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_CreateCredential_Call) Run(run func(ctx context.Context, in *internalpb.CredentialInfo, opts ...grpc.CallOption)) *MockRootCoordClient_CreateCredential_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.CredentialInfo), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_CreateCredential_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_CreateCredential_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_CreateCredential_Call) RunAndReturn(run func(context.Context, *internalpb.CredentialInfo, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_CreateCredential_Call { + _c.Call.Return(run) + return _c +} + +// CreateDatabase provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateDatabaseRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateDatabaseRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateDatabaseRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_CreateDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateDatabase' +type MockRootCoordClient_CreateDatabase_Call struct { + *mock.Call +} + +// CreateDatabase is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.CreateDatabaseRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) CreateDatabase(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_CreateDatabase_Call { + return &MockRootCoordClient_CreateDatabase_Call{Call: _e.mock.On("CreateDatabase", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_CreateDatabase_Call) Run(run func(ctx context.Context, in *milvuspb.CreateDatabaseRequest, opts ...grpc.CallOption)) *MockRootCoordClient_CreateDatabase_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.CreateDatabaseRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_CreateDatabase_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_CreateDatabase_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_CreateDatabase_Call) RunAndReturn(run func(context.Context, *milvuspb.CreateDatabaseRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_CreateDatabase_Call { + _c.Call.Return(run) + return _c +} + +// CreatePartition provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) CreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreatePartitionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreatePartitionRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreatePartitionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_CreatePartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreatePartition' +type MockRootCoordClient_CreatePartition_Call struct { + *mock.Call +} + +// CreatePartition is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.CreatePartitionRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) CreatePartition(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_CreatePartition_Call { + return &MockRootCoordClient_CreatePartition_Call{Call: _e.mock.On("CreatePartition", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_CreatePartition_Call) Run(run func(ctx context.Context, in *milvuspb.CreatePartitionRequest, opts ...grpc.CallOption)) *MockRootCoordClient_CreatePartition_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.CreatePartitionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_CreatePartition_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_CreatePartition_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_CreatePartition_Call) RunAndReturn(run func(context.Context, *milvuspb.CreatePartitionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_CreatePartition_Call { + _c.Call.Return(run) + return _c +} + +// CreateRole provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) CreateRole(ctx context.Context, in *milvuspb.CreateRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateRoleRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.CreateRoleRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.CreateRoleRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_CreateRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateRole' +type MockRootCoordClient_CreateRole_Call struct { + *mock.Call +} + +// CreateRole is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.CreateRoleRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) CreateRole(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_CreateRole_Call { + return &MockRootCoordClient_CreateRole_Call{Call: _e.mock.On("CreateRole", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_CreateRole_Call) Run(run func(ctx context.Context, in *milvuspb.CreateRoleRequest, opts ...grpc.CallOption)) *MockRootCoordClient_CreateRole_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.CreateRoleRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_CreateRole_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_CreateRole_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_CreateRole_Call) RunAndReturn(run func(context.Context, *milvuspb.CreateRoleRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_CreateRole_Call { + _c.Call.Return(run) + return _c +} + +// DeleteCredential provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) DeleteCredential(ctx context.Context, in *milvuspb.DeleteCredentialRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DeleteCredentialRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DeleteCredentialRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DeleteCredentialRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_DeleteCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteCredential' +type MockRootCoordClient_DeleteCredential_Call struct { + *mock.Call +} + +// DeleteCredential is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.DeleteCredentialRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) DeleteCredential(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_DeleteCredential_Call { + return &MockRootCoordClient_DeleteCredential_Call{Call: _e.mock.On("DeleteCredential", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_DeleteCredential_Call) Run(run func(ctx context.Context, in *milvuspb.DeleteCredentialRequest, opts ...grpc.CallOption)) *MockRootCoordClient_DeleteCredential_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.DeleteCredentialRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_DeleteCredential_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_DeleteCredential_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_DeleteCredential_Call) RunAndReturn(run func(context.Context, *milvuspb.DeleteCredentialRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_DeleteCredential_Call { + _c.Call.Return(run) + return _c +} + +// DescribeCollection provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.DescribeCollectionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeCollectionRequest, ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeCollectionRequest, ...grpc.CallOption) *milvuspb.DescribeCollectionResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.DescribeCollectionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DescribeCollectionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_DescribeCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeCollection' +type MockRootCoordClient_DescribeCollection_Call struct { + *mock.Call +} + +// DescribeCollection is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.DescribeCollectionRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) DescribeCollection(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_DescribeCollection_Call { + return &MockRootCoordClient_DescribeCollection_Call{Call: _e.mock.On("DescribeCollection", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_DescribeCollection_Call) Run(run func(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption)) *MockRootCoordClient_DescribeCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.DescribeCollectionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_DescribeCollection_Call) Return(_a0 *milvuspb.DescribeCollectionResponse, _a1 error) *MockRootCoordClient_DescribeCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_DescribeCollection_Call) RunAndReturn(run func(context.Context, *milvuspb.DescribeCollectionRequest, ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error)) *MockRootCoordClient_DescribeCollection_Call { + _c.Call.Return(run) + return _c +} + +// DescribeCollectionInternal provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) DescribeCollectionInternal(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.DescribeCollectionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeCollectionRequest, ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DescribeCollectionRequest, ...grpc.CallOption) *milvuspb.DescribeCollectionResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.DescribeCollectionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DescribeCollectionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_DescribeCollectionInternal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeCollectionInternal' +type MockRootCoordClient_DescribeCollectionInternal_Call struct { + *mock.Call +} + +// DescribeCollectionInternal is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.DescribeCollectionRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) DescribeCollectionInternal(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_DescribeCollectionInternal_Call { + return &MockRootCoordClient_DescribeCollectionInternal_Call{Call: _e.mock.On("DescribeCollectionInternal", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_DescribeCollectionInternal_Call) Run(run func(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption)) *MockRootCoordClient_DescribeCollectionInternal_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.DescribeCollectionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_DescribeCollectionInternal_Call) Return(_a0 *milvuspb.DescribeCollectionResponse, _a1 error) *MockRootCoordClient_DescribeCollectionInternal_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_DescribeCollectionInternal_Call) RunAndReturn(run func(context.Context, *milvuspb.DescribeCollectionRequest, ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error)) *MockRootCoordClient_DescribeCollectionInternal_Call { + _c.Call.Return(run) + return _c +} + +// DropAlias provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) DropAlias(ctx context.Context, in *milvuspb.DropAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropAliasRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropAliasRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropAliasRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_DropAlias_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropAlias' +type MockRootCoordClient_DropAlias_Call struct { + *mock.Call +} + +// DropAlias is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.DropAliasRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) DropAlias(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_DropAlias_Call { + return &MockRootCoordClient_DropAlias_Call{Call: _e.mock.On("DropAlias", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_DropAlias_Call) Run(run func(ctx context.Context, in *milvuspb.DropAliasRequest, opts ...grpc.CallOption)) *MockRootCoordClient_DropAlias_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.DropAliasRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_DropAlias_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_DropAlias_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_DropAlias_Call) RunAndReturn(run func(context.Context, *milvuspb.DropAliasRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_DropAlias_Call { + _c.Call.Return(run) + return _c +} + +// DropCollection provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) DropCollection(ctx context.Context, in *milvuspb.DropCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropCollectionRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropCollectionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_DropCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropCollection' +type MockRootCoordClient_DropCollection_Call struct { + *mock.Call +} + +// DropCollection is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.DropCollectionRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) DropCollection(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_DropCollection_Call { + return &MockRootCoordClient_DropCollection_Call{Call: _e.mock.On("DropCollection", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_DropCollection_Call) Run(run func(ctx context.Context, in *milvuspb.DropCollectionRequest, opts ...grpc.CallOption)) *MockRootCoordClient_DropCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.DropCollectionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_DropCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_DropCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_DropCollection_Call) RunAndReturn(run func(context.Context, *milvuspb.DropCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_DropCollection_Call { + _c.Call.Return(run) + return _c +} + +// DropDatabase provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropDatabaseRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropDatabaseRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropDatabaseRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_DropDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropDatabase' +type MockRootCoordClient_DropDatabase_Call struct { + *mock.Call +} + +// DropDatabase is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.DropDatabaseRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) DropDatabase(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_DropDatabase_Call { + return &MockRootCoordClient_DropDatabase_Call{Call: _e.mock.On("DropDatabase", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_DropDatabase_Call) Run(run func(ctx context.Context, in *milvuspb.DropDatabaseRequest, opts ...grpc.CallOption)) *MockRootCoordClient_DropDatabase_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.DropDatabaseRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_DropDatabase_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_DropDatabase_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_DropDatabase_Call) RunAndReturn(run func(context.Context, *milvuspb.DropDatabaseRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_DropDatabase_Call { + _c.Call.Return(run) + return _c +} + +// DropPartition provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropPartitionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropPartitionRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropPartitionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_DropPartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropPartition' +type MockRootCoordClient_DropPartition_Call struct { + *mock.Call +} + +// DropPartition is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.DropPartitionRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) DropPartition(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_DropPartition_Call { + return &MockRootCoordClient_DropPartition_Call{Call: _e.mock.On("DropPartition", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_DropPartition_Call) Run(run func(ctx context.Context, in *milvuspb.DropPartitionRequest, opts ...grpc.CallOption)) *MockRootCoordClient_DropPartition_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.DropPartitionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_DropPartition_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_DropPartition_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_DropPartition_Call) RunAndReturn(run func(context.Context, *milvuspb.DropPartitionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_DropPartition_Call { + _c.Call.Return(run) + return _c +} + +// DropRole provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) DropRole(ctx context.Context, in *milvuspb.DropRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropRoleRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.DropRoleRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.DropRoleRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_DropRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropRole' +type MockRootCoordClient_DropRole_Call struct { + *mock.Call +} + +// DropRole is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.DropRoleRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) DropRole(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_DropRole_Call { + return &MockRootCoordClient_DropRole_Call{Call: _e.mock.On("DropRole", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_DropRole_Call) Run(run func(ctx context.Context, in *milvuspb.DropRoleRequest, opts ...grpc.CallOption)) *MockRootCoordClient_DropRole_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.DropRoleRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_DropRole_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_DropRole_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_DropRole_Call) RunAndReturn(run func(context.Context, *milvuspb.DropRoleRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_DropRole_Call { + _c.Call.Return(run) + return _c +} + +// GetComponentStates provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ComponentStates + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) *milvuspb.ComponentStates); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ComponentStates) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_GetComponentStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetComponentStates' +type MockRootCoordClient_GetComponentStates_Call struct { + *mock.Call +} + +// GetComponentStates is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetComponentStatesRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) GetComponentStates(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_GetComponentStates_Call { + return &MockRootCoordClient_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_GetComponentStates_Call) Run(run func(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption)) *MockRootCoordClient_GetComponentStates_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetComponentStatesRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_GetComponentStates_Call) Return(_a0 *milvuspb.ComponentStates, _a1 error) *MockRootCoordClient_GetComponentStates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_GetComponentStates_Call) RunAndReturn(run func(context.Context, *milvuspb.GetComponentStatesRequest, ...grpc.CallOption) (*milvuspb.ComponentStates, error)) *MockRootCoordClient_GetComponentStates_Call { + _c.Call.Return(run) + return _c +} + +// GetCredential provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) GetCredential(ctx context.Context, in *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *rootcoordpb.GetCredentialResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.GetCredentialRequest, ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.GetCredentialRequest, ...grpc.CallOption) *rootcoordpb.GetCredentialResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*rootcoordpb.GetCredentialResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.GetCredentialRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_GetCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCredential' +type MockRootCoordClient_GetCredential_Call struct { + *mock.Call +} + +// GetCredential is a helper method to define mock.On call +// - ctx context.Context +// - in *rootcoordpb.GetCredentialRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) GetCredential(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_GetCredential_Call { + return &MockRootCoordClient_GetCredential_Call{Call: _e.mock.On("GetCredential", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_GetCredential_Call) Run(run func(ctx context.Context, in *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption)) *MockRootCoordClient_GetCredential_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*rootcoordpb.GetCredentialRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_GetCredential_Call) Return(_a0 *rootcoordpb.GetCredentialResponse, _a1 error) *MockRootCoordClient_GetCredential_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_GetCredential_Call) RunAndReturn(run func(context.Context, *rootcoordpb.GetCredentialRequest, ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error)) *MockRootCoordClient_GetCredential_Call { + _c.Call.Return(run) + return _c +} + +// GetImportState provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) GetImportState(ctx context.Context, in *milvuspb.GetImportStateRequest, opts ...grpc.CallOption) (*milvuspb.GetImportStateResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.GetImportStateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetImportStateRequest, ...grpc.CallOption) (*milvuspb.GetImportStateResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetImportStateRequest, ...grpc.CallOption) *milvuspb.GetImportStateResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetImportStateResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetImportStateRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_GetImportState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetImportState' +type MockRootCoordClient_GetImportState_Call struct { + *mock.Call +} + +// GetImportState is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetImportStateRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) GetImportState(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_GetImportState_Call { + return &MockRootCoordClient_GetImportState_Call{Call: _e.mock.On("GetImportState", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_GetImportState_Call) Run(run func(ctx context.Context, in *milvuspb.GetImportStateRequest, opts ...grpc.CallOption)) *MockRootCoordClient_GetImportState_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetImportStateRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_GetImportState_Call) Return(_a0 *milvuspb.GetImportStateResponse, _a1 error) *MockRootCoordClient_GetImportState_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_GetImportState_Call) RunAndReturn(run func(context.Context, *milvuspb.GetImportStateRequest, ...grpc.CallOption) (*milvuspb.GetImportStateResponse, error)) *MockRootCoordClient_GetImportState_Call { + _c.Call.Return(run) + return _c +} + +// GetMetrics provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.GetMetricsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) *milvuspb.GetMetricsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.GetMetricsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_GetMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetMetrics' +type MockRootCoordClient_GetMetrics_Call struct { + *mock.Call +} + +// GetMetrics is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.GetMetricsRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) GetMetrics(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_GetMetrics_Call { + return &MockRootCoordClient_GetMetrics_Call{Call: _e.mock.On("GetMetrics", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_GetMetrics_Call) Run(run func(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption)) *MockRootCoordClient_GetMetrics_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.GetMetricsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_GetMetrics_Call) Return(_a0 *milvuspb.GetMetricsResponse, _a1 error) *MockRootCoordClient_GetMetrics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_GetMetrics_Call) RunAndReturn(run func(context.Context, *milvuspb.GetMetricsRequest, ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error)) *MockRootCoordClient_GetMetrics_Call { + _c.Call.Return(run) + return _c +} + +// GetStatisticsChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.StringResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) *milvuspb.StringResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_GetStatisticsChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStatisticsChannel' +type MockRootCoordClient_GetStatisticsChannel_Call struct { + *mock.Call +} + +// GetStatisticsChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.GetStatisticsChannelRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) GetStatisticsChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_GetStatisticsChannel_Call { + return &MockRootCoordClient_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_GetStatisticsChannel_Call) Run(run func(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption)) *MockRootCoordClient_GetStatisticsChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.GetStatisticsChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_GetStatisticsChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *MockRootCoordClient_GetStatisticsChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_GetStatisticsChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetStatisticsChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)) *MockRootCoordClient_GetStatisticsChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetTimeTickChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.StringResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) *milvuspb.StringResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.StringResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_GetTimeTickChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTimeTickChannel' +type MockRootCoordClient_GetTimeTickChannel_Call struct { + *mock.Call +} + +// GetTimeTickChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.GetTimeTickChannelRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) GetTimeTickChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_GetTimeTickChannel_Call { + return &MockRootCoordClient_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_GetTimeTickChannel_Call) Run(run func(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption)) *MockRootCoordClient_GetTimeTickChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.GetTimeTickChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_GetTimeTickChannel_Call) Return(_a0 *milvuspb.StringResponse, _a1 error) *MockRootCoordClient_GetTimeTickChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_GetTimeTickChannel_Call) RunAndReturn(run func(context.Context, *internalpb.GetTimeTickChannelRequest, ...grpc.CallOption) (*milvuspb.StringResponse, error)) *MockRootCoordClient_GetTimeTickChannel_Call { + _c.Call.Return(run) + return _c +} + +// HasCollection provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) HasCollection(ctx context.Context, in *milvuspb.HasCollectionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.BoolResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasCollectionRequest, ...grpc.CallOption) (*milvuspb.BoolResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasCollectionRequest, ...grpc.CallOption) *milvuspb.BoolResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.BoolResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.HasCollectionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_HasCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HasCollection' +type MockRootCoordClient_HasCollection_Call struct { + *mock.Call +} + +// HasCollection is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.HasCollectionRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) HasCollection(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_HasCollection_Call { + return &MockRootCoordClient_HasCollection_Call{Call: _e.mock.On("HasCollection", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_HasCollection_Call) Run(run func(ctx context.Context, in *milvuspb.HasCollectionRequest, opts ...grpc.CallOption)) *MockRootCoordClient_HasCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.HasCollectionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_HasCollection_Call) Return(_a0 *milvuspb.BoolResponse, _a1 error) *MockRootCoordClient_HasCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_HasCollection_Call) RunAndReturn(run func(context.Context, *milvuspb.HasCollectionRequest, ...grpc.CallOption) (*milvuspb.BoolResponse, error)) *MockRootCoordClient_HasCollection_Call { + _c.Call.Return(run) + return _c +} + +// HasPartition provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.BoolResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasPartitionRequest, ...grpc.CallOption) (*milvuspb.BoolResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HasPartitionRequest, ...grpc.CallOption) *milvuspb.BoolResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.BoolResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.HasPartitionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_HasPartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HasPartition' +type MockRootCoordClient_HasPartition_Call struct { + *mock.Call +} + +// HasPartition is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.HasPartitionRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) HasPartition(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_HasPartition_Call { + return &MockRootCoordClient_HasPartition_Call{Call: _e.mock.On("HasPartition", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_HasPartition_Call) Run(run func(ctx context.Context, in *milvuspb.HasPartitionRequest, opts ...grpc.CallOption)) *MockRootCoordClient_HasPartition_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.HasPartitionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_HasPartition_Call) Return(_a0 *milvuspb.BoolResponse, _a1 error) *MockRootCoordClient_HasPartition_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_HasPartition_Call) RunAndReturn(run func(context.Context, *milvuspb.HasPartitionRequest, ...grpc.CallOption) (*milvuspb.BoolResponse, error)) *MockRootCoordClient_HasPartition_Call { + _c.Call.Return(run) + return _c +} + +// Import provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) Import(ctx context.Context, in *milvuspb.ImportRequest, opts ...grpc.CallOption) (*milvuspb.ImportResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ImportResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ImportRequest, ...grpc.CallOption) (*milvuspb.ImportResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ImportRequest, ...grpc.CallOption) *milvuspb.ImportResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ImportResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ImportRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_Import_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Import' +type MockRootCoordClient_Import_Call struct { + *mock.Call +} + +// Import is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.ImportRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) Import(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_Import_Call { + return &MockRootCoordClient_Import_Call{Call: _e.mock.On("Import", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_Import_Call) Run(run func(ctx context.Context, in *milvuspb.ImportRequest, opts ...grpc.CallOption)) *MockRootCoordClient_Import_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.ImportRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_Import_Call) Return(_a0 *milvuspb.ImportResponse, _a1 error) *MockRootCoordClient_Import_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_Import_Call) RunAndReturn(run func(context.Context, *milvuspb.ImportRequest, ...grpc.CallOption) (*milvuspb.ImportResponse, error)) *MockRootCoordClient_Import_Call { + _c.Call.Return(run) + return _c +} + +// InvalidateCollectionMetaCache provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_InvalidateCollectionMetaCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateCollectionMetaCache' +type MockRootCoordClient_InvalidateCollectionMetaCache_Call struct { + *mock.Call +} + +// InvalidateCollectionMetaCache is a helper method to define mock.On call +// - ctx context.Context +// - in *proxypb.InvalidateCollMetaCacheRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) InvalidateCollectionMetaCache(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_InvalidateCollectionMetaCache_Call { + return &MockRootCoordClient_InvalidateCollectionMetaCache_Call{Call: _e.mock.On("InvalidateCollectionMetaCache", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_InvalidateCollectionMetaCache_Call) Run(run func(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest, opts ...grpc.CallOption)) *MockRootCoordClient_InvalidateCollectionMetaCache_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*proxypb.InvalidateCollMetaCacheRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_InvalidateCollectionMetaCache_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_InvalidateCollectionMetaCache_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_InvalidateCollectionMetaCache_Call) RunAndReturn(run func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_InvalidateCollectionMetaCache_Call { + _c.Call.Return(run) + return _c +} + +// ListCredUsers provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) ListCredUsers(ctx context.Context, in *milvuspb.ListCredUsersRequest, opts ...grpc.CallOption) (*milvuspb.ListCredUsersResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ListCredUsersResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListCredUsersRequest, ...grpc.CallOption) (*milvuspb.ListCredUsersResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListCredUsersRequest, ...grpc.CallOption) *milvuspb.ListCredUsersResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ListCredUsersResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListCredUsersRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_ListCredUsers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCredUsers' +type MockRootCoordClient_ListCredUsers_Call struct { + *mock.Call +} + +// ListCredUsers is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.ListCredUsersRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) ListCredUsers(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_ListCredUsers_Call { + return &MockRootCoordClient_ListCredUsers_Call{Call: _e.mock.On("ListCredUsers", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_ListCredUsers_Call) Run(run func(ctx context.Context, in *milvuspb.ListCredUsersRequest, opts ...grpc.CallOption)) *MockRootCoordClient_ListCredUsers_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.ListCredUsersRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_ListCredUsers_Call) Return(_a0 *milvuspb.ListCredUsersResponse, _a1 error) *MockRootCoordClient_ListCredUsers_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_ListCredUsers_Call) RunAndReturn(run func(context.Context, *milvuspb.ListCredUsersRequest, ...grpc.CallOption) (*milvuspb.ListCredUsersResponse, error)) *MockRootCoordClient_ListCredUsers_Call { + _c.Call.Return(run) + return _c +} + +// ListDatabases provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ListDatabasesResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListDatabasesRequest, ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListDatabasesRequest, ...grpc.CallOption) *milvuspb.ListDatabasesResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ListDatabasesResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListDatabasesRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_ListDatabases_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListDatabases' +type MockRootCoordClient_ListDatabases_Call struct { + *mock.Call +} + +// ListDatabases is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.ListDatabasesRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) ListDatabases(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_ListDatabases_Call { + return &MockRootCoordClient_ListDatabases_Call{Call: _e.mock.On("ListDatabases", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_ListDatabases_Call) Run(run func(ctx context.Context, in *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption)) *MockRootCoordClient_ListDatabases_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.ListDatabasesRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_ListDatabases_Call) Return(_a0 *milvuspb.ListDatabasesResponse, _a1 error) *MockRootCoordClient_ListDatabases_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_ListDatabases_Call) RunAndReturn(run func(context.Context, *milvuspb.ListDatabasesRequest, ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error)) *MockRootCoordClient_ListDatabases_Call { + _c.Call.Return(run) + return _c +} + +// ListImportTasks provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) ListImportTasks(ctx context.Context, in *milvuspb.ListImportTasksRequest, opts ...grpc.CallOption) (*milvuspb.ListImportTasksResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ListImportTasksResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListImportTasksRequest, ...grpc.CallOption) (*milvuspb.ListImportTasksResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ListImportTasksRequest, ...grpc.CallOption) *milvuspb.ListImportTasksResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ListImportTasksResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ListImportTasksRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_ListImportTasks_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListImportTasks' +type MockRootCoordClient_ListImportTasks_Call struct { + *mock.Call +} + +// ListImportTasks is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.ListImportTasksRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) ListImportTasks(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_ListImportTasks_Call { + return &MockRootCoordClient_ListImportTasks_Call{Call: _e.mock.On("ListImportTasks", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_ListImportTasks_Call) Run(run func(ctx context.Context, in *milvuspb.ListImportTasksRequest, opts ...grpc.CallOption)) *MockRootCoordClient_ListImportTasks_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.ListImportTasksRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_ListImportTasks_Call) Return(_a0 *milvuspb.ListImportTasksResponse, _a1 error) *MockRootCoordClient_ListImportTasks_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_ListImportTasks_Call) RunAndReturn(run func(context.Context, *milvuspb.ListImportTasksRequest, ...grpc.CallOption) (*milvuspb.ListImportTasksResponse, error)) *MockRootCoordClient_ListImportTasks_Call { + _c.Call.Return(run) + return _c +} + +// ListPolicy provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest, opts ...grpc.CallOption) (*internalpb.ListPolicyResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *internalpb.ListPolicyResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ListPolicyRequest, ...grpc.CallOption) (*internalpb.ListPolicyResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ListPolicyRequest, ...grpc.CallOption) *internalpb.ListPolicyResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internalpb.ListPolicyResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ListPolicyRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_ListPolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListPolicy' +type MockRootCoordClient_ListPolicy_Call struct { + *mock.Call +} + +// ListPolicy is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.ListPolicyRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) ListPolicy(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_ListPolicy_Call { + return &MockRootCoordClient_ListPolicy_Call{Call: _e.mock.On("ListPolicy", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_ListPolicy_Call) Run(run func(ctx context.Context, in *internalpb.ListPolicyRequest, opts ...grpc.CallOption)) *MockRootCoordClient_ListPolicy_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.ListPolicyRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_ListPolicy_Call) Return(_a0 *internalpb.ListPolicyResponse, _a1 error) *MockRootCoordClient_ListPolicy_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_ListPolicy_Call) RunAndReturn(run func(context.Context, *internalpb.ListPolicyRequest, ...grpc.CallOption) (*internalpb.ListPolicyResponse, error)) *MockRootCoordClient_ListPolicy_Call { + _c.Call.Return(run) + return _c +} + +// OperatePrivilege provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) OperatePrivilege(ctx context.Context, in *milvuspb.OperatePrivilegeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperatePrivilegeRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperatePrivilegeRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.OperatePrivilegeRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_OperatePrivilege_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OperatePrivilege' +type MockRootCoordClient_OperatePrivilege_Call struct { + *mock.Call +} + +// OperatePrivilege is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.OperatePrivilegeRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) OperatePrivilege(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_OperatePrivilege_Call { + return &MockRootCoordClient_OperatePrivilege_Call{Call: _e.mock.On("OperatePrivilege", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_OperatePrivilege_Call) Run(run func(ctx context.Context, in *milvuspb.OperatePrivilegeRequest, opts ...grpc.CallOption)) *MockRootCoordClient_OperatePrivilege_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.OperatePrivilegeRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_OperatePrivilege_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_OperatePrivilege_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_OperatePrivilege_Call) RunAndReturn(run func(context.Context, *milvuspb.OperatePrivilegeRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_OperatePrivilege_Call { + _c.Call.Return(run) + return _c +} + +// OperateUserRole provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) OperateUserRole(ctx context.Context, in *milvuspb.OperateUserRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperateUserRoleRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.OperateUserRoleRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.OperateUserRoleRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_OperateUserRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OperateUserRole' +type MockRootCoordClient_OperateUserRole_Call struct { + *mock.Call +} + +// OperateUserRole is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.OperateUserRoleRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) OperateUserRole(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_OperateUserRole_Call { + return &MockRootCoordClient_OperateUserRole_Call{Call: _e.mock.On("OperateUserRole", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_OperateUserRole_Call) Run(run func(ctx context.Context, in *milvuspb.OperateUserRoleRequest, opts ...grpc.CallOption)) *MockRootCoordClient_OperateUserRole_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.OperateUserRoleRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_OperateUserRole_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_OperateUserRole_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_OperateUserRole_Call) RunAndReturn(run func(context.Context, *milvuspb.OperateUserRoleRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_OperateUserRole_Call { + _c.Call.Return(run) + return _c +} + +// RenameCollection provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) RenameCollection(ctx context.Context, in *milvuspb.RenameCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.RenameCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.RenameCollectionRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.RenameCollectionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_RenameCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RenameCollection' +type MockRootCoordClient_RenameCollection_Call struct { + *mock.Call +} + +// RenameCollection is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.RenameCollectionRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) RenameCollection(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_RenameCollection_Call { + return &MockRootCoordClient_RenameCollection_Call{Call: _e.mock.On("RenameCollection", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_RenameCollection_Call) Run(run func(ctx context.Context, in *milvuspb.RenameCollectionRequest, opts ...grpc.CallOption)) *MockRootCoordClient_RenameCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.RenameCollectionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_RenameCollection_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_RenameCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_RenameCollection_Call) RunAndReturn(run func(context.Context, *milvuspb.RenameCollectionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_RenameCollection_Call { + _c.Call.Return(run) + return _c +} + +// ReportImport provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) ReportImport(ctx context.Context, in *rootcoordpb.ImportResult, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.ImportResult, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.ImportResult, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.ImportResult, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_ReportImport_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportImport' +type MockRootCoordClient_ReportImport_Call struct { + *mock.Call +} + +// ReportImport is a helper method to define mock.On call +// - ctx context.Context +// - in *rootcoordpb.ImportResult +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) ReportImport(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_ReportImport_Call { + return &MockRootCoordClient_ReportImport_Call{Call: _e.mock.On("ReportImport", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_ReportImport_Call) Run(run func(ctx context.Context, in *rootcoordpb.ImportResult, opts ...grpc.CallOption)) *MockRootCoordClient_ReportImport_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*rootcoordpb.ImportResult), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_ReportImport_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_ReportImport_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_ReportImport_Call) RunAndReturn(run func(context.Context, *rootcoordpb.ImportResult, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_ReportImport_Call { + _c.Call.Return(run) + return _c +} + +// SelectGrant provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) SelectGrant(ctx context.Context, in *milvuspb.SelectGrantRequest, opts ...grpc.CallOption) (*milvuspb.SelectGrantResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.SelectGrantResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectGrantRequest, ...grpc.CallOption) (*milvuspb.SelectGrantResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectGrantRequest, ...grpc.CallOption) *milvuspb.SelectGrantResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.SelectGrantResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SelectGrantRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_SelectGrant_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SelectGrant' +type MockRootCoordClient_SelectGrant_Call struct { + *mock.Call +} + +// SelectGrant is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.SelectGrantRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) SelectGrant(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_SelectGrant_Call { + return &MockRootCoordClient_SelectGrant_Call{Call: _e.mock.On("SelectGrant", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_SelectGrant_Call) Run(run func(ctx context.Context, in *milvuspb.SelectGrantRequest, opts ...grpc.CallOption)) *MockRootCoordClient_SelectGrant_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.SelectGrantRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_SelectGrant_Call) Return(_a0 *milvuspb.SelectGrantResponse, _a1 error) *MockRootCoordClient_SelectGrant_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_SelectGrant_Call) RunAndReturn(run func(context.Context, *milvuspb.SelectGrantRequest, ...grpc.CallOption) (*milvuspb.SelectGrantResponse, error)) *MockRootCoordClient_SelectGrant_Call { + _c.Call.Return(run) + return _c +} + +// SelectRole provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) SelectRole(ctx context.Context, in *milvuspb.SelectRoleRequest, opts ...grpc.CallOption) (*milvuspb.SelectRoleResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.SelectRoleResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectRoleRequest, ...grpc.CallOption) (*milvuspb.SelectRoleResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectRoleRequest, ...grpc.CallOption) *milvuspb.SelectRoleResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.SelectRoleResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SelectRoleRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_SelectRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SelectRole' +type MockRootCoordClient_SelectRole_Call struct { + *mock.Call +} + +// SelectRole is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.SelectRoleRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) SelectRole(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_SelectRole_Call { + return &MockRootCoordClient_SelectRole_Call{Call: _e.mock.On("SelectRole", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_SelectRole_Call) Run(run func(ctx context.Context, in *milvuspb.SelectRoleRequest, opts ...grpc.CallOption)) *MockRootCoordClient_SelectRole_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.SelectRoleRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_SelectRole_Call) Return(_a0 *milvuspb.SelectRoleResponse, _a1 error) *MockRootCoordClient_SelectRole_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_SelectRole_Call) RunAndReturn(run func(context.Context, *milvuspb.SelectRoleRequest, ...grpc.CallOption) (*milvuspb.SelectRoleResponse, error)) *MockRootCoordClient_SelectRole_Call { + _c.Call.Return(run) + return _c +} + +// SelectUser provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) SelectUser(ctx context.Context, in *milvuspb.SelectUserRequest, opts ...grpc.CallOption) (*milvuspb.SelectUserResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.SelectUserResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectUserRequest, ...grpc.CallOption) (*milvuspb.SelectUserResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SelectUserRequest, ...grpc.CallOption) *milvuspb.SelectUserResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.SelectUserResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SelectUserRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_SelectUser_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SelectUser' +type MockRootCoordClient_SelectUser_Call struct { + *mock.Call +} + +// SelectUser is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.SelectUserRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) SelectUser(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_SelectUser_Call { + return &MockRootCoordClient_SelectUser_Call{Call: _e.mock.On("SelectUser", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_SelectUser_Call) Run(run func(ctx context.Context, in *milvuspb.SelectUserRequest, opts ...grpc.CallOption)) *MockRootCoordClient_SelectUser_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.SelectUserRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_SelectUser_Call) Return(_a0 *milvuspb.SelectUserResponse, _a1 error) *MockRootCoordClient_SelectUser_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_SelectUser_Call) RunAndReturn(run func(context.Context, *milvuspb.SelectUserRequest, ...grpc.CallOption) (*milvuspb.SelectUserResponse, error)) *MockRootCoordClient_SelectUser_Call { + _c.Call.Return(run) + return _c +} + +// ShowCollections provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ShowCollectionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowCollectionsRequest, ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowCollectionsRequest, ...grpc.CallOption) *milvuspb.ShowCollectionsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ShowCollectionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ShowCollectionsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_ShowCollections_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowCollections' +type MockRootCoordClient_ShowCollections_Call struct { + *mock.Call +} + +// ShowCollections is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.ShowCollectionsRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) ShowCollections(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_ShowCollections_Call { + return &MockRootCoordClient_ShowCollections_Call{Call: _e.mock.On("ShowCollections", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_ShowCollections_Call) Run(run func(ctx context.Context, in *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption)) *MockRootCoordClient_ShowCollections_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.ShowCollectionsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_ShowCollections_Call) Return(_a0 *milvuspb.ShowCollectionsResponse, _a1 error) *MockRootCoordClient_ShowCollections_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_ShowCollections_Call) RunAndReturn(run func(context.Context, *milvuspb.ShowCollectionsRequest, ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error)) *MockRootCoordClient_ShowCollections_Call { + _c.Call.Return(run) + return _c +} + +// ShowConfigurations provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *internalpb.ShowConfigurationsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) *internalpb.ShowConfigurationsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*internalpb.ShowConfigurationsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_ShowConfigurations_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowConfigurations' +type MockRootCoordClient_ShowConfigurations_Call struct { + *mock.Call +} + +// ShowConfigurations is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.ShowConfigurationsRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) ShowConfigurations(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_ShowConfigurations_Call { + return &MockRootCoordClient_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_ShowConfigurations_Call) Run(run func(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption)) *MockRootCoordClient_ShowConfigurations_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.ShowConfigurationsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_ShowConfigurations_Call) Return(_a0 *internalpb.ShowConfigurationsResponse, _a1 error) *MockRootCoordClient_ShowConfigurations_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_ShowConfigurations_Call) RunAndReturn(run func(context.Context, *internalpb.ShowConfigurationsRequest, ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error)) *MockRootCoordClient_ShowConfigurations_Call { + _c.Call.Return(run) + return _c +} + +// ShowPartitions provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ShowPartitionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowPartitionsRequest, ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowPartitionsRequest, ...grpc.CallOption) *milvuspb.ShowPartitionsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ShowPartitionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ShowPartitionsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_ShowPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowPartitions' +type MockRootCoordClient_ShowPartitions_Call struct { + *mock.Call +} + +// ShowPartitions is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.ShowPartitionsRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) ShowPartitions(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_ShowPartitions_Call { + return &MockRootCoordClient_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_ShowPartitions_Call) Run(run func(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption)) *MockRootCoordClient_ShowPartitions_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.ShowPartitionsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_ShowPartitions_Call) Return(_a0 *milvuspb.ShowPartitionsResponse, _a1 error) *MockRootCoordClient_ShowPartitions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_ShowPartitions_Call) RunAndReturn(run func(context.Context, *milvuspb.ShowPartitionsRequest, ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error)) *MockRootCoordClient_ShowPartitions_Call { + _c.Call.Return(run) + return _c +} + +// ShowPartitionsInternal provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) ShowPartitionsInternal(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ShowPartitionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowPartitionsRequest, ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowPartitionsRequest, ...grpc.CallOption) *milvuspb.ShowPartitionsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ShowPartitionsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ShowPartitionsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_ShowPartitionsInternal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowPartitionsInternal' +type MockRootCoordClient_ShowPartitionsInternal_Call struct { + *mock.Call +} + +// ShowPartitionsInternal is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.ShowPartitionsRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) ShowPartitionsInternal(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_ShowPartitionsInternal_Call { + return &MockRootCoordClient_ShowPartitionsInternal_Call{Call: _e.mock.On("ShowPartitionsInternal", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_ShowPartitionsInternal_Call) Run(run func(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption)) *MockRootCoordClient_ShowPartitionsInternal_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.ShowPartitionsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_ShowPartitionsInternal_Call) Return(_a0 *milvuspb.ShowPartitionsResponse, _a1 error) *MockRootCoordClient_ShowPartitionsInternal_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_ShowPartitionsInternal_Call) RunAndReturn(run func(context.Context, *milvuspb.ShowPartitionsRequest, ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error)) *MockRootCoordClient_ShowPartitionsInternal_Call { + _c.Call.Return(run) + return _c +} + +// ShowSegments provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequest, opts ...grpc.CallOption) (*milvuspb.ShowSegmentsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *milvuspb.ShowSegmentsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowSegmentsRequest, ...grpc.CallOption) (*milvuspb.ShowSegmentsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.ShowSegmentsRequest, ...grpc.CallOption) *milvuspb.ShowSegmentsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.ShowSegmentsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.ShowSegmentsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_ShowSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowSegments' +type MockRootCoordClient_ShowSegments_Call struct { + *mock.Call +} + +// ShowSegments is a helper method to define mock.On call +// - ctx context.Context +// - in *milvuspb.ShowSegmentsRequest +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) ShowSegments(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_ShowSegments_Call { + return &MockRootCoordClient_ShowSegments_Call{Call: _e.mock.On("ShowSegments", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_ShowSegments_Call) Run(run func(ctx context.Context, in *milvuspb.ShowSegmentsRequest, opts ...grpc.CallOption)) *MockRootCoordClient_ShowSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*milvuspb.ShowSegmentsRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_ShowSegments_Call) Return(_a0 *milvuspb.ShowSegmentsResponse, _a1 error) *MockRootCoordClient_ShowSegments_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_ShowSegments_Call) RunAndReturn(run func(context.Context, *milvuspb.ShowSegmentsRequest, ...grpc.CallOption) (*milvuspb.ShowSegmentsResponse, error)) *MockRootCoordClient_ShowSegments_Call { + _c.Call.Return(run) + return _c +} + +// UpdateChannelTimeTick provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) UpdateChannelTimeTick(ctx context.Context, in *internalpb.ChannelTimeTickMsg, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ChannelTimeTickMsg, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.ChannelTimeTickMsg, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.ChannelTimeTickMsg, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_UpdateChannelTimeTick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateChannelTimeTick' +type MockRootCoordClient_UpdateChannelTimeTick_Call struct { + *mock.Call +} + +// UpdateChannelTimeTick is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.ChannelTimeTickMsg +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) UpdateChannelTimeTick(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_UpdateChannelTimeTick_Call { + return &MockRootCoordClient_UpdateChannelTimeTick_Call{Call: _e.mock.On("UpdateChannelTimeTick", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_UpdateChannelTimeTick_Call) Run(run func(ctx context.Context, in *internalpb.ChannelTimeTickMsg, opts ...grpc.CallOption)) *MockRootCoordClient_UpdateChannelTimeTick_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.ChannelTimeTickMsg), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_UpdateChannelTimeTick_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_UpdateChannelTimeTick_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_UpdateChannelTimeTick_Call) RunAndReturn(run func(context.Context, *internalpb.ChannelTimeTickMsg, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_UpdateChannelTimeTick_Call { + _c.Call.Return(run) + return _c +} + +// UpdateCredential provides a mock function with given fields: ctx, in, opts +func (_m *MockRootCoordClient) UpdateCredential(ctx context.Context, in *internalpb.CredentialInfo, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.CredentialInfo, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *internalpb.CredentialInfo, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *internalpb.CredentialInfo, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRootCoordClient_UpdateCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCredential' +type MockRootCoordClient_UpdateCredential_Call struct { + *mock.Call +} + +// UpdateCredential is a helper method to define mock.On call +// - ctx context.Context +// - in *internalpb.CredentialInfo +// - opts ...grpc.CallOption +func (_e *MockRootCoordClient_Expecter) UpdateCredential(ctx interface{}, in interface{}, opts ...interface{}) *MockRootCoordClient_UpdateCredential_Call { + return &MockRootCoordClient_UpdateCredential_Call{Call: _e.mock.On("UpdateCredential", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockRootCoordClient_UpdateCredential_Call) Run(run func(ctx context.Context, in *internalpb.CredentialInfo, opts ...grpc.CallOption)) *MockRootCoordClient_UpdateCredential_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*internalpb.CredentialInfo), variadicArgs...) + }) + return _c +} + +func (_c *MockRootCoordClient_UpdateCredential_Call) Return(_a0 *commonpb.Status, _a1 error) *MockRootCoordClient_UpdateCredential_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRootCoordClient_UpdateCredential_Call) RunAndReturn(run func(context.Context, *internalpb.CredentialInfo, ...grpc.CallOption) (*commonpb.Status, error)) *MockRootCoordClient_UpdateCredential_Call { + _c.Call.Return(run) + return _c +} + +// NewMockRootCoordClient creates a new instance of MockRootCoordClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockRootCoordClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockRootCoordClient { + mock := &MockRootCoordClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mq/mqimpl/rocksmq/client/client.go b/internal/mq/mqimpl/rocksmq/client/client.go index cc25d8bd09c2e..8bc6aab90d4cc 100644 --- a/internal/mq/mqimpl/rocksmq/client/client.go +++ b/internal/mq/mqimpl/rocksmq/client/client.go @@ -11,9 +11,7 @@ package client -import ( - "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" -) +import "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" // RocksMQ is the type server.RocksMQ type RocksMQ = server.RocksMQ diff --git a/internal/mq/mqimpl/rocksmq/client/client_impl.go b/internal/mq/mqimpl/rocksmq/client/client_impl.go index b32fc7943523b..8680f77e97194 100644 --- a/internal/mq/mqimpl/rocksmq/client/client_impl.go +++ b/internal/mq/mqimpl/rocksmq/client/client_impl.go @@ -32,7 +32,6 @@ type client struct { } func newClient(options Options) (*client, error) { - if options.Server == nil { return nil, newError(InvalidConfiguration, "options.Server is nil") } @@ -50,7 +49,6 @@ func newClient(options Options) (*client, error) { func (c *client) CreateProducer(options ProducerOptions) (Producer, error) { // Create a producer producer, err := newProducer(c, options) - if err != nil { return nil, err } @@ -166,12 +164,19 @@ func (c *client) deliver(consumer *consumer) { break } for _, msg := range msgs { + // This is the hack, we put property into pl + properties := make(map[string]string, 0) + pl, err := UnmarshalHeader(msg.Payload) + if err == nil && pl != nil && pl.Base != nil { + properties = pl.Base.Properties + } select { - case consumer.messageCh <- Message{ - MsgID: msg.MsgID, - Payload: msg.Payload, - Properties: msg.Properties, - Topic: consumer.Topic()}: + case consumer.messageCh <- &RmqMessage{ + msgID: msg.MsgID, + payload: msg.Payload, + properties: properties, + topic: consumer.Topic(), + }: case <-c.closeCh: return } diff --git a/internal/mq/mqimpl/rocksmq/client/client_impl_test.go b/internal/mq/mqimpl/rocksmq/client/client_impl_test.go index 0a0f2e03d4fd8..19c0a1bab6201 100644 --- a/internal/mq/mqimpl/rocksmq/client/client_impl_test.go +++ b/internal/mq/mqimpl/rocksmq/client/client_impl_test.go @@ -17,10 +17,14 @@ import ( "testing" "time" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -187,7 +191,7 @@ func TestClient_SeekLatest(t *testing.T) { }) assert.NotNil(t, producer) assert.NoError(t, err) - msg := &ProducerMessage{ + msg := &mqwrapper.ProducerMessage{ Payload: make([]byte, 10), Properties: map[string]string{}, } @@ -197,7 +201,7 @@ func TestClient_SeekLatest(t *testing.T) { msgChan := consumer1.Chan() msgRead, ok := <-msgChan assert.Equal(t, ok, true) - assert.Equal(t, msgRead.MsgID, id) + assert.Equal(t, msgRead.ID(), &server.RmqID{MessageID: id}) consumer1.Close() @@ -217,10 +221,10 @@ func TestClient_SeekLatest(t *testing.T) { for loop { select { case msg := <-msgChan: - assert.Equal(t, len(msg.Payload), 8) + assert.Equal(t, len(msg.Payload()), 8) loop = false case <-ticker.C: - msg := &ProducerMessage{ + msg := &mqwrapper.ProducerMessage{ Payload: make([]byte, 8), } _, err = producer.Send(msg) @@ -261,7 +265,7 @@ func TestClient_consume(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, consumer) - msg := &ProducerMessage{ + msg := &mqwrapper.ProducerMessage{ Payload: make([]byte, 10), } id, err := producer.Send(msg) @@ -270,5 +274,80 @@ func TestClient_consume(t *testing.T) { msgChan := consumer.Chan() msgConsume, ok := <-msgChan assert.Equal(t, ok, true) - assert.Equal(t, id, msgConsume.MsgID) + assert.Equal(t, &server.RmqID{MessageID: id}, msgConsume.ID()) +} + +func TestRocksmq_Properties(t *testing.T) { + os.MkdirAll(rmqPath, os.ModePerm) + rmqPathTest := rmqPath + "/test_client4" + rmq := newRocksMQ(t, rmqPathTest) + defer removePath(rmqPath) + client, err := NewClient(Options{ + Server: rmq, + }) + assert.NoError(t, err) + defer client.Close() + topicName := newTopicName() + producer, err := client.CreateProducer(ProducerOptions{ + Topic: topicName, + }) + assert.NotNil(t, producer) + assert.NoError(t, err) + + opt := ConsumerOptions{ + Topic: topicName, + SubscriptionName: newConsumerName(), + SubscriptionInitialPosition: mqwrapper.SubscriptionPositionEarliest, + } + consumer, err := client.Subscribe(opt) + assert.NoError(t, err) + assert.NotNil(t, consumer) + + timeTickMsg := &msgpb.TimeTickMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_TimeTick, + MsgID: UniqueID(0), + Timestamp: 100, + SourceID: 0, + }, + } + msgb, errMarshal := proto.Marshal(timeTickMsg) + assert.NoError(t, errMarshal) + assert.True(t, len(msgb) > 0) + header, err := UnmarshalHeader(msgb) + assert.NoError(t, err) + assert.NotNil(t, header) + msg := &mqwrapper.ProducerMessage{ + Payload: msgb, + Properties: map[string]string{common.TraceIDKey: "a"}, + } + + _, err = producer.Send(msg) + assert.NoError(t, err) + + msg = &mqwrapper.ProducerMessage{ + Payload: msgb, + Properties: map[string]string{common.TraceIDKey: "b"}, + } + _, err = producer.Send(msg) + assert.NoError(t, err) + + msgChan := consumer.Chan() + msgConsume, ok := <-msgChan + assert.True(t, ok) + assert.Equal(t, len(msgConsume.Properties()), 1) + assert.Equal(t, msgConsume.Properties()[common.TraceIDKey], "a") + assert.NoError(t, err) + + msgConsume, ok = <-msgChan + assert.True(t, ok) + assert.Equal(t, len(msgConsume.Properties()), 1) + assert.Equal(t, msgConsume.Properties()[common.TraceIDKey], "b") + assert.NoError(t, err) + + timeTickMsg2 := &msgpb.TimeTickMsg{} + proto.Unmarshal(msgConsume.Payload(), timeTickMsg2) + + assert.Equal(t, timeTickMsg2.Base.MsgType, commonpb.MsgType_TimeTick) + assert.Equal(t, timeTickMsg2.Base.Timestamp, uint64(100)) } diff --git a/internal/mq/mqimpl/rocksmq/client/consumer.go b/internal/mq/mqimpl/rocksmq/client/consumer.go index b555b78bc45af..6790f5a520de7 100644 --- a/internal/mq/mqimpl/rocksmq/client/consumer.go +++ b/internal/mq/mqimpl/rocksmq/client/consumer.go @@ -38,16 +38,7 @@ type ConsumerOptions struct { // Message for this consumer // When a message is received, it will be pushed to this channel for consumption - MessageChannel chan Message -} - -// Message is the message content of a consumer message -type Message struct { - Consumer - MsgID UniqueID - Topic string - Payload []byte - Properties map[string]string + MessageChannel chan mqwrapper.Message } // Consumer interface provide operations for a consumer @@ -62,7 +53,7 @@ type Consumer interface { MsgMutex() chan struct{} // Message channel - Chan() <-chan Message + Chan() <-chan mqwrapper.Message // Seek to the uniqueID position Seek(UniqueID) error //nolint:govet diff --git a/internal/mq/mqimpl/rocksmq/client/consumer_impl.go b/internal/mq/mqimpl/rocksmq/client/consumer_impl.go index 75e220bf6ed4c..1f95087ef3e48 100644 --- a/internal/mq/mqimpl/rocksmq/client/consumer_impl.go +++ b/internal/mq/mqimpl/rocksmq/client/consumer_impl.go @@ -17,6 +17,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" ) type consumer struct { @@ -29,7 +30,7 @@ type consumer struct { msgMutex chan struct{} initCh chan struct{} - messageCh chan Message + messageCh chan mqwrapper.Message } func newConsumer(c *client, options ConsumerOptions) (*consumer, error) { @@ -47,7 +48,7 @@ func newConsumer(c *client, options ConsumerOptions) (*consumer, error) { messageCh := options.MessageChannel if options.MessageChannel == nil { - messageCh = make(chan Message, 1) + messageCh = make(chan mqwrapper.Message, 1) } // only used for initCh := make(chan struct{}, 1) @@ -79,7 +80,7 @@ func getExistedConsumer(c *client, options ConsumerOptions, msgMutex chan struct messageCh := options.MessageChannel if options.MessageChannel == nil { - messageCh = make(chan Message, 1) + messageCh = make(chan mqwrapper.Message, 1) } return &consumer{ @@ -108,7 +109,7 @@ func (c *consumer) MsgMutex() chan struct{} { } // Chan start consume goroutine and return message channel -func (c *consumer) Chan() <-chan Message { +func (c *consumer) Chan() <-chan mqwrapper.Message { c.startOnce.Do(func() { c.client.wg.Add(1) go c.client.consume(c) diff --git a/internal/mq/mqimpl/rocksmq/client/consumer_impl_test.go b/internal/mq/mqimpl/rocksmq/client/consumer_impl_test.go index 9afd9d39ef24a..feeb689139b9e 100644 --- a/internal/mq/mqimpl/rocksmq/client/consumer_impl_test.go +++ b/internal/mq/mqimpl/rocksmq/client/consumer_impl_test.go @@ -117,7 +117,7 @@ func TestConsumer_Subscription(t *testing.T) { }) assert.Nil(t, consumer) assert.Error(t, err) - //assert.Equal(t, consumerName, consumer.Subscription()) + // assert.Equal(t, consumerName, consumer.Subscription()) } func TestConsumer_Seek(t *testing.T) { diff --git a/internal/mq/mqimpl/rocksmq/client/producer.go b/internal/mq/mqimpl/rocksmq/client/producer.go index 9e1d22074b3db..6a9f74f8e6cfb 100644 --- a/internal/mq/mqimpl/rocksmq/client/producer.go +++ b/internal/mq/mqimpl/rocksmq/client/producer.go @@ -11,24 +11,20 @@ package client +import "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" + // ProducerOptions is the options of a producer type ProducerOptions struct { Topic string } -// ProducerMessage is the message of a producer -type ProducerMessage struct { - Payload []byte - Properties map[string]string -} - // Producer provedes some operations for a producer type Producer interface { // return the topic which producer is publishing to Topic() string // publish a message - Send(message *ProducerMessage) (UniqueID, error) + Send(message *mqwrapper.ProducerMessage) (UniqueID, error) // Close a producer Close() diff --git a/internal/mq/mqimpl/rocksmq/client/producer_impl.go b/internal/mq/mqimpl/rocksmq/client/producer_impl.go index b3b5de92c93b6..d401f7ed2ad9b 100644 --- a/internal/mq/mqimpl/rocksmq/client/producer_impl.go +++ b/internal/mq/mqimpl/rocksmq/client/producer_impl.go @@ -16,6 +16,7 @@ import ( "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" ) // assertion make sure implementation @@ -50,11 +51,23 @@ func (p *producer) Topic() string { } // Send produce message in rocksmq -func (p *producer) Send(message *ProducerMessage) (UniqueID, error) { +func (p *producer) Send(message *mqwrapper.ProducerMessage) (UniqueID, error) { + // NOTICE: this is the hack. + // we should not unmarshal the payload here but we can not extend the payload byte + payload := message.Payload + header, err := UnmarshalHeader(message.Payload) + if err == nil && header != nil && header.Base != nil { + // try to marshal properties into message if message is real message + header.Base.Properties = message.Properties + payload, err = MarshalHeader(header) + if err != nil { + return 0, err + } + } + ids, err := p.c.server.Produce(p.topic, []server.ProducerMessage{ { - Payload: message.Payload, - Properties: message.Properties, + Payload: payload, }, }) if err != nil { diff --git a/internal/mq/mqimpl/rocksmq/client/producer_impl_test.go b/internal/mq/mqimpl/rocksmq/client/producer_impl_test.go index d04aa02359822..f372ab1b979ea 100644 --- a/internal/mq/mqimpl/rocksmq/client/producer_impl_test.go +++ b/internal/mq/mqimpl/rocksmq/client/producer_impl_test.go @@ -40,5 +40,5 @@ func TestProducerTopic(t *testing.T) { }) assert.Nil(t, producer) assert.Error(t, err) - //assert.Equal(t, topicName, producer.Topic()) + // assert.Equal(t, topicName, producer.Topic()) } diff --git a/internal/mq/msgstream/mqwrapper/rmq/rmq_message.go b/internal/mq/mqimpl/rocksmq/client/rmq_message.go similarity index 62% rename from internal/mq/msgstream/mqwrapper/rmq/rmq_message.go rename to internal/mq/mqimpl/rocksmq/client/rmq_message.go index f9d5e3f63a637..7133f392344e5 100644 --- a/internal/mq/msgstream/mqwrapper/rmq/rmq_message.go +++ b/internal/mq/mqimpl/rocksmq/client/rmq_message.go @@ -9,37 +9,41 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License. -package rmq +package client import ( - "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/client" + "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) // Check rmqMessage implements ConsumerMessage -var _ mqwrapper.Message = (*rmqMessage)(nil) +var _ mqwrapper.Message = (*RmqMessage)(nil) // rmqMessage wraps the message for rocksmq -type rmqMessage struct { - msg client.Message +type RmqMessage struct { + msgID typeutil.UniqueID + topic string + payload []byte + properties map[string]string } // Topic returns the topic name of rocksmq message -func (rm *rmqMessage) Topic() string { - return rm.msg.Topic +func (rm *RmqMessage) Topic() string { + return rm.topic } // Properties returns the properties of rocksmq message -func (rm *rmqMessage) Properties() map[string]string { - return rm.msg.Properties +func (rm *RmqMessage) Properties() map[string]string { + return rm.properties } // Payload returns the payload of rocksmq message -func (rm *rmqMessage) Payload() []byte { - return rm.msg.Payload +func (rm *RmqMessage) Payload() []byte { + return rm.payload } // ID returns the id of rocksmq message -func (rm *rmqMessage) ID() mqwrapper.MessageID { - return &rmqID{messageID: rm.msg.MsgID} +func (rm *RmqMessage) ID() mqwrapper.MessageID { + return &server.RmqID{MessageID: rm.msgID} } diff --git a/internal/mq/mqimpl/rocksmq/client/test_helper.go b/internal/mq/mqimpl/rocksmq/client/test_helper.go index 6fee40a46dfeb..d99ade29e8acf 100644 --- a/internal/mq/mqimpl/rocksmq/client/test_helper.go +++ b/internal/mq/mqimpl/rocksmq/client/test_helper.go @@ -22,6 +22,7 @@ import ( server2 "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func newTopicName() string { @@ -46,6 +47,7 @@ func newMockClient() *client { func newRocksMQ(t *testing.T, rmqPath string) server2.RocksMQ { rocksdbPath := rmqPath + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := server2.NewRocksMQ(rocksdbPath, nil) assert.NoError(t, err) return rmq diff --git a/internal/mq/mqimpl/rocksmq/client/util.go b/internal/mq/mqimpl/rocksmq/client/util.go new file mode 100644 index 0000000000000..bdefdb666d4f4 --- /dev/null +++ b/internal/mq/mqimpl/rocksmq/client/util.go @@ -0,0 +1,43 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package client + +import ( + "fmt" + + "github.com/golang/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" +) + +func MarshalHeader(header *commonpb.MsgHeader) ([]byte, error) { + hb, err := proto.Marshal(header) + if err != nil { + return nil, err + } + return hb, nil +} + +func UnmarshalHeader(headerbyte []byte) (*commonpb.MsgHeader, error) { + header := commonpb.MsgHeader{} + if headerbyte == nil { + return &header, fmt.Errorf("failed to unmarshal message header, payload is empty") + } + err := proto.Unmarshal(headerbyte, &header) + if err != nil { + return &header, err + } + if header.Base == nil { + return nil, fmt.Errorf("failed to unmarshal message, header is uncomplete") + } + return &header, nil +} diff --git a/internal/mq/msgstream/mqwrapper/rmq/rmq_id.go b/internal/mq/mqimpl/rocksmq/server/rmq_id.go similarity index 74% rename from internal/mq/msgstream/mqwrapper/rmq/rmq_id.go rename to internal/mq/mqimpl/rocksmq/server/rmq_id.go index cf7ed44daf467..8e252e3346196 100644 --- a/internal/mq/msgstream/mqwrapper/rmq/rmq_id.go +++ b/internal/mq/mqimpl/rocksmq/server/rmq_id.go @@ -14,39 +14,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -package rmq +package server import ( - "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" ) // rmqID wraps message ID for rocksmq -type rmqID struct { - messageID server.UniqueID +type RmqID struct { + MessageID UniqueID } // Check if rmqID implements MessageID interface -var _ mqwrapper.MessageID = &rmqID{} +var _ mqwrapper.MessageID = &RmqID{} // Serialize convert rmq message id to []byte -func (rid *rmqID) Serialize() []byte { - return SerializeRmqID(rid.messageID) +func (rid *RmqID) Serialize() []byte { + return SerializeRmqID(rid.MessageID) } -func (rid *rmqID) AtEarliestPosition() bool { - return rid.messageID <= 0 +func (rid *RmqID) AtEarliestPosition() bool { + return rid.MessageID <= 0 } -func (rid *rmqID) LessOrEqualThan(msgID []byte) (bool, error) { +func (rid *RmqID) LessOrEqualThan(msgID []byte) (bool, error) { rMsgID := DeserializeRmqID(msgID) - return rid.messageID <= rMsgID, nil + return rid.MessageID <= rMsgID, nil } -func (rid *rmqID) Equal(msgID []byte) (bool, error) { +func (rid *RmqID) Equal(msgID []byte) (bool, error) { rMsgID := DeserializeRmqID(msgID) - return rid.messageID == rMsgID, nil + return rid.MessageID == rMsgID, nil } // SerializeRmqID is used to serialize a message ID to byte array diff --git a/internal/mq/msgstream/mqwrapper/rmq/rmq_id_test.go b/internal/mq/mqimpl/rocksmq/server/rmq_id_test.go similarity index 87% rename from internal/mq/msgstream/mqwrapper/rmq/rmq_id_test.go rename to internal/mq/mqimpl/rocksmq/server/rmq_id_test.go index e92db5a93f724..cb6edf1e64054 100644 --- a/internal/mq/msgstream/mqwrapper/rmq/rmq_id_test.go +++ b/internal/mq/mqimpl/rocksmq/server/rmq_id_test.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package rmq +package server import ( "math" @@ -24,8 +24,8 @@ import ( ) func TestRmqID_Serialize(t *testing.T) { - rid := &rmqID{ - messageID: 8, + rid := &RmqID{ + MessageID: 8, } bin := rid.Serialize() @@ -34,23 +34,23 @@ func TestRmqID_Serialize(t *testing.T) { } func Test_AtEarliestPosition(t *testing.T) { - rid := &rmqID{ - messageID: 0, + rid := &RmqID{ + MessageID: 0, } assert.True(t, rid.AtEarliestPosition()) - rid = &rmqID{ - messageID: math.MaxInt64, + rid = &RmqID{ + MessageID: math.MaxInt64, } assert.False(t, rid.AtEarliestPosition()) } func TestLessOrEqualThan(t *testing.T) { - rid1 := &rmqID{ - messageID: 0, + rid1 := &RmqID{ + MessageID: 0, } - rid2 := &rmqID{ - messageID: math.MaxInt64, + rid2 := &RmqID{ + MessageID: math.MaxInt64, } ret, err := rid1.LessOrEqualThan(rid2.Serialize()) @@ -67,19 +67,18 @@ func TestLessOrEqualThan(t *testing.T) { } func Test_Equal(t *testing.T) { - rid1 := &rmqID{ - messageID: 0, + rid1 := &RmqID{ + MessageID: 0, } - rid2 := &rmqID{ - messageID: math.MaxInt64, + rid2 := &RmqID{ + MessageID: math.MaxInt64, } { ret, err := rid1.Equal(rid1.Serialize()) assert.NoError(t, err) assert.True(t, ret) - } { diff --git a/internal/mq/mqimpl/rocksmq/server/rocksmq.go b/internal/mq/mqimpl/rocksmq/server/rocksmq.go index dee82f3147564..9b222834dda82 100644 --- a/internal/mq/mqimpl/rocksmq/server/rocksmq.go +++ b/internal/mq/mqimpl/rocksmq/server/rocksmq.go @@ -13,8 +13,7 @@ package server // ProducerMessage that will be written to rocksdb type ProducerMessage struct { - Payload []byte - Properties map[string]string + Payload []byte } // Consumer is rocksmq consumer @@ -26,9 +25,8 @@ type Consumer struct { // ConsumerMessage that consumed from rocksdb type ConsumerMessage struct { - MsgID UniqueID - Payload []byte - Properties map[string]string + MsgID UniqueID + Payload []byte } // RocksMQ is an interface thatmay be implemented by the application diff --git a/internal/mq/mqimpl/rocksmq/server/rocksmq_impl.go b/internal/mq/mqimpl/rocksmq/server/rocksmq_impl.go index 312aa534b217e..38c0bc8140024 100644 --- a/internal/mq/mqimpl/rocksmq/server/rocksmq_impl.go +++ b/internal/mq/mqimpl/rocksmq/server/rocksmq_impl.go @@ -12,7 +12,6 @@ package server import ( - "encoding/json" "fmt" "path" "runtime" @@ -31,7 +30,6 @@ import ( "github.com/milvus-io/milvus/internal/kv" rocksdbkv "github.com/milvus-io/milvus/internal/kv/rocksdb" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -623,7 +621,6 @@ func (rmq *rocksmq) Produce(topicName string, messages []ProducerMessage) ([]Uni msgLen := len(messages) idStart, idEnd, err := rmq.idAllocator.Alloc(uint32(msgLen)) - if err != nil { return []UniqueID{}, err } @@ -631,7 +628,6 @@ func (rmq *rocksmq) Produce(topicName string, messages []ProducerMessage) ([]Uni if UniqueID(msgLen) != idEnd-idStart { return []UniqueID{}, errors.New("Obtained id length is not equal that of message") } - // Insert data to store system batch := gorocksdb.NewWriteBatch() defer batch.Destroy() @@ -641,16 +637,6 @@ func (rmq *rocksmq) Produce(topicName string, messages []ProducerMessage) ([]Uni msgID := idStart + UniqueID(i) key := path.Join(topicName, strconv.FormatInt(msgID, 10)) batch.PutCF(rmq.cfh[0], []byte(key), messages[i].Payload) - // batch.Put([]byte(key), messages[i].Payload) - if messages[i].Properties != nil { - properties, err := json.Marshal(messages[i].Properties) - if err != nil { - log.Warn("properties marshal failed", zap.Int64("msgID", msgID), zap.String("topicName", topicName), - zap.Error(err)) - return nil, err - } - batch.PutCF(rmq.cfh[1], []byte(key), properties) - } msgIDs[i] = msgID msgSizes[msgID] = int64(len(messages[i].Payload)) } @@ -784,9 +770,7 @@ func (rmq *rocksmq) Consume(topicName string, groupName string, n int) ([]Consum defer readOpts.Destroy() prefix := topicName + "/" iter := rocksdbkv.NewRocksIteratorCFWithUpperBound(rmq.store, rmq.cfh[0], typeutil.AddOne(prefix), readOpts) - iterProperty := rocksdbkv.NewRocksIteratorCFWithUpperBound(rmq.store, rmq.cfh[1], typeutil.AddOne(prefix), readOpts) defer iter.Close() - defer iterProperty.Close() var dataKey string if currentID == DefaultMessageID { @@ -795,7 +779,6 @@ func (rmq *rocksmq) Consume(topicName string, groupName string, n int) ([]Consum dataKey = path.Join(topicName, strconv.FormatInt(currentID, 10)) } iter.Seek([]byte(dataKey)) - iterProperty.Seek([]byte(dataKey)) consumerMessage := make([]ConsumerMessage, 0, n) offset := 0 @@ -803,11 +786,9 @@ func (rmq *rocksmq) Consume(topicName string, groupName string, n int) ([]Consum for ; iter.Valid() && offset < n; iter.Next() { key := iter.Key() val := iter.Value() - strKey := string(key.Data()) key.Free() - properties := make(map[string]string) - var propertiesValue []byte + strKey := string(key.Data()) msgID, err := strconv.ParseInt(strKey[len(topicName)+1:], 10, 64) if err != nil { val.Free() @@ -815,23 +796,6 @@ func (rmq *rocksmq) Consume(topicName string, groupName string, n int) ([]Consum } offset++ - if iterProperty.Valid() && string(iterProperty.Key().Data()) == string(iter.Key().Data()) { - // the key of properties is the same with the key of payload - // to prevent mix message with or without property column family - propertiesValue = iterProperty.Value().Data() - iterProperty.Next() - } - - // between 2.2.0 and 2.3.0, the key of Payload is topic/properties/msgid/Payload - // will ingnore the property before 2.3.0, just make sure property empty is ok for 2.3 - - // before 2.2.0, there have no properties in ProducerMessage and ConsumerMessage in rocksmq - // when produce before 2.2.0, but consume after 2.2.0, propertiesValue will be [] - if len(propertiesValue) != 0 { - if err = json.Unmarshal(propertiesValue, &properties); err != nil { - return nil, err - } - } msg := ConsumerMessage{ MsgID: msgID, } @@ -839,10 +803,8 @@ func (rmq *rocksmq) Consume(topicName string, groupName string, n int) ([]Consum dataLen := len(origData) if dataLen == 0 { msg.Payload = nil - msg.Properties = nil } else { msg.Payload = make([]byte, dataLen) - msg.Properties = properties copy(msg.Payload, origData) } consumerMessage = append(consumerMessage, msg) @@ -902,7 +864,7 @@ func (rmq *rocksmq) seek(topicName string, groupName string, msgID UniqueID) err log.Warn("RocksMQ: trying to seek to no exist position, reset current id", zap.String("topic", topicName), zap.String("group", groupName), zap.Int64("msgId", msgID)) err := rmq.moveConsumePos(topicName, groupName, DefaultMessageID) - //skip seek if key is not found, this is the behavior as pulsar + // skip seek if key is not found, this is the behavior as pulsar return err } /* Step II: update current_id */ @@ -922,7 +884,7 @@ func (rmq *rocksmq) moveConsumePos(topicName string, groupName string, msgID Uni panic("move consume position backward") } - //update ack if position move forward + // update ack if position move forward err := rmq.updateAckedInfo(topicName, groupName, oldPos, msgID-1) if err != nil { log.Warn("failed to update acked info ", zap.String("topic", topicName), @@ -942,7 +904,7 @@ func (rmq *rocksmq) Seek(topicName string, groupName string, msgID UniqueID) err /* Step I: Check if key exists */ ll, ok := topicMu.Load(topicName) if !ok { - return fmt.Errorf("topic %s not exist, %w", topicName, mqwrapper.ErrTopicNotExist) + return merr.WrapErrMqTopicNotFound(topicName) } lock, ok := ll.(*sync.Mutex) if !ok { @@ -968,7 +930,7 @@ func (rmq *rocksmq) ForceSeek(topicName string, groupName string, msgID UniqueID /* Step I: Check if key exists */ ll, ok := topicMu.Load(topicName) if !ok { - return fmt.Errorf("topic %s not exist, %w", topicName, mqwrapper.ErrTopicNotExist) + return merr.WrapErrMqTopicNotFound(topicName) } lock, ok := ll.(*sync.Mutex) if !ok { @@ -1054,7 +1016,6 @@ func (rmq *rocksmq) getLatestMsg(topicName string) (int64, error) { } msgID, err := strconv.ParseInt(seekMsgID[len(topicName)+1:], 10, 64) - if err != nil { return DefaultMessageID, err } @@ -1154,22 +1115,14 @@ func (rmq *rocksmq) updateAckedInfo(topicName, groupName string, firstID UniqueI } func (rmq *rocksmq) CheckTopicValid(topic string) error { - // Check if key exists - log := log.With(zap.String("topic", topic)) - _, ok := topicMu.Load(topic) if !ok { - return merr.WrapErrTopicNotFound(topic, "failed to get topic") + return merr.WrapErrMqTopicNotFound(topic, "failed to get topic") } - latestMsgID, err := rmq.GetLatestMsg(topic) + _, err := rmq.GetLatestMsg(topic) if err != nil { return err } - - if latestMsgID != DefaultMessageID { - return merr.WrapErrTopicNotEmpty(topic, "topic is not empty") - } - log.Info("created topic is empty") return nil } diff --git a/internal/mq/mqimpl/rocksmq/server/rocksmq_impl_test.go b/internal/mq/mqimpl/rocksmq/server/rocksmq_impl_test.go index b089857d105a7..7dbfdd0778924 100644 --- a/internal/mq/mqimpl/rocksmq/server/rocksmq_impl_test.go +++ b/internal/mq/mqimpl/rocksmq/server/rocksmq_impl_test.go @@ -12,7 +12,6 @@ package server import ( - "encoding/json" "fmt" "os" "path" @@ -31,7 +30,6 @@ import ( "github.com/milvus-io/milvus/internal/allocator" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" rocksdbkv "github.com/milvus-io/milvus/internal/kv/rocksdb" - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/merr" @@ -102,88 +100,6 @@ func (rmq *rocksmq) produceBefore2(topicName string, messages []producerMessageB msgLen := len(messages) idStart, idEnd, err := rmq.idAllocator.Alloc(uint32(msgLen)) - - if err != nil { - return []UniqueID{}, err - } - allocTime := time.Since(start).Milliseconds() - if UniqueID(msgLen) != idEnd-idStart { - return []UniqueID{}, errors.New("Obtained id length is not equal that of message") - } - - // Insert data to store system - batch := gorocksdb.NewWriteBatch() - defer batch.Destroy() - msgSizes := make(map[UniqueID]int64) - msgIDs := make([]UniqueID, msgLen) - for i := 0; i < msgLen && idStart+UniqueID(i) < idEnd; i++ { - msgID := idStart + UniqueID(i) - key := path.Join(topicName, strconv.FormatInt(msgID, 10)) - batch.Put([]byte(key), messages[i].Payload) - msgIDs[i] = msgID - msgSizes[msgID] = int64(len(messages[i].Payload)) - } - - opts := gorocksdb.NewDefaultWriteOptions() - defer opts.Destroy() - err = rmq.store.Write(opts, batch) - if err != nil { - return []UniqueID{}, err - } - writeTime := time.Since(start).Milliseconds() - if vals, ok := rmq.consumers.Load(topicName); ok { - for _, v := range vals.([]*Consumer) { - select { - case v.MsgMutex <- struct{}{}: - continue - default: - continue - } - } - } - - // Update message page info - err = rmq.updatePageInfo(topicName, msgIDs, msgSizes) - if err != nil { - return []UniqueID{}, err - } - - getProduceTime := time.Since(start).Milliseconds() - if getProduceTime > 200 { - - log.Warn("rocksmq produce too slowly", zap.String("topic", topicName), - zap.Int64("get lock elapse", getLockTime), - zap.Int64("alloc elapse", allocTime-getLockTime), - zap.Int64("write elapse", writeTime-allocTime), - zap.Int64("updatePage elapse", getProduceTime-writeTime), - zap.Int64("produce total elapse", getProduceTime), - ) - } - return msgIDs, nil -} - -// to test compatibility concern -func (rmq *rocksmq) produceIn2(topicName string, messages []ProducerMessage) ([]UniqueID, error) { - if rmq.isClosed() { - return nil, errors.New(RmqNotServingErrMsg) - } - start := time.Now() - ll, ok := topicMu.Load(topicName) - if !ok { - return []UniqueID{}, fmt.Errorf("topic name = %s not exist", topicName) - } - lock, ok := ll.(*sync.Mutex) - if !ok { - return []UniqueID{}, fmt.Errorf("get mutex failed, topic name = %s", topicName) - } - lock.Lock() - defer lock.Unlock() - - getLockTime := time.Since(start).Milliseconds() - - msgLen := len(messages) - idStart, idEnd, err := rmq.idAllocator.Alloc(uint32(msgLen)) - if err != nil { return []UniqueID{}, err } @@ -201,16 +117,6 @@ func (rmq *rocksmq) produceIn2(topicName string, messages []ProducerMessage) ([] msgID := idStart + UniqueID(i) key := path.Join(topicName, strconv.FormatInt(msgID, 10)) batch.Put([]byte(key), messages[i].Payload) - properties, err := json.Marshal(messages[i].Properties) - if err != nil { - log.Warn("properties marshal failed", - zap.Int64("msgID", msgID), - zap.String("topicName", topicName), - zap.Error(err)) - return nil, err - } - pKey := path.Join(common.PropertiesKey, topicName, strconv.FormatInt(msgID, 10)) - batch.Put([]byte(pKey), properties) msgIDs[i] = msgID msgSizes[msgID] = int64(len(messages[i].Payload)) } @@ -239,7 +145,6 @@ func (rmq *rocksmq) produceIn2(topicName string, messages []ProducerMessage) ([] return []UniqueID{}, err } - // TODO add this to monitor metrics getProduceTime := time.Since(start).Milliseconds() if getProduceTime > 200 { log.Warn("rocksmq produce too slowly", zap.String("topic", topicName), @@ -250,8 +155,6 @@ func (rmq *rocksmq) produceIn2(topicName string, messages []ProducerMessage) ([] zap.Int64("produce total elapse", getProduceTime), ) } - - rmq.topicLastID.Store(topicName, msgIDs[len(msgIDs)-1]) return msgIDs, nil } @@ -266,6 +169,7 @@ func TestRocksmq_RegisterConsumer(t *testing.T) { defer os.RemoveAll(rocksdbPath) paramtable.Init() + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(rocksdbPath, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -330,6 +234,7 @@ func TestRocksmq_Basic(t *testing.T) { defer os.RemoveAll(rocksdbPath + kvSuffix) defer os.RemoveAll(rocksdbPath) paramtable.Init() + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(rocksdbPath, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -341,14 +246,14 @@ func TestRocksmq_Basic(t *testing.T) { msgA := "a_message" pMsgs := make([]ProducerMessage, 1) - pMsgA := ProducerMessage{Payload: []byte(msgA), Properties: map[string]string{common.TraceIDKey: "a"}} + pMsgA := ProducerMessage{Payload: []byte(msgA)} pMsgs[0] = pMsgA _, err = rmq.Produce(channelName, pMsgs) assert.NoError(t, err) - pMsgB := ProducerMessage{Payload: []byte("b_message"), Properties: map[string]string{common.TraceIDKey: "b"}} - pMsgC := ProducerMessage{Payload: []byte("c_message"), Properties: map[string]string{common.TraceIDKey: "c"}} + pMsgB := ProducerMessage{Payload: []byte("b_message")} + pMsgC := ProducerMessage{Payload: []byte("c_message")} pMsgs[0] = pMsgB pMsgs = append(pMsgs, pMsgC) @@ -366,121 +271,12 @@ func TestRocksmq_Basic(t *testing.T) { assert.NoError(t, err) assert.Equal(t, len(cMsgs), 1) assert.Equal(t, string(cMsgs[0].Payload), "a_message") - _, ok := cMsgs[0].Properties[common.TraceIDKey] - assert.True(t, ok) - assert.Equal(t, cMsgs[0].Properties[common.TraceIDKey], "a") cMsgs, err = rmq.Consume(channelName, groupName, 2) assert.NoError(t, err) assert.Equal(t, len(cMsgs), 2) assert.Equal(t, string(cMsgs[0].Payload), "b_message") - _, ok = cMsgs[0].Properties[common.TraceIDKey] - assert.True(t, ok) - assert.Equal(t, cMsgs[0].Properties[common.TraceIDKey], "b") assert.Equal(t, string(cMsgs[1].Payload), "c_message") - _, ok = cMsgs[1].Properties[common.TraceIDKey] - assert.True(t, ok) - assert.Equal(t, cMsgs[1].Properties[common.TraceIDKey], "c") -} - -func TestRocksmq_Compatibility(t *testing.T) { - suffix := "rmq_compatibility" - - kvPath := rmqPath + kvPathSuffix + suffix - defer os.RemoveAll(kvPath) - idAllocator := InitIDAllocator(kvPath) - - rocksdbPath := rmqPath + suffix - defer os.RemoveAll(rocksdbPath + kvSuffix) - defer os.RemoveAll(rocksdbPath) - paramtable.Init() - rmq, err := NewRocksMQ(rocksdbPath, idAllocator) - assert.NoError(t, err) - defer rmq.Close() - - channelName := "channel_rocks" - err = rmq.CreateTopic(channelName) - assert.NoError(t, err) - defer rmq.DestroyTopic(channelName) - - // before 2.2.0, there have no properties in ProducerMessage and ConsumerMessage in rocksmq - // it aims to test if produce before 2.2.0, will consume after 2.2.0 successfully - msgD := "d_message" - tMsgs := make([]producerMessageBefore2, 1) - tMsgD := producerMessageBefore2{Payload: []byte(msgD)} - tMsgs[0] = tMsgD - _, err = rmq.produceBefore2(channelName, tMsgs) - assert.NoError(t, err) - - groupName := "test_group" - _ = rmq.DestroyConsumerGroup(channelName, groupName) - err = rmq.CreateConsumerGroup(channelName, groupName) - assert.NoError(t, err) - - cMsgs, err := rmq.Consume(channelName, groupName, 1) - if err != nil { - log.Info("test", zap.Any("err", err)) - } - assert.NoError(t, err) - assert.Equal(t, len(cMsgs), 1) - assert.Equal(t, string(cMsgs[0].Payload), "d_message") - _, ok := cMsgs[0].Properties[common.TraceIDKey] - assert.False(t, ok) - // it will be set empty map if produce message has no properties field - expect := make(map[string]string) - assert.Equal(t, cMsgs[0].Properties, expect) - - // between 2.2.0 and 2.3.0, the key of Payload is topic/properties/msgid/Payload - // will ingnore the property before 2.3.0, just make sure property empty is ok for 2.3 - // after 2.3, the properties will be stored in column families - // it aims to test if produce in 2.2.0, but consume in 2.3.0, will get properties successfully - msg1 := "1_message" - tMsgs1 := make([]ProducerMessage, 1) - properties := make(map[string]string) - properties[common.TraceIDKey] = "1" - tMsg1 := ProducerMessage{Payload: []byte(msg1), Properties: properties} - tMsgs1[0] = tMsg1 - _, err = rmq.produceIn2(channelName, tMsgs1) - assert.NoError(t, err) - - msg2, err := rmq.Consume(channelName, groupName, 1) - assert.NoError(t, err) - assert.Equal(t, len(msg2), 1) - assert.Equal(t, string(msg2[0].Payload), "1_message") - _, ok = msg2[0].Properties[common.TraceIDKey] - assert.False(t, ok) - // will ingnore the property before 2.3.0, just make sure property empty is ok for 2.3 - expect = make(map[string]string) - assert.Equal(t, cMsgs[0].Properties, expect) - - // between 2.2.0 and 2.3.0, the key of Payload is topic/properties/msgid/Payload - // after 2.3, the properties will be stored in column families - // it aims to test the mixed message before 2.3.0 and after 2.3.0, will get properties successfully - msg3 := "3_message" - tMsgs3 := make([]ProducerMessage, 2) - properties3 := make(map[string]string) - properties3[common.TraceIDKey] = "3" - tMsg3 := ProducerMessage{Payload: []byte(msg3), Properties: properties3} - tMsgs3[0] = tMsg3 - msg4 := "4_message" - tMsg4 := ProducerMessage{Payload: []byte(msg4)} - tMsgs3[1] = tMsg4 - _, err = rmq.Produce(channelName, tMsgs3) - assert.NoError(t, err) - - msg5, err := rmq.Consume(channelName, groupName, 2) - assert.NoError(t, err) - assert.Equal(t, len(msg5), 2) - assert.Equal(t, string(msg5[0].Payload), "3_message") - _, ok = msg5[0].Properties[common.TraceIDKey] - assert.True(t, ok) - assert.Equal(t, msg5[0].Properties, properties3) - assert.Equal(t, string(msg5[1].Payload), "4_message") - _, ok = msg5[1].Properties[common.TraceIDKey] - assert.False(t, ok) - // it will be set empty map if produce message has no properties field - expect = make(map[string]string) - assert.Equal(t, msg5[1].Properties, expect) } func TestRocksmq_MultiConsumer(t *testing.T) { @@ -495,6 +291,7 @@ func TestRocksmq_MultiConsumer(t *testing.T) { params := paramtable.Get() params.Save(params.RocksmqCfg.PageSize.Key, "10") + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(rocksdbPath, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -547,6 +344,7 @@ func TestRocksmq_Dummy(t *testing.T) { defer os.RemoveAll(rocksdbPath + kvSuffix) defer os.RemoveAll(rocksdbPath) paramtable.Init() + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(rocksdbPath, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -617,10 +415,12 @@ func TestRocksmq_Seek(t *testing.T) { defer os.RemoveAll(rocksdbPath) paramtable.Init() + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(rocksdbPath, idAllocator) assert.NoError(t, err) defer rmq.Close() + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") _, err = NewRocksMQ("", idAllocator) assert.Error(t, err) defer os.RemoveAll("_meta_kv") @@ -663,7 +463,6 @@ func TestRocksmq_Seek(t *testing.T) { assert.Equal(t, messages[0].MsgID, seekID2) _ = rmq.DestroyConsumerGroup(channelName, groupName1) - } func TestRocksmq_Loop(t *testing.T) { @@ -685,6 +484,7 @@ func TestRocksmq_Loop(t *testing.T) { defer os.RemoveAll(kvName) paramtable.Init() + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -757,6 +557,7 @@ func TestRocksmq_Goroutines(t *testing.T) { defer os.RemoveAll(kvName) paramtable.Init() + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -836,6 +637,7 @@ func TestRocksmq_Throughout(t *testing.T) { defer os.RemoveAll(kvName) paramtable.Init() + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -901,6 +703,7 @@ func TestRocksmq_MultiChan(t *testing.T) { defer os.RemoveAll(kvName) paramtable.Init() + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -955,6 +758,7 @@ func TestRocksmq_CopyData(t *testing.T) { defer os.RemoveAll(kvName) paramtable.Init() + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -1023,6 +827,7 @@ func TestRocksmq_SeekToLatest(t *testing.T) { defer os.RemoveAll(kvName) paramtable.Init() + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -1114,6 +919,7 @@ func TestRocksmq_GetLatestMsg(t *testing.T) { kvName := name + "_meta_kv" _ = os.RemoveAll(kvName) defer os.RemoveAll(kvName) + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) @@ -1182,6 +988,7 @@ func TestRocksmq_CheckPreTopicValid(t *testing.T) { defer os.RemoveAll(rocksdbPath + kvSuffix) defer os.RemoveAll(rocksdbPath) paramtable.Init() + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(rocksdbPath, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -1189,10 +996,10 @@ func TestRocksmq_CheckPreTopicValid(t *testing.T) { channelName1 := "topic1" // topic not exist err = rmq.CheckTopicValid(channelName1) - assert.Equal(t, true, errors.Is(err, merr.ErrTopicNotFound)) + assert.Equal(t, true, errors.Is(err, merr.ErrMqTopicNotFound)) channelName2 := "topic2" - // topic is not empty + // allow topic is not empty err = rmq.CreateTopic(channelName2) defer rmq.DestroyTopic(channelName2) assert.NoError(t, err) @@ -1208,7 +1015,7 @@ func TestRocksmq_CheckPreTopicValid(t *testing.T) { assert.NoError(t, err) err = rmq.CheckTopicValid(channelName2) - assert.Equal(t, true, errors.Is(err, merr.ErrTopicNotEmpty)) + assert.NoError(t, err) channelName3 := "topic3" // pass @@ -1237,6 +1044,7 @@ func TestRocksmq_Close(t *testing.T) { kvName := name + "_meta_kv" _ = os.RemoveAll(kvName) defer os.RemoveAll(kvName) + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -1269,6 +1077,7 @@ func TestRocksmq_SeekWithNoConsumerError(t *testing.T) { kvName := name + "_meta_kv" _ = os.RemoveAll(kvName) defer os.RemoveAll(kvName) + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -1294,6 +1103,7 @@ func TestRocksmq_SeekTopicNotExistError(t *testing.T) { kvName := name + "_meta_kv" _ = os.RemoveAll(kvName) defer os.RemoveAll(kvName) + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -1316,6 +1126,7 @@ func TestRocksmq_SeekTopicMutexError(t *testing.T) { kvName := name + "_meta_kv" _ = os.RemoveAll(kvName) defer os.RemoveAll(kvName) + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -1339,6 +1150,7 @@ func TestRocksmq_moveConsumePosError(t *testing.T) { kvName := name + "_meta_kv" _ = os.RemoveAll(kvName) defer os.RemoveAll(kvName) + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -1363,6 +1175,7 @@ func TestRocksmq_updateAckedInfoErr(t *testing.T) { defer os.RemoveAll(kvName) params := paramtable.Get() params.Save(params.RocksmqCfg.PageSize.Key, "10") + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -1371,7 +1184,7 @@ func TestRocksmq_updateAckedInfoErr(t *testing.T) { rmq.CreateTopic(topicName) defer rmq.DestroyTopic(topicName) - //add message, make sure rmq has more than one page + // add message, make sure rmq has more than one page msgNum := 100 pMsgs := make([]ProducerMessage, msgNum) for i := 0; i < msgNum; i++ { @@ -1390,9 +1203,9 @@ func TestRocksmq_updateAckedInfoErr(t *testing.T) { GroupName: groupName + strconv.Itoa(i), MsgMutex: make(chan struct{}), } - //make sure consumer not in rmq.consumersID + // make sure consumer not in rmq.consumersID rmq.DestroyConsumerGroup(topicName, groupName+strconv.Itoa(i)) - //add consumer to rmq.consumers + // add consumer to rmq.consumers rmq.RegisterConsumer(consumer) } @@ -1422,6 +1235,7 @@ func TestRocksmq_Info(t *testing.T) { defer os.RemoveAll(kvName) params := paramtable.Get() params.Save(params.RocksmqCfg.PageSize.Key, "10") + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") rmq, err := NewRocksMQ(name, idAllocator) assert.NoError(t, err) defer rmq.Close() @@ -1445,7 +1259,7 @@ func TestRocksmq_Info(t *testing.T) { assert.True(t, rmq.Info()) - //test error + // test error rmq.kv = &rocksdbkv.RocksdbKV{} assert.False(t, rmq.Info()) } diff --git a/internal/mq/mqimpl/rocksmq/server/rocksmq_retention_test.go b/internal/mq/mqimpl/rocksmq/server/rocksmq_retention_test.go index af42fbee4bed8..ecaf612cdb120 100644 --- a/internal/mq/mqimpl/rocksmq/server/rocksmq_retention_test.go +++ b/internal/mq/mqimpl/rocksmq/server/rocksmq_retention_test.go @@ -391,7 +391,6 @@ func TestRmqRetention_MultipleTopic(t *testing.T) { newRes, err = rmq.Consume(topicName, groupName, 1) assert.NoError(t, err) assert.Equal(t, len(newRes), 0) - } func TestRetentionInfo_InitRetentionInfo(t *testing.T) { diff --git a/internal/mq/msgstream/mq_factory.go b/internal/mq/msgstream/mq_factory.go index 60f1dea3f06ca..6d19d2abf0a50 100644 --- a/internal/mq/msgstream/mq_factory.go +++ b/internal/mq/msgstream/mq_factory.go @@ -1,12 +1,13 @@ package msgstream import ( + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper/rmq" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/paramtable" - "go.uber.org/zap" ) // NewRocksmqFactory creates a new message stream factory based on rocksmq. diff --git a/internal/mq/msgstream/mq_factory_test.go b/internal/mq/msgstream/mq_factory_test.go index c66d6a3310426..dc0e9213c1c06 100644 --- a/internal/mq/msgstream/mq_factory_test.go +++ b/internal/mq/msgstream/mq_factory_test.go @@ -40,5 +40,4 @@ func TestRmsFactory(t *testing.T) { _, err = rmsFactory.NewTtMsgStream(ctx) assert.NoError(t, err) - } diff --git a/internal/mq/msgstream/mqwrapper/rmq/rmq_client.go b/internal/mq/msgstream/mqwrapper/rmq/rmq_client.go index 117500ad7aab7..6de1d3ae68507 100644 --- a/internal/mq/msgstream/mqwrapper/rmq/rmq_client.go +++ b/internal/mq/msgstream/mqwrapper/rmq/rmq_client.go @@ -17,13 +17,14 @@ package rmq import ( + "context" "strconv" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/client" - "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/client" + "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" @@ -38,7 +39,7 @@ type rmqClient struct { client client.Client } -func NewClientWithDefaultOptions() (mqwrapper.Client, error) { +func NewClientWithDefaultOptions(ctx context.Context) (mqwrapper.Client, error) { option := client.Options{Server: server.Rmq} return NewClient(option) } @@ -83,7 +84,7 @@ func (rc *rmqClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Con log.Warn("unexpected subscription consumer options", zap.Error(err)) return nil, err } - receiveChannel := make(chan client.Message, options.BufSize) + receiveChannel := make(chan mqwrapper.Message, options.BufSize) cli, err := rc.client.Subscribe(client.ConsumerOptions{ Topic: options.Topic, @@ -107,7 +108,7 @@ func (rc *rmqClient) Subscribe(options mqwrapper.ConsumerOptions) (mqwrapper.Con // EarliestMessageID returns the earliest message ID for rmq client func (rc *rmqClient) EarliestMessageID() mqwrapper.MessageID { rID := client.EarliestMessageID() - return &rmqID{messageID: rID} + return &server.RmqID{MessageID: rID} } // StringToMsgID converts string id to MessageID @@ -116,13 +117,13 @@ func (rc *rmqClient) StringToMsgID(id string) (mqwrapper.MessageID, error) { if err != nil { return nil, err } - return &rmqID{messageID: rID}, nil + return &server.RmqID{MessageID: rID}, nil } // BytesToMsgID converts a byte array to messageID func (rc *rmqClient) BytesToMsgID(id []byte) (mqwrapper.MessageID, error) { - rID := DeserializeRmqID(id) - return &rmqID{messageID: rID}, nil + rID := server.DeserializeRmqID(id) + return &server.RmqID{MessageID: rID}, nil } func (rc *rmqClient) Close() { diff --git a/internal/mq/msgstream/mqwrapper/rmq/rmq_client_test.go b/internal/mq/msgstream/mqwrapper/rmq/rmq_client_test.go index 11bed58150e47..1248914485e1f 100644 --- a/internal/mq/msgstream/mqwrapper/rmq/rmq_client_test.go +++ b/internal/mq/msgstream/mqwrapper/rmq/rmq_client_test.go @@ -25,11 +25,11 @@ import ( "time" "github.com/apache/pulsar-client-go/pulsar" - rocksmqimplclient "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/client" - rocksmqimplserver "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + rocksmqimplclient "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/client" + rocksmqimplserver "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" pulsarwrapper "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper/pulsar" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -43,6 +43,7 @@ func TestMain(m *testing.M) { rand.Seed(time.Now().UnixNano()) path := "/tmp/milvus/rdb_data" defer os.RemoveAll(path) + paramtable.Get().Save("rocksmq.compressionTypes", "0,0,0,0,0") _ = rocksmqimplserver.InitRocksMQ(path) exitCode := m.Run() defer rocksmqimplserver.CloseRocksMQ() @@ -66,11 +67,11 @@ func TestRmqClient_CreateProducer(t *testing.T) { topic := "TestRmqClient_CreateProducer" proOpts := mqwrapper.ProducerOptions{Topic: topic} producer, err := client.CreateProducer(proOpts) - - defer producer.Close() assert.NoError(t, err) assert.NotNil(t, producer) + defer producer.Close() + rmqProducer := producer.(*rmqProducer) defer rmqProducer.Close() assert.Equal(t, rmqProducer.Topic(), topic) @@ -150,9 +151,9 @@ func TestRmqClient_Subscribe(t *testing.T) { topic := "TestRmqClient_Subscribe" proOpts := mqwrapper.ProducerOptions{Topic: topic} producer, err := client.CreateProducer(proOpts) - defer producer.Close() assert.NoError(t, err) assert.NotNil(t, producer) + defer producer.Close() subName := "subName" consumerOpts := mqwrapper.ConsumerOptions{ @@ -197,7 +198,7 @@ func TestRmqClient_Subscribe(t *testing.T) { assert.FailNow(t, "consumer failed to yield message in 100 milliseconds") case msg := <-consumer.Chan(): consumer.Ack(msg) - rmqmsg := msg.(*rmqMessage) + rmqmsg := msg.(*rocksmqimplclient.RmqMessage) msgPayload := rmqmsg.Payload() assert.NotEmpty(t, msgPayload) msgTopic := rmqmsg.Topic() @@ -205,7 +206,7 @@ func TestRmqClient_Subscribe(t *testing.T) { msgProp := rmqmsg.Properties() assert.Empty(t, msgProp) msgID := rmqmsg.ID() - rID := msgID.(*rmqID) + rID := msgID.(*rocksmqimplserver.RmqID) assert.NotZero(t, rID) } } diff --git a/internal/mq/msgstream/mqwrapper/rmq/rmq_consumer.go b/internal/mq/msgstream/mqwrapper/rmq/rmq_consumer.go index d0baefc9a2b86..d02730cdc3ba5 100644 --- a/internal/mq/msgstream/mqwrapper/rmq/rmq_consumer.go +++ b/internal/mq/msgstream/mqwrapper/rmq/rmq_consumer.go @@ -21,6 +21,7 @@ import ( "sync/atomic" "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/client" + "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" ) @@ -57,7 +58,7 @@ func (rc *Consumer) Chan() <-chan mqwrapper.Message { skip := atomic.LoadInt32(&rc.skip) if skip != 1 { select { - case rc.msgChannel <- &rmqMessage{msg: msg}: + case rc.msgChannel <- msg: case <-rc.closeCh: // if consumer closed, enter close branch below } @@ -78,7 +79,7 @@ func (rc *Consumer) Chan() <-chan mqwrapper.Message { // Seek is used to seek the position in rocksmq topic func (rc *Consumer) Seek(id mqwrapper.MessageID, inclusive bool) error { - msgID := id.(*rmqID).messageID + msgID := id.(*server.RmqID).MessageID // skip the first message when consume if !inclusive { atomic.StoreInt32(&rc.skip, 1) @@ -98,7 +99,7 @@ func (rc *Consumer) Close() { func (rc *Consumer) GetLatestMsgID() (mqwrapper.MessageID, error) { msgID, err := rc.c.GetLatestMsgID() - return &rmqID{messageID: msgID}, err + return &server.RmqID{MessageID: msgID}, err } func (rc *Consumer) CheckTopicValid(topic string) error { diff --git a/internal/mq/msgstream/mqwrapper/rmq/rmq_producer.go b/internal/mq/msgstream/mqwrapper/rmq/rmq_producer.go index 408fe3810f5d2..40d051716f207 100644 --- a/internal/mq/msgstream/mqwrapper/rmq/rmq_producer.go +++ b/internal/mq/msgstream/mqwrapper/rmq/rmq_producer.go @@ -15,6 +15,7 @@ import ( "context" "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/client" + "github.com/milvus-io/milvus/internal/mq/mqimpl/rocksmq/server" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/timerecord" @@ -37,21 +38,20 @@ func (rp *rmqProducer) Send(ctx context.Context, message *mqwrapper.ProducerMess start := timerecord.NewTimeRecorder("send msg to stream") metrics.MsgStreamOpCounter.WithLabelValues(metrics.SendMsgLabel, metrics.TotalLabel).Inc() - pm := &client.ProducerMessage{Payload: message.Payload, Properties: message.Properties} - id, err := rp.p.Send(pm) + id, err := rp.p.Send(message) if err != nil { metrics.MsgStreamOpCounter.WithLabelValues(metrics.SendMsgLabel, metrics.FailLabel).Inc() - return &rmqID{messageID: id}, err + return &server.RmqID{MessageID: id}, err } elapsed := start.ElapseSpan() metrics.MsgStreamRequestLatency.WithLabelValues(metrics.SendMsgLabel).Observe(float64(elapsed.Milliseconds())) metrics.MsgStreamOpCounter.WithLabelValues(metrics.SendMsgLabel, metrics.SuccessLabel).Inc() - return &rmqID{messageID: id}, nil + return &server.RmqID{MessageID: id}, nil } // Close does nothing currently func (rp *rmqProducer) Close() { - //TODO: close producer. Now it has bug - //rp.p.Close() + // TODO: close producer. Now it has bug + // rp.p.Close() } diff --git a/internal/mq/msgstream/mqwrapper/rmq/rocksmq_msgstream_test.go b/internal/mq/msgstream/mqwrapper/rmq/rocksmq_msgstream_test.go index 8ead7cda03275..e29171d1437a5 100644 --- a/internal/mq/msgstream/mqwrapper/rmq/rocksmq_msgstream_test.go +++ b/internal/mq/msgstream/mqwrapper/rmq/rocksmq_msgstream_test.go @@ -66,8 +66,8 @@ func TestMqMsgStream_AsConsumer(t *testing.T) { assert.NoError(t, err) // repeat calling AsConsumer - m.AsConsumer([]string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown) - m.AsConsumer([]string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown) + m.AsConsumer(context.Background(), []string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown) + m.AsConsumer(context.Background(), []string{"a"}, "b", mqwrapper.SubscriptionPositionUnknown) } func TestMqMsgStream_ComputeProduceChannelIndexes(t *testing.T) { @@ -240,7 +240,7 @@ func TestMqMsgStream_SeekNotSubscribed(t *testing.T) { ChannelName: "b", }, } - err = m.Seek(p) + err = m.Seek(context.Background(), p) assert.Error(t, err) } @@ -265,7 +265,7 @@ func initRmqStream(ctx context.Context, ) (msgstream.MsgStream, msgstream.MsgStream) { factory := msgstream.ProtoUDFactory{} - rmqClient, _ := NewClientWithDefaultOptions() + rmqClient, _ := NewClientWithDefaultOptions(ctx) inputStream, _ := msgstream.NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) inputStream.AsProducer(producerChannels) for _, opt := range opts { @@ -273,9 +273,9 @@ func initRmqStream(ctx context.Context, } var input msgstream.MsgStream = inputStream - rmqClient2, _ := NewClientWithDefaultOptions() + rmqClient2, _ := NewClientWithDefaultOptions(ctx) outputStream, _ := msgstream.NewMqMsgStream(ctx, 100, 100, rmqClient2, factory.NewUnmarshalDispatcher()) - outputStream.AsConsumer(consumerChannels, consumerGroupName, mqwrapper.SubscriptionPositionEarliest) + outputStream.AsConsumer(ctx, consumerChannels, consumerGroupName, mqwrapper.SubscriptionPositionEarliest) var output msgstream.MsgStream = outputStream return input, output @@ -289,7 +289,7 @@ func initRmqTtStream(ctx context.Context, ) (msgstream.MsgStream, msgstream.MsgStream) { factory := msgstream.ProtoUDFactory{} - rmqClient, _ := NewClientWithDefaultOptions() + rmqClient, _ := NewClientWithDefaultOptions(ctx) inputStream, _ := msgstream.NewMqMsgStream(ctx, 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) inputStream.AsProducer(producerChannels) for _, opt := range opts { @@ -297,9 +297,9 @@ func initRmqTtStream(ctx context.Context, } var input msgstream.MsgStream = inputStream - rmqClient2, _ := NewClientWithDefaultOptions() + rmqClient2, _ := NewClientWithDefaultOptions(ctx) outputStream, _ := msgstream.NewMqTtMsgStream(ctx, 100, 100, rmqClient2, factory.NewUnmarshalDispatcher()) - outputStream.AsConsumer(consumerChannels, consumerGroupName, mqwrapper.SubscriptionPositionEarliest) + outputStream.AsConsumer(ctx, consumerChannels, consumerGroupName, mqwrapper.SubscriptionPositionEarliest) var output msgstream.MsgStream = outputStream return input, output @@ -399,11 +399,11 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) { factory := msgstream.ProtoUDFactory{} - rmqClient, _ := NewClientWithDefaultOptions() + rmqClient, _ := NewClientWithDefaultOptions(ctx) outputStream, _ = msgstream.NewMqTtMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) consumerSubName = funcutil.RandomString(8) - outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown) - outputStream.Seek(receivedMsg.StartPositions) + outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown) + outputStream.Seek(ctx, receivedMsg.StartPositions) seekMsg := consumer(ctx, outputStream) assert.Equal(t, len(seekMsg.Msgs), 1+2) assert.EqualValues(t, seekMsg.Msgs[0].BeginTs(), 1) @@ -501,12 +501,12 @@ func TestStream_RmqTtMsgStream_Seek(t *testing.T) { factory := msgstream.ProtoUDFactory{} - rmqClient, _ := NewClientWithDefaultOptions() + rmqClient, _ := NewClientWithDefaultOptions(ctx) outputStream, _ = msgstream.NewMqTtMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) consumerSubName = funcutil.RandomString(8) - outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown) + outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionUnknown) - outputStream.Seek(receivedMsg3.StartPositions) + outputStream.Seek(ctx, receivedMsg3.StartPositions) seekMsg := consumer(ctx, outputStream) assert.Equal(t, len(seekMsg.Msgs), 3) result := []uint64{14, 12, 13} @@ -549,9 +549,9 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) { outputStream.Close() factory := msgstream.ProtoUDFactory{} - rmqClient2, _ := NewClientWithDefaultOptions() + rmqClient2, _ := NewClientWithDefaultOptions(ctx) outputStream2, _ := msgstream.NewMqMsgStream(ctx, 100, 100, rmqClient2, factory.NewUnmarshalDispatcher()) - outputStream2.AsConsumer(consumerChannels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown) + outputStream2.AsConsumer(ctx, consumerChannels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown) id := common.Endian.Uint64(seekPosition.MsgID) + 10 bs := make([]byte, 8) @@ -565,7 +565,7 @@ func TestStream_RMqMsgStream_SeekInvalidMessage(t *testing.T) { }, } - err = outputStream2.Seek(p) + err = outputStream2.Seek(ctx, p) assert.NoError(t, err) for i := 10; i < 20; i++ { @@ -589,7 +589,7 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) { factory := msgstream.ProtoUDFactory{} - rmqClient, _ := NewClientWithDefaultOptions() + rmqClient, _ := NewClientWithDefaultOptions(context.Background()) otherInputStream, _ := msgstream.NewMqMsgStream(context.Background(), 100, 100, rmqClient, factory.NewUnmarshalDispatcher()) otherInputStream.AsProducer([]string{"root_timetick"}) @@ -602,9 +602,9 @@ func TestStream_RmqTtMsgStream_AsConsumerWithPosition(t *testing.T) { inputStream.Produce(getTimeTickMsgPack(int64(i))) } - rmqClient2, _ := NewClientWithDefaultOptions() + rmqClient2, _ := NewClientWithDefaultOptions(context.Background()) outputStream, _ := msgstream.NewMqMsgStream(context.Background(), 100, 100, rmqClient2, factory.NewUnmarshalDispatcher()) - outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionLatest) + outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionLatest) inputStream.Produce(getTimeTickMsgPack(1000)) pack := <-outputStream.Chan() diff --git a/internal/parser/planparserv2/Plan.g4 b/internal/parser/planparserv2/Plan.g4 index 3bace1f90c0f5..cc8a479c3c347 100644 --- a/internal/parser/planparserv2/Plan.g4 +++ b/internal/parser/planparserv2/Plan.g4 @@ -18,9 +18,10 @@ expr: | expr op = (SHL | SHR) expr # Shift | expr op = (IN | NIN) ('[' expr (',' expr)* ','? ']') # Term | expr op = (IN | NIN) EmptyTerm # EmptyTerm - | JSONContains'('expr',' expr')' # JSONContains - | JSONContainsAll'('expr',' expr')' # JSONContainsAll - | JSONContainsAny'('expr',' expr')' # JSONContainsAny + | (JSONContains | ArrayContains)'('expr',' expr')' # JSONContains + | (JSONContainsAll | ArrayContainsAll)'('expr',' expr')' # JSONContainsAll + | (JSONContainsAny | ArrayContainsAny)'('expr',' expr')' # JSONContainsAny + | ArrayLength'('(Identifier | JSONIdentifier)')' # ArrayLength | expr op1 = (LT | LE) (Identifier | JSONIdentifier) op2 = (LT | LE) expr # Range | expr op1 = (GT | GE) (Identifier | JSONIdentifier) op2 = (GT | GE) expr # ReverseRange | expr op = (LT | LE | GT | GE) expr # Relational @@ -78,6 +79,11 @@ JSONContains: 'json_contains' | 'JSON_CONTAINS'; JSONContainsAll: 'json_contains_all' | 'JSON_CONTAINS_ALL'; JSONContainsAny: 'json_contains_any' | 'JSON_CONTAINS_ANY'; +ArrayContains: 'array_contains' | 'ARRAY_CONTAINS'; +ArrayContainsAll: 'array_contains_all' | 'ARRAY_CONTAINS_ALL'; +ArrayContainsAny: 'array_contains_any' | 'ARRAY_CONTAINS_ANY'; +ArrayLength: 'array_length' | 'ARRAY_LENGTH'; + BooleanConstant: 'true' | 'True' | 'TRUE' | 'false' | 'False' | 'FALSE'; IntegerConstant: diff --git a/internal/parser/planparserv2/check_identical.go b/internal/parser/planparserv2/check_identical.go index a967f3b0f4622..faa6efd022f2c 100644 --- a/internal/parser/planparserv2/check_identical.go +++ b/internal/parser/planparserv2/check_identical.go @@ -37,7 +37,7 @@ func CheckQueryInfoIdentical(info1, info2 *planpb.QueryInfo) bool { } func CheckVectorANNSIdentical(node1, node2 *planpb.VectorANNS) bool { - if node1.GetIsBinary() != node2.GetIsBinary() { + if node1.GetVectorType() != node2.GetVectorType() { return false } if node1.GetFieldId() != node2.GetFieldId() { diff --git a/internal/parser/planparserv2/check_identical_test.go b/internal/parser/planparserv2/check_identical_test.go index f50f3abb0a4bc..9f48aec504d8e 100644 --- a/internal/parser/planparserv2/check_identical_test.go +++ b/internal/parser/planparserv2/check_identical_test.go @@ -3,10 +3,10 @@ package planparserv2 import ( "testing" - "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/stretchr/testify/assert" ) func TestCheckIdentical(t *testing.T) { @@ -94,65 +94,73 @@ func TestCheckVectorANNSIdentical(t *testing.T) { }{ { args: args{ - node1: &planpb.VectorANNS{IsBinary: true}, - node2: &planpb.VectorANNS{IsBinary: false}, + node1: &planpb.VectorANNS{VectorType: planpb.VectorType_BinaryVector}, + node2: &planpb.VectorANNS{VectorType: planpb.VectorType_FloatVector}, }, want: false, }, { args: args{ - node1: &planpb.VectorANNS{IsBinary: false, FieldId: 100}, - node2: &planpb.VectorANNS{IsBinary: false, FieldId: 101}, + node1: &planpb.VectorANNS{VectorType: planpb.VectorType_FloatVector, FieldId: 100}, + node2: &planpb.VectorANNS{VectorType: planpb.VectorType_FloatVector, FieldId: 101}, }, want: false, }, { args: args{ - node1: &planpb.VectorANNS{IsBinary: false, FieldId: 100, PlaceholderTag: "$0"}, - node2: &planpb.VectorANNS{IsBinary: false, FieldId: 100, PlaceholderTag: "$1"}, + node1: &planpb.VectorANNS{VectorType: planpb.VectorType_FloatVector, FieldId: 100, PlaceholderTag: "$0"}, + node2: &planpb.VectorANNS{VectorType: planpb.VectorType_FloatVector, FieldId: 100, PlaceholderTag: "$1"}, }, want: false, }, { args: args{ - node1: &planpb.VectorANNS{IsBinary: false, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 100}}, - node2: &planpb.VectorANNS{IsBinary: false, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 10}}, + node1: &planpb.VectorANNS{VectorType: planpb.VectorType_FloatVector, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 100}}, + node2: &planpb.VectorANNS{VectorType: planpb.VectorType_FloatVector, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 10}}, }, want: false, }, { args: args{ - node1: &planpb.VectorANNS{IsBinary: false, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 1, MetricType: "L2", SearchParams: `{"nprobe": 10}`, RoundDecimal: 6}, + node1: &planpb.VectorANNS{ + VectorType: planpb.VectorType_FloatVector, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 1, MetricType: "L2", SearchParams: `{"nprobe": 10}`, RoundDecimal: 6}, Predicates: &planpb.Expr{ Expr: &planpb.Expr_ColumnExpr{ ColumnExpr: &planpb.ColumnExpr{ Info: &planpb.ColumnInfo{}, }, }, - }}, - node2: &planpb.VectorANNS{IsBinary: false, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 1, MetricType: "L2", SearchParams: `{"nprobe": 10}`, RoundDecimal: 6}, + }, + }, + node2: &planpb.VectorANNS{ + VectorType: planpb.VectorType_FloatVector, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 1, MetricType: "L2", SearchParams: `{"nprobe": 10}`, RoundDecimal: 6}, Predicates: &planpb.Expr{ Expr: &planpb.Expr_ValueExpr{ ValueExpr: &planpb.ValueExpr{Value: NewInt(100)}, }, - }}, + }, + }, }, want: false, }, { args: args{ - node1: &planpb.VectorANNS{IsBinary: false, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 1, MetricType: "L2", SearchParams: `{"nprobe": 10}`, RoundDecimal: 6}, + node1: &planpb.VectorANNS{ + VectorType: planpb.VectorType_FloatVector, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 1, MetricType: "L2", SearchParams: `{"nprobe": 10}`, RoundDecimal: 6}, Predicates: &planpb.Expr{ Expr: &planpb.Expr_ValueExpr{ ValueExpr: &planpb.ValueExpr{Value: NewInt(100)}, }, - }}, - node2: &planpb.VectorANNS{IsBinary: false, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 1, MetricType: "L2", SearchParams: `{"nprobe": 10}`, RoundDecimal: 6}, + }, + }, + node2: &planpb.VectorANNS{ + VectorType: planpb.VectorType_FloatVector, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 1, MetricType: "L2", SearchParams: `{"nprobe": 10}`, RoundDecimal: 6}, Predicates: &planpb.Expr{ Expr: &planpb.Expr_ValueExpr{ ValueExpr: &planpb.ValueExpr{Value: NewInt(100)}, }, - }}, + }, + }, }, want: true, }, @@ -194,7 +202,7 @@ func TestCheckPlanNodeIdentical(t *testing.T) { node1: &planpb.PlanNode{ Node: &planpb.PlanNode_VectorAnns{ VectorAnns: &planpb.VectorANNS{ - IsBinary: true, + VectorType: planpb.VectorType_BinaryVector, }, }, OutputFieldIds: []int64{100}, @@ -202,7 +210,7 @@ func TestCheckPlanNodeIdentical(t *testing.T) { node2: &planpb.PlanNode{ Node: &planpb.PlanNode_VectorAnns{ VectorAnns: &planpb.VectorANNS{ - IsBinary: false, + VectorType: planpb.VectorType_FloatVector, }, }, OutputFieldIds: []int64{100}, @@ -214,23 +222,27 @@ func TestCheckPlanNodeIdentical(t *testing.T) { args: args{ node1: &planpb.PlanNode{ Node: &planpb.PlanNode_VectorAnns{ - VectorAnns: &planpb.VectorANNS{IsBinary: false, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 1, MetricType: "L2", SearchParams: `{"nprobe": 10}`, RoundDecimal: 6}, + VectorAnns: &planpb.VectorANNS{ + VectorType: planpb.VectorType_FloatVector, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 1, MetricType: "L2", SearchParams: `{"nprobe": 10}`, RoundDecimal: 6}, Predicates: &planpb.Expr{ Expr: &planpb.Expr_ValueExpr{ ValueExpr: &planpb.ValueExpr{Value: NewInt(100)}, }, - }}, + }, + }, }, OutputFieldIds: []int64{100}, }, node2: &planpb.PlanNode{ Node: &planpb.PlanNode_VectorAnns{ - VectorAnns: &planpb.VectorANNS{IsBinary: false, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 1, MetricType: "L2", SearchParams: `{"nprobe": 10}`, RoundDecimal: 6}, + VectorAnns: &planpb.VectorANNS{ + VectorType: planpb.VectorType_FloatVector, FieldId: 100, PlaceholderTag: "$0", QueryInfo: &planpb.QueryInfo{Topk: 1, MetricType: "L2", SearchParams: `{"nprobe": 10}`, RoundDecimal: 6}, Predicates: &planpb.Expr{ Expr: &planpb.Expr_ValueExpr{ ValueExpr: &planpb.ValueExpr{Value: NewInt(100)}, }, - }}, + }, + }, }, OutputFieldIds: []int64{100}, }, diff --git a/internal/parser/planparserv2/generated/Plan.interp b/internal/parser/planparserv2/generated/Plan.interp index 0b6b8b448b31a..1c4888f9bffed 100644 --- a/internal/parser/planparserv2/generated/Plan.interp +++ b/internal/parser/planparserv2/generated/Plan.interp @@ -42,6 +42,10 @@ null null null null +null +null +null +null token symbolic names: null @@ -79,6 +83,10 @@ EmptyTerm JSONContains JSONContainsAll JSONContainsAny +ArrayContains +ArrayContainsAll +ArrayContainsAny +ArrayLength BooleanConstant IntegerConstant FloatingConstant @@ -93,4 +101,4 @@ expr atn: -[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 3, 44, 127, 4, 2, 9, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 7, 2, 20, 10, 2, 12, 2, 14, 2, 23, 11, 2, 3, 2, 5, 2, 26, 10, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 5, 2, 55, 10, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 7, 2, 109, 10, 2, 12, 2, 14, 2, 112, 11, 2, 3, 2, 5, 2, 115, 10, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 7, 2, 122, 10, 2, 12, 2, 14, 2, 125, 11, 2, 3, 2, 2, 3, 2, 3, 2, 2, 12, 4, 2, 16, 17, 29, 30, 3, 2, 18, 20, 3, 2, 16, 17, 3, 2, 22, 23, 3, 2, 8, 9, 4, 2, 40, 40, 42, 42, 3, 2, 10, 11, 3, 2, 8, 11, 3, 2, 12, 13, 3, 2, 31, 32, 2, 157, 2, 54, 3, 2, 2, 2, 4, 5, 8, 2, 1, 2, 5, 55, 7, 38, 2, 2, 6, 55, 7, 39, 2, 2, 7, 55, 7, 37, 2, 2, 8, 55, 7, 41, 2, 2, 9, 55, 7, 40, 2, 2, 10, 55, 7, 42, 2, 2, 11, 12, 7, 3, 2, 2, 12, 13, 5, 2, 2, 2, 13, 14, 7, 4, 2, 2, 14, 55, 3, 2, 2, 2, 15, 16, 7, 5, 2, 2, 16, 21, 5, 2, 2, 2, 17, 18, 7, 6, 2, 2, 18, 20, 5, 2, 2, 2, 19, 17, 3, 2, 2, 2, 20, 23, 3, 2, 2, 2, 21, 19, 3, 2, 2, 2, 21, 22, 3, 2, 2, 2, 22, 25, 3, 2, 2, 2, 23, 21, 3, 2, 2, 2, 24, 26, 7, 6, 2, 2, 25, 24, 3, 2, 2, 2, 25, 26, 3, 2, 2, 2, 26, 27, 3, 2, 2, 2, 27, 28, 7, 7, 2, 2, 28, 55, 3, 2, 2, 2, 29, 30, 9, 2, 2, 2, 30, 55, 5, 2, 2, 21, 31, 32, 7, 34, 2, 2, 32, 33, 7, 3, 2, 2, 33, 34, 5, 2, 2, 2, 34, 35, 7, 6, 2, 2, 35, 36, 5, 2, 2, 2, 36, 37, 7, 4, 2, 2, 37, 55, 3, 2, 2, 2, 38, 39, 7, 35, 2, 2, 39, 40, 7, 3, 2, 2, 40, 41, 5, 2, 2, 2, 41, 42, 7, 6, 2, 2, 42, 43, 5, 2, 2, 2, 43, 44, 7, 4, 2, 2, 44, 55, 3, 2, 2, 2, 45, 46, 7, 36, 2, 2, 46, 47, 7, 3, 2, 2, 47, 48, 5, 2, 2, 2, 48, 49, 7, 6, 2, 2, 49, 50, 5, 2, 2, 2, 50, 51, 7, 4, 2, 2, 51, 55, 3, 2, 2, 2, 52, 53, 7, 15, 2, 2, 53, 55, 5, 2, 2, 3, 54, 4, 3, 2, 2, 2, 54, 6, 3, 2, 2, 2, 54, 7, 3, 2, 2, 2, 54, 8, 3, 2, 2, 2, 54, 9, 3, 2, 2, 2, 54, 10, 3, 2, 2, 2, 54, 11, 3, 2, 2, 2, 54, 15, 3, 2, 2, 2, 54, 29, 3, 2, 2, 2, 54, 31, 3, 2, 2, 2, 54, 38, 3, 2, 2, 2, 54, 45, 3, 2, 2, 2, 54, 52, 3, 2, 2, 2, 55, 123, 3, 2, 2, 2, 56, 57, 12, 22, 2, 2, 57, 58, 7, 21, 2, 2, 58, 122, 5, 2, 2, 23, 59, 60, 12, 20, 2, 2, 60, 61, 9, 3, 2, 2, 61, 122, 5, 2, 2, 21, 62, 63, 12, 19, 2, 2, 63, 64, 9, 4, 2, 2, 64, 122, 5, 2, 2, 20, 65, 66, 12, 18, 2, 2, 66, 67, 9, 5, 2, 2, 67, 122, 5, 2, 2, 19, 68, 69, 12, 12, 2, 2, 69, 70, 9, 6, 2, 2, 70, 71, 9, 7, 2, 2, 71, 72, 9, 6, 2, 2, 72, 122, 5, 2, 2, 13, 73, 74, 12, 11, 2, 2, 74, 75, 9, 8, 2, 2, 75, 76, 9, 7, 2, 2, 76, 77, 9, 8, 2, 2, 77, 122, 5, 2, 2, 12, 78, 79, 12, 10, 2, 2, 79, 80, 9, 9, 2, 2, 80, 122, 5, 2, 2, 11, 81, 82, 12, 9, 2, 2, 82, 83, 9, 10, 2, 2, 83, 122, 5, 2, 2, 10, 84, 85, 12, 8, 2, 2, 85, 86, 7, 24, 2, 2, 86, 122, 5, 2, 2, 9, 87, 88, 12, 7, 2, 2, 88, 89, 7, 26, 2, 2, 89, 122, 5, 2, 2, 8, 90, 91, 12, 6, 2, 2, 91, 92, 7, 25, 2, 2, 92, 122, 5, 2, 2, 7, 93, 94, 12, 5, 2, 2, 94, 95, 7, 27, 2, 2, 95, 122, 5, 2, 2, 6, 96, 97, 12, 4, 2, 2, 97, 98, 7, 28, 2, 2, 98, 122, 5, 2, 2, 5, 99, 100, 12, 23, 2, 2, 100, 101, 7, 14, 2, 2, 101, 122, 7, 41, 2, 2, 102, 103, 12, 17, 2, 2, 103, 104, 9, 11, 2, 2, 104, 105, 7, 5, 2, 2, 105, 110, 5, 2, 2, 2, 106, 107, 7, 6, 2, 2, 107, 109, 5, 2, 2, 2, 108, 106, 3, 2, 2, 2, 109, 112, 3, 2, 2, 2, 110, 108, 3, 2, 2, 2, 110, 111, 3, 2, 2, 2, 111, 114, 3, 2, 2, 2, 112, 110, 3, 2, 2, 2, 113, 115, 7, 6, 2, 2, 114, 113, 3, 2, 2, 2, 114, 115, 3, 2, 2, 2, 115, 116, 3, 2, 2, 2, 116, 117, 7, 7, 2, 2, 117, 122, 3, 2, 2, 2, 118, 119, 12, 16, 2, 2, 119, 120, 9, 11, 2, 2, 120, 122, 7, 33, 2, 2, 121, 56, 3, 2, 2, 2, 121, 59, 3, 2, 2, 2, 121, 62, 3, 2, 2, 2, 121, 65, 3, 2, 2, 2, 121, 68, 3, 2, 2, 2, 121, 73, 3, 2, 2, 2, 121, 78, 3, 2, 2, 2, 121, 81, 3, 2, 2, 2, 121, 84, 3, 2, 2, 2, 121, 87, 3, 2, 2, 2, 121, 90, 3, 2, 2, 2, 121, 93, 3, 2, 2, 2, 121, 96, 3, 2, 2, 2, 121, 99, 3, 2, 2, 2, 121, 102, 3, 2, 2, 2, 121, 118, 3, 2, 2, 2, 122, 125, 3, 2, 2, 2, 123, 121, 3, 2, 2, 2, 123, 124, 3, 2, 2, 2, 124, 3, 3, 2, 2, 2, 125, 123, 3, 2, 2, 2, 9, 21, 25, 54, 110, 114, 121, 123] \ No newline at end of file +[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 3, 48, 131, 4, 2, 9, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 7, 2, 20, 10, 2, 12, 2, 14, 2, 23, 11, 2, 3, 2, 5, 2, 26, 10, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 5, 2, 59, 10, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 7, 2, 113, 10, 2, 12, 2, 14, 2, 116, 11, 2, 3, 2, 5, 2, 119, 10, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 7, 2, 126, 10, 2, 12, 2, 14, 2, 129, 11, 2, 3, 2, 2, 3, 2, 3, 2, 2, 15, 4, 2, 16, 17, 29, 30, 4, 2, 34, 34, 37, 37, 4, 2, 35, 35, 38, 38, 4, 2, 36, 36, 39, 39, 4, 2, 44, 44, 46, 46, 3, 2, 18, 20, 3, 2, 16, 17, 3, 2, 22, 23, 3, 2, 8, 9, 3, 2, 10, 11, 3, 2, 8, 11, 3, 2, 12, 13, 3, 2, 31, 32, 2, 162, 2, 58, 3, 2, 2, 2, 4, 5, 8, 2, 1, 2, 5, 59, 7, 42, 2, 2, 6, 59, 7, 43, 2, 2, 7, 59, 7, 41, 2, 2, 8, 59, 7, 45, 2, 2, 9, 59, 7, 44, 2, 2, 10, 59, 7, 46, 2, 2, 11, 12, 7, 3, 2, 2, 12, 13, 5, 2, 2, 2, 13, 14, 7, 4, 2, 2, 14, 59, 3, 2, 2, 2, 15, 16, 7, 5, 2, 2, 16, 21, 5, 2, 2, 2, 17, 18, 7, 6, 2, 2, 18, 20, 5, 2, 2, 2, 19, 17, 3, 2, 2, 2, 20, 23, 3, 2, 2, 2, 21, 19, 3, 2, 2, 2, 21, 22, 3, 2, 2, 2, 22, 25, 3, 2, 2, 2, 23, 21, 3, 2, 2, 2, 24, 26, 7, 6, 2, 2, 25, 24, 3, 2, 2, 2, 25, 26, 3, 2, 2, 2, 26, 27, 3, 2, 2, 2, 27, 28, 7, 7, 2, 2, 28, 59, 3, 2, 2, 2, 29, 30, 9, 2, 2, 2, 30, 59, 5, 2, 2, 22, 31, 32, 9, 3, 2, 2, 32, 33, 7, 3, 2, 2, 33, 34, 5, 2, 2, 2, 34, 35, 7, 6, 2, 2, 35, 36, 5, 2, 2, 2, 36, 37, 7, 4, 2, 2, 37, 59, 3, 2, 2, 2, 38, 39, 9, 4, 2, 2, 39, 40, 7, 3, 2, 2, 40, 41, 5, 2, 2, 2, 41, 42, 7, 6, 2, 2, 42, 43, 5, 2, 2, 2, 43, 44, 7, 4, 2, 2, 44, 59, 3, 2, 2, 2, 45, 46, 9, 5, 2, 2, 46, 47, 7, 3, 2, 2, 47, 48, 5, 2, 2, 2, 48, 49, 7, 6, 2, 2, 49, 50, 5, 2, 2, 2, 50, 51, 7, 4, 2, 2, 51, 59, 3, 2, 2, 2, 52, 53, 7, 40, 2, 2, 53, 54, 7, 3, 2, 2, 54, 55, 9, 6, 2, 2, 55, 59, 7, 4, 2, 2, 56, 57, 7, 15, 2, 2, 57, 59, 5, 2, 2, 3, 58, 4, 3, 2, 2, 2, 58, 6, 3, 2, 2, 2, 58, 7, 3, 2, 2, 2, 58, 8, 3, 2, 2, 2, 58, 9, 3, 2, 2, 2, 58, 10, 3, 2, 2, 2, 58, 11, 3, 2, 2, 2, 58, 15, 3, 2, 2, 2, 58, 29, 3, 2, 2, 2, 58, 31, 3, 2, 2, 2, 58, 38, 3, 2, 2, 2, 58, 45, 3, 2, 2, 2, 58, 52, 3, 2, 2, 2, 58, 56, 3, 2, 2, 2, 59, 127, 3, 2, 2, 2, 60, 61, 12, 23, 2, 2, 61, 62, 7, 21, 2, 2, 62, 126, 5, 2, 2, 24, 63, 64, 12, 21, 2, 2, 64, 65, 9, 7, 2, 2, 65, 126, 5, 2, 2, 22, 66, 67, 12, 20, 2, 2, 67, 68, 9, 8, 2, 2, 68, 126, 5, 2, 2, 21, 69, 70, 12, 19, 2, 2, 70, 71, 9, 9, 2, 2, 71, 126, 5, 2, 2, 20, 72, 73, 12, 12, 2, 2, 73, 74, 9, 10, 2, 2, 74, 75, 9, 6, 2, 2, 75, 76, 9, 10, 2, 2, 76, 126, 5, 2, 2, 13, 77, 78, 12, 11, 2, 2, 78, 79, 9, 11, 2, 2, 79, 80, 9, 6, 2, 2, 80, 81, 9, 11, 2, 2, 81, 126, 5, 2, 2, 12, 82, 83, 12, 10, 2, 2, 83, 84, 9, 12, 2, 2, 84, 126, 5, 2, 2, 11, 85, 86, 12, 9, 2, 2, 86, 87, 9, 13, 2, 2, 87, 126, 5, 2, 2, 10, 88, 89, 12, 8, 2, 2, 89, 90, 7, 24, 2, 2, 90, 126, 5, 2, 2, 9, 91, 92, 12, 7, 2, 2, 92, 93, 7, 26, 2, 2, 93, 126, 5, 2, 2, 8, 94, 95, 12, 6, 2, 2, 95, 96, 7, 25, 2, 2, 96, 126, 5, 2, 2, 7, 97, 98, 12, 5, 2, 2, 98, 99, 7, 27, 2, 2, 99, 126, 5, 2, 2, 6, 100, 101, 12, 4, 2, 2, 101, 102, 7, 28, 2, 2, 102, 126, 5, 2, 2, 5, 103, 104, 12, 24, 2, 2, 104, 105, 7, 14, 2, 2, 105, 126, 7, 45, 2, 2, 106, 107, 12, 18, 2, 2, 107, 108, 9, 14, 2, 2, 108, 109, 7, 5, 2, 2, 109, 114, 5, 2, 2, 2, 110, 111, 7, 6, 2, 2, 111, 113, 5, 2, 2, 2, 112, 110, 3, 2, 2, 2, 113, 116, 3, 2, 2, 2, 114, 112, 3, 2, 2, 2, 114, 115, 3, 2, 2, 2, 115, 118, 3, 2, 2, 2, 116, 114, 3, 2, 2, 2, 117, 119, 7, 6, 2, 2, 118, 117, 3, 2, 2, 2, 118, 119, 3, 2, 2, 2, 119, 120, 3, 2, 2, 2, 120, 121, 7, 7, 2, 2, 121, 126, 3, 2, 2, 2, 122, 123, 12, 17, 2, 2, 123, 124, 9, 14, 2, 2, 124, 126, 7, 33, 2, 2, 125, 60, 3, 2, 2, 2, 125, 63, 3, 2, 2, 2, 125, 66, 3, 2, 2, 2, 125, 69, 3, 2, 2, 2, 125, 72, 3, 2, 2, 2, 125, 77, 3, 2, 2, 2, 125, 82, 3, 2, 2, 2, 125, 85, 3, 2, 2, 2, 125, 88, 3, 2, 2, 2, 125, 91, 3, 2, 2, 2, 125, 94, 3, 2, 2, 2, 125, 97, 3, 2, 2, 2, 125, 100, 3, 2, 2, 2, 125, 103, 3, 2, 2, 2, 125, 106, 3, 2, 2, 2, 125, 122, 3, 2, 2, 2, 126, 129, 3, 2, 2, 2, 127, 125, 3, 2, 2, 2, 127, 128, 3, 2, 2, 2, 128, 3, 3, 2, 2, 2, 129, 127, 3, 2, 2, 2, 9, 21, 25, 58, 114, 118, 125, 127] \ No newline at end of file diff --git a/internal/parser/planparserv2/generated/Plan.tokens b/internal/parser/planparserv2/generated/Plan.tokens index ca7b53db2ebfa..e808c9b6391b3 100644 --- a/internal/parser/planparserv2/generated/Plan.tokens +++ b/internal/parser/planparserv2/generated/Plan.tokens @@ -32,14 +32,18 @@ EmptyTerm=31 JSONContains=32 JSONContainsAll=33 JSONContainsAny=34 -BooleanConstant=35 -IntegerConstant=36 -FloatingConstant=37 -Identifier=38 -StringLiteral=39 -JSONIdentifier=40 -Whitespace=41 -Newline=42 +ArrayContains=35 +ArrayContainsAll=36 +ArrayContainsAny=37 +ArrayLength=38 +BooleanConstant=39 +IntegerConstant=40 +FloatingConstant=41 +Identifier=42 +StringLiteral=43 +JSONIdentifier=44 +Whitespace=45 +Newline=46 '('=1 ')'=2 '['=3 diff --git a/internal/parser/planparserv2/generated/PlanLexer.interp b/internal/parser/planparserv2/generated/PlanLexer.interp index e82c9859dccd1..e4294edea7641 100644 --- a/internal/parser/planparserv2/generated/PlanLexer.interp +++ b/internal/parser/planparserv2/generated/PlanLexer.interp @@ -42,6 +42,10 @@ null null null null +null +null +null +null token symbolic names: null @@ -79,6 +83,10 @@ EmptyTerm JSONContains JSONContainsAll JSONContainsAny +ArrayContains +ArrayContainsAll +ArrayContainsAny +ArrayLength BooleanConstant IntegerConstant FloatingConstant @@ -123,6 +131,10 @@ EmptyTerm JSONContains JSONContainsAll JSONContainsAny +ArrayContains +ArrayContainsAll +ArrayContainsAny +ArrayLength BooleanConstant IntegerConstant FloatingConstant @@ -165,4 +177,4 @@ mode names: DEFAULT_MODE atn: -[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 44, 614, 8, 1, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 4, 20, 9, 20, 4, 21, 9, 21, 4, 22, 9, 22, 4, 23, 9, 23, 4, 24, 9, 24, 4, 25, 9, 25, 4, 26, 9, 26, 4, 27, 9, 27, 4, 28, 9, 28, 4, 29, 9, 29, 4, 30, 9, 30, 4, 31, 9, 31, 4, 32, 9, 32, 4, 33, 9, 33, 4, 34, 9, 34, 4, 35, 9, 35, 4, 36, 9, 36, 4, 37, 9, 37, 4, 38, 9, 38, 4, 39, 9, 39, 4, 40, 9, 40, 4, 41, 9, 41, 4, 42, 9, 42, 4, 43, 9, 43, 4, 44, 9, 44, 4, 45, 9, 45, 4, 46, 9, 46, 4, 47, 9, 47, 4, 48, 9, 48, 4, 49, 9, 49, 4, 50, 9, 50, 4, 51, 9, 51, 4, 52, 9, 52, 4, 53, 9, 53, 4, 54, 9, 54, 4, 55, 9, 55, 4, 56, 9, 56, 4, 57, 9, 57, 4, 58, 9, 58, 4, 59, 9, 59, 4, 60, 9, 60, 4, 61, 9, 61, 4, 62, 9, 62, 4, 63, 9, 63, 4, 64, 9, 64, 4, 65, 9, 65, 4, 66, 9, 66, 4, 67, 9, 67, 4, 68, 9, 68, 3, 2, 3, 2, 3, 3, 3, 3, 3, 4, 3, 4, 3, 5, 3, 5, 3, 6, 3, 6, 3, 7, 3, 7, 3, 8, 3, 8, 3, 8, 3, 9, 3, 9, 3, 10, 3, 10, 3, 10, 3, 11, 3, 11, 3, 11, 3, 12, 3, 12, 3, 12, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 5, 13, 172, 10, 13, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 5, 14, 186, 10, 14, 3, 15, 3, 15, 3, 16, 3, 16, 3, 17, 3, 17, 3, 18, 3, 18, 3, 19, 3, 19, 3, 20, 3, 20, 3, 20, 3, 21, 3, 21, 3, 21, 3, 22, 3, 22, 3, 22, 3, 23, 3, 23, 3, 24, 3, 24, 3, 25, 3, 25, 3, 26, 3, 26, 3, 26, 3, 26, 3, 26, 5, 26, 218, 10, 26, 3, 27, 3, 27, 3, 27, 3, 27, 5, 27, 224, 10, 27, 3, 28, 3, 28, 3, 29, 3, 29, 3, 29, 3, 29, 5, 29, 232, 10, 29, 3, 30, 3, 30, 3, 30, 3, 31, 3, 31, 3, 31, 3, 31, 3, 31, 3, 31, 3, 31, 3, 32, 3, 32, 3, 32, 7, 32, 247, 10, 32, 12, 32, 14, 32, 250, 11, 32, 3, 32, 3, 32, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 5, 33, 280, 10, 33, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 5, 34, 316, 10, 34, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 5, 35, 352, 10, 35, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 5, 36, 381, 10, 36, 3, 37, 3, 37, 3, 37, 3, 37, 5, 37, 387, 10, 37, 3, 38, 3, 38, 5, 38, 391, 10, 38, 3, 39, 3, 39, 3, 39, 7, 39, 396, 10, 39, 12, 39, 14, 39, 399, 11, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 5, 39, 406, 10, 39, 3, 40, 5, 40, 409, 10, 40, 3, 40, 3, 40, 5, 40, 413, 10, 40, 3, 40, 3, 40, 3, 40, 5, 40, 418, 10, 40, 3, 40, 5, 40, 421, 10, 40, 3, 41, 3, 41, 3, 41, 3, 41, 5, 41, 427, 10, 41, 3, 41, 3, 41, 6, 41, 431, 10, 41, 13, 41, 14, 41, 432, 3, 42, 3, 42, 3, 42, 5, 42, 438, 10, 42, 3, 43, 6, 43, 441, 10, 43, 13, 43, 14, 43, 442, 3, 44, 6, 44, 446, 10, 44, 13, 44, 14, 44, 447, 3, 45, 3, 45, 3, 45, 3, 45, 3, 45, 3, 45, 3, 45, 5, 45, 457, 10, 45, 3, 46, 3, 46, 3, 46, 3, 46, 3, 46, 3, 46, 3, 46, 5, 46, 466, 10, 46, 3, 47, 3, 47, 3, 48, 3, 48, 3, 49, 3, 49, 3, 49, 6, 49, 475, 10, 49, 13, 49, 14, 49, 476, 3, 50, 3, 50, 7, 50, 481, 10, 50, 12, 50, 14, 50, 484, 11, 50, 3, 50, 5, 50, 487, 10, 50, 3, 51, 3, 51, 7, 51, 491, 10, 51, 12, 51, 14, 51, 494, 11, 51, 3, 52, 3, 52, 3, 52, 3, 52, 3, 53, 3, 53, 3, 54, 3, 54, 3, 55, 3, 55, 3, 56, 3, 56, 3, 56, 3, 56, 3, 56, 3, 57, 3, 57, 3, 57, 3, 57, 3, 57, 3, 57, 3, 57, 3, 57, 3, 57, 3, 57, 5, 57, 521, 10, 57, 3, 58, 3, 58, 5, 58, 525, 10, 58, 3, 58, 3, 58, 3, 58, 5, 58, 530, 10, 58, 3, 59, 3, 59, 3, 59, 3, 59, 5, 59, 536, 10, 59, 3, 59, 3, 59, 3, 60, 5, 60, 541, 10, 60, 3, 60, 3, 60, 3, 60, 3, 60, 3, 60, 5, 60, 548, 10, 60, 3, 61, 3, 61, 5, 61, 552, 10, 61, 3, 61, 3, 61, 3, 62, 6, 62, 557, 10, 62, 13, 62, 14, 62, 558, 3, 63, 5, 63, 562, 10, 63, 3, 63, 3, 63, 3, 63, 3, 63, 3, 63, 5, 63, 569, 10, 63, 3, 64, 6, 64, 572, 10, 64, 13, 64, 14, 64, 573, 3, 65, 3, 65, 5, 65, 578, 10, 65, 3, 65, 3, 65, 3, 66, 3, 66, 3, 66, 3, 66, 3, 66, 5, 66, 587, 10, 66, 3, 66, 5, 66, 590, 10, 66, 3, 66, 3, 66, 3, 66, 3, 66, 3, 66, 5, 66, 597, 10, 66, 3, 67, 6, 67, 600, 10, 67, 13, 67, 14, 67, 601, 3, 67, 3, 67, 3, 68, 3, 68, 5, 68, 608, 10, 68, 3, 68, 5, 68, 611, 10, 68, 3, 68, 3, 68, 2, 2, 69, 3, 3, 5, 4, 7, 5, 9, 6, 11, 7, 13, 8, 15, 9, 17, 10, 19, 11, 21, 12, 23, 13, 25, 14, 27, 15, 29, 16, 31, 17, 33, 18, 35, 19, 37, 20, 39, 21, 41, 22, 43, 23, 45, 24, 47, 25, 49, 26, 51, 27, 53, 28, 55, 29, 57, 30, 59, 31, 61, 32, 63, 33, 65, 34, 67, 35, 69, 36, 71, 37, 73, 38, 75, 39, 77, 40, 79, 41, 81, 42, 83, 2, 85, 2, 87, 2, 89, 2, 91, 2, 93, 2, 95, 2, 97, 2, 99, 2, 101, 2, 103, 2, 105, 2, 107, 2, 109, 2, 111, 2, 113, 2, 115, 2, 117, 2, 119, 2, 121, 2, 123, 2, 125, 2, 127, 2, 129, 2, 131, 2, 133, 43, 135, 44, 3, 2, 18, 5, 2, 78, 78, 87, 87, 119, 119, 6, 2, 12, 12, 15, 15, 36, 36, 94, 94, 6, 2, 12, 12, 15, 15, 41, 41, 94, 94, 5, 2, 67, 92, 97, 97, 99, 124, 3, 2, 50, 59, 4, 2, 68, 68, 100, 100, 3, 2, 50, 51, 4, 2, 90, 90, 122, 122, 3, 2, 51, 59, 3, 2, 50, 57, 5, 2, 50, 59, 67, 72, 99, 104, 4, 2, 71, 71, 103, 103, 4, 2, 45, 45, 47, 47, 4, 2, 82, 82, 114, 114, 12, 2, 36, 36, 41, 41, 65, 65, 94, 94, 99, 100, 104, 104, 112, 112, 116, 116, 118, 118, 120, 120, 4, 2, 11, 11, 34, 34, 2, 649, 2, 3, 3, 2, 2, 2, 2, 5, 3, 2, 2, 2, 2, 7, 3, 2, 2, 2, 2, 9, 3, 2, 2, 2, 2, 11, 3, 2, 2, 2, 2, 13, 3, 2, 2, 2, 2, 15, 3, 2, 2, 2, 2, 17, 3, 2, 2, 2, 2, 19, 3, 2, 2, 2, 2, 21, 3, 2, 2, 2, 2, 23, 3, 2, 2, 2, 2, 25, 3, 2, 2, 2, 2, 27, 3, 2, 2, 2, 2, 29, 3, 2, 2, 2, 2, 31, 3, 2, 2, 2, 2, 33, 3, 2, 2, 2, 2, 35, 3, 2, 2, 2, 2, 37, 3, 2, 2, 2, 2, 39, 3, 2, 2, 2, 2, 41, 3, 2, 2, 2, 2, 43, 3, 2, 2, 2, 2, 45, 3, 2, 2, 2, 2, 47, 3, 2, 2, 2, 2, 49, 3, 2, 2, 2, 2, 51, 3, 2, 2, 2, 2, 53, 3, 2, 2, 2, 2, 55, 3, 2, 2, 2, 2, 57, 3, 2, 2, 2, 2, 59, 3, 2, 2, 2, 2, 61, 3, 2, 2, 2, 2, 63, 3, 2, 2, 2, 2, 65, 3, 2, 2, 2, 2, 67, 3, 2, 2, 2, 2, 69, 3, 2, 2, 2, 2, 71, 3, 2, 2, 2, 2, 73, 3, 2, 2, 2, 2, 75, 3, 2, 2, 2, 2, 77, 3, 2, 2, 2, 2, 79, 3, 2, 2, 2, 2, 81, 3, 2, 2, 2, 2, 133, 3, 2, 2, 2, 2, 135, 3, 2, 2, 2, 3, 137, 3, 2, 2, 2, 5, 139, 3, 2, 2, 2, 7, 141, 3, 2, 2, 2, 9, 143, 3, 2, 2, 2, 11, 145, 3, 2, 2, 2, 13, 147, 3, 2, 2, 2, 15, 149, 3, 2, 2, 2, 17, 152, 3, 2, 2, 2, 19, 154, 3, 2, 2, 2, 21, 157, 3, 2, 2, 2, 23, 160, 3, 2, 2, 2, 25, 171, 3, 2, 2, 2, 27, 185, 3, 2, 2, 2, 29, 187, 3, 2, 2, 2, 31, 189, 3, 2, 2, 2, 33, 191, 3, 2, 2, 2, 35, 193, 3, 2, 2, 2, 37, 195, 3, 2, 2, 2, 39, 197, 3, 2, 2, 2, 41, 200, 3, 2, 2, 2, 43, 203, 3, 2, 2, 2, 45, 206, 3, 2, 2, 2, 47, 208, 3, 2, 2, 2, 49, 210, 3, 2, 2, 2, 51, 217, 3, 2, 2, 2, 53, 223, 3, 2, 2, 2, 55, 225, 3, 2, 2, 2, 57, 231, 3, 2, 2, 2, 59, 233, 3, 2, 2, 2, 61, 236, 3, 2, 2, 2, 63, 243, 3, 2, 2, 2, 65, 279, 3, 2, 2, 2, 67, 315, 3, 2, 2, 2, 69, 351, 3, 2, 2, 2, 71, 380, 3, 2, 2, 2, 73, 386, 3, 2, 2, 2, 75, 390, 3, 2, 2, 2, 77, 405, 3, 2, 2, 2, 79, 408, 3, 2, 2, 2, 81, 422, 3, 2, 2, 2, 83, 437, 3, 2, 2, 2, 85, 440, 3, 2, 2, 2, 87, 445, 3, 2, 2, 2, 89, 456, 3, 2, 2, 2, 91, 465, 3, 2, 2, 2, 93, 467, 3, 2, 2, 2, 95, 469, 3, 2, 2, 2, 97, 471, 3, 2, 2, 2, 99, 486, 3, 2, 2, 2, 101, 488, 3, 2, 2, 2, 103, 495, 3, 2, 2, 2, 105, 499, 3, 2, 2, 2, 107, 501, 3, 2, 2, 2, 109, 503, 3, 2, 2, 2, 111, 505, 3, 2, 2, 2, 113, 520, 3, 2, 2, 2, 115, 529, 3, 2, 2, 2, 117, 531, 3, 2, 2, 2, 119, 547, 3, 2, 2, 2, 121, 549, 3, 2, 2, 2, 123, 556, 3, 2, 2, 2, 125, 568, 3, 2, 2, 2, 127, 571, 3, 2, 2, 2, 129, 575, 3, 2, 2, 2, 131, 596, 3, 2, 2, 2, 133, 599, 3, 2, 2, 2, 135, 610, 3, 2, 2, 2, 137, 138, 7, 42, 2, 2, 138, 4, 3, 2, 2, 2, 139, 140, 7, 43, 2, 2, 140, 6, 3, 2, 2, 2, 141, 142, 7, 93, 2, 2, 142, 8, 3, 2, 2, 2, 143, 144, 7, 46, 2, 2, 144, 10, 3, 2, 2, 2, 145, 146, 7, 95, 2, 2, 146, 12, 3, 2, 2, 2, 147, 148, 7, 62, 2, 2, 148, 14, 3, 2, 2, 2, 149, 150, 7, 62, 2, 2, 150, 151, 7, 63, 2, 2, 151, 16, 3, 2, 2, 2, 152, 153, 7, 64, 2, 2, 153, 18, 3, 2, 2, 2, 154, 155, 7, 64, 2, 2, 155, 156, 7, 63, 2, 2, 156, 20, 3, 2, 2, 2, 157, 158, 7, 63, 2, 2, 158, 159, 7, 63, 2, 2, 159, 22, 3, 2, 2, 2, 160, 161, 7, 35, 2, 2, 161, 162, 7, 63, 2, 2, 162, 24, 3, 2, 2, 2, 163, 164, 7, 110, 2, 2, 164, 165, 7, 107, 2, 2, 165, 166, 7, 109, 2, 2, 166, 172, 7, 103, 2, 2, 167, 168, 7, 78, 2, 2, 168, 169, 7, 75, 2, 2, 169, 170, 7, 77, 2, 2, 170, 172, 7, 71, 2, 2, 171, 163, 3, 2, 2, 2, 171, 167, 3, 2, 2, 2, 172, 26, 3, 2, 2, 2, 173, 174, 7, 103, 2, 2, 174, 175, 7, 122, 2, 2, 175, 176, 7, 107, 2, 2, 176, 177, 7, 117, 2, 2, 177, 178, 7, 118, 2, 2, 178, 186, 7, 117, 2, 2, 179, 180, 7, 71, 2, 2, 180, 181, 7, 90, 2, 2, 181, 182, 7, 75, 2, 2, 182, 183, 7, 85, 2, 2, 183, 184, 7, 86, 2, 2, 184, 186, 7, 85, 2, 2, 185, 173, 3, 2, 2, 2, 185, 179, 3, 2, 2, 2, 186, 28, 3, 2, 2, 2, 187, 188, 7, 45, 2, 2, 188, 30, 3, 2, 2, 2, 189, 190, 7, 47, 2, 2, 190, 32, 3, 2, 2, 2, 191, 192, 7, 44, 2, 2, 192, 34, 3, 2, 2, 2, 193, 194, 7, 49, 2, 2, 194, 36, 3, 2, 2, 2, 195, 196, 7, 39, 2, 2, 196, 38, 3, 2, 2, 2, 197, 198, 7, 44, 2, 2, 198, 199, 7, 44, 2, 2, 199, 40, 3, 2, 2, 2, 200, 201, 7, 62, 2, 2, 201, 202, 7, 62, 2, 2, 202, 42, 3, 2, 2, 2, 203, 204, 7, 64, 2, 2, 204, 205, 7, 64, 2, 2, 205, 44, 3, 2, 2, 2, 206, 207, 7, 40, 2, 2, 207, 46, 3, 2, 2, 2, 208, 209, 7, 126, 2, 2, 209, 48, 3, 2, 2, 2, 210, 211, 7, 96, 2, 2, 211, 50, 3, 2, 2, 2, 212, 213, 7, 40, 2, 2, 213, 218, 7, 40, 2, 2, 214, 215, 7, 99, 2, 2, 215, 216, 7, 112, 2, 2, 216, 218, 7, 102, 2, 2, 217, 212, 3, 2, 2, 2, 217, 214, 3, 2, 2, 2, 218, 52, 3, 2, 2, 2, 219, 220, 7, 126, 2, 2, 220, 224, 7, 126, 2, 2, 221, 222, 7, 113, 2, 2, 222, 224, 7, 116, 2, 2, 223, 219, 3, 2, 2, 2, 223, 221, 3, 2, 2, 2, 224, 54, 3, 2, 2, 2, 225, 226, 7, 128, 2, 2, 226, 56, 3, 2, 2, 2, 227, 232, 7, 35, 2, 2, 228, 229, 7, 112, 2, 2, 229, 230, 7, 113, 2, 2, 230, 232, 7, 118, 2, 2, 231, 227, 3, 2, 2, 2, 231, 228, 3, 2, 2, 2, 232, 58, 3, 2, 2, 2, 233, 234, 7, 107, 2, 2, 234, 235, 7, 112, 2, 2, 235, 60, 3, 2, 2, 2, 236, 237, 7, 112, 2, 2, 237, 238, 7, 113, 2, 2, 238, 239, 7, 118, 2, 2, 239, 240, 7, 34, 2, 2, 240, 241, 7, 107, 2, 2, 241, 242, 7, 112, 2, 2, 242, 62, 3, 2, 2, 2, 243, 248, 7, 93, 2, 2, 244, 247, 5, 133, 67, 2, 245, 247, 5, 135, 68, 2, 246, 244, 3, 2, 2, 2, 246, 245, 3, 2, 2, 2, 247, 250, 3, 2, 2, 2, 248, 246, 3, 2, 2, 2, 248, 249, 3, 2, 2, 2, 249, 251, 3, 2, 2, 2, 250, 248, 3, 2, 2, 2, 251, 252, 7, 95, 2, 2, 252, 64, 3, 2, 2, 2, 253, 254, 7, 108, 2, 2, 254, 255, 7, 117, 2, 2, 255, 256, 7, 113, 2, 2, 256, 257, 7, 112, 2, 2, 257, 258, 7, 97, 2, 2, 258, 259, 7, 101, 2, 2, 259, 260, 7, 113, 2, 2, 260, 261, 7, 112, 2, 2, 261, 262, 7, 118, 2, 2, 262, 263, 7, 99, 2, 2, 263, 264, 7, 107, 2, 2, 264, 265, 7, 112, 2, 2, 265, 280, 7, 117, 2, 2, 266, 267, 7, 76, 2, 2, 267, 268, 7, 85, 2, 2, 268, 269, 7, 81, 2, 2, 269, 270, 7, 80, 2, 2, 270, 271, 7, 97, 2, 2, 271, 272, 7, 69, 2, 2, 272, 273, 7, 81, 2, 2, 273, 274, 7, 80, 2, 2, 274, 275, 7, 86, 2, 2, 275, 276, 7, 67, 2, 2, 276, 277, 7, 75, 2, 2, 277, 278, 7, 80, 2, 2, 278, 280, 7, 85, 2, 2, 279, 253, 3, 2, 2, 2, 279, 266, 3, 2, 2, 2, 280, 66, 3, 2, 2, 2, 281, 282, 7, 108, 2, 2, 282, 283, 7, 117, 2, 2, 283, 284, 7, 113, 2, 2, 284, 285, 7, 112, 2, 2, 285, 286, 7, 97, 2, 2, 286, 287, 7, 101, 2, 2, 287, 288, 7, 113, 2, 2, 288, 289, 7, 112, 2, 2, 289, 290, 7, 118, 2, 2, 290, 291, 7, 99, 2, 2, 291, 292, 7, 107, 2, 2, 292, 293, 7, 112, 2, 2, 293, 294, 7, 117, 2, 2, 294, 295, 7, 97, 2, 2, 295, 296, 7, 99, 2, 2, 296, 297, 7, 110, 2, 2, 297, 316, 7, 110, 2, 2, 298, 299, 7, 76, 2, 2, 299, 300, 7, 85, 2, 2, 300, 301, 7, 81, 2, 2, 301, 302, 7, 80, 2, 2, 302, 303, 7, 97, 2, 2, 303, 304, 7, 69, 2, 2, 304, 305, 7, 81, 2, 2, 305, 306, 7, 80, 2, 2, 306, 307, 7, 86, 2, 2, 307, 308, 7, 67, 2, 2, 308, 309, 7, 75, 2, 2, 309, 310, 7, 80, 2, 2, 310, 311, 7, 85, 2, 2, 311, 312, 7, 97, 2, 2, 312, 313, 7, 67, 2, 2, 313, 314, 7, 78, 2, 2, 314, 316, 7, 78, 2, 2, 315, 281, 3, 2, 2, 2, 315, 298, 3, 2, 2, 2, 316, 68, 3, 2, 2, 2, 317, 318, 7, 108, 2, 2, 318, 319, 7, 117, 2, 2, 319, 320, 7, 113, 2, 2, 320, 321, 7, 112, 2, 2, 321, 322, 7, 97, 2, 2, 322, 323, 7, 101, 2, 2, 323, 324, 7, 113, 2, 2, 324, 325, 7, 112, 2, 2, 325, 326, 7, 118, 2, 2, 326, 327, 7, 99, 2, 2, 327, 328, 7, 107, 2, 2, 328, 329, 7, 112, 2, 2, 329, 330, 7, 117, 2, 2, 330, 331, 7, 97, 2, 2, 331, 332, 7, 99, 2, 2, 332, 333, 7, 112, 2, 2, 333, 352, 7, 123, 2, 2, 334, 335, 7, 76, 2, 2, 335, 336, 7, 85, 2, 2, 336, 337, 7, 81, 2, 2, 337, 338, 7, 80, 2, 2, 338, 339, 7, 97, 2, 2, 339, 340, 7, 69, 2, 2, 340, 341, 7, 81, 2, 2, 341, 342, 7, 80, 2, 2, 342, 343, 7, 86, 2, 2, 343, 344, 7, 67, 2, 2, 344, 345, 7, 75, 2, 2, 345, 346, 7, 80, 2, 2, 346, 347, 7, 85, 2, 2, 347, 348, 7, 97, 2, 2, 348, 349, 7, 67, 2, 2, 349, 350, 7, 80, 2, 2, 350, 352, 7, 91, 2, 2, 351, 317, 3, 2, 2, 2, 351, 334, 3, 2, 2, 2, 352, 70, 3, 2, 2, 2, 353, 354, 7, 118, 2, 2, 354, 355, 7, 116, 2, 2, 355, 356, 7, 119, 2, 2, 356, 381, 7, 103, 2, 2, 357, 358, 7, 86, 2, 2, 358, 359, 7, 116, 2, 2, 359, 360, 7, 119, 2, 2, 360, 381, 7, 103, 2, 2, 361, 362, 7, 86, 2, 2, 362, 363, 7, 84, 2, 2, 363, 364, 7, 87, 2, 2, 364, 381, 7, 71, 2, 2, 365, 366, 7, 104, 2, 2, 366, 367, 7, 99, 2, 2, 367, 368, 7, 110, 2, 2, 368, 369, 7, 117, 2, 2, 369, 381, 7, 103, 2, 2, 370, 371, 7, 72, 2, 2, 371, 372, 7, 99, 2, 2, 372, 373, 7, 110, 2, 2, 373, 374, 7, 117, 2, 2, 374, 381, 7, 103, 2, 2, 375, 376, 7, 72, 2, 2, 376, 377, 7, 67, 2, 2, 377, 378, 7, 78, 2, 2, 378, 379, 7, 85, 2, 2, 379, 381, 7, 71, 2, 2, 380, 353, 3, 2, 2, 2, 380, 357, 3, 2, 2, 2, 380, 361, 3, 2, 2, 2, 380, 365, 3, 2, 2, 2, 380, 370, 3, 2, 2, 2, 380, 375, 3, 2, 2, 2, 381, 72, 3, 2, 2, 2, 382, 387, 5, 99, 50, 2, 383, 387, 5, 101, 51, 2, 384, 387, 5, 103, 52, 2, 385, 387, 5, 97, 49, 2, 386, 382, 3, 2, 2, 2, 386, 383, 3, 2, 2, 2, 386, 384, 3, 2, 2, 2, 386, 385, 3, 2, 2, 2, 387, 74, 3, 2, 2, 2, 388, 391, 5, 115, 58, 2, 389, 391, 5, 117, 59, 2, 390, 388, 3, 2, 2, 2, 390, 389, 3, 2, 2, 2, 391, 76, 3, 2, 2, 2, 392, 397, 5, 93, 47, 2, 393, 396, 5, 93, 47, 2, 394, 396, 5, 95, 48, 2, 395, 393, 3, 2, 2, 2, 395, 394, 3, 2, 2, 2, 396, 399, 3, 2, 2, 2, 397, 395, 3, 2, 2, 2, 397, 398, 3, 2, 2, 2, 398, 406, 3, 2, 2, 2, 399, 397, 3, 2, 2, 2, 400, 401, 7, 38, 2, 2, 401, 402, 7, 111, 2, 2, 402, 403, 7, 103, 2, 2, 403, 404, 7, 118, 2, 2, 404, 406, 7, 99, 2, 2, 405, 392, 3, 2, 2, 2, 405, 400, 3, 2, 2, 2, 406, 78, 3, 2, 2, 2, 407, 409, 5, 83, 42, 2, 408, 407, 3, 2, 2, 2, 408, 409, 3, 2, 2, 2, 409, 420, 3, 2, 2, 2, 410, 412, 7, 36, 2, 2, 411, 413, 5, 85, 43, 2, 412, 411, 3, 2, 2, 2, 412, 413, 3, 2, 2, 2, 413, 414, 3, 2, 2, 2, 414, 421, 7, 36, 2, 2, 415, 417, 7, 41, 2, 2, 416, 418, 5, 87, 44, 2, 417, 416, 3, 2, 2, 2, 417, 418, 3, 2, 2, 2, 418, 419, 3, 2, 2, 2, 419, 421, 7, 41, 2, 2, 420, 410, 3, 2, 2, 2, 420, 415, 3, 2, 2, 2, 421, 80, 3, 2, 2, 2, 422, 430, 5, 77, 39, 2, 423, 426, 7, 93, 2, 2, 424, 427, 5, 79, 40, 2, 425, 427, 5, 99, 50, 2, 426, 424, 3, 2, 2, 2, 426, 425, 3, 2, 2, 2, 427, 428, 3, 2, 2, 2, 428, 429, 7, 95, 2, 2, 429, 431, 3, 2, 2, 2, 430, 423, 3, 2, 2, 2, 431, 432, 3, 2, 2, 2, 432, 430, 3, 2, 2, 2, 432, 433, 3, 2, 2, 2, 433, 82, 3, 2, 2, 2, 434, 435, 7, 119, 2, 2, 435, 438, 7, 58, 2, 2, 436, 438, 9, 2, 2, 2, 437, 434, 3, 2, 2, 2, 437, 436, 3, 2, 2, 2, 438, 84, 3, 2, 2, 2, 439, 441, 5, 89, 45, 2, 440, 439, 3, 2, 2, 2, 441, 442, 3, 2, 2, 2, 442, 440, 3, 2, 2, 2, 442, 443, 3, 2, 2, 2, 443, 86, 3, 2, 2, 2, 444, 446, 5, 91, 46, 2, 445, 444, 3, 2, 2, 2, 446, 447, 3, 2, 2, 2, 447, 445, 3, 2, 2, 2, 447, 448, 3, 2, 2, 2, 448, 88, 3, 2, 2, 2, 449, 457, 10, 3, 2, 2, 450, 457, 5, 131, 66, 2, 451, 452, 7, 94, 2, 2, 452, 457, 7, 12, 2, 2, 453, 454, 7, 94, 2, 2, 454, 455, 7, 15, 2, 2, 455, 457, 7, 12, 2, 2, 456, 449, 3, 2, 2, 2, 456, 450, 3, 2, 2, 2, 456, 451, 3, 2, 2, 2, 456, 453, 3, 2, 2, 2, 457, 90, 3, 2, 2, 2, 458, 466, 10, 4, 2, 2, 459, 466, 5, 131, 66, 2, 460, 461, 7, 94, 2, 2, 461, 466, 7, 12, 2, 2, 462, 463, 7, 94, 2, 2, 463, 464, 7, 15, 2, 2, 464, 466, 7, 12, 2, 2, 465, 458, 3, 2, 2, 2, 465, 459, 3, 2, 2, 2, 465, 460, 3, 2, 2, 2, 465, 462, 3, 2, 2, 2, 466, 92, 3, 2, 2, 2, 467, 468, 9, 5, 2, 2, 468, 94, 3, 2, 2, 2, 469, 470, 9, 6, 2, 2, 470, 96, 3, 2, 2, 2, 471, 472, 7, 50, 2, 2, 472, 474, 9, 7, 2, 2, 473, 475, 9, 8, 2, 2, 474, 473, 3, 2, 2, 2, 475, 476, 3, 2, 2, 2, 476, 474, 3, 2, 2, 2, 476, 477, 3, 2, 2, 2, 477, 98, 3, 2, 2, 2, 478, 482, 5, 105, 53, 2, 479, 481, 5, 95, 48, 2, 480, 479, 3, 2, 2, 2, 481, 484, 3, 2, 2, 2, 482, 480, 3, 2, 2, 2, 482, 483, 3, 2, 2, 2, 483, 487, 3, 2, 2, 2, 484, 482, 3, 2, 2, 2, 485, 487, 7, 50, 2, 2, 486, 478, 3, 2, 2, 2, 486, 485, 3, 2, 2, 2, 487, 100, 3, 2, 2, 2, 488, 492, 7, 50, 2, 2, 489, 491, 5, 107, 54, 2, 490, 489, 3, 2, 2, 2, 491, 494, 3, 2, 2, 2, 492, 490, 3, 2, 2, 2, 492, 493, 3, 2, 2, 2, 493, 102, 3, 2, 2, 2, 494, 492, 3, 2, 2, 2, 495, 496, 7, 50, 2, 2, 496, 497, 9, 9, 2, 2, 497, 498, 5, 127, 64, 2, 498, 104, 3, 2, 2, 2, 499, 500, 9, 10, 2, 2, 500, 106, 3, 2, 2, 2, 501, 502, 9, 11, 2, 2, 502, 108, 3, 2, 2, 2, 503, 504, 9, 12, 2, 2, 504, 110, 3, 2, 2, 2, 505, 506, 5, 109, 55, 2, 506, 507, 5, 109, 55, 2, 507, 508, 5, 109, 55, 2, 508, 509, 5, 109, 55, 2, 509, 112, 3, 2, 2, 2, 510, 511, 7, 94, 2, 2, 511, 512, 7, 119, 2, 2, 512, 513, 3, 2, 2, 2, 513, 521, 5, 111, 56, 2, 514, 515, 7, 94, 2, 2, 515, 516, 7, 87, 2, 2, 516, 517, 3, 2, 2, 2, 517, 518, 5, 111, 56, 2, 518, 519, 5, 111, 56, 2, 519, 521, 3, 2, 2, 2, 520, 510, 3, 2, 2, 2, 520, 514, 3, 2, 2, 2, 521, 114, 3, 2, 2, 2, 522, 524, 5, 119, 60, 2, 523, 525, 5, 121, 61, 2, 524, 523, 3, 2, 2, 2, 524, 525, 3, 2, 2, 2, 525, 530, 3, 2, 2, 2, 526, 527, 5, 123, 62, 2, 527, 528, 5, 121, 61, 2, 528, 530, 3, 2, 2, 2, 529, 522, 3, 2, 2, 2, 529, 526, 3, 2, 2, 2, 530, 116, 3, 2, 2, 2, 531, 532, 7, 50, 2, 2, 532, 535, 9, 9, 2, 2, 533, 536, 5, 125, 63, 2, 534, 536, 5, 127, 64, 2, 535, 533, 3, 2, 2, 2, 535, 534, 3, 2, 2, 2, 536, 537, 3, 2, 2, 2, 537, 538, 5, 129, 65, 2, 538, 118, 3, 2, 2, 2, 539, 541, 5, 123, 62, 2, 540, 539, 3, 2, 2, 2, 540, 541, 3, 2, 2, 2, 541, 542, 3, 2, 2, 2, 542, 543, 7, 48, 2, 2, 543, 548, 5, 123, 62, 2, 544, 545, 5, 123, 62, 2, 545, 546, 7, 48, 2, 2, 546, 548, 3, 2, 2, 2, 547, 540, 3, 2, 2, 2, 547, 544, 3, 2, 2, 2, 548, 120, 3, 2, 2, 2, 549, 551, 9, 13, 2, 2, 550, 552, 9, 14, 2, 2, 551, 550, 3, 2, 2, 2, 551, 552, 3, 2, 2, 2, 552, 553, 3, 2, 2, 2, 553, 554, 5, 123, 62, 2, 554, 122, 3, 2, 2, 2, 555, 557, 5, 95, 48, 2, 556, 555, 3, 2, 2, 2, 557, 558, 3, 2, 2, 2, 558, 556, 3, 2, 2, 2, 558, 559, 3, 2, 2, 2, 559, 124, 3, 2, 2, 2, 560, 562, 5, 127, 64, 2, 561, 560, 3, 2, 2, 2, 561, 562, 3, 2, 2, 2, 562, 563, 3, 2, 2, 2, 563, 564, 7, 48, 2, 2, 564, 569, 5, 127, 64, 2, 565, 566, 5, 127, 64, 2, 566, 567, 7, 48, 2, 2, 567, 569, 3, 2, 2, 2, 568, 561, 3, 2, 2, 2, 568, 565, 3, 2, 2, 2, 569, 126, 3, 2, 2, 2, 570, 572, 5, 109, 55, 2, 571, 570, 3, 2, 2, 2, 572, 573, 3, 2, 2, 2, 573, 571, 3, 2, 2, 2, 573, 574, 3, 2, 2, 2, 574, 128, 3, 2, 2, 2, 575, 577, 9, 15, 2, 2, 576, 578, 9, 14, 2, 2, 577, 576, 3, 2, 2, 2, 577, 578, 3, 2, 2, 2, 578, 579, 3, 2, 2, 2, 579, 580, 5, 123, 62, 2, 580, 130, 3, 2, 2, 2, 581, 582, 7, 94, 2, 2, 582, 597, 9, 16, 2, 2, 583, 584, 7, 94, 2, 2, 584, 586, 5, 107, 54, 2, 585, 587, 5, 107, 54, 2, 586, 585, 3, 2, 2, 2, 586, 587, 3, 2, 2, 2, 587, 589, 3, 2, 2, 2, 588, 590, 5, 107, 54, 2, 589, 588, 3, 2, 2, 2, 589, 590, 3, 2, 2, 2, 590, 597, 3, 2, 2, 2, 591, 592, 7, 94, 2, 2, 592, 593, 7, 122, 2, 2, 593, 594, 3, 2, 2, 2, 594, 597, 5, 127, 64, 2, 595, 597, 5, 113, 57, 2, 596, 581, 3, 2, 2, 2, 596, 583, 3, 2, 2, 2, 596, 591, 3, 2, 2, 2, 596, 595, 3, 2, 2, 2, 597, 132, 3, 2, 2, 2, 598, 600, 9, 17, 2, 2, 599, 598, 3, 2, 2, 2, 600, 601, 3, 2, 2, 2, 601, 599, 3, 2, 2, 2, 601, 602, 3, 2, 2, 2, 602, 603, 3, 2, 2, 2, 603, 604, 8, 67, 2, 2, 604, 134, 3, 2, 2, 2, 605, 607, 7, 15, 2, 2, 606, 608, 7, 12, 2, 2, 607, 606, 3, 2, 2, 2, 607, 608, 3, 2, 2, 2, 608, 611, 3, 2, 2, 2, 609, 611, 7, 12, 2, 2, 610, 605, 3, 2, 2, 2, 610, 609, 3, 2, 2, 2, 611, 612, 3, 2, 2, 2, 612, 613, 8, 68, 2, 2, 613, 136, 3, 2, 2, 2, 52, 2, 171, 185, 217, 223, 231, 246, 248, 279, 315, 351, 380, 386, 390, 395, 397, 405, 408, 412, 417, 420, 426, 432, 437, 442, 447, 456, 465, 476, 482, 486, 492, 520, 524, 529, 535, 540, 547, 551, 558, 561, 568, 573, 577, 586, 589, 596, 601, 607, 610, 3, 8, 2, 2] \ No newline at end of file +[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 48, 754, 8, 1, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 4, 20, 9, 20, 4, 21, 9, 21, 4, 22, 9, 22, 4, 23, 9, 23, 4, 24, 9, 24, 4, 25, 9, 25, 4, 26, 9, 26, 4, 27, 9, 27, 4, 28, 9, 28, 4, 29, 9, 29, 4, 30, 9, 30, 4, 31, 9, 31, 4, 32, 9, 32, 4, 33, 9, 33, 4, 34, 9, 34, 4, 35, 9, 35, 4, 36, 9, 36, 4, 37, 9, 37, 4, 38, 9, 38, 4, 39, 9, 39, 4, 40, 9, 40, 4, 41, 9, 41, 4, 42, 9, 42, 4, 43, 9, 43, 4, 44, 9, 44, 4, 45, 9, 45, 4, 46, 9, 46, 4, 47, 9, 47, 4, 48, 9, 48, 4, 49, 9, 49, 4, 50, 9, 50, 4, 51, 9, 51, 4, 52, 9, 52, 4, 53, 9, 53, 4, 54, 9, 54, 4, 55, 9, 55, 4, 56, 9, 56, 4, 57, 9, 57, 4, 58, 9, 58, 4, 59, 9, 59, 4, 60, 9, 60, 4, 61, 9, 61, 4, 62, 9, 62, 4, 63, 9, 63, 4, 64, 9, 64, 4, 65, 9, 65, 4, 66, 9, 66, 4, 67, 9, 67, 4, 68, 9, 68, 4, 69, 9, 69, 4, 70, 9, 70, 4, 71, 9, 71, 4, 72, 9, 72, 3, 2, 3, 2, 3, 3, 3, 3, 3, 4, 3, 4, 3, 5, 3, 5, 3, 6, 3, 6, 3, 7, 3, 7, 3, 8, 3, 8, 3, 8, 3, 9, 3, 9, 3, 10, 3, 10, 3, 10, 3, 11, 3, 11, 3, 11, 3, 12, 3, 12, 3, 12, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 5, 13, 180, 10, 13, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 5, 14, 194, 10, 14, 3, 15, 3, 15, 3, 16, 3, 16, 3, 17, 3, 17, 3, 18, 3, 18, 3, 19, 3, 19, 3, 20, 3, 20, 3, 20, 3, 21, 3, 21, 3, 21, 3, 22, 3, 22, 3, 22, 3, 23, 3, 23, 3, 24, 3, 24, 3, 25, 3, 25, 3, 26, 3, 26, 3, 26, 3, 26, 3, 26, 5, 26, 226, 10, 26, 3, 27, 3, 27, 3, 27, 3, 27, 5, 27, 232, 10, 27, 3, 28, 3, 28, 3, 29, 3, 29, 3, 29, 3, 29, 5, 29, 240, 10, 29, 3, 30, 3, 30, 3, 30, 3, 31, 3, 31, 3, 31, 3, 31, 3, 31, 3, 31, 3, 31, 3, 32, 3, 32, 3, 32, 7, 32, 255, 10, 32, 12, 32, 14, 32, 258, 11, 32, 3, 32, 3, 32, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 5, 33, 288, 10, 33, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 5, 34, 324, 10, 34, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 5, 35, 360, 10, 35, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 5, 36, 390, 10, 36, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 5, 37, 428, 10, 37, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 5, 38, 466, 10, 38, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 5, 39, 492, 10, 39, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 5, 40, 521, 10, 40, 3, 41, 3, 41, 3, 41, 3, 41, 5, 41, 527, 10, 41, 3, 42, 3, 42, 5, 42, 531, 10, 42, 3, 43, 3, 43, 3, 43, 7, 43, 536, 10, 43, 12, 43, 14, 43, 539, 11, 43, 3, 43, 3, 43, 3, 43, 3, 43, 3, 43, 5, 43, 546, 10, 43, 3, 44, 5, 44, 549, 10, 44, 3, 44, 3, 44, 5, 44, 553, 10, 44, 3, 44, 3, 44, 3, 44, 5, 44, 558, 10, 44, 3, 44, 5, 44, 561, 10, 44, 3, 45, 3, 45, 3, 45, 3, 45, 5, 45, 567, 10, 45, 3, 45, 3, 45, 6, 45, 571, 10, 45, 13, 45, 14, 45, 572, 3, 46, 3, 46, 3, 46, 5, 46, 578, 10, 46, 3, 47, 6, 47, 581, 10, 47, 13, 47, 14, 47, 582, 3, 48, 6, 48, 586, 10, 48, 13, 48, 14, 48, 587, 3, 49, 3, 49, 3, 49, 3, 49, 3, 49, 3, 49, 3, 49, 5, 49, 597, 10, 49, 3, 50, 3, 50, 3, 50, 3, 50, 3, 50, 3, 50, 3, 50, 5, 50, 606, 10, 50, 3, 51, 3, 51, 3, 52, 3, 52, 3, 53, 3, 53, 3, 53, 6, 53, 615, 10, 53, 13, 53, 14, 53, 616, 3, 54, 3, 54, 7, 54, 621, 10, 54, 12, 54, 14, 54, 624, 11, 54, 3, 54, 5, 54, 627, 10, 54, 3, 55, 3, 55, 7, 55, 631, 10, 55, 12, 55, 14, 55, 634, 11, 55, 3, 56, 3, 56, 3, 56, 3, 56, 3, 57, 3, 57, 3, 58, 3, 58, 3, 59, 3, 59, 3, 60, 3, 60, 3, 60, 3, 60, 3, 60, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 5, 61, 661, 10, 61, 3, 62, 3, 62, 5, 62, 665, 10, 62, 3, 62, 3, 62, 3, 62, 5, 62, 670, 10, 62, 3, 63, 3, 63, 3, 63, 3, 63, 5, 63, 676, 10, 63, 3, 63, 3, 63, 3, 64, 5, 64, 681, 10, 64, 3, 64, 3, 64, 3, 64, 3, 64, 3, 64, 5, 64, 688, 10, 64, 3, 65, 3, 65, 5, 65, 692, 10, 65, 3, 65, 3, 65, 3, 66, 6, 66, 697, 10, 66, 13, 66, 14, 66, 698, 3, 67, 5, 67, 702, 10, 67, 3, 67, 3, 67, 3, 67, 3, 67, 3, 67, 5, 67, 709, 10, 67, 3, 68, 6, 68, 712, 10, 68, 13, 68, 14, 68, 713, 3, 69, 3, 69, 5, 69, 718, 10, 69, 3, 69, 3, 69, 3, 70, 3, 70, 3, 70, 3, 70, 3, 70, 5, 70, 727, 10, 70, 3, 70, 5, 70, 730, 10, 70, 3, 70, 3, 70, 3, 70, 3, 70, 3, 70, 5, 70, 737, 10, 70, 3, 71, 6, 71, 740, 10, 71, 13, 71, 14, 71, 741, 3, 71, 3, 71, 3, 72, 3, 72, 5, 72, 748, 10, 72, 3, 72, 5, 72, 751, 10, 72, 3, 72, 3, 72, 2, 2, 73, 3, 3, 5, 4, 7, 5, 9, 6, 11, 7, 13, 8, 15, 9, 17, 10, 19, 11, 21, 12, 23, 13, 25, 14, 27, 15, 29, 16, 31, 17, 33, 18, 35, 19, 37, 20, 39, 21, 41, 22, 43, 23, 45, 24, 47, 25, 49, 26, 51, 27, 53, 28, 55, 29, 57, 30, 59, 31, 61, 32, 63, 33, 65, 34, 67, 35, 69, 36, 71, 37, 73, 38, 75, 39, 77, 40, 79, 41, 81, 42, 83, 43, 85, 44, 87, 45, 89, 46, 91, 2, 93, 2, 95, 2, 97, 2, 99, 2, 101, 2, 103, 2, 105, 2, 107, 2, 109, 2, 111, 2, 113, 2, 115, 2, 117, 2, 119, 2, 121, 2, 123, 2, 125, 2, 127, 2, 129, 2, 131, 2, 133, 2, 135, 2, 137, 2, 139, 2, 141, 47, 143, 48, 3, 2, 18, 5, 2, 78, 78, 87, 87, 119, 119, 6, 2, 12, 12, 15, 15, 36, 36, 94, 94, 6, 2, 12, 12, 15, 15, 41, 41, 94, 94, 5, 2, 67, 92, 97, 97, 99, 124, 3, 2, 50, 59, 4, 2, 68, 68, 100, 100, 3, 2, 50, 51, 4, 2, 90, 90, 122, 122, 3, 2, 51, 59, 3, 2, 50, 57, 5, 2, 50, 59, 67, 72, 99, 104, 4, 2, 71, 71, 103, 103, 4, 2, 45, 45, 47, 47, 4, 2, 82, 82, 114, 114, 12, 2, 36, 36, 41, 41, 65, 65, 94, 94, 99, 100, 104, 104, 112, 112, 116, 116, 118, 118, 120, 120, 4, 2, 11, 11, 34, 34, 2, 793, 2, 3, 3, 2, 2, 2, 2, 5, 3, 2, 2, 2, 2, 7, 3, 2, 2, 2, 2, 9, 3, 2, 2, 2, 2, 11, 3, 2, 2, 2, 2, 13, 3, 2, 2, 2, 2, 15, 3, 2, 2, 2, 2, 17, 3, 2, 2, 2, 2, 19, 3, 2, 2, 2, 2, 21, 3, 2, 2, 2, 2, 23, 3, 2, 2, 2, 2, 25, 3, 2, 2, 2, 2, 27, 3, 2, 2, 2, 2, 29, 3, 2, 2, 2, 2, 31, 3, 2, 2, 2, 2, 33, 3, 2, 2, 2, 2, 35, 3, 2, 2, 2, 2, 37, 3, 2, 2, 2, 2, 39, 3, 2, 2, 2, 2, 41, 3, 2, 2, 2, 2, 43, 3, 2, 2, 2, 2, 45, 3, 2, 2, 2, 2, 47, 3, 2, 2, 2, 2, 49, 3, 2, 2, 2, 2, 51, 3, 2, 2, 2, 2, 53, 3, 2, 2, 2, 2, 55, 3, 2, 2, 2, 2, 57, 3, 2, 2, 2, 2, 59, 3, 2, 2, 2, 2, 61, 3, 2, 2, 2, 2, 63, 3, 2, 2, 2, 2, 65, 3, 2, 2, 2, 2, 67, 3, 2, 2, 2, 2, 69, 3, 2, 2, 2, 2, 71, 3, 2, 2, 2, 2, 73, 3, 2, 2, 2, 2, 75, 3, 2, 2, 2, 2, 77, 3, 2, 2, 2, 2, 79, 3, 2, 2, 2, 2, 81, 3, 2, 2, 2, 2, 83, 3, 2, 2, 2, 2, 85, 3, 2, 2, 2, 2, 87, 3, 2, 2, 2, 2, 89, 3, 2, 2, 2, 2, 141, 3, 2, 2, 2, 2, 143, 3, 2, 2, 2, 3, 145, 3, 2, 2, 2, 5, 147, 3, 2, 2, 2, 7, 149, 3, 2, 2, 2, 9, 151, 3, 2, 2, 2, 11, 153, 3, 2, 2, 2, 13, 155, 3, 2, 2, 2, 15, 157, 3, 2, 2, 2, 17, 160, 3, 2, 2, 2, 19, 162, 3, 2, 2, 2, 21, 165, 3, 2, 2, 2, 23, 168, 3, 2, 2, 2, 25, 179, 3, 2, 2, 2, 27, 193, 3, 2, 2, 2, 29, 195, 3, 2, 2, 2, 31, 197, 3, 2, 2, 2, 33, 199, 3, 2, 2, 2, 35, 201, 3, 2, 2, 2, 37, 203, 3, 2, 2, 2, 39, 205, 3, 2, 2, 2, 41, 208, 3, 2, 2, 2, 43, 211, 3, 2, 2, 2, 45, 214, 3, 2, 2, 2, 47, 216, 3, 2, 2, 2, 49, 218, 3, 2, 2, 2, 51, 225, 3, 2, 2, 2, 53, 231, 3, 2, 2, 2, 55, 233, 3, 2, 2, 2, 57, 239, 3, 2, 2, 2, 59, 241, 3, 2, 2, 2, 61, 244, 3, 2, 2, 2, 63, 251, 3, 2, 2, 2, 65, 287, 3, 2, 2, 2, 67, 323, 3, 2, 2, 2, 69, 359, 3, 2, 2, 2, 71, 389, 3, 2, 2, 2, 73, 427, 3, 2, 2, 2, 75, 465, 3, 2, 2, 2, 77, 491, 3, 2, 2, 2, 79, 520, 3, 2, 2, 2, 81, 526, 3, 2, 2, 2, 83, 530, 3, 2, 2, 2, 85, 545, 3, 2, 2, 2, 87, 548, 3, 2, 2, 2, 89, 562, 3, 2, 2, 2, 91, 577, 3, 2, 2, 2, 93, 580, 3, 2, 2, 2, 95, 585, 3, 2, 2, 2, 97, 596, 3, 2, 2, 2, 99, 605, 3, 2, 2, 2, 101, 607, 3, 2, 2, 2, 103, 609, 3, 2, 2, 2, 105, 611, 3, 2, 2, 2, 107, 626, 3, 2, 2, 2, 109, 628, 3, 2, 2, 2, 111, 635, 3, 2, 2, 2, 113, 639, 3, 2, 2, 2, 115, 641, 3, 2, 2, 2, 117, 643, 3, 2, 2, 2, 119, 645, 3, 2, 2, 2, 121, 660, 3, 2, 2, 2, 123, 669, 3, 2, 2, 2, 125, 671, 3, 2, 2, 2, 127, 687, 3, 2, 2, 2, 129, 689, 3, 2, 2, 2, 131, 696, 3, 2, 2, 2, 133, 708, 3, 2, 2, 2, 135, 711, 3, 2, 2, 2, 137, 715, 3, 2, 2, 2, 139, 736, 3, 2, 2, 2, 141, 739, 3, 2, 2, 2, 143, 750, 3, 2, 2, 2, 145, 146, 7, 42, 2, 2, 146, 4, 3, 2, 2, 2, 147, 148, 7, 43, 2, 2, 148, 6, 3, 2, 2, 2, 149, 150, 7, 93, 2, 2, 150, 8, 3, 2, 2, 2, 151, 152, 7, 46, 2, 2, 152, 10, 3, 2, 2, 2, 153, 154, 7, 95, 2, 2, 154, 12, 3, 2, 2, 2, 155, 156, 7, 62, 2, 2, 156, 14, 3, 2, 2, 2, 157, 158, 7, 62, 2, 2, 158, 159, 7, 63, 2, 2, 159, 16, 3, 2, 2, 2, 160, 161, 7, 64, 2, 2, 161, 18, 3, 2, 2, 2, 162, 163, 7, 64, 2, 2, 163, 164, 7, 63, 2, 2, 164, 20, 3, 2, 2, 2, 165, 166, 7, 63, 2, 2, 166, 167, 7, 63, 2, 2, 167, 22, 3, 2, 2, 2, 168, 169, 7, 35, 2, 2, 169, 170, 7, 63, 2, 2, 170, 24, 3, 2, 2, 2, 171, 172, 7, 110, 2, 2, 172, 173, 7, 107, 2, 2, 173, 174, 7, 109, 2, 2, 174, 180, 7, 103, 2, 2, 175, 176, 7, 78, 2, 2, 176, 177, 7, 75, 2, 2, 177, 178, 7, 77, 2, 2, 178, 180, 7, 71, 2, 2, 179, 171, 3, 2, 2, 2, 179, 175, 3, 2, 2, 2, 180, 26, 3, 2, 2, 2, 181, 182, 7, 103, 2, 2, 182, 183, 7, 122, 2, 2, 183, 184, 7, 107, 2, 2, 184, 185, 7, 117, 2, 2, 185, 186, 7, 118, 2, 2, 186, 194, 7, 117, 2, 2, 187, 188, 7, 71, 2, 2, 188, 189, 7, 90, 2, 2, 189, 190, 7, 75, 2, 2, 190, 191, 7, 85, 2, 2, 191, 192, 7, 86, 2, 2, 192, 194, 7, 85, 2, 2, 193, 181, 3, 2, 2, 2, 193, 187, 3, 2, 2, 2, 194, 28, 3, 2, 2, 2, 195, 196, 7, 45, 2, 2, 196, 30, 3, 2, 2, 2, 197, 198, 7, 47, 2, 2, 198, 32, 3, 2, 2, 2, 199, 200, 7, 44, 2, 2, 200, 34, 3, 2, 2, 2, 201, 202, 7, 49, 2, 2, 202, 36, 3, 2, 2, 2, 203, 204, 7, 39, 2, 2, 204, 38, 3, 2, 2, 2, 205, 206, 7, 44, 2, 2, 206, 207, 7, 44, 2, 2, 207, 40, 3, 2, 2, 2, 208, 209, 7, 62, 2, 2, 209, 210, 7, 62, 2, 2, 210, 42, 3, 2, 2, 2, 211, 212, 7, 64, 2, 2, 212, 213, 7, 64, 2, 2, 213, 44, 3, 2, 2, 2, 214, 215, 7, 40, 2, 2, 215, 46, 3, 2, 2, 2, 216, 217, 7, 126, 2, 2, 217, 48, 3, 2, 2, 2, 218, 219, 7, 96, 2, 2, 219, 50, 3, 2, 2, 2, 220, 221, 7, 40, 2, 2, 221, 226, 7, 40, 2, 2, 222, 223, 7, 99, 2, 2, 223, 224, 7, 112, 2, 2, 224, 226, 7, 102, 2, 2, 225, 220, 3, 2, 2, 2, 225, 222, 3, 2, 2, 2, 226, 52, 3, 2, 2, 2, 227, 228, 7, 126, 2, 2, 228, 232, 7, 126, 2, 2, 229, 230, 7, 113, 2, 2, 230, 232, 7, 116, 2, 2, 231, 227, 3, 2, 2, 2, 231, 229, 3, 2, 2, 2, 232, 54, 3, 2, 2, 2, 233, 234, 7, 128, 2, 2, 234, 56, 3, 2, 2, 2, 235, 240, 7, 35, 2, 2, 236, 237, 7, 112, 2, 2, 237, 238, 7, 113, 2, 2, 238, 240, 7, 118, 2, 2, 239, 235, 3, 2, 2, 2, 239, 236, 3, 2, 2, 2, 240, 58, 3, 2, 2, 2, 241, 242, 7, 107, 2, 2, 242, 243, 7, 112, 2, 2, 243, 60, 3, 2, 2, 2, 244, 245, 7, 112, 2, 2, 245, 246, 7, 113, 2, 2, 246, 247, 7, 118, 2, 2, 247, 248, 7, 34, 2, 2, 248, 249, 7, 107, 2, 2, 249, 250, 7, 112, 2, 2, 250, 62, 3, 2, 2, 2, 251, 256, 7, 93, 2, 2, 252, 255, 5, 141, 71, 2, 253, 255, 5, 143, 72, 2, 254, 252, 3, 2, 2, 2, 254, 253, 3, 2, 2, 2, 255, 258, 3, 2, 2, 2, 256, 254, 3, 2, 2, 2, 256, 257, 3, 2, 2, 2, 257, 259, 3, 2, 2, 2, 258, 256, 3, 2, 2, 2, 259, 260, 7, 95, 2, 2, 260, 64, 3, 2, 2, 2, 261, 262, 7, 108, 2, 2, 262, 263, 7, 117, 2, 2, 263, 264, 7, 113, 2, 2, 264, 265, 7, 112, 2, 2, 265, 266, 7, 97, 2, 2, 266, 267, 7, 101, 2, 2, 267, 268, 7, 113, 2, 2, 268, 269, 7, 112, 2, 2, 269, 270, 7, 118, 2, 2, 270, 271, 7, 99, 2, 2, 271, 272, 7, 107, 2, 2, 272, 273, 7, 112, 2, 2, 273, 288, 7, 117, 2, 2, 274, 275, 7, 76, 2, 2, 275, 276, 7, 85, 2, 2, 276, 277, 7, 81, 2, 2, 277, 278, 7, 80, 2, 2, 278, 279, 7, 97, 2, 2, 279, 280, 7, 69, 2, 2, 280, 281, 7, 81, 2, 2, 281, 282, 7, 80, 2, 2, 282, 283, 7, 86, 2, 2, 283, 284, 7, 67, 2, 2, 284, 285, 7, 75, 2, 2, 285, 286, 7, 80, 2, 2, 286, 288, 7, 85, 2, 2, 287, 261, 3, 2, 2, 2, 287, 274, 3, 2, 2, 2, 288, 66, 3, 2, 2, 2, 289, 290, 7, 108, 2, 2, 290, 291, 7, 117, 2, 2, 291, 292, 7, 113, 2, 2, 292, 293, 7, 112, 2, 2, 293, 294, 7, 97, 2, 2, 294, 295, 7, 101, 2, 2, 295, 296, 7, 113, 2, 2, 296, 297, 7, 112, 2, 2, 297, 298, 7, 118, 2, 2, 298, 299, 7, 99, 2, 2, 299, 300, 7, 107, 2, 2, 300, 301, 7, 112, 2, 2, 301, 302, 7, 117, 2, 2, 302, 303, 7, 97, 2, 2, 303, 304, 7, 99, 2, 2, 304, 305, 7, 110, 2, 2, 305, 324, 7, 110, 2, 2, 306, 307, 7, 76, 2, 2, 307, 308, 7, 85, 2, 2, 308, 309, 7, 81, 2, 2, 309, 310, 7, 80, 2, 2, 310, 311, 7, 97, 2, 2, 311, 312, 7, 69, 2, 2, 312, 313, 7, 81, 2, 2, 313, 314, 7, 80, 2, 2, 314, 315, 7, 86, 2, 2, 315, 316, 7, 67, 2, 2, 316, 317, 7, 75, 2, 2, 317, 318, 7, 80, 2, 2, 318, 319, 7, 85, 2, 2, 319, 320, 7, 97, 2, 2, 320, 321, 7, 67, 2, 2, 321, 322, 7, 78, 2, 2, 322, 324, 7, 78, 2, 2, 323, 289, 3, 2, 2, 2, 323, 306, 3, 2, 2, 2, 324, 68, 3, 2, 2, 2, 325, 326, 7, 108, 2, 2, 326, 327, 7, 117, 2, 2, 327, 328, 7, 113, 2, 2, 328, 329, 7, 112, 2, 2, 329, 330, 7, 97, 2, 2, 330, 331, 7, 101, 2, 2, 331, 332, 7, 113, 2, 2, 332, 333, 7, 112, 2, 2, 333, 334, 7, 118, 2, 2, 334, 335, 7, 99, 2, 2, 335, 336, 7, 107, 2, 2, 336, 337, 7, 112, 2, 2, 337, 338, 7, 117, 2, 2, 338, 339, 7, 97, 2, 2, 339, 340, 7, 99, 2, 2, 340, 341, 7, 112, 2, 2, 341, 360, 7, 123, 2, 2, 342, 343, 7, 76, 2, 2, 343, 344, 7, 85, 2, 2, 344, 345, 7, 81, 2, 2, 345, 346, 7, 80, 2, 2, 346, 347, 7, 97, 2, 2, 347, 348, 7, 69, 2, 2, 348, 349, 7, 81, 2, 2, 349, 350, 7, 80, 2, 2, 350, 351, 7, 86, 2, 2, 351, 352, 7, 67, 2, 2, 352, 353, 7, 75, 2, 2, 353, 354, 7, 80, 2, 2, 354, 355, 7, 85, 2, 2, 355, 356, 7, 97, 2, 2, 356, 357, 7, 67, 2, 2, 357, 358, 7, 80, 2, 2, 358, 360, 7, 91, 2, 2, 359, 325, 3, 2, 2, 2, 359, 342, 3, 2, 2, 2, 360, 70, 3, 2, 2, 2, 361, 362, 7, 99, 2, 2, 362, 363, 7, 116, 2, 2, 363, 364, 7, 116, 2, 2, 364, 365, 7, 99, 2, 2, 365, 366, 7, 123, 2, 2, 366, 367, 7, 97, 2, 2, 367, 368, 7, 101, 2, 2, 368, 369, 7, 113, 2, 2, 369, 370, 7, 112, 2, 2, 370, 371, 7, 118, 2, 2, 371, 372, 7, 99, 2, 2, 372, 373, 7, 107, 2, 2, 373, 374, 7, 112, 2, 2, 374, 390, 7, 117, 2, 2, 375, 376, 7, 67, 2, 2, 376, 377, 7, 84, 2, 2, 377, 378, 7, 84, 2, 2, 378, 379, 7, 67, 2, 2, 379, 380, 7, 91, 2, 2, 380, 381, 7, 97, 2, 2, 381, 382, 7, 69, 2, 2, 382, 383, 7, 81, 2, 2, 383, 384, 7, 80, 2, 2, 384, 385, 7, 86, 2, 2, 385, 386, 7, 67, 2, 2, 386, 387, 7, 75, 2, 2, 387, 388, 7, 80, 2, 2, 388, 390, 7, 85, 2, 2, 389, 361, 3, 2, 2, 2, 389, 375, 3, 2, 2, 2, 390, 72, 3, 2, 2, 2, 391, 392, 7, 99, 2, 2, 392, 393, 7, 116, 2, 2, 393, 394, 7, 116, 2, 2, 394, 395, 7, 99, 2, 2, 395, 396, 7, 123, 2, 2, 396, 397, 7, 97, 2, 2, 397, 398, 7, 101, 2, 2, 398, 399, 7, 113, 2, 2, 399, 400, 7, 112, 2, 2, 400, 401, 7, 118, 2, 2, 401, 402, 7, 99, 2, 2, 402, 403, 7, 107, 2, 2, 403, 404, 7, 112, 2, 2, 404, 405, 7, 117, 2, 2, 405, 406, 7, 97, 2, 2, 406, 407, 7, 99, 2, 2, 407, 408, 7, 110, 2, 2, 408, 428, 7, 110, 2, 2, 409, 410, 7, 67, 2, 2, 410, 411, 7, 84, 2, 2, 411, 412, 7, 84, 2, 2, 412, 413, 7, 67, 2, 2, 413, 414, 7, 91, 2, 2, 414, 415, 7, 97, 2, 2, 415, 416, 7, 69, 2, 2, 416, 417, 7, 81, 2, 2, 417, 418, 7, 80, 2, 2, 418, 419, 7, 86, 2, 2, 419, 420, 7, 67, 2, 2, 420, 421, 7, 75, 2, 2, 421, 422, 7, 80, 2, 2, 422, 423, 7, 85, 2, 2, 423, 424, 7, 97, 2, 2, 424, 425, 7, 67, 2, 2, 425, 426, 7, 78, 2, 2, 426, 428, 7, 78, 2, 2, 427, 391, 3, 2, 2, 2, 427, 409, 3, 2, 2, 2, 428, 74, 3, 2, 2, 2, 429, 430, 7, 99, 2, 2, 430, 431, 7, 116, 2, 2, 431, 432, 7, 116, 2, 2, 432, 433, 7, 99, 2, 2, 433, 434, 7, 123, 2, 2, 434, 435, 7, 97, 2, 2, 435, 436, 7, 101, 2, 2, 436, 437, 7, 113, 2, 2, 437, 438, 7, 112, 2, 2, 438, 439, 7, 118, 2, 2, 439, 440, 7, 99, 2, 2, 440, 441, 7, 107, 2, 2, 441, 442, 7, 112, 2, 2, 442, 443, 7, 117, 2, 2, 443, 444, 7, 97, 2, 2, 444, 445, 7, 99, 2, 2, 445, 446, 7, 112, 2, 2, 446, 466, 7, 123, 2, 2, 447, 448, 7, 67, 2, 2, 448, 449, 7, 84, 2, 2, 449, 450, 7, 84, 2, 2, 450, 451, 7, 67, 2, 2, 451, 452, 7, 91, 2, 2, 452, 453, 7, 97, 2, 2, 453, 454, 7, 69, 2, 2, 454, 455, 7, 81, 2, 2, 455, 456, 7, 80, 2, 2, 456, 457, 7, 86, 2, 2, 457, 458, 7, 67, 2, 2, 458, 459, 7, 75, 2, 2, 459, 460, 7, 80, 2, 2, 460, 461, 7, 85, 2, 2, 461, 462, 7, 97, 2, 2, 462, 463, 7, 67, 2, 2, 463, 464, 7, 80, 2, 2, 464, 466, 7, 91, 2, 2, 465, 429, 3, 2, 2, 2, 465, 447, 3, 2, 2, 2, 466, 76, 3, 2, 2, 2, 467, 468, 7, 99, 2, 2, 468, 469, 7, 116, 2, 2, 469, 470, 7, 116, 2, 2, 470, 471, 7, 99, 2, 2, 471, 472, 7, 123, 2, 2, 472, 473, 7, 97, 2, 2, 473, 474, 7, 110, 2, 2, 474, 475, 7, 103, 2, 2, 475, 476, 7, 112, 2, 2, 476, 477, 7, 105, 2, 2, 477, 478, 7, 118, 2, 2, 478, 492, 7, 106, 2, 2, 479, 480, 7, 67, 2, 2, 480, 481, 7, 84, 2, 2, 481, 482, 7, 84, 2, 2, 482, 483, 7, 67, 2, 2, 483, 484, 7, 91, 2, 2, 484, 485, 7, 97, 2, 2, 485, 486, 7, 78, 2, 2, 486, 487, 7, 71, 2, 2, 487, 488, 7, 80, 2, 2, 488, 489, 7, 73, 2, 2, 489, 490, 7, 86, 2, 2, 490, 492, 7, 74, 2, 2, 491, 467, 3, 2, 2, 2, 491, 479, 3, 2, 2, 2, 492, 78, 3, 2, 2, 2, 493, 494, 7, 118, 2, 2, 494, 495, 7, 116, 2, 2, 495, 496, 7, 119, 2, 2, 496, 521, 7, 103, 2, 2, 497, 498, 7, 86, 2, 2, 498, 499, 7, 116, 2, 2, 499, 500, 7, 119, 2, 2, 500, 521, 7, 103, 2, 2, 501, 502, 7, 86, 2, 2, 502, 503, 7, 84, 2, 2, 503, 504, 7, 87, 2, 2, 504, 521, 7, 71, 2, 2, 505, 506, 7, 104, 2, 2, 506, 507, 7, 99, 2, 2, 507, 508, 7, 110, 2, 2, 508, 509, 7, 117, 2, 2, 509, 521, 7, 103, 2, 2, 510, 511, 7, 72, 2, 2, 511, 512, 7, 99, 2, 2, 512, 513, 7, 110, 2, 2, 513, 514, 7, 117, 2, 2, 514, 521, 7, 103, 2, 2, 515, 516, 7, 72, 2, 2, 516, 517, 7, 67, 2, 2, 517, 518, 7, 78, 2, 2, 518, 519, 7, 85, 2, 2, 519, 521, 7, 71, 2, 2, 520, 493, 3, 2, 2, 2, 520, 497, 3, 2, 2, 2, 520, 501, 3, 2, 2, 2, 520, 505, 3, 2, 2, 2, 520, 510, 3, 2, 2, 2, 520, 515, 3, 2, 2, 2, 521, 80, 3, 2, 2, 2, 522, 527, 5, 107, 54, 2, 523, 527, 5, 109, 55, 2, 524, 527, 5, 111, 56, 2, 525, 527, 5, 105, 53, 2, 526, 522, 3, 2, 2, 2, 526, 523, 3, 2, 2, 2, 526, 524, 3, 2, 2, 2, 526, 525, 3, 2, 2, 2, 527, 82, 3, 2, 2, 2, 528, 531, 5, 123, 62, 2, 529, 531, 5, 125, 63, 2, 530, 528, 3, 2, 2, 2, 530, 529, 3, 2, 2, 2, 531, 84, 3, 2, 2, 2, 532, 537, 5, 101, 51, 2, 533, 536, 5, 101, 51, 2, 534, 536, 5, 103, 52, 2, 535, 533, 3, 2, 2, 2, 535, 534, 3, 2, 2, 2, 536, 539, 3, 2, 2, 2, 537, 535, 3, 2, 2, 2, 537, 538, 3, 2, 2, 2, 538, 546, 3, 2, 2, 2, 539, 537, 3, 2, 2, 2, 540, 541, 7, 38, 2, 2, 541, 542, 7, 111, 2, 2, 542, 543, 7, 103, 2, 2, 543, 544, 7, 118, 2, 2, 544, 546, 7, 99, 2, 2, 545, 532, 3, 2, 2, 2, 545, 540, 3, 2, 2, 2, 546, 86, 3, 2, 2, 2, 547, 549, 5, 91, 46, 2, 548, 547, 3, 2, 2, 2, 548, 549, 3, 2, 2, 2, 549, 560, 3, 2, 2, 2, 550, 552, 7, 36, 2, 2, 551, 553, 5, 93, 47, 2, 552, 551, 3, 2, 2, 2, 552, 553, 3, 2, 2, 2, 553, 554, 3, 2, 2, 2, 554, 561, 7, 36, 2, 2, 555, 557, 7, 41, 2, 2, 556, 558, 5, 95, 48, 2, 557, 556, 3, 2, 2, 2, 557, 558, 3, 2, 2, 2, 558, 559, 3, 2, 2, 2, 559, 561, 7, 41, 2, 2, 560, 550, 3, 2, 2, 2, 560, 555, 3, 2, 2, 2, 561, 88, 3, 2, 2, 2, 562, 570, 5, 85, 43, 2, 563, 566, 7, 93, 2, 2, 564, 567, 5, 87, 44, 2, 565, 567, 5, 107, 54, 2, 566, 564, 3, 2, 2, 2, 566, 565, 3, 2, 2, 2, 567, 568, 3, 2, 2, 2, 568, 569, 7, 95, 2, 2, 569, 571, 3, 2, 2, 2, 570, 563, 3, 2, 2, 2, 571, 572, 3, 2, 2, 2, 572, 570, 3, 2, 2, 2, 572, 573, 3, 2, 2, 2, 573, 90, 3, 2, 2, 2, 574, 575, 7, 119, 2, 2, 575, 578, 7, 58, 2, 2, 576, 578, 9, 2, 2, 2, 577, 574, 3, 2, 2, 2, 577, 576, 3, 2, 2, 2, 578, 92, 3, 2, 2, 2, 579, 581, 5, 97, 49, 2, 580, 579, 3, 2, 2, 2, 581, 582, 3, 2, 2, 2, 582, 580, 3, 2, 2, 2, 582, 583, 3, 2, 2, 2, 583, 94, 3, 2, 2, 2, 584, 586, 5, 99, 50, 2, 585, 584, 3, 2, 2, 2, 586, 587, 3, 2, 2, 2, 587, 585, 3, 2, 2, 2, 587, 588, 3, 2, 2, 2, 588, 96, 3, 2, 2, 2, 589, 597, 10, 3, 2, 2, 590, 597, 5, 139, 70, 2, 591, 592, 7, 94, 2, 2, 592, 597, 7, 12, 2, 2, 593, 594, 7, 94, 2, 2, 594, 595, 7, 15, 2, 2, 595, 597, 7, 12, 2, 2, 596, 589, 3, 2, 2, 2, 596, 590, 3, 2, 2, 2, 596, 591, 3, 2, 2, 2, 596, 593, 3, 2, 2, 2, 597, 98, 3, 2, 2, 2, 598, 606, 10, 4, 2, 2, 599, 606, 5, 139, 70, 2, 600, 601, 7, 94, 2, 2, 601, 606, 7, 12, 2, 2, 602, 603, 7, 94, 2, 2, 603, 604, 7, 15, 2, 2, 604, 606, 7, 12, 2, 2, 605, 598, 3, 2, 2, 2, 605, 599, 3, 2, 2, 2, 605, 600, 3, 2, 2, 2, 605, 602, 3, 2, 2, 2, 606, 100, 3, 2, 2, 2, 607, 608, 9, 5, 2, 2, 608, 102, 3, 2, 2, 2, 609, 610, 9, 6, 2, 2, 610, 104, 3, 2, 2, 2, 611, 612, 7, 50, 2, 2, 612, 614, 9, 7, 2, 2, 613, 615, 9, 8, 2, 2, 614, 613, 3, 2, 2, 2, 615, 616, 3, 2, 2, 2, 616, 614, 3, 2, 2, 2, 616, 617, 3, 2, 2, 2, 617, 106, 3, 2, 2, 2, 618, 622, 5, 113, 57, 2, 619, 621, 5, 103, 52, 2, 620, 619, 3, 2, 2, 2, 621, 624, 3, 2, 2, 2, 622, 620, 3, 2, 2, 2, 622, 623, 3, 2, 2, 2, 623, 627, 3, 2, 2, 2, 624, 622, 3, 2, 2, 2, 625, 627, 7, 50, 2, 2, 626, 618, 3, 2, 2, 2, 626, 625, 3, 2, 2, 2, 627, 108, 3, 2, 2, 2, 628, 632, 7, 50, 2, 2, 629, 631, 5, 115, 58, 2, 630, 629, 3, 2, 2, 2, 631, 634, 3, 2, 2, 2, 632, 630, 3, 2, 2, 2, 632, 633, 3, 2, 2, 2, 633, 110, 3, 2, 2, 2, 634, 632, 3, 2, 2, 2, 635, 636, 7, 50, 2, 2, 636, 637, 9, 9, 2, 2, 637, 638, 5, 135, 68, 2, 638, 112, 3, 2, 2, 2, 639, 640, 9, 10, 2, 2, 640, 114, 3, 2, 2, 2, 641, 642, 9, 11, 2, 2, 642, 116, 3, 2, 2, 2, 643, 644, 9, 12, 2, 2, 644, 118, 3, 2, 2, 2, 645, 646, 5, 117, 59, 2, 646, 647, 5, 117, 59, 2, 647, 648, 5, 117, 59, 2, 648, 649, 5, 117, 59, 2, 649, 120, 3, 2, 2, 2, 650, 651, 7, 94, 2, 2, 651, 652, 7, 119, 2, 2, 652, 653, 3, 2, 2, 2, 653, 661, 5, 119, 60, 2, 654, 655, 7, 94, 2, 2, 655, 656, 7, 87, 2, 2, 656, 657, 3, 2, 2, 2, 657, 658, 5, 119, 60, 2, 658, 659, 5, 119, 60, 2, 659, 661, 3, 2, 2, 2, 660, 650, 3, 2, 2, 2, 660, 654, 3, 2, 2, 2, 661, 122, 3, 2, 2, 2, 662, 664, 5, 127, 64, 2, 663, 665, 5, 129, 65, 2, 664, 663, 3, 2, 2, 2, 664, 665, 3, 2, 2, 2, 665, 670, 3, 2, 2, 2, 666, 667, 5, 131, 66, 2, 667, 668, 5, 129, 65, 2, 668, 670, 3, 2, 2, 2, 669, 662, 3, 2, 2, 2, 669, 666, 3, 2, 2, 2, 670, 124, 3, 2, 2, 2, 671, 672, 7, 50, 2, 2, 672, 675, 9, 9, 2, 2, 673, 676, 5, 133, 67, 2, 674, 676, 5, 135, 68, 2, 675, 673, 3, 2, 2, 2, 675, 674, 3, 2, 2, 2, 676, 677, 3, 2, 2, 2, 677, 678, 5, 137, 69, 2, 678, 126, 3, 2, 2, 2, 679, 681, 5, 131, 66, 2, 680, 679, 3, 2, 2, 2, 680, 681, 3, 2, 2, 2, 681, 682, 3, 2, 2, 2, 682, 683, 7, 48, 2, 2, 683, 688, 5, 131, 66, 2, 684, 685, 5, 131, 66, 2, 685, 686, 7, 48, 2, 2, 686, 688, 3, 2, 2, 2, 687, 680, 3, 2, 2, 2, 687, 684, 3, 2, 2, 2, 688, 128, 3, 2, 2, 2, 689, 691, 9, 13, 2, 2, 690, 692, 9, 14, 2, 2, 691, 690, 3, 2, 2, 2, 691, 692, 3, 2, 2, 2, 692, 693, 3, 2, 2, 2, 693, 694, 5, 131, 66, 2, 694, 130, 3, 2, 2, 2, 695, 697, 5, 103, 52, 2, 696, 695, 3, 2, 2, 2, 697, 698, 3, 2, 2, 2, 698, 696, 3, 2, 2, 2, 698, 699, 3, 2, 2, 2, 699, 132, 3, 2, 2, 2, 700, 702, 5, 135, 68, 2, 701, 700, 3, 2, 2, 2, 701, 702, 3, 2, 2, 2, 702, 703, 3, 2, 2, 2, 703, 704, 7, 48, 2, 2, 704, 709, 5, 135, 68, 2, 705, 706, 5, 135, 68, 2, 706, 707, 7, 48, 2, 2, 707, 709, 3, 2, 2, 2, 708, 701, 3, 2, 2, 2, 708, 705, 3, 2, 2, 2, 709, 134, 3, 2, 2, 2, 710, 712, 5, 117, 59, 2, 711, 710, 3, 2, 2, 2, 712, 713, 3, 2, 2, 2, 713, 711, 3, 2, 2, 2, 713, 714, 3, 2, 2, 2, 714, 136, 3, 2, 2, 2, 715, 717, 9, 15, 2, 2, 716, 718, 9, 14, 2, 2, 717, 716, 3, 2, 2, 2, 717, 718, 3, 2, 2, 2, 718, 719, 3, 2, 2, 2, 719, 720, 5, 131, 66, 2, 720, 138, 3, 2, 2, 2, 721, 722, 7, 94, 2, 2, 722, 737, 9, 16, 2, 2, 723, 724, 7, 94, 2, 2, 724, 726, 5, 115, 58, 2, 725, 727, 5, 115, 58, 2, 726, 725, 3, 2, 2, 2, 726, 727, 3, 2, 2, 2, 727, 729, 3, 2, 2, 2, 728, 730, 5, 115, 58, 2, 729, 728, 3, 2, 2, 2, 729, 730, 3, 2, 2, 2, 730, 737, 3, 2, 2, 2, 731, 732, 7, 94, 2, 2, 732, 733, 7, 122, 2, 2, 733, 734, 3, 2, 2, 2, 734, 737, 5, 135, 68, 2, 735, 737, 5, 121, 61, 2, 736, 721, 3, 2, 2, 2, 736, 723, 3, 2, 2, 2, 736, 731, 3, 2, 2, 2, 736, 735, 3, 2, 2, 2, 737, 140, 3, 2, 2, 2, 738, 740, 9, 17, 2, 2, 739, 738, 3, 2, 2, 2, 740, 741, 3, 2, 2, 2, 741, 739, 3, 2, 2, 2, 741, 742, 3, 2, 2, 2, 742, 743, 3, 2, 2, 2, 743, 744, 8, 71, 2, 2, 744, 142, 3, 2, 2, 2, 745, 747, 7, 15, 2, 2, 746, 748, 7, 12, 2, 2, 747, 746, 3, 2, 2, 2, 747, 748, 3, 2, 2, 2, 748, 751, 3, 2, 2, 2, 749, 751, 7, 12, 2, 2, 750, 745, 3, 2, 2, 2, 750, 749, 3, 2, 2, 2, 751, 752, 3, 2, 2, 2, 752, 753, 8, 72, 2, 2, 753, 144, 3, 2, 2, 2, 56, 2, 179, 193, 225, 231, 239, 254, 256, 287, 323, 359, 389, 427, 465, 491, 520, 526, 530, 535, 537, 545, 548, 552, 557, 560, 566, 572, 577, 582, 587, 596, 605, 616, 622, 626, 632, 660, 664, 669, 675, 680, 687, 691, 698, 701, 708, 713, 717, 726, 729, 736, 741, 747, 750, 3, 8, 2, 2] \ No newline at end of file diff --git a/internal/parser/planparserv2/generated/PlanLexer.tokens b/internal/parser/planparserv2/generated/PlanLexer.tokens index ca7b53db2ebfa..e808c9b6391b3 100644 --- a/internal/parser/planparserv2/generated/PlanLexer.tokens +++ b/internal/parser/planparserv2/generated/PlanLexer.tokens @@ -32,14 +32,18 @@ EmptyTerm=31 JSONContains=32 JSONContainsAll=33 JSONContainsAny=34 -BooleanConstant=35 -IntegerConstant=36 -FloatingConstant=37 -Identifier=38 -StringLiteral=39 -JSONIdentifier=40 -Whitespace=41 -Newline=42 +ArrayContains=35 +ArrayContainsAll=36 +ArrayContainsAny=37 +ArrayLength=38 +BooleanConstant=39 +IntegerConstant=40 +FloatingConstant=41 +Identifier=42 +StringLiteral=43 +JSONIdentifier=44 +Whitespace=45 +Newline=46 '('=1 ')'=2 '['=3 diff --git a/internal/parser/planparserv2/generated/plan_base_visitor.go b/internal/parser/planparserv2/generated/plan_base_visitor.go index 6eb76d4f45ca9..8752aa2555122 100644 --- a/internal/parser/planparserv2/generated/plan_base_visitor.go +++ b/internal/parser/planparserv2/generated/plan_base_visitor.go @@ -75,6 +75,10 @@ func (v *BasePlanVisitor) VisitRelational(ctx *RelationalContext) interface{} { return v.VisitChildren(ctx) } +func (v *BasePlanVisitor) VisitArrayLength(ctx *ArrayLengthContext) interface{} { + return v.VisitChildren(ctx) +} + func (v *BasePlanVisitor) VisitTerm(ctx *TermContext) interface{} { return v.VisitChildren(ctx) } diff --git a/internal/parser/planparserv2/generated/plan_lexer.go b/internal/parser/planparserv2/generated/plan_lexer.go index 7b4bdfe5af407..cab7502bc7241 100644 --- a/internal/parser/planparserv2/generated/plan_lexer.go +++ b/internal/parser/planparserv2/generated/plan_lexer.go @@ -14,7 +14,7 @@ var _ = fmt.Printf var _ = unicode.IsLetter var serializedLexerAtn = []uint16{ - 3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 44, 614, + 3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 48, 754, 8, 1, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, @@ -27,275 +27,336 @@ var serializedLexerAtn = []uint16{ 49, 4, 50, 9, 50, 4, 51, 9, 51, 4, 52, 9, 52, 4, 53, 9, 53, 4, 54, 9, 54, 4, 55, 9, 55, 4, 56, 9, 56, 4, 57, 9, 57, 4, 58, 9, 58, 4, 59, 9, 59, 4, 60, 9, 60, 4, 61, 9, 61, 4, 62, 9, 62, 4, 63, 9, 63, 4, 64, 9, 64, 4, 65, - 9, 65, 4, 66, 9, 66, 4, 67, 9, 67, 4, 68, 9, 68, 3, 2, 3, 2, 3, 3, 3, 3, - 3, 4, 3, 4, 3, 5, 3, 5, 3, 6, 3, 6, 3, 7, 3, 7, 3, 8, 3, 8, 3, 8, 3, 9, - 3, 9, 3, 10, 3, 10, 3, 10, 3, 11, 3, 11, 3, 11, 3, 12, 3, 12, 3, 12, 3, - 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 5, 13, 172, 10, 13, - 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, - 14, 3, 14, 5, 14, 186, 10, 14, 3, 15, 3, 15, 3, 16, 3, 16, 3, 17, 3, 17, - 3, 18, 3, 18, 3, 19, 3, 19, 3, 20, 3, 20, 3, 20, 3, 21, 3, 21, 3, 21, 3, - 22, 3, 22, 3, 22, 3, 23, 3, 23, 3, 24, 3, 24, 3, 25, 3, 25, 3, 26, 3, 26, - 3, 26, 3, 26, 3, 26, 5, 26, 218, 10, 26, 3, 27, 3, 27, 3, 27, 3, 27, 5, - 27, 224, 10, 27, 3, 28, 3, 28, 3, 29, 3, 29, 3, 29, 3, 29, 5, 29, 232, - 10, 29, 3, 30, 3, 30, 3, 30, 3, 31, 3, 31, 3, 31, 3, 31, 3, 31, 3, 31, - 3, 31, 3, 32, 3, 32, 3, 32, 7, 32, 247, 10, 32, 12, 32, 14, 32, 250, 11, - 32, 3, 32, 3, 32, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, + 9, 65, 4, 66, 9, 66, 4, 67, 9, 67, 4, 68, 9, 68, 4, 69, 9, 69, 4, 70, 9, + 70, 4, 71, 9, 71, 4, 72, 9, 72, 3, 2, 3, 2, 3, 3, 3, 3, 3, 4, 3, 4, 3, + 5, 3, 5, 3, 6, 3, 6, 3, 7, 3, 7, 3, 8, 3, 8, 3, 8, 3, 9, 3, 9, 3, 10, 3, + 10, 3, 10, 3, 11, 3, 11, 3, 11, 3, 12, 3, 12, 3, 12, 3, 13, 3, 13, 3, 13, + 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 5, 13, 180, 10, 13, 3, 14, 3, 14, 3, + 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 3, 14, 5, 14, + 194, 10, 14, 3, 15, 3, 15, 3, 16, 3, 16, 3, 17, 3, 17, 3, 18, 3, 18, 3, + 19, 3, 19, 3, 20, 3, 20, 3, 20, 3, 21, 3, 21, 3, 21, 3, 22, 3, 22, 3, 22, + 3, 23, 3, 23, 3, 24, 3, 24, 3, 25, 3, 25, 3, 26, 3, 26, 3, 26, 3, 26, 3, + 26, 5, 26, 226, 10, 26, 3, 27, 3, 27, 3, 27, 3, 27, 5, 27, 232, 10, 27, + 3, 28, 3, 28, 3, 29, 3, 29, 3, 29, 3, 29, 5, 29, 240, 10, 29, 3, 30, 3, + 30, 3, 30, 3, 31, 3, 31, 3, 31, 3, 31, 3, 31, 3, 31, 3, 31, 3, 32, 3, 32, + 3, 32, 7, 32, 255, 10, 32, 12, 32, 14, 32, 258, 11, 32, 3, 32, 3, 32, 3, + 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, - 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 3, 33, 5, 33, 280, 10, 33, + 33, 3, 33, 3, 33, 3, 33, 3, 33, 5, 33, 288, 10, 33, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, - 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, - 34, 3, 34, 3, 34, 5, 34, 316, 10, 34, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, + 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 3, 34, 5, + 34, 324, 10, 34, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, - 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 5, 35, 352, 10, - 35, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, + 3, 35, 3, 35, 3, 35, 3, 35, 3, 35, 5, 35, 360, 10, 35, 3, 36, 3, 36, 3, + 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, - 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 3, 36, 5, 36, 381, 10, 36, 3, 37, - 3, 37, 3, 37, 3, 37, 5, 37, 387, 10, 37, 3, 38, 3, 38, 5, 38, 391, 10, - 38, 3, 39, 3, 39, 3, 39, 7, 39, 396, 10, 39, 12, 39, 14, 39, 399, 11, 39, - 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 5, 39, 406, 10, 39, 3, 40, 5, 40, 409, - 10, 40, 3, 40, 3, 40, 5, 40, 413, 10, 40, 3, 40, 3, 40, 3, 40, 5, 40, 418, - 10, 40, 3, 40, 5, 40, 421, 10, 40, 3, 41, 3, 41, 3, 41, 3, 41, 5, 41, 427, - 10, 41, 3, 41, 3, 41, 6, 41, 431, 10, 41, 13, 41, 14, 41, 432, 3, 42, 3, - 42, 3, 42, 5, 42, 438, 10, 42, 3, 43, 6, 43, 441, 10, 43, 13, 43, 14, 43, - 442, 3, 44, 6, 44, 446, 10, 44, 13, 44, 14, 44, 447, 3, 45, 3, 45, 3, 45, - 3, 45, 3, 45, 3, 45, 3, 45, 5, 45, 457, 10, 45, 3, 46, 3, 46, 3, 46, 3, - 46, 3, 46, 3, 46, 3, 46, 5, 46, 466, 10, 46, 3, 47, 3, 47, 3, 48, 3, 48, - 3, 49, 3, 49, 3, 49, 6, 49, 475, 10, 49, 13, 49, 14, 49, 476, 3, 50, 3, - 50, 7, 50, 481, 10, 50, 12, 50, 14, 50, 484, 11, 50, 3, 50, 5, 50, 487, - 10, 50, 3, 51, 3, 51, 7, 51, 491, 10, 51, 12, 51, 14, 51, 494, 11, 51, - 3, 52, 3, 52, 3, 52, 3, 52, 3, 53, 3, 53, 3, 54, 3, 54, 3, 55, 3, 55, 3, - 56, 3, 56, 3, 56, 3, 56, 3, 56, 3, 57, 3, 57, 3, 57, 3, 57, 3, 57, 3, 57, - 3, 57, 3, 57, 3, 57, 3, 57, 5, 57, 521, 10, 57, 3, 58, 3, 58, 5, 58, 525, - 10, 58, 3, 58, 3, 58, 3, 58, 5, 58, 530, 10, 58, 3, 59, 3, 59, 3, 59, 3, - 59, 5, 59, 536, 10, 59, 3, 59, 3, 59, 3, 60, 5, 60, 541, 10, 60, 3, 60, - 3, 60, 3, 60, 3, 60, 3, 60, 5, 60, 548, 10, 60, 3, 61, 3, 61, 5, 61, 552, - 10, 61, 3, 61, 3, 61, 3, 62, 6, 62, 557, 10, 62, 13, 62, 14, 62, 558, 3, - 63, 5, 63, 562, 10, 63, 3, 63, 3, 63, 3, 63, 3, 63, 3, 63, 5, 63, 569, - 10, 63, 3, 64, 6, 64, 572, 10, 64, 13, 64, 14, 64, 573, 3, 65, 3, 65, 5, - 65, 578, 10, 65, 3, 65, 3, 65, 3, 66, 3, 66, 3, 66, 3, 66, 3, 66, 5, 66, - 587, 10, 66, 3, 66, 5, 66, 590, 10, 66, 3, 66, 3, 66, 3, 66, 3, 66, 3, - 66, 5, 66, 597, 10, 66, 3, 67, 6, 67, 600, 10, 67, 13, 67, 14, 67, 601, - 3, 67, 3, 67, 3, 68, 3, 68, 5, 68, 608, 10, 68, 3, 68, 5, 68, 611, 10, - 68, 3, 68, 3, 68, 2, 2, 69, 3, 3, 5, 4, 7, 5, 9, 6, 11, 7, 13, 8, 15, 9, - 17, 10, 19, 11, 21, 12, 23, 13, 25, 14, 27, 15, 29, 16, 31, 17, 33, 18, - 35, 19, 37, 20, 39, 21, 41, 22, 43, 23, 45, 24, 47, 25, 49, 26, 51, 27, - 53, 28, 55, 29, 57, 30, 59, 31, 61, 32, 63, 33, 65, 34, 67, 35, 69, 36, - 71, 37, 73, 38, 75, 39, 77, 40, 79, 41, 81, 42, 83, 2, 85, 2, 87, 2, 89, - 2, 91, 2, 93, 2, 95, 2, 97, 2, 99, 2, 101, 2, 103, 2, 105, 2, 107, 2, 109, - 2, 111, 2, 113, 2, 115, 2, 117, 2, 119, 2, 121, 2, 123, 2, 125, 2, 127, - 2, 129, 2, 131, 2, 133, 43, 135, 44, 3, 2, 18, 5, 2, 78, 78, 87, 87, 119, - 119, 6, 2, 12, 12, 15, 15, 36, 36, 94, 94, 6, 2, 12, 12, 15, 15, 41, 41, - 94, 94, 5, 2, 67, 92, 97, 97, 99, 124, 3, 2, 50, 59, 4, 2, 68, 68, 100, - 100, 3, 2, 50, 51, 4, 2, 90, 90, 122, 122, 3, 2, 51, 59, 3, 2, 50, 57, - 5, 2, 50, 59, 67, 72, 99, 104, 4, 2, 71, 71, 103, 103, 4, 2, 45, 45, 47, - 47, 4, 2, 82, 82, 114, 114, 12, 2, 36, 36, 41, 41, 65, 65, 94, 94, 99, - 100, 104, 104, 112, 112, 116, 116, 118, 118, 120, 120, 4, 2, 11, 11, 34, - 34, 2, 649, 2, 3, 3, 2, 2, 2, 2, 5, 3, 2, 2, 2, 2, 7, 3, 2, 2, 2, 2, 9, - 3, 2, 2, 2, 2, 11, 3, 2, 2, 2, 2, 13, 3, 2, 2, 2, 2, 15, 3, 2, 2, 2, 2, - 17, 3, 2, 2, 2, 2, 19, 3, 2, 2, 2, 2, 21, 3, 2, 2, 2, 2, 23, 3, 2, 2, 2, - 2, 25, 3, 2, 2, 2, 2, 27, 3, 2, 2, 2, 2, 29, 3, 2, 2, 2, 2, 31, 3, 2, 2, - 2, 2, 33, 3, 2, 2, 2, 2, 35, 3, 2, 2, 2, 2, 37, 3, 2, 2, 2, 2, 39, 3, 2, - 2, 2, 2, 41, 3, 2, 2, 2, 2, 43, 3, 2, 2, 2, 2, 45, 3, 2, 2, 2, 2, 47, 3, - 2, 2, 2, 2, 49, 3, 2, 2, 2, 2, 51, 3, 2, 2, 2, 2, 53, 3, 2, 2, 2, 2, 55, - 3, 2, 2, 2, 2, 57, 3, 2, 2, 2, 2, 59, 3, 2, 2, 2, 2, 61, 3, 2, 2, 2, 2, - 63, 3, 2, 2, 2, 2, 65, 3, 2, 2, 2, 2, 67, 3, 2, 2, 2, 2, 69, 3, 2, 2, 2, - 2, 71, 3, 2, 2, 2, 2, 73, 3, 2, 2, 2, 2, 75, 3, 2, 2, 2, 2, 77, 3, 2, 2, - 2, 2, 79, 3, 2, 2, 2, 2, 81, 3, 2, 2, 2, 2, 133, 3, 2, 2, 2, 2, 135, 3, - 2, 2, 2, 3, 137, 3, 2, 2, 2, 5, 139, 3, 2, 2, 2, 7, 141, 3, 2, 2, 2, 9, - 143, 3, 2, 2, 2, 11, 145, 3, 2, 2, 2, 13, 147, 3, 2, 2, 2, 15, 149, 3, - 2, 2, 2, 17, 152, 3, 2, 2, 2, 19, 154, 3, 2, 2, 2, 21, 157, 3, 2, 2, 2, - 23, 160, 3, 2, 2, 2, 25, 171, 3, 2, 2, 2, 27, 185, 3, 2, 2, 2, 29, 187, - 3, 2, 2, 2, 31, 189, 3, 2, 2, 2, 33, 191, 3, 2, 2, 2, 35, 193, 3, 2, 2, - 2, 37, 195, 3, 2, 2, 2, 39, 197, 3, 2, 2, 2, 41, 200, 3, 2, 2, 2, 43, 203, - 3, 2, 2, 2, 45, 206, 3, 2, 2, 2, 47, 208, 3, 2, 2, 2, 49, 210, 3, 2, 2, - 2, 51, 217, 3, 2, 2, 2, 53, 223, 3, 2, 2, 2, 55, 225, 3, 2, 2, 2, 57, 231, - 3, 2, 2, 2, 59, 233, 3, 2, 2, 2, 61, 236, 3, 2, 2, 2, 63, 243, 3, 2, 2, - 2, 65, 279, 3, 2, 2, 2, 67, 315, 3, 2, 2, 2, 69, 351, 3, 2, 2, 2, 71, 380, - 3, 2, 2, 2, 73, 386, 3, 2, 2, 2, 75, 390, 3, 2, 2, 2, 77, 405, 3, 2, 2, - 2, 79, 408, 3, 2, 2, 2, 81, 422, 3, 2, 2, 2, 83, 437, 3, 2, 2, 2, 85, 440, - 3, 2, 2, 2, 87, 445, 3, 2, 2, 2, 89, 456, 3, 2, 2, 2, 91, 465, 3, 2, 2, - 2, 93, 467, 3, 2, 2, 2, 95, 469, 3, 2, 2, 2, 97, 471, 3, 2, 2, 2, 99, 486, - 3, 2, 2, 2, 101, 488, 3, 2, 2, 2, 103, 495, 3, 2, 2, 2, 105, 499, 3, 2, - 2, 2, 107, 501, 3, 2, 2, 2, 109, 503, 3, 2, 2, 2, 111, 505, 3, 2, 2, 2, - 113, 520, 3, 2, 2, 2, 115, 529, 3, 2, 2, 2, 117, 531, 3, 2, 2, 2, 119, - 547, 3, 2, 2, 2, 121, 549, 3, 2, 2, 2, 123, 556, 3, 2, 2, 2, 125, 568, - 3, 2, 2, 2, 127, 571, 3, 2, 2, 2, 129, 575, 3, 2, 2, 2, 131, 596, 3, 2, - 2, 2, 133, 599, 3, 2, 2, 2, 135, 610, 3, 2, 2, 2, 137, 138, 7, 42, 2, 2, - 138, 4, 3, 2, 2, 2, 139, 140, 7, 43, 2, 2, 140, 6, 3, 2, 2, 2, 141, 142, - 7, 93, 2, 2, 142, 8, 3, 2, 2, 2, 143, 144, 7, 46, 2, 2, 144, 10, 3, 2, - 2, 2, 145, 146, 7, 95, 2, 2, 146, 12, 3, 2, 2, 2, 147, 148, 7, 62, 2, 2, - 148, 14, 3, 2, 2, 2, 149, 150, 7, 62, 2, 2, 150, 151, 7, 63, 2, 2, 151, - 16, 3, 2, 2, 2, 152, 153, 7, 64, 2, 2, 153, 18, 3, 2, 2, 2, 154, 155, 7, - 64, 2, 2, 155, 156, 7, 63, 2, 2, 156, 20, 3, 2, 2, 2, 157, 158, 7, 63, - 2, 2, 158, 159, 7, 63, 2, 2, 159, 22, 3, 2, 2, 2, 160, 161, 7, 35, 2, 2, - 161, 162, 7, 63, 2, 2, 162, 24, 3, 2, 2, 2, 163, 164, 7, 110, 2, 2, 164, - 165, 7, 107, 2, 2, 165, 166, 7, 109, 2, 2, 166, 172, 7, 103, 2, 2, 167, - 168, 7, 78, 2, 2, 168, 169, 7, 75, 2, 2, 169, 170, 7, 77, 2, 2, 170, 172, - 7, 71, 2, 2, 171, 163, 3, 2, 2, 2, 171, 167, 3, 2, 2, 2, 172, 26, 3, 2, - 2, 2, 173, 174, 7, 103, 2, 2, 174, 175, 7, 122, 2, 2, 175, 176, 7, 107, - 2, 2, 176, 177, 7, 117, 2, 2, 177, 178, 7, 118, 2, 2, 178, 186, 7, 117, - 2, 2, 179, 180, 7, 71, 2, 2, 180, 181, 7, 90, 2, 2, 181, 182, 7, 75, 2, - 2, 182, 183, 7, 85, 2, 2, 183, 184, 7, 86, 2, 2, 184, 186, 7, 85, 2, 2, - 185, 173, 3, 2, 2, 2, 185, 179, 3, 2, 2, 2, 186, 28, 3, 2, 2, 2, 187, 188, - 7, 45, 2, 2, 188, 30, 3, 2, 2, 2, 189, 190, 7, 47, 2, 2, 190, 32, 3, 2, - 2, 2, 191, 192, 7, 44, 2, 2, 192, 34, 3, 2, 2, 2, 193, 194, 7, 49, 2, 2, - 194, 36, 3, 2, 2, 2, 195, 196, 7, 39, 2, 2, 196, 38, 3, 2, 2, 2, 197, 198, - 7, 44, 2, 2, 198, 199, 7, 44, 2, 2, 199, 40, 3, 2, 2, 2, 200, 201, 7, 62, - 2, 2, 201, 202, 7, 62, 2, 2, 202, 42, 3, 2, 2, 2, 203, 204, 7, 64, 2, 2, - 204, 205, 7, 64, 2, 2, 205, 44, 3, 2, 2, 2, 206, 207, 7, 40, 2, 2, 207, - 46, 3, 2, 2, 2, 208, 209, 7, 126, 2, 2, 209, 48, 3, 2, 2, 2, 210, 211, - 7, 96, 2, 2, 211, 50, 3, 2, 2, 2, 212, 213, 7, 40, 2, 2, 213, 218, 7, 40, - 2, 2, 214, 215, 7, 99, 2, 2, 215, 216, 7, 112, 2, 2, 216, 218, 7, 102, - 2, 2, 217, 212, 3, 2, 2, 2, 217, 214, 3, 2, 2, 2, 218, 52, 3, 2, 2, 2, - 219, 220, 7, 126, 2, 2, 220, 224, 7, 126, 2, 2, 221, 222, 7, 113, 2, 2, - 222, 224, 7, 116, 2, 2, 223, 219, 3, 2, 2, 2, 223, 221, 3, 2, 2, 2, 224, - 54, 3, 2, 2, 2, 225, 226, 7, 128, 2, 2, 226, 56, 3, 2, 2, 2, 227, 232, - 7, 35, 2, 2, 228, 229, 7, 112, 2, 2, 229, 230, 7, 113, 2, 2, 230, 232, - 7, 118, 2, 2, 231, 227, 3, 2, 2, 2, 231, 228, 3, 2, 2, 2, 232, 58, 3, 2, - 2, 2, 233, 234, 7, 107, 2, 2, 234, 235, 7, 112, 2, 2, 235, 60, 3, 2, 2, - 2, 236, 237, 7, 112, 2, 2, 237, 238, 7, 113, 2, 2, 238, 239, 7, 118, 2, - 2, 239, 240, 7, 34, 2, 2, 240, 241, 7, 107, 2, 2, 241, 242, 7, 112, 2, - 2, 242, 62, 3, 2, 2, 2, 243, 248, 7, 93, 2, 2, 244, 247, 5, 133, 67, 2, - 245, 247, 5, 135, 68, 2, 246, 244, 3, 2, 2, 2, 246, 245, 3, 2, 2, 2, 247, - 250, 3, 2, 2, 2, 248, 246, 3, 2, 2, 2, 248, 249, 3, 2, 2, 2, 249, 251, - 3, 2, 2, 2, 250, 248, 3, 2, 2, 2, 251, 252, 7, 95, 2, 2, 252, 64, 3, 2, - 2, 2, 253, 254, 7, 108, 2, 2, 254, 255, 7, 117, 2, 2, 255, 256, 7, 113, - 2, 2, 256, 257, 7, 112, 2, 2, 257, 258, 7, 97, 2, 2, 258, 259, 7, 101, - 2, 2, 259, 260, 7, 113, 2, 2, 260, 261, 7, 112, 2, 2, 261, 262, 7, 118, - 2, 2, 262, 263, 7, 99, 2, 2, 263, 264, 7, 107, 2, 2, 264, 265, 7, 112, - 2, 2, 265, 280, 7, 117, 2, 2, 266, 267, 7, 76, 2, 2, 267, 268, 7, 85, 2, - 2, 268, 269, 7, 81, 2, 2, 269, 270, 7, 80, 2, 2, 270, 271, 7, 97, 2, 2, - 271, 272, 7, 69, 2, 2, 272, 273, 7, 81, 2, 2, 273, 274, 7, 80, 2, 2, 274, - 275, 7, 86, 2, 2, 275, 276, 7, 67, 2, 2, 276, 277, 7, 75, 2, 2, 277, 278, - 7, 80, 2, 2, 278, 280, 7, 85, 2, 2, 279, 253, 3, 2, 2, 2, 279, 266, 3, - 2, 2, 2, 280, 66, 3, 2, 2, 2, 281, 282, 7, 108, 2, 2, 282, 283, 7, 117, - 2, 2, 283, 284, 7, 113, 2, 2, 284, 285, 7, 112, 2, 2, 285, 286, 7, 97, - 2, 2, 286, 287, 7, 101, 2, 2, 287, 288, 7, 113, 2, 2, 288, 289, 7, 112, - 2, 2, 289, 290, 7, 118, 2, 2, 290, 291, 7, 99, 2, 2, 291, 292, 7, 107, - 2, 2, 292, 293, 7, 112, 2, 2, 293, 294, 7, 117, 2, 2, 294, 295, 7, 97, - 2, 2, 295, 296, 7, 99, 2, 2, 296, 297, 7, 110, 2, 2, 297, 316, 7, 110, - 2, 2, 298, 299, 7, 76, 2, 2, 299, 300, 7, 85, 2, 2, 300, 301, 7, 81, 2, - 2, 301, 302, 7, 80, 2, 2, 302, 303, 7, 97, 2, 2, 303, 304, 7, 69, 2, 2, - 304, 305, 7, 81, 2, 2, 305, 306, 7, 80, 2, 2, 306, 307, 7, 86, 2, 2, 307, - 308, 7, 67, 2, 2, 308, 309, 7, 75, 2, 2, 309, 310, 7, 80, 2, 2, 310, 311, - 7, 85, 2, 2, 311, 312, 7, 97, 2, 2, 312, 313, 7, 67, 2, 2, 313, 314, 7, - 78, 2, 2, 314, 316, 7, 78, 2, 2, 315, 281, 3, 2, 2, 2, 315, 298, 3, 2, - 2, 2, 316, 68, 3, 2, 2, 2, 317, 318, 7, 108, 2, 2, 318, 319, 7, 117, 2, - 2, 319, 320, 7, 113, 2, 2, 320, 321, 7, 112, 2, 2, 321, 322, 7, 97, 2, - 2, 322, 323, 7, 101, 2, 2, 323, 324, 7, 113, 2, 2, 324, 325, 7, 112, 2, - 2, 325, 326, 7, 118, 2, 2, 326, 327, 7, 99, 2, 2, 327, 328, 7, 107, 2, - 2, 328, 329, 7, 112, 2, 2, 329, 330, 7, 117, 2, 2, 330, 331, 7, 97, 2, - 2, 331, 332, 7, 99, 2, 2, 332, 333, 7, 112, 2, 2, 333, 352, 7, 123, 2, - 2, 334, 335, 7, 76, 2, 2, 335, 336, 7, 85, 2, 2, 336, 337, 7, 81, 2, 2, - 337, 338, 7, 80, 2, 2, 338, 339, 7, 97, 2, 2, 339, 340, 7, 69, 2, 2, 340, - 341, 7, 81, 2, 2, 341, 342, 7, 80, 2, 2, 342, 343, 7, 86, 2, 2, 343, 344, - 7, 67, 2, 2, 344, 345, 7, 75, 2, 2, 345, 346, 7, 80, 2, 2, 346, 347, 7, - 85, 2, 2, 347, 348, 7, 97, 2, 2, 348, 349, 7, 67, 2, 2, 349, 350, 7, 80, - 2, 2, 350, 352, 7, 91, 2, 2, 351, 317, 3, 2, 2, 2, 351, 334, 3, 2, 2, 2, - 352, 70, 3, 2, 2, 2, 353, 354, 7, 118, 2, 2, 354, 355, 7, 116, 2, 2, 355, - 356, 7, 119, 2, 2, 356, 381, 7, 103, 2, 2, 357, 358, 7, 86, 2, 2, 358, - 359, 7, 116, 2, 2, 359, 360, 7, 119, 2, 2, 360, 381, 7, 103, 2, 2, 361, - 362, 7, 86, 2, 2, 362, 363, 7, 84, 2, 2, 363, 364, 7, 87, 2, 2, 364, 381, - 7, 71, 2, 2, 365, 366, 7, 104, 2, 2, 366, 367, 7, 99, 2, 2, 367, 368, 7, - 110, 2, 2, 368, 369, 7, 117, 2, 2, 369, 381, 7, 103, 2, 2, 370, 371, 7, - 72, 2, 2, 371, 372, 7, 99, 2, 2, 372, 373, 7, 110, 2, 2, 373, 374, 7, 117, - 2, 2, 374, 381, 7, 103, 2, 2, 375, 376, 7, 72, 2, 2, 376, 377, 7, 67, 2, - 2, 377, 378, 7, 78, 2, 2, 378, 379, 7, 85, 2, 2, 379, 381, 7, 71, 2, 2, - 380, 353, 3, 2, 2, 2, 380, 357, 3, 2, 2, 2, 380, 361, 3, 2, 2, 2, 380, - 365, 3, 2, 2, 2, 380, 370, 3, 2, 2, 2, 380, 375, 3, 2, 2, 2, 381, 72, 3, - 2, 2, 2, 382, 387, 5, 99, 50, 2, 383, 387, 5, 101, 51, 2, 384, 387, 5, - 103, 52, 2, 385, 387, 5, 97, 49, 2, 386, 382, 3, 2, 2, 2, 386, 383, 3, - 2, 2, 2, 386, 384, 3, 2, 2, 2, 386, 385, 3, 2, 2, 2, 387, 74, 3, 2, 2, - 2, 388, 391, 5, 115, 58, 2, 389, 391, 5, 117, 59, 2, 390, 388, 3, 2, 2, - 2, 390, 389, 3, 2, 2, 2, 391, 76, 3, 2, 2, 2, 392, 397, 5, 93, 47, 2, 393, - 396, 5, 93, 47, 2, 394, 396, 5, 95, 48, 2, 395, 393, 3, 2, 2, 2, 395, 394, - 3, 2, 2, 2, 396, 399, 3, 2, 2, 2, 397, 395, 3, 2, 2, 2, 397, 398, 3, 2, - 2, 2, 398, 406, 3, 2, 2, 2, 399, 397, 3, 2, 2, 2, 400, 401, 7, 38, 2, 2, - 401, 402, 7, 111, 2, 2, 402, 403, 7, 103, 2, 2, 403, 404, 7, 118, 2, 2, - 404, 406, 7, 99, 2, 2, 405, 392, 3, 2, 2, 2, 405, 400, 3, 2, 2, 2, 406, - 78, 3, 2, 2, 2, 407, 409, 5, 83, 42, 2, 408, 407, 3, 2, 2, 2, 408, 409, - 3, 2, 2, 2, 409, 420, 3, 2, 2, 2, 410, 412, 7, 36, 2, 2, 411, 413, 5, 85, - 43, 2, 412, 411, 3, 2, 2, 2, 412, 413, 3, 2, 2, 2, 413, 414, 3, 2, 2, 2, - 414, 421, 7, 36, 2, 2, 415, 417, 7, 41, 2, 2, 416, 418, 5, 87, 44, 2, 417, - 416, 3, 2, 2, 2, 417, 418, 3, 2, 2, 2, 418, 419, 3, 2, 2, 2, 419, 421, - 7, 41, 2, 2, 420, 410, 3, 2, 2, 2, 420, 415, 3, 2, 2, 2, 421, 80, 3, 2, - 2, 2, 422, 430, 5, 77, 39, 2, 423, 426, 7, 93, 2, 2, 424, 427, 5, 79, 40, - 2, 425, 427, 5, 99, 50, 2, 426, 424, 3, 2, 2, 2, 426, 425, 3, 2, 2, 2, - 427, 428, 3, 2, 2, 2, 428, 429, 7, 95, 2, 2, 429, 431, 3, 2, 2, 2, 430, - 423, 3, 2, 2, 2, 431, 432, 3, 2, 2, 2, 432, 430, 3, 2, 2, 2, 432, 433, - 3, 2, 2, 2, 433, 82, 3, 2, 2, 2, 434, 435, 7, 119, 2, 2, 435, 438, 7, 58, - 2, 2, 436, 438, 9, 2, 2, 2, 437, 434, 3, 2, 2, 2, 437, 436, 3, 2, 2, 2, - 438, 84, 3, 2, 2, 2, 439, 441, 5, 89, 45, 2, 440, 439, 3, 2, 2, 2, 441, - 442, 3, 2, 2, 2, 442, 440, 3, 2, 2, 2, 442, 443, 3, 2, 2, 2, 443, 86, 3, - 2, 2, 2, 444, 446, 5, 91, 46, 2, 445, 444, 3, 2, 2, 2, 446, 447, 3, 2, - 2, 2, 447, 445, 3, 2, 2, 2, 447, 448, 3, 2, 2, 2, 448, 88, 3, 2, 2, 2, - 449, 457, 10, 3, 2, 2, 450, 457, 5, 131, 66, 2, 451, 452, 7, 94, 2, 2, - 452, 457, 7, 12, 2, 2, 453, 454, 7, 94, 2, 2, 454, 455, 7, 15, 2, 2, 455, - 457, 7, 12, 2, 2, 456, 449, 3, 2, 2, 2, 456, 450, 3, 2, 2, 2, 456, 451, - 3, 2, 2, 2, 456, 453, 3, 2, 2, 2, 457, 90, 3, 2, 2, 2, 458, 466, 10, 4, - 2, 2, 459, 466, 5, 131, 66, 2, 460, 461, 7, 94, 2, 2, 461, 466, 7, 12, - 2, 2, 462, 463, 7, 94, 2, 2, 463, 464, 7, 15, 2, 2, 464, 466, 7, 12, 2, - 2, 465, 458, 3, 2, 2, 2, 465, 459, 3, 2, 2, 2, 465, 460, 3, 2, 2, 2, 465, - 462, 3, 2, 2, 2, 466, 92, 3, 2, 2, 2, 467, 468, 9, 5, 2, 2, 468, 94, 3, - 2, 2, 2, 469, 470, 9, 6, 2, 2, 470, 96, 3, 2, 2, 2, 471, 472, 7, 50, 2, - 2, 472, 474, 9, 7, 2, 2, 473, 475, 9, 8, 2, 2, 474, 473, 3, 2, 2, 2, 475, - 476, 3, 2, 2, 2, 476, 474, 3, 2, 2, 2, 476, 477, 3, 2, 2, 2, 477, 98, 3, - 2, 2, 2, 478, 482, 5, 105, 53, 2, 479, 481, 5, 95, 48, 2, 480, 479, 3, - 2, 2, 2, 481, 484, 3, 2, 2, 2, 482, 480, 3, 2, 2, 2, 482, 483, 3, 2, 2, - 2, 483, 487, 3, 2, 2, 2, 484, 482, 3, 2, 2, 2, 485, 487, 7, 50, 2, 2, 486, - 478, 3, 2, 2, 2, 486, 485, 3, 2, 2, 2, 487, 100, 3, 2, 2, 2, 488, 492, - 7, 50, 2, 2, 489, 491, 5, 107, 54, 2, 490, 489, 3, 2, 2, 2, 491, 494, 3, - 2, 2, 2, 492, 490, 3, 2, 2, 2, 492, 493, 3, 2, 2, 2, 493, 102, 3, 2, 2, - 2, 494, 492, 3, 2, 2, 2, 495, 496, 7, 50, 2, 2, 496, 497, 9, 9, 2, 2, 497, - 498, 5, 127, 64, 2, 498, 104, 3, 2, 2, 2, 499, 500, 9, 10, 2, 2, 500, 106, - 3, 2, 2, 2, 501, 502, 9, 11, 2, 2, 502, 108, 3, 2, 2, 2, 503, 504, 9, 12, - 2, 2, 504, 110, 3, 2, 2, 2, 505, 506, 5, 109, 55, 2, 506, 507, 5, 109, - 55, 2, 507, 508, 5, 109, 55, 2, 508, 509, 5, 109, 55, 2, 509, 112, 3, 2, - 2, 2, 510, 511, 7, 94, 2, 2, 511, 512, 7, 119, 2, 2, 512, 513, 3, 2, 2, - 2, 513, 521, 5, 111, 56, 2, 514, 515, 7, 94, 2, 2, 515, 516, 7, 87, 2, - 2, 516, 517, 3, 2, 2, 2, 517, 518, 5, 111, 56, 2, 518, 519, 5, 111, 56, - 2, 519, 521, 3, 2, 2, 2, 520, 510, 3, 2, 2, 2, 520, 514, 3, 2, 2, 2, 521, - 114, 3, 2, 2, 2, 522, 524, 5, 119, 60, 2, 523, 525, 5, 121, 61, 2, 524, - 523, 3, 2, 2, 2, 524, 525, 3, 2, 2, 2, 525, 530, 3, 2, 2, 2, 526, 527, - 5, 123, 62, 2, 527, 528, 5, 121, 61, 2, 528, 530, 3, 2, 2, 2, 529, 522, - 3, 2, 2, 2, 529, 526, 3, 2, 2, 2, 530, 116, 3, 2, 2, 2, 531, 532, 7, 50, - 2, 2, 532, 535, 9, 9, 2, 2, 533, 536, 5, 125, 63, 2, 534, 536, 5, 127, - 64, 2, 535, 533, 3, 2, 2, 2, 535, 534, 3, 2, 2, 2, 536, 537, 3, 2, 2, 2, - 537, 538, 5, 129, 65, 2, 538, 118, 3, 2, 2, 2, 539, 541, 5, 123, 62, 2, - 540, 539, 3, 2, 2, 2, 540, 541, 3, 2, 2, 2, 541, 542, 3, 2, 2, 2, 542, - 543, 7, 48, 2, 2, 543, 548, 5, 123, 62, 2, 544, 545, 5, 123, 62, 2, 545, - 546, 7, 48, 2, 2, 546, 548, 3, 2, 2, 2, 547, 540, 3, 2, 2, 2, 547, 544, - 3, 2, 2, 2, 548, 120, 3, 2, 2, 2, 549, 551, 9, 13, 2, 2, 550, 552, 9, 14, - 2, 2, 551, 550, 3, 2, 2, 2, 551, 552, 3, 2, 2, 2, 552, 553, 3, 2, 2, 2, - 553, 554, 5, 123, 62, 2, 554, 122, 3, 2, 2, 2, 555, 557, 5, 95, 48, 2, - 556, 555, 3, 2, 2, 2, 557, 558, 3, 2, 2, 2, 558, 556, 3, 2, 2, 2, 558, - 559, 3, 2, 2, 2, 559, 124, 3, 2, 2, 2, 560, 562, 5, 127, 64, 2, 561, 560, - 3, 2, 2, 2, 561, 562, 3, 2, 2, 2, 562, 563, 3, 2, 2, 2, 563, 564, 7, 48, - 2, 2, 564, 569, 5, 127, 64, 2, 565, 566, 5, 127, 64, 2, 566, 567, 7, 48, - 2, 2, 567, 569, 3, 2, 2, 2, 568, 561, 3, 2, 2, 2, 568, 565, 3, 2, 2, 2, - 569, 126, 3, 2, 2, 2, 570, 572, 5, 109, 55, 2, 571, 570, 3, 2, 2, 2, 572, - 573, 3, 2, 2, 2, 573, 571, 3, 2, 2, 2, 573, 574, 3, 2, 2, 2, 574, 128, - 3, 2, 2, 2, 575, 577, 9, 15, 2, 2, 576, 578, 9, 14, 2, 2, 577, 576, 3, - 2, 2, 2, 577, 578, 3, 2, 2, 2, 578, 579, 3, 2, 2, 2, 579, 580, 5, 123, - 62, 2, 580, 130, 3, 2, 2, 2, 581, 582, 7, 94, 2, 2, 582, 597, 9, 16, 2, - 2, 583, 584, 7, 94, 2, 2, 584, 586, 5, 107, 54, 2, 585, 587, 5, 107, 54, - 2, 586, 585, 3, 2, 2, 2, 586, 587, 3, 2, 2, 2, 587, 589, 3, 2, 2, 2, 588, - 590, 5, 107, 54, 2, 589, 588, 3, 2, 2, 2, 589, 590, 3, 2, 2, 2, 590, 597, - 3, 2, 2, 2, 591, 592, 7, 94, 2, 2, 592, 593, 7, 122, 2, 2, 593, 594, 3, - 2, 2, 2, 594, 597, 5, 127, 64, 2, 595, 597, 5, 113, 57, 2, 596, 581, 3, - 2, 2, 2, 596, 583, 3, 2, 2, 2, 596, 591, 3, 2, 2, 2, 596, 595, 3, 2, 2, - 2, 597, 132, 3, 2, 2, 2, 598, 600, 9, 17, 2, 2, 599, 598, 3, 2, 2, 2, 600, - 601, 3, 2, 2, 2, 601, 599, 3, 2, 2, 2, 601, 602, 3, 2, 2, 2, 602, 603, - 3, 2, 2, 2, 603, 604, 8, 67, 2, 2, 604, 134, 3, 2, 2, 2, 605, 607, 7, 15, - 2, 2, 606, 608, 7, 12, 2, 2, 607, 606, 3, 2, 2, 2, 607, 608, 3, 2, 2, 2, - 608, 611, 3, 2, 2, 2, 609, 611, 7, 12, 2, 2, 610, 605, 3, 2, 2, 2, 610, - 609, 3, 2, 2, 2, 611, 612, 3, 2, 2, 2, 612, 613, 8, 68, 2, 2, 613, 136, - 3, 2, 2, 2, 52, 2, 171, 185, 217, 223, 231, 246, 248, 279, 315, 351, 380, - 386, 390, 395, 397, 405, 408, 412, 417, 420, 426, 432, 437, 442, 447, 456, - 465, 476, 482, 486, 492, 520, 524, 529, 535, 540, 547, 551, 558, 561, 568, - 573, 577, 586, 589, 596, 601, 607, 610, 3, 8, 2, 2, + 36, 3, 36, 3, 36, 3, 36, 3, 36, 5, 36, 390, 10, 36, 3, 37, 3, 37, 3, 37, + 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, + 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, + 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, 37, 3, + 37, 3, 37, 5, 37, 428, 10, 37, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, + 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, + 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, + 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 5, 38, 466, + 10, 38, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, + 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, 39, 3, + 39, 3, 39, 3, 39, 3, 39, 3, 39, 5, 39, 492, 10, 39, 3, 40, 3, 40, 3, 40, + 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, + 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, 3, 40, + 3, 40, 3, 40, 3, 40, 5, 40, 521, 10, 40, 3, 41, 3, 41, 3, 41, 3, 41, 5, + 41, 527, 10, 41, 3, 42, 3, 42, 5, 42, 531, 10, 42, 3, 43, 3, 43, 3, 43, + 7, 43, 536, 10, 43, 12, 43, 14, 43, 539, 11, 43, 3, 43, 3, 43, 3, 43, 3, + 43, 3, 43, 5, 43, 546, 10, 43, 3, 44, 5, 44, 549, 10, 44, 3, 44, 3, 44, + 5, 44, 553, 10, 44, 3, 44, 3, 44, 3, 44, 5, 44, 558, 10, 44, 3, 44, 5, + 44, 561, 10, 44, 3, 45, 3, 45, 3, 45, 3, 45, 5, 45, 567, 10, 45, 3, 45, + 3, 45, 6, 45, 571, 10, 45, 13, 45, 14, 45, 572, 3, 46, 3, 46, 3, 46, 5, + 46, 578, 10, 46, 3, 47, 6, 47, 581, 10, 47, 13, 47, 14, 47, 582, 3, 48, + 6, 48, 586, 10, 48, 13, 48, 14, 48, 587, 3, 49, 3, 49, 3, 49, 3, 49, 3, + 49, 3, 49, 3, 49, 5, 49, 597, 10, 49, 3, 50, 3, 50, 3, 50, 3, 50, 3, 50, + 3, 50, 3, 50, 5, 50, 606, 10, 50, 3, 51, 3, 51, 3, 52, 3, 52, 3, 53, 3, + 53, 3, 53, 6, 53, 615, 10, 53, 13, 53, 14, 53, 616, 3, 54, 3, 54, 7, 54, + 621, 10, 54, 12, 54, 14, 54, 624, 11, 54, 3, 54, 5, 54, 627, 10, 54, 3, + 55, 3, 55, 7, 55, 631, 10, 55, 12, 55, 14, 55, 634, 11, 55, 3, 56, 3, 56, + 3, 56, 3, 56, 3, 57, 3, 57, 3, 58, 3, 58, 3, 59, 3, 59, 3, 60, 3, 60, 3, + 60, 3, 60, 3, 60, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, 3, 61, + 3, 61, 3, 61, 5, 61, 661, 10, 61, 3, 62, 3, 62, 5, 62, 665, 10, 62, 3, + 62, 3, 62, 3, 62, 5, 62, 670, 10, 62, 3, 63, 3, 63, 3, 63, 3, 63, 5, 63, + 676, 10, 63, 3, 63, 3, 63, 3, 64, 5, 64, 681, 10, 64, 3, 64, 3, 64, 3, + 64, 3, 64, 3, 64, 5, 64, 688, 10, 64, 3, 65, 3, 65, 5, 65, 692, 10, 65, + 3, 65, 3, 65, 3, 66, 6, 66, 697, 10, 66, 13, 66, 14, 66, 698, 3, 67, 5, + 67, 702, 10, 67, 3, 67, 3, 67, 3, 67, 3, 67, 3, 67, 5, 67, 709, 10, 67, + 3, 68, 6, 68, 712, 10, 68, 13, 68, 14, 68, 713, 3, 69, 3, 69, 5, 69, 718, + 10, 69, 3, 69, 3, 69, 3, 70, 3, 70, 3, 70, 3, 70, 3, 70, 5, 70, 727, 10, + 70, 3, 70, 5, 70, 730, 10, 70, 3, 70, 3, 70, 3, 70, 3, 70, 3, 70, 5, 70, + 737, 10, 70, 3, 71, 6, 71, 740, 10, 71, 13, 71, 14, 71, 741, 3, 71, 3, + 71, 3, 72, 3, 72, 5, 72, 748, 10, 72, 3, 72, 5, 72, 751, 10, 72, 3, 72, + 3, 72, 2, 2, 73, 3, 3, 5, 4, 7, 5, 9, 6, 11, 7, 13, 8, 15, 9, 17, 10, 19, + 11, 21, 12, 23, 13, 25, 14, 27, 15, 29, 16, 31, 17, 33, 18, 35, 19, 37, + 20, 39, 21, 41, 22, 43, 23, 45, 24, 47, 25, 49, 26, 51, 27, 53, 28, 55, + 29, 57, 30, 59, 31, 61, 32, 63, 33, 65, 34, 67, 35, 69, 36, 71, 37, 73, + 38, 75, 39, 77, 40, 79, 41, 81, 42, 83, 43, 85, 44, 87, 45, 89, 46, 91, + 2, 93, 2, 95, 2, 97, 2, 99, 2, 101, 2, 103, 2, 105, 2, 107, 2, 109, 2, + 111, 2, 113, 2, 115, 2, 117, 2, 119, 2, 121, 2, 123, 2, 125, 2, 127, 2, + 129, 2, 131, 2, 133, 2, 135, 2, 137, 2, 139, 2, 141, 47, 143, 48, 3, 2, + 18, 5, 2, 78, 78, 87, 87, 119, 119, 6, 2, 12, 12, 15, 15, 36, 36, 94, 94, + 6, 2, 12, 12, 15, 15, 41, 41, 94, 94, 5, 2, 67, 92, 97, 97, 99, 124, 3, + 2, 50, 59, 4, 2, 68, 68, 100, 100, 3, 2, 50, 51, 4, 2, 90, 90, 122, 122, + 3, 2, 51, 59, 3, 2, 50, 57, 5, 2, 50, 59, 67, 72, 99, 104, 4, 2, 71, 71, + 103, 103, 4, 2, 45, 45, 47, 47, 4, 2, 82, 82, 114, 114, 12, 2, 36, 36, + 41, 41, 65, 65, 94, 94, 99, 100, 104, 104, 112, 112, 116, 116, 118, 118, + 120, 120, 4, 2, 11, 11, 34, 34, 2, 793, 2, 3, 3, 2, 2, 2, 2, 5, 3, 2, 2, + 2, 2, 7, 3, 2, 2, 2, 2, 9, 3, 2, 2, 2, 2, 11, 3, 2, 2, 2, 2, 13, 3, 2, + 2, 2, 2, 15, 3, 2, 2, 2, 2, 17, 3, 2, 2, 2, 2, 19, 3, 2, 2, 2, 2, 21, 3, + 2, 2, 2, 2, 23, 3, 2, 2, 2, 2, 25, 3, 2, 2, 2, 2, 27, 3, 2, 2, 2, 2, 29, + 3, 2, 2, 2, 2, 31, 3, 2, 2, 2, 2, 33, 3, 2, 2, 2, 2, 35, 3, 2, 2, 2, 2, + 37, 3, 2, 2, 2, 2, 39, 3, 2, 2, 2, 2, 41, 3, 2, 2, 2, 2, 43, 3, 2, 2, 2, + 2, 45, 3, 2, 2, 2, 2, 47, 3, 2, 2, 2, 2, 49, 3, 2, 2, 2, 2, 51, 3, 2, 2, + 2, 2, 53, 3, 2, 2, 2, 2, 55, 3, 2, 2, 2, 2, 57, 3, 2, 2, 2, 2, 59, 3, 2, + 2, 2, 2, 61, 3, 2, 2, 2, 2, 63, 3, 2, 2, 2, 2, 65, 3, 2, 2, 2, 2, 67, 3, + 2, 2, 2, 2, 69, 3, 2, 2, 2, 2, 71, 3, 2, 2, 2, 2, 73, 3, 2, 2, 2, 2, 75, + 3, 2, 2, 2, 2, 77, 3, 2, 2, 2, 2, 79, 3, 2, 2, 2, 2, 81, 3, 2, 2, 2, 2, + 83, 3, 2, 2, 2, 2, 85, 3, 2, 2, 2, 2, 87, 3, 2, 2, 2, 2, 89, 3, 2, 2, 2, + 2, 141, 3, 2, 2, 2, 2, 143, 3, 2, 2, 2, 3, 145, 3, 2, 2, 2, 5, 147, 3, + 2, 2, 2, 7, 149, 3, 2, 2, 2, 9, 151, 3, 2, 2, 2, 11, 153, 3, 2, 2, 2, 13, + 155, 3, 2, 2, 2, 15, 157, 3, 2, 2, 2, 17, 160, 3, 2, 2, 2, 19, 162, 3, + 2, 2, 2, 21, 165, 3, 2, 2, 2, 23, 168, 3, 2, 2, 2, 25, 179, 3, 2, 2, 2, + 27, 193, 3, 2, 2, 2, 29, 195, 3, 2, 2, 2, 31, 197, 3, 2, 2, 2, 33, 199, + 3, 2, 2, 2, 35, 201, 3, 2, 2, 2, 37, 203, 3, 2, 2, 2, 39, 205, 3, 2, 2, + 2, 41, 208, 3, 2, 2, 2, 43, 211, 3, 2, 2, 2, 45, 214, 3, 2, 2, 2, 47, 216, + 3, 2, 2, 2, 49, 218, 3, 2, 2, 2, 51, 225, 3, 2, 2, 2, 53, 231, 3, 2, 2, + 2, 55, 233, 3, 2, 2, 2, 57, 239, 3, 2, 2, 2, 59, 241, 3, 2, 2, 2, 61, 244, + 3, 2, 2, 2, 63, 251, 3, 2, 2, 2, 65, 287, 3, 2, 2, 2, 67, 323, 3, 2, 2, + 2, 69, 359, 3, 2, 2, 2, 71, 389, 3, 2, 2, 2, 73, 427, 3, 2, 2, 2, 75, 465, + 3, 2, 2, 2, 77, 491, 3, 2, 2, 2, 79, 520, 3, 2, 2, 2, 81, 526, 3, 2, 2, + 2, 83, 530, 3, 2, 2, 2, 85, 545, 3, 2, 2, 2, 87, 548, 3, 2, 2, 2, 89, 562, + 3, 2, 2, 2, 91, 577, 3, 2, 2, 2, 93, 580, 3, 2, 2, 2, 95, 585, 3, 2, 2, + 2, 97, 596, 3, 2, 2, 2, 99, 605, 3, 2, 2, 2, 101, 607, 3, 2, 2, 2, 103, + 609, 3, 2, 2, 2, 105, 611, 3, 2, 2, 2, 107, 626, 3, 2, 2, 2, 109, 628, + 3, 2, 2, 2, 111, 635, 3, 2, 2, 2, 113, 639, 3, 2, 2, 2, 115, 641, 3, 2, + 2, 2, 117, 643, 3, 2, 2, 2, 119, 645, 3, 2, 2, 2, 121, 660, 3, 2, 2, 2, + 123, 669, 3, 2, 2, 2, 125, 671, 3, 2, 2, 2, 127, 687, 3, 2, 2, 2, 129, + 689, 3, 2, 2, 2, 131, 696, 3, 2, 2, 2, 133, 708, 3, 2, 2, 2, 135, 711, + 3, 2, 2, 2, 137, 715, 3, 2, 2, 2, 139, 736, 3, 2, 2, 2, 141, 739, 3, 2, + 2, 2, 143, 750, 3, 2, 2, 2, 145, 146, 7, 42, 2, 2, 146, 4, 3, 2, 2, 2, + 147, 148, 7, 43, 2, 2, 148, 6, 3, 2, 2, 2, 149, 150, 7, 93, 2, 2, 150, + 8, 3, 2, 2, 2, 151, 152, 7, 46, 2, 2, 152, 10, 3, 2, 2, 2, 153, 154, 7, + 95, 2, 2, 154, 12, 3, 2, 2, 2, 155, 156, 7, 62, 2, 2, 156, 14, 3, 2, 2, + 2, 157, 158, 7, 62, 2, 2, 158, 159, 7, 63, 2, 2, 159, 16, 3, 2, 2, 2, 160, + 161, 7, 64, 2, 2, 161, 18, 3, 2, 2, 2, 162, 163, 7, 64, 2, 2, 163, 164, + 7, 63, 2, 2, 164, 20, 3, 2, 2, 2, 165, 166, 7, 63, 2, 2, 166, 167, 7, 63, + 2, 2, 167, 22, 3, 2, 2, 2, 168, 169, 7, 35, 2, 2, 169, 170, 7, 63, 2, 2, + 170, 24, 3, 2, 2, 2, 171, 172, 7, 110, 2, 2, 172, 173, 7, 107, 2, 2, 173, + 174, 7, 109, 2, 2, 174, 180, 7, 103, 2, 2, 175, 176, 7, 78, 2, 2, 176, + 177, 7, 75, 2, 2, 177, 178, 7, 77, 2, 2, 178, 180, 7, 71, 2, 2, 179, 171, + 3, 2, 2, 2, 179, 175, 3, 2, 2, 2, 180, 26, 3, 2, 2, 2, 181, 182, 7, 103, + 2, 2, 182, 183, 7, 122, 2, 2, 183, 184, 7, 107, 2, 2, 184, 185, 7, 117, + 2, 2, 185, 186, 7, 118, 2, 2, 186, 194, 7, 117, 2, 2, 187, 188, 7, 71, + 2, 2, 188, 189, 7, 90, 2, 2, 189, 190, 7, 75, 2, 2, 190, 191, 7, 85, 2, + 2, 191, 192, 7, 86, 2, 2, 192, 194, 7, 85, 2, 2, 193, 181, 3, 2, 2, 2, + 193, 187, 3, 2, 2, 2, 194, 28, 3, 2, 2, 2, 195, 196, 7, 45, 2, 2, 196, + 30, 3, 2, 2, 2, 197, 198, 7, 47, 2, 2, 198, 32, 3, 2, 2, 2, 199, 200, 7, + 44, 2, 2, 200, 34, 3, 2, 2, 2, 201, 202, 7, 49, 2, 2, 202, 36, 3, 2, 2, + 2, 203, 204, 7, 39, 2, 2, 204, 38, 3, 2, 2, 2, 205, 206, 7, 44, 2, 2, 206, + 207, 7, 44, 2, 2, 207, 40, 3, 2, 2, 2, 208, 209, 7, 62, 2, 2, 209, 210, + 7, 62, 2, 2, 210, 42, 3, 2, 2, 2, 211, 212, 7, 64, 2, 2, 212, 213, 7, 64, + 2, 2, 213, 44, 3, 2, 2, 2, 214, 215, 7, 40, 2, 2, 215, 46, 3, 2, 2, 2, + 216, 217, 7, 126, 2, 2, 217, 48, 3, 2, 2, 2, 218, 219, 7, 96, 2, 2, 219, + 50, 3, 2, 2, 2, 220, 221, 7, 40, 2, 2, 221, 226, 7, 40, 2, 2, 222, 223, + 7, 99, 2, 2, 223, 224, 7, 112, 2, 2, 224, 226, 7, 102, 2, 2, 225, 220, + 3, 2, 2, 2, 225, 222, 3, 2, 2, 2, 226, 52, 3, 2, 2, 2, 227, 228, 7, 126, + 2, 2, 228, 232, 7, 126, 2, 2, 229, 230, 7, 113, 2, 2, 230, 232, 7, 116, + 2, 2, 231, 227, 3, 2, 2, 2, 231, 229, 3, 2, 2, 2, 232, 54, 3, 2, 2, 2, + 233, 234, 7, 128, 2, 2, 234, 56, 3, 2, 2, 2, 235, 240, 7, 35, 2, 2, 236, + 237, 7, 112, 2, 2, 237, 238, 7, 113, 2, 2, 238, 240, 7, 118, 2, 2, 239, + 235, 3, 2, 2, 2, 239, 236, 3, 2, 2, 2, 240, 58, 3, 2, 2, 2, 241, 242, 7, + 107, 2, 2, 242, 243, 7, 112, 2, 2, 243, 60, 3, 2, 2, 2, 244, 245, 7, 112, + 2, 2, 245, 246, 7, 113, 2, 2, 246, 247, 7, 118, 2, 2, 247, 248, 7, 34, + 2, 2, 248, 249, 7, 107, 2, 2, 249, 250, 7, 112, 2, 2, 250, 62, 3, 2, 2, + 2, 251, 256, 7, 93, 2, 2, 252, 255, 5, 141, 71, 2, 253, 255, 5, 143, 72, + 2, 254, 252, 3, 2, 2, 2, 254, 253, 3, 2, 2, 2, 255, 258, 3, 2, 2, 2, 256, + 254, 3, 2, 2, 2, 256, 257, 3, 2, 2, 2, 257, 259, 3, 2, 2, 2, 258, 256, + 3, 2, 2, 2, 259, 260, 7, 95, 2, 2, 260, 64, 3, 2, 2, 2, 261, 262, 7, 108, + 2, 2, 262, 263, 7, 117, 2, 2, 263, 264, 7, 113, 2, 2, 264, 265, 7, 112, + 2, 2, 265, 266, 7, 97, 2, 2, 266, 267, 7, 101, 2, 2, 267, 268, 7, 113, + 2, 2, 268, 269, 7, 112, 2, 2, 269, 270, 7, 118, 2, 2, 270, 271, 7, 99, + 2, 2, 271, 272, 7, 107, 2, 2, 272, 273, 7, 112, 2, 2, 273, 288, 7, 117, + 2, 2, 274, 275, 7, 76, 2, 2, 275, 276, 7, 85, 2, 2, 276, 277, 7, 81, 2, + 2, 277, 278, 7, 80, 2, 2, 278, 279, 7, 97, 2, 2, 279, 280, 7, 69, 2, 2, + 280, 281, 7, 81, 2, 2, 281, 282, 7, 80, 2, 2, 282, 283, 7, 86, 2, 2, 283, + 284, 7, 67, 2, 2, 284, 285, 7, 75, 2, 2, 285, 286, 7, 80, 2, 2, 286, 288, + 7, 85, 2, 2, 287, 261, 3, 2, 2, 2, 287, 274, 3, 2, 2, 2, 288, 66, 3, 2, + 2, 2, 289, 290, 7, 108, 2, 2, 290, 291, 7, 117, 2, 2, 291, 292, 7, 113, + 2, 2, 292, 293, 7, 112, 2, 2, 293, 294, 7, 97, 2, 2, 294, 295, 7, 101, + 2, 2, 295, 296, 7, 113, 2, 2, 296, 297, 7, 112, 2, 2, 297, 298, 7, 118, + 2, 2, 298, 299, 7, 99, 2, 2, 299, 300, 7, 107, 2, 2, 300, 301, 7, 112, + 2, 2, 301, 302, 7, 117, 2, 2, 302, 303, 7, 97, 2, 2, 303, 304, 7, 99, 2, + 2, 304, 305, 7, 110, 2, 2, 305, 324, 7, 110, 2, 2, 306, 307, 7, 76, 2, + 2, 307, 308, 7, 85, 2, 2, 308, 309, 7, 81, 2, 2, 309, 310, 7, 80, 2, 2, + 310, 311, 7, 97, 2, 2, 311, 312, 7, 69, 2, 2, 312, 313, 7, 81, 2, 2, 313, + 314, 7, 80, 2, 2, 314, 315, 7, 86, 2, 2, 315, 316, 7, 67, 2, 2, 316, 317, + 7, 75, 2, 2, 317, 318, 7, 80, 2, 2, 318, 319, 7, 85, 2, 2, 319, 320, 7, + 97, 2, 2, 320, 321, 7, 67, 2, 2, 321, 322, 7, 78, 2, 2, 322, 324, 7, 78, + 2, 2, 323, 289, 3, 2, 2, 2, 323, 306, 3, 2, 2, 2, 324, 68, 3, 2, 2, 2, + 325, 326, 7, 108, 2, 2, 326, 327, 7, 117, 2, 2, 327, 328, 7, 113, 2, 2, + 328, 329, 7, 112, 2, 2, 329, 330, 7, 97, 2, 2, 330, 331, 7, 101, 2, 2, + 331, 332, 7, 113, 2, 2, 332, 333, 7, 112, 2, 2, 333, 334, 7, 118, 2, 2, + 334, 335, 7, 99, 2, 2, 335, 336, 7, 107, 2, 2, 336, 337, 7, 112, 2, 2, + 337, 338, 7, 117, 2, 2, 338, 339, 7, 97, 2, 2, 339, 340, 7, 99, 2, 2, 340, + 341, 7, 112, 2, 2, 341, 360, 7, 123, 2, 2, 342, 343, 7, 76, 2, 2, 343, + 344, 7, 85, 2, 2, 344, 345, 7, 81, 2, 2, 345, 346, 7, 80, 2, 2, 346, 347, + 7, 97, 2, 2, 347, 348, 7, 69, 2, 2, 348, 349, 7, 81, 2, 2, 349, 350, 7, + 80, 2, 2, 350, 351, 7, 86, 2, 2, 351, 352, 7, 67, 2, 2, 352, 353, 7, 75, + 2, 2, 353, 354, 7, 80, 2, 2, 354, 355, 7, 85, 2, 2, 355, 356, 7, 97, 2, + 2, 356, 357, 7, 67, 2, 2, 357, 358, 7, 80, 2, 2, 358, 360, 7, 91, 2, 2, + 359, 325, 3, 2, 2, 2, 359, 342, 3, 2, 2, 2, 360, 70, 3, 2, 2, 2, 361, 362, + 7, 99, 2, 2, 362, 363, 7, 116, 2, 2, 363, 364, 7, 116, 2, 2, 364, 365, + 7, 99, 2, 2, 365, 366, 7, 123, 2, 2, 366, 367, 7, 97, 2, 2, 367, 368, 7, + 101, 2, 2, 368, 369, 7, 113, 2, 2, 369, 370, 7, 112, 2, 2, 370, 371, 7, + 118, 2, 2, 371, 372, 7, 99, 2, 2, 372, 373, 7, 107, 2, 2, 373, 374, 7, + 112, 2, 2, 374, 390, 7, 117, 2, 2, 375, 376, 7, 67, 2, 2, 376, 377, 7, + 84, 2, 2, 377, 378, 7, 84, 2, 2, 378, 379, 7, 67, 2, 2, 379, 380, 7, 91, + 2, 2, 380, 381, 7, 97, 2, 2, 381, 382, 7, 69, 2, 2, 382, 383, 7, 81, 2, + 2, 383, 384, 7, 80, 2, 2, 384, 385, 7, 86, 2, 2, 385, 386, 7, 67, 2, 2, + 386, 387, 7, 75, 2, 2, 387, 388, 7, 80, 2, 2, 388, 390, 7, 85, 2, 2, 389, + 361, 3, 2, 2, 2, 389, 375, 3, 2, 2, 2, 390, 72, 3, 2, 2, 2, 391, 392, 7, + 99, 2, 2, 392, 393, 7, 116, 2, 2, 393, 394, 7, 116, 2, 2, 394, 395, 7, + 99, 2, 2, 395, 396, 7, 123, 2, 2, 396, 397, 7, 97, 2, 2, 397, 398, 7, 101, + 2, 2, 398, 399, 7, 113, 2, 2, 399, 400, 7, 112, 2, 2, 400, 401, 7, 118, + 2, 2, 401, 402, 7, 99, 2, 2, 402, 403, 7, 107, 2, 2, 403, 404, 7, 112, + 2, 2, 404, 405, 7, 117, 2, 2, 405, 406, 7, 97, 2, 2, 406, 407, 7, 99, 2, + 2, 407, 408, 7, 110, 2, 2, 408, 428, 7, 110, 2, 2, 409, 410, 7, 67, 2, + 2, 410, 411, 7, 84, 2, 2, 411, 412, 7, 84, 2, 2, 412, 413, 7, 67, 2, 2, + 413, 414, 7, 91, 2, 2, 414, 415, 7, 97, 2, 2, 415, 416, 7, 69, 2, 2, 416, + 417, 7, 81, 2, 2, 417, 418, 7, 80, 2, 2, 418, 419, 7, 86, 2, 2, 419, 420, + 7, 67, 2, 2, 420, 421, 7, 75, 2, 2, 421, 422, 7, 80, 2, 2, 422, 423, 7, + 85, 2, 2, 423, 424, 7, 97, 2, 2, 424, 425, 7, 67, 2, 2, 425, 426, 7, 78, + 2, 2, 426, 428, 7, 78, 2, 2, 427, 391, 3, 2, 2, 2, 427, 409, 3, 2, 2, 2, + 428, 74, 3, 2, 2, 2, 429, 430, 7, 99, 2, 2, 430, 431, 7, 116, 2, 2, 431, + 432, 7, 116, 2, 2, 432, 433, 7, 99, 2, 2, 433, 434, 7, 123, 2, 2, 434, + 435, 7, 97, 2, 2, 435, 436, 7, 101, 2, 2, 436, 437, 7, 113, 2, 2, 437, + 438, 7, 112, 2, 2, 438, 439, 7, 118, 2, 2, 439, 440, 7, 99, 2, 2, 440, + 441, 7, 107, 2, 2, 441, 442, 7, 112, 2, 2, 442, 443, 7, 117, 2, 2, 443, + 444, 7, 97, 2, 2, 444, 445, 7, 99, 2, 2, 445, 446, 7, 112, 2, 2, 446, 466, + 7, 123, 2, 2, 447, 448, 7, 67, 2, 2, 448, 449, 7, 84, 2, 2, 449, 450, 7, + 84, 2, 2, 450, 451, 7, 67, 2, 2, 451, 452, 7, 91, 2, 2, 452, 453, 7, 97, + 2, 2, 453, 454, 7, 69, 2, 2, 454, 455, 7, 81, 2, 2, 455, 456, 7, 80, 2, + 2, 456, 457, 7, 86, 2, 2, 457, 458, 7, 67, 2, 2, 458, 459, 7, 75, 2, 2, + 459, 460, 7, 80, 2, 2, 460, 461, 7, 85, 2, 2, 461, 462, 7, 97, 2, 2, 462, + 463, 7, 67, 2, 2, 463, 464, 7, 80, 2, 2, 464, 466, 7, 91, 2, 2, 465, 429, + 3, 2, 2, 2, 465, 447, 3, 2, 2, 2, 466, 76, 3, 2, 2, 2, 467, 468, 7, 99, + 2, 2, 468, 469, 7, 116, 2, 2, 469, 470, 7, 116, 2, 2, 470, 471, 7, 99, + 2, 2, 471, 472, 7, 123, 2, 2, 472, 473, 7, 97, 2, 2, 473, 474, 7, 110, + 2, 2, 474, 475, 7, 103, 2, 2, 475, 476, 7, 112, 2, 2, 476, 477, 7, 105, + 2, 2, 477, 478, 7, 118, 2, 2, 478, 492, 7, 106, 2, 2, 479, 480, 7, 67, + 2, 2, 480, 481, 7, 84, 2, 2, 481, 482, 7, 84, 2, 2, 482, 483, 7, 67, 2, + 2, 483, 484, 7, 91, 2, 2, 484, 485, 7, 97, 2, 2, 485, 486, 7, 78, 2, 2, + 486, 487, 7, 71, 2, 2, 487, 488, 7, 80, 2, 2, 488, 489, 7, 73, 2, 2, 489, + 490, 7, 86, 2, 2, 490, 492, 7, 74, 2, 2, 491, 467, 3, 2, 2, 2, 491, 479, + 3, 2, 2, 2, 492, 78, 3, 2, 2, 2, 493, 494, 7, 118, 2, 2, 494, 495, 7, 116, + 2, 2, 495, 496, 7, 119, 2, 2, 496, 521, 7, 103, 2, 2, 497, 498, 7, 86, + 2, 2, 498, 499, 7, 116, 2, 2, 499, 500, 7, 119, 2, 2, 500, 521, 7, 103, + 2, 2, 501, 502, 7, 86, 2, 2, 502, 503, 7, 84, 2, 2, 503, 504, 7, 87, 2, + 2, 504, 521, 7, 71, 2, 2, 505, 506, 7, 104, 2, 2, 506, 507, 7, 99, 2, 2, + 507, 508, 7, 110, 2, 2, 508, 509, 7, 117, 2, 2, 509, 521, 7, 103, 2, 2, + 510, 511, 7, 72, 2, 2, 511, 512, 7, 99, 2, 2, 512, 513, 7, 110, 2, 2, 513, + 514, 7, 117, 2, 2, 514, 521, 7, 103, 2, 2, 515, 516, 7, 72, 2, 2, 516, + 517, 7, 67, 2, 2, 517, 518, 7, 78, 2, 2, 518, 519, 7, 85, 2, 2, 519, 521, + 7, 71, 2, 2, 520, 493, 3, 2, 2, 2, 520, 497, 3, 2, 2, 2, 520, 501, 3, 2, + 2, 2, 520, 505, 3, 2, 2, 2, 520, 510, 3, 2, 2, 2, 520, 515, 3, 2, 2, 2, + 521, 80, 3, 2, 2, 2, 522, 527, 5, 107, 54, 2, 523, 527, 5, 109, 55, 2, + 524, 527, 5, 111, 56, 2, 525, 527, 5, 105, 53, 2, 526, 522, 3, 2, 2, 2, + 526, 523, 3, 2, 2, 2, 526, 524, 3, 2, 2, 2, 526, 525, 3, 2, 2, 2, 527, + 82, 3, 2, 2, 2, 528, 531, 5, 123, 62, 2, 529, 531, 5, 125, 63, 2, 530, + 528, 3, 2, 2, 2, 530, 529, 3, 2, 2, 2, 531, 84, 3, 2, 2, 2, 532, 537, 5, + 101, 51, 2, 533, 536, 5, 101, 51, 2, 534, 536, 5, 103, 52, 2, 535, 533, + 3, 2, 2, 2, 535, 534, 3, 2, 2, 2, 536, 539, 3, 2, 2, 2, 537, 535, 3, 2, + 2, 2, 537, 538, 3, 2, 2, 2, 538, 546, 3, 2, 2, 2, 539, 537, 3, 2, 2, 2, + 540, 541, 7, 38, 2, 2, 541, 542, 7, 111, 2, 2, 542, 543, 7, 103, 2, 2, + 543, 544, 7, 118, 2, 2, 544, 546, 7, 99, 2, 2, 545, 532, 3, 2, 2, 2, 545, + 540, 3, 2, 2, 2, 546, 86, 3, 2, 2, 2, 547, 549, 5, 91, 46, 2, 548, 547, + 3, 2, 2, 2, 548, 549, 3, 2, 2, 2, 549, 560, 3, 2, 2, 2, 550, 552, 7, 36, + 2, 2, 551, 553, 5, 93, 47, 2, 552, 551, 3, 2, 2, 2, 552, 553, 3, 2, 2, + 2, 553, 554, 3, 2, 2, 2, 554, 561, 7, 36, 2, 2, 555, 557, 7, 41, 2, 2, + 556, 558, 5, 95, 48, 2, 557, 556, 3, 2, 2, 2, 557, 558, 3, 2, 2, 2, 558, + 559, 3, 2, 2, 2, 559, 561, 7, 41, 2, 2, 560, 550, 3, 2, 2, 2, 560, 555, + 3, 2, 2, 2, 561, 88, 3, 2, 2, 2, 562, 570, 5, 85, 43, 2, 563, 566, 7, 93, + 2, 2, 564, 567, 5, 87, 44, 2, 565, 567, 5, 107, 54, 2, 566, 564, 3, 2, + 2, 2, 566, 565, 3, 2, 2, 2, 567, 568, 3, 2, 2, 2, 568, 569, 7, 95, 2, 2, + 569, 571, 3, 2, 2, 2, 570, 563, 3, 2, 2, 2, 571, 572, 3, 2, 2, 2, 572, + 570, 3, 2, 2, 2, 572, 573, 3, 2, 2, 2, 573, 90, 3, 2, 2, 2, 574, 575, 7, + 119, 2, 2, 575, 578, 7, 58, 2, 2, 576, 578, 9, 2, 2, 2, 577, 574, 3, 2, + 2, 2, 577, 576, 3, 2, 2, 2, 578, 92, 3, 2, 2, 2, 579, 581, 5, 97, 49, 2, + 580, 579, 3, 2, 2, 2, 581, 582, 3, 2, 2, 2, 582, 580, 3, 2, 2, 2, 582, + 583, 3, 2, 2, 2, 583, 94, 3, 2, 2, 2, 584, 586, 5, 99, 50, 2, 585, 584, + 3, 2, 2, 2, 586, 587, 3, 2, 2, 2, 587, 585, 3, 2, 2, 2, 587, 588, 3, 2, + 2, 2, 588, 96, 3, 2, 2, 2, 589, 597, 10, 3, 2, 2, 590, 597, 5, 139, 70, + 2, 591, 592, 7, 94, 2, 2, 592, 597, 7, 12, 2, 2, 593, 594, 7, 94, 2, 2, + 594, 595, 7, 15, 2, 2, 595, 597, 7, 12, 2, 2, 596, 589, 3, 2, 2, 2, 596, + 590, 3, 2, 2, 2, 596, 591, 3, 2, 2, 2, 596, 593, 3, 2, 2, 2, 597, 98, 3, + 2, 2, 2, 598, 606, 10, 4, 2, 2, 599, 606, 5, 139, 70, 2, 600, 601, 7, 94, + 2, 2, 601, 606, 7, 12, 2, 2, 602, 603, 7, 94, 2, 2, 603, 604, 7, 15, 2, + 2, 604, 606, 7, 12, 2, 2, 605, 598, 3, 2, 2, 2, 605, 599, 3, 2, 2, 2, 605, + 600, 3, 2, 2, 2, 605, 602, 3, 2, 2, 2, 606, 100, 3, 2, 2, 2, 607, 608, + 9, 5, 2, 2, 608, 102, 3, 2, 2, 2, 609, 610, 9, 6, 2, 2, 610, 104, 3, 2, + 2, 2, 611, 612, 7, 50, 2, 2, 612, 614, 9, 7, 2, 2, 613, 615, 9, 8, 2, 2, + 614, 613, 3, 2, 2, 2, 615, 616, 3, 2, 2, 2, 616, 614, 3, 2, 2, 2, 616, + 617, 3, 2, 2, 2, 617, 106, 3, 2, 2, 2, 618, 622, 5, 113, 57, 2, 619, 621, + 5, 103, 52, 2, 620, 619, 3, 2, 2, 2, 621, 624, 3, 2, 2, 2, 622, 620, 3, + 2, 2, 2, 622, 623, 3, 2, 2, 2, 623, 627, 3, 2, 2, 2, 624, 622, 3, 2, 2, + 2, 625, 627, 7, 50, 2, 2, 626, 618, 3, 2, 2, 2, 626, 625, 3, 2, 2, 2, 627, + 108, 3, 2, 2, 2, 628, 632, 7, 50, 2, 2, 629, 631, 5, 115, 58, 2, 630, 629, + 3, 2, 2, 2, 631, 634, 3, 2, 2, 2, 632, 630, 3, 2, 2, 2, 632, 633, 3, 2, + 2, 2, 633, 110, 3, 2, 2, 2, 634, 632, 3, 2, 2, 2, 635, 636, 7, 50, 2, 2, + 636, 637, 9, 9, 2, 2, 637, 638, 5, 135, 68, 2, 638, 112, 3, 2, 2, 2, 639, + 640, 9, 10, 2, 2, 640, 114, 3, 2, 2, 2, 641, 642, 9, 11, 2, 2, 642, 116, + 3, 2, 2, 2, 643, 644, 9, 12, 2, 2, 644, 118, 3, 2, 2, 2, 645, 646, 5, 117, + 59, 2, 646, 647, 5, 117, 59, 2, 647, 648, 5, 117, 59, 2, 648, 649, 5, 117, + 59, 2, 649, 120, 3, 2, 2, 2, 650, 651, 7, 94, 2, 2, 651, 652, 7, 119, 2, + 2, 652, 653, 3, 2, 2, 2, 653, 661, 5, 119, 60, 2, 654, 655, 7, 94, 2, 2, + 655, 656, 7, 87, 2, 2, 656, 657, 3, 2, 2, 2, 657, 658, 5, 119, 60, 2, 658, + 659, 5, 119, 60, 2, 659, 661, 3, 2, 2, 2, 660, 650, 3, 2, 2, 2, 660, 654, + 3, 2, 2, 2, 661, 122, 3, 2, 2, 2, 662, 664, 5, 127, 64, 2, 663, 665, 5, + 129, 65, 2, 664, 663, 3, 2, 2, 2, 664, 665, 3, 2, 2, 2, 665, 670, 3, 2, + 2, 2, 666, 667, 5, 131, 66, 2, 667, 668, 5, 129, 65, 2, 668, 670, 3, 2, + 2, 2, 669, 662, 3, 2, 2, 2, 669, 666, 3, 2, 2, 2, 670, 124, 3, 2, 2, 2, + 671, 672, 7, 50, 2, 2, 672, 675, 9, 9, 2, 2, 673, 676, 5, 133, 67, 2, 674, + 676, 5, 135, 68, 2, 675, 673, 3, 2, 2, 2, 675, 674, 3, 2, 2, 2, 676, 677, + 3, 2, 2, 2, 677, 678, 5, 137, 69, 2, 678, 126, 3, 2, 2, 2, 679, 681, 5, + 131, 66, 2, 680, 679, 3, 2, 2, 2, 680, 681, 3, 2, 2, 2, 681, 682, 3, 2, + 2, 2, 682, 683, 7, 48, 2, 2, 683, 688, 5, 131, 66, 2, 684, 685, 5, 131, + 66, 2, 685, 686, 7, 48, 2, 2, 686, 688, 3, 2, 2, 2, 687, 680, 3, 2, 2, + 2, 687, 684, 3, 2, 2, 2, 688, 128, 3, 2, 2, 2, 689, 691, 9, 13, 2, 2, 690, + 692, 9, 14, 2, 2, 691, 690, 3, 2, 2, 2, 691, 692, 3, 2, 2, 2, 692, 693, + 3, 2, 2, 2, 693, 694, 5, 131, 66, 2, 694, 130, 3, 2, 2, 2, 695, 697, 5, + 103, 52, 2, 696, 695, 3, 2, 2, 2, 697, 698, 3, 2, 2, 2, 698, 696, 3, 2, + 2, 2, 698, 699, 3, 2, 2, 2, 699, 132, 3, 2, 2, 2, 700, 702, 5, 135, 68, + 2, 701, 700, 3, 2, 2, 2, 701, 702, 3, 2, 2, 2, 702, 703, 3, 2, 2, 2, 703, + 704, 7, 48, 2, 2, 704, 709, 5, 135, 68, 2, 705, 706, 5, 135, 68, 2, 706, + 707, 7, 48, 2, 2, 707, 709, 3, 2, 2, 2, 708, 701, 3, 2, 2, 2, 708, 705, + 3, 2, 2, 2, 709, 134, 3, 2, 2, 2, 710, 712, 5, 117, 59, 2, 711, 710, 3, + 2, 2, 2, 712, 713, 3, 2, 2, 2, 713, 711, 3, 2, 2, 2, 713, 714, 3, 2, 2, + 2, 714, 136, 3, 2, 2, 2, 715, 717, 9, 15, 2, 2, 716, 718, 9, 14, 2, 2, + 717, 716, 3, 2, 2, 2, 717, 718, 3, 2, 2, 2, 718, 719, 3, 2, 2, 2, 719, + 720, 5, 131, 66, 2, 720, 138, 3, 2, 2, 2, 721, 722, 7, 94, 2, 2, 722, 737, + 9, 16, 2, 2, 723, 724, 7, 94, 2, 2, 724, 726, 5, 115, 58, 2, 725, 727, + 5, 115, 58, 2, 726, 725, 3, 2, 2, 2, 726, 727, 3, 2, 2, 2, 727, 729, 3, + 2, 2, 2, 728, 730, 5, 115, 58, 2, 729, 728, 3, 2, 2, 2, 729, 730, 3, 2, + 2, 2, 730, 737, 3, 2, 2, 2, 731, 732, 7, 94, 2, 2, 732, 733, 7, 122, 2, + 2, 733, 734, 3, 2, 2, 2, 734, 737, 5, 135, 68, 2, 735, 737, 5, 121, 61, + 2, 736, 721, 3, 2, 2, 2, 736, 723, 3, 2, 2, 2, 736, 731, 3, 2, 2, 2, 736, + 735, 3, 2, 2, 2, 737, 140, 3, 2, 2, 2, 738, 740, 9, 17, 2, 2, 739, 738, + 3, 2, 2, 2, 740, 741, 3, 2, 2, 2, 741, 739, 3, 2, 2, 2, 741, 742, 3, 2, + 2, 2, 742, 743, 3, 2, 2, 2, 743, 744, 8, 71, 2, 2, 744, 142, 3, 2, 2, 2, + 745, 747, 7, 15, 2, 2, 746, 748, 7, 12, 2, 2, 747, 746, 3, 2, 2, 2, 747, + 748, 3, 2, 2, 2, 748, 751, 3, 2, 2, 2, 749, 751, 7, 12, 2, 2, 750, 745, + 3, 2, 2, 2, 750, 749, 3, 2, 2, 2, 751, 752, 3, 2, 2, 2, 752, 753, 8, 72, + 2, 2, 753, 144, 3, 2, 2, 2, 56, 2, 179, 193, 225, 231, 239, 254, 256, 287, + 323, 359, 389, 427, 465, 491, 520, 526, 530, 535, 537, 545, 548, 552, 557, + 560, 566, 572, 577, 582, 587, 596, 605, 616, 622, 626, 632, 660, 664, 669, + 675, 680, 687, 691, 698, 701, 708, 713, 717, 726, 729, 736, 741, 747, 750, + 3, 8, 2, 2, } var lexerChannelNames = []string{ @@ -316,7 +377,8 @@ var lexerSymbolicNames = []string{ "", "", "", "", "", "", "LT", "LE", "GT", "GE", "EQ", "NE", "LIKE", "EXISTS", "ADD", "SUB", "MUL", "DIV", "MOD", "POW", "SHL", "SHR", "BAND", "BOR", "BXOR", "AND", "OR", "BNOT", "NOT", "IN", "NIN", "EmptyTerm", "JSONContains", - "JSONContainsAll", "JSONContainsAny", "BooleanConstant", "IntegerConstant", + "JSONContainsAll", "JSONContainsAny", "ArrayContains", "ArrayContainsAll", + "ArrayContainsAny", "ArrayLength", "BooleanConstant", "IntegerConstant", "FloatingConstant", "Identifier", "StringLiteral", "JSONIdentifier", "Whitespace", "Newline", } @@ -325,7 +387,8 @@ var lexerRuleNames = []string{ "T__0", "T__1", "T__2", "T__3", "T__4", "LT", "LE", "GT", "GE", "EQ", "NE", "LIKE", "EXISTS", "ADD", "SUB", "MUL", "DIV", "MOD", "POW", "SHL", "SHR", "BAND", "BOR", "BXOR", "AND", "OR", "BNOT", "NOT", "IN", "NIN", "EmptyTerm", - "JSONContains", "JSONContainsAll", "JSONContainsAny", "BooleanConstant", + "JSONContains", "JSONContainsAll", "JSONContainsAny", "ArrayContains", + "ArrayContainsAll", "ArrayContainsAny", "ArrayLength", "BooleanConstant", "IntegerConstant", "FloatingConstant", "Identifier", "StringLiteral", "JSONIdentifier", "EncodingPrefix", "DoubleSCharSequence", "SingleSCharSequence", "DoubleSChar", "SingleSChar", "Nondigit", "Digit", "BinaryConstant", "DecimalConstant", @@ -407,12 +470,16 @@ const ( PlanLexerJSONContains = 32 PlanLexerJSONContainsAll = 33 PlanLexerJSONContainsAny = 34 - PlanLexerBooleanConstant = 35 - PlanLexerIntegerConstant = 36 - PlanLexerFloatingConstant = 37 - PlanLexerIdentifier = 38 - PlanLexerStringLiteral = 39 - PlanLexerJSONIdentifier = 40 - PlanLexerWhitespace = 41 - PlanLexerNewline = 42 + PlanLexerArrayContains = 35 + PlanLexerArrayContainsAll = 36 + PlanLexerArrayContainsAny = 37 + PlanLexerArrayLength = 38 + PlanLexerBooleanConstant = 39 + PlanLexerIntegerConstant = 40 + PlanLexerFloatingConstant = 41 + PlanLexerIdentifier = 42 + PlanLexerStringLiteral = 43 + PlanLexerJSONIdentifier = 44 + PlanLexerWhitespace = 45 + PlanLexerNewline = 46 ) diff --git a/internal/parser/planparserv2/generated/plan_parser.go b/internal/parser/planparserv2/generated/plan_parser.go index a5368fdfa1ac5..53ed5f70cfc40 100644 --- a/internal/parser/planparserv2/generated/plan_parser.go +++ b/internal/parser/planparserv2/generated/plan_parser.go @@ -15,66 +15,69 @@ var _ = reflect.Copy var _ = strconv.Itoa var parserATN = []uint16{ - 3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 3, 44, 127, + 3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 3, 48, 131, 4, 2, 9, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 7, 2, 20, 10, 2, 12, 2, 14, 2, 23, 11, 2, 3, 2, 5, 2, 26, 10, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, - 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 5, 2, 55, 10, 2, 3, 2, 3, 2, + 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 5, 2, + 59, 10, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, - 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, - 3, 2, 3, 2, 7, 2, 109, 10, 2, 12, 2, 14, 2, 112, 11, 2, 3, 2, 5, 2, 115, - 10, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 7, 2, 122, 10, 2, 12, 2, 14, 2, 125, - 11, 2, 3, 2, 2, 3, 2, 3, 2, 2, 12, 4, 2, 16, 17, 29, 30, 3, 2, 18, 20, - 3, 2, 16, 17, 3, 2, 22, 23, 3, 2, 8, 9, 4, 2, 40, 40, 42, 42, 3, 2, 10, - 11, 3, 2, 8, 11, 3, 2, 12, 13, 3, 2, 31, 32, 2, 157, 2, 54, 3, 2, 2, 2, - 4, 5, 8, 2, 1, 2, 5, 55, 7, 38, 2, 2, 6, 55, 7, 39, 2, 2, 7, 55, 7, 37, - 2, 2, 8, 55, 7, 41, 2, 2, 9, 55, 7, 40, 2, 2, 10, 55, 7, 42, 2, 2, 11, - 12, 7, 3, 2, 2, 12, 13, 5, 2, 2, 2, 13, 14, 7, 4, 2, 2, 14, 55, 3, 2, 2, - 2, 15, 16, 7, 5, 2, 2, 16, 21, 5, 2, 2, 2, 17, 18, 7, 6, 2, 2, 18, 20, - 5, 2, 2, 2, 19, 17, 3, 2, 2, 2, 20, 23, 3, 2, 2, 2, 21, 19, 3, 2, 2, 2, - 21, 22, 3, 2, 2, 2, 22, 25, 3, 2, 2, 2, 23, 21, 3, 2, 2, 2, 24, 26, 7, - 6, 2, 2, 25, 24, 3, 2, 2, 2, 25, 26, 3, 2, 2, 2, 26, 27, 3, 2, 2, 2, 27, - 28, 7, 7, 2, 2, 28, 55, 3, 2, 2, 2, 29, 30, 9, 2, 2, 2, 30, 55, 5, 2, 2, - 21, 31, 32, 7, 34, 2, 2, 32, 33, 7, 3, 2, 2, 33, 34, 5, 2, 2, 2, 34, 35, - 7, 6, 2, 2, 35, 36, 5, 2, 2, 2, 36, 37, 7, 4, 2, 2, 37, 55, 3, 2, 2, 2, - 38, 39, 7, 35, 2, 2, 39, 40, 7, 3, 2, 2, 40, 41, 5, 2, 2, 2, 41, 42, 7, - 6, 2, 2, 42, 43, 5, 2, 2, 2, 43, 44, 7, 4, 2, 2, 44, 55, 3, 2, 2, 2, 45, - 46, 7, 36, 2, 2, 46, 47, 7, 3, 2, 2, 47, 48, 5, 2, 2, 2, 48, 49, 7, 6, - 2, 2, 49, 50, 5, 2, 2, 2, 50, 51, 7, 4, 2, 2, 51, 55, 3, 2, 2, 2, 52, 53, - 7, 15, 2, 2, 53, 55, 5, 2, 2, 3, 54, 4, 3, 2, 2, 2, 54, 6, 3, 2, 2, 2, - 54, 7, 3, 2, 2, 2, 54, 8, 3, 2, 2, 2, 54, 9, 3, 2, 2, 2, 54, 10, 3, 2, - 2, 2, 54, 11, 3, 2, 2, 2, 54, 15, 3, 2, 2, 2, 54, 29, 3, 2, 2, 2, 54, 31, - 3, 2, 2, 2, 54, 38, 3, 2, 2, 2, 54, 45, 3, 2, 2, 2, 54, 52, 3, 2, 2, 2, - 55, 123, 3, 2, 2, 2, 56, 57, 12, 22, 2, 2, 57, 58, 7, 21, 2, 2, 58, 122, - 5, 2, 2, 23, 59, 60, 12, 20, 2, 2, 60, 61, 9, 3, 2, 2, 61, 122, 5, 2, 2, - 21, 62, 63, 12, 19, 2, 2, 63, 64, 9, 4, 2, 2, 64, 122, 5, 2, 2, 20, 65, - 66, 12, 18, 2, 2, 66, 67, 9, 5, 2, 2, 67, 122, 5, 2, 2, 19, 68, 69, 12, - 12, 2, 2, 69, 70, 9, 6, 2, 2, 70, 71, 9, 7, 2, 2, 71, 72, 9, 6, 2, 2, 72, - 122, 5, 2, 2, 13, 73, 74, 12, 11, 2, 2, 74, 75, 9, 8, 2, 2, 75, 76, 9, - 7, 2, 2, 76, 77, 9, 8, 2, 2, 77, 122, 5, 2, 2, 12, 78, 79, 12, 10, 2, 2, - 79, 80, 9, 9, 2, 2, 80, 122, 5, 2, 2, 11, 81, 82, 12, 9, 2, 2, 82, 83, - 9, 10, 2, 2, 83, 122, 5, 2, 2, 10, 84, 85, 12, 8, 2, 2, 85, 86, 7, 24, - 2, 2, 86, 122, 5, 2, 2, 9, 87, 88, 12, 7, 2, 2, 88, 89, 7, 26, 2, 2, 89, - 122, 5, 2, 2, 8, 90, 91, 12, 6, 2, 2, 91, 92, 7, 25, 2, 2, 92, 122, 5, - 2, 2, 7, 93, 94, 12, 5, 2, 2, 94, 95, 7, 27, 2, 2, 95, 122, 5, 2, 2, 6, - 96, 97, 12, 4, 2, 2, 97, 98, 7, 28, 2, 2, 98, 122, 5, 2, 2, 5, 99, 100, - 12, 23, 2, 2, 100, 101, 7, 14, 2, 2, 101, 122, 7, 41, 2, 2, 102, 103, 12, - 17, 2, 2, 103, 104, 9, 11, 2, 2, 104, 105, 7, 5, 2, 2, 105, 110, 5, 2, - 2, 2, 106, 107, 7, 6, 2, 2, 107, 109, 5, 2, 2, 2, 108, 106, 3, 2, 2, 2, - 109, 112, 3, 2, 2, 2, 110, 108, 3, 2, 2, 2, 110, 111, 3, 2, 2, 2, 111, - 114, 3, 2, 2, 2, 112, 110, 3, 2, 2, 2, 113, 115, 7, 6, 2, 2, 114, 113, - 3, 2, 2, 2, 114, 115, 3, 2, 2, 2, 115, 116, 3, 2, 2, 2, 116, 117, 7, 7, - 2, 2, 117, 122, 3, 2, 2, 2, 118, 119, 12, 16, 2, 2, 119, 120, 9, 11, 2, - 2, 120, 122, 7, 33, 2, 2, 121, 56, 3, 2, 2, 2, 121, 59, 3, 2, 2, 2, 121, - 62, 3, 2, 2, 2, 121, 65, 3, 2, 2, 2, 121, 68, 3, 2, 2, 2, 121, 73, 3, 2, - 2, 2, 121, 78, 3, 2, 2, 2, 121, 81, 3, 2, 2, 2, 121, 84, 3, 2, 2, 2, 121, - 87, 3, 2, 2, 2, 121, 90, 3, 2, 2, 2, 121, 93, 3, 2, 2, 2, 121, 96, 3, 2, - 2, 2, 121, 99, 3, 2, 2, 2, 121, 102, 3, 2, 2, 2, 121, 118, 3, 2, 2, 2, - 122, 125, 3, 2, 2, 2, 123, 121, 3, 2, 2, 2, 123, 124, 3, 2, 2, 2, 124, - 3, 3, 2, 2, 2, 125, 123, 3, 2, 2, 2, 9, 21, 25, 54, 110, 114, 121, 123, + 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 7, 2, 113, 10, 2, 12, 2, 14, 2, 116, + 11, 2, 3, 2, 5, 2, 119, 10, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 7, 2, 126, + 10, 2, 12, 2, 14, 2, 129, 11, 2, 3, 2, 2, 3, 2, 3, 2, 2, 15, 4, 2, 16, + 17, 29, 30, 4, 2, 34, 34, 37, 37, 4, 2, 35, 35, 38, 38, 4, 2, 36, 36, 39, + 39, 4, 2, 44, 44, 46, 46, 3, 2, 18, 20, 3, 2, 16, 17, 3, 2, 22, 23, 3, + 2, 8, 9, 3, 2, 10, 11, 3, 2, 8, 11, 3, 2, 12, 13, 3, 2, 31, 32, 2, 162, + 2, 58, 3, 2, 2, 2, 4, 5, 8, 2, 1, 2, 5, 59, 7, 42, 2, 2, 6, 59, 7, 43, + 2, 2, 7, 59, 7, 41, 2, 2, 8, 59, 7, 45, 2, 2, 9, 59, 7, 44, 2, 2, 10, 59, + 7, 46, 2, 2, 11, 12, 7, 3, 2, 2, 12, 13, 5, 2, 2, 2, 13, 14, 7, 4, 2, 2, + 14, 59, 3, 2, 2, 2, 15, 16, 7, 5, 2, 2, 16, 21, 5, 2, 2, 2, 17, 18, 7, + 6, 2, 2, 18, 20, 5, 2, 2, 2, 19, 17, 3, 2, 2, 2, 20, 23, 3, 2, 2, 2, 21, + 19, 3, 2, 2, 2, 21, 22, 3, 2, 2, 2, 22, 25, 3, 2, 2, 2, 23, 21, 3, 2, 2, + 2, 24, 26, 7, 6, 2, 2, 25, 24, 3, 2, 2, 2, 25, 26, 3, 2, 2, 2, 26, 27, + 3, 2, 2, 2, 27, 28, 7, 7, 2, 2, 28, 59, 3, 2, 2, 2, 29, 30, 9, 2, 2, 2, + 30, 59, 5, 2, 2, 22, 31, 32, 9, 3, 2, 2, 32, 33, 7, 3, 2, 2, 33, 34, 5, + 2, 2, 2, 34, 35, 7, 6, 2, 2, 35, 36, 5, 2, 2, 2, 36, 37, 7, 4, 2, 2, 37, + 59, 3, 2, 2, 2, 38, 39, 9, 4, 2, 2, 39, 40, 7, 3, 2, 2, 40, 41, 5, 2, 2, + 2, 41, 42, 7, 6, 2, 2, 42, 43, 5, 2, 2, 2, 43, 44, 7, 4, 2, 2, 44, 59, + 3, 2, 2, 2, 45, 46, 9, 5, 2, 2, 46, 47, 7, 3, 2, 2, 47, 48, 5, 2, 2, 2, + 48, 49, 7, 6, 2, 2, 49, 50, 5, 2, 2, 2, 50, 51, 7, 4, 2, 2, 51, 59, 3, + 2, 2, 2, 52, 53, 7, 40, 2, 2, 53, 54, 7, 3, 2, 2, 54, 55, 9, 6, 2, 2, 55, + 59, 7, 4, 2, 2, 56, 57, 7, 15, 2, 2, 57, 59, 5, 2, 2, 3, 58, 4, 3, 2, 2, + 2, 58, 6, 3, 2, 2, 2, 58, 7, 3, 2, 2, 2, 58, 8, 3, 2, 2, 2, 58, 9, 3, 2, + 2, 2, 58, 10, 3, 2, 2, 2, 58, 11, 3, 2, 2, 2, 58, 15, 3, 2, 2, 2, 58, 29, + 3, 2, 2, 2, 58, 31, 3, 2, 2, 2, 58, 38, 3, 2, 2, 2, 58, 45, 3, 2, 2, 2, + 58, 52, 3, 2, 2, 2, 58, 56, 3, 2, 2, 2, 59, 127, 3, 2, 2, 2, 60, 61, 12, + 23, 2, 2, 61, 62, 7, 21, 2, 2, 62, 126, 5, 2, 2, 24, 63, 64, 12, 21, 2, + 2, 64, 65, 9, 7, 2, 2, 65, 126, 5, 2, 2, 22, 66, 67, 12, 20, 2, 2, 67, + 68, 9, 8, 2, 2, 68, 126, 5, 2, 2, 21, 69, 70, 12, 19, 2, 2, 70, 71, 9, + 9, 2, 2, 71, 126, 5, 2, 2, 20, 72, 73, 12, 12, 2, 2, 73, 74, 9, 10, 2, + 2, 74, 75, 9, 6, 2, 2, 75, 76, 9, 10, 2, 2, 76, 126, 5, 2, 2, 13, 77, 78, + 12, 11, 2, 2, 78, 79, 9, 11, 2, 2, 79, 80, 9, 6, 2, 2, 80, 81, 9, 11, 2, + 2, 81, 126, 5, 2, 2, 12, 82, 83, 12, 10, 2, 2, 83, 84, 9, 12, 2, 2, 84, + 126, 5, 2, 2, 11, 85, 86, 12, 9, 2, 2, 86, 87, 9, 13, 2, 2, 87, 126, 5, + 2, 2, 10, 88, 89, 12, 8, 2, 2, 89, 90, 7, 24, 2, 2, 90, 126, 5, 2, 2, 9, + 91, 92, 12, 7, 2, 2, 92, 93, 7, 26, 2, 2, 93, 126, 5, 2, 2, 8, 94, 95, + 12, 6, 2, 2, 95, 96, 7, 25, 2, 2, 96, 126, 5, 2, 2, 7, 97, 98, 12, 5, 2, + 2, 98, 99, 7, 27, 2, 2, 99, 126, 5, 2, 2, 6, 100, 101, 12, 4, 2, 2, 101, + 102, 7, 28, 2, 2, 102, 126, 5, 2, 2, 5, 103, 104, 12, 24, 2, 2, 104, 105, + 7, 14, 2, 2, 105, 126, 7, 45, 2, 2, 106, 107, 12, 18, 2, 2, 107, 108, 9, + 14, 2, 2, 108, 109, 7, 5, 2, 2, 109, 114, 5, 2, 2, 2, 110, 111, 7, 6, 2, + 2, 111, 113, 5, 2, 2, 2, 112, 110, 3, 2, 2, 2, 113, 116, 3, 2, 2, 2, 114, + 112, 3, 2, 2, 2, 114, 115, 3, 2, 2, 2, 115, 118, 3, 2, 2, 2, 116, 114, + 3, 2, 2, 2, 117, 119, 7, 6, 2, 2, 118, 117, 3, 2, 2, 2, 118, 119, 3, 2, + 2, 2, 119, 120, 3, 2, 2, 2, 120, 121, 7, 7, 2, 2, 121, 126, 3, 2, 2, 2, + 122, 123, 12, 17, 2, 2, 123, 124, 9, 14, 2, 2, 124, 126, 7, 33, 2, 2, 125, + 60, 3, 2, 2, 2, 125, 63, 3, 2, 2, 2, 125, 66, 3, 2, 2, 2, 125, 69, 3, 2, + 2, 2, 125, 72, 3, 2, 2, 2, 125, 77, 3, 2, 2, 2, 125, 82, 3, 2, 2, 2, 125, + 85, 3, 2, 2, 2, 125, 88, 3, 2, 2, 2, 125, 91, 3, 2, 2, 2, 125, 94, 3, 2, + 2, 2, 125, 97, 3, 2, 2, 2, 125, 100, 3, 2, 2, 2, 125, 103, 3, 2, 2, 2, + 125, 106, 3, 2, 2, 2, 125, 122, 3, 2, 2, 2, 126, 129, 3, 2, 2, 2, 127, + 125, 3, 2, 2, 2, 127, 128, 3, 2, 2, 2, 128, 3, 3, 2, 2, 2, 129, 127, 3, + 2, 2, 2, 9, 21, 25, 58, 114, 118, 125, 127, } var literalNames = []string{ "", "'('", "')'", "'['", "','", "']'", "'<'", "'<='", "'>'", "'>='", "'=='", @@ -85,7 +88,8 @@ var symbolicNames = []string{ "", "", "", "", "", "", "LT", "LE", "GT", "GE", "EQ", "NE", "LIKE", "EXISTS", "ADD", "SUB", "MUL", "DIV", "MOD", "POW", "SHL", "SHR", "BAND", "BOR", "BXOR", "AND", "OR", "BNOT", "NOT", "IN", "NIN", "EmptyTerm", "JSONContains", - "JSONContainsAll", "JSONContainsAny", "BooleanConstant", "IntegerConstant", + "JSONContainsAll", "JSONContainsAny", "ArrayContains", "ArrayContainsAll", + "ArrayContainsAny", "ArrayLength", "BooleanConstant", "IntegerConstant", "FloatingConstant", "Identifier", "StringLiteral", "JSONIdentifier", "Whitespace", "Newline", } @@ -160,14 +164,18 @@ const ( PlanParserJSONContains = 32 PlanParserJSONContainsAll = 33 PlanParserJSONContainsAny = 34 - PlanParserBooleanConstant = 35 - PlanParserIntegerConstant = 36 - PlanParserFloatingConstant = 37 - PlanParserIdentifier = 38 - PlanParserStringLiteral = 39 - PlanParserJSONIdentifier = 40 - PlanParserWhitespace = 41 - PlanParserNewline = 42 + PlanParserArrayContains = 35 + PlanParserArrayContainsAll = 36 + PlanParserArrayContainsAny = 37 + PlanParserArrayLength = 38 + PlanParserBooleanConstant = 39 + PlanParserIntegerConstant = 40 + PlanParserFloatingConstant = 41 + PlanParserIdentifier = 42 + PlanParserStringLiteral = 43 + PlanParserJSONIdentifier = 44 + PlanParserWhitespace = 45 + PlanParserNewline = 46 ) // PlanParserRULE_expr is the PlanParser rule. @@ -375,10 +383,6 @@ func (s *JSONContainsAllContext) GetRuleContext() antlr.RuleContext { return s } -func (s *JSONContainsAllContext) JSONContainsAll() antlr.TerminalNode { - return s.GetToken(PlanParserJSONContainsAll, 0) -} - func (s *JSONContainsAllContext) AllExpr() []IExprContext { var ts = s.GetTypedRuleContexts(reflect.TypeOf((*IExprContext)(nil)).Elem()) var tst = make([]IExprContext, len(ts)) @@ -402,6 +406,14 @@ func (s *JSONContainsAllContext) Expr(i int) IExprContext { return t.(IExprContext) } +func (s *JSONContainsAllContext) JSONContainsAll() antlr.TerminalNode { + return s.GetToken(PlanParserJSONContainsAll, 0) +} + +func (s *JSONContainsAllContext) ArrayContainsAll() antlr.TerminalNode { + return s.GetToken(PlanParserArrayContainsAll, 0) +} + func (s *JSONContainsAllContext) Accept(visitor antlr.ParseTreeVisitor) interface{} { switch t := visitor.(type) { case PlanVisitor: @@ -1104,6 +1116,46 @@ func (s *RelationalContext) Accept(visitor antlr.ParseTreeVisitor) interface{} { } } +type ArrayLengthContext struct { + *ExprContext +} + +func NewArrayLengthContext(parser antlr.Parser, ctx antlr.ParserRuleContext) *ArrayLengthContext { + var p = new(ArrayLengthContext) + + p.ExprContext = NewEmptyExprContext() + p.parser = parser + p.CopyFrom(ctx.(*ExprContext)) + + return p +} + +func (s *ArrayLengthContext) GetRuleContext() antlr.RuleContext { + return s +} + +func (s *ArrayLengthContext) ArrayLength() antlr.TerminalNode { + return s.GetToken(PlanParserArrayLength, 0) +} + +func (s *ArrayLengthContext) Identifier() antlr.TerminalNode { + return s.GetToken(PlanParserIdentifier, 0) +} + +func (s *ArrayLengthContext) JSONIdentifier() antlr.TerminalNode { + return s.GetToken(PlanParserJSONIdentifier, 0) +} + +func (s *ArrayLengthContext) Accept(visitor antlr.ParseTreeVisitor) interface{} { + switch t := visitor.(type) { + case PlanVisitor: + return t.VisitArrayLength(s) + + default: + return t.VisitChildren(s) + } +} + type TermContext struct { *ExprContext op antlr.Token @@ -1186,10 +1238,6 @@ func (s *JSONContainsContext) GetRuleContext() antlr.RuleContext { return s } -func (s *JSONContainsContext) JSONContains() antlr.TerminalNode { - return s.GetToken(PlanParserJSONContains, 0) -} - func (s *JSONContainsContext) AllExpr() []IExprContext { var ts = s.GetTypedRuleContexts(reflect.TypeOf((*IExprContext)(nil)).Elem()) var tst = make([]IExprContext, len(ts)) @@ -1213,6 +1261,14 @@ func (s *JSONContainsContext) Expr(i int) IExprContext { return t.(IExprContext) } +func (s *JSONContainsContext) JSONContains() antlr.TerminalNode { + return s.GetToken(PlanParserJSONContains, 0) +} + +func (s *JSONContainsContext) ArrayContains() antlr.TerminalNode { + return s.GetToken(PlanParserArrayContains, 0) +} + func (s *JSONContainsContext) Accept(visitor antlr.ParseTreeVisitor) interface{} { switch t := visitor.(type) { case PlanVisitor: @@ -1468,10 +1524,6 @@ func (s *JSONContainsAnyContext) GetRuleContext() antlr.RuleContext { return s } -func (s *JSONContainsAnyContext) JSONContainsAny() antlr.TerminalNode { - return s.GetToken(PlanParserJSONContainsAny, 0) -} - func (s *JSONContainsAnyContext) AllExpr() []IExprContext { var ts = s.GetTypedRuleContexts(reflect.TypeOf((*IExprContext)(nil)).Elem()) var tst = make([]IExprContext, len(ts)) @@ -1495,6 +1547,14 @@ func (s *JSONContainsAnyContext) Expr(i int) IExprContext { return t.(IExprContext) } +func (s *JSONContainsAnyContext) JSONContainsAny() antlr.TerminalNode { + return s.GetToken(PlanParserJSONContainsAny, 0) +} + +func (s *JSONContainsAnyContext) ArrayContainsAny() antlr.TerminalNode { + return s.GetToken(PlanParserArrayContainsAny, 0) +} + func (s *JSONContainsAnyContext) Accept(visitor antlr.ParseTreeVisitor) interface{} { switch t := visitor.(type) { case PlanVisitor: @@ -1800,7 +1860,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { var _alt int p.EnterOuterAlt(localctx, 1) - p.SetState(52) + p.SetState(56) p.GetErrorHandler().Sync(p) switch p.GetTokenStream().LA(1) { @@ -1948,16 +2008,23 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } { p.SetState(28) - p.expr(19) + p.expr(20) } - case PlanParserJSONContains: + case PlanParserJSONContains, PlanParserArrayContains: localctx = NewJSONContainsContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx { p.SetState(29) - p.Match(PlanParserJSONContains) + _la = p.GetTokenStream().LA(1) + + if !(_la == PlanParserJSONContains || _la == PlanParserArrayContains) { + p.GetErrorHandler().RecoverInline(p) + } else { + p.GetErrorHandler().ReportMatch(p) + p.Consume() + } } { p.SetState(30) @@ -1980,13 +2047,20 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { p.Match(PlanParserT__1) } - case PlanParserJSONContainsAll: + case PlanParserJSONContainsAll, PlanParserArrayContainsAll: localctx = NewJSONContainsAllContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx { p.SetState(36) - p.Match(PlanParserJSONContainsAll) + _la = p.GetTokenStream().LA(1) + + if !(_la == PlanParserJSONContainsAll || _la == PlanParserArrayContainsAll) { + p.GetErrorHandler().RecoverInline(p) + } else { + p.GetErrorHandler().ReportMatch(p) + p.Consume() + } } { p.SetState(37) @@ -2009,13 +2083,20 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { p.Match(PlanParserT__1) } - case PlanParserJSONContainsAny: + case PlanParserJSONContainsAny, PlanParserArrayContainsAny: localctx = NewJSONContainsAnyContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx { p.SetState(43) - p.Match(PlanParserJSONContainsAny) + _la = p.GetTokenStream().LA(1) + + if !(_la == PlanParserJSONContainsAny || _la == PlanParserArrayContainsAny) { + p.GetErrorHandler().RecoverInline(p) + } else { + p.GetErrorHandler().ReportMatch(p) + p.Consume() + } } { p.SetState(44) @@ -2038,16 +2119,44 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { p.Match(PlanParserT__1) } + case PlanParserArrayLength: + localctx = NewArrayLengthContext(p, localctx) + p.SetParserRuleContext(localctx) + _prevctx = localctx + { + p.SetState(50) + p.Match(PlanParserArrayLength) + } + { + p.SetState(51) + p.Match(PlanParserT__0) + } + { + p.SetState(52) + _la = p.GetTokenStream().LA(1) + + if !(_la == PlanParserIdentifier || _la == PlanParserJSONIdentifier) { + p.GetErrorHandler().RecoverInline(p) + } else { + p.GetErrorHandler().ReportMatch(p) + p.Consume() + } + } + { + p.SetState(53) + p.Match(PlanParserT__1) + } + case PlanParserEXISTS: localctx = NewExistsContext(p, localctx) p.SetParserRuleContext(localctx) _prevctx = localctx { - p.SetState(50) + p.SetState(54) p.Match(PlanParserEXISTS) } { - p.SetState(51) + p.SetState(55) p.expr(1) } @@ -2055,7 +2164,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { panic(antlr.NewNoViableAltException(p, nil, nil, nil, nil, nil)) } p.GetParserRuleContext().SetStop(p.GetTokenStream().LT(-1)) - p.SetState(121) + p.SetState(125) p.GetErrorHandler().Sync(p) _alt = p.GetInterpreter().AdaptivePredict(p.GetTokenStream(), 6, p.GetParserRuleContext()) @@ -2065,36 +2174,36 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { p.TriggerExitRuleEvent() } _prevctx = localctx - p.SetState(119) + p.SetState(123) p.GetErrorHandler().Sync(p) switch p.GetInterpreter().AdaptivePredict(p.GetTokenStream(), 5, p.GetParserRuleContext()) { case 1: localctx = NewPowerContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(54) + p.SetState(58) - if !(p.Precpred(p.GetParserRuleContext(), 20)) { - panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 20)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 21)) { + panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 21)", "")) } { - p.SetState(55) + p.SetState(59) p.Match(PlanParserPOW) } { - p.SetState(56) - p.expr(21) + p.SetState(60) + p.expr(22) } case 2: localctx = NewMulDivModContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(57) + p.SetState(61) - if !(p.Precpred(p.GetParserRuleContext(), 18)) { - panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 18)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 19)) { + panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 19)", "")) } { - p.SetState(58) + p.SetState(62) var _lt = p.GetTokenStream().LT(1) @@ -2112,20 +2221,20 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(59) - p.expr(19) + p.SetState(63) + p.expr(20) } case 3: localctx = NewAddSubContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(60) + p.SetState(64) - if !(p.Precpred(p.GetParserRuleContext(), 17)) { - panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 17)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 18)) { + panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 18)", "")) } { - p.SetState(61) + p.SetState(65) var _lt = p.GetTokenStream().LT(1) @@ -2143,20 +2252,20 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(62) - p.expr(18) + p.SetState(66) + p.expr(19) } case 4: localctx = NewShiftContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(63) + p.SetState(67) - if !(p.Precpred(p.GetParserRuleContext(), 16)) { - panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 16)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 17)) { + panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 17)", "")) } { - p.SetState(64) + p.SetState(68) var _lt = p.GetTokenStream().LT(1) @@ -2174,20 +2283,20 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(65) - p.expr(17) + p.SetState(69) + p.expr(18) } case 5: localctx = NewRangeContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(66) + p.SetState(70) if !(p.Precpred(p.GetParserRuleContext(), 10)) { panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 10)", "")) } { - p.SetState(67) + p.SetState(71) var _lt = p.GetTokenStream().LT(1) @@ -2205,7 +2314,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(68) + p.SetState(72) _la = p.GetTokenStream().LA(1) if !(_la == PlanParserIdentifier || _la == PlanParserJSONIdentifier) { @@ -2216,7 +2325,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(69) + p.SetState(73) var _lt = p.GetTokenStream().LT(1) @@ -2234,20 +2343,20 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(70) + p.SetState(74) p.expr(11) } case 6: localctx = NewReverseRangeContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(71) + p.SetState(75) if !(p.Precpred(p.GetParserRuleContext(), 9)) { panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 9)", "")) } { - p.SetState(72) + p.SetState(76) var _lt = p.GetTokenStream().LT(1) @@ -2265,7 +2374,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(73) + p.SetState(77) _la = p.GetTokenStream().LA(1) if !(_la == PlanParserIdentifier || _la == PlanParserJSONIdentifier) { @@ -2276,7 +2385,7 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(74) + p.SetState(78) var _lt = p.GetTokenStream().LT(1) @@ -2294,20 +2403,20 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(75) + p.SetState(79) p.expr(10) } case 7: localctx = NewRelationalContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(76) + p.SetState(80) if !(p.Precpred(p.GetParserRuleContext(), 8)) { panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 8)", "")) } { - p.SetState(77) + p.SetState(81) var _lt = p.GetTokenStream().LT(1) @@ -2325,20 +2434,20 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(78) + p.SetState(82) p.expr(9) } case 8: localctx = NewEqualityContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(79) + p.SetState(83) if !(p.Precpred(p.GetParserRuleContext(), 7)) { panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 7)", "")) } { - p.SetState(80) + p.SetState(84) var _lt = p.GetTokenStream().LT(1) @@ -2356,122 +2465,122 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(81) + p.SetState(85) p.expr(8) } case 9: localctx = NewBitAndContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(82) + p.SetState(86) if !(p.Precpred(p.GetParserRuleContext(), 6)) { panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 6)", "")) } { - p.SetState(83) + p.SetState(87) p.Match(PlanParserBAND) } { - p.SetState(84) + p.SetState(88) p.expr(7) } case 10: localctx = NewBitXorContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(85) + p.SetState(89) if !(p.Precpred(p.GetParserRuleContext(), 5)) { panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 5)", "")) } { - p.SetState(86) + p.SetState(90) p.Match(PlanParserBXOR) } { - p.SetState(87) + p.SetState(91) p.expr(6) } case 11: localctx = NewBitOrContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(88) + p.SetState(92) if !(p.Precpred(p.GetParserRuleContext(), 4)) { panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 4)", "")) } { - p.SetState(89) + p.SetState(93) p.Match(PlanParserBOR) } { - p.SetState(90) + p.SetState(94) p.expr(5) } case 12: localctx = NewLogicalAndContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(91) + p.SetState(95) if !(p.Precpred(p.GetParserRuleContext(), 3)) { panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 3)", "")) } { - p.SetState(92) + p.SetState(96) p.Match(PlanParserAND) } { - p.SetState(93) + p.SetState(97) p.expr(4) } case 13: localctx = NewLogicalOrContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(94) + p.SetState(98) if !(p.Precpred(p.GetParserRuleContext(), 2)) { panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 2)", "")) } { - p.SetState(95) + p.SetState(99) p.Match(PlanParserOR) } { - p.SetState(96) + p.SetState(100) p.expr(3) } case 14: localctx = NewLikeContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(97) + p.SetState(101) - if !(p.Precpred(p.GetParserRuleContext(), 21)) { - panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 21)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 22)) { + panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 22)", "")) } { - p.SetState(98) + p.SetState(102) p.Match(PlanParserLIKE) } { - p.SetState(99) + p.SetState(103) p.Match(PlanParserStringLiteral) } case 15: localctx = NewTermContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(100) + p.SetState(104) - if !(p.Precpred(p.GetParserRuleContext(), 15)) { - panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 15)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 16)) { + panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 16)", "")) } { - p.SetState(101) + p.SetState(105) var _lt = p.GetTokenStream().LT(1) @@ -2490,59 +2599,59 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } { - p.SetState(102) + p.SetState(106) p.Match(PlanParserT__2) } { - p.SetState(103) + p.SetState(107) p.expr(0) } - p.SetState(108) + p.SetState(112) p.GetErrorHandler().Sync(p) _alt = p.GetInterpreter().AdaptivePredict(p.GetTokenStream(), 3, p.GetParserRuleContext()) for _alt != 2 && _alt != antlr.ATNInvalidAltNumber { if _alt == 1 { { - p.SetState(104) + p.SetState(108) p.Match(PlanParserT__3) } { - p.SetState(105) + p.SetState(109) p.expr(0) } } - p.SetState(110) + p.SetState(114) p.GetErrorHandler().Sync(p) _alt = p.GetInterpreter().AdaptivePredict(p.GetTokenStream(), 3, p.GetParserRuleContext()) } - p.SetState(112) + p.SetState(116) p.GetErrorHandler().Sync(p) _la = p.GetTokenStream().LA(1) if _la == PlanParserT__3 { { - p.SetState(111) + p.SetState(115) p.Match(PlanParserT__3) } } { - p.SetState(114) + p.SetState(118) p.Match(PlanParserT__4) } case 16: localctx = NewEmptyTermContext(p, NewExprContext(p, _parentctx, _parentState)) p.PushNewRecursionContext(localctx, _startState, PlanParserRULE_expr) - p.SetState(116) + p.SetState(120) - if !(p.Precpred(p.GetParserRuleContext(), 14)) { - panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 14)", "")) + if !(p.Precpred(p.GetParserRuleContext(), 15)) { + panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 15)", "")) } { - p.SetState(117) + p.SetState(121) var _lt = p.GetTokenStream().LT(1) @@ -2560,14 +2669,14 @@ func (p *PlanParser) expr(_p int) (localctx IExprContext) { } } { - p.SetState(118) + p.SetState(122) p.Match(PlanParserEmptyTerm) } } } - p.SetState(123) + p.SetState(127) p.GetErrorHandler().Sync(p) _alt = p.GetInterpreter().AdaptivePredict(p.GetTokenStream(), 6, p.GetParserRuleContext()) } @@ -2592,16 +2701,16 @@ func (p *PlanParser) Sempred(localctx antlr.RuleContext, ruleIndex, predIndex in func (p *PlanParser) Expr_Sempred(localctx antlr.RuleContext, predIndex int) bool { switch predIndex { case 0: - return p.Precpred(p.GetParserRuleContext(), 20) + return p.Precpred(p.GetParserRuleContext(), 21) case 1: - return p.Precpred(p.GetParserRuleContext(), 18) + return p.Precpred(p.GetParserRuleContext(), 19) case 2: - return p.Precpred(p.GetParserRuleContext(), 17) + return p.Precpred(p.GetParserRuleContext(), 18) case 3: - return p.Precpred(p.GetParserRuleContext(), 16) + return p.Precpred(p.GetParserRuleContext(), 17) case 4: return p.Precpred(p.GetParserRuleContext(), 10) @@ -2631,13 +2740,13 @@ func (p *PlanParser) Expr_Sempred(localctx antlr.RuleContext, predIndex int) boo return p.Precpred(p.GetParserRuleContext(), 2) case 13: - return p.Precpred(p.GetParserRuleContext(), 21) + return p.Precpred(p.GetParserRuleContext(), 22) case 14: - return p.Precpred(p.GetParserRuleContext(), 15) + return p.Precpred(p.GetParserRuleContext(), 16) case 15: - return p.Precpred(p.GetParserRuleContext(), 14) + return p.Precpred(p.GetParserRuleContext(), 15) default: panic("No predicate with index: " + fmt.Sprint(predIndex)) diff --git a/internal/parser/planparserv2/generated/plan_visitor.go b/internal/parser/planparserv2/generated/plan_visitor.go index cb7c5c76472bf..ebe1bb63c0c80 100644 --- a/internal/parser/planparserv2/generated/plan_visitor.go +++ b/internal/parser/planparserv2/generated/plan_visitor.go @@ -58,6 +58,9 @@ type PlanVisitor interface { // Visit a parse tree produced by PlanParser#Relational. VisitRelational(ctx *RelationalContext) interface{} + // Visit a parse tree produced by PlanParser#ArrayLength. + VisitArrayLength(ctx *ArrayLengthContext) interface{} + // Visit a parse tree produced by PlanParser#Term. VisitTerm(ctx *TermContext) interface{} diff --git a/internal/parser/planparserv2/parser_visitor.go b/internal/parser/planparserv2/parser_visitor.go index 975cb1a661439..3476b2d44ced2 100644 --- a/internal/parser/planparserv2/parser_visitor.go +++ b/internal/parser/planparserv2/parser_visitor.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/antlr/antlr4/runtime/Go/antlr" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" parser "github.com/milvus-io/milvus/internal/parser/planparserv2/generated" "github.com/milvus-io/milvus/internal/proto/planpb" @@ -49,6 +50,7 @@ func (v *ParserVisitor) translateIdentifier(identifier string) (*ExprWithType, e IsAutoID: field.AutoID, NestedPath: nestedPath, IsPartitionKey: field.IsPartitionKey, + ElementType: field.GetElementType(), }, }, }, @@ -147,6 +149,13 @@ func (v *ParserVisitor) VisitString(ctx *parser.StringContext) interface{} { } } +func checkDirectComparisonBinaryField(columnInfo *planpb.ColumnInfo) error { + if typeutil.IsArrayType(columnInfo.GetDataType()) && len(columnInfo.GetNestedPath()) == 0 { + return fmt.Errorf("can not comparisons array fields directly") + } + return nil +} + // VisitAddSub translates expr to arithmetic plan. func (v *ParserVisitor) VisitAddSub(ctx *parser.AddSubContext) interface{} { left := ctx.Expr(0).Accept(v) @@ -191,12 +200,13 @@ func (v *ParserVisitor) VisitAddSub(ctx *parser.AddSubContext) interface{} { return fmt.Errorf("invalid arithmetic expression, left: %s, op: %s, right: %s", ctx.Expr(0).GetText(), ctx.GetOp(), ctx.Expr(1).GetText()) } - if typeutil.IsArrayType(leftExpr.dataType) || typeutil.IsArrayType(rightExpr.dataType) { - return fmt.Errorf("invalid expression, array is not supported for AddSub") + if err := checkDirectComparisonBinaryField(toColumnInfo(leftExpr)); err != nil { + return err } - - if (!typeutil.IsArithmetic(leftExpr.dataType) && !typeutil.IsJSONType(leftExpr.dataType)) || - (!typeutil.IsArithmetic(rightExpr.dataType) && !typeutil.IsJSONType(rightExpr.dataType)) { + if err := checkDirectComparisonBinaryField(toColumnInfo(rightExpr)); err != nil { + return err + } + if !canArithmetic(leftExpr, rightExpr) { return fmt.Errorf("'%s' can only be used between integer or floating or json field expressions", arithNameMap[ctx.GetOp().GetTokenType()]) } @@ -274,19 +284,19 @@ func (v *ParserVisitor) VisitMulDivMod(ctx *parser.MulDivModContext) interface{} return fmt.Errorf("invalid arithmetic expression, left: %s, op: %s, right: %s", ctx.Expr(0).GetText(), ctx.GetOp(), ctx.Expr(1).GetText()) } - if typeutil.IsArrayType(leftExpr.dataType) || typeutil.IsArrayType(rightExpr.dataType) { - return fmt.Errorf("invalid expression, array is not supported for MulDivMod") + if err := checkDirectComparisonBinaryField(toColumnInfo(leftExpr)); err != nil { + return err } - - if (!typeutil.IsArithmetic(leftExpr.dataType) && !typeutil.IsJSONType(leftExpr.dataType)) || - (!typeutil.IsArithmetic(rightExpr.dataType) && !typeutil.IsJSONType(rightExpr.dataType)) { - return fmt.Errorf("'%s' can only be used between integer or floating expressions", arithNameMap[ctx.GetOp().GetTokenType()]) + if err := checkDirectComparisonBinaryField(toColumnInfo(rightExpr)); err != nil { + return err + } + if !canArithmetic(leftExpr, rightExpr) { + return fmt.Errorf("'%s' can only be used between integer or floating or json field expressions", arithNameMap[ctx.GetOp().GetTokenType()]) } switch ctx.GetOp().GetTokenType() { case parser.PlanParserMOD: - if (!typeutil.IsIntegerType(leftExpr.dataType) && !typeutil.IsJSONType(leftExpr.dataType)) || - (!typeutil.IsIntegerType(rightExpr.dataType) && !typeutil.IsJSONType(rightExpr.dataType)) { + if !isIntegerColumn(toColumnInfo(leftExpr)) && !isIntegerColumn(toColumnInfo(rightExpr)) { return fmt.Errorf("modulo can only apply on integer types") } default: @@ -349,10 +359,6 @@ func (v *ParserVisitor) VisitEquality(ctx *parser.EqualityContext) interface{} { rightExpr = getExpr(right) } - if typeutil.IsArrayType(leftExpr.dataType) || typeutil.IsArrayType(rightExpr.dataType) { - return fmt.Errorf("invalid expression, array is not supported for Equality") - } - expr, err := HandleCompare(ctx.GetOp().GetTokenType(), leftExpr, rightExpr) if err != nil { return err @@ -405,9 +411,11 @@ func (v *ParserVisitor) VisitRelational(ctx *parser.RelationalContext) interface } else { rightExpr = getExpr(right) } - - if typeutil.IsArrayType(leftExpr.dataType) || typeutil.IsArrayType(rightExpr.dataType) { - return fmt.Errorf("invalid expression, array is not supported for Relational") + if err := checkDirectComparisonBinaryField(toColumnInfo(leftExpr)); err != nil { + return err + } + if err := checkDirectComparisonBinaryField(toColumnInfo(rightExpr)); err != nil { + return err } expr, err := HandleCompare(ctx.GetOp().GetTokenType(), leftExpr, rightExpr) @@ -433,14 +441,19 @@ func (v *ParserVisitor) VisitLike(ctx *parser.LikeContext) interface{} { return fmt.Errorf("the left operand of like is invalid") } - if !typeutil.IsStringType(leftExpr.dataType) && !typeutil.IsJSONType(leftExpr.dataType) { - return fmt.Errorf("like operation on non-string or no-json field is unsupported") - } - column := toColumnInfo(leftExpr) if column == nil { return fmt.Errorf("like operation on complicated expr is unsupported") } + if err := checkDirectComparisonBinaryField(column); err != nil { + return err + } + + if !typeutil.IsStringType(leftExpr.dataType) && !typeutil.IsJSONType(leftExpr.dataType) && + !(typeutil.IsArrayType(leftExpr.dataType) && typeutil.IsStringType(column.GetElementType())) { + return fmt.Errorf("like operation on non-string or no-json field is unsupported") + } + pattern, err := convertEscapeSingle(ctx.StringLiteral().GetText()) if err != nil { return err @@ -482,6 +495,10 @@ func (v *ParserVisitor) VisitTerm(ctx *parser.TermContext) interface{} { return fmt.Errorf("'term' can only be used on single field, but got: %s", ctx.Expr(0).GetText()) } + dataType := columnInfo.GetDataType() + if typeutil.IsArrayType(dataType) && len(columnInfo.GetNestedPath()) != 0 { + dataType = columnInfo.GetElementType() + } allExpr := ctx.AllExpr() lenOfAllExpr := len(allExpr) values := make([]*planpb.GenericValue, 0, lenOfAllExpr) @@ -494,9 +511,9 @@ func (v *ParserVisitor) VisitTerm(ctx *parser.TermContext) interface{} { if n == nil { return fmt.Errorf("value '%s' in list cannot be a non-const expression", ctx.Expr(i).GetText()) } - castedValue, err := castValue(childExpr.dataType, n) + castedValue, err := castValue(dataType, n) if err != nil { - return fmt.Errorf("value '%s' in list cannot be casted to %s", ctx.Expr(i).GetText(), childExpr.dataType.String()) + return fmt.Errorf("value '%s' in list cannot be casted to %s", ctx.Expr(i).GetText(), dataType.String()) } values = append(values, castedValue) } @@ -543,6 +560,9 @@ func (v *ParserVisitor) VisitEmptyTerm(ctx *parser.EmptyTermContext) interface{} if columnInfo == nil { return fmt.Errorf("'term' can only be used on single field, but got: %s", ctx.Expr().GetText()) } + if err := checkDirectComparisonBinaryField(columnInfo); err != nil { + return err + } expr := &planpb.Expr{ Expr: &planpb.Expr_TermExpr{ @@ -589,6 +609,9 @@ func (v *ParserVisitor) VisitRange(ctx *parser.RangeContext) interface{} { if columnInfo == nil { return fmt.Errorf("range operations are only supported on single fields now, got: %s", ctx.Expr(1).GetText()) } + if err := checkDirectComparisonBinaryField(columnInfo); err != nil { + return err + } lower := ctx.Expr(0).Accept(v) upper := ctx.Expr(1).Accept(v) @@ -608,7 +631,12 @@ func (v *ParserVisitor) VisitRange(ctx *parser.RangeContext) interface{} { return fmt.Errorf("upperbound cannot be a non-const expression: %s", ctx.Expr(1).GetText()) } - switch columnInfo.GetDataType() { + fieldDataType := columnInfo.GetDataType() + if typeutil.IsArrayType(columnInfo.GetDataType()) { + fieldDataType = columnInfo.GetElementType() + } + + switch fieldDataType { case schemapb.DataType_String, schemapb.DataType_VarChar: if !IsString(lowerValue) || !IsString(upperValue) { return fmt.Errorf("invalid range operations") @@ -671,6 +699,10 @@ func (v *ParserVisitor) VisitReverseRange(ctx *parser.ReverseRangeContext) inter return fmt.Errorf("range operations are only supported on single fields now, got: %s", ctx.Expr(1).GetText()) } + if err := checkDirectComparisonBinaryField(columnInfo); err != nil { + return err + } + lower := ctx.Expr(1).Accept(v) upper := ctx.Expr(0).Accept(v) if err := getError(lower); err != nil { @@ -767,6 +799,9 @@ func (v *ParserVisitor) VisitUnary(ctx *parser.UnaryContext) interface{} { if childExpr == nil { return fmt.Errorf("failed to parse unary expressions") } + if err := checkDirectComparisonBinaryField(toColumnInfo(childExpr)); err != nil { + return err + } switch ctx.GetOp().GetTokenType() { case parser.PlanParserADD: return childExpr @@ -983,25 +1018,25 @@ func (v *ParserVisitor) getColumnInfoFromJSONIdentifier(identifier string) (*pla if path == "" { return nil, fmt.Errorf("invalid identifier: %s", identifier) } + if typeutil.IsArrayType(field.DataType) { + return nil, fmt.Errorf("can only access array field with integer index") + } } else if _, err := strconv.ParseInt(path, 10, 64); err != nil { return nil, fmt.Errorf("json key must be enclosed in double quotes or single quotes: \"%s\"", path) } nestedPath = append(nestedPath, path) } - if typeutil.IsJSONType(field.DataType) && len(nestedPath) == 0 { - return nil, fmt.Errorf("can not comparisons jsonField directly") - } - return &planpb.ColumnInfo{ - FieldId: field.FieldID, - DataType: field.DataType, - NestedPath: nestedPath, + FieldId: field.FieldID, + DataType: field.DataType, + NestedPath: nestedPath, + ElementType: field.GetElementType(), }, nil } func (v *ParserVisitor) VisitJSONIdentifier(ctx *parser.JSONIdentifierContext) interface{} { - jsonField, err := v.getColumnInfoFromJSONIdentifier(ctx.JSONIdentifier().GetText()) + field, err := v.getColumnInfoFromJSONIdentifier(ctx.JSONIdentifier().GetText()) if err != nil { return err } @@ -1010,14 +1045,15 @@ func (v *ParserVisitor) VisitJSONIdentifier(ctx *parser.JSONIdentifierContext) i Expr: &planpb.Expr_ColumnExpr{ ColumnExpr: &planpb.ColumnExpr{ Info: &planpb.ColumnInfo{ - FieldId: jsonField.GetFieldId(), - DataType: jsonField.GetDataType(), - NestedPath: jsonField.GetNestedPath(), + FieldId: field.GetFieldId(), + DataType: field.GetDataType(), + NestedPath: field.GetNestedPath(), + ElementType: field.GetElementType(), }, }, }, }, - dataType: jsonField.GetDataType(), + dataType: field.GetDataType(), nodeDependent: true, } } @@ -1037,6 +1073,7 @@ func (v *ParserVisitor) VisitExists(ctx *parser.ExistsContext) interface{} { return fmt.Errorf( "exists oerations are only supportted on json field, got:%s", columnInfo.GetDataType()) } + return &ExprWithType{ expr: &planpb.Expr{ Expr: &planpb.Expr_ExistsExpr{ @@ -1053,47 +1090,6 @@ func (v *ParserVisitor) VisitExists(ctx *parser.ExistsContext) interface{} { } } -func (v *ParserVisitor) VisitJSONContains(ctx *parser.JSONContainsContext) interface{} { - field := ctx.Expr(0).Accept(v) - if err := getError(field); err != nil { - return err - } - - columnInfo := toColumnInfo(field.(*ExprWithType)) - if columnInfo == nil || !typeutil.IsJSONType(columnInfo.GetDataType()) { - return fmt.Errorf( - "json_contains operation are only supported on json fields now, got: %s", ctx.Expr(0).GetText()) - } - - element := ctx.Expr(1).Accept(v) - if err := getError(element); err != nil { - return err - } - elementValue := getGenericValue(element) - if elementValue == nil { - return fmt.Errorf( - "json_contains operation are only supported explicitly specified element, got: %s", ctx.Expr(1).GetText()) - } - - elements := make([]*planpb.GenericValue, 1) - elements[0] = elementValue - - expr := &planpb.Expr{ - Expr: &planpb.Expr_JsonContainsExpr{ - JsonContainsExpr: &planpb.JSONContainsExpr{ - ColumnInfo: columnInfo, - Elements: elements, - Op: planpb.JSONContainsExpr_Contains, - ElementsSameType: true, - }, - }, - } - return &ExprWithType{ - expr: expr, - dataType: schemapb.DataType_Bool, - } -} - func (v *ParserVisitor) VisitArray(ctx *parser.ArrayContext) interface{} { allExpr := ctx.AllExpr() array := make([]*planpb.GenericValue, 0, len(allExpr)) @@ -1108,7 +1104,7 @@ func (v *ParserVisitor) VisitArray(ctx *parser.ArrayContext) interface{} { if elementValue == nil { return fmt.Errorf("array element type must be generic value, but got: %s", allExpr[i].GetText()) } - array = append(array, getGenericValue(element)) + array = append(array, elementValue) if dType == schemapb.DataType_None { dType = element.(*ExprWithType).dataType @@ -1116,6 +1112,9 @@ func (v *ParserVisitor) VisitArray(ctx *parser.ArrayContext) interface{} { sameType = false } } + if !sameType { + dType = schemapb.DataType_None + } return &ExprWithType{ dataType: schemapb.DataType_Array, @@ -1125,8 +1124,9 @@ func (v *ParserVisitor) VisitArray(ctx *parser.ArrayContext) interface{} { Value: &planpb.GenericValue{ Val: &planpb.GenericValue_ArrayVal{ ArrayVal: &planpb.Array{ - Array: array, - SameType: sameType, + Array: array, + SameType: sameType, + ElementType: dType, }, }, }, @@ -1137,6 +1137,56 @@ func (v *ParserVisitor) VisitArray(ctx *parser.ArrayContext) interface{} { } } +func (v *ParserVisitor) VisitJSONContains(ctx *parser.JSONContainsContext) interface{} { + field := ctx.Expr(0).Accept(v) + if err := getError(field); err != nil { + return err + } + + columnInfo := toColumnInfo(field.(*ExprWithType)) + if columnInfo == nil || + (!typeutil.IsJSONType(columnInfo.GetDataType()) && !typeutil.IsArrayType(columnInfo.GetDataType())) { + return fmt.Errorf( + "contains operation are only supported on json or array fields now, got: %s", ctx.Expr(0).GetText()) + } + + element := ctx.Expr(1).Accept(v) + if err := getError(element); err != nil { + return err + } + elementValue := getGenericValue(element) + if elementValue == nil { + return fmt.Errorf( + "contains operation are only supported explicitly specified element, got: %s", ctx.Expr(1).GetText()) + } + if typeutil.IsArrayType(columnInfo.GetDataType()) { + valExpr := toValueExpr(elementValue) + if !canBeCompared(field.(*ExprWithType), valExpr) { + return fmt.Errorf("contains operation can't compare between array element type: %s and %s", + columnInfo.GetElementType(), + valExpr.dataType) + } + } + + elements := make([]*planpb.GenericValue, 1) + elements[0] = elementValue + + expr := &planpb.Expr{ + Expr: &planpb.Expr_JsonContainsExpr{ + JsonContainsExpr: &planpb.JSONContainsExpr{ + ColumnInfo: columnInfo, + Elements: elements, + Op: planpb.JSONContainsExpr_Contains, + ElementsSameType: true, + }, + }, + } + return &ExprWithType{ + expr: expr, + dataType: schemapb.DataType_Bool, + } +} + func (v *ParserVisitor) VisitJSONContainsAll(ctx *parser.JSONContainsAllContext) interface{} { field := ctx.Expr(0).Accept(v) if err := getError(field); err != nil { @@ -1144,9 +1194,10 @@ func (v *ParserVisitor) VisitJSONContainsAll(ctx *parser.JSONContainsAllContext) } columnInfo := toColumnInfo(field.(*ExprWithType)) - if columnInfo == nil || !typeutil.IsJSONType(columnInfo.GetDataType()) { + if columnInfo == nil || + (!typeutil.IsJSONType(columnInfo.GetDataType()) && !typeutil.IsArrayType(columnInfo.GetDataType())) { return fmt.Errorf( - "json_contains_all operation are only supported on json fields now, got: %s", ctx.Expr(0).GetText()) + "contains_all operation are only supported on json or array fields now, got: %s", ctx.Expr(0).GetText()) } element := ctx.Expr(1).Accept(v) @@ -1156,11 +1207,22 @@ func (v *ParserVisitor) VisitJSONContainsAll(ctx *parser.JSONContainsAllContext) elementValue := getGenericValue(element) if elementValue == nil { return fmt.Errorf( - "json_contains_all operation are only supported explicitly specified element, got: %s", ctx.Expr(1).GetText()) + "contains_all operation are only supported explicitly specified element, got: %s", ctx.Expr(1).GetText()) } if elementValue.GetArrayVal() == nil { - return fmt.Errorf("json_contains_all operation element must be an array") + return fmt.Errorf("contains_all operation element must be an array") + } + + if typeutil.IsArrayType(columnInfo.GetDataType()) { + for _, value := range elementValue.GetArrayVal().GetArray() { + valExpr := toValueExpr(value) + if !canBeCompared(field.(*ExprWithType), valExpr) { + return fmt.Errorf("contains_all operation can't compare between array element type: %s and %s", + columnInfo.GetElementType(), + valExpr.dataType) + } + } } expr := &planpb.Expr{ @@ -1186,9 +1248,10 @@ func (v *ParserVisitor) VisitJSONContainsAny(ctx *parser.JSONContainsAnyContext) } columnInfo := toColumnInfo(field.(*ExprWithType)) - if columnInfo == nil || !typeutil.IsJSONType(columnInfo.GetDataType()) { + if columnInfo == nil || + (!typeutil.IsJSONType(columnInfo.GetDataType()) && !typeutil.IsArrayType(columnInfo.GetDataType())) { return fmt.Errorf( - "json_contains_any operation are only supported on json fields now, got: %s", ctx.Expr(0).GetText()) + "contains_any operation are only supported on json or array fields now, got: %s", ctx.Expr(0).GetText()) } element := ctx.Expr(1).Accept(v) @@ -1198,11 +1261,22 @@ func (v *ParserVisitor) VisitJSONContainsAny(ctx *parser.JSONContainsAnyContext) elementValue := getGenericValue(element) if elementValue == nil { return fmt.Errorf( - "json_contains_any operation are only supported explicitly specified element, got: %s", ctx.Expr(1).GetText()) + "contains_any operation are only supported explicitly specified element, got: %s", ctx.Expr(1).GetText()) } if elementValue.GetArrayVal() == nil { - return fmt.Errorf("json_contains_any operation element must be an array") + return fmt.Errorf("contains_any operation element must be an array") + } + + if typeutil.IsArrayType(columnInfo.GetDataType()) { + for _, value := range elementValue.GetArrayVal().GetArray() { + valExpr := toValueExpr(value) + if !canBeCompared(field.(*ExprWithType), valExpr) { + return fmt.Errorf("contains_any operation can't compare between array element type: %s and %s", + columnInfo.GetElementType(), + valExpr.dataType) + } + } } expr := &planpb.Expr{ @@ -1220,3 +1294,36 @@ func (v *ParserVisitor) VisitJSONContainsAny(ctx *parser.JSONContainsAnyContext) dataType: schemapb.DataType_Bool, } } + +func (v *ParserVisitor) VisitArrayLength(ctx *parser.ArrayLengthContext) interface{} { + columnInfo, err := v.getChildColumnInfo(ctx.Identifier(), ctx.JSONIdentifier()) + if err != nil { + return err + } + if columnInfo == nil || + (!typeutil.IsJSONType(columnInfo.GetDataType()) && !typeutil.IsArrayType(columnInfo.GetDataType())) { + return fmt.Errorf( + "array_length operation are only supported on json or array fields now, got: %s", ctx.GetText()) + } + + expr := &planpb.Expr{ + Expr: &planpb.Expr_BinaryArithExpr{ + BinaryArithExpr: &planpb.BinaryArithExpr{ + Left: &planpb.Expr{ + Expr: &planpb.Expr_ColumnExpr{ + ColumnExpr: &planpb.ColumnExpr{ + Info: columnInfo, + }, + }, + }, + Right: nil, + Op: planpb.ArithOpType_ArrayLength, + }, + }, + } + return &ExprWithType{ + expr: expr, + dataType: schemapb.DataType_Int64, + nodeDependent: true, + } +} diff --git a/internal/parser/planparserv2/plan_parser_v2.go b/internal/parser/planparserv2/plan_parser_v2.go index a549afadaf014..bd7feb3b58050 100644 --- a/internal/parser/planparserv2/plan_parser_v2.go +++ b/internal/parser/planparserv2/plan_parser_v2.go @@ -3,9 +3,9 @@ package planparserv2 import ( "fmt" + "github.com/antlr/antlr4/runtime/Go/antlr" "go.uber.org/zap" - "github.com/antlr/antlr4/runtime/Go/antlr" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/pkg/log" @@ -134,14 +134,21 @@ func CreateSearchPlan(schemaPb *schemapb.CollectionSchema, exprStr string, vecto fieldID := vectorField.FieldID dataType := vectorField.DataType + var vectorType planpb.VectorType if !typeutil.IsVectorType(dataType) { return nil, fmt.Errorf("field (%s) to search is not of vector data type", vectorFieldName) } - + if dataType == schemapb.DataType_FloatVector { + vectorType = planpb.VectorType_FloatVector + } else if dataType == schemapb.DataType_BinaryVector { + vectorType = planpb.VectorType_BinaryVector + } else { + vectorType = planpb.VectorType_Float16Vector + } planNode := &planpb.PlanNode{ Node: &planpb.PlanNode_VectorAnns{ VectorAnns: &planpb.VectorANNS{ - IsBinary: dataType == schemapb.DataType_BinaryVector, + VectorType: vectorType, Predicates: expr, QueryInfo: queryInfo, PlaceholderTag: "$0", diff --git a/internal/parser/planparserv2/plan_parser_v2_test.go b/internal/parser/planparserv2/plan_parser_v2_test.go index 2383bc848a391..fce00dc72b556 100644 --- a/internal/parser/planparserv2/plan_parser_v2_test.go +++ b/internal/parser/planparserv2/plan_parser_v2_test.go @@ -4,11 +4,12 @@ import ( "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/stretchr/testify/assert" ) func newTestSchema() *schemapb.CollectionSchema { @@ -21,12 +22,21 @@ func newTestSchema() *schemapb.CollectionSchema { newField := &schemapb.FieldSchema{ FieldID: int64(100 + value), Name: name + "Field", IsPrimaryKey: false, Description: "", DataType: dataType, } + if dataType == schemapb.DataType_Array { + newField.ElementType = schemapb.DataType_Int64 + } fields = append(fields, newField) } fields = append(fields, &schemapb.FieldSchema{ - FieldID: 199, Name: common.MetaFieldName, IsPrimaryKey: false, Description: "dynamic field", DataType: schemapb.DataType_JSON, + FieldID: 130, Name: common.MetaFieldName, IsPrimaryKey: false, Description: "dynamic field", DataType: schemapb.DataType_JSON, IsDynamic: true, }) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: 131, Name: "StringArrayField", IsPrimaryKey: false, Description: "string array field", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_VarChar, + IsDynamic: true, + }) return &schemapb.CollectionSchema{ Name: "test", @@ -79,6 +89,7 @@ func TestExpr_Term(t *testing.T) { `A in []`, `A in ["abc", "def"]`, `A in ["1", "2", "abc", "def"]`, + `A in ["1", 2, "abc", 2.2]`, } for _, exprStr := range exprStrs { assertValidExpr(t, helper, exprStr) @@ -185,6 +196,16 @@ func TestExpr_BinaryRange(t *testing.T) { for _, exprStr := range exprStrs { assertValidExpr(t, helper, exprStr) } + + invalidExprs := []string{ + `1 < JSONField < 3`, + `1 < ArrayField < 3`, + `1 < A+B < 3`, + } + + for _, exprStr := range invalidExprs { + assertInvalidExpr(t, helper, exprStr) + } } func TestExpr_BinaryArith(t *testing.T) { @@ -195,6 +216,7 @@ func TestExpr_BinaryArith(t *testing.T) { exprStrs := []string{ `Int64Field % 10 == 9`, `Int64Field % 10 != 9`, + `Int64Field + 1.1 == 2.1`, `A % 10 != 2`, } for _, exprStr := range exprStrs { @@ -210,6 +232,10 @@ func TestExpr_BinaryArith(t *testing.T) { `FloatField + 11 < 12`, `DoubleField - 13 < 14`, `A - 15 < 16`, + `JSONField + 15 == 16`, + `15 + JSONField == 16`, + `ArrayField + 15 == 16`, + `15 + ArrayField == 16`, } for _, exprStr := range unsupported { assertInvalidExpr(t, helper, exprStr) @@ -407,6 +433,10 @@ func TestExpr_Invalid(t *testing.T) { `StringField % VarCharField`, `StringField * 2`, `2 / StringField`, + `JSONField / 2 == 1`, + `2 % JSONField == 1`, + `ArrayField / 2 == 1`, + `2 / ArrayField == 1`, // ----------------------- ==/!= ------------------------- //`not_in_schema != 1`, // maybe in json //`1 == not_in_schema`, // maybe in json @@ -421,6 +451,10 @@ func TestExpr_Invalid(t *testing.T) { `"str" >= false`, `VarCharField < FloatField`, `FloatField > VarCharField`, + `JSONField > 1`, + `1 < JSONField`, + `ArrayField > 2`, + `2 < ArrayField`, // ------------------------ like ------------------------ `(VarCharField % 2) like "prefix%"`, `FloatField like "prefix%"`, @@ -481,6 +515,16 @@ func TestExpr_Invalid(t *testing.T) { `Int64Field > 100 and BoolField`, `Int64Field < 100 or false`, // maybe this can be optimized. `!BoolField`, + // -------------------- array ---------------------- + //`A == [1, 2, 3]`, + `Int64Field == [1, 2, 3]`, + `Int64Field > [1, 2, 3]`, + `Int64Field + [1, 2, 3] == 10`, + `Int64Field % [1, 2, 3] == 10`, + `[1, 2, 3] < Int64Field < [4, 5, 6]`, + `Int64Field["A"] == 123`, + `[1,2,3] == [4,5,6]`, + `[1,2,3] == 1`, } for _, exprStr := range exprStrs { _, err := ParseExpr(helper, exprStr) @@ -575,1032 +619,270 @@ func Test_handleExpr_17126_26662(t *testing.T) { func Test_JSONExpr(t *testing.T) { schema := newTestSchema() expr := "" + var err error // search - expr = `$meta["A"] > 90` - _, err := CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `JSONField["A"] > 90` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) + exprs := []string{ + `$meta["A"] > 90`, + `JSONField["A"] > 90`, + `A < 10`, + `JSONField["A"] <= 5`, + `$meta["A"] <= 5`, + `$meta["A"] >= 95`, + `$meta["A"] == 5`, + `$meta["A"] != 95`, + `$meta["A"] > 90 && $meta["B"] < 5`, + `$meta["A"] > 95 || $meta["B"] < 5`, + `A > 95 || $meta["B"] < 5`, + `not ($meta["A"] == 95)`, + `$meta["A"] in [90, 91, 95, 97]`, + `$meta["A"] not in [90, 91, 95, 97]`, + `$meta["C"]["0"] in [90, 91, 95, 97]`, + `$meta["C"]["0"] not in [90, 91, 95, 97]`, + `C["0"] not in [90, 91, 95, 97]`, + `C[0] in [90, 91, 95, 97]`, + `C["0"] > 90`, + `C["0"] < 90`, + `C["0"] == 90`, + `10 < C["0"] < 90`, + `100 > C["0"] > 90`, + `0 <= $meta["A"] < 5`, + `0 <= A < 5`, + `$meta["A"] + 5 == 10`, + `$meta["A"] > 10 + 5`, + `100 - 5 < $meta["A"]`, + `100 == $meta["A"] + 6`, + `exists $meta["A"]`, + `exists $meta["A"]["B"]["C"] `, + `A["B"][0] > 100`, + `$meta[0] > 100`, + `A["\"\"B\"\""] > 10`, + `A["[\"B\"]"] == "abc\"bbb\"cc"`, + `A['B'] == "abc\"bbb\"cc"`, + `A['B'] == 'abc"cba'`, + `A['B'] == 'abc\"cba'`, + `A == [1,2,3]`, + `A + 1.2 == 3.3`, + `A + 1 == 2`, + } + for _, expr = range exprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.NoError(t, err) + } +} - expr = `A < 10` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) +func Test_InvalidExprOnJSONField(t *testing.T) { + schema := newTestSchema() + expr := "" + var err error + // search + exprs := []string{ + `exists $meta`, + `exists JSONField`, + `exists ArrayField`, + `$meta > 0`, + `JSONField == 0`, + `$meta < 100`, + `0 < $meta < 100`, + `20 > $meta > 0`, + `$meta + 5 > 0`, + `$meta > 2 + 5`, + `exists $meta["A"] > 10 `, + `exists Int64Field `, + `A[[""B""]] > 10`, + `A["[""B""]"] > 10`, + `A[[""B""]] > 10`, + `A[B] > 10`, + `A + B == 3.3`, + } - expr = `JSONField["A"] <= 5` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) + for _, expr = range exprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.Error(t, err, expr) + } +} - expr = `$meta["A"] <= 5` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) +func Test_InvalidExprWithoutJSONField(t *testing.T) { + fields := []*schemapb.FieldSchema{ + {FieldID: 100, Name: "id", IsPrimaryKey: true, Description: "id", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "vector", IsPrimaryKey: false, Description: "vector", DataType: schemapb.DataType_FloatVector}, + } - expr = `$meta["A"] >= 95` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) + schema := &schemapb.CollectionSchema{ + Name: "test", + Description: "schema for test used", + AutoID: true, + Fields: fields, + } + expr := "" + var err error - expr = `$meta["A"] == 5` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) + exprs := []string{ + `A == 0`, + `JSON["A"] > 0`, + `A < 100`, + `0 < JSON["A"] < 100`, + `0 < A < 100`, + `100 > JSON["A"] > 0`, + `100 > A > 0`, + } - expr = `$meta["A"] != 95` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) + for _, expr = range exprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.Error(t, err) + } +} - expr = `$meta["A"] > 90 && $meta["B"] < 5` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) +func Test_InvalidExprWithMultipleJSONField(t *testing.T) { + fields := []*schemapb.FieldSchema{ + {FieldID: 100, Name: "id", IsPrimaryKey: true, Description: "id", DataType: schemapb.DataType_Int64}, + {FieldID: 101, Name: "vector", IsPrimaryKey: false, Description: "vector", DataType: schemapb.DataType_FloatVector}, + {FieldID: 102, Name: "json1", IsPrimaryKey: false, Description: "json field 1", DataType: schemapb.DataType_JSON}, + {FieldID: 103, Name: "json2", IsPrimaryKey: false, Description: "json field 2", DataType: schemapb.DataType_JSON}, + } - expr = `$meta["A"] > 95 || $meta["B"] < 5` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) + schema := &schemapb.CollectionSchema{ + Name: "test", + Description: "schema for test used", + AutoID: true, + Fields: fields, + } - expr = `A > 95 || $meta["B"] < 5` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) + expr := "" + var err error + exprs := []string{ + `A == 0`, + `A in [1, 2, 3]`, + `A not in [1, 2, 3]`, + `"1" in A`, + `"1" not in A`, + } - expr = `not ($meta["A"] == 95)` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) + for _, expr = range exprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.Error(t, err) + } +} - expr = `$meta["A"] in [90, 91, 95, 97]` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) +func Test_exprWithSingleQuotes(t *testing.T) { + schema := newTestSchema() + expr := "" + var err error + exprs := []string{ + `'abc' < StringField < "def"`, + `'ab"c' < StringField < "d'ef"`, + `'ab\"c' < StringField < "d\'ef"`, + `'ab\'c' < StringField < "d\"ef"`, + } + for _, expr = range exprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.NoError(t, err) + } - expr = `$meta["A"] not in [90, 91, 95, 97]` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) + invalidExprs := []string{ + `'abc'd' < StringField < "def"`, + `'abc' < StringField < "def"g"`, + } - expr = `$meta["C"]["0"] in [90, 91, 95, 97]` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) + for _, expr = range invalidExprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.Error(t, err) + } +} - expr = `$meta["C"]["0"] not in [90, 91, 95, 97]` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) +func Test_JSONContains(t *testing.T) { + schema := newTestSchema() + expr := "" + var err error + exprs := []string{ + `json_contains(A, 10)`, + `not json_contains(A, 10)`, + `json_contains(A, 10.5)`, + `not json_contains(A, 10.5)`, + `json_contains(A, "10")`, + `not json_contains(A, "10")`, + `json_contains($meta["A"], 10)`, + `not json_contains($meta["A"], 10)`, + `json_contains(JSONField["x"], 5)`, + `not json_contains(JSONField["x"], 5)`, + `JSON_CONTAINS(JSONField["x"], 5)`, + `json_contains(A, [1,2,3])`, + `array_contains(A, [1,2,3])`, + `array_contains(ArrayField, [1,2,3])`, + `array_contains(ArrayField, 1)`, + } + for _, expr = range exprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.NoError(t, err) + } +} - expr = `C["0"] not in [90, 91, 95, 97]` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `C[0] in [90, 91, 95, 97]` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `C["0"] > 90` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `C["0"] < 90` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `C["0"] == 90` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `10 < C["0"] < 90` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `100 > C["0"] > 90` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `0 <= $meta["A"] < 5` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `0 <= A < 5` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `$meta["A"] + 5 == 10` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `$meta["A"] > 10 + 5` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `100 - 5 < $meta["A"]` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `100 == $meta["A"] + 6` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `exists $meta["A"]` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `exists $meta["A"]["B"]["C"] ` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `A["B"][0] > 100` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `A["B"][0] > 100` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `$meta[0] > 100` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `A["\"\"B\"\""] > 10` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `A["[\"B\"]"] == "abc\"bbb\"cc"` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `A['B'] == "abc\"bbb\"cc"` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `A['B'] == 'abc"cba'` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `A['B'] == 'abc\"cba'` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) -} - -func Test_InvalidExprOnJSONField(t *testing.T) { - schema := newTestSchema() - expr := "" - var err error - expr = `exists $meta` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `exists JSONField` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `$meta > 0` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `$meta > 0` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `JSONField == 0` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `$meta < 100` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `0 < $meta < 100` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `20 > $meta > 0` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `$meta + 5 > 0` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `$meta > 2 + 5` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `exists $meta["A"] > 10 ` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `exists Int64Field ` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `A[[""B""]] > 10` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `A["[""B""]"] > 10` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `A[[""B""]] > 10` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `A[B] > 10` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) -} - -func Test_InvalidExprWithoutJSONField(t *testing.T) { - fields := []*schemapb.FieldSchema{ - {FieldID: 100, Name: "id", IsPrimaryKey: true, Description: "id", DataType: schemapb.DataType_Int64}, - {FieldID: 101, Name: "vector", IsPrimaryKey: false, Description: "vector", DataType: schemapb.DataType_FloatVector}, - } - - schema := &schemapb.CollectionSchema{ - Name: "test", - Description: "schema for test used", - AutoID: true, - Fields: fields, - } - - expr := "" - var err error - expr = `A == 0` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `JSON["A"] > 0` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `A < 100` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `0 < JSON["A"] < 100` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `0 < A < 100` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `100 > JSON["A"] > 0` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `100 > A > 0` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) -} - -func Test_InvalidExprWithMultipleJSONField(t *testing.T) { - fields := []*schemapb.FieldSchema{ - {FieldID: 100, Name: "id", IsPrimaryKey: true, Description: "id", DataType: schemapb.DataType_Int64}, - {FieldID: 101, Name: "vector", IsPrimaryKey: false, Description: "vector", DataType: schemapb.DataType_FloatVector}, - {FieldID: 102, Name: "json1", IsPrimaryKey: false, Description: "json field 1", DataType: schemapb.DataType_JSON}, - {FieldID: 103, Name: "json2", IsPrimaryKey: false, Description: "json field 2", DataType: schemapb.DataType_JSON}, - } - - schema := &schemapb.CollectionSchema{ - Name: "test", - Description: "schema for test used", - AutoID: true, - Fields: fields, - } - - expr := "" - var err error - expr = `A == 0` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `A in [1, 2, 3]` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `A not in [1, 2, 3]` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `"1" in A` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `"1" not in A` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) -} - -func Test_exprWithSingleQuotes(t *testing.T) { - schema := newTestSchema() - expr := "" - var err error - expr = `'abc' < StringField < "def"` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `'ab"c' < StringField < "d'ef"` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `'ab\"c' < StringField < "d\'ef"` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `'ab\'c' < StringField < "d\"ef"` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - - expr = `'abc'd' < StringField < "def"` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `'abc' < StringField < "def"g"` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) -} - -func Test_JSONContains(t *testing.T) { - schema := newTestSchema() - expr := "" - var err error - expr = `json_contains(A, 10)` - plan, err := CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) - - expr = `not json_contains(A, 10)` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetUnaryExpr()) - - expr = `json_contains(A, 10.5)` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) - - expr = `not json_contains(A, 10.5)` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetUnaryExpr()) - - expr = `json_contains(A, "10")` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) - - expr = `not json_contains(A, "10")` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetUnaryExpr()) - - expr = `json_contains($meta["A"], 10)` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) - - expr = `not json_contains($meta["A"], 10)` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetUnaryExpr()) - - expr = `json_contains(JSONField["x"], 5)` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) - - expr = `not json_contains(JSONField["x"], 5)` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetUnaryExpr()) - - expr = `JSON_CONTAINS(JSONField["x"], 5)` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) - - expr = `json_contains(A, [1,2,3])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) -} - -func Test_InvalidJSONContains(t *testing.T) { - schema := newTestSchema() - expr := "" - var err error - expr = `json_contains(10, A)` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `json_contains(1, [1,2,3])` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `json_contains([1,2,3], 1)` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `json_contains([1,2,3], [1,2,3])` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `json_contains([1,2,3], [1,2])` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `json_contains($meta, 1)` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `json_contains(A, B)` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `not json_contains(A, B)` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `json_contains(A, B > 5)` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `json_contains(StringField, "a")` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `json_contains(A, StringField > 5)` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `json_contains(A)` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `json_contains(A, 5, C)` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `json_contains(JSONField, 5)` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `json_Contains(JSONField, 5)` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - - expr = `JSON_contains(JSONField, 5)` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) -} - -func Test_EscapeString(t *testing.T) { - schema := newTestSchema() - expr := "" - var err error - exprs := []string{ - `A == "\"" || B == '\"'`, - `A == "\n" || B == '\n'`, - `A == "\367" || B == '\367'`, - `A == "\3678" || B == '\3678'`, - `A == "ab'c\'d" || B == 'abc"de\"'`, - `A == "'" || B == '"'`, - `A == "\'" || B == '\"' || C == '\''`, - `A == "\\'" || B == '\\"' || C == '\''`, - `A == "\\\'" || B == '\\\"' || C == '\\\''`, - `A == "\\\\'" || B == '\\\\"' || C == '\\\''`, - `A == "\\\\\'" || B == '\\\\\"' || C == '\\\\\''`, - `A == "\\\\\\'" || B == '\\\\\\"' || C == '\\\\\''`, - `str2 like 'abc\"def-%'`, - `str2 like 'abc"def-%'`, - `str4 like "abc\367-%"`, - `str4 like "中国"`, - } - for _, expr = range exprs { - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - } - - invalidExprs := []string{ - `A == "ab -c" || B == 'ab -c'`, - `A == "\423" || B == '\378'`, - `A == "\中国"`, - } - for _, expr = range invalidExprs { - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - } -} +func Test_InvalidJSONContains(t *testing.T) { + schema := newTestSchema() + expr := "" + var err error + exprs := []string{ + `json_contains(10, A)`, + `json_contains(1, [1,2,3])`, + `json_contains([1,2,3], 1)`, + `json_contains([1,2,3], [1,2,3])`, + `json_contains([1,2,3], [1,2])`, + `json_contains($meta, 1)`, + `json_contains(A, B)`, + `not json_contains(A, B)`, + `json_contains(A, B > 5)`, + `json_contains(StringField, "a")`, + `json_contains(A, StringField > 5)`, + `json_contains(A)`, + `json_contains(A, 5, C)`, + `json_contains(JSONField, 5)`, + `json_Contains(JSONField, 5)`, + `JSON_contains(JSONField, 5)`, + } + for _, expr = range exprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.Error(t, err) + } +} func Test_isEmptyExpression(t *testing.T) { type args struct { @@ -1610,354 +892,298 @@ func Test_isEmptyExpression(t *testing.T) { name string args args want bool - }{ - { - args: args{s: ""}, - want: true, - }, - { - args: args{s: " "}, - want: true, - }, - { - args: args{s: "not empty"}, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equalf(t, tt.want, isEmptyExpression(tt.args.s), "isEmptyExpression(%v)", tt.args.s) - }) - } -} - -func Test_JSONContainsAll(t *testing.T) { - schema := newTestSchema() - expr := "" - var err error - var plan *planpb.PlanNode - - expr = `json_contains_all(A, [1,2,3])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) - assert.Equal(t, planpb.JSONContainsExpr_ContainsAll, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetOp()) - assert.True(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetElementsSameType()) - - expr = `json_contains_all(A, [1,"2",3.0])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) - assert.Equal(t, planpb.JSONContainsExpr_ContainsAll, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetOp()) - assert.False(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetElementsSameType()) - - expr = `JSON_CONTAINS_ALL(A, [1,"2",3.0])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) - assert.Equal(t, planpb.JSONContainsExpr_ContainsAll, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetOp()) - assert.False(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetElementsSameType()) -} - -func Test_InvalidJSONContainsAll(t *testing.T) { - schema := newTestSchema() - expr := "" - var err error - var plan *planpb.PlanNode - - expr = `JSON_CONTAINS_ALL(A, 1)` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) - - expr = `JSON_CONTAINS_ALL(A, [abc])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) - - expr = `JSON_CONTAINS_ALL(A, [2>a])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) - - expr = `JSON_CONTAINS_ALL(A, [2>>a])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) - - expr = `JSON_CONTAINS_ALL(A[""], [1,2,3])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) - - expr = `JSON_CONTAINS_ALL(Int64Field, [1,2,3])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) - - expr = `JSON_CONTAINS_ALL(A, B)` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) + }{ + { + args: args{s: ""}, + want: true, + }, + { + args: args{s: " "}, + want: true, + }, + { + args: args{s: "not empty"}, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, isEmptyExpression(tt.args.s), "isEmptyExpression(%v)", tt.args.s) + }) + } } -func Test_JSONContainsAny(t *testing.T) { +func Test_EscapeString(t *testing.T) { schema := newTestSchema() expr := "" var err error - var plan *planpb.PlanNode - - expr = `json_contains_any(A, [1,2,3])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) - assert.Equal(t, planpb.JSONContainsExpr_ContainsAny, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetOp()) - assert.True(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetElementsSameType()) - - expr = `json_contains_any(A, [1,"2",3.0])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) - assert.Equal(t, planpb.JSONContainsExpr_ContainsAny, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetOp()) - assert.False(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetElementsSameType()) - - expr = `JSON_CONTAINS_ANY(A, [1,"2",3.0])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.NoError(t, err) - assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) - assert.Equal(t, planpb.JSONContainsExpr_ContainsAny, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetOp()) - assert.False(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetElementsSameType()) + exprs := []string{ + `A == "\"" || B == '\"'`, + `A == "\n" || B == '\n'`, + `A == "\367" || B == '\367'`, + `A == "\3678" || B == '\3678'`, + `A == "ab'c\'d" || B == 'abc"de\"'`, + `A == "'" || B == '"'`, + `A == "\'" || B == '\"' || C == '\''`, + `A == "\\'" || B == '\\"' || C == '\''`, + `A == "\\\'" || B == '\\\"' || C == '\\\''`, + `A == "\\\\'" || B == '\\\\"' || C == '\\\''`, + `A == "\\\\\'" || B == '\\\\\"' || C == '\\\\\''`, + `A == "\\\\\\'" || B == '\\\\\\"' || C == '\\\\\''`, + `str2 like 'abc\"def-%'`, + `str2 like 'abc"def-%'`, + `str4 like "abc\367-%"`, + `str4 like "中国"`, + } + for _, expr = range exprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.NoError(t, err) + } - expr = `JSON_CONTAINS_ANY(A, 1)` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) + invalidExprs := []string{ + `A == "ab +c" || B == 'ab +c'`, + `A == "\423" || B == '\378'`, + `A == "\中国"`, + } + for _, expr = range invalidExprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.Error(t, err) + } } -func Test_InvalidJSONContainsAny(t *testing.T) { +func Test_JSONContainsAll(t *testing.T) { schema := newTestSchema() expr := "" var err error var plan *planpb.PlanNode - expr = `JSON_CONTAINS_ANY(A, 1)` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) - - expr = `JSON_CONTAINS_ANY(A, [abc])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) - - expr = `JSON_CONTAINS_ANY(A, [2>a])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) - - expr = `JSON_CONTAINS_ANY(A, [2>>a])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) - - expr = `JSON_CONTAINS_ANY(A[""], [1,2,3])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) - - expr = `JSON_CONTAINS_ANY(Int64Field, [1,2,3])` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) + exprs := []string{ + `json_contains_all(A, [1,2,3])`, + `json_contains_all(A, [1,"2",3.0])`, + `JSON_CONTAINS_ALL(A, [1,"2",3.0])`, + `array_contains_all(ArrayField, [1,2,3])`, + `array_contains_all(ArrayField, [1])`, + `json_contains_all(ArrayField, [1,2,3])`, + } + for _, expr = range exprs { + plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.NoError(t, err) + assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) + assert.Equal(t, planpb.JSONContainsExpr_ContainsAll, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetOp()) + } - expr = `JSON_CONTAINS_ANY(A, B)` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) + invalidExprs := []string{ + `JSON_CONTAINS_ALL(A, 1)`, + `JSON_CONTAINS_ALL(A, [abc])`, + `JSON_CONTAINS_ALL(A, [2>a])`, + `JSON_CONTAINS_ALL(A, [2>>a])`, + `JSON_CONTAINS_ALL(A[""], [1,2,3])`, + `JSON_CONTAINS_ALL(Int64Field, [1,2,3])`, + `JSON_CONTAINS_ALL(A, B)`, + } + for _, expr = range invalidExprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.Error(t, err) + } } -func Test_UnsupportedExpr(t *testing.T) { +func Test_JSONContainsAny(t *testing.T) { schema := newTestSchema() expr := "" var err error var plan *planpb.PlanNode - expr = `A == [1, 2, 3]` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) - - expr = `Int64Field == [1, 2, 3]` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) + exprs := []string{ + `json_contains_any(A, [1,2,3])`, + `json_contains_any(A, [1,"2",3.0])`, + `JSON_CONTAINS_ANY(A, [1,"2",3.0])`, + `JSON_CONTAINS_ANY(ArrayField, [1,2,3])`, + `JSON_CONTAINS_ANY(ArrayField, [3,4,5])`, + `JSON_CONTAINS_ANY(ArrayField, [1,2,3])`, + } + for _, expr = range exprs { + plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.NoError(t, err) + assert.NotNil(t, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr()) + assert.Equal(t, planpb.JSONContainsExpr_ContainsAny, plan.GetVectorAnns().GetPredicates().GetJsonContainsExpr().GetOp()) + } - expr = `Int64Field > [1, 2, 3]` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) + invalidExprs := []string{ + `JSON_CONTAINS_ANY(A, 1)`, + `JSON_CONTAINS_ANY(A, [abc])`, + `JSON_CONTAINS_ANY(A, [2>a])`, + `JSON_CONTAINS_ANY(A, [2>>a])`, + `JSON_CONTAINS_ANY(A[""], [1,2,3])`, + `JSON_CONTAINS_ANY(Int64Field, [1,2,3])`, + `JSON_CONTAINS_ANY(A, B)`, + } + for _, expr = range invalidExprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.Error(t, err) + } +} - expr = `Int64Field + [1, 2, 3] == 10` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) +func Test_ArrayExpr(t *testing.T) { + schema := newTestSchema() + expr := "" + var err error - expr = `Int64Field % [1, 2, 3] == 10` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) + exprs := []string{ + `ArrayField == [1,2,3,4]`, + `ArrayField[0] == 1`, + `ArrayField[0] > 1`, + `1 < ArrayField[0] < 3`, + `StringArrayField[0] == "abc"`, + `StringArrayField[0] < "abc"`, + `"abc" < StringArrayField[0] < "efg"`, + `array_contains(ArrayField, 1)`, + `not ARRAY_CONTAINS(ArrayField, 1)`, + `array_contains_all(ArrayField, [1,2,3,4])`, + `not ARRAY_CONTAINS_ALL(ArrayField, [1,2,3,4])`, + `array_contains_any(ArrayField, [1,2,3,4])`, + `not ARRAY_CONTAINS_ANY(ArrayField, [1,2,3,4])`, + `array_contains(StringArrayField, "abc")`, + `not ARRAY_CONTAINS(StringArrayField, "abc")`, + `array_contains_all(StringArrayField, ["a", "b", "c", "d"])`, + `not ARRAY_CONTAINS_ALL(StringArrayField, ["a", "b", "c", "d"])`, + `array_contains_any(StringArrayField, ["a", "b", "c", "d"])`, + `not ARRAY_CONTAINS_ANY(StringArrayField, ["a", "b", "c", "d"])`, + `StringArrayField[0] like "abd%"`, + `+ArrayField[0] == 1`, + `ArrayField[0] % 3 == 1`, + `ArrayField[0] + 3 == 1`, + `ArrayField[0] in [1,2,3]`, + `ArrayField[0] in []`, + `0 < ArrayField[0] < 100`, + `100 > ArrayField[0] > 0`, + `ArrayField[0] > 1`, + `ArrayField[0] == 1`, + } + for _, expr = range exprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.NoError(t, err, expr) + } - expr = `[1, 2, 3] < Int64Field < [4, 5, 6]` - plan, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) - assert.Nil(t, plan) + invalidExprs := []string{ + `ArrayField == ["abc", "def"]`, + `"abc" < ArrayField[0] < "def"`, + `ArrayField[0] == "def"`, + `ArrayField["0"] == 1`, + `array_contains(ArrayField, "a")`, + `array_contains(StringArrayField, 1)`, + `array_contains_all(StringArrayField, ["abc", 123])`, + `array_contains_any(StringArrayField, ["abc", 123])`, + `StringArrayField like "abd%"`, + `+ArrayField == 1`, + `ArrayField % 3 == 1`, + `ArrayField + 3 == 1`, + `ArrayField in [1,2,3]`, + `ArrayField[0] in [1, "abc",3.3]`, + `ArrayField in []`, + `0 < ArrayField < 100`, + `100 > ArrayField > 0`, + `ArrayField > 1`, + `ArrayField == 1`, + `ArrayField[] == 1`, + `A[] == 1`, + `ArrayField[0] + ArrayField[1] == 1`, + `ArrayField == []`, + } + for _, expr = range invalidExprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.Error(t, err, expr) + } } -func Test_InvalidAccess(t *testing.T) { +func Test_ArrayLength(t *testing.T) { schema := newTestSchema() expr := "" var err error - expr = `Int64Field["A"] == 123` - _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }) - assert.Error(t, err) + exprs := []string{ + `array_length(ArrayField) == 10`, + `array_length(A) != 10`, + `array_length(StringArrayField) == 1`, + `array_length(B) != 1`, + `not (array_length(C[0]) == 1)`, + `not (array_length(C["D"]) != 1)`, + } + for _, expr = range exprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.NoError(t, err, expr) + } + + invalidExprs := []string{ + `array_length(a > b) == 0`, + `array_length(a, b) == 1`, + `array_length(A)`, + `array_length("A") / 10 == 2`, + `array_length(Int64Field) == 2`, + `array_length(a-b) == 2`, + `0 < array_length(a-b) < 2`, + `0 < array_length(StringArrayField) < 1`, + `100 > array_length(ArrayField) > 10`, + `array_length(StringArrayField) < 1`, + `array_length(A) % 10 == 2`, + `array_length(A) / 10 == 2`, + `array_length(A) + 1 == 2`, + `array_length(JSONField) + 1 == 2`, + `array_length(A) == 2.2`, + } + for _, expr = range invalidExprs { + _, err = CreateSearchPlan(schema, expr, "FloatVectorField", &planpb.QueryInfo{ + Topk: 0, + MetricType: "", + SearchParams: "", + RoundDecimal: 0, + }) + assert.Error(t, err, expr) + } } diff --git a/internal/parser/planparserv2/pool.go b/internal/parser/planparserv2/pool.go index 005884b60ef21..73433572eea91 100644 --- a/internal/parser/planparserv2/pool.go +++ b/internal/parser/planparserv2/pool.go @@ -4,6 +4,7 @@ import ( "sync" "github.com/antlr/antlr4/runtime/Go/antlr" + antlrparser "github.com/milvus-io/milvus/internal/parser/planparserv2/generated" ) diff --git a/internal/parser/planparserv2/pool_test.go b/internal/parser/planparserv2/pool_test.go index 37b64e12cb4a6..bea97fb9b637f 100644 --- a/internal/parser/planparserv2/pool_test.go +++ b/internal/parser/planparserv2/pool_test.go @@ -4,8 +4,9 @@ import ( "testing" "github.com/antlr/antlr4/runtime/Go/antlr" - antlrparser "github.com/milvus-io/milvus/internal/parser/planparserv2/generated" "github.com/stretchr/testify/assert" + + antlrparser "github.com/milvus-io/milvus/internal/parser/planparserv2/generated" ) func genNaiveInputStream() *antlr.InputStream { diff --git a/internal/parser/planparserv2/show_visitor.go b/internal/parser/planparserv2/show_visitor.go index 6930e8fe76d3f..b9b263b6e0631 100644 --- a/internal/parser/planparserv2/show_visitor.go +++ b/internal/parser/planparserv2/show_visitor.go @@ -9,8 +9,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" ) -type ShowExprVisitor struct { -} +type ShowExprVisitor struct{} func extractColumnInfo(info *planpb.ColumnInfo) interface{} { js := make(map[string]interface{}) diff --git a/internal/parser/planparserv2/utils.go b/internal/parser/planparserv2/utils.go index 7a03773d45edf..d6c2172e7cf36 100644 --- a/internal/parser/planparserv2/utils.go +++ b/internal/parser/planparserv2/utils.go @@ -5,10 +5,9 @@ import ( "strconv" "strings" - "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) func IsBool(n *planpb.GenericValue) bool { @@ -47,6 +46,14 @@ func IsString(n *planpb.GenericValue) bool { return false } +func IsArray(n *planpb.GenericValue) bool { + switch n.GetVal().(type) { + case *planpb.GenericValue_ArrayVal: + return true + } + return false +} + func NewBool(value bool) *planpb.GenericValue { return &planpb.GenericValue{ Val: &planpb.GenericValue_BoolVal{ @@ -119,33 +126,47 @@ func toValueExpr(n *planpb.GenericValue) *ExprWithType { } } -func getSameType(a, b schemapb.DataType) (schemapb.DataType, error) { - if hasJSONField(a, b) { - return schemapb.DataType_JSON, nil +func getSameType(left, right *ExprWithType) (schemapb.DataType, error) { + lDataType, rDataType := left.dataType, right.dataType + if typeutil.IsArrayType(lDataType) { + lDataType = toColumnInfo(left).GetElementType() } - if typeutil.IsFloatingType(a) && typeutil.IsArithmetic(b) { - return schemapb.DataType_Double, nil + if typeutil.IsArrayType(rDataType) { + rDataType = toColumnInfo(right).GetElementType() } - - if typeutil.IsIntegerType(a) && typeutil.IsIntegerType(b) { - return schemapb.DataType_Int64, nil + if typeutil.IsJSONType(lDataType) { + if typeutil.IsJSONType(rDataType) { + return schemapb.DataType_JSON, nil + } + if typeutil.IsFloatingType(rDataType) { + return schemapb.DataType_Double, nil + } + if typeutil.IsIntegerType(rDataType) { + return schemapb.DataType_Int64, nil + } } - - return schemapb.DataType_None, fmt.Errorf("incompatible data type, %s, %s", a.String(), b.String()) -} - -func hasJSONField(a, b schemapb.DataType) bool { - if typeutil.IsJSONType(a) || typeutil.IsJSONType(b) { - return true + if typeutil.IsFloatingType(lDataType) { + if typeutil.IsJSONType(rDataType) || typeutil.IsArithmetic(rDataType) { + return schemapb.DataType_Double, nil + } } - return false + if typeutil.IsIntegerType(lDataType) { + if typeutil.IsFloatingType(rDataType) { + return schemapb.DataType_Double, nil + } + if typeutil.IsIntegerType(rDataType) || typeutil.IsJSONType(rDataType) { + return schemapb.DataType_Int64, nil + } + } + + return schemapb.DataType_None, fmt.Errorf("incompatible data type, %s, %s", lDataType.String(), rDataType.String()) } func calcDataType(left, right *ExprWithType, reverse bool) (schemapb.DataType, error) { if reverse { - return getSameType(right.dataType, left.dataType) + return getSameType(right, left) } - return getSameType(left.dataType, right.dataType) + return getSameType(left, right) } func reverseOrder(op planpb.OpType) (planpb.OpType, error) { @@ -175,6 +196,9 @@ func castValue(dataType schemapb.DataType, value *planpb.GenericValue) (*planpb. if typeutil.IsJSONType(dataType) { return value, nil } + if typeutil.IsArrayType(dataType) && IsArray(value) { + return value, nil + } if typeutil.IsStringType(dataType) && IsString(value) { return value, nil } @@ -192,17 +216,19 @@ func castValue(dataType schemapb.DataType, value *planpb.GenericValue) (*planpb. } } - if typeutil.IsIntegerType(dataType) { - if IsInteger(value) { - return value, nil - } + if typeutil.IsIntegerType(dataType) && IsInteger(value) { + return value, nil } return nil, fmt.Errorf("cannot cast value to %s, value: %s", dataType.String(), value) } func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, columnInfo *planpb.ColumnInfo, operand *planpb.GenericValue, value *planpb.GenericValue) *planpb.Expr { - castedValue, err := castValue(columnInfo.GetDataType(), operand) + dataType := columnInfo.GetDataType() + if typeutil.IsArrayType(dataType) && len(columnInfo.GetNestedPath()) != 0 { + dataType = columnInfo.GetElementType() + } + castedValue, err := castValue(dataType, operand) if err != nil { return nil } @@ -219,6 +245,19 @@ func combineBinaryArithExpr(op planpb.OpType, arithOp planpb.ArithOpType, column } } +func combineArrayLengthExpr(op planpb.OpType, arithOp planpb.ArithOpType, columnInfo *planpb.ColumnInfo, value *planpb.GenericValue) (*planpb.Expr, error) { + return &planpb.Expr{ + Expr: &planpb.Expr_BinaryArithOpEvalRangeExpr{ + BinaryArithOpEvalRangeExpr: &planpb.BinaryArithOpEvalRangeExpr{ + ColumnInfo: columnInfo, + ArithOp: arithOp, + Op: op, + Value: value, + }, + }, + }, nil +} + func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr, valueExpr *planpb.ValueExpr) (*planpb.Expr, error) { switch op { case planpb.OpType_Equal, planpb.OpType_NotEqual: @@ -230,6 +269,10 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr, leftExpr, leftValue := arithExpr.Left.GetColumnExpr(), arithExpr.Left.GetValueExpr() rightExpr, rightValue := arithExpr.Right.GetColumnExpr(), arithExpr.Right.GetValueExpr() + arithOp := arithExpr.GetOp() + if arithOp == planpb.ArithOpType_ArrayLength { + return combineArrayLengthExpr(op, arithOp, leftExpr.GetInfo(), valueExpr.GetValue()) + } if leftExpr != nil && rightExpr != nil { // a + b == 3 @@ -247,7 +290,7 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr, // a * 2 == 3 // a / 2 == 3 // a % 2 == 3 - return combineBinaryArithExpr(op, arithExpr.GetOp(), leftExpr.GetInfo(), rightValue.GetValue(), valueExpr.GetValue()), nil + return combineBinaryArithExpr(op, arithOp, leftExpr.GetInfo(), rightValue.GetValue(), valueExpr.GetValue()), nil } else if rightExpr != nil && leftValue != nil { // 2 + a == 3 // 2 - a == 3 @@ -257,9 +300,9 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr, switch arithExpr.GetOp() { case planpb.ArithOpType_Add, planpb.ArithOpType_Mul: - return combineBinaryArithExpr(op, arithExpr.GetOp(), rightExpr.GetInfo(), leftValue.GetValue(), valueExpr.GetValue()), nil + return combineBinaryArithExpr(op, arithOp, rightExpr.GetInfo(), leftValue.GetValue(), valueExpr.GetValue()), nil default: - return nil, fmt.Errorf("todo") + return nil, fmt.Errorf("module field is not yet supported") } } else { // (a + b) / 2 == 3 @@ -268,7 +311,11 @@ func handleBinaryArithExpr(op planpb.OpType, arithExpr *planpb.BinaryArithExpr, } func handleCompareRightValue(op planpb.OpType, left *ExprWithType, right *planpb.ValueExpr) (*planpb.Expr, error) { - castedValue, err := castValue(left.dataType, right.GetValue()) + dataType := left.dataType + if typeutil.IsArrayType(dataType) && len(toColumnInfo(left).GetNestedPath()) != 0 { + dataType = toColumnInfo(left).GetElementType() + } + castedValue, err := castValue(dataType, right.GetValue()) if err != nil { return nil, err } @@ -281,7 +328,6 @@ func handleCompareRightValue(op planpb.OpType, left *ExprWithType, right *planpb if columnInfo == nil { return nil, fmt.Errorf("not supported to combine multiple fields") } - expr := &planpb.Expr{ Expr: &planpb.Expr_UnaryRangeExpr{ UnaryRangeExpr: &planpb.UnaryRangeExpr{ @@ -332,9 +378,49 @@ func relationalCompatible(t1, t2 schemapb.DataType) bool { return both || neither } +func canBeComparedDataType(left, right schemapb.DataType) bool { + switch left { + case schemapb.DataType_Bool: + return typeutil.IsBoolType(right) || typeutil.IsJSONType(right) + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, schemapb.DataType_Int64, + schemapb.DataType_Float, schemapb.DataType_Double: + return typeutil.IsArithmetic(right) || typeutil.IsJSONType(right) + case schemapb.DataType_String, schemapb.DataType_VarChar: + return typeutil.IsStringType(right) || typeutil.IsJSONType(right) + case schemapb.DataType_JSON: + return true + default: + return false + } +} + +func getArrayElementType(expr *ExprWithType) schemapb.DataType { + if columnInfo := toColumnInfo(expr); columnInfo != nil { + return columnInfo.GetElementType() + } + if valueExpr := expr.expr.GetValueExpr(); valueExpr != nil { + return valueExpr.GetValue().GetArrayVal().GetElementType() + } + return schemapb.DataType_None +} + +func canBeCompared(left, right *ExprWithType) bool { + if !typeutil.IsArrayType(left.dataType) && !typeutil.IsArrayType(right.dataType) { + return canBeComparedDataType(left.dataType, right.dataType) + } + if typeutil.IsArrayType(left.dataType) && typeutil.IsArrayType(right.dataType) { + return canBeComparedDataType(getArrayElementType(left), getArrayElementType(right)) + } + if typeutil.IsArrayType(left.dataType) { + return canBeComparedDataType(getArrayElementType(left), right.dataType) + } + return canBeComparedDataType(left.dataType, getArrayElementType(right)) +} + func HandleCompare(op int, left, right *ExprWithType) (*planpb.Expr, error) { - if !relationalCompatible(left.dataType, right.dataType) { - return nil, fmt.Errorf("comparisons between string and non-string are not supported") + if !canBeCompared(left, right) { + return nil, fmt.Errorf("comparisons between %s, element_type: %s and %s elementType: %s are not supported", + left.dataType, getArrayElementType(left), right.dataType, getArrayElementType(right)) } cmpOp := cmpOpMap[op] @@ -424,3 +510,34 @@ func convertEscapeSingle(literal string) (string, error) { b.WriteString(`"`) return strconv.Unquote(b.String()) } + +func canArithmeticDataType(left, right schemapb.DataType) bool { + switch left { + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, schemapb.DataType_Int64, + schemapb.DataType_Float, schemapb.DataType_Double: + return typeutil.IsArithmetic(right) || typeutil.IsJSONType(right) + case schemapb.DataType_JSON: + return typeutil.IsArithmetic(right) + default: + return false + } +} + +func canArithmetic(left *ExprWithType, right *ExprWithType) bool { + if !typeutil.IsArrayType(left.dataType) && !typeutil.IsArrayType(right.dataType) { + return canArithmeticDataType(left.dataType, right.dataType) + } + if typeutil.IsArrayType(left.dataType) && typeutil.IsArrayType(right.dataType) { + return canArithmeticDataType(getArrayElementType(left), getArrayElementType(right)) + } + if typeutil.IsArrayType(left.dataType) { + return canArithmeticDataType(getArrayElementType(left), right.dataType) + } + return canArithmeticDataType(left.dataType, getArrayElementType(right)) +} + +func isIntegerColumn(col *planpb.ColumnInfo) bool { + return typeutil.IsIntegerType(col.GetDataType()) || + (typeutil.IsArrayType(col.GetDataType()) && typeutil.IsIntegerType(col.GetElementType())) || + typeutil.IsJSONType(col.GetDataType()) +} diff --git a/internal/parser/planparserv2/utils_test.go b/internal/parser/planparserv2/utils_test.go index 4c399401d8cd0..843bd2d34bba3 100644 --- a/internal/parser/planparserv2/utils_test.go +++ b/internal/parser/planparserv2/utils_test.go @@ -3,10 +3,10 @@ package planparserv2 import ( "testing" - "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/planpb" ) func Test_relationalCompatible(t *testing.T) { @@ -209,3 +209,115 @@ func Test_convertEscapeSingle(t *testing.T) { assert.Equal(t, c.expected, actual) } } + +func Test_canBeComparedDataType(t *testing.T) { + type testCases struct { + left schemapb.DataType + right schemapb.DataType + expected bool + } + + cases := []testCases{ + {schemapb.DataType_Bool, schemapb.DataType_Bool, true}, + {schemapb.DataType_Bool, schemapb.DataType_JSON, true}, + {schemapb.DataType_Bool, schemapb.DataType_Int8, false}, + {schemapb.DataType_Bool, schemapb.DataType_Int16, false}, + {schemapb.DataType_Bool, schemapb.DataType_Int32, false}, + {schemapb.DataType_Bool, schemapb.DataType_Int64, false}, + {schemapb.DataType_Bool, schemapb.DataType_Float, false}, + {schemapb.DataType_Bool, schemapb.DataType_Double, false}, + {schemapb.DataType_Bool, schemapb.DataType_String, false}, + {schemapb.DataType_Int8, schemapb.DataType_Int16, true}, + {schemapb.DataType_Int16, schemapb.DataType_Int32, true}, + {schemapb.DataType_Int32, schemapb.DataType_Int64, true}, + {schemapb.DataType_Int64, schemapb.DataType_Float, true}, + {schemapb.DataType_Float, schemapb.DataType_Double, true}, + {schemapb.DataType_Double, schemapb.DataType_Int32, true}, + {schemapb.DataType_Double, schemapb.DataType_String, false}, + {schemapb.DataType_Int64, schemapb.DataType_String, false}, + {schemapb.DataType_Int64, schemapb.DataType_JSON, true}, + {schemapb.DataType_Double, schemapb.DataType_JSON, true}, + {schemapb.DataType_String, schemapb.DataType_Double, false}, + {schemapb.DataType_String, schemapb.DataType_Int64, false}, + {schemapb.DataType_String, schemapb.DataType_JSON, true}, + {schemapb.DataType_String, schemapb.DataType_String, true}, + {schemapb.DataType_String, schemapb.DataType_VarChar, true}, + {schemapb.DataType_VarChar, schemapb.DataType_VarChar, true}, + {schemapb.DataType_VarChar, schemapb.DataType_JSON, true}, + {schemapb.DataType_VarChar, schemapb.DataType_Int64, false}, + {schemapb.DataType_Array, schemapb.DataType_Int64, false}, + {schemapb.DataType_Array, schemapb.DataType_Array, false}, + } + + for _, c := range cases { + assert.Equal(t, c.expected, canBeComparedDataType(c.left, c.right)) + } +} + +func Test_getArrayElementType(t *testing.T) { + t.Run("array element", func(t *testing.T) { + expr := &ExprWithType{ + expr: &planpb.Expr{ + Expr: &planpb.Expr_ValueExpr{ + ValueExpr: &planpb.ValueExpr{ + Value: &planpb.GenericValue{ + Val: &planpb.GenericValue_ArrayVal{ + ArrayVal: &planpb.Array{ + Array: nil, + ElementType: schemapb.DataType_Int64, + }, + }, + }, + }, + }, + }, + dataType: schemapb.DataType_Array, + nodeDependent: true, + } + + assert.Equal(t, schemapb.DataType_Int64, getArrayElementType(expr)) + }) + + t.Run("array field", func(t *testing.T) { + expr := &ExprWithType{ + expr: &planpb.Expr{ + Expr: &planpb.Expr_ColumnExpr{ + ColumnExpr: &planpb.ColumnExpr{ + Info: &planpb.ColumnInfo{ + FieldId: 101, + DataType: schemapb.DataType_Array, + IsPrimaryKey: false, + IsAutoID: false, + NestedPath: nil, + IsPartitionKey: false, + ElementType: schemapb.DataType_Int64, + }, + }, + }, + }, + dataType: schemapb.DataType_Array, + nodeDependent: true, + } + + assert.Equal(t, schemapb.DataType_Int64, getArrayElementType(expr)) + }) + + t.Run("not array", func(t *testing.T) { + expr := &ExprWithType{ + expr: &planpb.Expr{ + Expr: &planpb.Expr_ColumnExpr{ + ColumnExpr: &planpb.ColumnExpr{ + Info: &planpb.ColumnInfo{ + FieldId: 102, + DataType: schemapb.DataType_String, + }, + }, + }, + }, + dataType: schemapb.DataType_String, + nodeDependent: true, + } + + assert.Equal(t, schemapb.DataType_None, getArrayElementType(expr)) + }) +} diff --git a/internal/proto/data_coord.proto b/internal/proto/data_coord.proto index 8aa1e1658d1dc..d94e4b2470cef 100644 --- a/internal/proto/data_coord.proto +++ b/internal/proto/data_coord.proto @@ -21,6 +21,12 @@ enum SegmentType { Compacted = 3; } +enum SegmentLevel { + Legacy = 0; // zero value for legacy logic + L0 = 1; // L0 segment, contains delta data for current channel + L1 = 2; // L1 segment, normal segment, with no extra compaction attribute +} + service DataCoord { rpc GetComponentStates(milvus.GetComponentStatesRequest) returns (milvus.ComponentStates) {} rpc GetTimeTickChannel(internal.GetTimeTickChannelRequest) returns(milvus.StringResponse) {} @@ -54,7 +60,7 @@ service DataCoord { rpc GetCompactionStateWithPlans(milvus.GetCompactionPlansRequest) returns (milvus.GetCompactionPlansResponse) {} rpc WatchChannels(WatchChannelsRequest) returns (WatchChannelsResponse) {} - rpc GetFlushState(milvus.GetFlushStateRequest) returns (milvus.GetFlushStateResponse) {} + rpc GetFlushState(GetFlushStateRequest) returns (milvus.GetFlushStateResponse) {} rpc DropVirtualChannel(DropVirtualChannelRequest) returns (DropVirtualChannelResponse) {} rpc SetSegmentState(SetSegmentStateRequest) returns (SetSegmentStateResponse) {} @@ -83,7 +89,7 @@ service DataCoord { rpc GetIndexBuildProgress(index.GetIndexBuildProgressRequest) returns (index.GetIndexBuildProgressResponse) {} rpc GcConfirm(GcConfirmRequest) returns (GcConfirmResponse) {} - + rpc ReportDataNodeTtMsgs(ReportDataNodeTtMsgsRequest) returns (common.Status) {} } @@ -105,9 +111,14 @@ service DataNode { // https://wiki.lfaidata.foundation/display/MIL/MEP+24+--+Support+bulk+load rpc Import(ImportTaskRequest) returns(common.Status) {} + // Deprecated rpc ResendSegmentStats(ResendSegmentStatsRequest) returns(ResendSegmentStatsResponse) {} rpc AddImportSegment(AddImportSegmentRequest) returns(AddImportSegmentResponse) {} + + rpc FlushChannels(FlushChannelsRequest) returns(common.Status) {} + rpc NotifyChannelOperation(ChannelOperationsRequest) returns(common.Status) {} + rpc CheckChannelOperationProgress(ChannelWatchInfo) returns(ChannelOperationProgressResponse) {} } message FlushRequest { @@ -125,6 +136,13 @@ message FlushResponse { repeated int64 segmentIDs = 4; // newly sealed segments repeated int64 flushSegmentIDs = 5; // old flushed segment int64 timeOfSeal = 6; + uint64 flush_ts = 7; +} + +message FlushChannelsRequest { + common.MsgBase base = 1; + uint64 flush_ts = 2; + repeated string channels = 3; } message SegmentIDRequest { @@ -248,6 +266,7 @@ message FlushSegmentsRequest { int64 dbID = 2; int64 collectionID = 3; repeated int64 segmentIDs = 4; // segments to flush + string channelName = 5; // vchannel name to flush } message SegmentMsg{ @@ -284,7 +303,13 @@ message SegmentInfo { // For compatibility reasons, this flag of an old compacted segment may still be False. // As for new fields added in the message, they will be populated with their respective field types' default values. bool compacted = 19; - int64 storage_version = 20; + + // Segment level, indicating compaction segment level + // Available value: Legacy, L0, L1 + // For legacy level, it represent old segment before segment level introduced + // so segments with Legacy level shall be treated as L1 segment + SegmentLevel level = 20; + int64 storage_version = 21; } message SegmentStartPosition { @@ -431,8 +456,9 @@ message ChannelWatchInfo { int64 timeoutTs = 4; // the schema of the collection to watch, to avoid get schema rpc issues. schema.CollectionSchema schema = 5; - // watch progress + // watch progress, deprecated int32 progress = 6; + int64 opID = 7; } enum CompactionType { @@ -440,6 +466,11 @@ enum CompactionType { reserved 1; MergeCompaction = 2; MixCompaction = 3; + // compactionV2 + SingleCompaction = 4; + MinorCompaction = 5; + MajorCompaction = 6; + Level0DeleteCompaction = 7; } message CompactionStateRequest { @@ -505,7 +536,7 @@ message WatchChannelsRequest { int64 collectionID = 1; repeated string channelNames = 2; repeated common.KeyDataPair start_positions = 3; - schema.CollectionSchema schema = 4; + schema.CollectionSchema schema = 4; uint64 create_timestamp = 5; } @@ -679,3 +710,22 @@ message ReportDataNodeTtMsgsRequest { common.MsgBase base = 1; repeated msg.DataNodeTtMsg msgs = 2; // -1 means whole collection. } + +message GetFlushStateRequest { + repeated int64 segmentIDs = 1; + uint64 flush_ts = 2; + string db_name = 3; + string collection_name = 4; + int64 collectionID = 5; +} + +message ChannelOperationsRequest { + repeated ChannelWatchInfo infos = 1; +} + +message ChannelOperationProgressResponse { + common.Status status = 1; + int64 opID = 2; + ChannelWatchState state = 3; + int32 progress = 4; +} diff --git a/internal/proto/datapb/data_coord.pb.go b/internal/proto/datapb/data_coord.pb.go index 1ad0168428fd0..11452bea51cd8 100644 --- a/internal/proto/datapb/data_coord.pb.go +++ b/internal/proto/datapb/data_coord.pb.go @@ -61,6 +61,34 @@ func (SegmentType) EnumDescriptor() ([]byte, []int) { return fileDescriptor_82cd95f524594f49, []int{0} } +type SegmentLevel int32 + +const ( + SegmentLevel_Legacy SegmentLevel = 0 + SegmentLevel_L0 SegmentLevel = 1 + SegmentLevel_L1 SegmentLevel = 2 +) + +var SegmentLevel_name = map[int32]string{ + 0: "Legacy", + 1: "L0", + 2: "L1", +} + +var SegmentLevel_value = map[string]int32{ + "Legacy": 0, + "L0": 1, + "L1": 2, +} + +func (x SegmentLevel) String() string { + return proto.EnumName(SegmentLevel_name, int32(x)) +} + +func (SegmentLevel) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_82cd95f524594f49, []int{1} +} + type ChannelWatchState int32 const ( @@ -101,7 +129,7 @@ func (x ChannelWatchState) String() string { } func (ChannelWatchState) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{1} + return fileDescriptor_82cd95f524594f49, []int{2} } type CompactionType int32 @@ -110,18 +138,31 @@ const ( CompactionType_UndefinedCompaction CompactionType = 0 CompactionType_MergeCompaction CompactionType = 2 CompactionType_MixCompaction CompactionType = 3 + // compactionV2 + CompactionType_SingleCompaction CompactionType = 4 + CompactionType_MinorCompaction CompactionType = 5 + CompactionType_MajorCompaction CompactionType = 6 + CompactionType_Level0DeleteCompaction CompactionType = 7 ) var CompactionType_name = map[int32]string{ 0: "UndefinedCompaction", 2: "MergeCompaction", 3: "MixCompaction", + 4: "SingleCompaction", + 5: "MinorCompaction", + 6: "MajorCompaction", + 7: "Level0DeleteCompaction", } var CompactionType_value = map[string]int32{ - "UndefinedCompaction": 0, - "MergeCompaction": 2, - "MixCompaction": 3, + "UndefinedCompaction": 0, + "MergeCompaction": 2, + "MixCompaction": 3, + "SingleCompaction": 4, + "MinorCompaction": 5, + "MajorCompaction": 6, + "Level0DeleteCompaction": 7, } func (x CompactionType) String() string { @@ -129,7 +170,7 @@ func (x CompactionType) String() string { } func (CompactionType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{2} + return fileDescriptor_82cd95f524594f49, []int{3} } // TODO: import google/protobuf/empty.proto @@ -242,6 +283,7 @@ type FlushResponse struct { SegmentIDs []int64 `protobuf:"varint,4,rep,packed,name=segmentIDs,proto3" json:"segmentIDs,omitempty"` FlushSegmentIDs []int64 `protobuf:"varint,5,rep,packed,name=flushSegmentIDs,proto3" json:"flushSegmentIDs,omitempty"` TimeOfSeal int64 `protobuf:"varint,6,opt,name=timeOfSeal,proto3" json:"timeOfSeal,omitempty"` + FlushTs uint64 `protobuf:"varint,7,opt,name=flush_ts,json=flushTs,proto3" json:"flush_ts,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -314,6 +356,68 @@ func (m *FlushResponse) GetTimeOfSeal() int64 { return 0 } +func (m *FlushResponse) GetFlushTs() uint64 { + if m != nil { + return m.FlushTs + } + return 0 +} + +type FlushChannelsRequest struct { + Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + FlushTs uint64 `protobuf:"varint,2,opt,name=flush_ts,json=flushTs,proto3" json:"flush_ts,omitempty"` + Channels []string `protobuf:"bytes,3,rep,name=channels,proto3" json:"channels,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *FlushChannelsRequest) Reset() { *m = FlushChannelsRequest{} } +func (m *FlushChannelsRequest) String() string { return proto.CompactTextString(m) } +func (*FlushChannelsRequest) ProtoMessage() {} +func (*FlushChannelsRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_82cd95f524594f49, []int{3} +} + +func (m *FlushChannelsRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_FlushChannelsRequest.Unmarshal(m, b) +} +func (m *FlushChannelsRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_FlushChannelsRequest.Marshal(b, m, deterministic) +} +func (m *FlushChannelsRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_FlushChannelsRequest.Merge(m, src) +} +func (m *FlushChannelsRequest) XXX_Size() int { + return xxx_messageInfo_FlushChannelsRequest.Size(m) +} +func (m *FlushChannelsRequest) XXX_DiscardUnknown() { + xxx_messageInfo_FlushChannelsRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_FlushChannelsRequest proto.InternalMessageInfo + +func (m *FlushChannelsRequest) GetBase() *commonpb.MsgBase { + if m != nil { + return m.Base + } + return nil +} + +func (m *FlushChannelsRequest) GetFlushTs() uint64 { + if m != nil { + return m.FlushTs + } + return 0 +} + +func (m *FlushChannelsRequest) GetChannels() []string { + if m != nil { + return m.Channels + } + return nil +} + type SegmentIDRequest struct { Count uint32 `protobuf:"varint,1,opt,name=count,proto3" json:"count,omitempty"` ChannelName string `protobuf:"bytes,2,opt,name=channel_name,json=channelName,proto3" json:"channel_name,omitempty"` @@ -330,7 +434,7 @@ func (m *SegmentIDRequest) Reset() { *m = SegmentIDRequest{} } func (m *SegmentIDRequest) String() string { return proto.CompactTextString(m) } func (*SegmentIDRequest) ProtoMessage() {} func (*SegmentIDRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{3} + return fileDescriptor_82cd95f524594f49, []int{4} } func (m *SegmentIDRequest) XXX_Unmarshal(b []byte) error { @@ -406,7 +510,7 @@ func (m *AssignSegmentIDRequest) Reset() { *m = AssignSegmentIDRequest{} func (m *AssignSegmentIDRequest) String() string { return proto.CompactTextString(m) } func (*AssignSegmentIDRequest) ProtoMessage() {} func (*AssignSegmentIDRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{4} + return fileDescriptor_82cd95f524594f49, []int{5} } func (m *AssignSegmentIDRequest) XXX_Unmarshal(b []byte) error { @@ -465,7 +569,7 @@ func (m *SegmentIDAssignment) Reset() { *m = SegmentIDAssignment{} } func (m *SegmentIDAssignment) String() string { return proto.CompactTextString(m) } func (*SegmentIDAssignment) ProtoMessage() {} func (*SegmentIDAssignment) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{5} + return fileDescriptor_82cd95f524594f49, []int{6} } func (m *SegmentIDAssignment) XXX_Unmarshal(b []byte) error { @@ -547,7 +651,7 @@ func (m *AssignSegmentIDResponse) Reset() { *m = AssignSegmentIDResponse func (m *AssignSegmentIDResponse) String() string { return proto.CompactTextString(m) } func (*AssignSegmentIDResponse) ProtoMessage() {} func (*AssignSegmentIDResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{6} + return fileDescriptor_82cd95f524594f49, []int{7} } func (m *AssignSegmentIDResponse) XXX_Unmarshal(b []byte) error { @@ -594,7 +698,7 @@ func (m *GetSegmentStatesRequest) Reset() { *m = GetSegmentStatesRequest func (m *GetSegmentStatesRequest) String() string { return proto.CompactTextString(m) } func (*GetSegmentStatesRequest) ProtoMessage() {} func (*GetSegmentStatesRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{7} + return fileDescriptor_82cd95f524594f49, []int{8} } func (m *GetSegmentStatesRequest) XXX_Unmarshal(b []byte) error { @@ -644,7 +748,7 @@ func (m *SegmentStateInfo) Reset() { *m = SegmentStateInfo{} } func (m *SegmentStateInfo) String() string { return proto.CompactTextString(m) } func (*SegmentStateInfo) ProtoMessage() {} func (*SegmentStateInfo) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{8} + return fileDescriptor_82cd95f524594f49, []int{9} } func (m *SegmentStateInfo) XXX_Unmarshal(b []byte) error { @@ -712,7 +816,7 @@ func (m *GetSegmentStatesResponse) Reset() { *m = GetSegmentStatesRespon func (m *GetSegmentStatesResponse) String() string { return proto.CompactTextString(m) } func (*GetSegmentStatesResponse) ProtoMessage() {} func (*GetSegmentStatesResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{9} + return fileDescriptor_82cd95f524594f49, []int{10} } func (m *GetSegmentStatesResponse) XXX_Unmarshal(b []byte) error { @@ -760,7 +864,7 @@ func (m *GetSegmentInfoRequest) Reset() { *m = GetSegmentInfoRequest{} } func (m *GetSegmentInfoRequest) String() string { return proto.CompactTextString(m) } func (*GetSegmentInfoRequest) ProtoMessage() {} func (*GetSegmentInfoRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{10} + return fileDescriptor_82cd95f524594f49, []int{11} } func (m *GetSegmentInfoRequest) XXX_Unmarshal(b []byte) error { @@ -815,7 +919,7 @@ func (m *GetSegmentInfoResponse) Reset() { *m = GetSegmentInfoResponse{} func (m *GetSegmentInfoResponse) String() string { return proto.CompactTextString(m) } func (*GetSegmentInfoResponse) ProtoMessage() {} func (*GetSegmentInfoResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{11} + return fileDescriptor_82cd95f524594f49, []int{12} } func (m *GetSegmentInfoResponse) XXX_Unmarshal(b []byte) error { @@ -869,7 +973,7 @@ func (m *GetInsertBinlogPathsRequest) Reset() { *m = GetInsertBinlogPath func (m *GetInsertBinlogPathsRequest) String() string { return proto.CompactTextString(m) } func (*GetInsertBinlogPathsRequest) ProtoMessage() {} func (*GetInsertBinlogPathsRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{12} + return fileDescriptor_82cd95f524594f49, []int{13} } func (m *GetInsertBinlogPathsRequest) XXX_Unmarshal(b []byte) error { @@ -917,7 +1021,7 @@ func (m *GetInsertBinlogPathsResponse) Reset() { *m = GetInsertBinlogPat func (m *GetInsertBinlogPathsResponse) String() string { return proto.CompactTextString(m) } func (*GetInsertBinlogPathsResponse) ProtoMessage() {} func (*GetInsertBinlogPathsResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{13} + return fileDescriptor_82cd95f524594f49, []int{14} } func (m *GetInsertBinlogPathsResponse) XXX_Unmarshal(b []byte) error { @@ -972,7 +1076,7 @@ func (m *GetCollectionStatisticsRequest) Reset() { *m = GetCollectionSta func (m *GetCollectionStatisticsRequest) String() string { return proto.CompactTextString(m) } func (*GetCollectionStatisticsRequest) ProtoMessage() {} func (*GetCollectionStatisticsRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{14} + return fileDescriptor_82cd95f524594f49, []int{15} } func (m *GetCollectionStatisticsRequest) XXX_Unmarshal(b []byte) error { @@ -1026,7 +1130,7 @@ func (m *GetCollectionStatisticsResponse) Reset() { *m = GetCollectionSt func (m *GetCollectionStatisticsResponse) String() string { return proto.CompactTextString(m) } func (*GetCollectionStatisticsResponse) ProtoMessage() {} func (*GetCollectionStatisticsResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{15} + return fileDescriptor_82cd95f524594f49, []int{16} } func (m *GetCollectionStatisticsResponse) XXX_Unmarshal(b []byte) error { @@ -1075,7 +1179,7 @@ func (m *GetPartitionStatisticsRequest) Reset() { *m = GetPartitionStati func (m *GetPartitionStatisticsRequest) String() string { return proto.CompactTextString(m) } func (*GetPartitionStatisticsRequest) ProtoMessage() {} func (*GetPartitionStatisticsRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{16} + return fileDescriptor_82cd95f524594f49, []int{17} } func (m *GetPartitionStatisticsRequest) XXX_Unmarshal(b []byte) error { @@ -1136,7 +1240,7 @@ func (m *GetPartitionStatisticsResponse) Reset() { *m = GetPartitionStat func (m *GetPartitionStatisticsResponse) String() string { return proto.CompactTextString(m) } func (*GetPartitionStatisticsResponse) ProtoMessage() {} func (*GetPartitionStatisticsResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{17} + return fileDescriptor_82cd95f524594f49, []int{18} } func (m *GetPartitionStatisticsResponse) XXX_Unmarshal(b []byte) error { @@ -1181,7 +1285,7 @@ func (m *GetSegmentInfoChannelRequest) Reset() { *m = GetSegmentInfoChan func (m *GetSegmentInfoChannelRequest) String() string { return proto.CompactTextString(m) } func (*GetSegmentInfoChannelRequest) ProtoMessage() {} func (*GetSegmentInfoChannelRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{18} + return fileDescriptor_82cd95f524594f49, []int{19} } func (m *GetSegmentInfoChannelRequest) XXX_Unmarshal(b []byte) error { @@ -1223,7 +1327,7 @@ func (m *VchannelInfo) Reset() { *m = VchannelInfo{} } func (m *VchannelInfo) String() string { return proto.CompactTextString(m) } func (*VchannelInfo) ProtoMessage() {} func (*VchannelInfo) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{19} + return fileDescriptor_82cd95f524594f49, []int{20} } func (m *VchannelInfo) XXX_Unmarshal(b []byte) error { @@ -1333,7 +1437,7 @@ func (m *WatchDmChannelsRequest) Reset() { *m = WatchDmChannelsRequest{} func (m *WatchDmChannelsRequest) String() string { return proto.CompactTextString(m) } func (*WatchDmChannelsRequest) ProtoMessage() {} func (*WatchDmChannelsRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{20} + return fileDescriptor_82cd95f524594f49, []int{21} } func (m *WatchDmChannelsRequest) XXX_Unmarshal(b []byte) error { @@ -1373,6 +1477,7 @@ type FlushSegmentsRequest struct { DbID int64 `protobuf:"varint,2,opt,name=dbID,proto3" json:"dbID,omitempty"` CollectionID int64 `protobuf:"varint,3,opt,name=collectionID,proto3" json:"collectionID,omitempty"` SegmentIDs []int64 `protobuf:"varint,4,rep,packed,name=segmentIDs,proto3" json:"segmentIDs,omitempty"` + ChannelName string `protobuf:"bytes,5,opt,name=channelName,proto3" json:"channelName,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -1382,7 +1487,7 @@ func (m *FlushSegmentsRequest) Reset() { *m = FlushSegmentsRequest{} } func (m *FlushSegmentsRequest) String() string { return proto.CompactTextString(m) } func (*FlushSegmentsRequest) ProtoMessage() {} func (*FlushSegmentsRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{21} + return fileDescriptor_82cd95f524594f49, []int{22} } func (m *FlushSegmentsRequest) XXX_Unmarshal(b []byte) error { @@ -1431,6 +1536,13 @@ func (m *FlushSegmentsRequest) GetSegmentIDs() []int64 { return nil } +func (m *FlushSegmentsRequest) GetChannelName() string { + if m != nil { + return m.ChannelName + } + return "" +} + type SegmentMsg struct { Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` Segment *SegmentInfo `protobuf:"bytes,2,opt,name=segment,proto3" json:"segment,omitempty"` @@ -1443,7 +1555,7 @@ func (m *SegmentMsg) Reset() { *m = SegmentMsg{} } func (m *SegmentMsg) String() string { return proto.CompactTextString(m) } func (*SegmentMsg) ProtoMessage() {} func (*SegmentMsg) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{22} + return fileDescriptor_82cd95f524594f49, []int{23} } func (m *SegmentMsg) XXX_Unmarshal(b []byte) error { @@ -1505,18 +1617,23 @@ type SegmentInfo struct { // denote if this segment is compacted to other segment. // For compatibility reasons, this flag of an old compacted segment may still be False. // As for new fields added in the message, they will be populated with their respective field types' default values. - Compacted bool `protobuf:"varint,19,opt,name=compacted,proto3" json:"compacted,omitempty"` - StorageVersion int64 `protobuf:"varint,20,opt,name=storage_version,json=storageVersion,proto3" json:"storage_version,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + Compacted bool `protobuf:"varint,19,opt,name=compacted,proto3" json:"compacted,omitempty"` + // Segment level, indicating compaction segment level + // Available value: Legacy, L0, L1 + // For legacy level, it represent old segment before segment level introduced + // so segments with Legacy level shall be treated as L1 segment + Level SegmentLevel `protobuf:"varint,20,opt,name=level,proto3,enum=milvus.proto.data.SegmentLevel" json:"level,omitempty"` + StorageVersion int64 `protobuf:"varint,21,opt,name=storage_version,json=storageVersion,proto3" json:"storage_version,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } func (m *SegmentInfo) Reset() { *m = SegmentInfo{} } func (m *SegmentInfo) String() string { return proto.CompactTextString(m) } func (*SegmentInfo) ProtoMessage() {} func (*SegmentInfo) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{23} + return fileDescriptor_82cd95f524594f49, []int{24} } func (m *SegmentInfo) XXX_Unmarshal(b []byte) error { @@ -1670,6 +1787,13 @@ func (m *SegmentInfo) GetCompacted() bool { return false } +func (m *SegmentInfo) GetLevel() SegmentLevel { + if m != nil { + return m.Level + } + return SegmentLevel_Legacy +} + func (m *SegmentInfo) GetStorageVersion() int64 { if m != nil { return m.StorageVersion @@ -1689,7 +1813,7 @@ func (m *SegmentStartPosition) Reset() { *m = SegmentStartPosition{} } func (m *SegmentStartPosition) String() string { return proto.CompactTextString(m) } func (*SegmentStartPosition) ProtoMessage() {} func (*SegmentStartPosition) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{24} + return fileDescriptor_82cd95f524594f49, []int{25} } func (m *SegmentStartPosition) XXX_Unmarshal(b []byte) error { @@ -1747,7 +1871,7 @@ func (m *SaveBinlogPathsRequest) Reset() { *m = SaveBinlogPathsRequest{} func (m *SaveBinlogPathsRequest) String() string { return proto.CompactTextString(m) } func (*SaveBinlogPathsRequest) ProtoMessage() {} func (*SaveBinlogPathsRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{25} + return fileDescriptor_82cd95f524594f49, []int{26} } func (m *SaveBinlogPathsRequest) XXX_Unmarshal(b []byte) error { @@ -1872,7 +1996,7 @@ func (m *CheckPoint) Reset() { *m = CheckPoint{} } func (m *CheckPoint) String() string { return proto.CompactTextString(m) } func (*CheckPoint) ProtoMessage() {} func (*CheckPoint) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{26} + return fileDescriptor_82cd95f524594f49, []int{27} } func (m *CheckPoint) XXX_Unmarshal(b []byte) error { @@ -1929,7 +2053,7 @@ func (m *DeltaLogInfo) Reset() { *m = DeltaLogInfo{} } func (m *DeltaLogInfo) String() string { return proto.CompactTextString(m) } func (*DeltaLogInfo) ProtoMessage() {} func (*DeltaLogInfo) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{27} + return fileDescriptor_82cd95f524594f49, []int{28} } func (m *DeltaLogInfo) XXX_Unmarshal(b []byte) error { @@ -1998,7 +2122,7 @@ func (m *ChannelStatus) Reset() { *m = ChannelStatus{} } func (m *ChannelStatus) String() string { return proto.CompactTextString(m) } func (*ChannelStatus) ProtoMessage() {} func (*ChannelStatus) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{28} + return fileDescriptor_82cd95f524594f49, []int{29} } func (m *ChannelStatus) XXX_Unmarshal(b []byte) error { @@ -2053,7 +2177,7 @@ func (m *DataNodeInfo) Reset() { *m = DataNodeInfo{} } func (m *DataNodeInfo) String() string { return proto.CompactTextString(m) } func (*DataNodeInfo) ProtoMessage() {} func (*DataNodeInfo) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{29} + return fileDescriptor_82cd95f524594f49, []int{30} } func (m *DataNodeInfo) XXX_Unmarshal(b []byte) error { @@ -2111,7 +2235,7 @@ func (m *SegmentBinlogs) Reset() { *m = SegmentBinlogs{} } func (m *SegmentBinlogs) String() string { return proto.CompactTextString(m) } func (*SegmentBinlogs) ProtoMessage() {} func (*SegmentBinlogs) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{30} + return fileDescriptor_82cd95f524594f49, []int{31} } func (m *SegmentBinlogs) XXX_Unmarshal(b []byte) error { @@ -2186,7 +2310,7 @@ func (m *FieldBinlog) Reset() { *m = FieldBinlog{} } func (m *FieldBinlog) String() string { return proto.CompactTextString(m) } func (*FieldBinlog) ProtoMessage() {} func (*FieldBinlog) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{31} + return fileDescriptor_82cd95f524594f49, []int{32} } func (m *FieldBinlog) XXX_Unmarshal(b []byte) error { @@ -2238,7 +2362,7 @@ func (m *Binlog) Reset() { *m = Binlog{} } func (m *Binlog) String() string { return proto.CompactTextString(m) } func (*Binlog) ProtoMessage() {} func (*Binlog) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{32} + return fileDescriptor_82cd95f524594f49, []int{33} } func (m *Binlog) XXX_Unmarshal(b []byte) error { @@ -2314,7 +2438,7 @@ func (m *GetRecoveryInfoResponse) Reset() { *m = GetRecoveryInfoResponse func (m *GetRecoveryInfoResponse) String() string { return proto.CompactTextString(m) } func (*GetRecoveryInfoResponse) ProtoMessage() {} func (*GetRecoveryInfoResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{33} + return fileDescriptor_82cd95f524594f49, []int{34} } func (m *GetRecoveryInfoResponse) XXX_Unmarshal(b []byte) error { @@ -2369,7 +2493,7 @@ func (m *GetRecoveryInfoRequest) Reset() { *m = GetRecoveryInfoRequest{} func (m *GetRecoveryInfoRequest) String() string { return proto.CompactTextString(m) } func (*GetRecoveryInfoRequest) ProtoMessage() {} func (*GetRecoveryInfoRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{34} + return fileDescriptor_82cd95f524594f49, []int{35} } func (m *GetRecoveryInfoRequest) XXX_Unmarshal(b []byte) error { @@ -2424,7 +2548,7 @@ func (m *GetRecoveryInfoResponseV2) Reset() { *m = GetRecoveryInfoRespon func (m *GetRecoveryInfoResponseV2) String() string { return proto.CompactTextString(m) } func (*GetRecoveryInfoResponseV2) ProtoMessage() {} func (*GetRecoveryInfoResponseV2) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{35} + return fileDescriptor_82cd95f524594f49, []int{36} } func (m *GetRecoveryInfoResponseV2) XXX_Unmarshal(b []byte) error { @@ -2479,7 +2603,7 @@ func (m *GetRecoveryInfoRequestV2) Reset() { *m = GetRecoveryInfoRequest func (m *GetRecoveryInfoRequestV2) String() string { return proto.CompactTextString(m) } func (*GetRecoveryInfoRequestV2) ProtoMessage() {} func (*GetRecoveryInfoRequestV2) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{36} + return fileDescriptor_82cd95f524594f49, []int{37} } func (m *GetRecoveryInfoRequestV2) XXX_Unmarshal(b []byte) error { @@ -2535,7 +2659,7 @@ func (m *GetSegmentsByStatesRequest) Reset() { *m = GetSegmentsByStatesR func (m *GetSegmentsByStatesRequest) String() string { return proto.CompactTextString(m) } func (*GetSegmentsByStatesRequest) ProtoMessage() {} func (*GetSegmentsByStatesRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{37} + return fileDescriptor_82cd95f524594f49, []int{38} } func (m *GetSegmentsByStatesRequest) XXX_Unmarshal(b []byte) error { @@ -2596,7 +2720,7 @@ func (m *GetSegmentsByStatesResponse) Reset() { *m = GetSegmentsByStates func (m *GetSegmentsByStatesResponse) String() string { return proto.CompactTextString(m) } func (*GetSegmentsByStatesResponse) ProtoMessage() {} func (*GetSegmentsByStatesResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{38} + return fileDescriptor_82cd95f524594f49, []int{39} } func (m *GetSegmentsByStatesResponse) XXX_Unmarshal(b []byte) error { @@ -2645,7 +2769,7 @@ func (m *GetFlushedSegmentsRequest) Reset() { *m = GetFlushedSegmentsReq func (m *GetFlushedSegmentsRequest) String() string { return proto.CompactTextString(m) } func (*GetFlushedSegmentsRequest) ProtoMessage() {} func (*GetFlushedSegmentsRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{39} + return fileDescriptor_82cd95f524594f49, []int{40} } func (m *GetFlushedSegmentsRequest) XXX_Unmarshal(b []byte) error { @@ -2706,7 +2830,7 @@ func (m *GetFlushedSegmentsResponse) Reset() { *m = GetFlushedSegmentsRe func (m *GetFlushedSegmentsResponse) String() string { return proto.CompactTextString(m) } func (*GetFlushedSegmentsResponse) ProtoMessage() {} func (*GetFlushedSegmentsResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{40} + return fileDescriptor_82cd95f524594f49, []int{41} } func (m *GetFlushedSegmentsResponse) XXX_Unmarshal(b []byte) error { @@ -2753,7 +2877,7 @@ func (m *SegmentFlushCompletedMsg) Reset() { *m = SegmentFlushCompletedM func (m *SegmentFlushCompletedMsg) String() string { return proto.CompactTextString(m) } func (*SegmentFlushCompletedMsg) ProtoMessage() {} func (*SegmentFlushCompletedMsg) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{41} + return fileDescriptor_82cd95f524594f49, []int{42} } func (m *SegmentFlushCompletedMsg) XXX_Unmarshal(b []byte) error { @@ -2797,8 +2921,9 @@ type ChannelWatchInfo struct { TimeoutTs int64 `protobuf:"varint,4,opt,name=timeoutTs,proto3" json:"timeoutTs,omitempty"` // the schema of the collection to watch, to avoid get schema rpc issues. Schema *schemapb.CollectionSchema `protobuf:"bytes,5,opt,name=schema,proto3" json:"schema,omitempty"` - // watch progress + // watch progress, deprecated Progress int32 `protobuf:"varint,6,opt,name=progress,proto3" json:"progress,omitempty"` + OpID int64 `protobuf:"varint,7,opt,name=opID,proto3" json:"opID,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -2808,7 +2933,7 @@ func (m *ChannelWatchInfo) Reset() { *m = ChannelWatchInfo{} } func (m *ChannelWatchInfo) String() string { return proto.CompactTextString(m) } func (*ChannelWatchInfo) ProtoMessage() {} func (*ChannelWatchInfo) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{42} + return fileDescriptor_82cd95f524594f49, []int{43} } func (m *ChannelWatchInfo) XXX_Unmarshal(b []byte) error { @@ -2871,6 +2996,13 @@ func (m *ChannelWatchInfo) GetProgress() int32 { return 0 } +func (m *ChannelWatchInfo) GetOpID() int64 { + if m != nil { + return m.OpID + } + return 0 +} + type CompactionStateRequest struct { Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` @@ -2882,7 +3014,7 @@ func (m *CompactionStateRequest) Reset() { *m = CompactionStateRequest{} func (m *CompactionStateRequest) String() string { return proto.CompactTextString(m) } func (*CompactionStateRequest) ProtoMessage() {} func (*CompactionStateRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{43} + return fileDescriptor_82cd95f524594f49, []int{44} } func (m *CompactionStateRequest) XXX_Unmarshal(b []byte) error { @@ -2925,7 +3057,7 @@ func (m *SyncSegmentsRequest) Reset() { *m = SyncSegmentsRequest{} } func (m *SyncSegmentsRequest) String() string { return proto.CompactTextString(m) } func (*SyncSegmentsRequest) ProtoMessage() {} func (*SyncSegmentsRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{44} + return fileDescriptor_82cd95f524594f49, []int{45} } func (m *SyncSegmentsRequest) XXX_Unmarshal(b []byte) error { @@ -2996,7 +3128,7 @@ func (m *CompactionSegmentBinlogs) Reset() { *m = CompactionSegmentBinlo func (m *CompactionSegmentBinlogs) String() string { return proto.CompactTextString(m) } func (*CompactionSegmentBinlogs) ProtoMessage() {} func (*CompactionSegmentBinlogs) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{45} + return fileDescriptor_82cd95f524594f49, []int{46} } func (m *CompactionSegmentBinlogs) XXX_Unmarshal(b []byte) error { @@ -3071,7 +3203,7 @@ func (m *CompactionPlan) Reset() { *m = CompactionPlan{} } func (m *CompactionPlan) String() string { return proto.CompactTextString(m) } func (*CompactionPlan) ProtoMessage() {} func (*CompactionPlan) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{46} + return fileDescriptor_82cd95f524594f49, []int{47} } func (m *CompactionPlan) XXX_Unmarshal(b []byte) error { @@ -3172,7 +3304,7 @@ func (m *CompactionResult) Reset() { *m = CompactionResult{} } func (m *CompactionResult) String() string { return proto.CompactTextString(m) } func (*CompactionResult) ProtoMessage() {} func (*CompactionResult) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{47} + return fileDescriptor_82cd95f524594f49, []int{48} } func (m *CompactionResult) XXX_Unmarshal(b []byte) error { @@ -3255,7 +3387,7 @@ func (m *CompactionStateResult) Reset() { *m = CompactionStateResult{} } func (m *CompactionStateResult) String() string { return proto.CompactTextString(m) } func (*CompactionStateResult) ProtoMessage() {} func (*CompactionStateResult) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{48} + return fileDescriptor_82cd95f524594f49, []int{49} } func (m *CompactionStateResult) XXX_Unmarshal(b []byte) error { @@ -3309,7 +3441,7 @@ func (m *CompactionStateResponse) Reset() { *m = CompactionStateResponse func (m *CompactionStateResponse) String() string { return proto.CompactTextString(m) } func (*CompactionStateResponse) ProtoMessage() {} func (*CompactionStateResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{49} + return fileDescriptor_82cd95f524594f49, []int{50} } func (m *CompactionStateResponse) XXX_Unmarshal(b []byte) error { @@ -3357,7 +3489,7 @@ func (m *SegmentFieldBinlogMeta) Reset() { *m = SegmentFieldBinlogMeta{} func (m *SegmentFieldBinlogMeta) String() string { return proto.CompactTextString(m) } func (*SegmentFieldBinlogMeta) ProtoMessage() {} func (*SegmentFieldBinlogMeta) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{50} + return fileDescriptor_82cd95f524594f49, []int{51} } func (m *SegmentFieldBinlogMeta) XXX_Unmarshal(b []byte) error { @@ -3407,7 +3539,7 @@ func (m *WatchChannelsRequest) Reset() { *m = WatchChannelsRequest{} } func (m *WatchChannelsRequest) String() string { return proto.CompactTextString(m) } func (*WatchChannelsRequest) ProtoMessage() {} func (*WatchChannelsRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{51} + return fileDescriptor_82cd95f524594f49, []int{52} } func (m *WatchChannelsRequest) XXX_Unmarshal(b []byte) error { @@ -3474,7 +3606,7 @@ func (m *WatchChannelsResponse) Reset() { *m = WatchChannelsResponse{} } func (m *WatchChannelsResponse) String() string { return proto.CompactTextString(m) } func (*WatchChannelsResponse) ProtoMessage() {} func (*WatchChannelsResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{52} + return fileDescriptor_82cd95f524594f49, []int{53} } func (m *WatchChannelsResponse) XXX_Unmarshal(b []byte) error { @@ -3515,7 +3647,7 @@ func (m *SetSegmentStateRequest) Reset() { *m = SetSegmentStateRequest{} func (m *SetSegmentStateRequest) String() string { return proto.CompactTextString(m) } func (*SetSegmentStateRequest) ProtoMessage() {} func (*SetSegmentStateRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{53} + return fileDescriptor_82cd95f524594f49, []int{54} } func (m *SetSegmentStateRequest) XXX_Unmarshal(b []byte) error { @@ -3568,7 +3700,7 @@ func (m *SetSegmentStateResponse) Reset() { *m = SetSegmentStateResponse func (m *SetSegmentStateResponse) String() string { return proto.CompactTextString(m) } func (*SetSegmentStateResponse) ProtoMessage() {} func (*SetSegmentStateResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{54} + return fileDescriptor_82cd95f524594f49, []int{55} } func (m *SetSegmentStateResponse) XXX_Unmarshal(b []byte) error { @@ -3609,7 +3741,7 @@ func (m *DropVirtualChannelRequest) Reset() { *m = DropVirtualChannelReq func (m *DropVirtualChannelRequest) String() string { return proto.CompactTextString(m) } func (*DropVirtualChannelRequest) ProtoMessage() {} func (*DropVirtualChannelRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{55} + return fileDescriptor_82cd95f524594f49, []int{56} } func (m *DropVirtualChannelRequest) XXX_Unmarshal(b []byte) error { @@ -3669,7 +3801,7 @@ func (m *DropVirtualChannelSegment) Reset() { *m = DropVirtualChannelSeg func (m *DropVirtualChannelSegment) String() string { return proto.CompactTextString(m) } func (*DropVirtualChannelSegment) ProtoMessage() {} func (*DropVirtualChannelSegment) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{56} + return fileDescriptor_82cd95f524594f49, []int{57} } func (m *DropVirtualChannelSegment) XXX_Unmarshal(b []byte) error { @@ -3757,7 +3889,7 @@ func (m *DropVirtualChannelResponse) Reset() { *m = DropVirtualChannelRe func (m *DropVirtualChannelResponse) String() string { return proto.CompactTextString(m) } func (*DropVirtualChannelResponse) ProtoMessage() {} func (*DropVirtualChannelResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{57} + return fileDescriptor_82cd95f524594f49, []int{58} } func (m *DropVirtualChannelResponse) XXX_Unmarshal(b []byte) error { @@ -3804,7 +3936,7 @@ func (m *ImportTask) Reset() { *m = ImportTask{} } func (m *ImportTask) String() string { return proto.CompactTextString(m) } func (*ImportTask) ProtoMessage() {} func (*ImportTask) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{58} + return fileDescriptor_82cd95f524594f49, []int{59} } func (m *ImportTask) XXX_Unmarshal(b []byte) error { @@ -3903,7 +4035,7 @@ func (m *ImportTaskState) Reset() { *m = ImportTaskState{} } func (m *ImportTaskState) String() string { return proto.CompactTextString(m) } func (*ImportTaskState) ProtoMessage() {} func (*ImportTaskState) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{59} + return fileDescriptor_82cd95f524594f49, []int{60} } func (m *ImportTaskState) XXX_Unmarshal(b []byte) error { @@ -3985,7 +4117,7 @@ func (m *ImportTaskInfo) Reset() { *m = ImportTaskInfo{} } func (m *ImportTaskInfo) String() string { return proto.CompactTextString(m) } func (*ImportTaskInfo) ProtoMessage() {} func (*ImportTaskInfo) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{60} + return fileDescriptor_82cd95f524594f49, []int{61} } func (m *ImportTaskInfo) XXX_Unmarshal(b []byte) error { @@ -4131,7 +4263,7 @@ func (m *ImportTaskResponse) Reset() { *m = ImportTaskResponse{} } func (m *ImportTaskResponse) String() string { return proto.CompactTextString(m) } func (*ImportTaskResponse) ProtoMessage() {} func (*ImportTaskResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{61} + return fileDescriptor_82cd95f524594f49, []int{62} } func (m *ImportTaskResponse) XXX_Unmarshal(b []byte) error { @@ -4179,7 +4311,7 @@ func (m *ImportTaskRequest) Reset() { *m = ImportTaskRequest{} } func (m *ImportTaskRequest) String() string { return proto.CompactTextString(m) } func (*ImportTaskRequest) ProtoMessage() {} func (*ImportTaskRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{62} + return fileDescriptor_82cd95f524594f49, []int{63} } func (m *ImportTaskRequest) XXX_Unmarshal(b []byte) error { @@ -4233,7 +4365,7 @@ func (m *UpdateSegmentStatisticsRequest) Reset() { *m = UpdateSegmentSta func (m *UpdateSegmentStatisticsRequest) String() string { return proto.CompactTextString(m) } func (*UpdateSegmentStatisticsRequest) ProtoMessage() {} func (*UpdateSegmentStatisticsRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{63} + return fileDescriptor_82cd95f524594f49, []int{64} } func (m *UpdateSegmentStatisticsRequest) XXX_Unmarshal(b []byte) error { @@ -4281,7 +4413,7 @@ func (m *UpdateChannelCheckpointRequest) Reset() { *m = UpdateChannelChe func (m *UpdateChannelCheckpointRequest) String() string { return proto.CompactTextString(m) } func (*UpdateChannelCheckpointRequest) ProtoMessage() {} func (*UpdateChannelCheckpointRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{64} + return fileDescriptor_82cd95f524594f49, []int{65} } func (m *UpdateChannelCheckpointRequest) XXX_Unmarshal(b []byte) error { @@ -4334,7 +4466,7 @@ func (m *ResendSegmentStatsRequest) Reset() { *m = ResendSegmentStatsReq func (m *ResendSegmentStatsRequest) String() string { return proto.CompactTextString(m) } func (*ResendSegmentStatsRequest) ProtoMessage() {} func (*ResendSegmentStatsRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{65} + return fileDescriptor_82cd95f524594f49, []int{66} } func (m *ResendSegmentStatsRequest) XXX_Unmarshal(b []byte) error { @@ -4374,7 +4506,7 @@ func (m *ResendSegmentStatsResponse) Reset() { *m = ResendSegmentStatsRe func (m *ResendSegmentStatsResponse) String() string { return proto.CompactTextString(m) } func (*ResendSegmentStatsResponse) ProtoMessage() {} func (*ResendSegmentStatsResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{66} + return fileDescriptor_82cd95f524594f49, []int{67} } func (m *ResendSegmentStatsResponse) XXX_Unmarshal(b []byte) error { @@ -4426,7 +4558,7 @@ func (m *AddImportSegmentRequest) Reset() { *m = AddImportSegmentRequest func (m *AddImportSegmentRequest) String() string { return proto.CompactTextString(m) } func (*AddImportSegmentRequest) ProtoMessage() {} func (*AddImportSegmentRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{67} + return fileDescriptor_82cd95f524594f49, []int{68} } func (m *AddImportSegmentRequest) XXX_Unmarshal(b []byte) error { @@ -4508,7 +4640,7 @@ func (m *AddImportSegmentResponse) Reset() { *m = AddImportSegmentRespon func (m *AddImportSegmentResponse) String() string { return proto.CompactTextString(m) } func (*AddImportSegmentResponse) ProtoMessage() {} func (*AddImportSegmentResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{68} + return fileDescriptor_82cd95f524594f49, []int{69} } func (m *AddImportSegmentResponse) XXX_Unmarshal(b []byte) error { @@ -4561,7 +4693,7 @@ func (m *SaveImportSegmentRequest) Reset() { *m = SaveImportSegmentReque func (m *SaveImportSegmentRequest) String() string { return proto.CompactTextString(m) } func (*SaveImportSegmentRequest) ProtoMessage() {} func (*SaveImportSegmentRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{69} + return fileDescriptor_82cd95f524594f49, []int{70} } func (m *SaveImportSegmentRequest) XXX_Unmarshal(b []byte) error { @@ -4650,7 +4782,7 @@ func (m *UnsetIsImportingStateRequest) Reset() { *m = UnsetIsImportingSt func (m *UnsetIsImportingStateRequest) String() string { return proto.CompactTextString(m) } func (*UnsetIsImportingStateRequest) ProtoMessage() {} func (*UnsetIsImportingStateRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{70} + return fileDescriptor_82cd95f524594f49, []int{71} } func (m *UnsetIsImportingStateRequest) XXX_Unmarshal(b []byte) error { @@ -4697,7 +4829,7 @@ func (m *MarkSegmentsDroppedRequest) Reset() { *m = MarkSegmentsDroppedR func (m *MarkSegmentsDroppedRequest) String() string { return proto.CompactTextString(m) } func (*MarkSegmentsDroppedRequest) ProtoMessage() {} func (*MarkSegmentsDroppedRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{71} + return fileDescriptor_82cd95f524594f49, []int{72} } func (m *MarkSegmentsDroppedRequest) XXX_Unmarshal(b []byte) error { @@ -4745,7 +4877,7 @@ func (m *SegmentReferenceLock) Reset() { *m = SegmentReferenceLock{} } func (m *SegmentReferenceLock) String() string { return proto.CompactTextString(m) } func (*SegmentReferenceLock) ProtoMessage() {} func (*SegmentReferenceLock) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{72} + return fileDescriptor_82cd95f524594f49, []int{73} } func (m *SegmentReferenceLock) XXX_Unmarshal(b []byte) error { @@ -4802,7 +4934,7 @@ func (m *AlterCollectionRequest) Reset() { *m = AlterCollectionRequest{} func (m *AlterCollectionRequest) String() string { return proto.CompactTextString(m) } func (*AlterCollectionRequest) ProtoMessage() {} func (*AlterCollectionRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{73} + return fileDescriptor_82cd95f524594f49, []int{74} } func (m *AlterCollectionRequest) XXX_Unmarshal(b []byte) error { @@ -4870,7 +5002,7 @@ func (m *GcConfirmRequest) Reset() { *m = GcConfirmRequest{} } func (m *GcConfirmRequest) String() string { return proto.CompactTextString(m) } func (*GcConfirmRequest) ProtoMessage() {} func (*GcConfirmRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{74} + return fileDescriptor_82cd95f524594f49, []int{75} } func (m *GcConfirmRequest) XXX_Unmarshal(b []byte) error { @@ -4917,7 +5049,7 @@ func (m *GcConfirmResponse) Reset() { *m = GcConfirmResponse{} } func (m *GcConfirmResponse) String() string { return proto.CompactTextString(m) } func (*GcConfirmResponse) ProtoMessage() {} func (*GcConfirmResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{75} + return fileDescriptor_82cd95f524594f49, []int{76} } func (m *GcConfirmResponse) XXX_Unmarshal(b []byte) error { @@ -4964,7 +5096,7 @@ func (m *ReportDataNodeTtMsgsRequest) Reset() { *m = ReportDataNodeTtMsg func (m *ReportDataNodeTtMsgsRequest) String() string { return proto.CompactTextString(m) } func (*ReportDataNodeTtMsgsRequest) ProtoMessage() {} func (*ReportDataNodeTtMsgsRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_82cd95f524594f49, []int{76} + return fileDescriptor_82cd95f524594f49, []int{77} } func (m *ReportDataNodeTtMsgsRequest) XXX_Unmarshal(b []byte) error { @@ -4999,13 +5131,188 @@ func (m *ReportDataNodeTtMsgsRequest) GetMsgs() []*msgpb.DataNodeTtMsg { return nil } +type GetFlushStateRequest struct { + SegmentIDs []int64 `protobuf:"varint,1,rep,packed,name=segmentIDs,proto3" json:"segmentIDs,omitempty"` + FlushTs uint64 `protobuf:"varint,2,opt,name=flush_ts,json=flushTs,proto3" json:"flush_ts,omitempty"` + DbName string `protobuf:"bytes,3,opt,name=db_name,json=dbName,proto3" json:"db_name,omitempty"` + CollectionName string `protobuf:"bytes,4,opt,name=collection_name,json=collectionName,proto3" json:"collection_name,omitempty"` + CollectionID int64 `protobuf:"varint,5,opt,name=collectionID,proto3" json:"collectionID,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *GetFlushStateRequest) Reset() { *m = GetFlushStateRequest{} } +func (m *GetFlushStateRequest) String() string { return proto.CompactTextString(m) } +func (*GetFlushStateRequest) ProtoMessage() {} +func (*GetFlushStateRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_82cd95f524594f49, []int{78} +} + +func (m *GetFlushStateRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_GetFlushStateRequest.Unmarshal(m, b) +} +func (m *GetFlushStateRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_GetFlushStateRequest.Marshal(b, m, deterministic) +} +func (m *GetFlushStateRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_GetFlushStateRequest.Merge(m, src) +} +func (m *GetFlushStateRequest) XXX_Size() int { + return xxx_messageInfo_GetFlushStateRequest.Size(m) +} +func (m *GetFlushStateRequest) XXX_DiscardUnknown() { + xxx_messageInfo_GetFlushStateRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_GetFlushStateRequest proto.InternalMessageInfo + +func (m *GetFlushStateRequest) GetSegmentIDs() []int64 { + if m != nil { + return m.SegmentIDs + } + return nil +} + +func (m *GetFlushStateRequest) GetFlushTs() uint64 { + if m != nil { + return m.FlushTs + } + return 0 +} + +func (m *GetFlushStateRequest) GetDbName() string { + if m != nil { + return m.DbName + } + return "" +} + +func (m *GetFlushStateRequest) GetCollectionName() string { + if m != nil { + return m.CollectionName + } + return "" +} + +func (m *GetFlushStateRequest) GetCollectionID() int64 { + if m != nil { + return m.CollectionID + } + return 0 +} + +type ChannelOperationsRequest struct { + Infos []*ChannelWatchInfo `protobuf:"bytes,1,rep,name=infos,proto3" json:"infos,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ChannelOperationsRequest) Reset() { *m = ChannelOperationsRequest{} } +func (m *ChannelOperationsRequest) String() string { return proto.CompactTextString(m) } +func (*ChannelOperationsRequest) ProtoMessage() {} +func (*ChannelOperationsRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_82cd95f524594f49, []int{79} +} + +func (m *ChannelOperationsRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ChannelOperationsRequest.Unmarshal(m, b) +} +func (m *ChannelOperationsRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ChannelOperationsRequest.Marshal(b, m, deterministic) +} +func (m *ChannelOperationsRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_ChannelOperationsRequest.Merge(m, src) +} +func (m *ChannelOperationsRequest) XXX_Size() int { + return xxx_messageInfo_ChannelOperationsRequest.Size(m) +} +func (m *ChannelOperationsRequest) XXX_DiscardUnknown() { + xxx_messageInfo_ChannelOperationsRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_ChannelOperationsRequest proto.InternalMessageInfo + +func (m *ChannelOperationsRequest) GetInfos() []*ChannelWatchInfo { + if m != nil { + return m.Infos + } + return nil +} + +type ChannelOperationProgressResponse struct { + Status *commonpb.Status `protobuf:"bytes,1,opt,name=status,proto3" json:"status,omitempty"` + OpID int64 `protobuf:"varint,2,opt,name=opID,proto3" json:"opID,omitempty"` + State ChannelWatchState `protobuf:"varint,3,opt,name=state,proto3,enum=milvus.proto.data.ChannelWatchState" json:"state,omitempty"` + Progress int32 `protobuf:"varint,4,opt,name=progress,proto3" json:"progress,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ChannelOperationProgressResponse) Reset() { *m = ChannelOperationProgressResponse{} } +func (m *ChannelOperationProgressResponse) String() string { return proto.CompactTextString(m) } +func (*ChannelOperationProgressResponse) ProtoMessage() {} +func (*ChannelOperationProgressResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_82cd95f524594f49, []int{80} +} + +func (m *ChannelOperationProgressResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ChannelOperationProgressResponse.Unmarshal(m, b) +} +func (m *ChannelOperationProgressResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ChannelOperationProgressResponse.Marshal(b, m, deterministic) +} +func (m *ChannelOperationProgressResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_ChannelOperationProgressResponse.Merge(m, src) +} +func (m *ChannelOperationProgressResponse) XXX_Size() int { + return xxx_messageInfo_ChannelOperationProgressResponse.Size(m) +} +func (m *ChannelOperationProgressResponse) XXX_DiscardUnknown() { + xxx_messageInfo_ChannelOperationProgressResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_ChannelOperationProgressResponse proto.InternalMessageInfo + +func (m *ChannelOperationProgressResponse) GetStatus() *commonpb.Status { + if m != nil { + return m.Status + } + return nil +} + +func (m *ChannelOperationProgressResponse) GetOpID() int64 { + if m != nil { + return m.OpID + } + return 0 +} + +func (m *ChannelOperationProgressResponse) GetState() ChannelWatchState { + if m != nil { + return m.State + } + return ChannelWatchState_Uncomplete +} + +func (m *ChannelOperationProgressResponse) GetProgress() int32 { + if m != nil { + return m.Progress + } + return 0 +} + func init() { proto.RegisterEnum("milvus.proto.data.SegmentType", SegmentType_name, SegmentType_value) + proto.RegisterEnum("milvus.proto.data.SegmentLevel", SegmentLevel_name, SegmentLevel_value) proto.RegisterEnum("milvus.proto.data.ChannelWatchState", ChannelWatchState_name, ChannelWatchState_value) proto.RegisterEnum("milvus.proto.data.CompactionType", CompactionType_name, CompactionType_value) proto.RegisterType((*Empty)(nil), "milvus.proto.data.Empty") proto.RegisterType((*FlushRequest)(nil), "milvus.proto.data.FlushRequest") proto.RegisterType((*FlushResponse)(nil), "milvus.proto.data.FlushResponse") + proto.RegisterType((*FlushChannelsRequest)(nil), "milvus.proto.data.FlushChannelsRequest") proto.RegisterType((*SegmentIDRequest)(nil), "milvus.proto.data.SegmentIDRequest") proto.RegisterType((*AssignSegmentIDRequest)(nil), "milvus.proto.data.AssignSegmentIDRequest") proto.RegisterType((*SegmentIDAssignment)(nil), "milvus.proto.data.SegmentIDAssignment") @@ -5081,316 +5388,336 @@ func init() { proto.RegisterType((*GcConfirmRequest)(nil), "milvus.proto.data.GcConfirmRequest") proto.RegisterType((*GcConfirmResponse)(nil), "milvus.proto.data.GcConfirmResponse") proto.RegisterType((*ReportDataNodeTtMsgsRequest)(nil), "milvus.proto.data.ReportDataNodeTtMsgsRequest") + proto.RegisterType((*GetFlushStateRequest)(nil), "milvus.proto.data.GetFlushStateRequest") + proto.RegisterType((*ChannelOperationsRequest)(nil), "milvus.proto.data.ChannelOperationsRequest") + proto.RegisterType((*ChannelOperationProgressResponse)(nil), "milvus.proto.data.ChannelOperationProgressResponse") } func init() { proto.RegisterFile("data_coord.proto", fileDescriptor_82cd95f524594f49) } var fileDescriptor_82cd95f524594f49 = []byte{ - // 4849 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe4, 0x3c, 0x4b, 0x6f, 0x1c, 0x47, - 0x7a, 0xea, 0x79, 0x71, 0xe6, 0x1b, 0x72, 0x38, 0x2c, 0xd1, 0xd4, 0x68, 0x64, 0x3d, 0xb6, 0x25, - 0x59, 0xb4, 0x6c, 0x4b, 0x36, 0x95, 0xc5, 0x7a, 0xed, 0xb5, 0x77, 0x45, 0xd2, 0x92, 0x27, 0x11, - 0xb5, 0xdc, 0x26, 0x25, 0x07, 0xde, 0x00, 0x83, 0xe6, 0x74, 0x71, 0xd8, 0xcb, 0x9e, 0xee, 0x71, - 0x77, 0x0f, 0x29, 0x3a, 0x40, 0xe2, 0x24, 0x9b, 0x00, 0x79, 0x20, 0xef, 0x00, 0xc9, 0x2d, 0xc8, - 0x21, 0xd8, 0x24, 0xd8, 0x53, 0x92, 0x43, 0x82, 0x00, 0x7b, 0x75, 0x90, 0x43, 0x10, 0x04, 0x08, - 0x92, 0x43, 0xae, 0xc9, 0x0f, 0xc8, 0x1f, 0x08, 0xea, 0xd1, 0xd5, 0xaf, 0xea, 0x9e, 0x26, 0x47, - 0xb2, 0x80, 0xe4, 0x36, 0x55, 0xfd, 0xd5, 0x57, 0x5f, 0x7d, 0xf5, 0xbd, 0xab, 0x6a, 0xa0, 0x6d, - 0xe8, 0xbe, 0xde, 0x1f, 0x38, 0x8e, 0x6b, 0xdc, 0x19, 0xbb, 0x8e, 0xef, 0xa0, 0xa5, 0x91, 0x69, - 0x1d, 0x4d, 0x3c, 0xd6, 0xba, 0x43, 0x3e, 0x77, 0xe7, 0x07, 0xce, 0x68, 0xe4, 0xd8, 0xac, 0xab, - 0xdb, 0x32, 0x6d, 0x1f, 0xbb, 0xb6, 0x6e, 0xf1, 0xf6, 0x7c, 0x74, 0x40, 0x77, 0xde, 0x1b, 0x1c, - 0xe0, 0x91, 0xce, 0x5b, 0x8d, 0x91, 0x37, 0xe4, 0x3f, 0x97, 0x4c, 0xdb, 0xc0, 0xcf, 0xa2, 0x53, - 0xa9, 0x73, 0x50, 0xfd, 0x68, 0x34, 0xf6, 0x4f, 0xd4, 0xbf, 0x51, 0x60, 0xfe, 0x81, 0x35, 0xf1, - 0x0e, 0x34, 0xfc, 0xd9, 0x04, 0x7b, 0x3e, 0x7a, 0x1b, 0x2a, 0x7b, 0xba, 0x87, 0x3b, 0xca, 0x35, - 0x65, 0xb5, 0xb9, 0xf6, 0xea, 0x9d, 0x18, 0x4d, 0x9c, 0x9a, 0x2d, 0x6f, 0xb8, 0xae, 0x7b, 0x58, - 0xa3, 0x90, 0x08, 0x41, 0xc5, 0xd8, 0xeb, 0x6d, 0x76, 0x4a, 0xd7, 0x94, 0xd5, 0xb2, 0x46, 0x7f, - 0xa3, 0x2b, 0x00, 0x1e, 0x1e, 0x8e, 0xb0, 0xed, 0xf7, 0x36, 0xbd, 0x4e, 0xf9, 0x5a, 0x79, 0xb5, - 0xac, 0x45, 0x7a, 0x90, 0x0a, 0xf3, 0x03, 0xc7, 0xb2, 0xf0, 0xc0, 0x37, 0x1d, 0xbb, 0xb7, 0xd9, - 0xa9, 0xd0, 0xb1, 0xb1, 0x3e, 0xd4, 0x85, 0xba, 0xe9, 0xf5, 0x46, 0x63, 0xc7, 0xf5, 0x3b, 0xd5, - 0x6b, 0xca, 0x6a, 0x5d, 0x13, 0x6d, 0xf5, 0xbf, 0x14, 0x58, 0xe0, 0x64, 0x7b, 0x63, 0xc7, 0xf6, - 0x30, 0xba, 0x07, 0x35, 0xcf, 0xd7, 0xfd, 0x89, 0xc7, 0x29, 0xbf, 0x24, 0xa5, 0x7c, 0x87, 0x82, - 0x68, 0x1c, 0x54, 0x4a, 0x7a, 0x92, 0xb4, 0xb2, 0x84, 0xb4, 0xf8, 0xf2, 0x2a, 0xa9, 0xe5, 0xad, - 0xc2, 0xe2, 0x3e, 0xa1, 0x6e, 0x27, 0x04, 0xaa, 0x52, 0xa0, 0x64, 0x37, 0xc1, 0xe4, 0x9b, 0x23, - 0xfc, 0xdd, 0xfd, 0x1d, 0xac, 0x5b, 0x9d, 0x1a, 0x9d, 0x2b, 0xd2, 0xa3, 0xfe, 0x8b, 0x02, 0x6d, - 0x01, 0x1e, 0xec, 0xd1, 0x32, 0x54, 0x07, 0xce, 0xc4, 0xf6, 0xe9, 0x52, 0x17, 0x34, 0xd6, 0x40, - 0x5f, 0x83, 0xf9, 0xc1, 0x81, 0x6e, 0xdb, 0xd8, 0xea, 0xdb, 0xfa, 0x08, 0xd3, 0x45, 0x35, 0xb4, - 0x26, 0xef, 0x7b, 0xac, 0x8f, 0x70, 0xa1, 0xb5, 0x5d, 0x83, 0xe6, 0x58, 0x77, 0x7d, 0x33, 0xb6, - 0x33, 0xd1, 0xae, 0xbc, 0x8d, 0x21, 0x33, 0x98, 0xf4, 0xd7, 0xae, 0xee, 0x1d, 0xf6, 0x36, 0xf9, - 0x8a, 0x62, 0x7d, 0xea, 0x9f, 0x2a, 0xb0, 0x72, 0xdf, 0xf3, 0xcc, 0xa1, 0x9d, 0x5a, 0xd9, 0x0a, - 0xd4, 0x6c, 0xc7, 0xc0, 0xbd, 0x4d, 0xba, 0xb4, 0xb2, 0xc6, 0x5b, 0xe8, 0x12, 0x34, 0xc6, 0x18, - 0xbb, 0x7d, 0xd7, 0xb1, 0x82, 0x85, 0xd5, 0x49, 0x87, 0xe6, 0x58, 0x18, 0x7d, 0x0f, 0x96, 0xbc, - 0x04, 0x22, 0x26, 0x73, 0xcd, 0xb5, 0xeb, 0x77, 0x52, 0x3a, 0x75, 0x27, 0x39, 0xa9, 0x96, 0x1e, - 0xad, 0x7e, 0x51, 0x82, 0xf3, 0x02, 0x8e, 0xd1, 0x4a, 0x7e, 0x13, 0xce, 0x7b, 0x78, 0x28, 0xc8, - 0x63, 0x8d, 0x22, 0x9c, 0x17, 0x5b, 0x56, 0x8e, 0x6e, 0x59, 0x11, 0x35, 0x48, 0xec, 0x47, 0x35, - 0xbd, 0x1f, 0x57, 0xa1, 0x89, 0x9f, 0x8d, 0x4d, 0x17, 0xf7, 0x89, 0xe0, 0x50, 0x96, 0x57, 0x34, - 0x60, 0x5d, 0xbb, 0xe6, 0x28, 0xaa, 0x1b, 0x73, 0x85, 0x75, 0x43, 0xfd, 0x33, 0x05, 0x2e, 0xa4, - 0x76, 0x89, 0x2b, 0x9b, 0x06, 0x6d, 0xba, 0xf2, 0x90, 0x33, 0x44, 0xed, 0x08, 0xc3, 0x5f, 0xcb, - 0x63, 0x78, 0x08, 0xae, 0xa5, 0xc6, 0x47, 0x88, 0x2c, 0x15, 0x27, 0xf2, 0x10, 0x2e, 0x3c, 0xc4, - 0x3e, 0x9f, 0x80, 0x7c, 0xc3, 0xde, 0xd9, 0x0d, 0x59, 0x5c, 0xab, 0x4b, 0x49, 0xad, 0x56, 0xff, - 0xbc, 0x24, 0x74, 0x91, 0x4e, 0xd5, 0xb3, 0xf7, 0x1d, 0xf4, 0x2a, 0x34, 0x04, 0x08, 0x97, 0x8a, - 0xb0, 0x03, 0x7d, 0x03, 0xaa, 0x84, 0x52, 0x26, 0x12, 0xad, 0xb5, 0xaf, 0xc9, 0xd7, 0x14, 0xc1, - 0xa9, 0x31, 0x78, 0xb4, 0x09, 0x2d, 0xcf, 0xd7, 0x5d, 0xbf, 0x3f, 0x76, 0x3c, 0xba, 0xcf, 0x54, - 0x70, 0x9a, 0x6b, 0x97, 0xe3, 0x18, 0x88, 0x91, 0xdf, 0xf2, 0x86, 0xdb, 0x1c, 0x48, 0x5b, 0xa0, - 0x83, 0x82, 0x26, 0xfa, 0x0e, 0xcc, 0x63, 0xdb, 0x08, 0x71, 0x54, 0x8a, 0xe0, 0x68, 0x62, 0xdb, - 0x10, 0x18, 0xc2, 0x5d, 0xa9, 0x16, 0xdf, 0x95, 0xdf, 0x52, 0xa0, 0x93, 0xde, 0x96, 0x59, 0x0c, - 0xf5, 0xfb, 0x6c, 0x10, 0x66, 0xdb, 0x92, 0xab, 0xd7, 0x62, 0x6b, 0x34, 0x3e, 0x44, 0xfd, 0x23, - 0x05, 0x5e, 0x09, 0xc9, 0xa1, 0x9f, 0x5e, 0x94, 0x8c, 0xa0, 0xdb, 0xd0, 0x36, 0xed, 0x81, 0x35, - 0x31, 0xf0, 0x13, 0xfb, 0x63, 0xac, 0x5b, 0xfe, 0xc1, 0x09, 0xdd, 0xb9, 0xba, 0x96, 0xea, 0x57, - 0xff, 0xa3, 0x04, 0x2b, 0x49, 0xba, 0x66, 0x61, 0xd2, 0x4f, 0x41, 0xd5, 0xb4, 0xf7, 0x9d, 0x80, - 0x47, 0x57, 0x72, 0x54, 0x91, 0xcc, 0xc5, 0x80, 0x91, 0x03, 0x28, 0x30, 0x5e, 0x83, 0x03, 0x3c, - 0x38, 0x1c, 0x3b, 0x26, 0x35, 0x53, 0x04, 0xc5, 0x77, 0x24, 0x28, 0xe4, 0x14, 0xdf, 0xd9, 0x60, - 0x38, 0x36, 0x04, 0x8a, 0x8f, 0x6c, 0xdf, 0x3d, 0xd1, 0x96, 0x06, 0xc9, 0xfe, 0xee, 0x00, 0x56, - 0xe4, 0xc0, 0xa8, 0x0d, 0xe5, 0x43, 0x7c, 0x42, 0x97, 0xdc, 0xd0, 0xc8, 0x4f, 0x74, 0x0f, 0xaa, - 0x47, 0xba, 0x35, 0xc1, 0xdc, 0x26, 0x4c, 0x91, 0x5c, 0x06, 0xfb, 0x5e, 0xe9, 0x5d, 0x45, 0x1d, - 0xc1, 0xa5, 0x87, 0xd8, 0xef, 0xd9, 0x1e, 0x76, 0xfd, 0x75, 0xd3, 0xb6, 0x9c, 0xe1, 0xb6, 0xee, - 0x1f, 0xcc, 0x60, 0x1c, 0x62, 0x7a, 0x5e, 0x4a, 0xe8, 0xb9, 0xfa, 0x23, 0x05, 0x5e, 0x95, 0xcf, - 0xc7, 0x37, 0xb4, 0x0b, 0xf5, 0x7d, 0x13, 0x5b, 0x06, 0x91, 0x1a, 0x85, 0x4a, 0x8d, 0x68, 0x13, - 0x23, 0x31, 0x26, 0xc0, 0x7c, 0xdf, 0x12, 0x46, 0x42, 0xc4, 0x7c, 0x3b, 0xbe, 0x6b, 0xda, 0xc3, - 0x47, 0xa6, 0xe7, 0x6b, 0x0c, 0x3e, 0x22, 0x25, 0xe5, 0xe2, 0xca, 0xf9, 0x1b, 0x0a, 0x5c, 0x79, - 0x88, 0xfd, 0x0d, 0xe1, 0x63, 0xc8, 0x77, 0xd3, 0xf3, 0xcd, 0x81, 0xf7, 0x7c, 0x63, 0xc0, 0x02, - 0xc1, 0x86, 0xfa, 0x3b, 0x0a, 0x5c, 0xcd, 0x24, 0x86, 0xb3, 0x8e, 0xdb, 0xd0, 0xc0, 0xc3, 0xc8, - 0x6d, 0xe8, 0xcf, 0xe0, 0x93, 0xa7, 0x64, 0xf3, 0xb7, 0x75, 0xd3, 0x65, 0x36, 0xf4, 0x8c, 0x1e, - 0xe5, 0xc7, 0x0a, 0x5c, 0x7e, 0x88, 0xfd, 0xed, 0xc0, 0xbf, 0xbe, 0x44, 0xee, 0x10, 0x98, 0x88, - 0x9f, 0x0f, 0x02, 0xcd, 0x58, 0x9f, 0xfa, 0xdb, 0x6c, 0x3b, 0xa5, 0xf4, 0xbe, 0x14, 0x06, 0x5e, - 0xa1, 0x9a, 0x10, 0x31, 0x11, 0x5c, 0xd9, 0x39, 0xfb, 0xd4, 0x1f, 0x56, 0x61, 0xfe, 0x29, 0xb7, - 0x0a, 0xd4, 0x83, 0x26, 0x39, 0xa1, 0xc8, 0x83, 0xa0, 0x48, 0x34, 0x25, 0x0b, 0xb0, 0xd6, 0x61, - 0xc1, 0xc3, 0xf8, 0xf0, 0x94, 0xfe, 0x72, 0x9e, 0x8c, 0x11, 0xce, 0xee, 0x11, 0x2c, 0x4d, 0x6c, - 0x1a, 0xa1, 0x63, 0x83, 0x2f, 0x80, 0x31, 0x7d, 0xba, 0x31, 0x4d, 0x0f, 0x44, 0x1f, 0xf3, 0x24, - 0x20, 0x82, 0xab, 0x5a, 0x08, 0x57, 0x72, 0x18, 0xea, 0x41, 0xdb, 0x70, 0x9d, 0xf1, 0x18, 0x1b, - 0x7d, 0x2f, 0x40, 0x55, 0x2b, 0x86, 0x8a, 0x8f, 0x13, 0xa8, 0xde, 0x86, 0xf3, 0x49, 0x4a, 0x7b, - 0x06, 0x89, 0x0b, 0x89, 0x64, 0xc9, 0x3e, 0xa1, 0x37, 0x61, 0x29, 0x0d, 0x5f, 0xa7, 0xf0, 0xe9, - 0x0f, 0xe8, 0x2d, 0x40, 0x09, 0x52, 0x09, 0x78, 0x83, 0x81, 0xc7, 0x89, 0xe1, 0xe0, 0x34, 0x39, - 0x8d, 0x83, 0x03, 0x03, 0xe7, 0x5f, 0x22, 0xe0, 0x3d, 0xe2, 0x5d, 0x63, 0xe0, 0x5e, 0xa7, 0x59, - 0x8c, 0x11, 0x71, 0x64, 0x9e, 0xfa, 0xeb, 0x0a, 0xac, 0x7c, 0xa2, 0xfb, 0x83, 0x83, 0xcd, 0x11, - 0x17, 0xd0, 0x19, 0x14, 0xfc, 0x03, 0x68, 0x1c, 0x71, 0x61, 0x0c, 0xac, 0xf8, 0x55, 0x09, 0x41, - 0x51, 0xb1, 0xd7, 0xc2, 0x11, 0x24, 0x21, 0x5a, 0x7e, 0x10, 0x49, 0x0c, 0x5f, 0x82, 0xa9, 0x99, - 0x92, 0xd1, 0xaa, 0xcf, 0x00, 0x38, 0x71, 0x5b, 0xde, 0xf0, 0x0c, 0x74, 0xbd, 0x0b, 0x73, 0x1c, - 0x1b, 0xb7, 0x25, 0xd3, 0x36, 0x2c, 0x00, 0x57, 0xff, 0xbb, 0x06, 0xcd, 0xc8, 0x07, 0xd4, 0x82, - 0x92, 0x30, 0x12, 0x25, 0xc9, 0xea, 0x4a, 0xd3, 0x73, 0xa8, 0x72, 0x3a, 0x87, 0xba, 0x09, 0x2d, - 0x93, 0x3a, 0xef, 0x3e, 0xdf, 0x15, 0x1a, 0x2b, 0x37, 0xb4, 0x05, 0xd6, 0xcb, 0x45, 0x04, 0x5d, - 0x81, 0xa6, 0x3d, 0x19, 0xf5, 0x9d, 0xfd, 0xbe, 0xeb, 0x1c, 0x7b, 0x3c, 0x19, 0x6b, 0xd8, 0x93, - 0xd1, 0x77, 0xf7, 0x35, 0xe7, 0xd8, 0x0b, 0xe3, 0xfd, 0xda, 0x29, 0xe3, 0xfd, 0x2b, 0xd0, 0x1c, - 0xe9, 0xcf, 0x08, 0xd6, 0xbe, 0x3d, 0x19, 0xd1, 0x3c, 0xad, 0xac, 0x35, 0x46, 0xfa, 0x33, 0xcd, - 0x39, 0x7e, 0x3c, 0x19, 0xa1, 0x55, 0x68, 0x5b, 0xba, 0xe7, 0xf7, 0xa3, 0x89, 0x5e, 0x9d, 0x26, - 0x7a, 0x2d, 0xd2, 0xff, 0x51, 0x98, 0xec, 0xa5, 0x33, 0x87, 0xc6, 0xd9, 0x32, 0x07, 0x63, 0x64, - 0x85, 0x38, 0xa0, 0x50, 0xe6, 0x60, 0x8c, 0x2c, 0x81, 0xe1, 0x5d, 0x98, 0xdb, 0xa3, 0x81, 0x50, - 0x9e, 0x8a, 0x3e, 0x20, 0x31, 0x10, 0x8b, 0x97, 0xb4, 0x00, 0x1c, 0x7d, 0x0b, 0x1a, 0xd4, 0xff, - 0xd0, 0xb1, 0xf3, 0x85, 0xc6, 0x86, 0x03, 0xc8, 0x68, 0x03, 0x5b, 0xbe, 0x4e, 0x47, 0x2f, 0x14, - 0x1b, 0x2d, 0x06, 0x10, 0xfb, 0x38, 0x70, 0xb1, 0xee, 0x63, 0x63, 0xfd, 0x64, 0xc3, 0x19, 0x8d, - 0x75, 0x2a, 0x42, 0x9d, 0x16, 0x0d, 0xe1, 0x65, 0x9f, 0xd0, 0x6b, 0xd0, 0x1a, 0x88, 0xd6, 0x03, - 0xd7, 0x19, 0x75, 0x16, 0xa9, 0xf6, 0x24, 0x7a, 0xd1, 0x65, 0x80, 0xc0, 0x32, 0xea, 0x7e, 0xa7, - 0x4d, 0xf7, 0xae, 0xc1, 0x7b, 0xee, 0xd3, 0xea, 0x8d, 0xe9, 0xf5, 0x59, 0x9d, 0xc4, 0xb4, 0x87, - 0x9d, 0x25, 0x3a, 0x63, 0x33, 0x28, 0xac, 0x98, 0xf6, 0x10, 0x5d, 0x80, 0x39, 0xd3, 0xeb, 0xef, - 0xeb, 0x87, 0xb8, 0x83, 0xe8, 0xd7, 0x9a, 0xe9, 0x3d, 0xd0, 0x0f, 0x69, 0x6c, 0xca, 0x27, 0xc3, - 0x46, 0xe7, 0x3c, 0xfd, 0x14, 0x76, 0xa0, 0x5b, 0xb0, 0xe8, 0xf9, 0x8e, 0xab, 0x0f, 0x71, 0xff, - 0x08, 0xbb, 0x1e, 0x59, 0xce, 0x32, 0x15, 0xaf, 0x16, 0xef, 0x7e, 0xca, 0x7a, 0xd5, 0xcf, 0x61, - 0x39, 0x14, 0xcd, 0x88, 0x2c, 0xa4, 0x25, 0x4a, 0x39, 0x83, 0x44, 0xe5, 0x07, 0xd0, 0xbf, 0x57, - 0x85, 0x95, 0x1d, 0xfd, 0x08, 0xbf, 0xf8, 0x58, 0xbd, 0x90, 0x39, 0x7c, 0x04, 0x4b, 0x34, 0x3c, - 0x5f, 0x8b, 0xd0, 0x93, 0x13, 0x09, 0x44, 0x85, 0x29, 0x3d, 0x10, 0x7d, 0x9b, 0x44, 0x2f, 0x78, - 0x70, 0xb8, 0x4d, 0x52, 0x9d, 0x20, 0x0a, 0xb8, 0x2c, 0xc1, 0xb3, 0x21, 0xa0, 0xb4, 0xe8, 0x08, - 0xb4, 0x4d, 0xb6, 0x30, 0xba, 0x03, 0x81, 0xff, 0xbf, 0x95, 0x9b, 0x07, 0x87, 0xdc, 0xd7, 0x5a, - 0xb1, 0xcd, 0xf0, 0x50, 0x07, 0xe6, 0xb8, 0xf3, 0xa6, 0xb6, 0xa6, 0xae, 0x05, 0x4d, 0xb4, 0x0d, - 0xe7, 0xd9, 0x0a, 0x76, 0xb8, 0x4a, 0xb1, 0xc5, 0xd7, 0x0b, 0x2d, 0x5e, 0x36, 0x34, 0xae, 0x91, - 0x8d, 0xd3, 0x6a, 0x64, 0x07, 0xe6, 0xb8, 0x96, 0x50, 0x23, 0x54, 0xd7, 0x82, 0x26, 0xd9, 0xe6, - 0x50, 0x5f, 0x9a, 0x4c, 0xec, 0x45, 0x07, 0x19, 0x17, 0x98, 0xf2, 0x79, 0x6a, 0xca, 0x83, 0x26, - 0xd1, 0xd8, 0xb8, 0xe4, 0x77, 0x16, 0xa4, 0xfa, 0xf0, 0xab, 0x0a, 0x40, 0xb8, 0x23, 0x53, 0x2a, - 0x3d, 0xdf, 0x84, 0xba, 0x50, 0x8f, 0x42, 0xc9, 0xaa, 0x00, 0x4f, 0x3a, 0x95, 0x72, 0xc2, 0xa9, - 0xa8, 0xff, 0xa4, 0xc0, 0xfc, 0x26, 0xe1, 0xc7, 0x23, 0x67, 0x48, 0x5d, 0xe0, 0x4d, 0x68, 0xb9, - 0x78, 0xe0, 0xb8, 0x46, 0x1f, 0xdb, 0xbe, 0x6b, 0x62, 0x56, 0x25, 0xa8, 0x68, 0x0b, 0xac, 0xf7, - 0x23, 0xd6, 0x49, 0xc0, 0x88, 0x9f, 0xf0, 0x7c, 0x7d, 0x34, 0xee, 0xef, 0x13, 0xcb, 0x54, 0x62, - 0x60, 0xa2, 0x97, 0x1a, 0xa6, 0xaf, 0xc1, 0x7c, 0x08, 0xe6, 0x3b, 0x74, 0xfe, 0x8a, 0xd6, 0x14, - 0x7d, 0xbb, 0x0e, 0xba, 0x01, 0x2d, 0xba, 0x21, 0x7d, 0xcb, 0x19, 0xf6, 0x49, 0xee, 0xc9, 0xbd, - 0xe3, 0xbc, 0xc1, 0xc9, 0x22, 0x1b, 0x1d, 0x87, 0xf2, 0xcc, 0xcf, 0x31, 0xf7, 0x8f, 0x02, 0x6a, - 0xc7, 0xfc, 0x1c, 0xab, 0xbf, 0xa2, 0xc0, 0x02, 0x77, 0xa7, 0x3b, 0xa2, 0x0a, 0x4f, 0xcb, 0xa6, - 0x2c, 0xef, 0xa7, 0xbf, 0xd1, 0x7b, 0xf1, 0xc2, 0xd9, 0x0d, 0xa9, 0xb2, 0x50, 0x24, 0x34, 0x88, - 0x8b, 0xf9, 0xd2, 0x22, 0x89, 0xe7, 0x17, 0x84, 0xa7, 0xba, 0xaf, 0x3f, 0x76, 0x0c, 0x56, 0xc7, - 0xeb, 0xc0, 0x9c, 0x6e, 0x18, 0x2e, 0xf6, 0x3c, 0x4e, 0x47, 0xd0, 0x24, 0x5f, 0x02, 0xbb, 0xc9, - 0x6c, 0x49, 0xd0, 0x44, 0xdf, 0x82, 0xba, 0x88, 0xfa, 0x58, 0xc1, 0xe4, 0x5a, 0x36, 0x9d, 0x3c, - 0x4d, 0x12, 0x23, 0xd4, 0xbf, 0x2d, 0x41, 0x8b, 0xeb, 0xea, 0x3a, 0xf7, 0x7c, 0xf9, 0x22, 0xb6, - 0x0e, 0xf3, 0xfb, 0xa1, 0x8e, 0xe4, 0x95, 0x79, 0xa2, 0xaa, 0x14, 0x1b, 0x33, 0x4d, 0xd6, 0xe2, - 0xbe, 0xb7, 0x32, 0x93, 0xef, 0xad, 0x9e, 0x56, 0xd3, 0xd3, 0x31, 0x58, 0x4d, 0x12, 0x83, 0xa9, - 0x3f, 0x07, 0xcd, 0x08, 0x02, 0x6a, 0xc9, 0x58, 0x25, 0x85, 0x73, 0x2c, 0x68, 0xa2, 0x7b, 0x61, - 0x04, 0xc2, 0x58, 0x75, 0x51, 0x42, 0x4b, 0x22, 0xf8, 0x50, 0x7f, 0xa2, 0x40, 0x8d, 0x63, 0xbe, - 0x0a, 0x4d, 0xae, 0x5f, 0x34, 0x26, 0x63, 0xd8, 0x81, 0x77, 0x91, 0xa0, 0xec, 0xf9, 0x29, 0xd8, - 0x45, 0xa8, 0x27, 0x54, 0x6b, 0x8e, 0x9b, 0xcf, 0xe0, 0x53, 0x44, 0x9f, 0xc8, 0x27, 0xa2, 0x4a, - 0x68, 0x19, 0xaa, 0x96, 0x33, 0x14, 0xa7, 0x2c, 0xac, 0xa1, 0x7e, 0xa9, 0xd0, 0xa2, 0xb8, 0x86, - 0x07, 0xce, 0x11, 0x76, 0x4f, 0x66, 0xaf, 0x2b, 0xbe, 0x1f, 0x11, 0xf3, 0x82, 0xc9, 0x8d, 0x18, - 0x80, 0xde, 0x0f, 0x37, 0xa1, 0x2c, 0x2b, 0x3f, 0x44, 0x5d, 0x16, 0x17, 0xd2, 0x70, 0x33, 0x7e, - 0x57, 0xa1, 0x15, 0xd2, 0xf8, 0x52, 0xce, 0x1a, 0x15, 0x3c, 0x97, 0x44, 0x41, 0xfd, 0x47, 0x05, - 0x2e, 0x66, 0x70, 0xf7, 0xe9, 0xda, 0x4b, 0xe0, 0xef, 0x7b, 0x50, 0x17, 0xa9, 0x70, 0xb9, 0x50, - 0x2a, 0x2c, 0xe0, 0xd5, 0x3f, 0x64, 0x75, 0x7a, 0x09, 0x7b, 0x9f, 0xae, 0xbd, 0x20, 0x06, 0x27, - 0x4b, 0x5a, 0x65, 0x49, 0x49, 0xeb, 0x9f, 0x15, 0xe8, 0x86, 0x25, 0x24, 0x6f, 0xfd, 0x64, 0xd6, - 0x83, 0x9d, 0xe7, 0x93, 0x22, 0x7e, 0x53, 0x9c, 0x41, 0x10, 0xbb, 0x58, 0x28, 0xb9, 0x0b, 0x4e, - 0x20, 0x6c, 0x5a, 0x8d, 0x4e, 0x2f, 0x68, 0x16, 0xad, 0xec, 0x46, 0x36, 0x9e, 0x9d, 0x43, 0x84, - 0x1b, 0xfb, 0x13, 0x26, 0xa4, 0x0f, 0xe2, 0x75, 0xa4, 0x97, 0xcd, 0xc0, 0xe8, 0xd9, 0xc8, 0x01, - 0x3f, 0x1b, 0xa9, 0x24, 0xce, 0x46, 0x78, 0xbf, 0x3a, 0xa2, 0x22, 0x90, 0x5a, 0xc0, 0x8b, 0x62, - 0xd8, 0xaf, 0x29, 0xd0, 0xe1, 0xb3, 0xd0, 0x39, 0x49, 0x7e, 0x67, 0x61, 0x1f, 0x1b, 0x5f, 0x75, - 0xb5, 0xe3, 0x8f, 0x4b, 0xd0, 0x8e, 0x06, 0x36, 0x34, 0x36, 0xf9, 0x3a, 0x54, 0x69, 0xb1, 0x88, - 0x53, 0x30, 0xd5, 0x3a, 0x30, 0x68, 0xe2, 0x19, 0x69, 0xd4, 0xbf, 0xeb, 0x05, 0x81, 0x0b, 0x6f, - 0x86, 0xd1, 0x55, 0xf9, 0xf4, 0xd1, 0xd5, 0xab, 0xd0, 0x20, 0x9e, 0xcb, 0x99, 0x10, 0xbc, 0xec, - 0xc0, 0x3a, 0xec, 0x40, 0x1f, 0x40, 0x8d, 0x5d, 0x43, 0xe1, 0xe7, 0x85, 0x37, 0xe3, 0xa8, 0xf9, - 0x15, 0x95, 0x48, 0xbd, 0x9f, 0x76, 0x68, 0x7c, 0x10, 0xd9, 0xa3, 0xb1, 0xeb, 0x0c, 0x69, 0x18, - 0x46, 0x9c, 0x5a, 0x55, 0x13, 0x6d, 0xf5, 0xa7, 0x61, 0x25, 0x4c, 0xbb, 0x19, 0x49, 0x67, 0x15, - 0x68, 0xf5, 0xdf, 0x14, 0x38, 0xbf, 0x73, 0x62, 0x0f, 0x92, 0xaa, 0xb1, 0x02, 0xb5, 0xb1, 0xa5, - 0x87, 0x55, 0x68, 0xde, 0xa2, 0x27, 0xfc, 0x41, 0x42, 0x4d, 0x5c, 0x38, 0xe3, 0x67, 0x53, 0xf4, - 0xed, 0x3a, 0x53, 0x23, 0xab, 0x9b, 0xa2, 0x4e, 0x80, 0x0d, 0x16, 0x2c, 0xb0, 0x2a, 0xdb, 0x82, - 0xe8, 0xa5, 0xc1, 0xc2, 0x07, 0x00, 0x34, 0x9e, 0xea, 0x9f, 0x26, 0x86, 0xa2, 0x23, 0x1e, 0x11, - 0x8f, 0xf9, 0xd7, 0x25, 0xe8, 0x44, 0xb8, 0xf4, 0x55, 0x87, 0x97, 0x19, 0xc9, 0x63, 0xf9, 0x39, - 0x25, 0x8f, 0x95, 0xd9, 0x43, 0xca, 0xaa, 0x2c, 0xa4, 0xfc, 0xa5, 0x32, 0xb4, 0x42, 0xae, 0x6d, - 0x5b, 0xba, 0x9d, 0x29, 0x09, 0x3b, 0xd0, 0xf2, 0x62, 0x5c, 0xe5, 0x7c, 0x7a, 0x43, 0xa6, 0x43, - 0x19, 0x1b, 0xa1, 0x25, 0x50, 0xa0, 0xcb, 0x74, 0xd3, 0x5d, 0x9f, 0xd5, 0xf5, 0x58, 0x7c, 0xd8, - 0x60, 0xca, 0x6a, 0x8e, 0x30, 0x7a, 0x13, 0x10, 0xd7, 0xb0, 0xbe, 0x69, 0xf7, 0x3d, 0x3c, 0x70, - 0x6c, 0x83, 0xe9, 0x5e, 0x55, 0x6b, 0xf3, 0x2f, 0x3d, 0x7b, 0x87, 0xf5, 0xa3, 0xaf, 0x43, 0xc5, - 0x3f, 0x19, 0xb3, 0x60, 0xb1, 0x25, 0x0d, 0xb7, 0x42, 0xba, 0x76, 0x4f, 0xc6, 0x58, 0xa3, 0xe0, - 0xc1, 0x4d, 0x24, 0xdf, 0xd5, 0x8f, 0x78, 0xe4, 0x5d, 0xd1, 0x22, 0x3d, 0xd1, 0x7c, 0x7a, 0x2e, - 0x9e, 0x4f, 0x53, 0xc9, 0x0e, 0x14, 0xba, 0xef, 0xfb, 0x16, 0xad, 0x4c, 0x52, 0xc9, 0x0e, 0x7a, - 0x77, 0x7d, 0x8b, 0x2c, 0xd2, 0x77, 0x7c, 0xdd, 0x62, 0xfa, 0xd1, 0xe0, 0x96, 0x83, 0xf4, 0xd0, - 0x2c, 0xf7, 0x5f, 0x89, 0xe5, 0x13, 0x84, 0x69, 0xd8, 0x9b, 0x58, 0xd9, 0xfa, 0x98, 0x5f, 0xe1, - 0x99, 0xa6, 0x8a, 0xdf, 0x86, 0x26, 0x97, 0x8a, 0x53, 0x48, 0x15, 0xb0, 0x21, 0x8f, 0x72, 0xc4, - 0xbc, 0xfa, 0x9c, 0xc4, 0xbc, 0x76, 0x86, 0x1a, 0x89, 0x7c, 0x6f, 0xd4, 0x1f, 0x29, 0xf0, 0x4a, - 0xca, 0x6a, 0xe6, 0xb2, 0x36, 0x3f, 0xf3, 0xe6, 0xd6, 0x34, 0x89, 0x92, 0xfb, 0x86, 0xf7, 0xa1, - 0xe6, 0x52, 0xec, 0xfc, 0xf4, 0xed, 0x7a, 0xae, 0xf0, 0x31, 0x42, 0x34, 0x3e, 0x44, 0xfd, 0x7d, - 0x05, 0x2e, 0xa4, 0x49, 0x9d, 0xc1, 0xe1, 0xaf, 0xc3, 0x1c, 0x43, 0x1d, 0xe8, 0xe8, 0x6a, 0xbe, - 0x8e, 0x86, 0xcc, 0xd1, 0x82, 0x81, 0xea, 0x0e, 0xac, 0x04, 0x71, 0x41, 0xc8, 0xfa, 0x2d, 0xec, - 0xeb, 0x39, 0x79, 0xe7, 0x55, 0x68, 0xb2, 0x04, 0x86, 0xe5, 0x73, 0xec, 0xb0, 0x12, 0xf6, 0x44, - 0x41, 0x50, 0xfd, 0x83, 0x12, 0x2c, 0x53, 0xc7, 0x9a, 0x3c, 0x79, 0x2a, 0x72, 0x14, 0xaa, 0x8a, - 0xcb, 0x66, 0x8f, 0xf5, 0x11, 0xbf, 0x10, 0xd3, 0xd0, 0x62, 0x7d, 0xa8, 0x97, 0xae, 0x17, 0x4a, - 0xeb, 0x13, 0xe1, 0xd9, 0xef, 0xa6, 0xee, 0xeb, 0xf4, 0xe8, 0x37, 0x59, 0x28, 0x0c, 0x1d, 0x7a, - 0xe5, 0x2c, 0x0e, 0xfd, 0x75, 0x68, 0xb3, 0xa2, 0x79, 0x5f, 0xa4, 0xbb, 0xd4, 0x30, 0x55, 0xb4, - 0x45, 0xd6, 0xbf, 0x1b, 0x74, 0xab, 0x8f, 0xe0, 0x95, 0x04, 0x53, 0x66, 0xd8, 0x7c, 0xf5, 0x2f, - 0x14, 0xb2, 0x73, 0xb1, 0x3b, 0x48, 0x67, 0x8f, 0x7f, 0x2f, 0x8b, 0xd3, 0xb1, 0xbe, 0x69, 0x24, - 0xed, 0x8d, 0x81, 0x3e, 0x84, 0x86, 0x8d, 0x8f, 0xfb, 0xd1, 0x90, 0xaa, 0x40, 0x72, 0x50, 0xb7, - 0xf1, 0x31, 0xfd, 0xa5, 0x3e, 0x86, 0x0b, 0x29, 0x52, 0x67, 0x59, 0xfb, 0xdf, 0x2b, 0x70, 0x71, - 0xd3, 0x75, 0xc6, 0x4f, 0x4d, 0xd7, 0x9f, 0xe8, 0x56, 0xfc, 0x00, 0xfe, 0x0c, 0xcb, 0x2f, 0x70, - 0xbf, 0xf1, 0xe3, 0x54, 0x1a, 0xfa, 0xa6, 0x44, 0xd9, 0xd2, 0x44, 0xf1, 0x45, 0x47, 0x42, 0xf1, - 0xff, 0x2c, 0xcb, 0x88, 0xe7, 0x70, 0x53, 0x42, 0x98, 0x22, 0x79, 0x8a, 0xb4, 0xb4, 0x5f, 0x3e, - 0x6b, 0x69, 0x3f, 0xc3, 0x13, 0x54, 0x9e, 0x93, 0x27, 0x38, 0x75, 0x0d, 0x6d, 0x03, 0xe2, 0xc7, - 0x2e, 0xd4, 0x91, 0x9f, 0xf6, 0xa8, 0xe6, 0x03, 0x80, 0xf0, 0xf4, 0x81, 0xdf, 0x19, 0x9d, 0x82, - 0x21, 0x32, 0x80, 0xec, 0x91, 0xf0, 0xb5, 0x3c, 0x14, 0x88, 0x54, 0xb3, 0xbf, 0x07, 0x5d, 0x99, - 0x6c, 0xce, 0x22, 0xef, 0xff, 0x5e, 0x02, 0xe8, 0x89, 0x1b, 0xc6, 0x67, 0x73, 0x16, 0xd7, 0x21, - 0x12, 0xae, 0x84, 0x5a, 0x1e, 0x95, 0x1d, 0x83, 0x28, 0x82, 0x48, 0x68, 0x09, 0x4c, 0x2a, 0xc9, - 0x35, 0x28, 0x9e, 0x88, 0xae, 0x30, 0x51, 0x48, 0xda, 0xe7, 0x4b, 0xd0, 0x70, 0x9d, 0xe3, 0x3e, - 0x51, 0x2e, 0x23, 0xb8, 0x42, 0xed, 0x3a, 0xc7, 0x44, 0xe5, 0x0c, 0x74, 0x01, 0xe6, 0x7c, 0xdd, - 0x3b, 0x24, 0xf8, 0x59, 0x5d, 0xaf, 0x46, 0x9a, 0x3d, 0x03, 0x2d, 0x43, 0x75, 0xdf, 0xb4, 0x30, - 0xbb, 0xad, 0xd1, 0xd0, 0x58, 0x03, 0x7d, 0x23, 0xb8, 0xf5, 0x57, 0x2f, 0x7c, 0xbb, 0x87, 0x5d, - 0xfc, 0xbb, 0x0e, 0x0b, 0x44, 0x92, 0x08, 0x11, 0x4c, 0xad, 0xdb, 0xbc, 0xa6, 0xcf, 0x3b, 0x09, - 0xa9, 0xea, 0x97, 0x0a, 0x2c, 0x86, 0xac, 0xa5, 0xb6, 0x89, 0x98, 0x3b, 0x6a, 0xea, 0x36, 0x1c, - 0x83, 0x59, 0x91, 0x56, 0x86, 0x5f, 0x61, 0x03, 0x99, 0x41, 0x0b, 0x87, 0xe4, 0x25, 0xe2, 0x64, - 0xf1, 0x84, 0x33, 0xa6, 0x11, 0x94, 0x86, 0x6a, 0xae, 0x73, 0xdc, 0x33, 0x04, 0xcb, 0xd8, 0x25, - 0x6a, 0x96, 0x76, 0x12, 0x96, 0x6d, 0xd0, 0x7b, 0xd4, 0xd7, 0x61, 0x01, 0xbb, 0xae, 0xe3, 0xf6, - 0x47, 0xd8, 0xf3, 0xf4, 0x21, 0xe6, 0x51, 0xfe, 0x3c, 0xed, 0xdc, 0x62, 0x7d, 0xea, 0x3f, 0x54, - 0xa0, 0x15, 0x2e, 0x25, 0xb8, 0x4b, 0x60, 0x1a, 0xc1, 0x5d, 0x02, 0x93, 0xec, 0x2f, 0xb8, 0xcc, - 0x4a, 0x0a, 0x09, 0x58, 0x2f, 0x75, 0x14, 0xad, 0xc1, 0x7b, 0x7b, 0x06, 0x71, 0xee, 0x84, 0x41, - 0xb6, 0x63, 0xe0, 0x50, 0x02, 0x20, 0xe8, 0xe2, 0x02, 0x10, 0x13, 0xa4, 0x4a, 0x01, 0x41, 0xaa, - 0x16, 0x10, 0xa4, 0x9a, 0x44, 0x90, 0x56, 0xa0, 0xb6, 0x37, 0x19, 0x1c, 0x62, 0x9f, 0xc7, 0x7d, - 0xbc, 0x15, 0x17, 0xb0, 0x7a, 0x42, 0xc0, 0x84, 0x1c, 0x35, 0xa2, 0x72, 0x74, 0x09, 0x1a, 0x81, - 0xa7, 0xf6, 0xe8, 0x49, 0x5b, 0x59, 0xab, 0x73, 0x17, 0xed, 0xa1, 0x77, 0x83, 0xa0, 0xb0, 0x49, - 0x35, 0x4a, 0x95, 0x18, 0xa4, 0x84, 0x94, 0x04, 0x21, 0xe1, 0x2d, 0x58, 0x8c, 0xb0, 0x83, 0xca, - 0x19, 0x3b, 0x8e, 0x8b, 0xe4, 0x0c, 0xd4, 0x83, 0xdc, 0x84, 0x56, 0xc8, 0x12, 0x0a, 0xb7, 0xc0, - 0x52, 0x35, 0xd1, 0x4b, 0xc1, 0x84, 0xb8, 0xb7, 0x4e, 0x29, 0xee, 0x17, 0xa1, 0xce, 0x73, 0x2c, - 0xaf, 0xb3, 0x18, 0x2f, 0x87, 0x14, 0xd2, 0x84, 0x1f, 0x00, 0x0a, 0x97, 0x38, 0x5b, 0x60, 0x9a, - 0x90, 0xa1, 0x52, 0x52, 0x86, 0xd4, 0xbf, 0x54, 0x60, 0x29, 0x3a, 0xd9, 0x59, 0x1d, 0xf7, 0x87, - 0xd0, 0x64, 0x07, 0xa2, 0x7d, 0x62, 0x42, 0xe4, 0xe7, 0x92, 0x89, 0xcd, 0xd3, 0x20, 0x7c, 0xab, - 0x41, 0x18, 0x73, 0xec, 0xb8, 0x87, 0xa6, 0x3d, 0xec, 0x13, 0xca, 0x44, 0xb9, 0x96, 0x77, 0x3e, - 0x26, 0x7d, 0xea, 0x6f, 0x2a, 0x70, 0xe5, 0xc9, 0xd8, 0xd0, 0x7d, 0x1c, 0x89, 0x60, 0x66, 0xbd, - 0x32, 0x29, 0xee, 0x2c, 0x96, 0x72, 0xb6, 0x39, 0x32, 0x9f, 0xc7, 0xef, 0x2c, 0x92, 0xb8, 0x8f, - 0x53, 0x93, 0xba, 0x64, 0x7c, 0x76, 0x6a, 0xba, 0x50, 0x3f, 0xe2, 0xe8, 0x82, 0xd7, 0x27, 0x41, - 0x3b, 0x76, 0xf0, 0x5b, 0x3e, 0xd5, 0xc1, 0xaf, 0xba, 0x05, 0x17, 0x35, 0xec, 0x61, 0xdb, 0x88, - 0x2d, 0xe4, 0xcc, 0x45, 0xad, 0x31, 0x74, 0x65, 0xe8, 0x66, 0x91, 0x54, 0x16, 0xf8, 0xf6, 0x5d, - 0x82, 0xd6, 0xe7, 0xc6, 0x9a, 0xc4, 0x5b, 0x74, 0x1e, 0x5f, 0xfd, 0xab, 0x12, 0x5c, 0xb8, 0x6f, - 0x18, 0xdc, 0xce, 0xf3, 0x50, 0xee, 0x45, 0x45, 0xd9, 0xc9, 0x28, 0xb4, 0x9c, 0x8e, 0x42, 0x9f, - 0x97, 0xed, 0xe5, 0x5e, 0xc8, 0x9e, 0x8c, 0x02, 0x17, 0xec, 0xb2, 0x6b, 0x58, 0xef, 0xf3, 0xe3, - 0xd1, 0xbe, 0xe5, 0x0c, 0xa9, 0x1b, 0x9e, 0x1e, 0x9c, 0xd5, 0x83, 0xe2, 0x9c, 0x3a, 0x86, 0x4e, - 0x9a, 0x59, 0x33, 0xda, 0x91, 0x80, 0x23, 0x63, 0x87, 0x15, 0x79, 0xe7, 0x49, 0x24, 0x46, 0xbb, - 0xb6, 0x1d, 0x4f, 0xfd, 0x9f, 0x12, 0x74, 0x76, 0xf4, 0x23, 0xfc, 0xff, 0x67, 0x83, 0x3e, 0x85, - 0x65, 0x4f, 0x3f, 0xc2, 0xfd, 0x48, 0x02, 0xde, 0x77, 0xf1, 0x67, 0x3c, 0x88, 0x7d, 0x5d, 0x56, - 0x86, 0x97, 0xde, 0x3a, 0xd2, 0x96, 0xbc, 0x58, 0xbf, 0x86, 0x3f, 0x43, 0xaf, 0xc1, 0x62, 0xf4, - 0x4e, 0x1c, 0x21, 0xad, 0x4e, 0x59, 0xbe, 0x10, 0xb9, 0xf7, 0xd6, 0x33, 0xd4, 0xcf, 0xe0, 0xd5, - 0x27, 0xb6, 0x87, 0xfd, 0x5e, 0x78, 0x77, 0x6b, 0xc6, 0xfc, 0xf3, 0x2a, 0x34, 0x43, 0xc6, 0xa7, - 0x9e, 0x9d, 0x18, 0x9e, 0xea, 0x40, 0x77, 0x4b, 0x77, 0x0f, 0x83, 0x72, 0xf6, 0x26, 0xbb, 0x21, - 0xf3, 0x02, 0x27, 0xdc, 0x17, 0x77, 0xc5, 0x34, 0xbc, 0x8f, 0x5d, 0x6c, 0x0f, 0xf0, 0x23, 0x67, - 0x70, 0x48, 0x02, 0x12, 0x9f, 0xbd, 0xfc, 0x53, 0x22, 0xb1, 0xeb, 0x66, 0xe4, 0x61, 0x5f, 0x29, - 0xf6, 0xb0, 0x6f, 0xca, 0x43, 0x51, 0xf5, 0xc7, 0x25, 0x58, 0xb9, 0x6f, 0xf9, 0xd8, 0x0d, 0x2b, - 0x0c, 0xa7, 0x29, 0x96, 0x84, 0xd5, 0x8b, 0xd2, 0x59, 0xaa, 0x17, 0x05, 0x4e, 0x2b, 0x65, 0xb5, - 0x96, 0xca, 0x19, 0x6b, 0x2d, 0xf7, 0x01, 0xc6, 0xae, 0x33, 0xc6, 0xae, 0x6f, 0xe2, 0x20, 0xf7, - 0x2b, 0x10, 0xe0, 0x44, 0x06, 0xa9, 0x9f, 0x42, 0xfb, 0xe1, 0x60, 0xc3, 0xb1, 0xf7, 0x4d, 0x77, - 0x14, 0x30, 0x2a, 0xa5, 0x74, 0x4a, 0x01, 0xa5, 0x2b, 0xa5, 0x94, 0x4e, 0x35, 0x61, 0x29, 0x82, - 0x7b, 0x46, 0xc3, 0x35, 0x1c, 0xf4, 0xf7, 0x4d, 0xdb, 0xa4, 0x37, 0xd0, 0x4a, 0x34, 0x40, 0x85, - 0xe1, 0xe0, 0x01, 0xef, 0x51, 0x7f, 0xa8, 0xc0, 0x25, 0x0d, 0x13, 0xe5, 0x09, 0x2e, 0xe9, 0xec, - 0xfa, 0x5b, 0xde, 0x70, 0x86, 0x80, 0xe2, 0x1e, 0x54, 0x46, 0xde, 0x30, 0xe3, 0x80, 0x9d, 0xb8, - 0xe8, 0xd8, 0x44, 0x1a, 0x05, 0xbe, 0xfd, 0xa1, 0xb8, 0x7a, 0xbc, 0x7b, 0x32, 0xc6, 0x68, 0x0e, - 0xca, 0x8f, 0xf1, 0x71, 0xfb, 0x1c, 0x02, 0xa8, 0x3d, 0x76, 0xdc, 0x91, 0x6e, 0xb5, 0x15, 0xd4, - 0x84, 0x39, 0x7e, 0x4c, 0xd9, 0x2e, 0xa1, 0x05, 0x68, 0x6c, 0x04, 0xc7, 0x39, 0xed, 0xf2, 0xed, - 0x3f, 0x51, 0x60, 0x29, 0x75, 0x90, 0x86, 0x5a, 0x00, 0x4f, 0xec, 0x01, 0x3f, 0x61, 0x6c, 0x9f, - 0x43, 0xf3, 0x50, 0x0f, 0xce, 0x1b, 0x19, 0xbe, 0x5d, 0x87, 0x42, 0xb7, 0x4b, 0xa8, 0x0d, 0xf3, - 0x6c, 0xe0, 0x64, 0x30, 0xc0, 0x9e, 0xd7, 0x2e, 0x8b, 0x9e, 0x07, 0xba, 0x69, 0x4d, 0x5c, 0xdc, - 0xae, 0x90, 0x39, 0x77, 0x1d, 0x0d, 0x5b, 0x58, 0xf7, 0x70, 0xbb, 0x8a, 0x10, 0xb4, 0x78, 0x23, - 0x18, 0x54, 0x8b, 0xf4, 0x05, 0xc3, 0xe6, 0x6e, 0x7f, 0x12, 0x3d, 0xf2, 0xa0, 0xcb, 0xbb, 0x00, - 0xe7, 0x9f, 0xd8, 0x06, 0xde, 0x37, 0x6d, 0x6c, 0x84, 0x9f, 0xda, 0xe7, 0xd0, 0x79, 0x58, 0xdc, - 0xc2, 0xee, 0x10, 0x47, 0x3a, 0x4b, 0x68, 0x09, 0x16, 0xb6, 0xcc, 0x67, 0x91, 0xae, 0xb2, 0x5a, - 0xa9, 0x2b, 0x6d, 0x65, 0xed, 0xef, 0x6e, 0x42, 0x83, 0x30, 0x73, 0xc3, 0x71, 0x5c, 0x03, 0x59, - 0x80, 0xe8, 0x03, 0x9f, 0xd1, 0xd8, 0xb1, 0xc5, 0x63, 0x40, 0x74, 0x27, 0xc1, 0x7f, 0xd6, 0x48, - 0x03, 0xf2, 0xfd, 0xee, 0xde, 0x90, 0xc2, 0x27, 0x80, 0xd5, 0x73, 0x68, 0x44, 0x67, 0xdb, 0x35, - 0x47, 0x78, 0xd7, 0x1c, 0x1c, 0x06, 0x21, 0xda, 0xdb, 0x19, 0x2f, 0xaa, 0xd2, 0xa0, 0xc1, 0x7c, - 0xd7, 0xa5, 0xf3, 0xb1, 0x17, 0x58, 0x81, 0xe8, 0xab, 0xe7, 0xd0, 0x67, 0xb0, 0xfc, 0x10, 0x47, - 0xe2, 0xdd, 0x60, 0xc2, 0xb5, 0xec, 0x09, 0x53, 0xc0, 0xa7, 0x9c, 0xf2, 0x11, 0x54, 0xa9, 0xb8, - 0x21, 0xd9, 0x29, 0x70, 0xf4, 0x25, 0x7f, 0xf7, 0x5a, 0x36, 0x80, 0xc0, 0xf6, 0x03, 0x58, 0x4c, - 0xbc, 0xf1, 0x45, 0x32, 0x1f, 0x29, 0x7f, 0xad, 0xdd, 0xbd, 0x5d, 0x04, 0x54, 0xcc, 0x35, 0x84, - 0x56, 0xfc, 0x61, 0x10, 0x5a, 0x2d, 0xf0, 0xbc, 0x90, 0xcd, 0xf4, 0x7a, 0xe1, 0x87, 0x88, 0x54, - 0x08, 0xda, 0xc9, 0xd7, 0xa7, 0xe8, 0x76, 0x2e, 0x82, 0xb8, 0xb0, 0xbd, 0x51, 0x08, 0x56, 0x4c, - 0x77, 0x42, 0x85, 0x20, 0xf5, 0xf4, 0x2f, 0x29, 0xe3, 0x01, 0x9a, 0xac, 0x37, 0x89, 0xdd, 0xbb, - 0x85, 0xe1, 0xc5, 0xd4, 0xbf, 0xcc, 0xae, 0x7a, 0xc9, 0x9e, 0xcf, 0xa1, 0x77, 0xe4, 0xe8, 0x72, - 0xde, 0xfd, 0x75, 0xd7, 0x4e, 0x33, 0x44, 0x10, 0xf1, 0x8b, 0xf4, 0x8e, 0x96, 0xe4, 0x01, 0x5a, - 0x52, 0xef, 0x02, 0x7c, 0xd9, 0x6f, 0xeb, 0xba, 0xef, 0x9c, 0x62, 0x84, 0x20, 0xc0, 0x49, 0x3e, - 0xef, 0x0d, 0xd4, 0xf0, 0xee, 0x54, 0xa9, 0x39, 0x9b, 0x0e, 0x7e, 0x1f, 0x16, 0x13, 0x51, 0x23, - 0x2a, 0x1e, 0x59, 0x76, 0xf3, 0x3c, 0x24, 0x53, 0xc9, 0xc4, 0x9d, 0x2c, 0x94, 0x21, 0xfd, 0x92, - 0x7b, 0x5b, 0xdd, 0xdb, 0x45, 0x40, 0xc5, 0x42, 0xc6, 0xb0, 0x94, 0xf8, 0xf8, 0x74, 0x0d, 0xbd, - 0x51, 0x78, 0xb6, 0xa7, 0x6b, 0xdd, 0x37, 0x8b, 0xcf, 0xf7, 0x74, 0x4d, 0x3d, 0x87, 0x3c, 0x6a, - 0xa0, 0x13, 0xf7, 0x7a, 0x50, 0x06, 0x16, 0xf9, 0xfd, 0xa5, 0xee, 0x5b, 0x05, 0xa1, 0xc5, 0x32, - 0x8f, 0xe0, 0xbc, 0xe4, 0xfa, 0x15, 0x7a, 0x2b, 0x57, 0x3c, 0x92, 0xf7, 0xce, 0xba, 0x77, 0x8a, - 0x82, 0x47, 0xdc, 0x43, 0x3b, 0xa0, 0xeb, 0xbe, 0x65, 0x31, 0xe7, 0xff, 0x66, 0x96, 0xe7, 0x8b, - 0x81, 0x65, 0x2c, 0x35, 0x13, 0x5a, 0x4c, 0xf9, 0xf3, 0x80, 0x76, 0x0e, 0x9c, 0x63, 0x1a, 0xa5, - 0x0d, 0x27, 0xae, 0xce, 0x02, 0xcb, 0x2c, 0x07, 0x98, 0x06, 0xcd, 0x50, 0xc4, 0xdc, 0x11, 0x62, - 0xf2, 0x3e, 0xc0, 0x43, 0xec, 0x6f, 0x61, 0xdf, 0x25, 0xda, 0xff, 0x5a, 0x16, 0xed, 0x1c, 0x20, - 0x98, 0xea, 0xd6, 0x54, 0xb8, 0x28, 0x43, 0xb7, 0x74, 0x7b, 0xa2, 0x5b, 0x91, 0xf7, 0x37, 0x72, - 0x86, 0x26, 0xc1, 0xf2, 0x19, 0x9a, 0x86, 0x16, 0x53, 0x1e, 0x8b, 0xf8, 0x25, 0x72, 0x4a, 0x9c, - 0x1f, 0xbf, 0xa4, 0x6f, 0x28, 0x25, 0x6d, 0x7b, 0x0e, 0xbc, 0x98, 0xf8, 0x0b, 0x85, 0x5e, 0x1a, - 0x4c, 0x00, 0x7c, 0x62, 0xfa, 0x07, 0xdb, 0x96, 0x6e, 0x7b, 0x45, 0x48, 0xa0, 0x80, 0xa7, 0x20, - 0x81, 0xc3, 0x0b, 0x12, 0x0c, 0x58, 0x88, 0x9d, 0xc8, 0x22, 0xd9, 0x73, 0x13, 0xd9, 0x41, 0x76, - 0x77, 0x75, 0x3a, 0xa0, 0x98, 0xe5, 0x00, 0x16, 0x02, 0x81, 0x66, 0xcc, 0x7d, 0x3d, 0x57, 0xe8, - 0x63, 0x7c, 0xbd, 0x5d, 0x04, 0x54, 0xcc, 0xe4, 0x01, 0x4a, 0x1f, 0x3d, 0xa1, 0x62, 0x07, 0x95, - 0x79, 0xc6, 0x27, 0xfb, 0x3c, 0x8b, 0xd9, 0xf3, 0xc4, 0xe1, 0xae, 0xdc, 0x59, 0x48, 0xcf, 0xaa, - 0xa5, 0xf6, 0x3c, 0xe3, 0xac, 0x58, 0x3d, 0x87, 0x3e, 0x81, 0x1a, 0xff, 0x1f, 0x9e, 0x1b, 0xf9, - 0x45, 0x5e, 0x8e, 0xfd, 0xe6, 0x14, 0x28, 0x81, 0xf8, 0x10, 0x2e, 0x64, 0x94, 0x78, 0xa5, 0x71, - 0x46, 0x7e, 0x39, 0x78, 0x9a, 0x07, 0x14, 0x93, 0xa5, 0x2a, 0xb8, 0x39, 0x93, 0x65, 0x55, 0x7b, - 0xa7, 0x4d, 0xd6, 0x87, 0xa5, 0x54, 0x85, 0x4c, 0xea, 0x02, 0xb3, 0xea, 0x68, 0xd3, 0x26, 0x18, - 0xc2, 0x2b, 0xd2, 0x6a, 0x90, 0x34, 0x3a, 0xc9, 0xab, 0x1b, 0x4d, 0x9b, 0x68, 0x00, 0xe7, 0x25, - 0x35, 0x20, 0xa9, 0x97, 0xcb, 0xae, 0x15, 0x4d, 0x9b, 0x64, 0x1f, 0xba, 0xeb, 0xae, 0xa3, 0x1b, - 0x03, 0xdd, 0xf3, 0x69, 0x5d, 0x86, 0xa4, 0x8a, 0x41, 0x78, 0x28, 0xcf, 0x1d, 0xa4, 0xd5, 0x9b, - 0x69, 0xf3, 0xec, 0x41, 0x93, 0x6e, 0x25, 0xfb, 0xaf, 0x14, 0x24, 0xf7, 0x11, 0x11, 0x88, 0x0c, - 0xc3, 0x23, 0x03, 0x14, 0x42, 0xbd, 0x0b, 0xcd, 0x0d, 0x7a, 0xc0, 0xd5, 0xb3, 0x0d, 0xfc, 0x2c, - 0xe9, 0xaf, 0xe8, 0x83, 0xf1, 0x3b, 0x11, 0x80, 0xc2, 0x1c, 0x5a, 0xa0, 0x51, 0xbb, 0x81, 0x9f, - 0xb1, 0x7d, 0x5e, 0x95, 0xe1, 0x8d, 0x81, 0x64, 0x64, 0x39, 0x52, 0xc8, 0x88, 0xa7, 0x5f, 0x8e, - 0xc6, 0xb2, 0x62, 0xba, 0xbb, 0x19, 0x48, 0x52, 0x90, 0xc1, 0xac, 0x6f, 0x17, 0x1f, 0x10, 0xf5, - 0x0c, 0x01, 0x5d, 0x3d, 0x7a, 0xba, 0x76, 0x2b, 0x8f, 0xf4, 0x68, 0x80, 0xba, 0x3a, 0x1d, 0x50, - 0xcc, 0xb2, 0x0d, 0x0d, 0x22, 0x9d, 0x6c, 0x7b, 0x6e, 0xc8, 0x06, 0x8a, 0xcf, 0xc5, 0x37, 0x67, - 0x13, 0x7b, 0x03, 0xd7, 0xdc, 0xe3, 0x9b, 0x2e, 0x25, 0x27, 0x06, 0x92, 0xbb, 0x39, 0x09, 0x48, - 0x41, 0xf9, 0x84, 0x46, 0x0d, 0x82, 0x75, 0xdc, 0x54, 0xbe, 0x35, 0x6d, 0x7f, 0xe3, 0x66, 0xf2, - 0x4e, 0x51, 0x70, 0x31, 0xed, 0x2f, 0xd0, 0x4c, 0x88, 0x7e, 0x5f, 0x9f, 0x98, 0x96, 0xb1, 0xcd, - 0xef, 0x4e, 0xa3, 0xb7, 0xf3, 0x50, 0xc5, 0x40, 0x33, 0x03, 0xc0, 0x9c, 0x11, 0x62, 0xfe, 0x9f, - 0x85, 0x86, 0xa8, 0x10, 0x22, 0xd9, 0xe5, 0xbf, 0x64, 0x6d, 0xb2, 0x7b, 0x23, 0x1f, 0x48, 0x60, - 0xc6, 0xb0, 0x2c, 0xab, 0x07, 0x4a, 0x93, 0xec, 0x9c, 0xc2, 0xe1, 0x14, 0xf9, 0x58, 0xfb, 0xb2, - 0x01, 0xf5, 0x60, 0xe0, 0x57, 0x5c, 0xba, 0x7a, 0x09, 0xb5, 0xa4, 0xef, 0xc3, 0x62, 0xe2, 0x2f, - 0x30, 0xa4, 0x16, 0x5c, 0xfe, 0x37, 0x19, 0xd3, 0x54, 0xed, 0x13, 0xfe, 0x0f, 0x8d, 0x22, 0xc9, - 0xbb, 0x95, 0x55, 0x8f, 0x4a, 0xe6, 0x77, 0x53, 0x10, 0xff, 0xdf, 0x4e, 0x71, 0x1e, 0x03, 0x44, - 0x92, 0x9b, 0xfc, 0xdb, 0xdb, 0x24, 0x5e, 0x9f, 0xc6, 0xad, 0x91, 0x34, 0x7f, 0x79, 0xbd, 0xc8, - 0x4d, 0xd8, 0xec, 0x08, 0x34, 0x3b, 0x6b, 0x79, 0x02, 0xf3, 0xd1, 0x77, 0x15, 0x48, 0xfa, 0x7f, - 0x80, 0xe9, 0x87, 0x17, 0xd3, 0x56, 0xb1, 0x75, 0xca, 0xc0, 0x76, 0x0a, 0x3a, 0x0f, 0x50, 0xfa, - 0xa4, 0x5c, 0x9a, 0x08, 0x64, 0x9e, 0xcf, 0x4b, 0x13, 0x81, 0xec, 0xe3, 0x77, 0x56, 0x96, 0x4c, - 0x1e, 0xff, 0x4a, 0xcb, 0x92, 0x19, 0x07, 0xea, 0xd2, 0xb2, 0x64, 0xd6, 0x79, 0xb2, 0x7a, 0x6e, - 0xfd, 0xde, 0xa7, 0xef, 0x0c, 0x4d, 0xff, 0x60, 0xb2, 0x47, 0x56, 0x7f, 0x97, 0x0d, 0x7d, 0xcb, - 0x74, 0xf8, 0xaf, 0xbb, 0x81, 0xb8, 0xdf, 0xa5, 0xd8, 0xee, 0x12, 0x6c, 0xe3, 0xbd, 0xbd, 0x1a, - 0x6d, 0xdd, 0xfb, 0xdf, 0x00, 0x00, 0x00, 0xff, 0xff, 0xc2, 0xc7, 0xa4, 0xb3, 0x9d, 0x56, 0x00, - 0x00, + // 5135 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe4, 0x3c, 0x5b, 0x6f, 0x1b, 0x57, + 0x7a, 0x1e, 0xde, 0x44, 0x7e, 0xa4, 0x28, 0xea, 0x58, 0x91, 0x69, 0x3a, 0xbe, 0x64, 0x6c, 0xc7, + 0x8a, 0x93, 0xd8, 0x8e, 0xdc, 0x60, 0x73, 0xd9, 0x64, 0xd7, 0x92, 0x62, 0x87, 0xad, 0xe4, 0x68, + 0x47, 0xb2, 0x53, 0x64, 0x0b, 0x10, 0x23, 0xce, 0x11, 0x3d, 0xd1, 0x70, 0x86, 0x99, 0x19, 0x4a, + 0x56, 0x0a, 0x74, 0xb3, 0x6d, 0x5a, 0xa0, 0x17, 0xf4, 0xde, 0x87, 0xbe, 0x14, 0x45, 0x1f, 0x8a, + 0xed, 0x65, 0x9f, 0xb6, 0xc5, 0x02, 0x45, 0x81, 0x05, 0xda, 0x97, 0x2d, 0xfa, 0x50, 0x14, 0x05, + 0x8a, 0xf6, 0xa1, 0xbf, 0xa2, 0x7f, 0xa0, 0x38, 0x97, 0x39, 0x73, 0x3b, 0x43, 0x8e, 0x44, 0x3b, + 0x01, 0xba, 0x4f, 0xd2, 0x9c, 0xf9, 0xce, 0x39, 0xdf, 0x7c, 0xe7, 0xbb, 0x7f, 0xdf, 0x21, 0xb4, + 0x0c, 0xdd, 0xd7, 0x7b, 0x7d, 0xc7, 0x71, 0x8d, 0x5b, 0x23, 0xd7, 0xf1, 0x1d, 0xb4, 0x38, 0x34, + 0xad, 0xc3, 0xb1, 0xc7, 0x9e, 0x6e, 0x91, 0xd7, 0x9d, 0x46, 0xdf, 0x19, 0x0e, 0x1d, 0x9b, 0x0d, + 0x75, 0x9a, 0xa6, 0xed, 0x63, 0xd7, 0xd6, 0x2d, 0xfe, 0xdc, 0x88, 0x4e, 0xe8, 0x34, 0xbc, 0xfe, + 0x13, 0x3c, 0xd4, 0xf9, 0x53, 0x6d, 0xe8, 0x0d, 0xf8, 0xbf, 0x8b, 0xa6, 0x6d, 0xe0, 0xa7, 0xd1, + 0xad, 0xd4, 0x39, 0x28, 0x7f, 0x30, 0x1c, 0xf9, 0xc7, 0xea, 0xdf, 0x29, 0xd0, 0xb8, 0x6f, 0x8d, + 0xbd, 0x27, 0x1a, 0xfe, 0x6c, 0x8c, 0x3d, 0x1f, 0xdd, 0x81, 0xd2, 0x9e, 0xee, 0xe1, 0xb6, 0x72, + 0x45, 0x59, 0xa9, 0xaf, 0xbe, 0x78, 0x2b, 0x86, 0x13, 0xc7, 0x66, 0xcb, 0x1b, 0xac, 0xe9, 0x1e, + 0xd6, 0x28, 0x24, 0x42, 0x50, 0x32, 0xf6, 0xba, 0x1b, 0xed, 0xc2, 0x15, 0x65, 0xa5, 0xa8, 0xd1, + 0xff, 0xd1, 0x25, 0x00, 0x0f, 0x0f, 0x86, 0xd8, 0xf6, 0xbb, 0x1b, 0x5e, 0xbb, 0x78, 0xa5, 0xb8, + 0x52, 0xd4, 0x22, 0x23, 0x48, 0x85, 0x46, 0xdf, 0xb1, 0x2c, 0xdc, 0xf7, 0x4d, 0xc7, 0xee, 0x6e, + 0xb4, 0x4b, 0x74, 0x6e, 0x6c, 0x0c, 0x75, 0xa0, 0x6a, 0x7a, 0xdd, 0xe1, 0xc8, 0x71, 0xfd, 0x76, + 0xf9, 0x8a, 0xb2, 0x52, 0xd5, 0xc4, 0xb3, 0xfa, 0xfd, 0x02, 0xcc, 0x73, 0xb4, 0xbd, 0x91, 0x63, + 0x7b, 0x18, 0xdd, 0x85, 0x8a, 0xe7, 0xeb, 0xfe, 0xd8, 0xe3, 0x98, 0x5f, 0x90, 0x62, 0xbe, 0x43, + 0x41, 0x34, 0x0e, 0x2a, 0x45, 0x3d, 0x89, 0x5a, 0x51, 0x82, 0x5a, 0xfc, 0xf3, 0x4a, 0xa9, 0xcf, + 0x5b, 0x81, 0x85, 0x7d, 0x82, 0xdd, 0x4e, 0x08, 0x54, 0xa6, 0x40, 0xc9, 0x61, 0xb2, 0x92, 0x6f, + 0x0e, 0xf1, 0x47, 0xfb, 0x3b, 0x58, 0xb7, 0xda, 0x15, 0xba, 0x57, 0x64, 0x04, 0x9d, 0x87, 0x2a, + 0x9d, 0xd2, 0xf3, 0xbd, 0xf6, 0xdc, 0x15, 0x65, 0xa5, 0xa4, 0xcd, 0xd1, 0xe7, 0x5d, 0x4f, 0xfd, + 0x1e, 0x2c, 0x51, 0x12, 0xac, 0x3f, 0xd1, 0x6d, 0x1b, 0x5b, 0xde, 0xe9, 0x4f, 0x30, 0xba, 0x49, + 0x21, 0xb6, 0x09, 0x39, 0x84, 0x3e, 0x5f, 0x9f, 0x1e, 0x63, 0x4d, 0x13, 0xcf, 0xea, 0xbf, 0x2b, + 0xd0, 0x12, 0x9f, 0x12, 0xec, 0xbe, 0x04, 0xe5, 0xbe, 0x33, 0xb6, 0x7d, 0xba, 0xfd, 0xbc, 0xc6, + 0x1e, 0xd0, 0x4b, 0xd0, 0xe0, 0xd3, 0x7a, 0xb6, 0x3e, 0xc4, 0x74, 0x97, 0x9a, 0x56, 0xe7, 0x63, + 0x0f, 0xf5, 0x21, 0xce, 0x45, 0xf7, 0x2b, 0x50, 0x1f, 0xe9, 0xae, 0x6f, 0xc6, 0xb8, 0x26, 0x3a, + 0x34, 0x89, 0x69, 0xc8, 0x0e, 0x26, 0xfd, 0x6f, 0x57, 0xf7, 0x0e, 0xba, 0x1b, 0x9c, 0xda, 0xb1, + 0x31, 0xf5, 0xcf, 0x15, 0x58, 0xbe, 0xe7, 0x79, 0xe6, 0xc0, 0x4e, 0x7d, 0xd9, 0x32, 0x54, 0x6c, + 0xc7, 0xc0, 0xdd, 0x0d, 0xfa, 0x69, 0x45, 0x8d, 0x3f, 0xa1, 0x0b, 0x50, 0x1b, 0x61, 0xec, 0xf6, + 0x5c, 0xc7, 0x0a, 0x3e, 0xac, 0x4a, 0x06, 0x34, 0xc7, 0xc2, 0xe8, 0x3b, 0xb0, 0xe8, 0x25, 0x16, + 0x62, 0x84, 0xac, 0xaf, 0x5e, 0xbd, 0x95, 0x92, 0xf7, 0x5b, 0xc9, 0x4d, 0xb5, 0xf4, 0x6c, 0xf5, + 0x8b, 0x02, 0x9c, 0x15, 0x70, 0x0c, 0x57, 0xf2, 0x3f, 0xa1, 0xbc, 0x87, 0x07, 0x02, 0x3d, 0xf6, + 0x90, 0x87, 0xf2, 0xe2, 0xc8, 0x8a, 0xd1, 0x23, 0xcb, 0x23, 0xa2, 0x89, 0xf3, 0x28, 0xa7, 0xcf, + 0xe3, 0x32, 0xd4, 0xf1, 0xd3, 0x91, 0xe9, 0xe2, 0x1e, 0x61, 0x6a, 0x4a, 0xf2, 0x92, 0x06, 0x6c, + 0x68, 0xd7, 0x1c, 0x46, 0xe5, 0x76, 0x2e, 0xb7, 0xdc, 0xaa, 0x7f, 0xa1, 0xc0, 0xb9, 0xd4, 0x29, + 0x71, 0x45, 0xa0, 0x41, 0x8b, 0x7e, 0x79, 0x48, 0x19, 0xa2, 0x12, 0x08, 0xc1, 0x5f, 0x9e, 0x44, + 0xf0, 0x10, 0x5c, 0x4b, 0xcd, 0x8f, 0x20, 0x59, 0xc8, 0x8f, 0xe4, 0x01, 0x9c, 0x7b, 0x80, 0x7d, + 0xbe, 0x01, 0x79, 0x87, 0x67, 0x10, 0xd1, 0xb8, 0xc6, 0x29, 0x24, 0x35, 0x8e, 0xfa, 0x97, 0x05, + 0x21, 0x8b, 0x74, 0xab, 0xae, 0xbd, 0xef, 0xa0, 0x17, 0xa1, 0x26, 0x40, 0x38, 0x57, 0x84, 0x03, + 0xe8, 0x1b, 0x50, 0x26, 0x98, 0x32, 0x96, 0x68, 0xae, 0xbe, 0x24, 0xff, 0xa6, 0xc8, 0x9a, 0x1a, + 0x83, 0x47, 0x1b, 0xd0, 0xf4, 0x7c, 0xdd, 0xf5, 0x7b, 0x23, 0xc7, 0xa3, 0xe7, 0x4c, 0x19, 0xa7, + 0xbe, 0x7a, 0x31, 0xbe, 0x02, 0x31, 0x40, 0x5b, 0xde, 0x60, 0x9b, 0x03, 0x69, 0xf3, 0x74, 0x52, + 0xf0, 0x88, 0xbe, 0x0d, 0x0d, 0x6c, 0x1b, 0xe1, 0x1a, 0xa5, 0x3c, 0x6b, 0xd4, 0xb1, 0x6d, 0x88, + 0x15, 0xc2, 0x53, 0x29, 0xe7, 0x3f, 0x95, 0xdf, 0x51, 0xa0, 0x9d, 0x3e, 0x96, 0x59, 0x8c, 0xc8, + 0xbb, 0x6c, 0x12, 0x66, 0xc7, 0x32, 0x51, 0xae, 0xc5, 0xd1, 0x68, 0x7c, 0x8a, 0xfa, 0x27, 0x0a, + 0xbc, 0x10, 0xa2, 0x43, 0x5f, 0x3d, 0x2f, 0x1e, 0x41, 0x37, 0xa1, 0x65, 0xda, 0x7d, 0x6b, 0x6c, + 0xe0, 0x47, 0xf6, 0x87, 0x58, 0xb7, 0xfc, 0x27, 0xc7, 0xf4, 0xe4, 0xaa, 0x5a, 0x6a, 0x5c, 0xfd, + 0xef, 0x02, 0x2c, 0x27, 0xf1, 0x9a, 0x85, 0x48, 0x3f, 0x07, 0x65, 0xd3, 0xde, 0x77, 0x02, 0x1a, + 0x5d, 0x9a, 0x20, 0x8a, 0x64, 0x2f, 0x06, 0x8c, 0x1c, 0x40, 0x81, 0xf2, 0xea, 0x3f, 0xc1, 0xfd, + 0x83, 0x91, 0x63, 0x52, 0x35, 0x45, 0x96, 0xf8, 0xb6, 0x64, 0x09, 0x39, 0xc6, 0xb7, 0xb8, 0x85, + 0x5c, 0x17, 0x4b, 0x7c, 0x60, 0xfb, 0xee, 0xb1, 0xb6, 0xd8, 0x4f, 0x8e, 0x77, 0xfa, 0xb0, 0x2c, + 0x07, 0x46, 0x2d, 0x28, 0x1e, 0xe0, 0x63, 0xfa, 0xc9, 0x35, 0x8d, 0xfc, 0x8b, 0xee, 0x42, 0xf9, + 0x50, 0xb7, 0xc6, 0x98, 0xeb, 0x84, 0x29, 0x9c, 0xcb, 0x60, 0xdf, 0x29, 0xbc, 0xa5, 0xa8, 0x43, + 0xb8, 0xf0, 0x00, 0xfb, 0x5d, 0xdb, 0xc3, 0xae, 0xbf, 0x66, 0xda, 0x96, 0x33, 0xd8, 0xd6, 0xfd, + 0x27, 0x33, 0x28, 0x87, 0x98, 0x9c, 0x17, 0x12, 0x72, 0xae, 0xfe, 0x40, 0x81, 0x17, 0xe5, 0xfb, + 0xf1, 0x03, 0xed, 0x40, 0x75, 0xdf, 0xc4, 0x96, 0x41, 0xb8, 0x46, 0xa1, 0x5c, 0x23, 0x9e, 0x89, + 0x92, 0x18, 0x11, 0x60, 0x7e, 0x6e, 0x09, 0x25, 0x21, 0xfc, 0xd1, 0x1d, 0xdf, 0x35, 0xed, 0xc1, + 0xa6, 0xe9, 0xf9, 0x1a, 0x83, 0x8f, 0x70, 0x49, 0x31, 0xbf, 0x70, 0xfe, 0x96, 0x02, 0x97, 0x1e, + 0x60, 0x7f, 0x5d, 0xd8, 0x18, 0xf2, 0xde, 0xf4, 0x7c, 0xb3, 0xef, 0x3d, 0x5b, 0xff, 0x34, 0x87, + 0xb3, 0xa1, 0xfe, 0x9e, 0x02, 0x97, 0x33, 0x91, 0xe1, 0xa4, 0xe3, 0x3a, 0x34, 0xb0, 0x30, 0x72, + 0x1d, 0xfa, 0x0b, 0xf8, 0xf8, 0x31, 0x39, 0xfc, 0x6d, 0xdd, 0x74, 0x99, 0x0e, 0x3d, 0xa5, 0x45, + 0xf9, 0xa1, 0x02, 0x17, 0x1f, 0x60, 0x7f, 0x3b, 0xb0, 0xaf, 0x5f, 0x23, 0x75, 0x08, 0x4c, 0xc4, + 0xce, 0x07, 0x4e, 0x70, 0x6c, 0x4c, 0xfd, 0x5d, 0x76, 0x9c, 0x52, 0x7c, 0xbf, 0x16, 0x02, 0x5e, + 0xa2, 0x92, 0x10, 0x51, 0x11, 0x5c, 0xd8, 0x39, 0xf9, 0xd4, 0x2f, 0xcb, 0xd0, 0x78, 0xcc, 0xb5, + 0x02, 0xb5, 0xa0, 0x49, 0x4a, 0x28, 0x72, 0x27, 0x28, 0xe2, 0x4d, 0xc9, 0x1c, 0xac, 0x35, 0x98, + 0xf7, 0x30, 0x3e, 0x38, 0xa1, 0xbd, 0x6c, 0x90, 0x39, 0xc2, 0xd8, 0x6d, 0xc2, 0xe2, 0xd8, 0xa6, + 0x5e, 0x39, 0x36, 0xf8, 0x07, 0x30, 0xa2, 0x4f, 0x57, 0xa6, 0xe9, 0x89, 0xe8, 0x43, 0x1e, 0xa0, + 0x44, 0xd6, 0x2a, 0xe7, 0x5a, 0x2b, 0x39, 0x0d, 0x75, 0xa1, 0x65, 0xb8, 0xce, 0x68, 0x84, 0x8d, + 0x9e, 0x17, 0x2c, 0x55, 0xc9, 0xb7, 0x14, 0x9f, 0x27, 0x96, 0xba, 0x03, 0x67, 0x93, 0x98, 0x76, + 0x0d, 0xe2, 0x17, 0x12, 0xce, 0x92, 0xbd, 0x42, 0xaf, 0xc1, 0x62, 0x1a, 0xbe, 0x4a, 0xe1, 0xd3, + 0x2f, 0xd0, 0xeb, 0x80, 0x12, 0xa8, 0x12, 0xf0, 0x1a, 0x03, 0x8f, 0x23, 0xc3, 0xc1, 0x69, 0xe0, + 0x1c, 0x07, 0x07, 0x06, 0xce, 0xdf, 0x44, 0xc0, 0xbb, 0xc4, 0xba, 0xc6, 0xc0, 0xbd, 0x76, 0x3d, + 0x1f, 0x21, 0xe2, 0x8b, 0x79, 0xea, 0x6f, 0x2a, 0xb0, 0xfc, 0xb1, 0xee, 0xf7, 0x9f, 0x6c, 0x0c, + 0x67, 0x0f, 0xee, 0xde, 0x83, 0xda, 0xa1, 0x08, 0xe1, 0x98, 0x16, 0xbf, 0x2c, 0x41, 0x28, 0xca, + 0xf6, 0x5a, 0x38, 0x43, 0xfd, 0x27, 0x85, 0x87, 0x99, 0x01, 0x76, 0x5f, 0xbd, 0xaa, 0x99, 0x16, + 0x6d, 0x27, 0x04, 0xb0, 0x9c, 0x12, 0x40, 0xf5, 0x29, 0x00, 0x47, 0x7f, 0xcb, 0x1b, 0x9c, 0x02, + 0xf3, 0xb7, 0x60, 0x8e, 0xef, 0xc7, 0xb5, 0xcd, 0xb4, 0x23, 0x0d, 0xc0, 0xd5, 0x3f, 0x9b, 0x83, + 0x7a, 0xe4, 0x05, 0x6a, 0x42, 0x41, 0xa8, 0x91, 0x82, 0xe4, 0xfb, 0x0b, 0xd3, 0xa3, 0xac, 0x62, + 0x3a, 0xca, 0xba, 0x0e, 0x4d, 0x93, 0x9a, 0xf7, 0x1e, 0xff, 0x6a, 0xea, 0x4d, 0xd7, 0xb4, 0x79, + 0x36, 0xca, 0x99, 0x08, 0x5d, 0x82, 0xba, 0x3d, 0x1e, 0xf6, 0x9c, 0xfd, 0x9e, 0xeb, 0x1c, 0x79, + 0x3c, 0x5c, 0xab, 0xd9, 0xe3, 0xe1, 0x47, 0xfb, 0x9a, 0x73, 0xe4, 0x85, 0x11, 0x41, 0xe5, 0x84, + 0x11, 0xc1, 0x25, 0xa8, 0x0f, 0xf5, 0xa7, 0x64, 0xd5, 0x9e, 0x3d, 0x1e, 0xd2, 0x48, 0xae, 0xa8, + 0xd5, 0x86, 0xfa, 0x53, 0xcd, 0x39, 0x7a, 0x38, 0x1e, 0xa2, 0x15, 0x68, 0x59, 0xba, 0xe7, 0xf7, + 0xa2, 0xa1, 0x60, 0x95, 0x86, 0x82, 0x4d, 0x32, 0xfe, 0x41, 0x18, 0x0e, 0xa6, 0x63, 0x8b, 0xda, + 0xe9, 0x62, 0x0b, 0x63, 0x68, 0x85, 0x6b, 0x40, 0xae, 0xd8, 0xc2, 0x18, 0x5a, 0x62, 0x85, 0xb7, + 0x60, 0x6e, 0x8f, 0xba, 0x4a, 0x93, 0x84, 0xf8, 0x3e, 0xf1, 0x92, 0x98, 0x47, 0xa5, 0x05, 0xe0, + 0xe8, 0x9b, 0x50, 0xa3, 0x16, 0x8a, 0xce, 0x6d, 0xe4, 0x9a, 0x1b, 0x4e, 0x20, 0xb3, 0x0d, 0x6c, + 0xf9, 0x3a, 0x9d, 0x3d, 0x9f, 0x6f, 0xb6, 0x98, 0x40, 0x34, 0x68, 0xdf, 0xc5, 0xba, 0x8f, 0x8d, + 0xb5, 0xe3, 0x75, 0x67, 0x38, 0xd2, 0x29, 0x0b, 0xb5, 0x9b, 0xd4, 0xc9, 0x97, 0xbd, 0x42, 0x2f, + 0x43, 0xb3, 0x2f, 0x9e, 0xee, 0xbb, 0xce, 0xb0, 0xbd, 0x40, 0xe5, 0x2b, 0x31, 0x8a, 0x2e, 0x02, + 0x04, 0xba, 0x53, 0xf7, 0xdb, 0x2d, 0x7a, 0x76, 0x35, 0x3e, 0x72, 0x8f, 0xe6, 0x77, 0x4c, 0xaf, + 0xc7, 0x32, 0x29, 0xa6, 0x3d, 0x68, 0x2f, 0xd2, 0x1d, 0xeb, 0x41, 0xea, 0xc5, 0xb4, 0x07, 0xe8, + 0x1c, 0xcc, 0x99, 0x5e, 0x6f, 0x5f, 0x3f, 0xc0, 0x6d, 0x44, 0xdf, 0x56, 0x4c, 0xef, 0xbe, 0x7e, + 0x40, 0xbd, 0x57, 0xbe, 0x19, 0x36, 0xda, 0x67, 0xe9, 0xab, 0x70, 0x00, 0xbd, 0x09, 0x65, 0x0b, + 0x1f, 0x62, 0xab, 0xbd, 0x44, 0x79, 0xf2, 0x72, 0xb6, 0xe0, 0x6d, 0x12, 0x30, 0x8d, 0x41, 0xa3, + 0x1b, 0xb0, 0xe0, 0xf9, 0x8e, 0xab, 0x0f, 0x70, 0xef, 0x10, 0xbb, 0x1e, 0xa1, 0xc2, 0x0b, 0x94, + 0x2b, 0x9b, 0x7c, 0xf8, 0x31, 0x1b, 0x55, 0x3f, 0x87, 0xa5, 0x90, 0xa3, 0x23, 0x2c, 0x94, 0x66, + 0x44, 0xe5, 0x14, 0x8c, 0x38, 0xd9, 0x33, 0xff, 0x83, 0x32, 0x2c, 0xef, 0xe8, 0x87, 0xf8, 0xf9, + 0x07, 0x01, 0xb9, 0xf4, 0xec, 0x26, 0x2c, 0x52, 0xbf, 0x7f, 0x35, 0x82, 0xcf, 0x04, 0x17, 0x23, + 0xca, 0x83, 0xe9, 0x89, 0xe8, 0x5b, 0x44, 0x2b, 0xe3, 0xfe, 0xc1, 0x36, 0x89, 0xa1, 0x02, 0xf7, + 0xe2, 0xa2, 0x64, 0x9d, 0x75, 0x01, 0xa5, 0x45, 0x67, 0xa0, 0x6d, 0x72, 0x84, 0xd1, 0x13, 0x08, + 0x1c, 0x8b, 0x1b, 0x13, 0x03, 0xec, 0x90, 0xfa, 0x5a, 0x33, 0x76, 0x18, 0x1e, 0x6a, 0xc3, 0x1c, + 0xf7, 0x0a, 0xa8, 0x8a, 0xaa, 0x6a, 0xc1, 0x23, 0xda, 0x86, 0xb3, 0xec, 0x0b, 0x76, 0xb8, 0x24, + 0xb2, 0x8f, 0xaf, 0xe6, 0xfa, 0x78, 0xd9, 0xd4, 0xb8, 0x20, 0xd7, 0x4e, 0x2a, 0xc8, 0x6d, 0x98, + 0xe3, 0xc2, 0x45, 0x75, 0x57, 0x55, 0x0b, 0x1e, 0xc9, 0x31, 0x87, 0x62, 0x56, 0x67, 0xd2, 0x22, + 0x06, 0xc8, 0xbc, 0xc0, 0x02, 0x34, 0xa8, 0x05, 0x08, 0x1e, 0x89, 0xa0, 0xc7, 0x39, 0xbf, 0x3d, + 0x2f, 0x95, 0x87, 0x5f, 0x57, 0x00, 0xc2, 0x13, 0x99, 0x92, 0x42, 0x7a, 0x1b, 0xaa, 0x42, 0x3c, + 0x72, 0x45, 0xc1, 0x02, 0x3c, 0x69, 0x8b, 0x8a, 0x09, 0x5b, 0xa4, 0xfe, 0xab, 0x02, 0x8d, 0x0d, + 0x42, 0x8f, 0x4d, 0x67, 0x40, 0x2d, 0xe7, 0x75, 0x68, 0xba, 0xb8, 0xef, 0xb8, 0x46, 0x0f, 0xdb, + 0xbe, 0x6b, 0x62, 0x96, 0x7e, 0x28, 0x69, 0xf3, 0x6c, 0xf4, 0x03, 0x36, 0x48, 0xc0, 0x88, 0x79, + 0xf1, 0x7c, 0x7d, 0x38, 0xea, 0xed, 0x13, 0x85, 0xc6, 0x32, 0xda, 0xf3, 0x62, 0x94, 0xea, 0xb3, + 0x97, 0xa0, 0x11, 0x82, 0xf9, 0x0e, 0xdd, 0xbf, 0xa4, 0xd5, 0xc5, 0xd8, 0xae, 0x83, 0xae, 0x41, + 0x93, 0x1e, 0x48, 0xcf, 0x72, 0x06, 0x3d, 0x12, 0xd4, 0x72, 0xa3, 0xda, 0x30, 0x38, 0x5a, 0xe4, + 0xa0, 0xe3, 0x50, 0x9e, 0xf9, 0x39, 0xe6, 0x66, 0x55, 0x40, 0xed, 0x98, 0x9f, 0x63, 0xf5, 0xd7, + 0x14, 0x98, 0xe7, 0x56, 0x78, 0x47, 0x94, 0x1e, 0x68, 0x3e, 0x96, 0x25, 0x14, 0xe8, 0xff, 0xe8, + 0x9d, 0x78, 0x46, 0xee, 0x9a, 0x54, 0x58, 0xe8, 0x22, 0xd4, 0x3b, 0x8c, 0x99, 0xe0, 0x3c, 0x11, + 0xed, 0x17, 0x84, 0xa6, 0xba, 0xaf, 0x3f, 0x74, 0x0c, 0x96, 0x20, 0x6c, 0xc3, 0x9c, 0x6e, 0x18, + 0x2e, 0xf6, 0x3c, 0x8e, 0x47, 0xf0, 0x48, 0xde, 0x04, 0x7a, 0x93, 0xe9, 0x92, 0xe0, 0x11, 0x7d, + 0x33, 0x51, 0x11, 0xa8, 0xaf, 0x5e, 0xc9, 0xc6, 0x93, 0xc7, 0x5f, 0x61, 0xcd, 0xe0, 0xef, 0x0b, + 0xd0, 0xe4, 0xb2, 0xba, 0xc6, 0x0d, 0xe6, 0x64, 0x16, 0x5b, 0x83, 0xc6, 0x7e, 0x28, 0x23, 0x93, + 0xf2, 0x47, 0x51, 0x51, 0x8a, 0xcd, 0x99, 0xc6, 0x6b, 0x71, 0x93, 0x5d, 0x9a, 0xc9, 0x64, 0x97, + 0x4f, 0x2a, 0xe9, 0x69, 0xd7, 0xad, 0x22, 0x71, 0xdd, 0xd4, 0x5f, 0x82, 0x7a, 0x64, 0x01, 0xaa, + 0xc9, 0x58, 0x8a, 0x86, 0x53, 0x2c, 0x78, 0x44, 0x77, 0x43, 0xc7, 0x85, 0x91, 0xea, 0xbc, 0x04, + 0x97, 0x84, 0xcf, 0xa2, 0xfe, 0x44, 0x81, 0x0a, 0x5f, 0xf9, 0x32, 0xd4, 0xb9, 0x7c, 0x51, 0x57, + 0x8e, 0xad, 0x0e, 0x7c, 0x88, 0xf8, 0x72, 0xcf, 0x4e, 0xc0, 0xce, 0x43, 0x35, 0x21, 0x5a, 0x73, + 0x5c, 0x7d, 0x06, 0xaf, 0x22, 0xf2, 0x44, 0x5e, 0x11, 0x51, 0x42, 0x4b, 0x50, 0xb6, 0x9c, 0x81, + 0x28, 0xdf, 0xb0, 0x07, 0xf5, 0xa7, 0x0a, 0xcd, 0xb6, 0x6b, 0xb8, 0xef, 0x1c, 0x62, 0xf7, 0x78, + 0xf6, 0x84, 0xe5, 0xbb, 0x11, 0x36, 0xcf, 0x19, 0x35, 0x89, 0x09, 0xe8, 0xdd, 0xf0, 0x10, 0x8a, + 0xb2, 0xbc, 0x46, 0xd4, 0x64, 0x71, 0x26, 0x0d, 0x0f, 0xe3, 0xf7, 0x15, 0x9a, 0x7a, 0x8d, 0x7f, + 0xca, 0x69, 0xbd, 0x82, 0x67, 0x12, 0x5f, 0xa8, 0xff, 0xa2, 0xc0, 0xf9, 0x0c, 0xea, 0x3e, 0x5e, + 0xfd, 0x1a, 0xe8, 0xfb, 0x0e, 0x54, 0x45, 0x8c, 0x5d, 0xcc, 0x15, 0x63, 0x0b, 0x78, 0xf5, 0x8f, + 0x59, 0x01, 0x40, 0x42, 0xde, 0xc7, 0xab, 0xcf, 0x89, 0xc0, 0xc9, 0x5c, 0x59, 0x51, 0x92, 0x2b, + 0xfb, 0x37, 0x05, 0x3a, 0x61, 0x6e, 0xca, 0x5b, 0x3b, 0x9e, 0xb5, 0x62, 0xf4, 0x6c, 0x22, 0xcb, + 0xb7, 0x45, 0x71, 0x83, 0xe8, 0xc5, 0x5c, 0x31, 0x61, 0x50, 0xda, 0xb0, 0x69, 0x9a, 0x3b, 0xfd, + 0x41, 0xb3, 0x48, 0x65, 0x27, 0x72, 0xf0, 0xac, 0xc0, 0x11, 0x1e, 0xec, 0x4f, 0x18, 0x93, 0xde, + 0x8f, 0x27, 0xa8, 0xbe, 0x6e, 0x02, 0x46, 0x8b, 0x2e, 0x4f, 0x78, 0xd1, 0xa5, 0x94, 0x28, 0xba, + 0xf0, 0x71, 0x75, 0x48, 0x59, 0x20, 0xf5, 0x01, 0xcf, 0x8b, 0x60, 0xbf, 0xa1, 0x40, 0x9b, 0xef, + 0xc2, 0x1a, 0x09, 0x9c, 0xe1, 0xc8, 0xc2, 0x3e, 0x36, 0xbe, 0xea, 0x24, 0xc9, 0xdf, 0x16, 0xa0, + 0x15, 0x75, 0x6c, 0xa8, 0x6f, 0xf2, 0x26, 0x94, 0x69, 0x16, 0x8a, 0x63, 0x30, 0x55, 0x3b, 0x30, + 0x68, 0x62, 0x19, 0xa9, 0xd7, 0xbf, 0xeb, 0x05, 0x8e, 0x0b, 0x7f, 0x0c, 0xbd, 0xab, 0xe2, 0xc9, + 0xbd, 0xab, 0x17, 0xa1, 0x46, 0x2c, 0x97, 0x33, 0x26, 0xeb, 0xb2, 0x4a, 0x78, 0x38, 0x80, 0xde, + 0x83, 0x0a, 0xeb, 0xbd, 0xe1, 0x85, 0xc8, 0xeb, 0xf1, 0xa5, 0x79, 0x5f, 0x4e, 0xa4, 0x90, 0x40, + 0x07, 0x34, 0x3e, 0x89, 0x9c, 0xd1, 0xc8, 0x75, 0x06, 0xd4, 0x0d, 0x23, 0x46, 0xad, 0xac, 0x89, + 0x67, 0xe2, 0x26, 0x3a, 0xa3, 0xee, 0x06, 0x4f, 0xa9, 0xd0, 0xff, 0xd5, 0x9f, 0x87, 0xe5, 0x30, + 0x82, 0x67, 0x68, 0x9e, 0x96, 0xc9, 0xd5, 0xff, 0x54, 0xe0, 0xec, 0xce, 0xb1, 0xdd, 0x4f, 0x8a, + 0xcb, 0x32, 0x54, 0x46, 0x96, 0x1e, 0xa6, 0xbc, 0xf9, 0x13, 0x6d, 0x27, 0x08, 0x62, 0x73, 0x62, + 0xd6, 0x19, 0x8d, 0xeb, 0x62, 0x6c, 0xd7, 0x99, 0xea, 0x6d, 0x5d, 0x17, 0x29, 0x07, 0x6c, 0x30, + 0x07, 0x82, 0xa5, 0xf4, 0xe6, 0xc5, 0x28, 0x75, 0x20, 0xde, 0x03, 0xa0, 0x3e, 0x56, 0xef, 0x24, + 0x7e, 0x15, 0x9d, 0xb1, 0x49, 0xac, 0xe8, 0x8f, 0x0a, 0xd0, 0x8e, 0x50, 0xe9, 0xab, 0x76, 0x39, + 0x33, 0x02, 0xca, 0xe2, 0x33, 0x0a, 0x28, 0x4b, 0xb3, 0xbb, 0x99, 0x65, 0x99, 0x9b, 0xf9, 0xfd, + 0x22, 0x34, 0x43, 0xaa, 0x6d, 0x5b, 0xba, 0x9d, 0xc9, 0x09, 0x3b, 0xd0, 0xf4, 0x62, 0x54, 0xe5, + 0x74, 0x7a, 0x55, 0x26, 0x57, 0x19, 0x07, 0xa1, 0x25, 0x96, 0x40, 0x17, 0xe9, 0xa1, 0xbb, 0x3e, + 0x4b, 0x11, 0x32, 0x9f, 0xb1, 0xc6, 0x04, 0xd8, 0x1c, 0x62, 0xf4, 0x1a, 0x20, 0x2e, 0x75, 0x3d, + 0xd3, 0xee, 0x79, 0xb8, 0xef, 0xd8, 0x06, 0x93, 0xc7, 0xb2, 0xd6, 0xe2, 0x6f, 0xba, 0xf6, 0x0e, + 0x1b, 0x47, 0x6f, 0x42, 0xc9, 0x3f, 0x1e, 0x31, 0x07, 0xb2, 0x29, 0x75, 0xc1, 0x42, 0xbc, 0x76, + 0x8f, 0x47, 0x58, 0xa3, 0xe0, 0x41, 0x4b, 0x96, 0xef, 0xea, 0x87, 0xdc, 0x1b, 0x2f, 0x69, 0x91, + 0x91, 0x68, 0x8c, 0x3d, 0x17, 0x8f, 0xb1, 0x29, 0x67, 0x07, 0x42, 0xde, 0xf3, 0x7d, 0x8b, 0x26, + 0x39, 0x29, 0x67, 0x07, 0xa3, 0xbb, 0xbe, 0x45, 0x3e, 0xd2, 0x77, 0x7c, 0xdd, 0x62, 0xf2, 0x51, + 0xe3, 0xda, 0x84, 0x8c, 0xd0, 0xc8, 0xf7, 0x3f, 0x88, 0x36, 0x14, 0x88, 0x69, 0xd8, 0x1b, 0x5b, + 0xd9, 0xf2, 0x38, 0x39, 0xeb, 0x33, 0x4d, 0x14, 0xbf, 0x05, 0x75, 0xce, 0x15, 0x27, 0xe0, 0x2a, + 0x60, 0x53, 0x36, 0x27, 0xb0, 0x79, 0xf9, 0x19, 0xb1, 0x79, 0xe5, 0x14, 0x79, 0x13, 0xf9, 0xd9, + 0xa8, 0x3f, 0x50, 0xe0, 0x85, 0x94, 0xd6, 0x9c, 0x48, 0xda, 0xc9, 0xd1, 0x38, 0xd7, 0xa6, 0xc9, + 0x25, 0xb9, 0xbd, 0x78, 0x17, 0x2a, 0x2e, 0x5d, 0x9d, 0x97, 0xfa, 0xae, 0x4e, 0x64, 0x3e, 0x86, + 0x88, 0xc6, 0xa7, 0xa8, 0x7f, 0xa8, 0xc0, 0xb9, 0x34, 0xaa, 0x33, 0x38, 0x01, 0x6b, 0x30, 0xc7, + 0x96, 0x0e, 0x64, 0x74, 0x65, 0xb2, 0x8c, 0x86, 0xc4, 0xd1, 0x82, 0x89, 0xea, 0x0e, 0x2c, 0x07, + 0xbe, 0x42, 0x48, 0xfa, 0x2d, 0xec, 0xeb, 0x13, 0x62, 0xd1, 0xcb, 0x50, 0x67, 0x41, 0x0d, 0x8b, + 0xf1, 0x58, 0x65, 0x14, 0xf6, 0x44, 0x92, 0x50, 0xfd, 0xa3, 0x02, 0x2c, 0x51, 0x63, 0x9b, 0x2c, + 0x73, 0xe5, 0xa9, 0xbb, 0xaa, 0xa2, 0xb3, 0xed, 0xa1, 0x3e, 0xe4, 0xdd, 0x37, 0x35, 0x2d, 0x36, + 0x86, 0xba, 0xe9, 0x1c, 0xa2, 0x34, 0x67, 0x11, 0x16, 0x9a, 0x37, 0x74, 0x5f, 0xa7, 0x75, 0xe6, + 0x64, 0xf2, 0x30, 0x34, 0xf2, 0xa5, 0xd3, 0x18, 0xf9, 0x57, 0xa0, 0xc5, 0xf2, 0xef, 0x3d, 0x11, + 0x02, 0x53, 0xc5, 0x54, 0xd2, 0x16, 0xd8, 0xf8, 0x6e, 0x30, 0xac, 0x6e, 0xc2, 0x0b, 0x09, 0xa2, + 0xcc, 0x70, 0xf8, 0xea, 0x5f, 0x29, 0xe4, 0xe4, 0x62, 0x0d, 0x4f, 0xa7, 0xf7, 0x89, 0x2f, 0x8a, + 0x52, 0x5c, 0xcf, 0x34, 0x92, 0xfa, 0xc6, 0x40, 0xef, 0x43, 0xcd, 0xc6, 0x47, 0xbd, 0xa8, 0x9b, + 0x95, 0x23, 0x60, 0xa8, 0xda, 0xf8, 0x88, 0xfe, 0xa7, 0x3e, 0x84, 0x73, 0x29, 0x54, 0x67, 0xf9, + 0xf6, 0x7f, 0x50, 0xe0, 0xfc, 0x86, 0xeb, 0x8c, 0x1e, 0x9b, 0xae, 0x3f, 0xd6, 0xad, 0x78, 0xb5, + 0xff, 0x14, 0x9f, 0x9f, 0xa3, 0x99, 0xf2, 0xc3, 0x54, 0x68, 0xfa, 0x9a, 0x44, 0xd8, 0xd2, 0x48, + 0xf1, 0x8f, 0x8e, 0xb8, 0xe7, 0xff, 0x53, 0x94, 0x21, 0xcf, 0xe1, 0xa6, 0xb8, 0x30, 0x79, 0x62, + 0x17, 0x69, 0xba, 0xbf, 0x78, 0xda, 0x74, 0x7f, 0x86, 0x25, 0x28, 0x3d, 0x23, 0x4b, 0x70, 0xe2, + 0xbc, 0xda, 0x3a, 0xc4, 0x4b, 0x31, 0xd4, 0x90, 0x9f, 0xb4, 0x7c, 0xf3, 0x1e, 0x40, 0x58, 0x91, + 0xe0, 0x0d, 0xaa, 0x53, 0x56, 0x88, 0x4c, 0x20, 0x67, 0x24, 0x6c, 0x2d, 0x77, 0x05, 0x22, 0x19, + 0xee, 0xef, 0x40, 0x47, 0xc6, 0x9b, 0xb3, 0xf0, 0xfb, 0x7f, 0x15, 0x00, 0xba, 0xa2, 0x9d, 0xf9, + 0x74, 0xc6, 0xe2, 0x2a, 0x44, 0xdc, 0x95, 0x50, 0xca, 0xa3, 0xbc, 0x63, 0x10, 0x41, 0x10, 0x41, + 0x2e, 0x81, 0x49, 0x05, 0xbe, 0x06, 0x5d, 0x27, 0x22, 0x2b, 0x8c, 0x15, 0x92, 0xfa, 0xf9, 0x02, + 0xd4, 0x5c, 0xe7, 0xa8, 0x47, 0x84, 0xcb, 0x08, 0xfa, 0xb5, 0x5d, 0xe7, 0x88, 0x88, 0x9c, 0x81, + 0xce, 0xc1, 0x9c, 0xaf, 0x7b, 0x07, 0x64, 0x7d, 0x96, 0xeb, 0xab, 0x90, 0xc7, 0xae, 0x81, 0x96, + 0xa0, 0xbc, 0x6f, 0x5a, 0x98, 0xb5, 0x86, 0xd4, 0x34, 0xf6, 0x80, 0xbe, 0x11, 0xb4, 0x18, 0x56, + 0x73, 0xb7, 0x12, 0xb1, 0x2e, 0xc3, 0xab, 0x30, 0x4f, 0x38, 0x89, 0x20, 0xc1, 0xc4, 0xba, 0xc5, + 0xf3, 0xfc, 0x7c, 0x90, 0xb6, 0x10, 0xfc, 0x54, 0x81, 0x85, 0x90, 0xb4, 0x54, 0x37, 0x11, 0x75, + 0x47, 0x55, 0xdd, 0xba, 0x63, 0x30, 0x2d, 0xd2, 0xcc, 0xb0, 0x2b, 0x6c, 0x22, 0x53, 0x68, 0xe1, + 0x94, 0x49, 0xc1, 0x39, 0xf9, 0x78, 0x42, 0x19, 0xd3, 0x08, 0xd2, 0x45, 0x15, 0xd7, 0x39, 0xea, + 0x1a, 0x82, 0x64, 0xac, 0x63, 0x9b, 0x85, 0xa2, 0x84, 0x64, 0xeb, 0xb4, 0x69, 0xfb, 0x2a, 0xcc, + 0x63, 0xd7, 0x75, 0xdc, 0xde, 0x10, 0x7b, 0x9e, 0x3e, 0x08, 0x9a, 0x21, 0x1a, 0x74, 0x70, 0x8b, + 0x8d, 0xa9, 0xff, 0x58, 0x82, 0x66, 0xf8, 0x29, 0x41, 0x5b, 0x82, 0x69, 0x04, 0x6d, 0x09, 0x26, + 0x39, 0x5f, 0x70, 0x99, 0x96, 0x14, 0x1c, 0xb0, 0x56, 0x68, 0x2b, 0x5a, 0x8d, 0x8f, 0x76, 0x0d, + 0x62, 0xdc, 0x09, 0x81, 0x6c, 0xc7, 0xc0, 0x21, 0x07, 0x40, 0x30, 0xc4, 0x19, 0x20, 0xc6, 0x48, + 0xa5, 0x1c, 0x8c, 0x54, 0xce, 0xc1, 0x48, 0x15, 0x09, 0x23, 0x2d, 0x43, 0x65, 0x6f, 0xdc, 0x3f, + 0xc0, 0x3e, 0xf7, 0xfb, 0xf8, 0x53, 0x9c, 0xc1, 0xaa, 0x09, 0x06, 0x13, 0x7c, 0x54, 0x8b, 0xf2, + 0xd1, 0x05, 0xa8, 0x05, 0x96, 0xda, 0xa3, 0xd5, 0xb7, 0xa2, 0x56, 0xe5, 0x26, 0xda, 0x43, 0x6f, + 0x05, 0x4e, 0x61, 0x9d, 0x4a, 0x94, 0x2a, 0x51, 0x48, 0x09, 0x2e, 0x09, 0x5c, 0xc2, 0x1b, 0xb0, + 0x10, 0x21, 0x07, 0xe5, 0x33, 0x56, 0xa2, 0x8b, 0xc4, 0x0c, 0xd4, 0x82, 0x5c, 0x87, 0x66, 0x48, + 0x12, 0x0a, 0x37, 0xcf, 0x42, 0x35, 0x31, 0x4a, 0xc1, 0x04, 0xbb, 0x37, 0x4f, 0xc8, 0xee, 0xe7, + 0xa1, 0xca, 0x63, 0x2c, 0xaf, 0xbd, 0x10, 0x4f, 0x91, 0xe4, 0x92, 0x84, 0x4f, 0x01, 0x85, 0x9f, + 0x38, 0x9b, 0x63, 0x9a, 0xe0, 0xa1, 0x42, 0x92, 0x87, 0xd4, 0xbf, 0x56, 0x60, 0x31, 0xba, 0xd9, + 0x69, 0x0d, 0xf7, 0xfb, 0x50, 0x67, 0x45, 0xd2, 0x1e, 0x51, 0x21, 0xf2, 0x5a, 0x65, 0xe2, 0xf0, + 0x34, 0x08, 0x2f, 0x86, 0x10, 0xc2, 0x1c, 0x39, 0xee, 0x81, 0x69, 0x0f, 0x7a, 0x04, 0x33, 0x91, + 0xc2, 0xe5, 0x83, 0x0f, 0xc9, 0x98, 0xfa, 0xdb, 0x0a, 0x5c, 0x7a, 0x34, 0x32, 0x74, 0x1f, 0x47, + 0x3c, 0x98, 0x59, 0xfb, 0x33, 0x45, 0x83, 0x64, 0x61, 0xc2, 0x31, 0x47, 0xf6, 0xf3, 0x78, 0x83, + 0x24, 0xf1, 0xfb, 0x38, 0x36, 0xa9, 0x8e, 0xe6, 0xd3, 0x63, 0xd3, 0x81, 0xea, 0x21, 0x5f, 0x2e, + 0xb8, 0xea, 0x12, 0x3c, 0xc7, 0x8a, 0xc1, 0xc5, 0x13, 0x15, 0x83, 0xd5, 0x2d, 0x38, 0xaf, 0x61, + 0x0f, 0xdb, 0x46, 0xec, 0x43, 0x4e, 0x9d, 0xd4, 0x1a, 0x41, 0x47, 0xb6, 0xdc, 0x2c, 0x9c, 0xca, + 0x1c, 0xdf, 0x9e, 0x4b, 0x96, 0xf5, 0xb9, 0xb2, 0x26, 0xfe, 0x16, 0xdd, 0xc7, 0x57, 0xff, 0xa6, + 0x00, 0xe7, 0xee, 0x19, 0x06, 0xd7, 0xf3, 0xdc, 0x95, 0x7b, 0x5e, 0x5e, 0x76, 0xd2, 0x0b, 0x2d, + 0xa6, 0xbd, 0xd0, 0x67, 0xa5, 0x7b, 0xb9, 0x15, 0xb2, 0xc7, 0xc3, 0xc0, 0x04, 0xbb, 0xac, 0xa3, + 0xeb, 0x5d, 0x5e, 0x32, 0xed, 0x59, 0xce, 0x80, 0x9a, 0xe1, 0xe9, 0xce, 0x59, 0x35, 0x48, 0xce, + 0xa9, 0x23, 0x68, 0xa7, 0x89, 0x35, 0xa3, 0x1e, 0x09, 0x28, 0x32, 0x72, 0x58, 0xe2, 0xb7, 0x41, + 0x3c, 0x31, 0x3a, 0xb4, 0xed, 0x78, 0xea, 0xff, 0x16, 0xa0, 0xbd, 0xa3, 0x1f, 0xe2, 0x9f, 0x9d, + 0x03, 0xfa, 0x04, 0x96, 0x3c, 0xfd, 0x10, 0xf7, 0x22, 0x01, 0x78, 0xcf, 0xc5, 0x9f, 0x71, 0x27, + 0xf6, 0x15, 0x59, 0x6a, 0x5e, 0xda, 0x89, 0xa4, 0x2d, 0x7a, 0xb1, 0x71, 0x0d, 0x7f, 0x86, 0x5e, + 0x86, 0x85, 0x68, 0x7b, 0x1d, 0x41, 0xad, 0x4a, 0x49, 0x3e, 0x1f, 0x69, 0xa1, 0xeb, 0x1a, 0xea, + 0x67, 0xf0, 0xe2, 0x23, 0xdb, 0xc3, 0x7e, 0x37, 0x6c, 0x03, 0x9b, 0x31, 0xfe, 0xbc, 0x0c, 0xf5, + 0x90, 0xf0, 0xa9, 0x3b, 0x2e, 0x86, 0xa7, 0x3a, 0xd0, 0xd9, 0xd2, 0xdd, 0x83, 0x20, 0x9d, 0xbd, + 0xc1, 0xba, 0x66, 0x9e, 0xe3, 0x86, 0xfb, 0xa2, 0x7f, 0x4c, 0xc3, 0xfb, 0xd8, 0xc5, 0x76, 0x1f, + 0x6f, 0x3a, 0xfd, 0x03, 0xe2, 0x90, 0xf8, 0xec, 0x9a, 0xa1, 0x12, 0xf1, 0x5d, 0x37, 0x22, 0xb7, + 0x08, 0x0b, 0xb1, 0x5b, 0x84, 0x53, 0x6e, 0xcc, 0xaa, 0x3f, 0x2c, 0xc0, 0xf2, 0x3d, 0xcb, 0xc7, + 0x6e, 0x98, 0x61, 0x38, 0x49, 0xb2, 0x24, 0xcc, 0x5e, 0x14, 0x4e, 0x93, 0xbd, 0xc8, 0x51, 0xc1, + 0x94, 0xe5, 0x5a, 0x4a, 0xa7, 0xcc, 0xb5, 0xdc, 0x03, 0x18, 0xb9, 0xce, 0x08, 0xbb, 0xbe, 0x89, + 0x83, 0xd8, 0x2f, 0x87, 0x83, 0x13, 0x99, 0xa4, 0x7e, 0x02, 0xad, 0x07, 0xfd, 0x75, 0xc7, 0xde, + 0x37, 0xdd, 0x61, 0x40, 0xa8, 0x94, 0xd0, 0x29, 0x39, 0x84, 0xae, 0x90, 0x12, 0x3a, 0xd5, 0x84, + 0xc5, 0xc8, 0xda, 0x33, 0x2a, 0xae, 0x41, 0xbf, 0xb7, 0x6f, 0xda, 0x26, 0xed, 0x4a, 0x2b, 0x50, + 0x07, 0x15, 0x06, 0xfd, 0xfb, 0x7c, 0x44, 0xfd, 0x52, 0x81, 0x0b, 0x1a, 0x26, 0xc2, 0x13, 0x34, + 0xee, 0xec, 0xfa, 0x5b, 0xde, 0x60, 0x06, 0x87, 0xe2, 0x2e, 0x94, 0x86, 0xde, 0x20, 0xa3, 0xe8, + 0x4e, 0x4c, 0x74, 0x6c, 0x23, 0x8d, 0x02, 0xab, 0x3f, 0x56, 0x60, 0x29, 0x28, 0x4d, 0xc6, 0x44, + 0x38, 0xce, 0xb6, 0x4a, 0xaa, 0x37, 0x7b, 0xc2, 0xd5, 0xe2, 0x73, 0x30, 0x67, 0xec, 0x45, 0x15, + 0x64, 0xc5, 0xd8, 0xa3, 0xba, 0x51, 0xe2, 0x29, 0x97, 0xa4, 0x9e, 0x72, 0x92, 0xf1, 0xcb, 0x92, + 0x9e, 0xa7, 0x47, 0xd0, 0xe6, 0x0e, 0xca, 0x47, 0x23, 0xec, 0xea, 0x94, 0xbf, 0x02, 0xe4, 0xdf, + 0x0e, 0x5c, 0x68, 0x25, 0xf3, 0xe2, 0x5e, 0xb2, 0x2c, 0xc9, 0x9d, 0x68, 0xf5, 0x9f, 0x15, 0xb8, + 0x92, 0x5c, 0x77, 0x9b, 0x17, 0xed, 0x66, 0xbe, 0x93, 0x4e, 0x2b, 0x7e, 0x85, 0xb0, 0xe2, 0x37, + 0x53, 0xe9, 0x32, 0x5a, 0x5d, 0x2c, 0xc5, 0xab, 0x8b, 0x37, 0xdf, 0x17, 0xcd, 0xe9, 0xbb, 0xc7, + 0x23, 0x8c, 0xe6, 0xa0, 0xf8, 0x10, 0x1f, 0xb5, 0xce, 0x20, 0x80, 0xca, 0x43, 0xc7, 0x1d, 0xea, + 0x56, 0x4b, 0x41, 0x75, 0x98, 0xe3, 0x15, 0xe9, 0x56, 0x01, 0xcd, 0x43, 0x6d, 0x3d, 0xa8, 0xd2, + 0xb5, 0x8a, 0x37, 0x6f, 0x42, 0x23, 0xda, 0x7c, 0x4b, 0xe6, 0x6d, 0xe2, 0x81, 0xde, 0x3f, 0x6e, + 0x9d, 0x41, 0x15, 0x28, 0x6c, 0xde, 0x69, 0x29, 0xf4, 0xef, 0x1b, 0xad, 0xc2, 0xcd, 0x3f, 0x55, + 0x60, 0x31, 0x85, 0x24, 0x6a, 0x02, 0x3c, 0xb2, 0xfb, 0xbc, 0xf0, 0xdc, 0x3a, 0x83, 0x1a, 0x50, + 0x0d, 0xca, 0xd0, 0x6c, 0xef, 0x5d, 0x87, 0x42, 0xb7, 0x0a, 0xa8, 0x05, 0x0d, 0x36, 0x71, 0xdc, + 0xef, 0x63, 0xcf, 0x6b, 0x15, 0xc5, 0xc8, 0x7d, 0xdd, 0xb4, 0xc6, 0x2e, 0x6e, 0x95, 0x08, 0x7e, + 0xbb, 0x8e, 0x86, 0x2d, 0xac, 0x7b, 0xb8, 0x55, 0x46, 0x08, 0x9a, 0xfc, 0x21, 0x98, 0x54, 0x89, + 0x8c, 0x05, 0xd3, 0xe6, 0x6e, 0xfe, 0x48, 0x89, 0x96, 0xbd, 0x28, 0x2d, 0xce, 0xc1, 0xd9, 0x47, + 0xb6, 0x81, 0xf7, 0x4d, 0x1b, 0x1b, 0xe1, 0xab, 0xd6, 0x19, 0x74, 0x16, 0x16, 0xb6, 0xb0, 0x3b, + 0xc0, 0x91, 0xc1, 0x02, 0x5a, 0x84, 0xf9, 0x2d, 0xf3, 0x69, 0x64, 0xa8, 0x88, 0x96, 0xa0, 0xb5, + 0x63, 0xda, 0x03, 0x2b, 0x0a, 0x58, 0xa2, 0xb3, 0x4d, 0xdb, 0x71, 0x23, 0x83, 0x65, 0x3a, 0xa8, + 0x7f, 0x1a, 0x1b, 0xac, 0xa0, 0x0e, 0x2c, 0x53, 0xa2, 0xde, 0xd9, 0xc0, 0x84, 0x1a, 0x91, 0x77, + 0x73, 0x6a, 0xa9, 0xaa, 0xb4, 0x94, 0xd5, 0x1f, 0x5f, 0x87, 0x1a, 0x11, 0xd6, 0x75, 0xc7, 0x71, + 0x0d, 0x64, 0x01, 0xa2, 0xb7, 0xd5, 0x86, 0x23, 0xc7, 0x16, 0x37, 0x5b, 0xd1, 0xad, 0x84, 0x7c, + 0xb3, 0x87, 0x34, 0x20, 0x17, 0x89, 0xce, 0x35, 0x29, 0x7c, 0x02, 0x58, 0x3d, 0x83, 0x86, 0x74, + 0xb7, 0x5d, 0x73, 0x88, 0x77, 0xcd, 0xfe, 0x41, 0x10, 0x02, 0xdc, 0xc9, 0xb8, 0x1e, 0x98, 0x06, + 0x0d, 0xf6, 0xbb, 0x2a, 0xdd, 0x8f, 0x5d, 0x27, 0x0c, 0xe4, 0x48, 0x3d, 0x83, 0x3e, 0xa3, 0xea, + 0x27, 0x8c, 0xa7, 0x82, 0x0d, 0x57, 0xb3, 0x37, 0x4c, 0x01, 0x9f, 0x70, 0xcb, 0x4d, 0x28, 0x53, + 0xbe, 0x47, 0xb2, 0xce, 0x83, 0xe8, 0x4f, 0x66, 0x74, 0xae, 0x64, 0x03, 0x88, 0xd5, 0x3e, 0x85, + 0x85, 0xc4, 0x85, 0x75, 0x24, 0xf3, 0xc1, 0xe4, 0x3f, 0x3d, 0xd0, 0xb9, 0x99, 0x07, 0x54, 0xec, + 0x35, 0x80, 0x66, 0xfc, 0x96, 0x1b, 0x5a, 0xc9, 0x71, 0x57, 0x96, 0xed, 0xf4, 0x4a, 0xee, 0x5b, + 0xb5, 0x94, 0x09, 0x5a, 0xc9, 0xab, 0xd4, 0xe8, 0xe6, 0xc4, 0x05, 0xe2, 0xcc, 0xf6, 0x6a, 0x2e, + 0x58, 0xb1, 0xdd, 0x31, 0x65, 0x82, 0xd4, 0x3d, 0xd6, 0x24, 0x8f, 0x07, 0xcb, 0x64, 0x5d, 0xb0, + 0xed, 0xdc, 0xce, 0x0d, 0x2f, 0xb6, 0xfe, 0x55, 0xd6, 0x5e, 0x28, 0xbb, 0x0b, 0x8a, 0xde, 0x90, + 0x2f, 0x37, 0xe1, 0x12, 0x6b, 0x67, 0xf5, 0x24, 0x53, 0x04, 0x12, 0xdf, 0xa3, 0x7d, 0x81, 0x92, + 0xdb, 0x94, 0x49, 0xb9, 0x0b, 0xd6, 0xcb, 0xbe, 0x28, 0xda, 0x79, 0xe3, 0x04, 0x33, 0x04, 0x02, + 0x4e, 0xf2, 0xae, 0x7a, 0x20, 0x86, 0xb7, 0xa7, 0x72, 0xcd, 0xe9, 0x64, 0xf0, 0xbb, 0xb0, 0x90, + 0x88, 0x4a, 0x50, 0xfe, 0xc8, 0xa5, 0x33, 0xc9, 0xdc, 0x32, 0x91, 0x4c, 0xf4, 0x01, 0xa2, 0x0c, + 0xee, 0x97, 0xf4, 0x0a, 0x76, 0x6e, 0xe6, 0x01, 0x15, 0x1f, 0x32, 0x82, 0xc5, 0xc4, 0xcb, 0xc7, + 0xab, 0xe8, 0xd5, 0xdc, 0xbb, 0x3d, 0x5e, 0xed, 0xbc, 0x96, 0x7f, 0xbf, 0xc7, 0xab, 0xea, 0x19, + 0xe4, 0x51, 0x05, 0x9d, 0xe8, 0x25, 0x43, 0x19, 0xab, 0xc8, 0x7b, 0xe6, 0x3a, 0xaf, 0xe7, 0x84, + 0x16, 0x9f, 0x79, 0x08, 0x67, 0x25, 0x2d, 0x7f, 0xe8, 0xf5, 0x89, 0xec, 0x91, 0xec, 0x75, 0xec, + 0xdc, 0xca, 0x0b, 0x1e, 0x31, 0x0f, 0xad, 0x00, 0xaf, 0x7b, 0x96, 0xc5, 0x3c, 0x8b, 0xd7, 0xb2, + 0x2c, 0x5f, 0x0c, 0x2c, 0xe3, 0x53, 0x33, 0xa1, 0xc5, 0x96, 0xbf, 0x0c, 0x68, 0xe7, 0x89, 0x73, + 0x44, 0xa3, 0x80, 0xc1, 0x98, 0x3b, 0x96, 0x99, 0x06, 0x30, 0x0d, 0x9a, 0x21, 0x88, 0x13, 0x67, + 0x88, 0xcd, 0x7b, 0x00, 0x0f, 0xb0, 0xbf, 0x85, 0x7d, 0x97, 0x48, 0xff, 0xcb, 0x59, 0xb8, 0x73, + 0x80, 0x60, 0xab, 0x1b, 0x53, 0xe1, 0xa2, 0x04, 0xdd, 0xd2, 0xed, 0xb1, 0x6e, 0x45, 0xae, 0x8a, + 0xc9, 0x09, 0x9a, 0x04, 0x9b, 0x4c, 0xd0, 0x34, 0xb4, 0xd8, 0xf2, 0x48, 0xf8, 0x2f, 0x91, 0x2e, + 0x84, 0xc9, 0xfe, 0x4b, 0xba, 0x03, 0x2e, 0xa9, 0xdb, 0x27, 0xc0, 0x8b, 0x8d, 0xbf, 0x50, 0x68, + 0xa3, 0x6a, 0x02, 0xe0, 0x63, 0xd3, 0x7f, 0xb2, 0x6d, 0xe9, 0xb6, 0x97, 0x07, 0x05, 0x0a, 0x78, + 0x02, 0x14, 0x38, 0xbc, 0x40, 0xc1, 0x80, 0xf9, 0x58, 0xc5, 0x1f, 0xc9, 0xae, 0x38, 0xc9, 0x1a, + 0x25, 0x3a, 0x2b, 0xd3, 0x01, 0xc5, 0x2e, 0xfb, 0x30, 0x1f, 0x8b, 0xe1, 0xa4, 0xbb, 0xc8, 0xa2, + 0xbc, 0xa4, 0xb2, 0x4b, 0x48, 0x47, 0x92, 0xa0, 0x1e, 0xa0, 0x74, 0x61, 0x13, 0xe5, 0x2b, 0x83, + 0x4f, 0x52, 0x3d, 0xd9, 0xd5, 0x52, 0xa6, 0xcd, 0x13, 0xad, 0x03, 0x72, 0x53, 0x21, 0xed, 0x84, + 0x90, 0x6a, 0xf3, 0x8c, 0x4e, 0x04, 0xf5, 0x0c, 0xfa, 0x18, 0x2a, 0xfc, 0x27, 0xa5, 0xae, 0x4d, + 0x2e, 0x21, 0xf0, 0xd5, 0xaf, 0x4f, 0x81, 0x12, 0x0b, 0x1f, 0xc0, 0xb9, 0x8c, 0x02, 0x82, 0xd4, + 0xcb, 0x98, 0x5c, 0x6c, 0x98, 0x66, 0xff, 0xc4, 0x66, 0xa9, 0xfa, 0xc0, 0x84, 0xcd, 0xb2, 0x6a, + 0x09, 0xd3, 0x36, 0xeb, 0xc1, 0x62, 0x2a, 0xff, 0x2a, 0x35, 0x80, 0x59, 0x59, 0xda, 0x69, 0x1b, + 0x0c, 0xe0, 0x05, 0x69, 0xae, 0x51, 0xea, 0x9b, 0x4c, 0xca, 0x4a, 0x4e, 0xdb, 0xa8, 0x0f, 0x67, + 0x25, 0x19, 0x46, 0xa9, 0x8d, 0xcb, 0xce, 0x44, 0x4e, 0xdb, 0x64, 0x1f, 0x3a, 0x6b, 0xae, 0xa3, + 0x1b, 0x7d, 0xdd, 0xf3, 0x69, 0xd6, 0x8f, 0x04, 0xa1, 0x81, 0x73, 0x28, 0x8f, 0x1c, 0xa4, 0xb9, + 0xc1, 0x69, 0xfb, 0xec, 0x41, 0x9d, 0x1e, 0x25, 0xfb, 0xd9, 0x1f, 0x24, 0xb7, 0x10, 0x11, 0x88, + 0x0c, 0xb5, 0x23, 0x03, 0x14, 0x4c, 0xbd, 0x0b, 0xf5, 0x75, 0x5a, 0x3e, 0xed, 0xda, 0x06, 0x7e, + 0x9a, 0xb4, 0x56, 0xf4, 0xb7, 0x0f, 0x6e, 0x45, 0x00, 0x72, 0x53, 0x68, 0x9e, 0xfa, 0xec, 0x06, + 0x7e, 0xca, 0xce, 0x79, 0x45, 0xb6, 0x6e, 0x0c, 0x24, 0x23, 0xc6, 0x91, 0x42, 0x46, 0xec, 0xfc, + 0x52, 0xd4, 0x93, 0x15, 0xdb, 0xdd, 0xce, 0x58, 0x24, 0x05, 0x19, 0xec, 0x7a, 0x27, 0xff, 0x84, + 0xa8, 0x5d, 0x08, 0xf0, 0xea, 0xd2, 0xda, 0xed, 0x8d, 0x49, 0xa8, 0x47, 0xdd, 0xd3, 0x95, 0xe9, + 0x80, 0x62, 0x97, 0x6d, 0xa8, 0x11, 0xee, 0x64, 0xc7, 0x73, 0x4d, 0x36, 0x51, 0xbc, 0xce, 0x7f, + 0x38, 0x1b, 0xd8, 0xeb, 0xbb, 0xe6, 0x1e, 0x3f, 0x74, 0x29, 0x3a, 0x31, 0x90, 0x89, 0x87, 0x93, + 0x80, 0x14, 0x98, 0x8f, 0xa9, 0xcf, 0x20, 0x48, 0xc7, 0x55, 0xe5, 0xeb, 0xd3, 0xce, 0x37, 0xae, + 0x26, 0x6f, 0xe5, 0x05, 0x17, 0xdb, 0xfe, 0x0a, 0x8d, 0x83, 0xe8, 0xfb, 0xb5, 0xb1, 0x69, 0x19, + 0x41, 0xe2, 0x0f, 0xdd, 0x99, 0xb4, 0x54, 0x0c, 0x34, 0xd3, 0xfd, 0x9b, 0x30, 0x43, 0xec, 0xff, + 0x8b, 0x50, 0x13, 0xf9, 0x67, 0x24, 0xcb, 0x5a, 0x26, 0x33, 0xdf, 0x9d, 0x6b, 0x93, 0x81, 0xc4, + 0xca, 0x18, 0x96, 0x64, 0xd9, 0x66, 0x69, 0x88, 0x3d, 0x21, 0x2d, 0x3d, 0x85, 0x3f, 0x56, 0xbf, + 0x6c, 0x40, 0x35, 0x98, 0xf8, 0x15, 0x27, 0xae, 0xbe, 0x86, 0x4c, 0xd2, 0x77, 0x61, 0x21, 0xf1, + 0x6b, 0x2e, 0x52, 0x0d, 0x2e, 0xff, 0xc5, 0x97, 0x69, 0xa2, 0xf6, 0x31, 0xff, 0x21, 0x54, 0x11, + 0xe2, 0xdd, 0xc8, 0xca, 0x46, 0x25, 0xa3, 0xbb, 0x29, 0x0b, 0xff, 0xff, 0x0e, 0x70, 0x1e, 0x02, + 0x44, 0x42, 0x9b, 0xc9, 0x77, 0x03, 0x88, 0xb7, 0x3e, 0x8d, 0x5a, 0x43, 0x69, 0xf4, 0xf2, 0x4a, + 0x9e, 0x3e, 0xeb, 0x6c, 0x0f, 0x34, 0x3b, 0x66, 0x79, 0x04, 0x8d, 0xe8, 0xad, 0x1d, 0x24, 0xfd, + 0x69, 0xcb, 0xf4, 0xb5, 0x9e, 0x69, 0x5f, 0xb1, 0x75, 0x42, 0xc7, 0x76, 0xca, 0x72, 0x1e, 0xa0, + 0x74, 0x1f, 0x86, 0x34, 0x10, 0xc8, 0xec, 0xfe, 0x90, 0x06, 0x02, 0xd9, 0xcd, 0x1d, 0x2c, 0x29, + 0x99, 0x6c, 0x2e, 0x90, 0x26, 0x25, 0x33, 0xda, 0x35, 0xa4, 0x49, 0xc9, 0xac, 0x6e, 0x85, 0x88, + 0xfc, 0x4d, 0x0c, 0xdd, 0x64, 0xbf, 0xd3, 0x3b, 0x8d, 0x78, 0x06, 0x2c, 0x3f, 0x74, 0x7c, 0x73, + 0xff, 0x38, 0x59, 0x66, 0x92, 0xba, 0xcd, 0x59, 0x35, 0xae, 0xe9, 0x52, 0x7e, 0x91, 0x7a, 0x6d, + 0x59, 0xb5, 0x2c, 0x94, 0xa7, 0x28, 0xd6, 0xb9, 0x9b, 0x03, 0xa3, 0xb4, 0x1d, 0x5b, 0xbb, 0xfb, + 0xc9, 0x1b, 0x03, 0xd3, 0x7f, 0x32, 0xde, 0x23, 0x68, 0xdd, 0x66, 0x4b, 0xbc, 0x6e, 0x3a, 0xfc, + 0xbf, 0xdb, 0x81, 0xaa, 0xb8, 0x4d, 0x57, 0xbd, 0x4d, 0x56, 0x1d, 0xed, 0xed, 0x55, 0xe8, 0xd3, + 0xdd, 0xff, 0x0b, 0x00, 0x00, 0xff, 0xff, 0x56, 0x96, 0x37, 0x49, 0x40, 0x5b, 0x00, 0x00, } // Reference imports to suppress errors if they are not otherwise used. @@ -5429,7 +5756,7 @@ type DataCoordClient interface { GetCompactionState(ctx context.Context, in *milvuspb.GetCompactionStateRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionStateResponse, error) GetCompactionStateWithPlans(ctx context.Context, in *milvuspb.GetCompactionPlansRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionPlansResponse, error) WatchChannels(ctx context.Context, in *WatchChannelsRequest, opts ...grpc.CallOption) (*WatchChannelsResponse, error) - GetFlushState(ctx context.Context, in *milvuspb.GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) + GetFlushState(ctx context.Context, in *GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) DropVirtualChannel(ctx context.Context, in *DropVirtualChannelRequest, opts ...grpc.CallOption) (*DropVirtualChannelResponse, error) SetSegmentState(ctx context.Context, in *SetSegmentStateRequest, opts ...grpc.CallOption) (*SetSegmentStateResponse, error) // https://wiki.lfaidata.foundation/display/MIL/MEP+24+--+Support+bulk+load @@ -5670,7 +5997,7 @@ func (c *dataCoordClient) WatchChannels(ctx context.Context, in *WatchChannelsRe return out, nil } -func (c *dataCoordClient) GetFlushState(ctx context.Context, in *milvuspb.GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) { +func (c *dataCoordClient) GetFlushState(ctx context.Context, in *GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) { out := new(milvuspb.GetFlushStateResponse) err := c.cc.Invoke(ctx, "/milvus.proto.data.DataCoord/GetFlushState", in, out, opts...) if err != nil { @@ -5885,7 +6212,7 @@ type DataCoordServer interface { GetCompactionState(context.Context, *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) GetCompactionStateWithPlans(context.Context, *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) WatchChannels(context.Context, *WatchChannelsRequest) (*WatchChannelsResponse, error) - GetFlushState(context.Context, *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) + GetFlushState(context.Context, *GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) DropVirtualChannel(context.Context, *DropVirtualChannelRequest) (*DropVirtualChannelResponse, error) SetSegmentState(context.Context, *SetSegmentStateRequest) (*SetSegmentStateResponse, error) // https://wiki.lfaidata.foundation/display/MIL/MEP+24+--+Support+bulk+load @@ -5984,7 +6311,7 @@ func (*UnimplementedDataCoordServer) GetCompactionStateWithPlans(ctx context.Con func (*UnimplementedDataCoordServer) WatchChannels(ctx context.Context, req *WatchChannelsRequest) (*WatchChannelsResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method WatchChannels not implemented") } -func (*UnimplementedDataCoordServer) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { +func (*UnimplementedDataCoordServer) GetFlushState(ctx context.Context, req *GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetFlushState not implemented") } func (*UnimplementedDataCoordServer) DropVirtualChannel(ctx context.Context, req *DropVirtualChannelRequest) (*DropVirtualChannelResponse, error) { @@ -6467,7 +6794,7 @@ func _DataCoord_WatchChannels_Handler(srv interface{}, ctx context.Context, dec } func _DataCoord_GetFlushState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(milvuspb.GetFlushStateRequest) + in := new(GetFlushStateRequest) if err := dec(in); err != nil { return nil, err } @@ -6479,7 +6806,7 @@ func _DataCoord_GetFlushState_Handler(srv interface{}, ctx context.Context, dec FullMethod: "/milvus.proto.data.DataCoord/GetFlushState", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DataCoordServer).GetFlushState(ctx, req.(*milvuspb.GetFlushStateRequest)) + return srv.(DataCoordServer).GetFlushState(ctx, req.(*GetFlushStateRequest)) } return interceptor(ctx, in, info, handler) } @@ -7045,8 +7372,12 @@ type DataNodeClient interface { SyncSegments(ctx context.Context, in *SyncSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) // https://wiki.lfaidata.foundation/display/MIL/MEP+24+--+Support+bulk+load Import(ctx context.Context, in *ImportTaskRequest, opts ...grpc.CallOption) (*commonpb.Status, error) + // Deprecated ResendSegmentStats(ctx context.Context, in *ResendSegmentStatsRequest, opts ...grpc.CallOption) (*ResendSegmentStatsResponse, error) AddImportSegment(ctx context.Context, in *AddImportSegmentRequest, opts ...grpc.CallOption) (*AddImportSegmentResponse, error) + FlushChannels(ctx context.Context, in *FlushChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) + NotifyChannelOperation(ctx context.Context, in *ChannelOperationsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) + CheckChannelOperationProgress(ctx context.Context, in *ChannelWatchInfo, opts ...grpc.CallOption) (*ChannelOperationProgressResponse, error) } type dataNodeClient struct { @@ -7165,6 +7496,33 @@ func (c *dataNodeClient) AddImportSegment(ctx context.Context, in *AddImportSegm return out, nil } +func (c *dataNodeClient) FlushChannels(ctx context.Context, in *FlushChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + out := new(commonpb.Status) + err := c.cc.Invoke(ctx, "/milvus.proto.data.DataNode/FlushChannels", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dataNodeClient) NotifyChannelOperation(ctx context.Context, in *ChannelOperationsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + out := new(commonpb.Status) + err := c.cc.Invoke(ctx, "/milvus.proto.data.DataNode/NotifyChannelOperation", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *dataNodeClient) CheckChannelOperationProgress(ctx context.Context, in *ChannelWatchInfo, opts ...grpc.CallOption) (*ChannelOperationProgressResponse, error) { + out := new(ChannelOperationProgressResponse) + err := c.cc.Invoke(ctx, "/milvus.proto.data.DataNode/CheckChannelOperationProgress", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DataNodeServer is the server API for DataNode service. type DataNodeServer interface { GetComponentStates(context.Context, *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) @@ -7179,8 +7537,12 @@ type DataNodeServer interface { SyncSegments(context.Context, *SyncSegmentsRequest) (*commonpb.Status, error) // https://wiki.lfaidata.foundation/display/MIL/MEP+24+--+Support+bulk+load Import(context.Context, *ImportTaskRequest) (*commonpb.Status, error) + // Deprecated ResendSegmentStats(context.Context, *ResendSegmentStatsRequest) (*ResendSegmentStatsResponse, error) AddImportSegment(context.Context, *AddImportSegmentRequest) (*AddImportSegmentResponse, error) + FlushChannels(context.Context, *FlushChannelsRequest) (*commonpb.Status, error) + NotifyChannelOperation(context.Context, *ChannelOperationsRequest) (*commonpb.Status, error) + CheckChannelOperationProgress(context.Context, *ChannelWatchInfo) (*ChannelOperationProgressResponse, error) } // UnimplementedDataNodeServer can be embedded to have forward compatible implementations. @@ -7223,6 +7585,15 @@ func (*UnimplementedDataNodeServer) ResendSegmentStats(ctx context.Context, req func (*UnimplementedDataNodeServer) AddImportSegment(ctx context.Context, req *AddImportSegmentRequest) (*AddImportSegmentResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method AddImportSegment not implemented") } +func (*UnimplementedDataNodeServer) FlushChannels(ctx context.Context, req *FlushChannelsRequest) (*commonpb.Status, error) { + return nil, status.Errorf(codes.Unimplemented, "method FlushChannels not implemented") +} +func (*UnimplementedDataNodeServer) NotifyChannelOperation(ctx context.Context, req *ChannelOperationsRequest) (*commonpb.Status, error) { + return nil, status.Errorf(codes.Unimplemented, "method NotifyChannelOperation not implemented") +} +func (*UnimplementedDataNodeServer) CheckChannelOperationProgress(ctx context.Context, req *ChannelWatchInfo) (*ChannelOperationProgressResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method CheckChannelOperationProgress not implemented") +} func RegisterDataNodeServer(s *grpc.Server, srv DataNodeServer) { s.RegisterService(&_DataNode_serviceDesc, srv) @@ -7444,6 +7815,60 @@ func _DataNode_AddImportSegment_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DataNode_FlushChannels_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(FlushChannelsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DataNodeServer).FlushChannels(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/milvus.proto.data.DataNode/FlushChannels", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DataNodeServer).FlushChannels(ctx, req.(*FlushChannelsRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DataNode_NotifyChannelOperation_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ChannelOperationsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DataNodeServer).NotifyChannelOperation(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/milvus.proto.data.DataNode/NotifyChannelOperation", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DataNodeServer).NotifyChannelOperation(ctx, req.(*ChannelOperationsRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DataNode_CheckChannelOperationProgress_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ChannelWatchInfo) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DataNodeServer).CheckChannelOperationProgress(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/milvus.proto.data.DataNode/CheckChannelOperationProgress", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DataNodeServer).CheckChannelOperationProgress(ctx, req.(*ChannelWatchInfo)) + } + return interceptor(ctx, in, info, handler) +} + var _DataNode_serviceDesc = grpc.ServiceDesc{ ServiceName: "milvus.proto.data.DataNode", HandlerType: (*DataNodeServer)(nil), @@ -7496,6 +7921,18 @@ var _DataNode_serviceDesc = grpc.ServiceDesc{ MethodName: "AddImportSegment", Handler: _DataNode_AddImportSegment_Handler, }, + { + MethodName: "FlushChannels", + Handler: _DataNode_FlushChannels_Handler, + }, + { + MethodName: "NotifyChannelOperation", + Handler: _DataNode_NotifyChannelOperation_Handler, + }, + { + MethodName: "CheckChannelOperationProgress", + Handler: _DataNode_CheckChannelOperationProgress_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "data_coord.proto", diff --git a/internal/proto/index_coord.proto b/internal/proto/index_coord.proto index e8d001f43a22c..500b8477b6678 100644 --- a/internal/proto/index_coord.proto +++ b/internal/proto/index_coord.proto @@ -102,21 +102,22 @@ message FieldIndex { } message SegmentIndex { - int64 collectionID = 1; - int64 partitionID = 2; - int64 segmentID = 3; - int64 num_rows = 4; - int64 indexID = 5; - int64 buildID = 6; - int64 nodeID = 7; - int64 index_version = 8; - common.IndexState state = 9; - string fail_reason = 10; - repeated string index_file_keys = 11; - bool deleted = 12; - uint64 create_time = 13; - uint64 serialize_size = 14; - bool write_handoff = 15; + int64 collectionID = 1; + int64 partitionID = 2; + int64 segmentID = 3; + int64 num_rows = 4; + int64 indexID = 5; + int64 buildID = 6; + int64 nodeID = 7; + int64 index_version = 8; + common.IndexState state = 9; + string fail_reason = 10; + repeated string index_file_keys = 11; + bool deleted = 12; + uint64 create_time = 13; + uint64 serialize_size = 14; + bool write_handoff = 15; + int32 current_index_version = 16; } message RegisterNodeRequest { @@ -176,16 +177,17 @@ message GetIndexInfoRequest { } message IndexFilePathInfo { - int64 segmentID = 1; - int64 fieldID = 2; - int64 indexID = 3; - int64 buildID = 4; - string index_name = 5; - repeated common.KeyValuePair index_params = 6; - repeated string index_file_paths = 7; - uint64 serialized_size = 8; - int64 index_version = 9; - int64 num_rows = 10; + int64 segmentID = 1; + int64 fieldID = 2; + int64 indexID = 3; + int64 buildID = 4; + string index_name = 5; + repeated common.KeyValuePair index_params = 6; + repeated string index_file_paths = 7; + uint64 serialized_size = 8; + int64 index_version = 9; + int64 num_rows = 10; + int32 current_index_version = 11; } message SegmentInfo { @@ -242,6 +244,8 @@ message StorageConfig { string storage_type = 9; bool use_virtual_host = 10; string region = 11; + string cloud_provider = 12; + int64 request_timeout_ms = 13; } message CreateJobRequest { @@ -256,29 +260,32 @@ message CreateJobRequest { repeated common.KeyValuePair index_params = 9; repeated common.KeyValuePair type_params = 10; int64 num_rows = 11; - int64 collectionID = 12; - int64 partitionID = 13; - int64 segmentID = 14; - int64 fieldID = 15; - string field_name = 16; - schema.DataType field_type = 17; - string store_path = 18; - int64 store_version = 19; - string index_store_path = 20; - int64 dim = 21; + int32 current_index_version = 12; + int64 collectionID = 13; + int64 partitionID = 14; + int64 segmentID = 15; + int64 fieldID = 16; + string field_name = 17; + schema.DataType field_type = 18; + string store_path = 19; + int64 store_version = 20; + string index_store_path = 21; + int64 dim = 22; } + message QueryJobsRequest { string clusterID = 1; repeated int64 buildIDs = 2; } message IndexTaskInfo { - int64 buildID = 1; - common.IndexState state = 2; - repeated string index_file_keys = 3; - uint64 serialized_size = 4; - string fail_reason = 5; + int64 buildID = 1; + common.IndexState state = 2; + repeated string index_file_keys = 3; + uint64 serialized_size = 4; + string fail_reason = 5; + int32 current_index_version = 6; } message QueryJobsResponse { diff --git a/internal/proto/indexpb/index_coord.pb.go b/internal/proto/indexpb/index_coord.pb.go index 22866766bf975..18c28895f4b1c 100644 --- a/internal/proto/indexpb/index_coord.pb.go +++ b/internal/proto/indexpb/index_coord.pb.go @@ -237,6 +237,7 @@ type SegmentIndex struct { CreateTime uint64 `protobuf:"varint,13,opt,name=create_time,json=createTime,proto3" json:"create_time,omitempty"` SerializeSize uint64 `protobuf:"varint,14,opt,name=serialize_size,json=serializeSize,proto3" json:"serialize_size,omitempty"` WriteHandoff bool `protobuf:"varint,15,opt,name=write_handoff,json=writeHandoff,proto3" json:"write_handoff,omitempty"` + CurrentIndexVersion int32 `protobuf:"varint,16,opt,name=current_index_version,json=currentIndexVersion,proto3" json:"current_index_version,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -372,6 +373,13 @@ func (m *SegmentIndex) GetWriteHandoff() bool { return false } +func (m *SegmentIndex) GetCurrentIndexVersion() int32 { + if m != nil { + return m.CurrentIndexVersion + } + return 0 +} + type RegisterNodeRequest struct { Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` Address *commonpb.Address `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"` @@ -894,6 +902,7 @@ type IndexFilePathInfo struct { SerializedSize uint64 `protobuf:"varint,8,opt,name=serialized_size,json=serializedSize,proto3" json:"serialized_size,omitempty"` IndexVersion int64 `protobuf:"varint,9,opt,name=index_version,json=indexVersion,proto3" json:"index_version,omitempty"` NumRows int64 `protobuf:"varint,10,opt,name=num_rows,json=numRows,proto3" json:"num_rows,omitempty"` + CurrentIndexVersion int32 `protobuf:"varint,11,opt,name=current_index_version,json=currentIndexVersion,proto3" json:"current_index_version,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -994,6 +1003,13 @@ func (m *IndexFilePathInfo) GetNumRows() int64 { return 0 } +func (m *IndexFilePathInfo) GetCurrentIndexVersion() int32 { + if m != nil { + return m.CurrentIndexVersion + } + return 0 +} + type SegmentInfo struct { CollectionID int64 `protobuf:"varint,1,opt,name=collectionID,proto3" json:"collectionID,omitempty"` SegmentID int64 `protobuf:"varint,2,opt,name=segmentID,proto3" json:"segmentID,omitempty"` @@ -1391,6 +1407,8 @@ type StorageConfig struct { StorageType string `protobuf:"bytes,9,opt,name=storage_type,json=storageType,proto3" json:"storage_type,omitempty"` UseVirtualHost bool `protobuf:"varint,10,opt,name=use_virtual_host,json=useVirtualHost,proto3" json:"use_virtual_host,omitempty"` Region string `protobuf:"bytes,11,opt,name=region,proto3" json:"region,omitempty"` + CloudProvider string `protobuf:"bytes,12,opt,name=cloud_provider,json=cloudProvider,proto3" json:"cloud_provider,omitempty"` + RequestTimeoutMs int64 `protobuf:"varint,13,opt,name=request_timeout_ms,json=requestTimeoutMs,proto3" json:"request_timeout_ms,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -1498,6 +1516,20 @@ func (m *StorageConfig) GetRegion() string { return "" } +func (m *StorageConfig) GetCloudProvider() string { + if m != nil { + return m.CloudProvider + } + return "" +} + +func (m *StorageConfig) GetRequestTimeoutMs() int64 { + if m != nil { + return m.RequestTimeoutMs + } + return 0 +} + type CreateJobRequest struct { ClusterID string `protobuf:"bytes,1,opt,name=clusterID,proto3" json:"clusterID,omitempty"` IndexFilePrefix string `protobuf:"bytes,2,opt,name=index_file_prefix,json=indexFilePrefix,proto3" json:"index_file_prefix,omitempty"` @@ -1510,16 +1542,17 @@ type CreateJobRequest struct { IndexParams []*commonpb.KeyValuePair `protobuf:"bytes,9,rep,name=index_params,json=indexParams,proto3" json:"index_params,omitempty"` TypeParams []*commonpb.KeyValuePair `protobuf:"bytes,10,rep,name=type_params,json=typeParams,proto3" json:"type_params,omitempty"` NumRows int64 `protobuf:"varint,11,opt,name=num_rows,json=numRows,proto3" json:"num_rows,omitempty"` - CollectionID int64 `protobuf:"varint,12,opt,name=collectionID,proto3" json:"collectionID,omitempty"` - PartitionID int64 `protobuf:"varint,13,opt,name=partitionID,proto3" json:"partitionID,omitempty"` - SegmentID int64 `protobuf:"varint,14,opt,name=segmentID,proto3" json:"segmentID,omitempty"` - FieldID int64 `protobuf:"varint,15,opt,name=fieldID,proto3" json:"fieldID,omitempty"` - FieldName string `protobuf:"bytes,16,opt,name=field_name,json=fieldName,proto3" json:"field_name,omitempty"` - FieldType schemapb.DataType `protobuf:"varint,17,opt,name=field_type,json=fieldType,proto3,enum=milvus.proto.schema.DataType" json:"field_type,omitempty"` - StorePath string `protobuf:"bytes,18,opt,name=store_path,json=storePath,proto3" json:"store_path,omitempty"` - StoreVersion int64 `protobuf:"varint,19,opt,name=store_version,json=storeVersion,proto3" json:"store_version,omitempty"` - IndexStorePath string `protobuf:"bytes,20,opt,name=index_store_path,json=indexStorePath,proto3" json:"index_store_path,omitempty"` - Dim int64 `protobuf:"varint,21,opt,name=dim,proto3" json:"dim,omitempty"` + CurrentIndexVersion int32 `protobuf:"varint,12,opt,name=current_index_version,json=currentIndexVersion,proto3" json:"current_index_version,omitempty"` + CollectionID int64 `protobuf:"varint,13,opt,name=collectionID,proto3" json:"collectionID,omitempty"` + PartitionID int64 `protobuf:"varint,14,opt,name=partitionID,proto3" json:"partitionID,omitempty"` + SegmentID int64 `protobuf:"varint,15,opt,name=segmentID,proto3" json:"segmentID,omitempty"` + FieldID int64 `protobuf:"varint,16,opt,name=fieldID,proto3" json:"fieldID,omitempty"` + FieldName string `protobuf:"bytes,17,opt,name=field_name,json=fieldName,proto3" json:"field_name,omitempty"` + FieldType schemapb.DataType `protobuf:"varint,18,opt,name=field_type,json=fieldType,proto3,enum=milvus.proto.schema.DataType" json:"field_type,omitempty"` + StorePath string `protobuf:"bytes,19,opt,name=store_path,json=storePath,proto3" json:"store_path,omitempty"` + StoreVersion int64 `protobuf:"varint,20,opt,name=store_version,json=storeVersion,proto3" json:"store_version,omitempty"` + IndexStorePath string `protobuf:"bytes,21,opt,name=index_store_path,json=indexStorePath,proto3" json:"index_store_path,omitempty"` + Dim int64 `protobuf:"varint,22,opt,name=dim,proto3" json:"dim,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -1627,6 +1660,13 @@ func (m *CreateJobRequest) GetNumRows() int64 { return 0 } +func (m *CreateJobRequest) GetCurrentIndexVersion() int32 { + if m != nil { + return m.CurrentIndexVersion + } + return 0 +} + func (m *CreateJobRequest) GetCollectionID() int64 { if m != nil { return m.CollectionID @@ -1750,6 +1790,7 @@ type IndexTaskInfo struct { IndexFileKeys []string `protobuf:"bytes,3,rep,name=index_file_keys,json=indexFileKeys,proto3" json:"index_file_keys,omitempty"` SerializedSize uint64 `protobuf:"varint,4,opt,name=serialized_size,json=serializedSize,proto3" json:"serialized_size,omitempty"` FailReason string `protobuf:"bytes,5,opt,name=fail_reason,json=failReason,proto3" json:"fail_reason,omitempty"` + CurrentIndexVersion int32 `protobuf:"varint,6,opt,name=current_index_version,json=currentIndexVersion,proto3" json:"current_index_version,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -1815,6 +1856,13 @@ func (m *IndexTaskInfo) GetFailReason() string { return "" } +func (m *IndexTaskInfo) GetCurrentIndexVersion() int32 { + if m != nil { + return m.CurrentIndexVersion + } + return 0 +} + type QueryJobsResponse struct { Status *commonpb.Status `protobuf:"bytes,1,opt,name=status,proto3" json:"status,omitempty"` ClusterID string `protobuf:"bytes,2,opt,name=clusterID,proto3" json:"clusterID,omitempty"` @@ -2246,158 +2294,163 @@ func init() { func init() { proto.RegisterFile("index_coord.proto", fileDescriptor_f9e019eb3fda53c2) } var fileDescriptor_f9e019eb3fda53c2 = []byte{ - // 2403 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe4, 0x5a, 0xcd, 0x6f, 0x1c, 0x49, - 0x15, 0x4f, 0xcf, 0x8c, 0xed, 0xe9, 0xd7, 0x33, 0xfe, 0xa8, 0x38, 0x30, 0x99, 0x24, 0xc4, 0xe9, - 0x6c, 0x12, 0x83, 0x88, 0x13, 0xbc, 0x2c, 0x5a, 0xd0, 0x82, 0xe4, 0xd8, 0x9b, 0x64, 0x92, 0x4d, - 0x64, 0xda, 0x51, 0x24, 0x56, 0x88, 0xa6, 0x67, 0xba, 0xc6, 0xae, 0x75, 0x4f, 0xd7, 0xa4, 0xab, - 0x3a, 0x89, 0x83, 0x84, 0xe0, 0xb0, 0x07, 0xd0, 0x4a, 0x08, 0x84, 0xc4, 0x3f, 0xc0, 0x69, 0xff, - 0x04, 0x24, 0xc4, 0x85, 0x23, 0x27, 0xee, 0xfc, 0x09, 0x9c, 0xb9, 0xa2, 0xfa, 0xe8, 0x9e, 0xee, - 0x9e, 0x1e, 0xcf, 0xf8, 0x03, 0x21, 0xc1, 0x6d, 0xea, 0xd5, 0xab, 0x8f, 0x7e, 0xef, 0xf7, 0xde, - 0xfb, 0xbd, 0xb2, 0x61, 0x85, 0x84, 0x3e, 0x7e, 0xeb, 0xf6, 0x28, 0x8d, 0xfc, 0x8d, 0x61, 0x44, - 0x39, 0x45, 0x68, 0x40, 0x82, 0xd7, 0x31, 0x53, 0xa3, 0x0d, 0x39, 0xdf, 0x6e, 0xf4, 0xe8, 0x60, - 0x40, 0x43, 0x25, 0x6b, 0x2f, 0x92, 0x90, 0xe3, 0x28, 0xf4, 0x02, 0x3d, 0x6e, 0x64, 0x57, 0xb4, - 0x1b, 0xac, 0x77, 0x80, 0x07, 0x9e, 0x1a, 0xd9, 0xff, 0xa8, 0x81, 0xd9, 0x11, 0x7b, 0x74, 0xc2, - 0x3e, 0x45, 0x36, 0x34, 0x7a, 0x34, 0x08, 0x70, 0x8f, 0x13, 0x1a, 0x76, 0x76, 0x5a, 0xc6, 0x9a, - 0xb1, 0x5e, 0x75, 0x72, 0x32, 0xd4, 0x82, 0x85, 0x3e, 0xc1, 0x81, 0xdf, 0xd9, 0x69, 0x55, 0xe4, - 0x74, 0x32, 0x44, 0xd7, 0x00, 0xd4, 0x75, 0x43, 0x6f, 0x80, 0x5b, 0xd5, 0x35, 0x63, 0xdd, 0x74, - 0x4c, 0x29, 0x79, 0xee, 0x0d, 0xb0, 0x58, 0x28, 0x07, 0x9d, 0x9d, 0x56, 0x4d, 0x2d, 0xd4, 0x43, - 0xf4, 0x00, 0x2c, 0x7e, 0x34, 0xc4, 0xee, 0xd0, 0x8b, 0xbc, 0x01, 0x6b, 0xcd, 0xad, 0x55, 0xd7, - 0xad, 0xcd, 0x1b, 0x1b, 0xb9, 0x0f, 0xd5, 0x5f, 0xf8, 0x14, 0x1f, 0xbd, 0xf4, 0x82, 0x18, 0xef, - 0x7a, 0x24, 0x72, 0x40, 0xac, 0xda, 0x95, 0x8b, 0xd0, 0x0e, 0x34, 0xd4, 0xe1, 0x7a, 0x93, 0xf9, - 0x59, 0x37, 0xb1, 0xe4, 0x32, 0xbd, 0xcb, 0x0d, 0xbd, 0x0b, 0xf6, 0xdd, 0x88, 0xbe, 0x61, 0xad, - 0x05, 0x79, 0x51, 0x4b, 0xcb, 0x1c, 0xfa, 0x86, 0x89, 0xaf, 0xe4, 0x94, 0x7b, 0x81, 0x52, 0xa8, - 0x4b, 0x05, 0x53, 0x4a, 0xe4, 0xf4, 0x07, 0x30, 0xc7, 0xb8, 0xc7, 0x71, 0xcb, 0x5c, 0x33, 0xd6, - 0x17, 0x37, 0xaf, 0x97, 0x5e, 0x40, 0x5a, 0x7c, 0x4f, 0xa8, 0x39, 0x4a, 0x1b, 0x7d, 0x00, 0x5f, - 0x55, 0xd7, 0x97, 0x43, 0xb7, 0xef, 0x91, 0xc0, 0x8d, 0xb0, 0xc7, 0x68, 0xd8, 0x02, 0x69, 0xc8, - 0x55, 0x92, 0xae, 0x79, 0xe8, 0x91, 0xc0, 0x91, 0x73, 0xc8, 0x86, 0x26, 0x61, 0xae, 0x17, 0x73, - 0xea, 0xca, 0xf9, 0x96, 0xb5, 0x66, 0xac, 0xd7, 0x1d, 0x8b, 0xb0, 0xad, 0x98, 0x53, 0x79, 0x0c, - 0x7a, 0x06, 0x2b, 0x31, 0xc3, 0x91, 0x9b, 0x33, 0x4f, 0x63, 0x56, 0xf3, 0x2c, 0x89, 0xb5, 0x9d, - 0x8c, 0x89, 0xbe, 0x09, 0x68, 0x88, 0x43, 0x9f, 0x84, 0xfb, 0x7a, 0x47, 0x69, 0x87, 0xa6, 0xb4, - 0xc3, 0xb2, 0x9e, 0x91, 0xfa, 0xc2, 0x1c, 0xf6, 0xe7, 0x06, 0xc0, 0x43, 0x89, 0x0f, 0x79, 0x97, - 0x8f, 0x12, 0x88, 0x90, 0xb0, 0x4f, 0x25, 0xbc, 0xac, 0xcd, 0x6b, 0x1b, 0xe3, 0x88, 0xde, 0x48, - 0x31, 0xa9, 0x11, 0x24, 0xe1, 0xd9, 0x82, 0x05, 0x1f, 0x07, 0x98, 0x63, 0x5f, 0x42, 0xaf, 0xee, - 0x24, 0x43, 0x74, 0x1d, 0xac, 0x5e, 0x84, 0x85, 0xe5, 0x38, 0xd1, 0xd8, 0xab, 0x39, 0xa0, 0x44, - 0x2f, 0xc8, 0x00, 0xdb, 0x9f, 0xd7, 0xa0, 0xb1, 0x87, 0xf7, 0x07, 0x38, 0xe4, 0xea, 0x26, 0xb3, - 0x40, 0x7d, 0x0d, 0xac, 0xa1, 0x17, 0x71, 0xa2, 0x55, 0x14, 0xdc, 0xb3, 0x22, 0x74, 0x15, 0x4c, - 0xa6, 0x77, 0xdd, 0x91, 0xa7, 0x56, 0x9d, 0x91, 0x00, 0x5d, 0x86, 0x7a, 0x18, 0x0f, 0x94, 0x81, - 0x34, 0xe4, 0xc3, 0x78, 0x20, 0x61, 0x92, 0x09, 0x86, 0xb9, 0x7c, 0x30, 0xb4, 0x60, 0xa1, 0x1b, - 0x13, 0x19, 0x5f, 0xf3, 0x6a, 0x46, 0x0f, 0xd1, 0x57, 0x60, 0x3e, 0xa4, 0x3e, 0xee, 0xec, 0x68, - 0x58, 0xea, 0x11, 0xba, 0x09, 0x4d, 0x65, 0xd4, 0xd7, 0x38, 0x62, 0x84, 0x86, 0x1a, 0x94, 0x0a, - 0xc9, 0x2f, 0x95, 0xec, 0xb4, 0xb8, 0xbc, 0x0e, 0xd6, 0x38, 0x16, 0xa1, 0x3f, 0x42, 0xe0, 0x6d, - 0x58, 0x52, 0x87, 0xf7, 0x49, 0x80, 0xdd, 0x43, 0x7c, 0xc4, 0x5a, 0xd6, 0x5a, 0x75, 0xdd, 0x74, - 0xd4, 0x9d, 0x1e, 0x92, 0x00, 0x3f, 0xc5, 0x47, 0x2c, 0xeb, 0xbb, 0xc6, 0xb1, 0xbe, 0x6b, 0x16, - 0x7d, 0x87, 0x6e, 0xc1, 0x22, 0xc3, 0x11, 0xf1, 0x02, 0xf2, 0x0e, 0xbb, 0x8c, 0xbc, 0xc3, 0xad, - 0x45, 0xa9, 0xd3, 0x4c, 0xa5, 0x7b, 0xe4, 0x1d, 0x16, 0x66, 0x78, 0x13, 0x11, 0x8e, 0xdd, 0x03, - 0x2f, 0xf4, 0x69, 0xbf, 0xdf, 0x5a, 0x92, 0xe7, 0x34, 0xa4, 0xf0, 0xb1, 0x92, 0xd9, 0x7f, 0x30, - 0xe0, 0xa2, 0x83, 0xf7, 0x09, 0xe3, 0x38, 0x7a, 0x4e, 0x7d, 0xec, 0xe0, 0x57, 0x31, 0x66, 0x1c, - 0xdd, 0x87, 0x5a, 0xd7, 0x63, 0x58, 0x43, 0xf2, 0x6a, 0xa9, 0x75, 0x9e, 0xb1, 0xfd, 0x07, 0x1e, - 0xc3, 0x8e, 0xd4, 0x44, 0xdf, 0x81, 0x05, 0xcf, 0xf7, 0x23, 0xcc, 0x98, 0x04, 0xc6, 0xa4, 0x45, - 0x5b, 0x4a, 0xc7, 0x49, 0x94, 0x33, 0x5e, 0xac, 0x66, 0xbd, 0x68, 0xff, 0xc6, 0x80, 0xd5, 0xfc, - 0xcd, 0xd8, 0x90, 0x86, 0x0c, 0xa3, 0xf7, 0x61, 0x5e, 0xf8, 0x22, 0x66, 0xfa, 0x72, 0x57, 0x4a, - 0xcf, 0xd9, 0x93, 0x2a, 0x8e, 0x56, 0x15, 0x29, 0x95, 0x84, 0x84, 0x27, 0xe1, 0xae, 0x6e, 0x78, - 0xa3, 0x18, 0x69, 0xba, 0x4c, 0x74, 0x42, 0xc2, 0x55, 0x74, 0x3b, 0x40, 0xd2, 0xdf, 0xf6, 0x8f, - 0x60, 0xf5, 0x11, 0xe6, 0x19, 0x4c, 0x68, 0x5b, 0xcd, 0x12, 0x3a, 0xf9, 0x5a, 0x50, 0x29, 0xd4, - 0x02, 0xfb, 0x8f, 0x06, 0x5c, 0x2a, 0xec, 0x7d, 0x96, 0xaf, 0x4d, 0xc1, 0x5d, 0x39, 0x0b, 0xb8, - 0xab, 0x45, 0x70, 0xdb, 0xbf, 0x30, 0xe0, 0xca, 0x23, 0xcc, 0xb3, 0x89, 0xe3, 0x9c, 0x2d, 0x81, - 0xbe, 0x06, 0x90, 0x26, 0x0c, 0xd6, 0xaa, 0xae, 0x55, 0xd7, 0xab, 0x4e, 0x46, 0x62, 0xff, 0xca, - 0x80, 0x95, 0xb1, 0xf3, 0xf3, 0x79, 0xc7, 0x28, 0xe6, 0x9d, 0xff, 0x94, 0x39, 0x7e, 0x67, 0xc0, - 0xd5, 0x72, 0x73, 0x9c, 0xc5, 0x79, 0xdf, 0x57, 0x8b, 0xb0, 0x40, 0xa9, 0x28, 0x4a, 0xb7, 0xca, - 0xea, 0xc1, 0xf8, 0x99, 0x7a, 0x91, 0xfd, 0x45, 0x15, 0xd0, 0xb6, 0x4c, 0x16, 0xaa, 0xea, 0x9c, - 0xc0, 0x35, 0xa7, 0xa6, 0x32, 0x05, 0xc2, 0x52, 0x3b, 0x0f, 0xc2, 0x32, 0x77, 0x2a, 0xc2, 0x72, - 0x15, 0x4c, 0x91, 0x35, 0x19, 0xf7, 0x06, 0x43, 0x59, 0x2f, 0x6a, 0xce, 0x48, 0x30, 0x4e, 0x0f, - 0x16, 0x66, 0xa4, 0x07, 0xf5, 0xd3, 0xd2, 0x03, 0xfb, 0x2d, 0x5c, 0x4c, 0x02, 0x5b, 0x96, 0xef, - 0x13, 0xb8, 0x23, 0x1f, 0x0a, 0x95, 0x62, 0x28, 0x4c, 0x71, 0x8a, 0xfd, 0xaf, 0x0a, 0xac, 0x74, - 0x92, 0x9a, 0xb3, 0xeb, 0xf1, 0x03, 0xc9, 0x19, 0x8e, 0x8f, 0x94, 0xc9, 0x08, 0xc8, 0x14, 0xe8, - 0xea, 0xc4, 0x02, 0x5d, 0xcb, 0x17, 0xe8, 0xfc, 0x05, 0xe7, 0x8a, 0xa8, 0x39, 0x1f, 0x8a, 0xba, - 0x0e, 0xcb, 0x99, 0x82, 0x3b, 0xf4, 0xf8, 0x81, 0xa0, 0xa9, 0xa2, 0xe2, 0x2e, 0x92, 0xec, 0xd7, - 0x33, 0x74, 0x07, 0x96, 0xd2, 0x0a, 0xe9, 0xab, 0xc2, 0x59, 0x97, 0x08, 0x19, 0x95, 0x53, 0x3f, - 0xa9, 0x9c, 0x79, 0x02, 0x61, 0x96, 0x10, 0x88, 0x2c, 0x99, 0x81, 0x1c, 0x99, 0xb1, 0xff, 0x64, - 0x80, 0x95, 0x06, 0xe8, 0x8c, 0x6d, 0x44, 0xce, 0x2f, 0x95, 0xa2, 0x5f, 0x6e, 0x40, 0x03, 0x87, - 0x5e, 0x37, 0xc0, 0x1a, 0xb7, 0x55, 0x85, 0x5b, 0x25, 0x53, 0xb8, 0x7d, 0x08, 0xd6, 0x88, 0x4a, - 0x26, 0x31, 0x78, 0x6b, 0x22, 0x97, 0xcc, 0x82, 0xc2, 0x81, 0x94, 0x53, 0x32, 0xfb, 0xd7, 0x95, - 0x51, 0x99, 0x53, 0x88, 0x3d, 0x4b, 0x32, 0xfb, 0x31, 0x34, 0xf4, 0x57, 0x28, 0x8a, 0xab, 0x52, - 0xda, 0x77, 0xcb, 0xae, 0x55, 0x76, 0xe8, 0x46, 0xc6, 0x8c, 0x1f, 0x87, 0x3c, 0x3a, 0x72, 0x2c, - 0x36, 0x92, 0xb4, 0x5d, 0x58, 0x2e, 0x2a, 0xa0, 0x65, 0xa8, 0x1e, 0xe2, 0x23, 0x6d, 0x63, 0xf1, - 0x53, 0xa4, 0xff, 0xd7, 0x02, 0x3b, 0xba, 0xea, 0x5f, 0x3f, 0x36, 0x9f, 0xf6, 0xa9, 0xa3, 0xb4, - 0xbf, 0x57, 0xf9, 0xd0, 0xb0, 0x7f, 0x6f, 0xc0, 0xf2, 0x4e, 0x44, 0x87, 0x27, 0x4e, 0xa5, 0x36, - 0x34, 0x32, 0xbc, 0x38, 0x89, 0xde, 0x9c, 0x6c, 0x5a, 0x52, 0xbd, 0x0c, 0x75, 0x3f, 0xa2, 0x43, - 0xd7, 0x0b, 0x02, 0x19, 0x58, 0x82, 0x22, 0x46, 0x74, 0xb8, 0x15, 0x04, 0xf6, 0x1b, 0x58, 0xdd, - 0xc1, 0xac, 0x17, 0x91, 0xee, 0xc9, 0x93, 0xfc, 0x94, 0xfa, 0x9b, 0x4b, 0xa0, 0xd5, 0x42, 0x02, - 0xb5, 0xbf, 0x30, 0xe0, 0x52, 0xe1, 0xe4, 0xb3, 0xa0, 0xe3, 0x07, 0x79, 0xcc, 0x2a, 0x70, 0x4c, - 0xe9, 0x7f, 0xb2, 0x58, 0xf5, 0x64, 0xfd, 0x95, 0x73, 0x0f, 0x44, 0xce, 0xd9, 0x8d, 0xe8, 0xbe, - 0x64, 0x97, 0xe7, 0xc7, 0xcc, 0xfe, 0x6a, 0xc0, 0xb5, 0x09, 0x67, 0x9c, 0xe5, 0xcb, 0x8b, 0x8d, - 0x75, 0x65, 0x5a, 0x63, 0x5d, 0x2d, 0x36, 0xd6, 0xe5, 0x7d, 0x67, 0x6d, 0x42, 0xdf, 0xf9, 0xcf, - 0x0a, 0x34, 0xf7, 0x38, 0x8d, 0xbc, 0x7d, 0xbc, 0x4d, 0xc3, 0x3e, 0xd9, 0x17, 0x69, 0x3b, 0xe1, - 0xeb, 0x86, 0xfc, 0xe8, 0x94, 0x91, 0xdf, 0x80, 0x86, 0xd7, 0xeb, 0x61, 0xc6, 0x44, 0xfb, 0xa2, - 0xb3, 0x91, 0xe9, 0x58, 0x4a, 0xf6, 0x54, 0x88, 0xd0, 0x37, 0x60, 0x85, 0xe1, 0x5e, 0x84, 0xb9, - 0x3b, 0xd2, 0xd4, 0x08, 0x5e, 0x52, 0x13, 0x5b, 0x89, 0xb6, 0x20, 0xf8, 0x31, 0xc3, 0x7b, 0x7b, - 0x9f, 0x68, 0x14, 0xeb, 0x91, 0xa0, 0x57, 0xdd, 0xb8, 0x77, 0x88, 0x79, 0xb6, 0x3c, 0x80, 0x12, - 0x49, 0x28, 0x5e, 0x01, 0x33, 0xa2, 0x94, 0xcb, 0x9c, 0x2e, 0x6b, 0xb9, 0xe9, 0xd4, 0x85, 0x40, - 0xa4, 0x2d, 0xbd, 0x6b, 0x67, 0xeb, 0x99, 0xae, 0xe1, 0x7a, 0x24, 0x7a, 0xd4, 0xce, 0xd6, 0xb3, - 0x8f, 0x43, 0x7f, 0x48, 0x49, 0xc8, 0x65, 0x82, 0x37, 0x9d, 0xac, 0x48, 0x7c, 0x1e, 0x53, 0x96, - 0x70, 0x05, 0xfd, 0x90, 0xc9, 0xdd, 0x74, 0x2c, 0x2d, 0x7b, 0x71, 0x34, 0xc4, 0xa2, 0xa6, 0xc4, - 0x0c, 0xbb, 0xaf, 0x49, 0xc4, 0x63, 0x2f, 0x70, 0x0f, 0x28, 0xe3, 0x32, 0xc7, 0xd7, 0x9d, 0xc5, - 0x98, 0xe1, 0x97, 0x4a, 0xfc, 0x98, 0x32, 0x2e, 0xae, 0x11, 0xe1, 0x7d, 0x51, 0x23, 0x2c, 0xb9, - 0x8d, 0x1e, 0xd9, 0x7f, 0x9e, 0x87, 0x65, 0xc5, 0xc2, 0x9e, 0xd0, 0x6e, 0x02, 0xc7, 0xab, 0x60, - 0xf6, 0x82, 0x58, 0x34, 0x34, 0x1a, 0x8b, 0xa6, 0x33, 0x12, 0x08, 0x9b, 0x66, 0x0b, 0x59, 0x84, - 0xfb, 0xe4, 0xad, 0xb6, 0xfd, 0xd2, 0xa8, 0x92, 0x49, 0x71, 0xb6, 0xe6, 0x56, 0xc7, 0x6a, 0xae, - 0xef, 0x71, 0x4f, 0x17, 0xc2, 0x9a, 0x2c, 0x84, 0xa6, 0x90, 0xa8, 0x1a, 0x38, 0x56, 0xda, 0xe6, - 0x4a, 0x4a, 0x5b, 0xa6, 0xd6, 0xcf, 0xe7, 0x6b, 0x7d, 0x3e, 0x58, 0x16, 0x8a, 0xc9, 0xe3, 0x31, - 0x2c, 0x26, 0xa6, 0xed, 0x49, 0x94, 0x49, 0xfb, 0x97, 0x34, 0x5a, 0x32, 0xe5, 0x66, 0xe1, 0xe8, - 0x34, 0x59, 0x0e, 0x9d, 0x45, 0x6e, 0x60, 0x9e, 0x8a, 0x1b, 0x14, 0x78, 0x29, 0x9c, 0x86, 0x97, - 0x66, 0xeb, 0xbc, 0x95, 0x7f, 0xb4, 0x28, 0xa6, 0x97, 0xc6, 0xf4, 0x37, 0x93, 0xe6, 0x94, 0x37, - 0x93, 0xc5, 0x63, 0x18, 0xd9, 0xd2, 0x18, 0x27, 0x97, 0x3f, 0x95, 0x2f, 0x96, 0x95, 0x2f, 0xa4, - 0x44, 0xfa, 0xe2, 0xa3, 0x64, 0x5a, 0x82, 0x7c, 0x45, 0x76, 0x3e, 0x85, 0xd4, 0xaa, 0x5f, 0x3e, - 0x77, 0x3c, 0xee, 0x09, 0xd8, 0xeb, 0xd5, 0x32, 0x02, 0xae, 0x01, 0x08, 0x87, 0x28, 0x42, 0xd5, - 0x42, 0x6a, 0x73, 0x29, 0x91, 0xd1, 0x77, 0x13, 0x9a, 0x6a, 0x3a, 0x81, 0xd1, 0x45, 0xf5, 0xe9, - 0x52, 0x98, 0xc0, 0x28, 0x65, 0x66, 0x99, 0x9d, 0x56, 0xe5, 0x4e, 0x8b, 0xfa, 0xf1, 0x2e, 0xd9, - 0x6e, 0x19, 0xaa, 0x3e, 0x19, 0xb4, 0x2e, 0xa9, 0x9a, 0xed, 0x93, 0x81, 0xfd, 0x09, 0x2c, 0xff, - 0x30, 0xc6, 0xd1, 0xd1, 0x13, 0xda, 0x65, 0xb3, 0x85, 0x4f, 0x1b, 0xea, 0x3a, 0x06, 0x92, 0x6a, - 0x9b, 0x8e, 0xed, 0xbf, 0x1b, 0xd0, 0x94, 0xb9, 0xf0, 0x85, 0xc7, 0x0e, 0x93, 0xa7, 0xb3, 0x24, - 0x80, 0x8c, 0x7c, 0x00, 0x9d, 0xb2, 0x59, 0x2c, 0x79, 0xf7, 0xa9, 0x96, 0xbd, 0xfb, 0x94, 0x90, - 0xd0, 0x5a, 0x29, 0x09, 0x2d, 0x74, 0x9f, 0x73, 0x63, 0xdd, 0xe7, 0x97, 0x06, 0xac, 0x64, 0x6c, - 0x74, 0x96, 0x6a, 0x94, 0xb3, 0x6c, 0xa5, 0x68, 0xd9, 0x07, 0xf9, 0x2a, 0x5d, 0x2d, 0x8b, 0xa2, - 0x4c, 0x95, 0x4e, 0x6c, 0x9c, 0xab, 0xd4, 0x4f, 0x61, 0x49, 0xf0, 0xa8, 0xf3, 0x71, 0xe7, 0xdf, - 0x0c, 0x58, 0x78, 0x42, 0xbb, 0xd2, 0x91, 0xd9, 0xf0, 0x34, 0xf2, 0xe1, 0xa9, 0x51, 0x55, 0x49, - 0x51, 0xa5, 0x50, 0xed, 0x45, 0x7c, 0xf4, 0x2a, 0x2a, 0x62, 0x4d, 0x48, 0xe4, 0xc3, 0xda, 0x65, - 0xa8, 0xe3, 0xd0, 0x57, 0x93, 0xba, 0x95, 0xc1, 0xa1, 0x2f, 0xa7, 0xce, 0xa7, 0x3b, 0x5d, 0x85, - 0xb9, 0x21, 0x1d, 0xbd, 0x64, 0xaa, 0x81, 0xbd, 0x0a, 0xe8, 0x11, 0xe6, 0x4f, 0x68, 0x57, 0x78, - 0x25, 0x31, 0x8f, 0xfd, 0x97, 0x8a, 0xec, 0x1c, 0x47, 0xe2, 0xb3, 0x38, 0xd8, 0x86, 0xa6, 0xe2, - 0x12, 0x9f, 0xd1, 0xae, 0x1b, 0xc6, 0x89, 0x51, 0x2c, 0x29, 0x7c, 0x42, 0xbb, 0xcf, 0xe3, 0x01, - 0xba, 0x0b, 0x17, 0x49, 0xe8, 0x0e, 0x35, 0xbd, 0x49, 0x35, 0x95, 0x95, 0x96, 0x49, 0x98, 0x10, - 0x1f, 0xad, 0x7e, 0x1b, 0x96, 0x70, 0xf8, 0x2a, 0xc6, 0x31, 0x4e, 0x55, 0x95, 0xcd, 0x9a, 0x5a, - 0xac, 0xf5, 0x04, 0x8d, 0xf1, 0xd8, 0xa1, 0xcb, 0x02, 0xca, 0x99, 0x2e, 0x37, 0xa6, 0x90, 0xec, - 0x09, 0x01, 0xfa, 0x10, 0x4c, 0xb1, 0x5c, 0x41, 0x4b, 0x75, 0x80, 0x57, 0xca, 0xa0, 0xa5, 0xfd, - 0xed, 0xd4, 0x3f, 0x53, 0x3f, 0x98, 0x08, 0x10, 0xdd, 0x13, 0xf9, 0x84, 0x1d, 0x6a, 0x1a, 0x00, - 0x4a, 0xb4, 0x43, 0xd8, 0xa1, 0xfd, 0x13, 0xb8, 0x9c, 0x7d, 0x53, 0x23, 0x8c, 0x93, 0xde, 0x79, - 0x52, 0xc3, 0xdf, 0x1a, 0xd0, 0x2e, 0x3b, 0xe0, 0xbf, 0xc8, 0x88, 0x37, 0x7f, 0x69, 0x01, 0xc8, - 0x99, 0x6d, 0x4a, 0x23, 0x1f, 0x05, 0x12, 0x5a, 0xdb, 0x74, 0x30, 0xa4, 0x21, 0x0e, 0xb9, 0xcc, - 0x58, 0x0c, 0x6d, 0xe4, 0xf7, 0xd3, 0x83, 0x71, 0x45, 0x6d, 0xab, 0xf6, 0x7b, 0xa5, 0xfa, 0x05, - 0x65, 0xfb, 0x02, 0x7a, 0x25, 0x3b, 0xc7, 0x91, 0x29, 0xb6, 0x0f, 0xbc, 0x30, 0xc4, 0x01, 0xda, - 0x9c, 0xf0, 0xce, 0x5a, 0xa6, 0x9c, 0x9c, 0x79, 0xb3, 0xf4, 0xcc, 0x3d, 0x1e, 0x91, 0x70, 0x3f, - 0x31, 0xb1, 0x7d, 0x01, 0xbd, 0x00, 0x2b, 0xf3, 0xd8, 0x85, 0x6e, 0x97, 0x59, 0x6a, 0xfc, 0x35, - 0xac, 0x7d, 0x9c, 0x2f, 0xec, 0x0b, 0xa8, 0x0f, 0xcd, 0xdc, 0x6b, 0x2c, 0x5a, 0x3f, 0xae, 0x61, - 0xcd, 0x3e, 0x81, 0xb6, 0xbf, 0x3e, 0x83, 0x66, 0x7a, 0xfb, 0x9f, 0x29, 0x83, 0x8d, 0x3d, 0x67, - 0xde, 0x9b, 0xb0, 0xc9, 0xa4, 0x87, 0xd7, 0xf6, 0xfd, 0xd9, 0x17, 0xa4, 0x87, 0xfb, 0xa3, 0x8f, - 0x54, 0x01, 0x75, 0x67, 0x7a, 0x57, 0xae, 0x4e, 0x5b, 0x9f, 0xb5, 0x7d, 0xb7, 0x2f, 0xa0, 0x5d, - 0x30, 0xd3, 0x06, 0x1a, 0xbd, 0x57, 0xb6, 0xb0, 0xd8, 0x5f, 0xcf, 0xe0, 0x9c, 0x5c, 0x0b, 0x5a, - 0xee, 0x9c, 0xb2, 0xfe, 0xb8, 0xdc, 0x39, 0xa5, 0xfd, 0xac, 0x7d, 0x01, 0xc5, 0x32, 0x76, 0x0a, - 0xd1, 0x8d, 0xee, 0x4e, 0xf3, 0x6f, 0x2e, 0xcd, 0xb4, 0x37, 0x66, 0x55, 0x4f, 0x8f, 0xfd, 0xf9, - 0xe8, 0x2f, 0x01, 0xb9, 0x7e, 0x13, 0xdd, 0x3f, 0x6e, 0xab, 0xb2, 0xf6, 0xb7, 0xfd, 0xad, 0x13, - 0xac, 0xc8, 0x60, 0x12, 0xed, 0x1d, 0xd0, 0x37, 0x8a, 0x87, 0xc7, 0x91, 0x27, 0x72, 0x61, 0xc9, - 0xe1, 0x3a, 0x84, 0xc7, 0x55, 0x27, 0x1e, 0x7e, 0xcc, 0x8a, 0xf4, 0x70, 0x17, 0xe0, 0x11, 0xe6, - 0xcf, 0x30, 0x8f, 0x84, 0xad, 0x6f, 0x4f, 0xca, 0x53, 0x5a, 0x21, 0x39, 0xea, 0xce, 0x54, 0xbd, - 0xf4, 0x80, 0x2e, 0x58, 0xdb, 0x07, 0xb8, 0x77, 0xf8, 0x18, 0x7b, 0x01, 0x3f, 0x40, 0xe5, 0x2b, - 0x33, 0x1a, 0x13, 0x20, 0x5f, 0xa6, 0x98, 0x9c, 0xb1, 0xf9, 0xe5, 0xbc, 0xfe, 0x1f, 0x82, 0xe7, - 0xd4, 0xc7, 0xff, 0xfb, 0x29, 0x78, 0x17, 0xcc, 0xb4, 0xd3, 0x2d, 0x8f, 0xf0, 0x62, 0x23, 0x3c, - 0x2d, 0xc2, 0x3f, 0x05, 0x33, 0x25, 0xb6, 0xe5, 0x3b, 0x16, 0x7b, 0x83, 0xf6, 0xad, 0x29, 0x5a, - 0xe9, 0x6d, 0x9f, 0x43, 0x3d, 0x21, 0xa2, 0xe8, 0xe6, 0xa4, 0x74, 0x94, 0xdd, 0x79, 0xca, 0x5d, - 0x7f, 0x0a, 0x56, 0x86, 0xa5, 0x95, 0x17, 0xa0, 0x71, 0x76, 0xd7, 0xbe, 0x33, 0x55, 0xef, 0xff, - 0x23, 0x20, 0x1f, 0x7c, 0xfb, 0xd3, 0xcd, 0x7d, 0xc2, 0x0f, 0xe2, 0xae, 0xb0, 0xec, 0x3d, 0xa5, - 0x79, 0x97, 0x50, 0xfd, 0xeb, 0x5e, 0x72, 0xcb, 0x7b, 0x72, 0xa7, 0x7b, 0xd2, 0x4e, 0xc3, 0x6e, - 0x77, 0x5e, 0x0e, 0xdf, 0xff, 0x77, 0x00, 0x00, 0x00, 0xff, 0xff, 0x20, 0xf8, 0x84, 0x75, 0x10, - 0x24, 0x00, 0x00, + // 2486 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe4, 0x5a, 0xdd, 0x6e, 0x1c, 0x49, + 0xf5, 0x4f, 0xcf, 0xf8, 0x63, 0xfa, 0xf4, 0x8c, 0x3f, 0x2a, 0xce, 0xfe, 0x27, 0x93, 0xe4, 0x1f, + 0xa7, 0xb3, 0x49, 0x06, 0x44, 0x9c, 0xe0, 0x65, 0xd1, 0x82, 0x16, 0x24, 0xc7, 0xde, 0x24, 0x93, + 0xac, 0x23, 0xd3, 0x8e, 0x22, 0xb1, 0x42, 0x34, 0x3d, 0xd3, 0x35, 0x76, 0xad, 0x7b, 0xba, 0x26, + 0x5d, 0xd5, 0x49, 0x1c, 0x24, 0x04, 0x17, 0x20, 0x81, 0x56, 0x42, 0x20, 0x24, 0x5e, 0x80, 0xab, + 0xe5, 0x09, 0xe0, 0x06, 0x09, 0x71, 0xc9, 0x2b, 0xf0, 0x10, 0xbc, 0x01, 0xaa, 0x8f, 0xee, 0xe9, + 0xee, 0xe9, 0xf1, 0x4c, 0x6c, 0x23, 0x24, 0xb8, 0x9b, 0x3a, 0x75, 0xea, 0xa3, 0xcf, 0xf9, 0x9d, + 0x73, 0x7e, 0xa7, 0x6c, 0x58, 0x25, 0xa1, 0x8f, 0xdf, 0xb8, 0x3d, 0x4a, 0x23, 0x7f, 0x63, 0x18, + 0x51, 0x4e, 0x11, 0x1a, 0x90, 0xe0, 0x55, 0xcc, 0xd4, 0x68, 0x43, 0xce, 0xb7, 0xea, 0x3d, 0x3a, + 0x18, 0xd0, 0x50, 0xc9, 0x5a, 0x4b, 0x24, 0xe4, 0x38, 0x0a, 0xbd, 0x40, 0x8f, 0xeb, 0xd9, 0x15, + 0xad, 0x3a, 0xeb, 0x1d, 0xe2, 0x81, 0xa7, 0x46, 0xf6, 0x3f, 0xe6, 0xc0, 0xec, 0x88, 0x3d, 0x3a, + 0x61, 0x9f, 0x22, 0x1b, 0xea, 0x3d, 0x1a, 0x04, 0xb8, 0xc7, 0x09, 0x0d, 0x3b, 0x3b, 0x4d, 0x63, + 0xdd, 0x68, 0x57, 0x9d, 0x9c, 0x0c, 0x35, 0x61, 0xb1, 0x4f, 0x70, 0xe0, 0x77, 0x76, 0x9a, 0x15, + 0x39, 0x9d, 0x0c, 0xd1, 0x35, 0x00, 0x75, 0xdd, 0xd0, 0x1b, 0xe0, 0x66, 0x75, 0xdd, 0x68, 0x9b, + 0x8e, 0x29, 0x25, 0xcf, 0xbc, 0x01, 0x16, 0x0b, 0xe5, 0xa0, 0xb3, 0xd3, 0x9c, 0x53, 0x0b, 0xf5, + 0x10, 0x3d, 0x00, 0x8b, 0x1f, 0x0f, 0xb1, 0x3b, 0xf4, 0x22, 0x6f, 0xc0, 0x9a, 0xf3, 0xeb, 0xd5, + 0xb6, 0xb5, 0x79, 0x63, 0x23, 0xf7, 0xa1, 0xfa, 0x0b, 0x9f, 0xe2, 0xe3, 0x17, 0x5e, 0x10, 0xe3, + 0x3d, 0x8f, 0x44, 0x0e, 0x88, 0x55, 0x7b, 0x72, 0x11, 0xda, 0x81, 0xba, 0x3a, 0x5c, 0x6f, 0xb2, + 0x30, 0xeb, 0x26, 0x96, 0x5c, 0xa6, 0x77, 0xb9, 0xa1, 0x77, 0xc1, 0xbe, 0x1b, 0xd1, 0xd7, 0xac, + 0xb9, 0x28, 0x2f, 0x6a, 0x69, 0x99, 0x43, 0x5f, 0x33, 0xf1, 0x95, 0x9c, 0x72, 0x2f, 0x50, 0x0a, + 0x35, 0xa9, 0x60, 0x4a, 0x89, 0x9c, 0xfe, 0x10, 0xe6, 0x19, 0xf7, 0x38, 0x6e, 0x9a, 0xeb, 0x46, + 0x7b, 0x69, 0xf3, 0x7a, 0xe9, 0x05, 0xa4, 0xc5, 0xf7, 0x85, 0x9a, 0xa3, 0xb4, 0xd1, 0x87, 0xf0, + 0x7f, 0xea, 0xfa, 0x72, 0xe8, 0xf6, 0x3d, 0x12, 0xb8, 0x11, 0xf6, 0x18, 0x0d, 0x9b, 0x20, 0x0d, + 0xb9, 0x46, 0xd2, 0x35, 0x0f, 0x3d, 0x12, 0x38, 0x72, 0x0e, 0xd9, 0xd0, 0x20, 0xcc, 0xf5, 0x62, + 0x4e, 0x5d, 0x39, 0xdf, 0xb4, 0xd6, 0x8d, 0x76, 0xcd, 0xb1, 0x08, 0xdb, 0x8a, 0x39, 0x95, 0xc7, + 0xa0, 0x5d, 0x58, 0x8d, 0x19, 0x8e, 0xdc, 0x9c, 0x79, 0xea, 0xb3, 0x9a, 0x67, 0x59, 0xac, 0xed, + 0x64, 0x4c, 0xf4, 0x35, 0x40, 0x43, 0x1c, 0xfa, 0x24, 0x3c, 0xd0, 0x3b, 0x4a, 0x3b, 0x34, 0xa4, + 0x1d, 0x56, 0xf4, 0x8c, 0xd4, 0x17, 0xe6, 0xb0, 0x7f, 0x6e, 0x00, 0x3c, 0x94, 0xf8, 0x90, 0x77, + 0xf9, 0x38, 0x81, 0x08, 0x09, 0xfb, 0x54, 0xc2, 0xcb, 0xda, 0xbc, 0xb6, 0x31, 0x8e, 0xe8, 0x8d, + 0x14, 0x93, 0x1a, 0x41, 0x12, 0x9e, 0x4d, 0x58, 0xf4, 0x71, 0x80, 0x39, 0xf6, 0x25, 0xf4, 0x6a, + 0x4e, 0x32, 0x44, 0xd7, 0xc1, 0xea, 0x45, 0x58, 0x58, 0x8e, 0x13, 0x8d, 0xbd, 0x39, 0x07, 0x94, + 0xe8, 0x39, 0x19, 0x60, 0xfb, 0x4f, 0x73, 0x50, 0xdf, 0xc7, 0x07, 0x03, 0x1c, 0x72, 0x75, 0x93, + 0x59, 0xa0, 0xbe, 0x0e, 0xd6, 0xd0, 0x8b, 0x38, 0xd1, 0x2a, 0x0a, 0xee, 0x59, 0x11, 0xba, 0x0a, + 0x26, 0xd3, 0xbb, 0xee, 0xc8, 0x53, 0xab, 0xce, 0x48, 0x80, 0x2e, 0x43, 0x2d, 0x8c, 0x07, 0xca, + 0x40, 0x1a, 0xf2, 0x61, 0x3c, 0x90, 0x30, 0xc9, 0x04, 0xc3, 0x7c, 0x3e, 0x18, 0x9a, 0xb0, 0xd8, + 0x8d, 0x89, 0x8c, 0xaf, 0x05, 0x35, 0xa3, 0x87, 0xe8, 0x3d, 0x58, 0x08, 0xa9, 0x8f, 0x3b, 0x3b, + 0x1a, 0x96, 0x7a, 0x84, 0x6e, 0x42, 0x43, 0x19, 0xf5, 0x15, 0x8e, 0x18, 0xa1, 0xa1, 0x06, 0xa5, + 0x42, 0xf2, 0x0b, 0x25, 0x3b, 0x2d, 0x2e, 0xaf, 0x83, 0x35, 0x8e, 0x45, 0xe8, 0x8f, 0x10, 0x78, + 0x1b, 0x96, 0xd5, 0xe1, 0x7d, 0x12, 0x60, 0xf7, 0x08, 0x1f, 0xb3, 0xa6, 0xb5, 0x5e, 0x6d, 0x9b, + 0x8e, 0xba, 0xd3, 0x43, 0x12, 0xe0, 0xa7, 0xf8, 0x98, 0x65, 0x7d, 0x57, 0x3f, 0xd1, 0x77, 0x8d, + 0xa2, 0xef, 0xd0, 0x2d, 0x58, 0x62, 0x38, 0x22, 0x5e, 0x40, 0xde, 0x62, 0x97, 0x91, 0xb7, 0xb8, + 0xb9, 0x24, 0x75, 0x1a, 0xa9, 0x74, 0x9f, 0xbc, 0xc5, 0xc2, 0x0c, 0xaf, 0x23, 0xc2, 0xb1, 0x7b, + 0xe8, 0x85, 0x3e, 0xed, 0xf7, 0x9b, 0xcb, 0xf2, 0x9c, 0xba, 0x14, 0x3e, 0x56, 0x32, 0xb4, 0x09, + 0x97, 0x7a, 0x71, 0x14, 0xe1, 0x90, 0xbb, 0x79, 0x9b, 0xad, 0xac, 0x1b, 0xed, 0x79, 0xe7, 0xa2, + 0x9e, 0xec, 0x64, 0x4c, 0x67, 0xff, 0xde, 0x80, 0x8b, 0x0e, 0x3e, 0x20, 0x8c, 0xe3, 0xe8, 0x19, + 0xf5, 0xb1, 0x83, 0x5f, 0xc6, 0x98, 0x71, 0x74, 0x1f, 0xe6, 0xba, 0x1e, 0xc3, 0x1a, 0xc6, 0x57, + 0x4b, 0x2d, 0xba, 0xcb, 0x0e, 0x1e, 0x78, 0x0c, 0x3b, 0x52, 0x13, 0x7d, 0x13, 0x16, 0x3d, 0xdf, + 0x8f, 0x30, 0x63, 0x12, 0x4c, 0x93, 0x16, 0x6d, 0x29, 0x1d, 0x27, 0x51, 0xce, 0x78, 0xbe, 0x9a, + 0xf5, 0xbc, 0xfd, 0x6b, 0x03, 0xd6, 0xf2, 0x37, 0x63, 0x43, 0x1a, 0x32, 0x8c, 0x3e, 0x80, 0x05, + 0xe1, 0xbf, 0x98, 0xe9, 0xcb, 0x5d, 0x29, 0x3d, 0x67, 0x5f, 0xaa, 0x38, 0x5a, 0x55, 0xa4, 0x61, + 0x12, 0x12, 0x9e, 0xa4, 0x08, 0x75, 0xc3, 0x1b, 0xc5, 0xe8, 0xd4, 0xa5, 0xa5, 0x13, 0x12, 0xae, + 0x32, 0x82, 0x03, 0x24, 0xfd, 0x6d, 0x7f, 0x1f, 0xd6, 0x1e, 0x61, 0x9e, 0xc1, 0x91, 0xb6, 0xd5, + 0x2c, 0xe1, 0x96, 0xaf, 0x1f, 0x95, 0x42, 0xfd, 0xb0, 0xff, 0x60, 0xc0, 0xa5, 0xc2, 0xde, 0x67, + 0xf9, 0xda, 0x34, 0x20, 0x2a, 0x67, 0x09, 0x88, 0x6a, 0x31, 0x20, 0xec, 0x9f, 0x1a, 0x70, 0xe5, + 0x11, 0xe6, 0xd9, 0x64, 0x73, 0xce, 0x96, 0x40, 0xff, 0x0f, 0x90, 0x26, 0x19, 0xd6, 0xac, 0xae, + 0x57, 0xdb, 0x55, 0x27, 0x23, 0xb1, 0x7f, 0x69, 0xc0, 0xea, 0xd8, 0xf9, 0xf9, 0x5c, 0x65, 0x14, + 0x73, 0xd5, 0xbf, 0xcb, 0x1c, 0xbf, 0x35, 0xe0, 0x6a, 0xb9, 0x39, 0xce, 0xe2, 0xbc, 0xef, 0xa8, + 0x45, 0x58, 0xa0, 0x54, 0x14, 0xb2, 0x5b, 0x65, 0x35, 0x64, 0xfc, 0x4c, 0xbd, 0xc8, 0xfe, 0xa2, + 0x0a, 0x68, 0x5b, 0x26, 0x18, 0x55, 0xa9, 0xde, 0xc1, 0x35, 0xa7, 0xa6, 0x3f, 0x05, 0x92, 0x33, + 0x77, 0x1e, 0x24, 0x67, 0xfe, 0x54, 0x24, 0xe7, 0x2a, 0x98, 0x22, 0xd3, 0x32, 0xee, 0x0d, 0x86, + 0xb2, 0xc6, 0xcc, 0x39, 0x23, 0xc1, 0x38, 0xa5, 0x58, 0x9c, 0x91, 0x52, 0xd4, 0x4e, 0x4b, 0x29, + 0xec, 0x37, 0x70, 0x31, 0x09, 0x6c, 0x59, 0xf2, 0xdf, 0xc1, 0x1d, 0xf9, 0x50, 0xa8, 0x14, 0x43, + 0x61, 0x8a, 0x53, 0xec, 0x3f, 0x56, 0x61, 0xb5, 0x93, 0xd4, 0xa9, 0x3d, 0x8f, 0x1f, 0x4a, 0x9e, + 0x71, 0x72, 0xa4, 0x4c, 0x46, 0x40, 0xa6, 0xa8, 0x57, 0x27, 0x16, 0xf5, 0xb9, 0x7c, 0x51, 0xcf, + 0x5f, 0x70, 0xbe, 0x88, 0x9a, 0xf3, 0xa1, 0xb5, 0x6d, 0x58, 0xc9, 0x14, 0xe9, 0xa1, 0xc7, 0x0f, + 0x05, 0xb5, 0x15, 0x55, 0x7a, 0x89, 0x64, 0xbf, 0x9e, 0xa1, 0x3b, 0xb0, 0x9c, 0x56, 0x55, 0x5f, + 0x15, 0xdb, 0x9a, 0x44, 0xc8, 0xa8, 0x04, 0xfb, 0x49, 0xb5, 0xcd, 0x17, 0x50, 0xb3, 0x84, 0x74, + 0x64, 0x09, 0x10, 0xe4, 0x09, 0xd0, 0xc4, 0x42, 0x6c, 0x4d, 0x2e, 0xc4, 0x7f, 0x36, 0xc0, 0x4a, + 0x83, 0x7a, 0xc6, 0x76, 0x25, 0xe7, 0xcb, 0x4a, 0xd1, 0x97, 0x37, 0xa0, 0x8e, 0x43, 0xaf, 0x1b, + 0x60, 0x8d, 0xf5, 0xaa, 0xc2, 0xba, 0x92, 0x29, 0xac, 0x3f, 0x04, 0x6b, 0x44, 0x59, 0x93, 0xb8, + 0xbd, 0x35, 0x91, 0xb3, 0x66, 0x81, 0xe4, 0x40, 0xca, 0x5d, 0x99, 0xfd, 0xab, 0xca, 0xa8, 0x34, + 0x2a, 0x94, 0x9f, 0x25, 0x01, 0xfe, 0x00, 0xea, 0xfa, 0x2b, 0x14, 0x95, 0x56, 0x69, 0xf0, 0x5b, + 0x65, 0xd7, 0x2a, 0x3b, 0x74, 0x23, 0x63, 0xc6, 0x4f, 0x42, 0x1e, 0x1d, 0x3b, 0x16, 0x1b, 0x49, + 0x5a, 0x2e, 0xac, 0x14, 0x15, 0xd0, 0x0a, 0x54, 0x8f, 0xf0, 0xb1, 0xb6, 0xb1, 0xf8, 0x29, 0x4a, + 0xc6, 0x2b, 0x81, 0x37, 0xcd, 0x14, 0xae, 0x9f, 0x98, 0x83, 0xfb, 0xd4, 0x51, 0xda, 0xdf, 0xae, + 0x7c, 0x64, 0xd8, 0xbf, 0x33, 0x60, 0x65, 0x27, 0xa2, 0xc3, 0x77, 0x4e, 0xbf, 0x36, 0xd4, 0x33, + 0xfc, 0x3b, 0x89, 0xf8, 0x9c, 0x6c, 0x5a, 0x22, 0xbe, 0x0c, 0x35, 0x3f, 0xa2, 0x43, 0xd7, 0x0b, + 0x02, 0x19, 0x8c, 0x82, 0x8a, 0x46, 0x74, 0xb8, 0x15, 0x04, 0xf6, 0x6b, 0x58, 0xdb, 0xc1, 0xac, + 0x17, 0x91, 0xee, 0xbb, 0x17, 0x86, 0x29, 0x35, 0x3b, 0x97, 0x74, 0xab, 0x85, 0xa4, 0x6b, 0x7f, + 0x61, 0xc0, 0xa5, 0xc2, 0xc9, 0x67, 0x41, 0xc7, 0x77, 0xf3, 0x98, 0x55, 0xe0, 0x98, 0xd2, 0x67, + 0x65, 0xb1, 0xea, 0xc9, 0x9a, 0x2d, 0xe7, 0x1e, 0x88, 0x3c, 0xb5, 0x17, 0xd1, 0x03, 0xc9, 0x48, + 0xcf, 0x8f, 0xcd, 0xfd, 0xcd, 0x80, 0x6b, 0x13, 0xce, 0x38, 0xcb, 0x97, 0x17, 0x1b, 0xf8, 0xca, + 0xb4, 0x06, 0xbe, 0x5a, 0x6c, 0xe0, 0xcb, 0xfb, 0xdb, 0xb9, 0x09, 0xfd, 0xed, 0x5f, 0xab, 0xd0, + 0xd8, 0xe7, 0x34, 0xf2, 0x0e, 0xf0, 0x36, 0x0d, 0xfb, 0xe4, 0x40, 0xa4, 0xfa, 0x84, 0xe3, 0x1b, + 0xf2, 0xa3, 0x53, 0x16, 0x7f, 0x03, 0xea, 0x5e, 0xaf, 0x87, 0x19, 0x13, 0x6d, 0x92, 0xce, 0x46, + 0xa6, 0x63, 0x29, 0xd9, 0x53, 0x21, 0x42, 0x5f, 0x85, 0x55, 0x86, 0x7b, 0x11, 0xe6, 0xee, 0x48, + 0x53, 0x23, 0x78, 0x59, 0x4d, 0x6c, 0x25, 0xda, 0xa2, 0x29, 0x88, 0x19, 0xde, 0xdf, 0xff, 0x54, + 0xa3, 0x58, 0x8f, 0x04, 0x25, 0xeb, 0xc6, 0xbd, 0x23, 0xcc, 0xb3, 0x25, 0x05, 0x94, 0x48, 0x42, + 0xf1, 0x0a, 0x98, 0x11, 0xa5, 0x5c, 0xd6, 0x01, 0x59, 0xff, 0x4d, 0xa7, 0x26, 0x04, 0x22, 0x6d, + 0xe9, 0x5d, 0x3b, 0x5b, 0xbb, 0xba, 0xee, 0xeb, 0x91, 0xe8, 0x85, 0x3b, 0x5b, 0xbb, 0x9f, 0x84, + 0xfe, 0x90, 0x92, 0x90, 0xcb, 0xa2, 0x60, 0x3a, 0x59, 0x91, 0xf8, 0x3c, 0xa6, 0x2c, 0xe1, 0x0a, + 0xca, 0x22, 0x0b, 0x82, 0xe9, 0x58, 0x5a, 0xf6, 0xfc, 0x78, 0x88, 0x45, 0x1d, 0x8a, 0x19, 0x76, + 0x5f, 0x91, 0x88, 0xc7, 0x5e, 0xe0, 0x1e, 0x52, 0xc6, 0x65, 0x5d, 0xa8, 0x39, 0x4b, 0x31, 0xc3, + 0x2f, 0x94, 0xf8, 0x31, 0x65, 0x5c, 0x5c, 0x23, 0xc2, 0x07, 0x49, 0x3d, 0x30, 0x1d, 0x3d, 0x12, + 0xbd, 0x60, 0x2f, 0xa0, 0xb1, 0xef, 0x0e, 0x23, 0xfa, 0x8a, 0xf8, 0x38, 0x92, 0xdd, 0xa4, 0xe9, + 0x34, 0xa4, 0x74, 0x4f, 0x0b, 0x85, 0x13, 0x23, 0x85, 0x55, 0xd9, 0x54, 0xd2, 0x98, 0xbb, 0x83, + 0xf4, 0x91, 0x42, 0xcf, 0x3c, 0x57, 0x13, 0xbb, 0xcc, 0xfe, 0xe7, 0x02, 0xac, 0x28, 0x3a, 0xf8, + 0x84, 0x76, 0x13, 0x8c, 0x5f, 0x05, 0xb3, 0x17, 0xc4, 0xa2, 0xb3, 0xd2, 0x00, 0x37, 0x9d, 0x91, + 0x40, 0x38, 0x2a, 0x5b, 0x51, 0x23, 0xdc, 0x27, 0x6f, 0xb4, 0x43, 0x97, 0x47, 0x25, 0x55, 0x8a, + 0xb3, 0xc5, 0xbf, 0x3a, 0x56, 0xfc, 0x7d, 0x8f, 0x7b, 0xba, 0x22, 0xcf, 0xc9, 0x8a, 0x6c, 0x0a, + 0x89, 0x2a, 0xc6, 0x63, 0x35, 0x76, 0xbe, 0xa4, 0xc6, 0x66, 0x48, 0xc7, 0x42, 0x9e, 0x74, 0xe4, + 0x23, 0x70, 0xb1, 0x98, 0x91, 0x1e, 0xc3, 0x52, 0xe2, 0xaf, 0x9e, 0x84, 0xae, 0x74, 0x6a, 0x49, + 0xc7, 0x27, 0xf3, 0x78, 0x16, 0xe3, 0x4e, 0x83, 0xe5, 0x20, 0x5f, 0x24, 0x29, 0xe6, 0xa9, 0x48, + 0x4a, 0x81, 0x20, 0xc3, 0x69, 0x08, 0x72, 0x96, 0x70, 0x58, 0x33, 0x12, 0x8e, 0xfa, 0x44, 0xc2, + 0x31, 0x96, 0xe7, 0x1a, 0xd3, 0x1f, 0x89, 0x96, 0xa6, 0x3c, 0x12, 0x2d, 0x9f, 0x40, 0x27, 0x57, + 0xc6, 0x1a, 0x0a, 0xf9, 0x53, 0xf9, 0x6f, 0x55, 0xf9, 0x4f, 0x4a, 0xa4, 0xff, 0x3e, 0x4e, 0xa6, + 0x65, 0xb4, 0x21, 0xd9, 0xb6, 0x15, 0x72, 0xbc, 0x7e, 0xea, 0xdd, 0xf1, 0xb8, 0x27, 0xe2, 0x4f, + 0xaf, 0x96, 0xa1, 0x78, 0x0d, 0x40, 0x38, 0x51, 0xb1, 0xc1, 0xe6, 0x45, 0xb5, 0xb9, 0x94, 0xc8, + 0x34, 0x70, 0x13, 0x1a, 0x6a, 0x3a, 0xb1, 0xd2, 0x9a, 0xfa, 0x74, 0x29, 0x4c, 0xcc, 0x93, 0xd2, + 0xca, 0xcc, 0x4e, 0x97, 0xe4, 0x4e, 0x4b, 0xfa, 0xb5, 0x32, 0xd9, 0x6e, 0x05, 0xaa, 0x3e, 0x19, + 0x34, 0xdf, 0x53, 0xe4, 0xc1, 0x27, 0x03, 0xfb, 0x53, 0x58, 0xf9, 0x5e, 0x8c, 0xa3, 0xe3, 0x27, + 0xb4, 0xcb, 0x66, 0x0b, 0xb9, 0x16, 0xd4, 0x74, 0xdc, 0x24, 0x65, 0x3f, 0x1d, 0xdb, 0xbf, 0xa8, + 0x40, 0x43, 0x7a, 0xee, 0xb9, 0xc7, 0x8e, 0x92, 0xb7, 0xc2, 0x24, 0xe8, 0x8c, 0x7c, 0xd0, 0x9d, + 0xb2, 0xd3, 0x2d, 0x79, 0xe8, 0xaa, 0x96, 0x3d, 0x74, 0x95, 0x30, 0xe8, 0xb9, 0x52, 0x06, 0x5d, + 0x68, 0x9d, 0xe7, 0xc7, 0x9e, 0xd6, 0x26, 0x22, 0x76, 0x61, 0x32, 0x45, 0xfe, 0xd2, 0x80, 0xd5, + 0x8c, 0x5d, 0xcf, 0x52, 0x4a, 0x73, 0xde, 0xa8, 0x14, 0xbd, 0xf1, 0x20, 0x4f, 0x31, 0xaa, 0x65, + 0xd1, 0x9a, 0xa1, 0x18, 0x89, 0x5f, 0x72, 0x34, 0xe3, 0x29, 0x2c, 0x0b, 0x12, 0x78, 0x3e, 0x10, + 0xf8, 0xbb, 0x01, 0x8b, 0x4f, 0x68, 0x57, 0x3a, 0x3f, 0x9b, 0x06, 0x8c, 0x7c, 0x1a, 0xd0, 0x48, + 0xac, 0xa4, 0x48, 0x54, 0x91, 0xe0, 0x45, 0x7c, 0xf4, 0x74, 0x2c, 0xe2, 0x53, 0x48, 0xe4, 0xeb, + 0xe3, 0x65, 0xa8, 0xe1, 0xd0, 0x57, 0x93, 0xba, 0x77, 0xc3, 0xa1, 0x2f, 0xa7, 0xce, 0xa7, 0x1d, + 0x5f, 0x83, 0xf9, 0x21, 0x1d, 0x3d, 0xf7, 0xaa, 0x81, 0xbd, 0x06, 0xe8, 0x11, 0xe6, 0x4f, 0x68, + 0x57, 0x78, 0x25, 0x31, 0x8f, 0xfd, 0x97, 0x8a, 0x6c, 0x95, 0x47, 0xe2, 0xb3, 0x38, 0xd8, 0x86, + 0x86, 0x22, 0x42, 0x9f, 0xd3, 0xae, 0x1b, 0xc6, 0x89, 0x51, 0x2c, 0x29, 0x7c, 0x42, 0xbb, 0xcf, + 0xe2, 0x01, 0xba, 0x0b, 0x17, 0x49, 0x28, 0x8a, 0xad, 0xe4, 0x66, 0xa9, 0xa6, 0xb2, 0xd2, 0x0a, + 0x09, 0x13, 0xd6, 0xa6, 0xd5, 0x6f, 0xc3, 0x32, 0x0e, 0x5f, 0xc6, 0x38, 0xc6, 0xa9, 0xaa, 0xb2, + 0x59, 0x43, 0x8b, 0xb5, 0x9e, 0xe0, 0x60, 0x1e, 0x3b, 0x72, 0x59, 0x40, 0x39, 0xd3, 0x65, 0xcd, + 0x14, 0x92, 0x7d, 0x21, 0x40, 0x1f, 0x81, 0x29, 0x96, 0x2b, 0x68, 0xa9, 0x96, 0xf7, 0x4a, 0x19, + 0xb4, 0xb4, 0xbf, 0x9d, 0xda, 0xe7, 0xea, 0x07, 0x13, 0x41, 0xa5, 0x1b, 0x3a, 0x9f, 0xb0, 0x23, + 0xcd, 0x61, 0x40, 0x89, 0x76, 0x08, 0x3b, 0xb2, 0x7f, 0x08, 0x97, 0xb3, 0x8f, 0x88, 0x84, 0x71, + 0xd2, 0x3b, 0x4f, 0x5e, 0xfb, 0x1b, 0x03, 0x5a, 0x65, 0x07, 0xfc, 0x07, 0xe9, 0xfc, 0xe6, 0xcf, + 0x2c, 0x00, 0x39, 0xb3, 0x4d, 0x69, 0xe4, 0xa3, 0x40, 0x42, 0x6b, 0x9b, 0x0e, 0x86, 0x34, 0xc4, + 0x21, 0x97, 0x59, 0x8e, 0xa1, 0x8d, 0xfc, 0x7e, 0x7a, 0x30, 0xae, 0xa8, 0x6d, 0xd5, 0x7a, 0xbf, + 0x54, 0xbf, 0xa0, 0x6c, 0x5f, 0x40, 0x2f, 0x65, 0xdb, 0x3b, 0x32, 0xc5, 0xf6, 0xa1, 0x17, 0x86, + 0x38, 0x40, 0x9b, 0x13, 0x1e, 0x96, 0xcb, 0x94, 0x93, 0x33, 0x6f, 0x96, 0x9e, 0xb9, 0xcf, 0x23, + 0x12, 0x1e, 0x24, 0x26, 0xb6, 0x2f, 0xa0, 0xe7, 0x60, 0x65, 0x5e, 0xf7, 0xd0, 0xed, 0x32, 0x4b, + 0x8d, 0x3f, 0xff, 0xb5, 0x4e, 0xf2, 0x85, 0x7d, 0x01, 0xf5, 0xa1, 0x91, 0x7b, 0x7e, 0x46, 0xed, + 0x93, 0xba, 0xed, 0xec, 0x9b, 0x6f, 0xeb, 0x2b, 0x33, 0x68, 0xa6, 0xb7, 0xff, 0xb1, 0x32, 0xd8, + 0xd8, 0xfb, 0xed, 0xbd, 0x09, 0x9b, 0x4c, 0x7a, 0x69, 0x6e, 0xdd, 0x9f, 0x7d, 0x41, 0x7a, 0xb8, + 0x3f, 0xfa, 0x48, 0x15, 0x50, 0x77, 0xa6, 0x3f, 0x29, 0xa8, 0xd3, 0xda, 0xb3, 0xbe, 0x3d, 0xd8, + 0x17, 0xd0, 0x1e, 0x98, 0x69, 0xf7, 0x8f, 0xde, 0x2f, 0x5b, 0x58, 0x7c, 0x1c, 0x98, 0xc1, 0x39, + 0xb9, 0xfe, 0xb9, 0xdc, 0x39, 0x65, 0xcd, 0x7d, 0xb9, 0x73, 0x4a, 0x9b, 0x71, 0xfb, 0x02, 0x8a, + 0x65, 0xec, 0x14, 0xa2, 0x1b, 0xdd, 0x9d, 0xe6, 0xdf, 0x5c, 0x9a, 0x69, 0x6d, 0xcc, 0xaa, 0x9e, + 0x1e, 0xfb, 0x93, 0xd1, 0x9f, 0x3e, 0x72, 0xcd, 0x32, 0xba, 0x7f, 0xd2, 0x56, 0x65, 0xbd, 0x7b, + 0xeb, 0xeb, 0xef, 0xb0, 0x22, 0x83, 0x49, 0xb4, 0x7f, 0x48, 0x5f, 0x2b, 0xbe, 0x1f, 0x47, 0x9e, + 0xc8, 0x85, 0x25, 0x87, 0xeb, 0x10, 0x1e, 0x57, 0x9d, 0x78, 0xf8, 0x09, 0x2b, 0xd2, 0xc3, 0x5d, + 0x80, 0x47, 0x98, 0xef, 0x62, 0x1e, 0x09, 0x5b, 0xdf, 0x9e, 0x94, 0xa7, 0xb4, 0x42, 0x72, 0xd4, + 0x9d, 0xa9, 0x7a, 0xe9, 0x01, 0x5d, 0xb0, 0xb6, 0x0f, 0x71, 0xef, 0xe8, 0x31, 0xf6, 0x02, 0x7e, + 0x88, 0xca, 0x57, 0x66, 0x34, 0x26, 0x40, 0xbe, 0x4c, 0x31, 0x39, 0x63, 0xf3, 0xcb, 0x05, 0xfd, + 0x8f, 0x16, 0xcf, 0xa8, 0x8f, 0xff, 0xfb, 0x53, 0xf0, 0x1e, 0x98, 0x69, 0x47, 0x5d, 0x1e, 0xe1, + 0xc5, 0x86, 0x7b, 0x5a, 0x84, 0x7f, 0x06, 0x66, 0x4a, 0x6c, 0xcb, 0x77, 0x2c, 0xf6, 0x13, 0xad, + 0x5b, 0x53, 0xb4, 0xd2, 0xdb, 0x3e, 0x83, 0x5a, 0x42, 0x44, 0xd1, 0xcd, 0x49, 0xe9, 0x28, 0xbb, + 0xf3, 0x94, 0xbb, 0xfe, 0x08, 0xac, 0x0c, 0x4b, 0x2b, 0x2f, 0x40, 0xe3, 0xec, 0xae, 0x75, 0x67, + 0xaa, 0xde, 0xff, 0x46, 0x40, 0x3e, 0xf8, 0xc6, 0x67, 0x9b, 0x07, 0x84, 0x1f, 0xc6, 0x5d, 0x61, + 0xd9, 0x7b, 0x4a, 0xf3, 0x2e, 0xa1, 0xfa, 0xd7, 0xbd, 0xe4, 0x96, 0xf7, 0xe4, 0x4e, 0xf7, 0xa4, + 0x9d, 0x86, 0xdd, 0xee, 0x82, 0x1c, 0x7e, 0xf0, 0xaf, 0x00, 0x00, 0x00, 0xff, 0xff, 0x07, 0x7d, + 0x82, 0x6d, 0x35, 0x25, 0x00, 0x00, } // Reference imports to suppress errors if they are not otherwise used. diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index 43ead7d37c5e7..9f768a7796594 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -144,6 +144,7 @@ message RetrieveRequest { bool is_count = 13; int64 iteration_extension_reduce_rate = 14; string username = 15; + bool reduce_stop_for_best = 16; } diff --git a/internal/proto/internalpb/internal.pb.go b/internal/proto/internalpb/internal.pb.go index 591fa1ce69591..30d24043d6c28 100644 --- a/internal/proto/internalpb/internal.pb.go +++ b/internal/proto/internalpb/internal.pb.go @@ -1108,6 +1108,7 @@ type RetrieveRequest struct { IsCount bool `protobuf:"varint,13,opt,name=is_count,json=isCount,proto3" json:"is_count,omitempty"` IterationExtensionReduceRate int64 `protobuf:"varint,14,opt,name=iteration_extension_reduce_rate,json=iterationExtensionReduceRate,proto3" json:"iteration_extension_reduce_rate,omitempty"` Username string `protobuf:"bytes,15,opt,name=username,proto3" json:"username,omitempty"` + ReduceStopForBest bool `protobuf:"varint,16,opt,name=reduce_stop_for_best,json=reduceStopForBest,proto3" json:"reduce_stop_for_best,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -1243,6 +1244,13 @@ func (m *RetrieveRequest) GetUsername() string { return "" } +func (m *RetrieveRequest) GetReduceStopForBest() bool { + if m != nil { + return m.ReduceStopForBest + } + return false +} + type RetrieveResults struct { Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` Status *commonpb.Status `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` @@ -1999,126 +2007,128 @@ func init() { func init() { proto.RegisterFile("internal.proto", fileDescriptor_41f4a519b878ee3b) } var fileDescriptor_41f4a519b878ee3b = []byte{ - // 1936 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xd4, 0x58, 0x4b, 0x73, 0x24, 0x39, - 0x11, 0xa6, 0xfa, 0xdd, 0xd9, 0xed, 0x76, 0x5b, 0xe3, 0x19, 0x6a, 0x1e, 0xbb, 0xe3, 0x2d, 0x5e, - 0x66, 0x89, 0xb5, 0x17, 0x6f, 0xec, 0x0e, 0x07, 0x02, 0x62, 0xec, 0xf2, 0x3a, 0x3a, 0xb6, 0x3d, - 0x78, 0xaa, 0x87, 0x8d, 0x80, 0x4b, 0x85, 0xba, 0x4a, 0x6e, 0x8b, 0xa9, 0x97, 0x25, 0x95, 0x1f, - 0x73, 0xe6, 0x46, 0x04, 0x37, 0x2e, 0x44, 0xc0, 0x3f, 0xe0, 0x4c, 0x70, 0xe2, 0x1f, 0x10, 0xfc, - 0x0f, 0xfe, 0xc1, 0x9e, 0x08, 0x3d, 0xaa, 0xbb, 0xba, 0xdd, 0x36, 0x1e, 0x0f, 0x8f, 0xe5, 0x26, - 0x65, 0x7e, 0x4a, 0x29, 0x53, 0xa9, 0x4f, 0x29, 0x41, 0x8f, 0x26, 0x82, 0xb0, 0x04, 0x47, 0x5b, - 0x19, 0x4b, 0x45, 0x8a, 0xee, 0xc7, 0x34, 0x3a, 0xcb, 0xb9, 0xee, 0x6d, 0x15, 0xca, 0x47, 0xdd, - 0x20, 0x8d, 0xe3, 0x34, 0xd1, 0xe2, 0x47, 0x5d, 0x1e, 0x9c, 0x90, 0x18, 0xeb, 0x9e, 0xf3, 0x18, - 0x1e, 0x1e, 0x10, 0xf1, 0x8a, 0xc6, 0xe4, 0x15, 0x0d, 0x5e, 0xef, 0x9d, 0xe0, 0x24, 0x21, 0x91, - 0x47, 0x4e, 0x73, 0xc2, 0x85, 0xf3, 0x1e, 0x3c, 0x3e, 0x20, 0x62, 0x24, 0xb0, 0xa0, 0x5c, 0xd0, - 0x80, 0x2f, 0xa8, 0xef, 0xc3, 0xbd, 0x03, 0x22, 0xdc, 0x70, 0x41, 0xfc, 0x25, 0xb4, 0x5e, 0xa4, - 0x21, 0x19, 0x24, 0xc7, 0x29, 0xfa, 0x0c, 0x9a, 0x38, 0x0c, 0x19, 0xe1, 0xdc, 0xb6, 0x36, 0xac, - 0xcd, 0xce, 0xce, 0x93, 0xad, 0xb9, 0x35, 0x9a, 0x95, 0x3d, 0xd7, 0x18, 0xaf, 0x00, 0x23, 0x04, - 0x35, 0x96, 0x46, 0xc4, 0xae, 0x6c, 0x58, 0x9b, 0x6d, 0x4f, 0xb5, 0x9d, 0x5f, 0x01, 0x0c, 0x12, - 0x2a, 0x8e, 0x30, 0xc3, 0x31, 0x47, 0x0f, 0xa0, 0x91, 0xc8, 0x59, 0x5c, 0x65, 0xb8, 0xea, 0x99, - 0x1e, 0x72, 0xa1, 0xcb, 0x05, 0x66, 0xc2, 0xcf, 0x14, 0xce, 0xae, 0x6c, 0x54, 0x37, 0x3b, 0x3b, - 0x1f, 0x2c, 0x9d, 0xf6, 0x0b, 0x72, 0xf9, 0x25, 0x8e, 0x72, 0x72, 0x84, 0x29, 0xf3, 0x3a, 0x6a, - 0x98, 0xb6, 0xee, 0xfc, 0x02, 0x60, 0x24, 0x18, 0x4d, 0x26, 0x43, 0xca, 0x85, 0x9c, 0xeb, 0x4c, - 0xe2, 0xa4, 0x13, 0xd5, 0xcd, 0xb6, 0x67, 0x7a, 0xe8, 0x13, 0x68, 0x70, 0x81, 0x45, 0xce, 0xd5, - 0x3a, 0x3b, 0x3b, 0x8f, 0x97, 0xce, 0x32, 0x52, 0x10, 0xcf, 0x40, 0x9d, 0x3f, 0x55, 0x60, 0x7d, - 0x2e, 0xaa, 0x26, 0x6e, 0xe8, 0x63, 0xa8, 0x8d, 0x31, 0x27, 0x37, 0x06, 0xea, 0x90, 0x4f, 0x76, - 0x31, 0x27, 0x9e, 0x42, 0xca, 0x28, 0x85, 0xe3, 0x81, 0xab, 0x66, 0xaf, 0x7a, 0xaa, 0x8d, 0x1c, - 0xe8, 0x06, 0x69, 0x14, 0x91, 0x40, 0xd0, 0x34, 0x19, 0xb8, 0x76, 0x55, 0xe9, 0xe6, 0x64, 0x12, - 0x93, 0x61, 0x26, 0xa8, 0xee, 0x72, 0xbb, 0xb6, 0x51, 0x95, 0x98, 0xb2, 0x0c, 0x7d, 0x1f, 0xfa, - 0x82, 0xe1, 0x33, 0x12, 0xf9, 0x82, 0xc6, 0x84, 0x0b, 0x1c, 0x67, 0x76, 0x7d, 0xc3, 0xda, 0xac, - 0x79, 0xab, 0x5a, 0xfe, 0xaa, 0x10, 0xa3, 0x6d, 0xb8, 0x37, 0xc9, 0x31, 0xc3, 0x89, 0x20, 0xa4, - 0x84, 0x6e, 0x28, 0x34, 0x9a, 0xaa, 0x66, 0x03, 0x7e, 0x00, 0x6b, 0x12, 0x96, 0xe6, 0xa2, 0x04, - 0x6f, 0x2a, 0x78, 0xdf, 0x28, 0xa6, 0x60, 0xe7, 0xcf, 0x16, 0xdc, 0x5f, 0x88, 0x17, 0xcf, 0xd2, - 0x84, 0x93, 0x3b, 0x04, 0xec, 0x2e, 0x1b, 0x86, 0x9e, 0x41, 0x5d, 0xb6, 0xb8, 0x5d, 0xbd, 0x6d, - 0x2a, 0x69, 0xbc, 0xf3, 0x47, 0x0b, 0xd0, 0x1e, 0x23, 0x58, 0x90, 0xe7, 0x11, 0xc5, 0xef, 0xb0, - 0xcf, 0xdf, 0x84, 0x66, 0x38, 0xf6, 0x13, 0x1c, 0x17, 0x07, 0xa2, 0x11, 0x8e, 0x5f, 0xe0, 0x98, - 0xa0, 0xef, 0xc1, 0xea, 0x6c, 0x63, 0x35, 0xa0, 0xaa, 0x00, 0xbd, 0x99, 0x58, 0x01, 0xd7, 0xa1, - 0x8e, 0xe5, 0x1a, 0xec, 0x9a, 0x52, 0xeb, 0x8e, 0xc3, 0xa1, 0xef, 0xb2, 0x34, 0xfb, 0x4f, 0xad, - 0x6e, 0x3a, 0x69, 0xb5, 0x3c, 0xe9, 0x1f, 0x2c, 0x58, 0x7b, 0x1e, 0x09, 0xc2, 0xbe, 0xa6, 0x41, - 0xf9, 0x6b, 0xa5, 0xd8, 0xb5, 0x41, 0x12, 0x92, 0x8b, 0xff, 0xe5, 0x02, 0xdf, 0x03, 0x38, 0xa6, - 0x24, 0x0a, 0x35, 0x46, 0xaf, 0xb2, 0xad, 0x24, 0x4a, 0x5d, 0x1c, 0xff, 0xfa, 0x0d, 0xc7, 0xbf, - 0xb1, 0xe4, 0xf8, 0xdb, 0xd0, 0x54, 0x46, 0x06, 0xae, 0x3a, 0x74, 0x55, 0xaf, 0xe8, 0x4a, 0xf2, - 0x24, 0x17, 0x82, 0xe1, 0x82, 0x3c, 0x5b, 0xb7, 0x26, 0x4f, 0x35, 0xcc, 0x90, 0xe7, 0x3f, 0x6a, - 0xb0, 0x32, 0x22, 0x98, 0x05, 0x27, 0x77, 0x0f, 0xde, 0x3a, 0xd4, 0x19, 0x39, 0x9d, 0x72, 0x9b, - 0xee, 0x4c, 0x3d, 0xae, 0xde, 0xe0, 0x71, 0xed, 0x16, 0x84, 0x57, 0x5f, 0x42, 0x78, 0x7d, 0xa8, - 0x86, 0x3c, 0x52, 0x01, 0x6b, 0x7b, 0xb2, 0x29, 0x69, 0x2a, 0x8b, 0x70, 0x40, 0x4e, 0xd2, 0x28, - 0x24, 0xcc, 0x9f, 0xb0, 0x34, 0xd7, 0x34, 0xd5, 0xf5, 0xfa, 0x25, 0xc5, 0x81, 0x94, 0xa3, 0x67, - 0xd0, 0x0a, 0x79, 0xe4, 0x8b, 0xcb, 0x8c, 0xd8, 0xad, 0x0d, 0x6b, 0xb3, 0x77, 0x8d, 0x9b, 0x2e, - 0x8f, 0x5e, 0x5d, 0x66, 0xc4, 0x6b, 0x86, 0xba, 0x81, 0x3e, 0x86, 0x75, 0x4e, 0x18, 0xc5, 0x11, - 0x7d, 0x43, 0x42, 0x9f, 0x5c, 0x64, 0xcc, 0xcf, 0x22, 0x9c, 0xd8, 0x6d, 0x35, 0x11, 0x9a, 0xe9, - 0xf6, 0x2f, 0x32, 0x76, 0x14, 0xe1, 0x04, 0x6d, 0x42, 0x3f, 0xcd, 0x45, 0x96, 0x0b, 0x5f, 0xed, - 0x1b, 0xf7, 0x69, 0x68, 0x83, 0xf2, 0xa8, 0xa7, 0xe5, 0x9f, 0x2b, 0xf1, 0x20, 0xbc, 0x8e, 0x99, - 0xbb, 0x6f, 0xc7, 0xcc, 0x2b, 0xcb, 0x99, 0x19, 0xf5, 0xa0, 0x92, 0x9c, 0xda, 0x3d, 0x15, 0xef, - 0x4a, 0x72, 0x2a, 0x77, 0x47, 0xa4, 0xd9, 0x6b, 0x7b, 0x55, 0xef, 0x8e, 0x6c, 0xa3, 0xf7, 0x01, - 0x62, 0x22, 0x18, 0x0d, 0xa4, 0xaf, 0x76, 0x5f, 0x05, 0xb7, 0x24, 0x41, 0xdf, 0x86, 0x15, 0x3a, - 0x49, 0x52, 0x46, 0x0e, 0x58, 0x7a, 0x4e, 0x93, 0x89, 0xbd, 0xb6, 0x61, 0x6d, 0xb6, 0xbc, 0x79, - 0x21, 0x7a, 0x04, 0xad, 0x9c, 0xcb, 0x62, 0x26, 0x26, 0x36, 0x52, 0x36, 0xa6, 0x7d, 0xe7, 0x6f, - 0xa5, 0x6c, 0xe3, 0x79, 0x24, 0xf8, 0x7f, 0xeb, 0x5e, 0x98, 0xa6, 0x68, 0xb5, 0x9c, 0xa2, 0x4f, - 0xa1, 0xa3, 0xdd, 0xd3, 0xa9, 0x50, 0xbb, 0xe2, 0xf1, 0x53, 0xe8, 0x24, 0x79, 0xec, 0x9f, 0xe6, - 0x84, 0x51, 0xc2, 0xcd, 0xe1, 0x85, 0x24, 0x8f, 0x5f, 0x6a, 0x09, 0xba, 0x07, 0x75, 0x91, 0x66, - 0xfe, 0x6b, 0x73, 0x76, 0x65, 0x1c, 0xbf, 0x40, 0x3f, 0x86, 0x47, 0x9c, 0xe0, 0x88, 0x84, 0x3e, - 0x27, 0x93, 0x98, 0x24, 0x62, 0xe0, 0x72, 0x9f, 0x2b, 0xb7, 0x49, 0x68, 0x37, 0xd5, 0xee, 0xdb, - 0x1a, 0x31, 0x9a, 0x02, 0x46, 0x46, 0x2f, 0xf3, 0x20, 0xd0, 0x45, 0xda, 0xdc, 0xb0, 0x96, 0xaa, - 0x66, 0xd0, 0x4c, 0x35, 0x1d, 0xf0, 0x23, 0xb0, 0x27, 0x51, 0x3a, 0xc6, 0x91, 0x7f, 0x65, 0x56, - 0xbb, 0xad, 0x26, 0x7b, 0xa0, 0xf5, 0xa3, 0x85, 0x29, 0xa5, 0x7b, 0x3c, 0xa2, 0x01, 0x09, 0xfd, - 0x71, 0x94, 0x8e, 0x6d, 0x50, 0x59, 0x0c, 0x5a, 0xb4, 0x1b, 0xa5, 0x63, 0x99, 0xbd, 0x06, 0x20, - 0xc3, 0x10, 0xa4, 0x79, 0x22, 0xec, 0x8e, 0xf2, 0xb4, 0xa7, 0xe5, 0x2f, 0xf2, 0x78, 0x4f, 0x4a, - 0xd1, 0xb7, 0x60, 0xc5, 0x20, 0xd3, 0xe3, 0x63, 0x4e, 0x84, 0xca, 0xdb, 0xaa, 0xd7, 0xd5, 0xc2, - 0x9f, 0x29, 0x19, 0x3a, 0x92, 0x64, 0xca, 0xc5, 0xf3, 0xc9, 0x84, 0x91, 0x09, 0x96, 0x87, 0x59, - 0xe5, 0x6b, 0x67, 0xe7, 0xbb, 0x5b, 0x4b, 0xab, 0xe1, 0xad, 0xbd, 0x79, 0xb4, 0xb7, 0x38, 0xdc, - 0x39, 0x85, 0xd5, 0x05, 0x8c, 0xe4, 0x0f, 0x66, 0xaa, 0x0e, 0x99, 0xfe, 0xa6, 0xe4, 0x9c, 0x93, - 0xa1, 0x0d, 0xe8, 0x70, 0xc2, 0xce, 0x68, 0xa0, 0x21, 0x9a, 0xb7, 0xca, 0x22, 0xc9, 0xbb, 0x22, - 0x15, 0x38, 0x7a, 0xf1, 0xd2, 0xa4, 0x4c, 0xd1, 0x75, 0xfe, 0x5e, 0x83, 0x55, 0x4f, 0xa6, 0x08, - 0x39, 0x23, 0xff, 0x4f, 0x9c, 0x79, 0x1d, 0x77, 0x35, 0xde, 0x8a, 0xbb, 0x9a, 0x4b, 0xb9, 0xeb, - 0x3b, 0xd0, 0x8b, 0xcf, 0x82, 0xa0, 0xc4, 0x43, 0x2d, 0xc5, 0x43, 0x2b, 0x52, 0xfa, 0x2f, 0x8b, - 0xcf, 0xf6, 0xdb, 0x51, 0x1c, 0x5c, 0x43, 0x71, 0xeb, 0x50, 0x8f, 0x68, 0x4c, 0x8b, 0x0c, 0xd5, - 0x9d, 0xab, 0xa4, 0xd5, 0x5d, 0x46, 0x5a, 0x0f, 0xa1, 0x45, 0xb9, 0x49, 0xf0, 0x15, 0x05, 0x68, - 0x52, 0xae, 0x33, 0x7b, 0x1f, 0x9e, 0x52, 0x41, 0x98, 0x4a, 0x2e, 0x9f, 0x5c, 0x08, 0x92, 0x70, - 0xd9, 0x62, 0x24, 0xcc, 0x03, 0xe2, 0x33, 0x2c, 0x88, 0xa1, 0xd5, 0x27, 0x53, 0xd8, 0x7e, 0x81, - 0xf2, 0x14, 0xc8, 0xc3, 0x82, 0xcc, 0xd1, 0xe2, 0xea, 0x02, 0x2d, 0x7e, 0x55, 0x2d, 0xa7, 0xd4, - 0xd7, 0x80, 0x18, 0x3f, 0x84, 0x2a, 0x0d, 0x75, 0xad, 0xd5, 0xd9, 0xb1, 0xe7, 0xed, 0x98, 0x27, - 0xe9, 0xc0, 0xe5, 0x9e, 0x04, 0xa1, 0x9f, 0x42, 0xc7, 0xa4, 0x47, 0x88, 0x05, 0x56, 0xa9, 0xd7, - 0xd9, 0x79, 0x7f, 0xe9, 0x18, 0x95, 0x2f, 0x2e, 0x16, 0xd8, 0xd3, 0xb5, 0x12, 0x97, 0x6d, 0xf4, - 0x13, 0x78, 0x7c, 0x95, 0x2e, 0x99, 0x09, 0x47, 0x68, 0x37, 0x54, 0xc6, 0x3d, 0x5c, 0xe4, 0xcb, - 0x22, 0x5e, 0x21, 0xfa, 0x21, 0xac, 0x97, 0x08, 0x73, 0x36, 0xb0, 0xa9, 0x18, 0xb3, 0x44, 0xa6, - 0xb3, 0x21, 0x37, 0x51, 0x66, 0xeb, 0x46, 0xca, 0xfc, 0xf7, 0x53, 0xd8, 0x57, 0x16, 0xb4, 0x87, - 0x29, 0x0e, 0x55, 0x05, 0x7b, 0x87, 0x6d, 0x7f, 0x02, 0xed, 0xe9, 0xea, 0x0d, 0x9b, 0xcc, 0x04, - 0x52, 0x3b, 0x2d, 0x42, 0x4d, 0xe5, 0x5a, 0xaa, 0x4a, 0x4b, 0xd5, 0x65, 0x6d, 0xbe, 0xba, 0x7c, - 0x0a, 0x1d, 0x2a, 0x17, 0xe4, 0x67, 0x58, 0x9c, 0x68, 0x42, 0x69, 0x7b, 0xa0, 0x44, 0x47, 0x52, - 0x22, 0xcb, 0xcf, 0x02, 0xa0, 0xca, 0xcf, 0xc6, 0xad, 0xcb, 0x4f, 0x63, 0x44, 0x95, 0x9f, 0xbf, - 0xb6, 0x00, 0x94, 0xe3, 0x32, 0x2d, 0xaf, 0x1a, 0xb5, 0xee, 0x62, 0x54, 0x32, 0x9d, 0xbc, 0xae, - 0x18, 0x89, 0xb0, 0x98, 0xed, 0x2d, 0x37, 0xc1, 0x41, 0x49, 0x1e, 0x7b, 0x5a, 0x65, 0xf6, 0x95, - 0x3b, 0xbf, 0xb5, 0x00, 0x54, 0x72, 0xea, 0x65, 0x2c, 0x52, 0xae, 0x75, 0x73, 0x61, 0x5e, 0x99, - 0x0f, 0xdd, 0x6e, 0x11, 0xba, 0x1b, 0x5e, 0xa2, 0xd3, 0xf4, 0x98, 0x39, 0x6f, 0xa2, 0xab, 0xda, - 0xce, 0xef, 0x2c, 0xe8, 0x9a, 0xd5, 0xe9, 0x25, 0xcd, 0xed, 0xb2, 0xb5, 0xb8, 0xcb, 0xaa, 0x90, - 0x89, 0x53, 0x76, 0xe9, 0x73, 0xfa, 0xa6, 0xb8, 0xcf, 0x40, 0x8b, 0x46, 0xf4, 0x0d, 0x91, 0xfc, - 0xa6, 0x42, 0x92, 0x9e, 0xf3, 0xe2, 0x3e, 0x93, 0x61, 0x48, 0xcf, 0xb9, 0xe4, 0x58, 0x46, 0x02, - 0x92, 0x88, 0xe8, 0xd2, 0x8f, 0xd3, 0x90, 0x1e, 0x53, 0x12, 0xaa, 0x6c, 0x68, 0x79, 0xfd, 0x42, - 0x71, 0x68, 0xe4, 0xf2, 0x81, 0x8f, 0xcc, 0x17, 0x52, 0xf1, 0x0f, 0x75, 0xc8, 0x27, 0x77, 0xc8, - 0x5a, 0x19, 0x62, 0x6d, 0x47, 0x26, 0xa2, 0xfe, 0xfa, 0x69, 0x7b, 0x73, 0x32, 0x59, 0x8f, 0x4e, - 0x59, 0x5f, 0xc7, 0xb1, 0xe6, 0x95, 0x24, 0x72, 0xe5, 0x21, 0x39, 0xc6, 0x79, 0x54, 0xbe, 0x1d, - 0x6a, 0xfa, 0x76, 0x30, 0x8a, 0xb9, 0xaf, 0x89, 0xde, 0x1e, 0x23, 0x21, 0x49, 0x04, 0xc5, 0x91, - 0xfa, 0xf0, 0x2a, 0x53, 0xb2, 0x35, 0x4f, 0xc9, 0xe8, 0x23, 0x40, 0x24, 0x09, 0xd8, 0x65, 0x26, - 0x33, 0x28, 0xc3, 0x9c, 0x9f, 0xa7, 0x2c, 0x34, 0x6f, 0xc3, 0xb5, 0xa9, 0xe6, 0xc8, 0x28, 0xd0, - 0x03, 0x68, 0x08, 0x92, 0xe0, 0x44, 0x98, 0x33, 0x66, 0x7a, 0xe6, 0x5e, 0xe1, 0x79, 0x46, 0x98, - 0x89, 0x69, 0x93, 0xf2, 0x91, 0xec, 0xca, 0x97, 0x25, 0x3f, 0xc1, 0x3b, 0x9f, 0x7e, 0x36, 0x33, - 0x5f, 0xd7, 0x2f, 0x4b, 0x2d, 0x2e, 0x6c, 0x3b, 0xfb, 0xb0, 0x36, 0xa4, 0x5c, 0x1c, 0xa5, 0x11, - 0x0d, 0x2e, 0xef, 0x5c, 0x71, 0x38, 0xbf, 0xb1, 0x00, 0x95, 0xed, 0x98, 0x8f, 0x99, 0xd9, 0xad, - 0x61, 0xdd, 0xfe, 0xd6, 0xf8, 0x00, 0xba, 0x99, 0x32, 0xe3, 0xd3, 0xe4, 0x38, 0x2d, 0x76, 0xaf, - 0xa3, 0x65, 0x32, 0xb6, 0x5c, 0xbe, 0x87, 0x65, 0x30, 0x7d, 0x96, 0x46, 0x44, 0x6f, 0x5e, 0xdb, - 0x6b, 0x4b, 0x89, 0x27, 0x05, 0xce, 0x04, 0x1e, 0x8e, 0x4e, 0xd2, 0xf3, 0xbd, 0x34, 0x39, 0xa6, - 0x93, 0x5c, 0x5f, 0x9b, 0xef, 0xf0, 0xc1, 0x60, 0x43, 0x33, 0xc3, 0x42, 0x9e, 0x29, 0xb3, 0x47, - 0x45, 0xd7, 0xf9, 0xbd, 0x05, 0x8f, 0x96, 0xcd, 0xf4, 0x2e, 0xee, 0x1f, 0xc0, 0x4a, 0xa0, 0xcd, - 0x69, 0x6b, 0xb7, 0xff, 0xb8, 0x9c, 0x1f, 0xe7, 0xec, 0x43, 0x4d, 0x15, 0x07, 0xdb, 0x50, 0x61, - 0x42, 0xad, 0xa0, 0xb7, 0xf3, 0xf4, 0x1a, 0xa6, 0x90, 0x40, 0xf5, 0x1a, 0xad, 0x30, 0x81, 0xba, - 0x60, 0x31, 0xe5, 0xa9, 0xe5, 0x59, 0xec, 0xc3, 0xbf, 0x58, 0xd0, 0x2a, 0xd4, 0x68, 0x0d, 0x56, - 0x5c, 0x77, 0xb8, 0x37, 0xe5, 0xaa, 0xfe, 0x37, 0x50, 0x1f, 0xba, 0xae, 0x3b, 0x3c, 0x2a, 0xaa, - 0xc1, 0xbe, 0x85, 0xba, 0xd0, 0x72, 0xdd, 0xa1, 0x22, 0x9f, 0x7e, 0xc5, 0xf4, 0x3e, 0x8f, 0x72, - 0x7e, 0xd2, 0xaf, 0x4e, 0x0d, 0xc4, 0x19, 0xd6, 0x06, 0x6a, 0x68, 0x05, 0xda, 0xee, 0xe1, 0x70, - 0x90, 0x70, 0xc2, 0x44, 0xbf, 0x6e, 0xba, 0x2e, 0x89, 0x88, 0x20, 0xfd, 0x06, 0x5a, 0x85, 0x8e, - 0x7b, 0x38, 0xdc, 0xcd, 0xa3, 0xd7, 0xf2, 0x1e, 0xeb, 0x37, 0x95, 0xfe, 0xe5, 0x50, 0x3f, 0x50, - 0xfa, 0x2d, 0x65, 0xfe, 0xe5, 0x50, 0x3e, 0x99, 0x2e, 0xfb, 0x6d, 0x33, 0xf8, 0xe7, 0x99, 0xb2, - 0x05, 0xbb, 0xcf, 0x7e, 0xf9, 0xe9, 0x84, 0x8a, 0x93, 0x7c, 0x2c, 0xe3, 0xb5, 0xad, 0x5d, 0xff, - 0x88, 0xa6, 0xa6, 0xb5, 0x5d, 0xb8, 0xbf, 0xad, 0xa2, 0x31, 0xed, 0x66, 0xe3, 0x71, 0x43, 0x49, - 0x3e, 0xf9, 0x67, 0x00, 0x00, 0x00, 0xff, 0xff, 0xec, 0xd1, 0x6e, 0x73, 0x59, 0x17, 0x00, 0x00, + // 1966 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xd4, 0x58, 0xcb, 0x6e, 0x1c, 0xb9, + 0x15, 0x4d, 0xf5, 0xbb, 0x6f, 0xb7, 0x5a, 0x2d, 0x5a, 0x76, 0xca, 0x8f, 0x19, 0x6b, 0x2a, 0x2f, + 0x65, 0x82, 0xb1, 0x26, 0x1a, 0xcc, 0x38, 0x8b, 0x20, 0x81, 0xa5, 0xb2, 0x85, 0xc6, 0xb4, 0x1d, + 0xb9, 0xda, 0x19, 0x20, 0xd9, 0x14, 0xd8, 0x55, 0x54, 0x8b, 0x71, 0x55, 0xb1, 0x44, 0xb2, 0x64, + 0xcb, 0xbb, 0x00, 0xd9, 0x05, 0xc8, 0x2e, 0x9b, 0x00, 0xc9, 0x1f, 0x64, 0x3d, 0xc8, 0x2a, 0x7f, + 0x90, 0x1f, 0xc9, 0x1f, 0xcc, 0x2a, 0xe0, 0xa3, 0xfa, 0xa5, 0x96, 0x22, 0xcb, 0x79, 0xcc, 0xec, + 0x8a, 0xe7, 0x1e, 0x5e, 0x92, 0x97, 0x97, 0x87, 0x97, 0x05, 0x3d, 0x9a, 0x49, 0xc2, 0x33, 0x9c, + 0x3c, 0xc8, 0x39, 0x93, 0x0c, 0xdd, 0x4c, 0x69, 0x72, 0x5a, 0x08, 0xd3, 0x7a, 0x50, 0x1a, 0xef, + 0x74, 0x23, 0x96, 0xa6, 0x2c, 0x33, 0xf0, 0x9d, 0xae, 0x88, 0x8e, 0x49, 0x8a, 0x4d, 0xcb, 0xbb, + 0x0b, 0xb7, 0x0f, 0x88, 0x7c, 0x41, 0x53, 0xf2, 0x82, 0x46, 0x2f, 0xf7, 0x8f, 0x71, 0x96, 0x91, + 0x24, 0x20, 0x27, 0x05, 0x11, 0xd2, 0x7b, 0x0f, 0xee, 0x1e, 0x10, 0x39, 0x92, 0x58, 0x52, 0x21, + 0x69, 0x24, 0x96, 0xcc, 0x37, 0xe1, 0xc6, 0x01, 0x91, 0x7e, 0xbc, 0x04, 0x7f, 0x01, 0xad, 0x67, + 0x2c, 0x26, 0x83, 0xec, 0x88, 0xa1, 0xcf, 0xa0, 0x89, 0xe3, 0x98, 0x13, 0x21, 0x5c, 0x67, 0xcb, + 0xd9, 0xee, 0xec, 0xde, 0x7b, 0xb0, 0x30, 0x47, 0x3b, 0xb3, 0x47, 0x86, 0x13, 0x94, 0x64, 0x84, + 0xa0, 0xc6, 0x59, 0x42, 0xdc, 0xca, 0x96, 0xb3, 0xdd, 0x0e, 0xf4, 0xb7, 0xf7, 0x1b, 0x80, 0x41, + 0x46, 0xe5, 0x21, 0xe6, 0x38, 0x15, 0xe8, 0x16, 0x34, 0x32, 0x35, 0x8a, 0xaf, 0x1d, 0x57, 0x03, + 0xdb, 0x42, 0x3e, 0x74, 0x85, 0xc4, 0x5c, 0x86, 0xb9, 0xe6, 0xb9, 0x95, 0xad, 0xea, 0x76, 0x67, + 0xf7, 0x83, 0x95, 0xc3, 0x7e, 0x4e, 0xce, 0xbe, 0xc0, 0x49, 0x41, 0x0e, 0x31, 0xe5, 0x41, 0x47, + 0x77, 0x33, 0xde, 0xbd, 0x5f, 0x01, 0x8c, 0x24, 0xa7, 0xd9, 0x64, 0x48, 0x85, 0x54, 0x63, 0x9d, + 0x2a, 0x9e, 0x5a, 0x44, 0x75, 0xbb, 0x1d, 0xd8, 0x16, 0xfa, 0x04, 0x1a, 0x42, 0x62, 0x59, 0x08, + 0x3d, 0xcf, 0xce, 0xee, 0xdd, 0x95, 0xa3, 0x8c, 0x34, 0x25, 0xb0, 0x54, 0xef, 0xaf, 0x15, 0xd8, + 0x5c, 0x88, 0xaa, 0x8d, 0x1b, 0xfa, 0x18, 0x6a, 0x63, 0x2c, 0xc8, 0xa5, 0x81, 0x7a, 0x2a, 0x26, + 0x7b, 0x58, 0x90, 0x40, 0x33, 0x55, 0x94, 0xe2, 0xf1, 0xc0, 0xd7, 0xa3, 0x57, 0x03, 0xfd, 0x8d, + 0x3c, 0xe8, 0x46, 0x2c, 0x49, 0x48, 0x24, 0x29, 0xcb, 0x06, 0xbe, 0x5b, 0xd5, 0xb6, 0x05, 0x4c, + 0x71, 0x72, 0xcc, 0x25, 0x35, 0x4d, 0xe1, 0xd6, 0xb6, 0xaa, 0x8a, 0x33, 0x8f, 0xa1, 0x1f, 0x42, + 0x5f, 0x72, 0x7c, 0x4a, 0x92, 0x50, 0xd2, 0x94, 0x08, 0x89, 0xd3, 0xdc, 0xad, 0x6f, 0x39, 0xdb, + 0xb5, 0x60, 0xdd, 0xe0, 0x2f, 0x4a, 0x18, 0xed, 0xc0, 0x8d, 0x49, 0x81, 0x39, 0xce, 0x24, 0x21, + 0x73, 0xec, 0x86, 0x66, 0xa3, 0xa9, 0x69, 0xd6, 0xe1, 0x47, 0xb0, 0xa1, 0x68, 0xac, 0x90, 0x73, + 0xf4, 0xa6, 0xa6, 0xf7, 0xad, 0x61, 0x4a, 0xf6, 0xbe, 0x74, 0xe0, 0xe6, 0x52, 0xbc, 0x44, 0xce, + 0x32, 0x41, 0xae, 0x11, 0xb0, 0xeb, 0x6c, 0x18, 0x7a, 0x08, 0x75, 0xf5, 0x25, 0xdc, 0xea, 0x55, + 0x53, 0xc9, 0xf0, 0xbd, 0xbf, 0x38, 0x80, 0xf6, 0x39, 0xc1, 0x92, 0x3c, 0x4a, 0x28, 0x7e, 0x87, + 0x7d, 0xfe, 0x36, 0x34, 0xe3, 0x71, 0x98, 0xe1, 0xb4, 0x3c, 0x10, 0x8d, 0x78, 0xfc, 0x0c, 0xa7, + 0x04, 0xfd, 0x00, 0xd6, 0x67, 0x1b, 0x6b, 0x08, 0x55, 0x4d, 0xe8, 0xcd, 0x60, 0x4d, 0xdc, 0x84, + 0x3a, 0x56, 0x73, 0x70, 0x6b, 0xda, 0x6c, 0x1a, 0x9e, 0x80, 0xbe, 0xcf, 0x59, 0xfe, 0xdf, 0x9a, + 0xdd, 0x74, 0xd0, 0xea, 0xfc, 0xa0, 0x7f, 0x76, 0x60, 0xe3, 0x51, 0x22, 0x09, 0xff, 0x9a, 0x06, + 0xe5, 0xef, 0x95, 0x72, 0xd7, 0x06, 0x59, 0x4c, 0x5e, 0xff, 0x3f, 0x27, 0xf8, 0x1e, 0xc0, 0x11, + 0x25, 0x49, 0x6c, 0x38, 0x66, 0x96, 0x6d, 0x8d, 0x68, 0x73, 0x79, 0xfc, 0xeb, 0x97, 0x1c, 0xff, + 0xc6, 0x8a, 0xe3, 0xef, 0x42, 0x53, 0x3b, 0x19, 0xf8, 0xfa, 0xd0, 0x55, 0x83, 0xb2, 0xa9, 0xc4, + 0x93, 0xbc, 0x96, 0x1c, 0x97, 0xe2, 0xd9, 0xba, 0xb2, 0x78, 0xea, 0x6e, 0x56, 0x3c, 0xff, 0x59, + 0x83, 0xb5, 0x11, 0xc1, 0x3c, 0x3a, 0xbe, 0x7e, 0xf0, 0x36, 0xa1, 0xce, 0xc9, 0xc9, 0x54, 0xdb, + 0x4c, 0x63, 0xba, 0xe2, 0xea, 0x25, 0x2b, 0xae, 0x5d, 0x41, 0xf0, 0xea, 0x2b, 0x04, 0xaf, 0x0f, + 0xd5, 0x58, 0x24, 0x3a, 0x60, 0xed, 0x40, 0x7d, 0x2a, 0x99, 0xca, 0x13, 0x1c, 0x91, 0x63, 0x96, + 0xc4, 0x84, 0x87, 0x13, 0xce, 0x0a, 0x23, 0x53, 0xdd, 0xa0, 0x3f, 0x67, 0x38, 0x50, 0x38, 0x7a, + 0x08, 0xad, 0x58, 0x24, 0xa1, 0x3c, 0xcb, 0x89, 0xdb, 0xda, 0x72, 0xb6, 0x7b, 0x17, 0x2c, 0xd3, + 0x17, 0xc9, 0x8b, 0xb3, 0x9c, 0x04, 0xcd, 0xd8, 0x7c, 0xa0, 0x8f, 0x61, 0x53, 0x10, 0x4e, 0x71, + 0x42, 0xdf, 0x90, 0x38, 0x24, 0xaf, 0x73, 0x1e, 0xe6, 0x09, 0xce, 0xdc, 0xb6, 0x1e, 0x08, 0xcd, + 0x6c, 0x8f, 0x5f, 0xe7, 0xfc, 0x30, 0xc1, 0x19, 0xda, 0x86, 0x3e, 0x2b, 0x64, 0x5e, 0xc8, 0x50, + 0xef, 0x9b, 0x08, 0x69, 0xec, 0x82, 0x5e, 0x51, 0xcf, 0xe0, 0x4f, 0x34, 0x3c, 0x88, 0x2f, 0x52, + 0xe6, 0xee, 0xdb, 0x29, 0xf3, 0xda, 0x6a, 0x65, 0x46, 0x3d, 0xa8, 0x64, 0x27, 0x6e, 0x4f, 0xc7, + 0xbb, 0x92, 0x9d, 0xa8, 0xdd, 0x91, 0x2c, 0x7f, 0xe9, 0xae, 0x9b, 0xdd, 0x51, 0xdf, 0xe8, 0x7d, + 0x80, 0x94, 0x48, 0x4e, 0x23, 0xb5, 0x56, 0xb7, 0xaf, 0x83, 0x3b, 0x87, 0xa0, 0xef, 0xc2, 0x1a, + 0x9d, 0x64, 0x8c, 0x93, 0x03, 0xce, 0x5e, 0xd1, 0x6c, 0xe2, 0x6e, 0x6c, 0x39, 0xdb, 0xad, 0x60, + 0x11, 0x44, 0x77, 0xa0, 0x55, 0x08, 0x55, 0xcc, 0xa4, 0xc4, 0x45, 0xda, 0xc7, 0xb4, 0xed, 0xfd, + 0x63, 0x2e, 0xdb, 0x44, 0x91, 0x48, 0xf1, 0xbf, 0xba, 0x17, 0xa6, 0x29, 0x5a, 0x9d, 0x4f, 0xd1, + 0xfb, 0xd0, 0x31, 0xcb, 0x33, 0xa9, 0x50, 0x3b, 0xb7, 0xe2, 0xfb, 0xd0, 0xc9, 0x8a, 0x34, 0x3c, + 0x29, 0x08, 0xa7, 0x44, 0xd8, 0xc3, 0x0b, 0x59, 0x91, 0x3e, 0x37, 0x08, 0xba, 0x01, 0x75, 0xc9, + 0xf2, 0xf0, 0xa5, 0x3d, 0xbb, 0x2a, 0x8e, 0x9f, 0xa3, 0x9f, 0xc2, 0x1d, 0x41, 0x70, 0x42, 0xe2, + 0x50, 0x90, 0x49, 0x4a, 0x32, 0x39, 0xf0, 0x45, 0x28, 0xf4, 0xb2, 0x49, 0xec, 0x36, 0xf5, 0xee, + 0xbb, 0x86, 0x31, 0x9a, 0x12, 0x46, 0xd6, 0xae, 0xf2, 0x20, 0x32, 0x45, 0xda, 0x42, 0xb7, 0x96, + 0xae, 0x66, 0xd0, 0xcc, 0x34, 0xed, 0xf0, 0x13, 0x70, 0x27, 0x09, 0x1b, 0xe3, 0x24, 0x3c, 0x37, + 0xaa, 0xdb, 0xd6, 0x83, 0xdd, 0x32, 0xf6, 0xd1, 0xd2, 0x90, 0x6a, 0x79, 0x22, 0xa1, 0x11, 0x89, + 0xc3, 0x71, 0xc2, 0xc6, 0x2e, 0xe8, 0x2c, 0x06, 0x03, 0xed, 0x25, 0x6c, 0xac, 0xb2, 0xd7, 0x12, + 0x54, 0x18, 0x22, 0x56, 0x64, 0xd2, 0xed, 0xe8, 0x95, 0xf6, 0x0c, 0xfe, 0xac, 0x48, 0xf7, 0x15, + 0x8a, 0xbe, 0x03, 0x6b, 0x96, 0xc9, 0x8e, 0x8e, 0x04, 0x91, 0x3a, 0x6f, 0xab, 0x41, 0xd7, 0x80, + 0xbf, 0xd0, 0x18, 0x3a, 0x54, 0x62, 0x2a, 0xe4, 0xa3, 0xc9, 0x84, 0x93, 0x09, 0x56, 0x87, 0x59, + 0xe7, 0x6b, 0x67, 0xf7, 0xfb, 0x0f, 0x56, 0x56, 0xc3, 0x0f, 0xf6, 0x17, 0xd9, 0xc1, 0x72, 0x77, + 0xef, 0x04, 0xd6, 0x97, 0x38, 0x4a, 0x3f, 0xb8, 0xad, 0x3a, 0x54, 0xfa, 0xdb, 0x92, 0x73, 0x01, + 0x43, 0x5b, 0xd0, 0x11, 0x84, 0x9f, 0xd2, 0xc8, 0x50, 0x8c, 0x6e, 0xcd, 0x43, 0x4a, 0x77, 0x25, + 0x93, 0x38, 0x79, 0xf6, 0xdc, 0xa6, 0x4c, 0xd9, 0xf4, 0x7e, 0x5b, 0x87, 0xf5, 0x40, 0xa5, 0x08, + 0x39, 0x25, 0xdf, 0x24, 0xcd, 0xbc, 0x48, 0xbb, 0x1a, 0x6f, 0xa5, 0x5d, 0xcd, 0x95, 0xda, 0xf5, + 0x3d, 0xe8, 0xa5, 0xa7, 0x51, 0x34, 0xa7, 0x43, 0x2d, 0xad, 0x43, 0x6b, 0x0a, 0xfd, 0xb7, 0xc5, + 0x67, 0xfb, 0xed, 0x24, 0x0e, 0x2e, 0x90, 0xb8, 0x4d, 0xa8, 0x27, 0x34, 0xa5, 0x65, 0x86, 0x9a, + 0xc6, 0x79, 0xd1, 0xea, 0xae, 0x12, 0xad, 0xdb, 0xd0, 0xa2, 0xc2, 0x26, 0xf8, 0x9a, 0x26, 0x34, + 0xa9, 0x30, 0x99, 0xfd, 0x18, 0xee, 0x53, 0x49, 0xb8, 0x4e, 0xae, 0x90, 0xbc, 0x96, 0x24, 0x13, + 0xea, 0x8b, 0x93, 0xb8, 0x88, 0x48, 0xc8, 0xb1, 0x24, 0x56, 0x56, 0xef, 0x4d, 0x69, 0x8f, 0x4b, + 0x56, 0xa0, 0x49, 0x01, 0x96, 0x64, 0x41, 0x16, 0xd7, 0x17, 0x65, 0x11, 0xed, 0xc0, 0xa6, 0x75, + 0x27, 0x94, 0x9a, 0x1c, 0x31, 0x1e, 0x8e, 0x89, 0x90, 0x5a, 0x82, 0x5b, 0xc1, 0x86, 0xb1, 0x8d, + 0x24, 0xcb, 0x9f, 0x30, 0xbe, 0xa7, 0x9e, 0x6d, 0x5f, 0x55, 0xe7, 0x73, 0xf0, 0x6b, 0xa0, 0xa4, + 0x1f, 0x42, 0x95, 0xc6, 0xa6, 0x38, 0xeb, 0xec, 0xba, 0x8b, 0x7e, 0xec, 0x1b, 0x76, 0xe0, 0x8b, + 0x40, 0x91, 0xd0, 0xcf, 0xa1, 0x63, 0xf3, 0x29, 0xc6, 0x12, 0xeb, 0x5c, 0xed, 0xec, 0xbe, 0xbf, + 0xb2, 0x8f, 0x4e, 0x30, 0x1f, 0x4b, 0x1c, 0x98, 0xe2, 0x4a, 0xa8, 0x6f, 0xf4, 0x33, 0xb8, 0x7b, + 0x5e, 0x5f, 0xb9, 0x0d, 0x47, 0xec, 0x36, 0x74, 0x8a, 0xde, 0x5e, 0x16, 0xd8, 0x32, 0x5e, 0x31, + 0xfa, 0x31, 0x6c, 0xce, 0x29, 0xec, 0xac, 0x63, 0x53, 0x4b, 0xec, 0x9c, 0xfa, 0xce, 0xba, 0x5c, + 0xa6, 0xb1, 0xad, 0x4b, 0x35, 0xf6, 0x3f, 0xaf, 0x79, 0x5f, 0x39, 0xd0, 0x1e, 0x32, 0x1c, 0xeb, + 0x92, 0xf7, 0x1a, 0xdb, 0x7e, 0x0f, 0xda, 0xd3, 0xd9, 0x5b, 0xf9, 0x99, 0x01, 0xca, 0x3a, 0xad, + 0x5a, 0x6d, 0xa9, 0x3b, 0x57, 0xc6, 0xce, 0x95, 0xa3, 0xb5, 0xc5, 0x72, 0xf4, 0x3e, 0x74, 0xa8, + 0x9a, 0x50, 0x98, 0x63, 0x79, 0x6c, 0x14, 0xa8, 0x1d, 0x80, 0x86, 0x0e, 0x15, 0xa2, 0xea, 0xd5, + 0x92, 0xa0, 0xeb, 0xd5, 0xc6, 0x95, 0xeb, 0x55, 0xeb, 0x44, 0xd7, 0xab, 0xbf, 0x73, 0x00, 0xf4, + 0xc2, 0x55, 0x5a, 0x9e, 0x77, 0xea, 0x5c, 0xc7, 0xa9, 0x92, 0x46, 0x75, 0xbf, 0x71, 0x92, 0x60, + 0x39, 0xdb, 0x5b, 0x61, 0x83, 0x83, 0xb2, 0x22, 0x0d, 0x8c, 0xc9, 0xee, 0xab, 0xf0, 0xfe, 0xe0, + 0x00, 0xe8, 0xe4, 0x34, 0xd3, 0x58, 0xd6, 0x68, 0xe7, 0xf2, 0x4a, 0xbe, 0xb2, 0x18, 0xba, 0xbd, + 0x32, 0x74, 0x97, 0x3c, 0x5d, 0xa7, 0xe9, 0x31, 0x5b, 0xbc, 0x8d, 0xae, 0xfe, 0xf6, 0xfe, 0xe8, + 0x40, 0xd7, 0xce, 0xce, 0x4c, 0x69, 0x61, 0x97, 0x9d, 0xe5, 0x5d, 0xd6, 0x95, 0x4f, 0xca, 0xf8, + 0x59, 0x28, 0xe8, 0x9b, 0xf2, 0x02, 0x04, 0x03, 0x8d, 0xe8, 0x1b, 0xa2, 0x04, 0x51, 0x87, 0x84, + 0xbd, 0x12, 0xe5, 0x05, 0xa8, 0xc2, 0xc0, 0x5e, 0x09, 0x25, 0xca, 0x9c, 0x44, 0x24, 0x93, 0xc9, + 0x59, 0x98, 0xb2, 0x98, 0x1e, 0x51, 0x12, 0xeb, 0x6c, 0x68, 0x05, 0xfd, 0xd2, 0xf0, 0xd4, 0xe2, + 0xde, 0x97, 0xea, 0x5d, 0x6d, 0x0e, 0x54, 0xf9, 0xe3, 0xea, 0xa9, 0x98, 0x5c, 0x23, 0x6b, 0x55, + 0x88, 0x8d, 0x1f, 0x95, 0x88, 0xe6, 0x5f, 0x51, 0x3b, 0x58, 0xc0, 0x54, 0x01, 0x3b, 0xbd, 0x26, + 0x4c, 0x1c, 0x6b, 0xc1, 0x1c, 0xa2, 0x66, 0x1e, 0x93, 0x23, 0x5c, 0x24, 0xf3, 0xd7, 0x49, 0xcd, + 0x5c, 0x27, 0xd6, 0xb0, 0xf0, 0x2f, 0xa3, 0xb7, 0xcf, 0x49, 0x4c, 0x32, 0x49, 0x71, 0xa2, 0xff, + 0x90, 0xcd, 0x6b, 0xb8, 0xb3, 0xa4, 0xe1, 0x1f, 0x01, 0x22, 0x59, 0xc4, 0xcf, 0x72, 0x95, 0x41, + 0x39, 0x16, 0xe2, 0x15, 0xe3, 0xb1, 0x7d, 0x4c, 0x6e, 0x4c, 0x2d, 0x87, 0xd6, 0x80, 0x6e, 0x41, + 0x43, 0x92, 0x0c, 0x67, 0xd2, 0x9e, 0x31, 0xdb, 0xb2, 0x17, 0x91, 0x28, 0x72, 0xc2, 0x6d, 0x4c, + 0x9b, 0x54, 0x8c, 0x54, 0x53, 0x3d, 0x45, 0xc5, 0x31, 0xde, 0xfd, 0xf4, 0xb3, 0x99, 0xfb, 0xba, + 0x79, 0x8a, 0x1a, 0xb8, 0xf4, 0xed, 0x3d, 0x86, 0x8d, 0x21, 0x15, 0xf2, 0x90, 0x25, 0x34, 0x3a, + 0xbb, 0x76, 0x89, 0xe2, 0xfd, 0xde, 0x01, 0x34, 0xef, 0xc7, 0xfe, 0xc9, 0x99, 0xdd, 0x1a, 0xce, + 0xd5, 0x6f, 0x8d, 0x0f, 0xa0, 0x9b, 0x6b, 0x37, 0x21, 0xcd, 0x8e, 0x58, 0xb9, 0x7b, 0x1d, 0x83, + 0xa9, 0xd8, 0x0a, 0xf5, 0x80, 0x56, 0xc1, 0x0c, 0x39, 0x4b, 0x88, 0xd9, 0xbc, 0x76, 0xd0, 0x56, + 0x48, 0xa0, 0x00, 0x6f, 0x02, 0xb7, 0x47, 0xc7, 0xec, 0xd5, 0x3e, 0xcb, 0x8e, 0xe8, 0xa4, 0x30, + 0xf7, 0xec, 0x3b, 0xfc, 0x91, 0x70, 0xa1, 0x99, 0x63, 0xa9, 0xce, 0x94, 0xdd, 0xa3, 0xb2, 0xe9, + 0xfd, 0xc9, 0x81, 0x3b, 0xab, 0x46, 0x7a, 0x97, 0xe5, 0x1f, 0xc0, 0x5a, 0x64, 0xdc, 0x19, 0x6f, + 0x57, 0xff, 0xd3, 0xb9, 0xd8, 0xcf, 0x7b, 0x0c, 0x35, 0x5d, 0x4d, 0xec, 0x40, 0x85, 0x4b, 0x3d, + 0x83, 0xde, 0xee, 0xfd, 0x0b, 0x94, 0x42, 0x11, 0xf5, 0xf3, 0xb5, 0xc2, 0x25, 0xea, 0x82, 0xc3, + 0xf5, 0x4a, 0x9d, 0xc0, 0xe1, 0x1f, 0xfe, 0xcd, 0x81, 0x56, 0x69, 0x46, 0x1b, 0xb0, 0xe6, 0xfb, + 0xc3, 0xfd, 0xa9, 0x56, 0xf5, 0xbf, 0x85, 0xfa, 0xd0, 0xf5, 0xfd, 0xe1, 0x61, 0x59, 0x3e, 0xf6, + 0x1d, 0xd4, 0x85, 0x96, 0xef, 0x0f, 0xb5, 0xf8, 0xf4, 0x2b, 0xb6, 0xf5, 0x24, 0x29, 0xc4, 0x71, + 0xbf, 0x3a, 0x75, 0x90, 0xe6, 0xd8, 0x38, 0xa8, 0xa1, 0x35, 0x68, 0xfb, 0x4f, 0x87, 0x83, 0x4c, + 0x10, 0x2e, 0xfb, 0x75, 0xdb, 0xf4, 0x49, 0x42, 0x24, 0xe9, 0x37, 0xd0, 0x3a, 0x74, 0xfc, 0xa7, + 0xc3, 0xbd, 0x22, 0x79, 0xa9, 0xee, 0xb1, 0x7e, 0x53, 0xdb, 0x9f, 0x0f, 0xcd, 0x8b, 0xa6, 0xdf, + 0xd2, 0xee, 0x9f, 0x0f, 0xd5, 0x1b, 0xeb, 0xac, 0xdf, 0xb6, 0x9d, 0x7f, 0x99, 0x6b, 0x5f, 0xb0, + 0xf7, 0xf0, 0xd7, 0x9f, 0x4e, 0xa8, 0x3c, 0x2e, 0xc6, 0x2a, 0x5e, 0x3b, 0x66, 0xe9, 0x1f, 0x51, + 0x66, 0xbf, 0x76, 0xca, 0xe5, 0xef, 0xe8, 0x68, 0x4c, 0x9b, 0xf9, 0x78, 0xdc, 0xd0, 0xc8, 0x27, + 0xff, 0x0a, 0x00, 0x00, 0xff, 0xff, 0x39, 0x3a, 0x8f, 0x34, 0x8a, 0x17, 0x00, 0x00, } diff --git a/internal/proto/plan.proto b/internal/proto/plan.proto index 1521385c6689b..b80fd3f349120 100644 --- a/internal/proto/plan.proto +++ b/internal/proto/plan.proto @@ -27,6 +27,13 @@ enum ArithOpType { Mul = 3; Div = 4; Mod = 5; + ArrayLength = 6; +}; + +enum VectorType { + BinaryVector = 0; + FloatVector = 1; + Float16Vector = 2; }; message GenericValue { @@ -42,6 +49,7 @@ message GenericValue { message Array { repeated GenericValue array = 1; bool same_type = 2; + schema.DataType element_type = 3; } message QueryInfo { @@ -58,6 +66,7 @@ message ColumnInfo { bool is_autoID = 4; repeated string nested_path = 5; bool is_partition_key = 6; + schema.DataType element_type = 7; } message ColumnExpr { @@ -102,9 +111,9 @@ message JSONContainsExpr { ColumnInfo column_info = 1; repeated GenericValue elements = 2; // 0: invalid - // 1: json_contains - // 2: json_contains_all - // 3: json_contains_any + // 1: json_contains | array_contains + // 2: json_contains_all | array_contains_all + // 3: json_contains_any | array_contains_any enum JSONOp { Invalid = 0; Contains = 1; @@ -176,7 +185,7 @@ message Expr { } message VectorANNS { - bool is_binary = 1; + VectorType vector_type = 1; int64 field_id = 2; Expr predicates = 3; QueryInfo query_info = 4; diff --git a/internal/proto/planpb/plan.pb.go b/internal/proto/planpb/plan.pb.go index 3a8466533dc68..2edfb4c92706d 100644 --- a/internal/proto/planpb/plan.pb.go +++ b/internal/proto/planpb/plan.pb.go @@ -82,12 +82,13 @@ func (OpType) EnumDescriptor() ([]byte, []int) { type ArithOpType int32 const ( - ArithOpType_Unknown ArithOpType = 0 - ArithOpType_Add ArithOpType = 1 - ArithOpType_Sub ArithOpType = 2 - ArithOpType_Mul ArithOpType = 3 - ArithOpType_Div ArithOpType = 4 - ArithOpType_Mod ArithOpType = 5 + ArithOpType_Unknown ArithOpType = 0 + ArithOpType_Add ArithOpType = 1 + ArithOpType_Sub ArithOpType = 2 + ArithOpType_Mul ArithOpType = 3 + ArithOpType_Div ArithOpType = 4 + ArithOpType_Mod ArithOpType = 5 + ArithOpType_ArrayLength ArithOpType = 6 ) var ArithOpType_name = map[int32]string{ @@ -97,15 +98,17 @@ var ArithOpType_name = map[int32]string{ 3: "Mul", 4: "Div", 5: "Mod", + 6: "ArrayLength", } var ArithOpType_value = map[string]int32{ - "Unknown": 0, - "Add": 1, - "Sub": 2, - "Mul": 3, - "Div": 4, - "Mod": 5, + "Unknown": 0, + "Add": 1, + "Sub": 2, + "Mul": 3, + "Div": 4, + "Mod": 5, + "ArrayLength": 6, } func (x ArithOpType) String() string { @@ -116,10 +119,38 @@ func (ArithOpType) EnumDescriptor() ([]byte, []int) { return fileDescriptor_2d655ab2f7683c23, []int{1} } +type VectorType int32 + +const ( + VectorType_BinaryVector VectorType = 0 + VectorType_FloatVector VectorType = 1 + VectorType_Float16Vector VectorType = 2 +) + +var VectorType_name = map[int32]string{ + 0: "BinaryVector", + 1: "FloatVector", + 2: "Float16Vector", +} + +var VectorType_value = map[string]int32{ + "BinaryVector": 0, + "FloatVector": 1, + "Float16Vector": 2, +} + +func (x VectorType) String() string { + return proto.EnumName(VectorType_name, int32(x)) +} + +func (VectorType) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_2d655ab2f7683c23, []int{2} +} + // 0: invalid -// 1: json_contains -// 2: json_contains_all -// 3: json_contains_any +// 1: json_contains | array_contains +// 2: json_contains_all | array_contains_all +// 3: json_contains_any | array_contains_any type JSONContainsExpr_JSONOp int32 const ( @@ -331,11 +362,12 @@ func (*GenericValue) XXX_OneofWrappers() []interface{} { } type Array struct { - Array []*GenericValue `protobuf:"bytes,1,rep,name=array,proto3" json:"array,omitempty"` - SameType bool `protobuf:"varint,2,opt,name=same_type,json=sameType,proto3" json:"same_type,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + Array []*GenericValue `protobuf:"bytes,1,rep,name=array,proto3" json:"array,omitempty"` + SameType bool `protobuf:"varint,2,opt,name=same_type,json=sameType,proto3" json:"same_type,omitempty"` + ElementType schemapb.DataType `protobuf:"varint,3,opt,name=element_type,json=elementType,proto3,enum=milvus.proto.schema.DataType" json:"element_type,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } func (m *Array) Reset() { *m = Array{} } @@ -377,6 +409,13 @@ func (m *Array) GetSameType() bool { return false } +func (m *Array) GetElementType() schemapb.DataType { + if m != nil { + return m.ElementType + } + return schemapb.DataType_None +} + type QueryInfo struct { Topk int64 `protobuf:"varint,1,opt,name=topk,proto3" json:"topk,omitempty"` MetricType string `protobuf:"bytes,3,opt,name=metric_type,json=metricType,proto3" json:"metric_type,omitempty"` @@ -447,6 +486,7 @@ type ColumnInfo struct { IsAutoID bool `protobuf:"varint,4,opt,name=is_autoID,json=isAutoID,proto3" json:"is_autoID,omitempty"` NestedPath []string `protobuf:"bytes,5,rep,name=nested_path,json=nestedPath,proto3" json:"nested_path,omitempty"` IsPartitionKey bool `protobuf:"varint,6,opt,name=is_partition_key,json=isPartitionKey,proto3" json:"is_partition_key,omitempty"` + ElementType schemapb.DataType `protobuf:"varint,7,opt,name=element_type,json=elementType,proto3,enum=milvus.proto.schema.DataType" json:"element_type,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -519,6 +559,13 @@ func (m *ColumnInfo) GetIsPartitionKey() bool { return false } +func (m *ColumnInfo) GetElementType() schemapb.DataType { + if m != nil { + return m.ElementType + } + return schemapb.DataType_None +} + type ColumnExpr struct { Info *ColumnInfo `protobuf:"bytes,1,opt,name=info,proto3" json:"info,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` @@ -1496,7 +1543,7 @@ func (*Expr) XXX_OneofWrappers() []interface{} { } type VectorANNS struct { - IsBinary bool `protobuf:"varint,1,opt,name=is_binary,json=isBinary,proto3" json:"is_binary,omitempty"` + VectorType VectorType `protobuf:"varint,1,opt,name=vector_type,json=vectorType,proto3,enum=milvus.proto.plan.VectorType" json:"vector_type,omitempty"` FieldId int64 `protobuf:"varint,2,opt,name=field_id,json=fieldId,proto3" json:"field_id,omitempty"` Predicates *Expr `protobuf:"bytes,3,opt,name=predicates,proto3" json:"predicates,omitempty"` QueryInfo *QueryInfo `protobuf:"bytes,4,opt,name=query_info,json=queryInfo,proto3" json:"query_info,omitempty"` @@ -1531,11 +1578,11 @@ func (m *VectorANNS) XXX_DiscardUnknown() { var xxx_messageInfo_VectorANNS proto.InternalMessageInfo -func (m *VectorANNS) GetIsBinary() bool { +func (m *VectorANNS) GetVectorType() VectorType { if m != nil { - return m.IsBinary + return m.VectorType } - return false + return VectorType_BinaryVector } func (m *VectorANNS) GetFieldId() int64 { @@ -1728,6 +1775,7 @@ func (*PlanNode) XXX_OneofWrappers() []interface{} { func init() { proto.RegisterEnum("milvus.proto.plan.OpType", OpType_name, OpType_value) proto.RegisterEnum("milvus.proto.plan.ArithOpType", ArithOpType_name, ArithOpType_value) + proto.RegisterEnum("milvus.proto.plan.VectorType", VectorType_name, VectorType_value) proto.RegisterEnum("milvus.proto.plan.JSONContainsExpr_JSONOp", JSONContainsExpr_JSONOp_name, JSONContainsExpr_JSONOp_value) proto.RegisterEnum("milvus.proto.plan.UnaryExpr_UnaryOp", UnaryExpr_UnaryOp_name, UnaryExpr_UnaryOp_value) proto.RegisterEnum("milvus.proto.plan.BinaryExpr_BinaryOp", BinaryExpr_BinaryOp_name, BinaryExpr_BinaryOp_value) @@ -1758,116 +1806,120 @@ func init() { func init() { proto.RegisterFile("plan.proto", fileDescriptor_2d655ab2f7683c23) } var fileDescriptor_2d655ab2f7683c23 = []byte{ - // 1762 bytes of a gzipped FileDescriptorProto + // 1826 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xc4, 0x58, 0x4f, 0x93, 0xdb, 0x48, - 0x15, 0xb7, 0x2c, 0x7b, 0x2c, 0x3d, 0x79, 0x3c, 0x8a, 0x8a, 0x2a, 0x26, 0x09, 0x9b, 0x19, 0xb4, - 0x5b, 0xec, 0x10, 0xc8, 0xa4, 0x36, 0xbb, 0x9b, 0xb0, 0xd9, 0x5a, 0x18, 0xcf, 0x9f, 0x8d, 0xcd, - 0x6e, 0x66, 0x06, 0xcd, 0x6c, 0x0e, 0x70, 0x50, 0xb5, 0xa5, 0x9e, 0x71, 0x13, 0xb9, 0xa5, 0x48, - 0x2d, 0x27, 0xfe, 0x0a, 0xdc, 0xf8, 0x00, 0x9c, 0x38, 0x70, 0xe7, 0xc8, 0x05, 0xce, 0x14, 0x07, - 0x8e, 0x1c, 0xa9, 0xe2, 0x13, 0x50, 0x7c, 0x01, 0xaa, 0x5f, 0x4b, 0x96, 0x3d, 0xd8, 0x19, 0x0f, - 0xa4, 0x8a, 0x5b, 0xeb, 0xf5, 0x7b, 0xbf, 0x7e, 0xff, 0xfa, 0xbd, 0xd7, 0x02, 0x48, 0x22, 0xc2, - 0x77, 0x93, 0x34, 0x16, 0xb1, 0x73, 0x6b, 0xc4, 0xa2, 0x71, 0x9e, 0xa9, 0xaf, 0x5d, 0xb9, 0x71, - 0xa7, 0x9d, 0x05, 0x43, 0x3a, 0x22, 0x8a, 0xe4, 0xfe, 0x59, 0x83, 0xf6, 0x33, 0xca, 0x69, 0xca, - 0x82, 0x17, 0x24, 0xca, 0xa9, 0x73, 0x17, 0x8c, 0x41, 0x1c, 0x47, 0xfe, 0x98, 0x44, 0x9b, 0xda, - 0xb6, 0xb6, 0x63, 0xf4, 0x6a, 0x5e, 0x4b, 0x52, 0x5e, 0x90, 0xc8, 0x79, 0x0f, 0x4c, 0xc6, 0xc5, - 0xe3, 0x4f, 0x70, 0xb7, 0xbe, 0xad, 0xed, 0xe8, 0xbd, 0x9a, 0x67, 0x20, 0xa9, 0xd8, 0xbe, 0x88, - 0x62, 0x22, 0x70, 0x5b, 0xdf, 0xd6, 0x76, 0x34, 0xb9, 0x8d, 0x24, 0xb9, 0xbd, 0x05, 0x90, 0x89, - 0x94, 0xf1, 0x4b, 0xdc, 0x6f, 0x6c, 0x6b, 0x3b, 0x66, 0xaf, 0xe6, 0x99, 0x8a, 0x26, 0x19, 0x9e, - 0x80, 0x49, 0xd2, 0x94, 0x4c, 0x70, 0xbf, 0xb9, 0xad, 0xed, 0x58, 0x8f, 0x36, 0x77, 0xff, 0xc3, - 0x82, 0xdd, 0xae, 0xe4, 0x91, 0xc8, 0xc8, 0xfc, 0x82, 0x44, 0xfb, 0x4d, 0xd0, 0xc7, 0x24, 0x72, - 0x7f, 0x01, 0x4d, 0xdc, 0x73, 0x3e, 0x85, 0x26, 0xee, 0x6d, 0x6a, 0xdb, 0xfa, 0x8e, 0xf5, 0x68, - 0x6b, 0x01, 0xc8, 0xac, 0xd1, 0x9e, 0xe2, 0x76, 0xee, 0x82, 0x99, 0x91, 0x11, 0xf5, 0xc5, 0x24, - 0xa1, 0x68, 0x9e, 0xe1, 0x19, 0x92, 0x70, 0x3e, 0x49, 0xa8, 0xfb, 0x2b, 0x0d, 0xcc, 0x9f, 0xe5, - 0x34, 0x9d, 0xf4, 0xf9, 0x45, 0xec, 0x38, 0xd0, 0x10, 0x71, 0xf2, 0x12, 0x5d, 0xa4, 0x7b, 0xb8, - 0x76, 0xb6, 0xc0, 0x1a, 0x51, 0x91, 0xb2, 0x40, 0x01, 0x48, 0x07, 0x98, 0x1e, 0x28, 0x92, 0x84, - 0x70, 0xde, 0x87, 0xf5, 0x8c, 0x92, 0x34, 0x18, 0xfa, 0x09, 0x49, 0xc9, 0x28, 0x53, 0x3e, 0xf0, - 0xda, 0x8a, 0x78, 0x8a, 0x34, 0xc9, 0x94, 0xc6, 0x39, 0x0f, 0xfd, 0x90, 0x06, 0x6c, 0x54, 0x38, - 0x42, 0xf7, 0xda, 0x48, 0x3c, 0x54, 0x34, 0xf7, 0x9f, 0x1a, 0xc0, 0x41, 0x1c, 0xe5, 0x23, 0x8e, - 0xda, 0xdc, 0x06, 0xe3, 0x82, 0xd1, 0x28, 0xf4, 0x59, 0x58, 0x68, 0xd4, 0xc2, 0xef, 0x7e, 0xe8, - 0x3c, 0x05, 0x33, 0x24, 0x82, 0x54, 0x36, 0x75, 0x1e, 0xbd, 0x37, 0xef, 0x8e, 0x22, 0x1f, 0x0e, - 0x89, 0x20, 0x52, 0x4b, 0xcf, 0x08, 0x8b, 0x95, 0xf3, 0x01, 0x74, 0x58, 0xe6, 0x27, 0x29, 0x1b, - 0x91, 0x74, 0xe2, 0xbf, 0xa4, 0x13, 0xb4, 0xc9, 0xf0, 0xda, 0x2c, 0x3b, 0x55, 0xc4, 0xaf, 0x28, - 0x7a, 0x8d, 0x65, 0x3e, 0xc9, 0x45, 0xdc, 0x3f, 0x44, 0x8b, 0x0c, 0xcf, 0x60, 0x59, 0x17, 0xbf, - 0xa5, 0x4f, 0x38, 0xcd, 0x04, 0x0d, 0xfd, 0x84, 0x88, 0xe1, 0x66, 0x73, 0x5b, 0x97, 0x3e, 0x51, - 0xa4, 0x53, 0x22, 0x86, 0xce, 0x0e, 0xd8, 0xf2, 0x0c, 0x92, 0x0a, 0x26, 0x58, 0xcc, 0xf1, 0x94, - 0x35, 0x04, 0xe9, 0xb0, 0xec, 0xb4, 0x24, 0x7f, 0x45, 0x27, 0xee, 0x4f, 0x4a, 0x93, 0x8f, 0xde, - 0x24, 0xa9, 0xf3, 0x11, 0x34, 0x18, 0xbf, 0x88, 0xd1, 0x5c, 0xeb, 0xaa, 0x49, 0x18, 0xe1, 0xca, - 0x3f, 0x1e, 0xb2, 0x4a, 0x80, 0xa3, 0x37, 0x2c, 0x13, 0xd9, 0x7f, 0x0b, 0xb0, 0x0f, 0x26, 0xe6, - 0x0b, 0xca, 0x7f, 0x0a, 0xcd, 0xb1, 0xfc, 0x28, 0x00, 0xae, 0xcf, 0x31, 0xe4, 0x76, 0x7f, 0xaf, - 0x41, 0xe7, 0x1b, 0x4e, 0xd2, 0x89, 0x47, 0xf8, 0xa5, 0x42, 0xfa, 0x31, 0x58, 0x01, 0x1e, 0xe5, - 0xaf, 0xae, 0x10, 0x04, 0x55, 0xf4, 0xbf, 0x0f, 0xf5, 0x38, 0x29, 0x62, 0x7b, 0x7b, 0x81, 0xd8, - 0x49, 0x82, 0x71, 0xad, 0xc7, 0x49, 0xa5, 0xb4, 0x7e, 0x23, 0xa5, 0x7f, 0x57, 0x87, 0x8d, 0x7d, - 0xf6, 0x6e, 0xb5, 0xfe, 0x10, 0x36, 0xa2, 0xf8, 0x35, 0x4d, 0x7d, 0xc6, 0x83, 0x28, 0xcf, 0xd8, - 0xb8, 0xbc, 0x72, 0x1d, 0x24, 0xf7, 0x4b, 0xaa, 0x64, 0xcc, 0x93, 0x64, 0x8e, 0x51, 0xa5, 0x61, - 0x07, 0xc9, 0x15, 0xe3, 0x1e, 0x58, 0x0a, 0x51, 0x99, 0xd8, 0x58, 0xcd, 0x44, 0x40, 0x19, 0x55, - 0xfc, 0xf6, 0xc0, 0x52, 0x47, 0x29, 0x84, 0xe6, 0x8a, 0x08, 0x28, 0x83, 0x6b, 0xf7, 0x2f, 0x1a, - 0x58, 0x07, 0xf1, 0x28, 0x21, 0xa9, 0xf2, 0xd2, 0x33, 0xb0, 0x23, 0x7a, 0x21, 0xfc, 0x1b, 0xbb, - 0xaa, 0x23, 0xc5, 0x66, 0xae, 0x78, 0x1f, 0x6e, 0xa5, 0xec, 0x72, 0x38, 0x8f, 0x54, 0x5f, 0x05, - 0x69, 0x03, 0xe5, 0x0e, 0xae, 0xe6, 0x8b, 0xbe, 0x42, 0xbe, 0xb8, 0xbf, 0xd5, 0xc0, 0x38, 0xa7, - 0xe9, 0xe8, 0x9d, 0x44, 0xfc, 0x09, 0xac, 0xa1, 0x5f, 0xb3, 0xcd, 0xfa, 0x6a, 0x65, 0xb9, 0x60, - 0x77, 0xee, 0x81, 0xc5, 0x32, 0x9f, 0x71, 0x1f, 0x8b, 0x5a, 0x11, 0x7d, 0x93, 0x65, 0x7d, 0xfe, - 0xa5, 0x24, 0xb8, 0x7f, 0xaa, 0x83, 0xfd, 0xd3, 0xb3, 0x93, 0xe3, 0x83, 0x98, 0x0b, 0xc2, 0x78, - 0xf6, 0x4e, 0xb4, 0xfd, 0x1c, 0x0c, 0x1a, 0xd1, 0x11, 0xe5, 0x62, 0x65, 0x7d, 0xa7, 0x02, 0xce, - 0xd3, 0x19, 0x17, 0xdf, 0x5f, 0x20, 0x76, 0x55, 0x5b, 0x24, 0x9c, 0x24, 0x78, 0x47, 0x7f, 0x08, - 0x4e, 0x89, 0xe3, 0x57, 0xed, 0x48, 0x15, 0x56, 0xbb, 0xdc, 0x39, 0x2b, 0xdb, 0xd2, 0x11, 0xac, - 0x29, 0x59, 0xc7, 0x82, 0x56, 0x9f, 0x8f, 0x49, 0xc4, 0x42, 0xbb, 0xe6, 0xb4, 0xc1, 0x28, 0xf1, - 0x6d, 0xcd, 0xd9, 0x90, 0x49, 0xa9, 0xbe, 0xba, 0x51, 0x64, 0xd7, 0xe7, 0x08, 0x7c, 0x62, 0xeb, - 0xee, 0xaf, 0x35, 0x30, 0xb1, 0x2c, 0xa1, 0xef, 0x3e, 0x41, 0xf5, 0x35, 0x54, 0xff, 0x83, 0x05, - 0xea, 0x4f, 0x39, 0xd5, 0xaa, 0x50, 0xfc, 0x01, 0x34, 0x83, 0x21, 0x8b, 0xc2, 0x22, 0x2d, 0xbf, - 0xbd, 0x40, 0x50, 0xca, 0x78, 0x8a, 0xcb, 0xdd, 0x82, 0x56, 0x21, 0x3d, 0xaf, 0x7a, 0x0b, 0xf4, - 0xe3, 0x58, 0xd8, 0x9a, 0xfb, 0x37, 0x0d, 0x40, 0x55, 0x1d, 0x54, 0xea, 0xf1, 0x8c, 0x52, 0xdf, - 0x5b, 0x80, 0x5d, 0xb1, 0x16, 0xcb, 0x42, 0xad, 0x1f, 0x40, 0x43, 0xde, 0xa5, 0xeb, 0xb4, 0x42, - 0x26, 0x69, 0x03, 0x5e, 0x97, 0xa2, 0x40, 0x2e, 0xb7, 0x01, 0xb9, 0xdc, 0xc7, 0x60, 0x94, 0x67, - 0xcd, 0x1b, 0xd1, 0x01, 0xf8, 0x3a, 0xbe, 0x64, 0x01, 0x89, 0xba, 0x3c, 0xb4, 0x35, 0x67, 0x1d, - 0xcc, 0xe2, 0xfb, 0x24, 0xb5, 0xeb, 0xee, 0x5f, 0x35, 0x58, 0x57, 0x82, 0xdd, 0x94, 0x89, 0xe1, - 0x49, 0xf2, 0x3f, 0xa7, 0xeb, 0x67, 0x60, 0x10, 0x09, 0xe5, 0x4f, 0x5b, 0xc1, 0xbd, 0x85, 0xa3, - 0x13, 0x9e, 0x86, 0xf7, 0xbb, 0x45, 0x8a, 0xa3, 0x0f, 0x61, 0x5d, 0x95, 0x96, 0x38, 0xa1, 0x29, - 0xe1, 0xe1, 0xaa, 0xcd, 0xa1, 0x8d, 0x52, 0x27, 0x4a, 0xc8, 0xfd, 0x8d, 0x56, 0xf6, 0x08, 0x3c, - 0x04, 0x43, 0x56, 0xba, 0x5e, 0xbb, 0x91, 0xeb, 0xeb, 0xab, 0xb8, 0xde, 0xd9, 0x9d, 0xb9, 0x62, - 0xd7, 0x99, 0x2a, 0x4b, 0xd9, 0x1f, 0xeb, 0x70, 0x67, 0xce, 0xe5, 0x47, 0x63, 0x12, 0xbd, 0xbb, - 0x76, 0xf6, 0xff, 0xf6, 0x7f, 0x51, 0xd5, 0x1b, 0x37, 0x9a, 0x02, 0x9a, 0x37, 0x9a, 0x02, 0x6c, - 0xe8, 0x74, 0xa3, 0xd7, 0x64, 0x92, 0x9d, 0xa7, 0x6a, 0x06, 0x72, 0xff, 0xde, 0x82, 0x06, 0x7a, - 0xef, 0x29, 0x98, 0x82, 0xa6, 0x23, 0x9f, 0xbe, 0x49, 0xd2, 0xc2, 0x77, 0x77, 0x17, 0xa0, 0x96, - 0xad, 0x44, 0x0e, 0xef, 0xa2, 0x6c, 0x2b, 0x5f, 0x00, 0xe4, 0x32, 0x2c, 0x4a, 0x58, 0x05, 0xff, - 0x3b, 0x6f, 0x2b, 0x3a, 0xf2, 0xd1, 0x90, 0x4f, 0xcb, 0xc2, 0x1e, 0x58, 0x03, 0x56, 0xc9, 0xeb, - 0x4b, 0x03, 0x57, 0xd5, 0x87, 0x5e, 0xcd, 0x83, 0x41, 0x55, 0x58, 0x0e, 0xa0, 0x1d, 0xa8, 0x96, - 0xad, 0x20, 0xd4, 0xe0, 0x70, 0x6f, 0x61, 0xec, 0xa7, 0x9d, 0xbd, 0x57, 0xf3, 0xac, 0x60, 0xa6, - 0xd1, 0x3f, 0x07, 0x5b, 0x59, 0x91, 0xca, 0x94, 0x52, 0x40, 0xca, 0xbd, 0xdf, 0x5d, 0x66, 0xcb, - 0x34, 0xf9, 0x7a, 0x35, 0xaf, 0x93, 0xcf, 0x4f, 0x57, 0xa7, 0x70, 0xab, 0xb0, 0x6a, 0x06, 0x6f, - 0x0d, 0xf1, 0xdc, 0xa5, 0xb6, 0xcd, 0x02, 0x6e, 0x0c, 0xae, 0xcc, 0x6b, 0x02, 0xb6, 0x0a, 0xc4, - 0x32, 0x4f, 0x7d, 0x3a, 0x26, 0xd1, 0x2c, 0x7e, 0x0b, 0xf1, 0x1f, 0x2c, 0xc5, 0x5f, 0x74, 0x71, - 0x7a, 0x35, 0xef, 0xce, 0x60, 0xf9, 0xb5, 0xaa, 0xec, 0x50, 0xa7, 0xe2, 0x39, 0xc6, 0x35, 0x76, - 0x4c, 0x0b, 0x48, 0x65, 0x47, 0x55, 0x53, 0xbe, 0x00, 0xc0, 0x74, 0x54, 0x50, 0xe6, 0xd2, 0x74, - 0x99, 0x4e, 0xea, 0x32, 0x5d, 0xc6, 0xd3, 0xb1, 0x7d, 0x6f, 0x7a, 0xcf, 0x51, 0x1e, 0xae, 0xb9, - 0xe7, 0x65, 0xba, 0x04, 0xd5, 0xcb, 0x63, 0x0f, 0x2c, 0x8a, 0xcf, 0x08, 0x85, 0x60, 0x2d, 0x45, - 0xa8, 0x1e, 0x1b, 0x12, 0x81, 0x56, 0x4f, 0x8f, 0xe7, 0x60, 0x13, 0xbc, 0x48, 0xbe, 0x48, 0x4b, - 0x43, 0xda, 0x4b, 0x73, 0x65, 0xfe, 0xce, 0xc9, 0x5c, 0x21, 0x73, 0x14, 0xe7, 0x0c, 0x9c, 0x5f, - 0x66, 0x31, 0xf7, 0x83, 0xa2, 0xa3, 0x2b, 0xc0, 0x75, 0x04, 0x7c, 0x7f, 0x85, 0xe1, 0xa3, 0x57, - 0xf3, 0x6c, 0x09, 0x30, 0x4b, 0xdb, 0x5f, 0x83, 0x86, 0x84, 0x71, 0xff, 0xa1, 0x01, 0xbc, 0xa0, - 0x81, 0x88, 0xd3, 0xee, 0xf1, 0xf1, 0x59, 0xf1, 0xd8, 0x53, 0x31, 0x51, 0xff, 0x07, 0xe4, 0x63, - 0x4f, 0x85, 0x6d, 0xee, 0x19, 0x5a, 0x9f, 0x7f, 0x86, 0x3e, 0x01, 0x48, 0x52, 0x1a, 0xb2, 0x80, - 0x08, 0x9a, 0x5d, 0xd7, 0x5c, 0x67, 0x58, 0x9d, 0xcf, 0x01, 0x5e, 0xc9, 0x57, 0xb7, 0x2a, 0xcb, - 0x8d, 0xa5, 0xe1, 0x9e, 0x3e, 0xcd, 0x3d, 0xf3, 0xd5, 0xf4, 0x95, 0xfe, 0x21, 0x6c, 0x24, 0x11, - 0x09, 0xe8, 0x30, 0x8e, 0x42, 0x9a, 0xfa, 0x82, 0x5c, 0xe2, 0x9d, 0x34, 0xbd, 0xce, 0x0c, 0xf9, - 0x9c, 0x5c, 0xba, 0x13, 0x58, 0x47, 0x80, 0xd3, 0x88, 0xf0, 0xe3, 0x38, 0xa4, 0x57, 0xf4, 0xd5, - 0x56, 0xd7, 0xf7, 0x36, 0x18, 0x2c, 0xf3, 0x83, 0x38, 0xe7, 0xa2, 0x78, 0xcf, 0xb4, 0x58, 0x76, - 0x20, 0x3f, 0x9d, 0x6f, 0x41, 0x33, 0x62, 0x23, 0xa6, 0x66, 0x0b, 0xdd, 0x53, 0x1f, 0xee, 0xbf, - 0x34, 0x30, 0xa6, 0xc7, 0xee, 0x81, 0x35, 0x46, 0x67, 0xfb, 0x84, 0xf3, 0xec, 0x2d, 0x5d, 0xa8, - 0x0a, 0x89, 0xcc, 0x2d, 0x25, 0xd3, 0xe5, 0x3c, 0x73, 0x3e, 0x9b, 0x53, 0xfc, 0xed, 0xad, 0x54, - 0x8a, 0xce, 0xa8, 0xfe, 0x23, 0x68, 0xa2, 0xeb, 0x0a, 0x2f, 0x6f, 0x2f, 0xf3, 0x72, 0xa9, 0x6d, - 0xaf, 0xe6, 0x29, 0x01, 0xf9, 0x88, 0x8f, 0x73, 0x91, 0xe4, 0xc2, 0x2f, 0xe3, 0x2f, 0x63, 0xac, - 0xef, 0xe8, 0x5e, 0x47, 0xd1, 0xbf, 0x54, 0x69, 0x90, 0xc9, 0xb4, 0xe2, 0x71, 0x48, 0xef, 0xff, - 0x41, 0x83, 0x35, 0xd5, 0x91, 0xe6, 0xe7, 0xa6, 0x0d, 0xb0, 0x9e, 0xa5, 0x94, 0x08, 0x9a, 0x9e, - 0x0f, 0x09, 0xb7, 0x35, 0xc7, 0x86, 0x76, 0x41, 0x38, 0x7a, 0x95, 0x13, 0x39, 0xbb, 0xb6, 0xc1, - 0xf8, 0x9a, 0x66, 0x19, 0xee, 0xeb, 0x38, 0x58, 0xd1, 0x2c, 0x53, 0x9b, 0x0d, 0xc7, 0x84, 0xa6, - 0x5a, 0x36, 0x25, 0xdf, 0x71, 0x2c, 0xd4, 0xd7, 0x9a, 0x04, 0x3e, 0x4d, 0xe9, 0x05, 0x7b, 0xf3, - 0x9c, 0x88, 0x60, 0x68, 0xb7, 0x24, 0xf0, 0x69, 0x9c, 0x89, 0x29, 0xc5, 0x90, 0xb2, 0x6a, 0x69, - 0xca, 0x25, 0xd6, 0x30, 0x1b, 0x9c, 0x35, 0xa8, 0xf7, 0xb9, 0x6d, 0x49, 0xd2, 0x71, 0x2c, 0xfa, - 0xdc, 0x6e, 0xdf, 0x7f, 0x06, 0xd6, 0x4c, 0x23, 0x97, 0x06, 0x7c, 0xc3, 0x5f, 0xf2, 0xf8, 0x35, - 0x57, 0xd3, 0x6b, 0x37, 0x94, 0x13, 0x5f, 0x0b, 0xf4, 0xb3, 0x7c, 0x60, 0xd7, 0xe5, 0xe2, 0x79, - 0x1e, 0xd9, 0xba, 0x5c, 0x1c, 0xb2, 0xb1, 0xdd, 0x40, 0x4a, 0x1c, 0xda, 0xcd, 0xfd, 0x8f, 0x7f, - 0xfe, 0xd1, 0x25, 0x13, 0xc3, 0x7c, 0xb0, 0x1b, 0xc4, 0xa3, 0x87, 0xca, 0xdd, 0x0f, 0x58, 0x5c, - 0xac, 0x1e, 0x32, 0x2e, 0x68, 0xca, 0x49, 0xf4, 0x10, 0x23, 0xf0, 0x50, 0x46, 0x20, 0x19, 0x0c, - 0xd6, 0xf0, 0xeb, 0xe3, 0x7f, 0x07, 0x00, 0x00, 0xff, 0xff, 0x48, 0x4f, 0x55, 0x7b, 0xe8, 0x13, + 0x15, 0xb7, 0x2c, 0xff, 0x91, 0x9e, 0x3c, 0x1e, 0x45, 0x45, 0x15, 0x93, 0x84, 0xcd, 0x0c, 0xda, + 0x2d, 0x76, 0x08, 0x64, 0x52, 0xc9, 0xee, 0x26, 0x6c, 0xb6, 0x16, 0xc6, 0xf3, 0x27, 0xb1, 0x21, + 0x99, 0x19, 0x34, 0x93, 0x14, 0xc5, 0x45, 0xd5, 0x96, 0x7a, 0xc6, 0x4d, 0xe4, 0x96, 0x22, 0xb5, + 0x9c, 0xf8, 0x0b, 0x70, 0xe0, 0xc6, 0x07, 0xe0, 0xc4, 0x81, 0x3b, 0xdc, 0xb8, 0xc0, 0x99, 0xe2, + 0xc0, 0x91, 0x23, 0xdf, 0x81, 0x2f, 0x40, 0xf5, 0x6b, 0xd9, 0xb2, 0x07, 0x3b, 0xe3, 0x81, 0x54, + 0xed, 0x4d, 0xfd, 0xfa, 0xbd, 0x5f, 0xbf, 0x7f, 0xfd, 0xde, 0x6b, 0x01, 0x24, 0x11, 0xe1, 0x3b, + 0x49, 0x1a, 0x8b, 0xd8, 0xb9, 0x31, 0x64, 0xd1, 0x28, 0xcf, 0xd4, 0x6a, 0x47, 0x6e, 0xdc, 0x6a, + 0x65, 0xc1, 0x80, 0x0e, 0x89, 0x22, 0xb9, 0x7f, 0xd3, 0xa0, 0xf5, 0x8c, 0x72, 0x9a, 0xb2, 0xe0, + 0x15, 0x89, 0x72, 0xea, 0xdc, 0x06, 0xa3, 0x1f, 0xc7, 0x91, 0x3f, 0x22, 0xd1, 0x86, 0xb6, 0xa5, + 0x6d, 0x1b, 0xdd, 0x8a, 0xd7, 0x94, 0x94, 0x57, 0x24, 0x72, 0x3e, 0x02, 0x93, 0x71, 0xf1, 0xe8, + 0x73, 0xdc, 0xad, 0x6e, 0x69, 0xdb, 0x7a, 0xb7, 0xe2, 0x19, 0x48, 0x2a, 0xb6, 0xcf, 0xa3, 0x98, + 0x08, 0xdc, 0xd6, 0xb7, 0xb4, 0x6d, 0x4d, 0x6e, 0x23, 0x49, 0x6e, 0x6f, 0x02, 0x64, 0x22, 0x65, + 0xfc, 0x02, 0xf7, 0x6b, 0x5b, 0xda, 0xb6, 0xd9, 0xad, 0x78, 0xa6, 0xa2, 0x49, 0x86, 0xc7, 0x60, + 0x92, 0x34, 0x25, 0x63, 0xdc, 0xaf, 0x6f, 0x69, 0xdb, 0xd6, 0xc3, 0x8d, 0x9d, 0xff, 0xb2, 0x60, + 0xa7, 0x23, 0x79, 0x24, 0x32, 0x32, 0xbf, 0x22, 0xd1, 0x5e, 0x1d, 0xf4, 0x11, 0x89, 0xdc, 0xdf, + 0x69, 0x50, 0xc7, 0x4d, 0xe7, 0x0b, 0xa8, 0xe3, 0xe6, 0x86, 0xb6, 0xa5, 0x6f, 0x5b, 0x0f, 0x37, + 0x17, 0xa0, 0xcc, 0x5a, 0xed, 0x29, 0x6e, 0xe7, 0x36, 0x98, 0x19, 0x19, 0x52, 0x5f, 0x8c, 0x13, + 0x8a, 0xf6, 0x19, 0x9e, 0x21, 0x09, 0x67, 0xe3, 0x84, 0x3a, 0xbb, 0xd0, 0xa2, 0x11, 0x1d, 0x52, + 0x2e, 0xd4, 0xbe, 0x34, 0xb0, 0xfd, 0xf0, 0xa3, 0x79, 0xe8, 0xc2, 0xb9, 0x07, 0x44, 0x10, 0x29, + 0xe4, 0x59, 0x85, 0x88, 0x5c, 0xb8, 0xbf, 0xd1, 0xc0, 0xfc, 0x79, 0x4e, 0xd3, 0x71, 0x8f, 0x9f, + 0xc7, 0x8e, 0x03, 0x35, 0x11, 0x27, 0xaf, 0xd1, 0xcb, 0xba, 0x87, 0xdf, 0xce, 0x26, 0x58, 0x43, + 0x2a, 0x52, 0x16, 0x94, 0x47, 0x98, 0x1e, 0x28, 0x12, 0x2a, 0xf1, 0x31, 0xac, 0x65, 0x94, 0xa4, + 0xc1, 0xc0, 0x4f, 0x48, 0x4a, 0x86, 0x99, 0x72, 0xa3, 0xd7, 0x52, 0xc4, 0x13, 0xa4, 0x49, 0xa6, + 0x34, 0xce, 0x79, 0xe8, 0x87, 0x34, 0x60, 0xc3, 0xc2, 0x97, 0xba, 0xd7, 0x42, 0xe2, 0x81, 0xa2, + 0xb9, 0x7f, 0xaa, 0x02, 0xec, 0xc7, 0x51, 0x3e, 0xe4, 0xa8, 0xcd, 0x4d, 0x30, 0xce, 0x19, 0x8d, + 0x42, 0x9f, 0x85, 0x85, 0x46, 0x4d, 0x5c, 0xf7, 0x42, 0xe7, 0x09, 0x98, 0x21, 0x11, 0xa4, 0xf4, + 0xca, 0x95, 0x56, 0x1b, 0x61, 0xf1, 0xe5, 0x7c, 0x02, 0x6d, 0x96, 0xf9, 0x49, 0xca, 0x86, 0x24, + 0x1d, 0xfb, 0xaf, 0xe9, 0x18, 0x6d, 0x32, 0xbc, 0x16, 0xcb, 0x4e, 0x14, 0xf1, 0x67, 0x14, 0xfd, + 0xce, 0x32, 0x9f, 0xe4, 0x22, 0xee, 0x1d, 0xa0, 0x45, 0x86, 0x67, 0xb0, 0xac, 0x83, 0x6b, 0xe9, + 0x13, 0x4e, 0x33, 0x41, 0x43, 0x3f, 0x21, 0x62, 0xb0, 0x51, 0xdf, 0xd2, 0xa5, 0x4f, 0x14, 0xe9, + 0x84, 0x88, 0x81, 0xb3, 0x0d, 0xb6, 0x3c, 0x83, 0xa4, 0x82, 0x09, 0x16, 0x73, 0x3c, 0xa5, 0x81, + 0x20, 0x6d, 0x96, 0x9d, 0x4c, 0xc8, 0xf2, 0x9c, 0xcb, 0x21, 0x6c, 0x5e, 0x3b, 0x84, 0x3f, 0x99, + 0x38, 0xed, 0xf0, 0x5d, 0x92, 0x3a, 0x0f, 0xa0, 0xc6, 0xf8, 0x79, 0x8c, 0x0e, 0xb3, 0x2e, 0xe3, + 0x60, 0x96, 0x95, 0x1e, 0xf6, 0x90, 0x55, 0x02, 0x1c, 0xbe, 0x63, 0x99, 0xc8, 0xfe, 0x57, 0x80, + 0x3d, 0x30, 0x31, 0x67, 0x51, 0xfe, 0x0b, 0xa8, 0x8f, 0xe4, 0xa2, 0x00, 0xb8, 0x3a, 0xcf, 0x91, + 0xdb, 0xfd, 0xa3, 0x06, 0xed, 0x97, 0x9c, 0xa4, 0x63, 0x8f, 0xf0, 0x0b, 0x85, 0xf4, 0x63, 0xb0, + 0x02, 0x3c, 0xca, 0x5f, 0x5d, 0x21, 0x08, 0xca, 0xfc, 0xf9, 0x3e, 0x54, 0xe3, 0xa4, 0xc8, 0x8e, + 0x9b, 0x0b, 0xc4, 0x8e, 0x13, 0x74, 0x66, 0x35, 0x4e, 0x4a, 0xa5, 0xf5, 0x6b, 0x29, 0xfd, 0x87, + 0x2a, 0xac, 0xef, 0xb1, 0x0f, 0xab, 0xf5, 0xa7, 0xb0, 0x1e, 0xc5, 0x6f, 0x69, 0xea, 0x33, 0x1e, + 0x44, 0x79, 0xc6, 0x46, 0x93, 0x6b, 0xdf, 0x46, 0x72, 0x6f, 0x42, 0x95, 0x8c, 0x79, 0x92, 0xcc, + 0x31, 0xaa, 0x44, 0x6e, 0x23, 0xb9, 0x64, 0xdc, 0x05, 0x4b, 0x21, 0x2a, 0x13, 0x6b, 0xab, 0x99, + 0x08, 0x28, 0xa3, 0x2a, 0xf0, 0x2e, 0x58, 0xea, 0x28, 0x85, 0x50, 0x5f, 0x11, 0x01, 0x65, 0xf0, + 0xdb, 0xfd, 0xbb, 0x06, 0xd6, 0x7e, 0x3c, 0x4c, 0x48, 0xaa, 0xbc, 0xf4, 0x0c, 0xec, 0x88, 0x9e, + 0x0b, 0xff, 0xda, 0xae, 0x6a, 0x4b, 0xb1, 0x99, 0x22, 0xd1, 0x83, 0x1b, 0x29, 0xbb, 0x18, 0xcc, + 0x23, 0x55, 0x57, 0x41, 0x5a, 0x47, 0xb9, 0xfd, 0xcb, 0xf9, 0xa2, 0xaf, 0x90, 0x2f, 0xee, 0xef, + 0x35, 0x30, 0xce, 0x68, 0x3a, 0xfc, 0x20, 0x11, 0x7f, 0x0c, 0x0d, 0xf4, 0x6b, 0xb6, 0x51, 0x5d, + 0xad, 0x35, 0x14, 0xec, 0xce, 0x1d, 0xb0, 0x58, 0xe6, 0x33, 0xee, 0x63, 0x59, 0x2c, 0xa2, 0x6f, + 0xb2, 0xac, 0xc7, 0x9f, 0x4a, 0x82, 0xfb, 0xd7, 0x2a, 0xd8, 0x3f, 0x3d, 0x3d, 0x3e, 0xda, 0x8f, + 0xb9, 0x20, 0x8c, 0x67, 0x1f, 0x44, 0xdb, 0xaf, 0xc0, 0x28, 0xaa, 0xcf, 0xca, 0xfa, 0x4e, 0x05, + 0x9c, 0x27, 0x33, 0x2e, 0xbe, 0xbb, 0x40, 0xec, 0xb2, 0xb6, 0x48, 0x38, 0x4e, 0xf0, 0x8e, 0xfe, + 0x10, 0x9c, 0x09, 0x8e, 0x5f, 0xb6, 0x44, 0x55, 0x9a, 0xed, 0xc9, 0xce, 0x69, 0xd1, 0x1a, 0xdd, + 0x43, 0x68, 0x28, 0x59, 0xc7, 0x82, 0x66, 0x8f, 0x8f, 0x48, 0xc4, 0x42, 0xbb, 0xe2, 0xb4, 0xc0, + 0x98, 0xe0, 0xdb, 0x9a, 0xb3, 0x2e, 0x93, 0x52, 0xad, 0x3a, 0x51, 0x64, 0x57, 0xe7, 0x08, 0x7c, + 0x6c, 0xeb, 0xee, 0x6f, 0x35, 0x30, 0xb1, 0x2c, 0xa1, 0xef, 0x3e, 0x47, 0xf5, 0x35, 0x54, 0xff, + 0x93, 0x05, 0xea, 0x4f, 0x39, 0xd5, 0x57, 0xa1, 0xf8, 0x3d, 0xa8, 0x07, 0x03, 0x16, 0x85, 0x45, + 0x5a, 0x7e, 0x7b, 0x81, 0xa0, 0x94, 0xf1, 0x14, 0x97, 0xbb, 0x09, 0xcd, 0x42, 0x7a, 0x5e, 0xf5, + 0x26, 0xe8, 0x47, 0xb1, 0xb0, 0x35, 0xf7, 0x9f, 0x1a, 0x80, 0xaa, 0x3a, 0xa8, 0xd4, 0xa3, 0x19, + 0xa5, 0xbe, 0xb7, 0x00, 0xbb, 0x64, 0x2d, 0x3e, 0x0b, 0xb5, 0x7e, 0x00, 0x35, 0x79, 0x97, 0xae, + 0xd2, 0x0a, 0x99, 0xa4, 0x0d, 0x78, 0x5d, 0x8a, 0x02, 0xb9, 0xdc, 0x06, 0xe4, 0x72, 0x1f, 0x81, + 0x31, 0x39, 0x6b, 0xde, 0x88, 0x36, 0xc0, 0xf3, 0xf8, 0x82, 0x05, 0x24, 0xea, 0xf0, 0xd0, 0xd6, + 0x9c, 0x35, 0x30, 0x8b, 0xf5, 0x71, 0x6a, 0x57, 0xdd, 0x7f, 0x68, 0xb0, 0xa6, 0x04, 0x3b, 0x29, + 0x13, 0x83, 0xe3, 0xe4, 0xff, 0x4e, 0xd7, 0x2f, 0xc1, 0x20, 0x12, 0xca, 0x9f, 0xb6, 0x82, 0x3b, + 0x0b, 0xe7, 0x37, 0x3c, 0x0d, 0xef, 0x77, 0x93, 0x14, 0x47, 0x1f, 0xc0, 0x9a, 0x2a, 0x2d, 0x71, + 0x42, 0x53, 0xc2, 0xc3, 0x55, 0x9b, 0x43, 0x0b, 0xa5, 0x8e, 0x95, 0x90, 0x9c, 0x00, 0xd7, 0x67, + 0x4c, 0xc2, 0x90, 0x4d, 0x5c, 0xaf, 0x5d, 0xcb, 0xf5, 0xd5, 0x55, 0x5c, 0xef, 0xec, 0xcc, 0x5c, + 0xb1, 0xab, 0x4c, 0x95, 0xa5, 0xec, 0x2f, 0x55, 0xb8, 0x35, 0xe7, 0xf2, 0xc3, 0x11, 0x89, 0x3e, + 0x5c, 0x3b, 0xfb, 0xa6, 0xfd, 0x5f, 0x54, 0xf5, 0xda, 0xb5, 0xa6, 0x80, 0xfa, 0xb5, 0xa6, 0x00, + 0x1b, 0xda, 0x9d, 0xe8, 0x2d, 0x19, 0x67, 0x67, 0xa9, 0x9a, 0x81, 0xdc, 0x7f, 0x35, 0xa1, 0x86, + 0xde, 0x7b, 0x02, 0xa6, 0xa0, 0xe9, 0xd0, 0xa7, 0xef, 0x92, 0xb4, 0xf0, 0xdd, 0xed, 0x05, 0xa8, + 0x93, 0x56, 0x22, 0x5f, 0x10, 0x62, 0xd2, 0x56, 0xbe, 0x06, 0xc8, 0x65, 0x58, 0x94, 0xb0, 0x0a, + 0xfe, 0x77, 0xde, 0x57, 0x74, 0xe4, 0xcb, 0x25, 0x9f, 0x96, 0x85, 0x5d, 0xb0, 0xfa, 0xac, 0x94, + 0xd7, 0x97, 0x06, 0xae, 0xac, 0x0f, 0xdd, 0x8a, 0x07, 0xfd, 0xb2, 0xb0, 0xec, 0x43, 0x2b, 0x50, + 0x2d, 0x5b, 0x41, 0xa8, 0xc1, 0xe1, 0xce, 0xc2, 0xd8, 0x4f, 0x3b, 0x7b, 0xb7, 0xe2, 0x59, 0xc1, + 0x4c, 0xa3, 0x7f, 0x01, 0xb6, 0xb2, 0x22, 0x95, 0x29, 0xa5, 0x80, 0x94, 0x7b, 0xbf, 0xbb, 0xcc, + 0x96, 0x69, 0xf2, 0x75, 0x2b, 0x5e, 0x3b, 0x9f, 0x9f, 0xae, 0x4e, 0xe0, 0x46, 0x61, 0xd5, 0x0c, + 0x5e, 0x03, 0xf1, 0xdc, 0xa5, 0xb6, 0xcd, 0x02, 0xae, 0xf7, 0x2f, 0xcd, 0x6b, 0x02, 0x36, 0x0b, + 0xc4, 0x49, 0x9e, 0xfa, 0x74, 0x44, 0xa2, 0x59, 0xfc, 0x26, 0xe2, 0xdf, 0x5b, 0x8a, 0xbf, 0xe8, + 0xe2, 0x74, 0x2b, 0xde, 0xad, 0xfe, 0xf2, 0x6b, 0x55, 0xda, 0xa1, 0x4e, 0xc5, 0x73, 0x8c, 0x2b, + 0xec, 0x98, 0x16, 0x90, 0xd2, 0x8e, 0xb2, 0xa6, 0x7c, 0x0d, 0x80, 0xe9, 0xa8, 0xa0, 0xcc, 0xa5, + 0xe9, 0x32, 0x9d, 0xd4, 0x65, 0xba, 0x8c, 0xa6, 0x63, 0xfb, 0xee, 0xf4, 0x9e, 0xa3, 0x3c, 0x5c, + 0x71, 0xcf, 0x27, 0xe9, 0x12, 0x94, 0x2f, 0x8f, 0x5d, 0xb0, 0x28, 0x3e, 0x23, 0x14, 0x82, 0xb5, + 0x14, 0xa1, 0x7c, 0x6c, 0x48, 0x04, 0x5a, 0x3e, 0x3d, 0x5e, 0x80, 0x4d, 0xf0, 0x22, 0xf9, 0x22, + 0x9d, 0x18, 0xd2, 0x5a, 0x9a, 0x2b, 0xf3, 0x77, 0x4e, 0xe6, 0x0a, 0x99, 0xa3, 0x38, 0xa7, 0xe0, + 0xfc, 0x2a, 0x8b, 0xb9, 0x1f, 0x14, 0x1d, 0x5d, 0x01, 0xae, 0x21, 0xe0, 0xc7, 0x2b, 0x0c, 0x1f, + 0xdd, 0x8a, 0x67, 0x4b, 0x80, 0x59, 0xda, 0x5e, 0x03, 0x6a, 0x12, 0xc6, 0xfd, 0x75, 0x15, 0xe0, + 0x15, 0x0d, 0x44, 0x9c, 0x76, 0x8e, 0x8e, 0x4e, 0x65, 0x99, 0x1c, 0xe1, 0x4a, 0x4d, 0x25, 0xda, + 0xa2, 0x57, 0x9c, 0x72, 0x3f, 0x72, 0x61, 0xc9, 0x81, 0xd1, 0xf4, 0x7b, 0xee, 0xad, 0x5b, 0x9d, + 0x7f, 0xeb, 0x3e, 0x06, 0x48, 0x52, 0x1a, 0xb2, 0x80, 0x08, 0x9a, 0x5d, 0xd5, 0x7f, 0x67, 0x58, + 0x9d, 0xaf, 0x00, 0xde, 0xc8, 0xa7, 0xbd, 0xaa, 0xdc, 0xb5, 0xa5, 0x19, 0x31, 0x7d, 0xff, 0x7b, + 0xe6, 0x9b, 0xe9, 0xaf, 0x80, 0x4f, 0x61, 0x3d, 0x89, 0x48, 0x40, 0x07, 0x71, 0x14, 0xd2, 0xd4, + 0x17, 0xe4, 0x02, 0xaf, 0xad, 0xe9, 0xb5, 0x67, 0xc8, 0x67, 0xe4, 0xc2, 0x1d, 0xc3, 0x1a, 0x02, + 0x9c, 0x44, 0x84, 0x1f, 0xc5, 0x21, 0xbd, 0xa4, 0xaf, 0xb6, 0xba, 0xbe, 0x37, 0xc1, 0x60, 0x99, + 0x1f, 0xc4, 0x39, 0x17, 0xc5, 0x93, 0xa7, 0xc9, 0xb2, 0x7d, 0xb9, 0x74, 0xbe, 0x05, 0xf5, 0x88, + 0x0d, 0x99, 0x1a, 0x3f, 0x74, 0x4f, 0x2d, 0xdc, 0x7f, 0x6b, 0x60, 0x4c, 0x8f, 0xdd, 0x9d, 0x46, + 0x80, 0x70, 0x9e, 0xbd, 0xa7, 0x51, 0x95, 0x51, 0x93, 0xe9, 0xa7, 0x64, 0x3a, 0x9c, 0x67, 0xce, + 0x97, 0x73, 0x8a, 0xbf, 0xbf, 0xdb, 0x4a, 0xd1, 0x19, 0xd5, 0x7f, 0x04, 0x75, 0x74, 0x5d, 0xe1, + 0xe5, 0xad, 0x65, 0x5e, 0x9e, 0x68, 0xdb, 0xad, 0x78, 0x4a, 0xc0, 0xd9, 0x06, 0x3b, 0xce, 0x45, + 0x92, 0x0b, 0x7f, 0x12, 0x7f, 0x19, 0x63, 0x7d, 0x5b, 0xf7, 0xda, 0x8a, 0xfe, 0x54, 0xa5, 0x41, + 0x26, 0x33, 0x8f, 0xc7, 0x21, 0xbd, 0xfb, 0x67, 0x0d, 0x1a, 0xaa, 0x69, 0xcd, 0x8f, 0x56, 0xeb, + 0x60, 0x3d, 0x4b, 0x29, 0x11, 0x34, 0x3d, 0x1b, 0x10, 0x6e, 0x6b, 0x8e, 0x0d, 0xad, 0x82, 0x70, + 0xf8, 0x26, 0x27, 0x72, 0xbc, 0x6d, 0x81, 0xf1, 0x9c, 0x66, 0x19, 0xee, 0xeb, 0x38, 0x7b, 0xd1, + 0x2c, 0x53, 0x9b, 0x35, 0xc7, 0x84, 0xba, 0xfa, 0xac, 0x4b, 0xbe, 0xa3, 0x58, 0xa8, 0x55, 0x43, + 0x02, 0x9f, 0xa4, 0xf4, 0x9c, 0xbd, 0x7b, 0x41, 0x44, 0x30, 0xb0, 0x9b, 0x12, 0xf8, 0x24, 0xce, + 0xc4, 0x94, 0x62, 0x48, 0x59, 0xf5, 0x69, 0xca, 0x4f, 0x2c, 0x73, 0x36, 0x38, 0x0d, 0xa8, 0xf6, + 0xb8, 0x6d, 0x49, 0xd2, 0x51, 0x2c, 0x7a, 0xdc, 0x6e, 0xdd, 0xfd, 0x05, 0x58, 0x33, 0xbd, 0x5e, + 0x1a, 0xf0, 0x92, 0xbf, 0xe6, 0xf1, 0x5b, 0xae, 0x06, 0xdc, 0x4e, 0x28, 0x87, 0xc2, 0x26, 0xe8, + 0xa7, 0x79, 0xdf, 0xae, 0xca, 0x8f, 0x17, 0x79, 0x64, 0xeb, 0xf2, 0xe3, 0x80, 0x8d, 0xec, 0x1a, + 0x52, 0xe2, 0xd0, 0xae, 0x4b, 0xa5, 0xf0, 0xbf, 0xda, 0x73, 0xca, 0x2f, 0xc4, 0xc0, 0x6e, 0xdc, + 0xdd, 0x9b, 0xdc, 0x47, 0x04, 0xb6, 0xa1, 0xa5, 0x6a, 0xa6, 0xa2, 0x29, 0xf7, 0x3c, 0xc5, 0xdf, + 0x7e, 0x8a, 0xa0, 0x39, 0x37, 0x60, 0x0d, 0x09, 0x0f, 0x1e, 0x15, 0xa4, 0xea, 0xde, 0x67, 0xbf, + 0x7c, 0x70, 0xc1, 0xc4, 0x20, 0xef, 0xef, 0x04, 0xf1, 0xf0, 0xbe, 0x8a, 0xe1, 0x3d, 0x16, 0x17, + 0x5f, 0xf7, 0x19, 0x17, 0x34, 0xe5, 0x24, 0xba, 0x8f, 0x61, 0xbd, 0x2f, 0xc3, 0x9a, 0xf4, 0xfb, + 0x0d, 0x5c, 0x7d, 0xf6, 0x9f, 0x00, 0x00, 0x00, 0xff, 0xff, 0x4f, 0xa9, 0x51, 0x74, 0xe5, 0x14, 0x00, 0x00, } diff --git a/internal/proto/query_coord.proto b/internal/proto/query_coord.proto index 0973aec4f1453..6f1bc40b6dd9c 100644 --- a/internal/proto/query_coord.proto +++ b/internal/proto/query_coord.proto @@ -67,7 +67,9 @@ service QueryNode { rpc Search(SearchRequest) returns (internal.SearchResults) {} rpc SearchSegments(SearchRequest) returns (internal.SearchResults) {} rpc Query(QueryRequest) returns (internal.RetrieveResults) {} + rpc QueryStream(QueryRequest) returns (stream internal.RetrieveResults){} rpc QuerySegments(QueryRequest) returns (internal.RetrieveResults) {} + rpc QueryStreamSegments(QueryRequest) returns (stream internal.RetrieveResults){} rpc ShowConfigurations(internal.ShowConfigurationsRequest) returns (internal.ShowConfigurationsResponse){} // https://wiki.lfaidata.foundation/display/MIL/MEP+8+--+Add+metrics+for+proxy @@ -78,7 +80,8 @@ service QueryNode { rpc Delete(DeleteRequest) returns (common.Status) {} } -//--------------------QueryCoord grpc request and response proto------------------ +// --------------------QueryCoord grpc request and response proto------------------ + message ShowCollectionsRequest { common.MsgBase base = 1; // Not useful for now @@ -204,7 +207,8 @@ message SyncNewCreatedPartitionRequest { int64 partitionID = 3; } -//-----------------query node grpc request and response proto---------------- +// -----------------query node grpc request and response proto---------------- + message LoadMetaInfo { LoadType load_type = 1; int64 collectionID = 2; @@ -269,6 +273,7 @@ message FieldIndexInfo { int64 index_size = 8; int64 index_version = 9; int64 num_rows = 10; + int32 current_index_version = 11; } enum LoadScope { @@ -348,7 +353,8 @@ message GetLoadInfoResponse { repeated int64 partitions = 4; } -//----------------request auto triggered by QueryCoord----------------- +// ----------------request auto triggered by QueryCoord----------------- + message HandoffSegmentsRequest { common.MsgBase base = 1; repeated SegmentInfo segmentInfos = 2; @@ -364,7 +370,7 @@ message LoadBalanceRequest { int64 collectionID = 6; } -//-------------------- internal meta proto------------------ +// -------------------- internal meta proto------------------ enum DataScope { UnKnown = 0; @@ -462,7 +468,8 @@ message UnsubscribeChannelInfo { repeated UnsubscribeChannels collection_channels = 2; } -//---- synchronize messages proto between QueryCoord and QueryNode ----- +// ---- synchronize messages proto between QueryCoord and QueryNode ----- + message SegmentChangeInfo { int64 online_nodeID = 1; repeated SegmentInfo online_segments = 2; @@ -532,6 +539,7 @@ message CollectionLoadInfo { LoadStatus status = 4; map field_indexID = 5; LoadType load_type = 6; + int32 recover_times = 7; } message PartitionLoadInfo { @@ -540,6 +548,7 @@ message PartitionLoadInfo { int32 replica_number = 3; // Deprecated: No longer used; kept for compatibility. LoadStatus status = 4; map field_indexID = 5; // Deprecated: No longer used; kept for compatibility. + int32 recover_times = 7; } message Replica { diff --git a/internal/proto/querypb/query_coord.pb.go b/internal/proto/querypb/query_coord.pb.go index 7d796b86e3252..313c55f6dc665 100644 --- a/internal/proto/querypb/query_coord.pb.go +++ b/internal/proto/querypb/query_coord.pb.go @@ -251,7 +251,6 @@ func (SyncType) EnumDescriptor() ([]byte, []int) { return fileDescriptor_aab7cc9a69ed26e8, []int{6} } -// --------------------QueryCoord grpc request and response proto------------------ type ShowCollectionsRequest struct { Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` // Not useful for now @@ -1336,7 +1335,6 @@ func (m *SyncNewCreatedPartitionRequest) GetPartitionID() int64 { return 0 } -// -----------------query node grpc request and response proto---------------- type LoadMetaInfo struct { LoadType LoadType `protobuf:"varint,1,opt,name=load_type,json=loadType,proto3,enum=milvus.proto.query.LoadType" json:"load_type,omitempty"` CollectionID int64 `protobuf:"varint,2,opt,name=collectionID,proto3" json:"collectionID,omitempty"` @@ -1779,6 +1777,7 @@ type FieldIndexInfo struct { IndexSize int64 `protobuf:"varint,8,opt,name=index_size,json=indexSize,proto3" json:"index_size,omitempty"` IndexVersion int64 `protobuf:"varint,9,opt,name=index_version,json=indexVersion,proto3" json:"index_version,omitempty"` NumRows int64 `protobuf:"varint,10,opt,name=num_rows,json=numRows,proto3" json:"num_rows,omitempty"` + CurrentIndexVersion int32 `protobuf:"varint,11,opt,name=current_index_version,json=currentIndexVersion,proto3" json:"current_index_version,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -1879,6 +1878,13 @@ func (m *FieldIndexInfo) GetNumRows() int64 { return 0 } +func (m *FieldIndexInfo) GetCurrentIndexVersion() int32 { + if m != nil { + return m.CurrentIndexVersion + } + return 0 +} + type LoadSegmentsRequest struct { Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` DstNodeID int64 `protobuf:"varint,2,opt,name=dst_nodeID,json=dstNodeID,proto3" json:"dst_nodeID,omitempty"` @@ -2496,7 +2502,6 @@ func (m *GetLoadInfoResponse) GetPartitions() []int64 { return nil } -// ----------------request auto triggered by QueryCoord----------------- type HandoffSegmentsRequest struct { Base *commonpb.MsgBase `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` SegmentInfos []*SegmentInfo `protobuf:"bytes,2,rep,name=segmentInfos,proto3" json:"segmentInfos,omitempty"` @@ -3193,7 +3198,6 @@ func (m *UnsubscribeChannelInfo) GetCollectionChannels() []*UnsubscribeChannels return nil } -// ---- synchronize messages proto between QueryCoord and QueryNode ----- type SegmentChangeInfo struct { OnlineNodeID int64 `protobuf:"varint,1,opt,name=online_nodeID,json=onlineNodeID,proto3" json:"online_nodeID,omitempty"` OnlineSegments []*SegmentInfo `protobuf:"bytes,2,rep,name=online_segments,json=onlineSegments,proto3" json:"online_segments,omitempty"` @@ -3697,6 +3701,7 @@ type CollectionLoadInfo struct { Status LoadStatus `protobuf:"varint,4,opt,name=status,proto3,enum=milvus.proto.query.LoadStatus" json:"status,omitempty"` FieldIndexID map[int64]int64 `protobuf:"bytes,5,rep,name=field_indexID,json=fieldIndexID,proto3" json:"field_indexID,omitempty" protobuf_key:"varint,1,opt,name=key,proto3" protobuf_val:"varint,2,opt,name=value,proto3"` LoadType LoadType `protobuf:"varint,6,opt,name=load_type,json=loadType,proto3,enum=milvus.proto.query.LoadType" json:"load_type,omitempty"` + RecoverTimes int32 `protobuf:"varint,7,opt,name=recover_times,json=recoverTimes,proto3" json:"recover_times,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -3769,12 +3774,20 @@ func (m *CollectionLoadInfo) GetLoadType() LoadType { return LoadType_UnKnownType } +func (m *CollectionLoadInfo) GetRecoverTimes() int32 { + if m != nil { + return m.RecoverTimes + } + return 0 +} + type PartitionLoadInfo struct { CollectionID int64 `protobuf:"varint,1,opt,name=collectionID,proto3" json:"collectionID,omitempty"` PartitionID int64 `protobuf:"varint,2,opt,name=partitionID,proto3" json:"partitionID,omitempty"` ReplicaNumber int32 `protobuf:"varint,3,opt,name=replica_number,json=replicaNumber,proto3" json:"replica_number,omitempty"` Status LoadStatus `protobuf:"varint,4,opt,name=status,proto3,enum=milvus.proto.query.LoadStatus" json:"status,omitempty"` FieldIndexID map[int64]int64 `protobuf:"bytes,5,rep,name=field_indexID,json=fieldIndexID,proto3" json:"field_indexID,omitempty" protobuf_key:"varint,1,opt,name=key,proto3" protobuf_val:"varint,2,opt,name=value,proto3"` + RecoverTimes int32 `protobuf:"varint,7,opt,name=recover_times,json=recoverTimes,proto3" json:"recover_times,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -3840,6 +3853,13 @@ func (m *PartitionLoadInfo) GetFieldIndexID() map[int64]int64 { return nil } +func (m *PartitionLoadInfo) GetRecoverTimes() int32 { + if m != nil { + return m.RecoverTimes + } + return 0 +} + type Replica struct { ID int64 `protobuf:"varint,1,opt,name=ID,proto3" json:"ID,omitempty"` CollectionID int64 `protobuf:"varint,2,opt,name=collectionID,proto3" json:"collectionID,omitempty"` @@ -4582,306 +4602,310 @@ func init() { func init() { proto.RegisterFile("query_coord.proto", fileDescriptor_aab7cc9a69ed26e8) } var fileDescriptor_aab7cc9a69ed26e8 = []byte{ - // 4784 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xec, 0x3c, 0x4b, 0x6f, 0x1c, 0x47, - 0x7a, 0x9a, 0x27, 0x67, 0xbe, 0x79, 0x35, 0x8b, 0xa4, 0x34, 0x3b, 0x2b, 0xc9, 0x74, 0xcb, 0x0f, - 0x2e, 0x6d, 0x93, 0x5a, 0x6a, 0xed, 0xd5, 0xae, 0x6d, 0x38, 0x12, 0x69, 0xc9, 0x5c, 0xcb, 0x34, - 0xb7, 0x49, 0x69, 0x03, 0xc7, 0xbb, 0xe3, 0xe6, 0x74, 0x71, 0xd8, 0x50, 0x3f, 0x46, 0xdd, 0x3d, - 0xa4, 0xe8, 0x00, 0x41, 0x0e, 0xb9, 0x64, 0x93, 0x4d, 0x82, 0xe4, 0x90, 0x1c, 0x16, 0x39, 0x24, - 0x08, 0xb0, 0x09, 0x92, 0x4b, 0x90, 0x00, 0x39, 0xe4, 0x90, 0x5b, 0x4e, 0x79, 0xdc, 0xf2, 0x07, - 0x92, 0x5b, 0x80, 0x9c, 0x16, 0x81, 0x6f, 0x41, 0x3d, 0xfa, 0x51, 0xdd, 0x35, 0x9c, 0x26, 0x47, - 0x5e, 0xdb, 0xc1, 0xde, 0xa6, 0xbf, 0x7a, 0x7c, 0x5f, 0x7d, 0xaf, 0xfa, 0xbe, 0xaf, 0xaa, 0x06, - 0xe6, 0x9f, 0x8c, 0xb1, 0x77, 0xda, 0x1f, 0xb8, 0xae, 0x67, 0xac, 0x8d, 0x3c, 0x37, 0x70, 0x11, - 0xb2, 0x4d, 0xeb, 0x78, 0xec, 0xb3, 0xaf, 0x35, 0xda, 0xde, 0x6b, 0x0e, 0x5c, 0xdb, 0x76, 0x1d, - 0x06, 0xeb, 0x35, 0x93, 0x3d, 0x7a, 0x6d, 0xd3, 0x09, 0xb0, 0xe7, 0xe8, 0x56, 0xd8, 0xea, 0x0f, - 0x8e, 0xb0, 0xad, 0xf3, 0xaf, 0xba, 0xed, 0x0f, 0xf9, 0x4f, 0xc5, 0xd0, 0x03, 0x3d, 0x89, 0xaa, - 0x37, 0x6f, 0x3a, 0x06, 0x7e, 0x9a, 0x04, 0xa9, 0xbf, 0x55, 0x80, 0xcb, 0x7b, 0x47, 0xee, 0xc9, - 0xa6, 0x6b, 0x59, 0x78, 0x10, 0x98, 0xae, 0xe3, 0x6b, 0xf8, 0xc9, 0x18, 0xfb, 0x01, 0xba, 0x09, - 0xe5, 0x03, 0xdd, 0xc7, 0xdd, 0xc2, 0x72, 0x61, 0xa5, 0xb1, 0x71, 0x75, 0x4d, 0xa0, 0x93, 0x13, - 0xf8, 0x81, 0x3f, 0xbc, 0xab, 0xfb, 0x58, 0xa3, 0x3d, 0x11, 0x82, 0xb2, 0x71, 0xb0, 0xbd, 0xd5, - 0x2d, 0x2e, 0x17, 0x56, 0x4a, 0x1a, 0xfd, 0x8d, 0x5e, 0x80, 0xd6, 0x20, 0x9a, 0x7b, 0x7b, 0xcb, - 0xef, 0x96, 0x96, 0x4b, 0x2b, 0x25, 0x4d, 0x04, 0xaa, 0x3f, 0x2e, 0xc2, 0x95, 0x0c, 0x19, 0xfe, - 0xc8, 0x75, 0x7c, 0x8c, 0x6e, 0x41, 0xd5, 0x0f, 0xf4, 0x60, 0xec, 0x73, 0x4a, 0xbe, 0x2e, 0xa5, - 0x64, 0x8f, 0x76, 0xd1, 0x78, 0xd7, 0x2c, 0xda, 0xa2, 0x04, 0x2d, 0xfa, 0x26, 0x2c, 0x9a, 0xce, - 0x07, 0xd8, 0x76, 0xbd, 0xd3, 0xfe, 0x08, 0x7b, 0x03, 0xec, 0x04, 0xfa, 0x10, 0x87, 0x34, 0x2e, - 0x84, 0x6d, 0xbb, 0x71, 0x13, 0x7a, 0x03, 0xae, 0x30, 0x19, 0xfa, 0xd8, 0x3b, 0x36, 0x07, 0xb8, - 0xaf, 0x1f, 0xeb, 0xa6, 0xa5, 0x1f, 0x58, 0xb8, 0x5b, 0x5e, 0x2e, 0xad, 0xd4, 0xb4, 0x25, 0xda, - 0xbc, 0xc7, 0x5a, 0xef, 0x84, 0x8d, 0xe8, 0x1b, 0xa0, 0x78, 0xf8, 0xd0, 0xc3, 0xfe, 0x51, 0x7f, - 0xe4, 0xb9, 0x43, 0x0f, 0xfb, 0x7e, 0xb7, 0x42, 0xd1, 0x74, 0x38, 0x7c, 0x97, 0x83, 0xd5, 0xbf, - 0x28, 0xc0, 0x12, 0x61, 0xc6, 0xae, 0xee, 0x05, 0xe6, 0xe7, 0x20, 0x12, 0x15, 0x9a, 0x49, 0x36, - 0x74, 0x4b, 0xb4, 0x4d, 0x80, 0x91, 0x3e, 0xa3, 0x10, 0x3d, 0x61, 0x5f, 0x99, 0x92, 0x2a, 0xc0, - 0xd4, 0x7f, 0xe3, 0xba, 0x93, 0xa4, 0x73, 0x16, 0x99, 0xa5, 0x71, 0x16, 0xb3, 0x38, 0x2f, 0x22, - 0x31, 0x19, 0xe7, 0xcb, 0x72, 0xce, 0xff, 0x4b, 0x09, 0x96, 0x1e, 0xb8, 0xba, 0x11, 0xab, 0xe1, - 0x2f, 0x9e, 0xf3, 0x6f, 0x43, 0x95, 0x59, 0x74, 0xb7, 0x4c, 0x71, 0xbd, 0x28, 0xe2, 0xe2, 0xd6, - 0x1e, 0x53, 0xb8, 0x47, 0x01, 0x1a, 0x1f, 0x84, 0x5e, 0x84, 0xb6, 0x87, 0x47, 0x96, 0x39, 0xd0, - 0xfb, 0xce, 0xd8, 0x3e, 0xc0, 0x5e, 0xb7, 0xb2, 0x5c, 0x58, 0xa9, 0x68, 0x2d, 0x0e, 0xdd, 0xa1, - 0x40, 0xf4, 0x09, 0xb4, 0x0e, 0x4d, 0x6c, 0x19, 0x7d, 0xea, 0x12, 0xb6, 0xb7, 0xba, 0xd5, 0xe5, - 0xd2, 0x4a, 0x63, 0xe3, 0xcd, 0xb5, 0xac, 0x37, 0x5a, 0x93, 0x72, 0x64, 0xed, 0x1e, 0x19, 0xbe, - 0xcd, 0x46, 0xbf, 0xeb, 0x04, 0xde, 0xa9, 0xd6, 0x3c, 0x4c, 0x80, 0x50, 0x17, 0xe6, 0x38, 0x7b, - 0xbb, 0x73, 0xcb, 0x85, 0x95, 0x9a, 0x16, 0x7e, 0xa2, 0x97, 0xa1, 0xe3, 0x61, 0xdf, 0x1d, 0x7b, - 0x03, 0xdc, 0x1f, 0x7a, 0xee, 0x78, 0xe4, 0x77, 0x6b, 0xcb, 0xa5, 0x95, 0xba, 0xd6, 0x0e, 0xc1, - 0xf7, 0x29, 0xb4, 0xf7, 0x0e, 0xcc, 0x67, 0xb0, 0x20, 0x05, 0x4a, 0x8f, 0xf1, 0x29, 0x15, 0x44, - 0x49, 0x23, 0x3f, 0xd1, 0x22, 0x54, 0x8e, 0x75, 0x6b, 0x8c, 0x39, 0xab, 0xd9, 0xc7, 0x77, 0x8b, - 0xb7, 0x0b, 0xea, 0x4f, 0x0b, 0xd0, 0xd5, 0xb0, 0x85, 0x75, 0x1f, 0x7f, 0x91, 0x22, 0xbd, 0x0c, - 0x55, 0xc7, 0x35, 0xf0, 0xf6, 0x16, 0x15, 0x69, 0x49, 0xe3, 0x5f, 0xea, 0x67, 0x05, 0x58, 0xbc, - 0x8f, 0x03, 0x62, 0x06, 0xa6, 0x1f, 0x98, 0x83, 0xc8, 0xce, 0xdf, 0x86, 0x92, 0x87, 0x9f, 0x70, - 0xca, 0x5e, 0x11, 0x29, 0x8b, 0xdc, 0xbf, 0x6c, 0xa4, 0x46, 0xc6, 0xa1, 0xe7, 0xa1, 0x69, 0xd8, - 0x56, 0x7f, 0x70, 0xa4, 0x3b, 0x0e, 0xb6, 0x98, 0x21, 0xd5, 0xb5, 0x86, 0x61, 0x5b, 0x9b, 0x1c, - 0x84, 0xae, 0x03, 0xf8, 0x78, 0x68, 0x63, 0x27, 0x88, 0x7d, 0x72, 0x02, 0x82, 0x56, 0x61, 0xfe, - 0xd0, 0x73, 0xed, 0xbe, 0x7f, 0xa4, 0x7b, 0x46, 0xdf, 0xc2, 0xba, 0x81, 0x3d, 0x4a, 0x7d, 0x4d, - 0xeb, 0x90, 0x86, 0x3d, 0x02, 0x7f, 0x40, 0xc1, 0xe8, 0x16, 0x54, 0xfc, 0x81, 0x3b, 0xc2, 0x54, - 0xd3, 0xda, 0x1b, 0xd7, 0x64, 0x3a, 0xb4, 0xa5, 0x07, 0xfa, 0x1e, 0xe9, 0xa4, 0xb1, 0xbe, 0xea, - 0x3f, 0x94, 0x99, 0xa9, 0x7d, 0xc9, 0x9d, 0x5c, 0xc2, 0x1c, 0x2b, 0xcf, 0xc6, 0x1c, 0xab, 0xb9, - 0xcc, 0x71, 0xee, 0x6c, 0x73, 0xcc, 0x70, 0xed, 0x3c, 0xe6, 0x58, 0x9b, 0x6a, 0x8e, 0x75, 0x99, - 0x39, 0xa2, 0x77, 0xa1, 0xc3, 0x02, 0x08, 0xd3, 0x39, 0x74, 0xfb, 0x96, 0xe9, 0x07, 0x5d, 0xa0, - 0x64, 0x5e, 0x4b, 0x6b, 0xa8, 0x81, 0x9f, 0xae, 0x31, 0xc4, 0xce, 0xa1, 0xab, 0xb5, 0xcc, 0xf0, - 0xe7, 0x03, 0xd3, 0x0f, 0x66, 0xb7, 0xea, 0x7f, 0x8a, 0xad, 0xfa, 0xcb, 0xae, 0x3d, 0xb1, 0xe5, - 0x57, 0x04, 0xcb, 0xff, 0xcb, 0x02, 0x7c, 0xed, 0x3e, 0x0e, 0x22, 0xf2, 0x89, 0x21, 0xe3, 0x2f, - 0xe9, 0x36, 0xff, 0x37, 0x05, 0xe8, 0xc9, 0x68, 0x9d, 0x65, 0xab, 0xff, 0x08, 0x2e, 0x47, 0x38, - 0xfa, 0x06, 0xf6, 0x07, 0x9e, 0x39, 0xa2, 0x62, 0xa4, 0xbe, 0xaa, 0xb1, 0x71, 0x43, 0xa6, 0xf8, - 0x69, 0x0a, 0x96, 0xa2, 0x29, 0xb6, 0x12, 0x33, 0xa8, 0x3f, 0x29, 0xc0, 0x12, 0xf1, 0x8d, 0xdc, - 0x99, 0x11, 0x0d, 0xbc, 0x30, 0x5f, 0x45, 0x37, 0x59, 0xcc, 0xb8, 0xc9, 0x1c, 0x3c, 0xa6, 0x21, - 0x76, 0x9a, 0x9e, 0x59, 0x78, 0xf7, 0x3a, 0x54, 0x88, 0x01, 0x86, 0xac, 0x7a, 0x4e, 0xc6, 0xaa, - 0x24, 0x32, 0xd6, 0x5b, 0x75, 0x18, 0x15, 0xb1, 0xdf, 0x9e, 0x41, 0xdd, 0xd2, 0xcb, 0x2e, 0x4a, - 0x96, 0xfd, 0xbb, 0x05, 0xb8, 0x92, 0x41, 0x38, 0xcb, 0xba, 0xdf, 0x82, 0x2a, 0xdd, 0x8d, 0xc2, - 0x85, 0xbf, 0x20, 0x5d, 0x78, 0x02, 0x1d, 0xf1, 0x36, 0x1a, 0x1f, 0xa3, 0xba, 0xa0, 0xa4, 0xdb, - 0xc8, 0x3e, 0xc9, 0xf7, 0xc8, 0xbe, 0xa3, 0xdb, 0x8c, 0x01, 0x75, 0xad, 0xc1, 0x61, 0x3b, 0xba, - 0x8d, 0xd1, 0xd7, 0xa0, 0x46, 0x4c, 0xb6, 0x6f, 0x1a, 0xa1, 0xf8, 0xe7, 0xa8, 0x09, 0x1b, 0x3e, - 0xba, 0x06, 0x40, 0x9b, 0x74, 0xc3, 0xf0, 0xd8, 0x16, 0x5a, 0xd7, 0xea, 0x04, 0x72, 0x87, 0x00, - 0xd4, 0x3f, 0x29, 0xc0, 0xf5, 0xbd, 0x53, 0x67, 0xb0, 0x83, 0x4f, 0x36, 0x3d, 0xac, 0x07, 0x38, - 0x76, 0xda, 0x9f, 0x2b, 0xe3, 0xd1, 0x32, 0x34, 0x12, 0xf6, 0xcb, 0x55, 0x32, 0x09, 0x52, 0xff, - 0xb6, 0x00, 0x4d, 0xb2, 0x8b, 0x7c, 0x80, 0x03, 0x9d, 0xa8, 0x08, 0xfa, 0x0e, 0xd4, 0x2d, 0x57, - 0x37, 0xfa, 0xc1, 0xe9, 0x88, 0x51, 0xd3, 0x4e, 0x53, 0x13, 0x6f, 0x3d, 0xfb, 0xa7, 0x23, 0xac, - 0xd5, 0x2c, 0xfe, 0x2b, 0x17, 0x45, 0x69, 0x2f, 0x53, 0x92, 0x78, 0xca, 0xe7, 0xa0, 0x61, 0xe3, - 0xc0, 0x33, 0x07, 0x8c, 0x88, 0x32, 0x15, 0x05, 0x30, 0x10, 0x41, 0xa4, 0xfe, 0xa4, 0x0a, 0x97, - 0x7f, 0xa0, 0x07, 0x83, 0xa3, 0x2d, 0x3b, 0x8c, 0x62, 0x2e, 0xce, 0xc7, 0xd8, 0x2f, 0x17, 0x93, - 0x7e, 0xf9, 0x99, 0xf9, 0xfd, 0xc8, 0x46, 0x2b, 0x32, 0x1b, 0x25, 0x89, 0xf9, 0xda, 0x23, 0xae, - 0x66, 0x09, 0x1b, 0x4d, 0x04, 0x1b, 0xd5, 0x8b, 0x04, 0x1b, 0x9b, 0xd0, 0xc2, 0x4f, 0x07, 0xd6, - 0x98, 0xe8, 0x2b, 0xc5, 0xce, 0xa2, 0x88, 0xeb, 0x12, 0xec, 0x49, 0x07, 0xd1, 0xe4, 0x83, 0xb6, - 0x39, 0x0d, 0x4c, 0x17, 0x6c, 0x1c, 0xe8, 0x34, 0x54, 0x68, 0x6c, 0x2c, 0x4f, 0xd2, 0x85, 0x50, - 0x81, 0x98, 0x3e, 0x90, 0x2f, 0x74, 0x15, 0xea, 0x3c, 0xb4, 0xd9, 0xde, 0xea, 0xd6, 0x29, 0xfb, - 0x62, 0x00, 0xd2, 0xa1, 0xc5, 0xbd, 0x27, 0xa7, 0x90, 0x05, 0x10, 0x6f, 0xc9, 0x10, 0xc8, 0x85, - 0x9d, 0xa4, 0xdc, 0xe7, 0x81, 0x8e, 0x9f, 0x00, 0x91, 0xcc, 0xdf, 0x3d, 0x3c, 0xb4, 0x4c, 0x07, - 0xef, 0x30, 0x09, 0x37, 0x28, 0x11, 0x22, 0x90, 0x84, 0x43, 0xc7, 0xd8, 0xf3, 0x4d, 0xd7, 0xe9, - 0x36, 0x69, 0x7b, 0xf8, 0x29, 0x8b, 0x72, 0x5a, 0x17, 0x88, 0x72, 0xfa, 0x30, 0x9f, 0xa1, 0x54, - 0x12, 0xe5, 0x7c, 0x2b, 0x19, 0xe5, 0x4c, 0x17, 0x55, 0x22, 0x0a, 0xfa, 0x59, 0x01, 0x96, 0x1e, - 0x3a, 0xfe, 0xf8, 0x20, 0x62, 0xd1, 0x17, 0x63, 0x0e, 0x69, 0x27, 0x5a, 0xce, 0x38, 0x51, 0xf5, - 0xa7, 0x55, 0xe8, 0xf0, 0x55, 0x10, 0xad, 0xa1, 0x2e, 0xe7, 0x2a, 0xd4, 0xa3, 0x7d, 0x94, 0x33, - 0x24, 0x06, 0xa4, 0x7d, 0x58, 0x31, 0xe3, 0xc3, 0x72, 0x91, 0x16, 0x46, 0x45, 0xe5, 0x44, 0x54, - 0x74, 0x0d, 0xe0, 0xd0, 0x1a, 0xfb, 0x47, 0xfd, 0xc0, 0xb4, 0x31, 0x8f, 0xca, 0xea, 0x14, 0xb2, - 0x6f, 0xda, 0x18, 0xdd, 0x81, 0xe6, 0x81, 0xe9, 0x58, 0xee, 0xb0, 0x3f, 0xd2, 0x83, 0x23, 0x9f, - 0xa7, 0xc5, 0x32, 0xb1, 0xd0, 0x18, 0xf6, 0x2e, 0xed, 0xab, 0x35, 0xd8, 0x98, 0x5d, 0x32, 0x04, - 0x5d, 0x87, 0x86, 0x33, 0xb6, 0xfb, 0xee, 0x61, 0xdf, 0x73, 0x4f, 0x7c, 0x9a, 0xfc, 0x96, 0xb4, - 0xba, 0x33, 0xb6, 0x3f, 0x3c, 0xd4, 0xdc, 0x13, 0xb2, 0x8f, 0xd5, 0xc9, 0x8e, 0xe6, 0x5b, 0xee, - 0x90, 0x25, 0xbe, 0xd3, 0xe7, 0x8f, 0x07, 0x90, 0xd1, 0x06, 0xb6, 0x02, 0x9d, 0x8e, 0xae, 0xe7, - 0x1b, 0x1d, 0x0d, 0x40, 0x2f, 0x41, 0x7b, 0xe0, 0xda, 0x23, 0x9d, 0x72, 0xe8, 0x9e, 0xe7, 0xda, - 0xd4, 0x00, 0x4b, 0x5a, 0x0a, 0x8a, 0x36, 0xa1, 0x11, 0x1b, 0x81, 0xdf, 0x6d, 0x50, 0x3c, 0xaa, - 0xcc, 0x4a, 0x13, 0xa1, 0x3c, 0x51, 0x50, 0x88, 0xac, 0xc0, 0x27, 0x9a, 0x11, 0x1a, 0xbb, 0x6f, - 0x7e, 0x8a, 0xb9, 0xa1, 0x35, 0x38, 0x6c, 0xcf, 0xfc, 0x14, 0x93, 0xf4, 0xc8, 0x74, 0x7c, 0xec, - 0x05, 0x61, 0xb2, 0xda, 0x6d, 0x51, 0xf5, 0x69, 0x31, 0x28, 0x57, 0x6c, 0xb4, 0x05, 0x6d, 0x3f, - 0xd0, 0xbd, 0xa0, 0x3f, 0x72, 0x7d, 0xaa, 0x00, 0xdd, 0x36, 0xd5, 0xed, 0x94, 0x49, 0xda, 0xfe, - 0x90, 0x28, 0xf6, 0x2e, 0xef, 0xa4, 0xb5, 0xe8, 0xa0, 0xf0, 0x93, 0xcc, 0x42, 0x39, 0x11, 0xcf, - 0xd2, 0xc9, 0x35, 0x0b, 0x1d, 0x14, 0xcd, 0xb2, 0x42, 0xd2, 0x25, 0xdd, 0xd0, 0x0f, 0x2c, 0xfc, - 0x88, 0x7b, 0x10, 0x85, 0x2e, 0x2c, 0x0d, 0x26, 0xcc, 0xf6, 0x03, 0xd7, 0xd3, 0x87, 0x51, 0xc7, - 0x79, 0xda, 0x31, 0x05, 0x55, 0xff, 0xa7, 0x08, 0x6d, 0x91, 0x8d, 0xc4, 0x3d, 0xb1, 0xec, 0x2d, - 0xb4, 0x8d, 0xf0, 0x93, 0x30, 0x15, 0x3b, 0x04, 0x0b, 0x4b, 0x15, 0xa9, 0x69, 0xd4, 0xb4, 0x06, - 0x83, 0xd1, 0x09, 0x88, 0x8a, 0x33, 0xe1, 0x51, 0x7b, 0x2c, 0x51, 0x86, 0xd6, 0x29, 0x84, 0x86, - 0x34, 0x5d, 0x98, 0x0b, 0xb3, 0x4c, 0x66, 0x18, 0xe1, 0x27, 0x69, 0x39, 0x18, 0x9b, 0x14, 0x2b, - 0x33, 0x8c, 0xf0, 0x13, 0x6d, 0x41, 0x93, 0x4d, 0x39, 0xd2, 0x3d, 0xdd, 0x0e, 0xcd, 0xe2, 0x79, - 0xa9, 0x6b, 0x79, 0x1f, 0x9f, 0x3e, 0x22, 0x5e, 0x6a, 0x57, 0x37, 0x3d, 0x8d, 0xa9, 0xd1, 0x2e, - 0x1d, 0x85, 0x56, 0x40, 0x61, 0xb3, 0x1c, 0x9a, 0x16, 0xe6, 0x06, 0x36, 0xc7, 0x52, 0x4d, 0x0a, - 0xbf, 0x67, 0x5a, 0x98, 0xd9, 0x50, 0xb4, 0x04, 0xaa, 0x38, 0x35, 0x66, 0x42, 0x14, 0x42, 0xd5, - 0xe6, 0x06, 0x30, 0x6f, 0xdb, 0x0f, 0x7d, 0x38, 0xdb, 0x68, 0x18, 0x8d, 0x21, 0xfb, 0x49, 0xe8, - 0x36, 0xb6, 0x99, 0x11, 0x02, 0x5b, 0x8e, 0x33, 0xb6, 0x89, 0x09, 0xaa, 0x7f, 0x58, 0x81, 0x05, - 0xe2, 0x89, 0xb8, 0x53, 0x9a, 0x21, 0x90, 0xb8, 0x06, 0x60, 0xf8, 0x41, 0x5f, 0xf0, 0x9e, 0x75, - 0xc3, 0x0f, 0xf8, 0x36, 0xf3, 0x9d, 0x30, 0x0e, 0x28, 0x4d, 0x4e, 0x6b, 0x52, 0x9e, 0x31, 0x1b, - 0x0b, 0x5c, 0xa8, 0x0e, 0x78, 0x03, 0x5a, 0x3c, 0xa7, 0x17, 0x12, 0xd0, 0x26, 0x03, 0xee, 0xc8, - 0xfd, 0x7b, 0x55, 0x5a, 0x8f, 0x4c, 0xc4, 0x03, 0x73, 0xb3, 0xc5, 0x03, 0xb5, 0x74, 0x3c, 0x70, - 0x0f, 0x3a, 0xa2, 0x49, 0x86, 0x3e, 0x6d, 0x8a, 0x4d, 0xb6, 0x05, 0x9b, 0xf4, 0x93, 0xdb, 0x39, - 0x88, 0xdb, 0xf9, 0x0d, 0x68, 0x39, 0x18, 0x1b, 0xfd, 0xc0, 0xd3, 0x1d, 0xff, 0x10, 0x7b, 0x34, - 0x1c, 0xa8, 0x69, 0x4d, 0x02, 0xdc, 0xe7, 0x30, 0xf4, 0x16, 0x00, 0x5d, 0x23, 0x2b, 0x63, 0x35, - 0x27, 0x97, 0xb1, 0xa8, 0xd2, 0xd0, 0x32, 0x16, 0x65, 0x0a, 0xfd, 0xf9, 0x8c, 0x22, 0x06, 0xf5, - 0x5f, 0x8b, 0x70, 0x99, 0x97, 0x35, 0x66, 0xd7, 0xcb, 0x49, 0x3b, 0x7a, 0xb8, 0x25, 0x96, 0xce, - 0x28, 0x14, 0x94, 0x73, 0x04, 0xbd, 0x15, 0x49, 0xd0, 0x2b, 0x26, 0xcb, 0xd5, 0x4c, 0xb2, 0x1c, - 0xd5, 0x09, 0xe7, 0xf2, 0xd7, 0x09, 0xd1, 0x22, 0x54, 0x68, 0x06, 0x47, 0x75, 0xa7, 0xae, 0xb1, - 0x8f, 0x5c, 0x52, 0x55, 0xff, 0xb8, 0x08, 0xad, 0x3d, 0xac, 0x7b, 0x83, 0xa3, 0x90, 0x8f, 0x6f, - 0x24, 0xeb, 0xaa, 0x2f, 0x4c, 0xa8, 0xab, 0x0a, 0x43, 0xbe, 0x32, 0x05, 0x55, 0x82, 0x20, 0x70, - 0x03, 0x3d, 0xa2, 0xb2, 0xef, 0x8c, 0x6d, 0x5e, 0x6c, 0xec, 0xd0, 0x06, 0x4e, 0xea, 0xce, 0xd8, - 0x56, 0xff, 0xbb, 0x00, 0xcd, 0xef, 0x93, 0x69, 0x42, 0xc6, 0xdc, 0x4e, 0x32, 0xe6, 0xa5, 0x09, - 0x8c, 0xd1, 0x48, 0x32, 0x86, 0x8f, 0xf1, 0x57, 0xae, 0xd6, 0xfc, 0xcf, 0x05, 0xe8, 0x91, 0x54, - 0x5c, 0x63, 0x7e, 0x67, 0x76, 0xeb, 0xba, 0x01, 0xad, 0x63, 0x21, 0xe8, 0x2d, 0x52, 0xe5, 0x6c, - 0x1e, 0x27, 0x4b, 0x07, 0x1a, 0x28, 0x61, 0xe9, 0x97, 0x2f, 0x36, 0xdc, 0x06, 0x5e, 0x96, 0x51, - 0x9d, 0x22, 0x8e, 0x7a, 0x88, 0x8e, 0x27, 0x02, 0xd5, 0xdf, 0x2b, 0xc0, 0x82, 0xa4, 0x23, 0xba, - 0x02, 0x73, 0xbc, 0x4c, 0xc1, 0xe3, 0x05, 0x66, 0xef, 0x06, 0x11, 0x4f, 0x5c, 0x68, 0x33, 0x8d, - 0x6c, 0x24, 0x6d, 0x90, 0xcc, 0x3b, 0xca, 0xc9, 0x8c, 0x8c, 0x7c, 0x0c, 0x1f, 0xf5, 0xa0, 0xc6, - 0xbd, 0x69, 0x98, 0xec, 0x46, 0xdf, 0xea, 0x63, 0x40, 0xf7, 0x71, 0xbc, 0x77, 0xcd, 0xc2, 0xd1, - 0xd8, 0xdf, 0xc4, 0x84, 0x26, 0x9d, 0x90, 0xa1, 0xfe, 0x67, 0x01, 0x16, 0x04, 0x6c, 0xb3, 0x94, - 0x93, 0xe2, 0xfd, 0xb5, 0x78, 0x91, 0xfd, 0x55, 0x28, 0x99, 0x94, 0xce, 0x55, 0x32, 0xb9, 0x0e, - 0x10, 0xf1, 0x3f, 0xe4, 0x68, 0x02, 0xa2, 0xfe, 0x63, 0x01, 0x2e, 0xbf, 0xa7, 0x3b, 0x86, 0x7b, - 0x78, 0x38, 0xbb, 0xaa, 0x6e, 0x82, 0x90, 0x1e, 0xe7, 0x2d, 0x1a, 0x8a, 0x39, 0xf5, 0x2b, 0x30, - 0xef, 0xb1, 0x9d, 0xc9, 0x10, 0x75, 0xb9, 0xa4, 0x29, 0x61, 0x43, 0xa4, 0xa3, 0x7f, 0x5d, 0x04, - 0x44, 0x56, 0x7d, 0x57, 0xb7, 0x74, 0x67, 0x80, 0x2f, 0x4e, 0xfa, 0x8b, 0xd0, 0x16, 0x42, 0x98, - 0xe8, 0x10, 0x3f, 0x19, 0xc3, 0xf8, 0xe8, 0x7d, 0x68, 0x1f, 0x30, 0x54, 0x7d, 0x0f, 0xeb, 0xbe, - 0xeb, 0x70, 0x71, 0x48, 0xeb, 0x83, 0xfb, 0x9e, 0x39, 0x1c, 0x62, 0x6f, 0xd3, 0x75, 0x0c, 0x1e, - 0xdd, 0x1f, 0x84, 0x64, 0x92, 0xa1, 0xc4, 0x18, 0xe2, 0x78, 0x2e, 0x12, 0x4e, 0x14, 0xd0, 0x51, - 0x56, 0xf8, 0x58, 0xb7, 0x62, 0x46, 0xc4, 0xbb, 0xa1, 0xc2, 0x1a, 0xf6, 0x26, 0x97, 0x87, 0x25, - 0xf1, 0x95, 0xfa, 0x77, 0x05, 0x40, 0x51, 0x0a, 0x4f, 0x6b, 0x1e, 0xd4, 0xa2, 0xd3, 0x43, 0x0b, - 0x92, 0x4d, 0xf9, 0x2a, 0xd4, 0x8d, 0x70, 0x24, 0x77, 0x41, 0x31, 0x80, 0xee, 0x91, 0x94, 0xe8, - 0x3e, 0xd1, 0x3c, 0x6c, 0x84, 0x29, 0x32, 0x03, 0x3e, 0xa0, 0x30, 0x31, 0x3c, 0x2b, 0xa7, 0xc3, - 0xb3, 0x64, 0xf5, 0xb3, 0x22, 0x54, 0x3f, 0xd5, 0x9f, 0x15, 0x41, 0xa1, 0x5b, 0xc8, 0x66, 0x5c, - 0xc6, 0xca, 0x45, 0xf4, 0x0d, 0x68, 0xf1, 0x4b, 0x30, 0x02, 0xe1, 0xcd, 0x27, 0x89, 0xc9, 0xd0, - 0x4d, 0x58, 0x64, 0x9d, 0x3c, 0xec, 0x8f, 0xad, 0x38, 0x3b, 0x64, 0xc9, 0x0c, 0x7a, 0xc2, 0xf6, - 0x2e, 0xd2, 0x14, 0x8e, 0x78, 0x08, 0x97, 0x87, 0x96, 0x7b, 0xa0, 0x5b, 0x7d, 0x51, 0x3c, 0x4c, - 0x86, 0x39, 0x34, 0x7e, 0x91, 0x0d, 0xdf, 0x4b, 0xca, 0xd0, 0x47, 0x77, 0xa1, 0xe5, 0x63, 0xfc, - 0x38, 0x4e, 0x19, 0x2b, 0x79, 0x52, 0xc6, 0x26, 0x19, 0x13, 0x7e, 0xa9, 0x7f, 0x5a, 0x80, 0x4e, - 0xea, 0xec, 0x22, 0x5d, 0xe0, 0x28, 0x64, 0x0b, 0x1c, 0xb7, 0xa1, 0x42, 0x3c, 0x15, 0xdb, 0x5b, - 0xda, 0xf2, 0xe4, 0x5b, 0x9c, 0x55, 0x63, 0x03, 0xd0, 0x3a, 0x2c, 0x48, 0xee, 0x48, 0x70, 0xf1, - 0xa3, 0xec, 0x15, 0x09, 0xf5, 0xe7, 0x65, 0x68, 0x24, 0x58, 0x31, 0xa5, 0x36, 0xf3, 0x4c, 0x6a, - 0xd0, 0x93, 0xce, 0xc4, 0x89, 0xca, 0xd9, 0xd8, 0x66, 0x79, 0x1f, 0x4f, 0x42, 0x6d, 0x6c, 0xd3, - 0xac, 0x2f, 0x99, 0xd0, 0x55, 0x85, 0x84, 0x2e, 0x95, 0xf2, 0xce, 0x9d, 0x91, 0xf2, 0xd6, 0xc4, - 0x94, 0x57, 0x30, 0xa1, 0x7a, 0xda, 0x84, 0xf2, 0x96, 0x4b, 0x6e, 0xc2, 0xc2, 0x80, 0xd5, 0xf8, - 0xef, 0x9e, 0x6e, 0x46, 0x4d, 0x3c, 0x28, 0x95, 0x35, 0xa1, 0x7b, 0x71, 0x21, 0x94, 0x49, 0x99, - 0x25, 0x1d, 0xf2, 0x8c, 0x9a, 0xcb, 0x86, 0x09, 0x39, 0xf4, 0xcc, 0xf4, 0x2b, 0x5d, 0xa8, 0x69, - 0x5d, 0xa8, 0x50, 0xf3, 0x1c, 0x34, 0xc2, 0x48, 0x85, 0x58, 0x7a, 0x9b, 0x39, 0xbd, 0xd0, 0x0d, - 0x18, 0xbe, 0xe0, 0x07, 0x3a, 0xe2, 0x29, 0x48, 0xba, 0x1e, 0xa1, 0x64, 0xeb, 0x11, 0x57, 0x60, - 0xce, 0xf4, 0xfb, 0x87, 0xfa, 0x63, 0x4c, 0x0b, 0x20, 0x35, 0xad, 0x6a, 0xfa, 0xf7, 0xf4, 0xc7, - 0x58, 0xfd, 0xf7, 0x12, 0xb4, 0xe3, 0x0d, 0x36, 0xb7, 0x07, 0xc9, 0x73, 0x4f, 0x68, 0x07, 0x94, - 0x38, 0xee, 0xa1, 0x1c, 0x3e, 0x33, 0x07, 0x4f, 0x1f, 0x2d, 0x76, 0x46, 0x29, 0x7b, 0x15, 0xb6, - 0xfb, 0xf2, 0xb9, 0xb6, 0xfb, 0x19, 0x6f, 0x10, 0xdc, 0x82, 0xa5, 0x68, 0xef, 0x15, 0x96, 0xcd, - 0x12, 0xac, 0xc5, 0xb0, 0x71, 0x37, 0xb9, 0xfc, 0x09, 0x2e, 0x60, 0x6e, 0x92, 0x0b, 0x48, 0xab, - 0x40, 0x2d, 0xa3, 0x02, 0xd9, 0x8b, 0x0c, 0x75, 0xc9, 0x45, 0x06, 0xf5, 0x21, 0x2c, 0xd0, 0xa2, - 0xb4, 0x3f, 0xf0, 0xcc, 0x03, 0x1c, 0xa5, 0x00, 0x79, 0xc4, 0xda, 0x83, 0x5a, 0x2a, 0x8b, 0x88, - 0xbe, 0xd5, 0x1f, 0x17, 0xe0, 0x72, 0x76, 0x5e, 0xaa, 0x31, 0xb1, 0x23, 0x29, 0x08, 0x8e, 0xe4, - 0x57, 0x61, 0x21, 0x11, 0x51, 0x0a, 0x33, 0x4f, 0x88, 0xc0, 0x25, 0x84, 0x6b, 0x28, 0x9e, 0x23, - 0x84, 0xa9, 0x3f, 0x2f, 0x44, 0xb5, 0x7d, 0x02, 0x1b, 0xd2, 0x83, 0x13, 0xb2, 0xaf, 0xb9, 0x8e, - 0x65, 0x3a, 0x51, 0xc1, 0x85, 0xaf, 0x91, 0x01, 0x79, 0xc1, 0xe5, 0x3d, 0xe8, 0xf0, 0x4e, 0xd1, - 0xf6, 0x94, 0x33, 0x20, 0x6b, 0xb3, 0x71, 0xd1, 0xc6, 0xf4, 0x22, 0xb4, 0xf9, 0x89, 0x46, 0x88, - 0xaf, 0x24, 0x3b, 0xe7, 0xf8, 0x1e, 0x28, 0x61, 0xb7, 0xf3, 0x6e, 0x88, 0x1d, 0x3e, 0x30, 0x0a, - 0xec, 0x7e, 0xbb, 0x00, 0x5d, 0x71, 0x7b, 0x4c, 0x2c, 0xff, 0xfc, 0xe1, 0xdd, 0x9b, 0xe2, 0x39, - 0xf6, 0x8b, 0x67, 0xd0, 0x13, 0xe3, 0x09, 0x4f, 0xb3, 0xff, 0xa0, 0x48, 0x2f, 0x25, 0x90, 0x54, - 0x6f, 0xcb, 0xf4, 0x03, 0xcf, 0x3c, 0x18, 0xcf, 0x76, 0xb2, 0xaa, 0x43, 0x63, 0x70, 0x84, 0x07, - 0x8f, 0x47, 0xae, 0x19, 0x4b, 0xe5, 0x1d, 0x19, 0x4d, 0x93, 0xd1, 0xae, 0x6d, 0xc6, 0x33, 0xb0, - 0xa3, 0xa9, 0xe4, 0x9c, 0xbd, 0x1f, 0x82, 0x92, 0xee, 0x90, 0x3c, 0x11, 0xaa, 0xb3, 0x13, 0xa1, - 0x5b, 0xe2, 0x89, 0xd0, 0x94, 0x48, 0x23, 0x71, 0x20, 0xf4, 0xf7, 0x45, 0xf8, 0xba, 0x94, 0xb6, - 0x59, 0xb2, 0xa4, 0x49, 0x75, 0xa4, 0xbb, 0x50, 0x4b, 0x25, 0xb5, 0x2f, 0x9d, 0x21, 0x3f, 0x5e, - 0x92, 0x65, 0xa5, 0x41, 0x3f, 0x8e, 0xad, 0x62, 0x83, 0x2f, 0x4f, 0x9e, 0x83, 0xdb, 0x9d, 0x30, - 0x47, 0x38, 0x0e, 0xdd, 0x81, 0x26, 0x2b, 0x18, 0xf4, 0x8f, 0x4d, 0x7c, 0x12, 0x9e, 0xb7, 0x5e, - 0x97, 0xba, 0x66, 0xda, 0xef, 0x91, 0x89, 0x4f, 0xb4, 0x86, 0x15, 0xfd, 0xf6, 0xd5, 0x3f, 0x2a, - 0x03, 0xc4, 0x6d, 0x24, 0x3b, 0x8b, 0x6d, 0x9e, 0x1b, 0x71, 0x02, 0x42, 0x62, 0x09, 0x31, 0x72, - 0x0d, 0x3f, 0x91, 0x16, 0x9f, 0x77, 0x18, 0xa6, 0x1f, 0x70, 0xbe, 0xac, 0x9f, 0x4d, 0x4b, 0xc8, - 0x22, 0x22, 0x32, 0xae, 0x33, 0x7e, 0x0c, 0x41, 0xaf, 0x01, 0x1a, 0x7a, 0xee, 0x89, 0xe9, 0x0c, - 0x93, 0xf9, 0x06, 0x4b, 0x4b, 0xe6, 0x79, 0x4b, 0x22, 0xe1, 0xf8, 0x11, 0x28, 0xa9, 0xee, 0x21, - 0x4b, 0x6e, 0x4d, 0x21, 0xe3, 0xbe, 0x30, 0x17, 0x57, 0xdf, 0x8e, 0x88, 0x81, 0x1e, 0xae, 0xee, - 0xeb, 0xde, 0x10, 0x87, 0x12, 0xe5, 0x71, 0x98, 0x08, 0xec, 0xf5, 0x41, 0x49, 0xaf, 0x4a, 0x72, - 0xf4, 0xf9, 0xba, 0xa8, 0xe8, 0x67, 0xf9, 0x23, 0x32, 0x4d, 0x42, 0xd5, 0x7b, 0x3a, 0x2c, 0xca, - 0xe8, 0x95, 0x20, 0xb9, 0xb0, 0x35, 0xbd, 0x13, 0x85, 0xc4, 0x54, 0x0e, 0x93, 0x76, 0x99, 0x44, - 0xe1, 0xb9, 0x28, 0x14, 0x9e, 0xd5, 0xdf, 0x2c, 0x01, 0xca, 0xaa, 0x3f, 0x6a, 0x43, 0x31, 0x9a, - 0xa4, 0xb8, 0xbd, 0x95, 0x52, 0xb7, 0x62, 0x46, 0xdd, 0xae, 0x42, 0x3d, 0xda, 0xf5, 0xb9, 0x8b, - 0x8f, 0x01, 0x49, 0x65, 0x2c, 0x8b, 0xca, 0x98, 0x20, 0xac, 0x22, 0x56, 0xc4, 0x6f, 0xc2, 0xa2, - 0xa5, 0xfb, 0x41, 0x9f, 0x15, 0xde, 0x03, 0xd3, 0xc6, 0x7e, 0xa0, 0xdb, 0x23, 0x2a, 0xca, 0xb2, - 0x86, 0x48, 0xdb, 0x16, 0x69, 0xda, 0x0f, 0x5b, 0xd0, 0x7e, 0x18, 0x5d, 0x13, 0xdf, 0xcb, 0x2f, - 0x15, 0xbc, 0x9e, 0xcf, 0xdc, 0xe3, 0x72, 0x37, 0xd3, 0xa8, 0x7a, 0x14, 0x76, 0xf6, 0x3e, 0x81, - 0xb6, 0xd8, 0x28, 0x11, 0xdf, 0x6d, 0x51, 0x7c, 0x79, 0x02, 0xdb, 0x84, 0x0c, 0x8f, 0x00, 0x65, - 0x9d, 0x47, 0x92, 0x67, 0x05, 0x91, 0x67, 0xd3, 0x64, 0x91, 0xe0, 0x69, 0x49, 0x14, 0xf6, 0x9f, - 0x97, 0x00, 0xc5, 0x11, 0x5c, 0x74, 0xc8, 0x9d, 0x27, 0xec, 0x59, 0x87, 0x85, 0x6c, 0x7c, 0x17, - 0x06, 0xb5, 0x28, 0x13, 0xdd, 0xc9, 0x22, 0xb1, 0x92, 0xec, 0x4a, 0xe9, 0x1b, 0x91, 0xbb, 0x67, - 0xe1, 0xea, 0xf5, 0x89, 0xe7, 0x19, 0xa2, 0xc7, 0xff, 0x61, 0xfa, 0x2a, 0x2a, 0xf3, 0x1f, 0xb7, - 0xa5, 0xae, 0x39, 0xb3, 0xe4, 0xa9, 0xf7, 0x50, 0x85, 0x40, 0xba, 0x7a, 0x9e, 0x40, 0x7a, 0xf6, - 0x8b, 0xa3, 0xff, 0x51, 0x84, 0xf9, 0x88, 0x91, 0xe7, 0x12, 0xd2, 0xf4, 0xfb, 0x08, 0x9f, 0xb3, - 0x54, 0x3e, 0x96, 0x4b, 0xe5, 0xdb, 0x67, 0x26, 0x33, 0x79, 0x85, 0x32, 0x3b, 0x67, 0x3f, 0x85, - 0x39, 0x5e, 0x96, 0xce, 0x38, 0xb8, 0x3c, 0xe5, 0x82, 0x45, 0xa8, 0x10, 0x7f, 0x1a, 0xd6, 0x14, - 0xd9, 0x07, 0x63, 0x69, 0xf2, 0x62, 0x32, 0xf7, 0x71, 0x2d, 0xe1, 0x5e, 0xb2, 0xfa, 0x3b, 0x25, - 0x80, 0xbd, 0x53, 0x67, 0x70, 0x87, 0x19, 0xe9, 0x4d, 0x28, 0x4f, 0xbb, 0xc6, 0x46, 0x7a, 0x53, - 0xdd, 0xa2, 0x3d, 0x73, 0x08, 0x57, 0x28, 0x88, 0x94, 0xd2, 0x05, 0x91, 0x49, 0xa5, 0x8c, 0xc9, - 0x2e, 0xf8, 0xdb, 0x50, 0xa6, 0xae, 0x94, 0xdd, 0xf2, 0xca, 0x75, 0x2a, 0x4c, 0x07, 0xa0, 0x15, - 0x08, 0xb7, 0xe4, 0x6d, 0x87, 0xed, 0xb9, 0xd4, 0x1d, 0x97, 0xb4, 0x34, 0x98, 0x5e, 0x3e, 0xa0, - 0xb1, 0x7a, 0xd4, 0x91, 0xe5, 0x74, 0x29, 0x68, 0x76, 0x47, 0xaf, 0x4b, 0x76, 0x74, 0x82, 0xd7, - 0xf0, 0xdc, 0xd1, 0x28, 0x31, 0x1d, 0xab, 0x84, 0xa4, 0xc1, 0xea, 0x67, 0x45, 0xb8, 0x42, 0xf8, - 0xfb, 0x6c, 0xa2, 0xf2, 0x3c, 0xca, 0x93, 0xf0, 0xe7, 0x25, 0xd1, 0x9f, 0xdf, 0x86, 0x39, 0x56, - 0x6e, 0x09, 0xe3, 0xcb, 0xeb, 0x93, 0xb4, 0x81, 0xe9, 0x8e, 0x16, 0x76, 0x9f, 0x35, 0x67, 0x17, - 0xce, 0xcc, 0xab, 0xb3, 0x9d, 0x99, 0xcf, 0xa5, 0x8b, 0xb2, 0x09, 0xb5, 0xaa, 0x89, 0xbb, 0xd0, - 0x43, 0x68, 0x69, 0x49, 0xd3, 0x40, 0x08, 0xca, 0x89, 0x8b, 0xad, 0xf4, 0x37, 0x4d, 0xb3, 0xf5, - 0x91, 0x3e, 0x30, 0x83, 0x53, 0xca, 0xce, 0x8a, 0x16, 0x7d, 0xcb, 0xed, 0x50, 0xfd, 0xdf, 0x02, - 0x5c, 0x0e, 0x0f, 0x55, 0xb9, 0x95, 0x5f, 0x5c, 0xa2, 0x1b, 0xb0, 0xc4, 0x4d, 0x3a, 0x65, 0xdb, - 0x2c, 0x98, 0x5e, 0x60, 0x30, 0x71, 0x19, 0x1b, 0xb0, 0x14, 0x50, 0xed, 0x4a, 0x8f, 0x61, 0xf2, - 0x5e, 0x60, 0x8d, 0xe2, 0x98, 0x3c, 0x87, 0xda, 0xcf, 0xb1, 0x9b, 0x5a, 0x9c, 0xb5, 0xdc, 0x48, - 0xc1, 0x19, 0xdb, 0x7c, 0x95, 0xea, 0x09, 0x5c, 0x65, 0x57, 0xcb, 0x0f, 0x44, 0x8a, 0x66, 0x3a, - 0xd3, 0x90, 0xae, 0x3b, 0xe5, 0xd3, 0xfe, 0xac, 0x00, 0xd7, 0x26, 0x60, 0x9e, 0x25, 0x9b, 0x7b, - 0x20, 0xc5, 0x3e, 0x21, 0xf7, 0x16, 0xf0, 0xb2, 0x0b, 0x0b, 0x22, 0x91, 0x9f, 0x95, 0x61, 0x3e, - 0xd3, 0xe9, 0xdc, 0x3a, 0xf7, 0x2a, 0x20, 0x22, 0x84, 0xe8, 0x19, 0x25, 0x2d, 0x67, 0xf0, 0xcd, - 0x53, 0x71, 0xc6, 0x76, 0xf4, 0x84, 0x72, 0xc7, 0x35, 0x30, 0x32, 0x59, 0x6f, 0x76, 0xa2, 0x11, - 0x49, 0xae, 0x3c, 0xf9, 0xb5, 0x4c, 0x86, 0xc0, 0xb5, 0x9d, 0xb1, 0xcd, 0x0e, 0x3f, 0xb8, 0x94, - 0xd9, 0x86, 0x48, 0x50, 0x09, 0x60, 0x74, 0x08, 0xf3, 0xf4, 0x1e, 0xdf, 0x38, 0x18, 0xba, 0x24, - 0xa1, 0xa2, 0x74, 0xb1, 0x6d, 0xf7, 0xbb, 0xb9, 0x31, 0x7d, 0xc8, 0x47, 0x13, 0xe2, 0x79, 0x4e, - 0xe5, 0x88, 0xd0, 0x10, 0x8f, 0xe9, 0x0c, 0x5c, 0x3b, 0xc2, 0x53, 0x3d, 0x27, 0x9e, 0x6d, 0x3e, - 0x5a, 0xc4, 0x93, 0x84, 0xf6, 0x36, 0x61, 0x49, 0xba, 0xf4, 0x69, 0x1b, 0x7d, 0x25, 0x99, 0x79, - 0xdd, 0x85, 0x45, 0xd9, 0xaa, 0x2e, 0x30, 0x47, 0x86, 0xe2, 0xf3, 0xcc, 0xa1, 0xfe, 0x55, 0x11, - 0x5a, 0x5b, 0xd8, 0xc2, 0x01, 0xfe, 0x7c, 0xcf, 0x9c, 0x33, 0x07, 0xe8, 0xa5, 0xec, 0x01, 0x7a, - 0xe6, 0x36, 0x40, 0x59, 0x72, 0x1b, 0xe0, 0x5a, 0x74, 0x09, 0x82, 0xcc, 0x52, 0x11, 0x63, 0x08, - 0x03, 0xbd, 0x09, 0xcd, 0x91, 0x67, 0xda, 0xba, 0x77, 0xda, 0x7f, 0x8c, 0x4f, 0x7d, 0xbe, 0x69, - 0x74, 0xa5, 0xdb, 0xce, 0xf6, 0x96, 0xaf, 0x35, 0x78, 0xef, 0xf7, 0xf1, 0x29, 0xbd, 0x60, 0x11, - 0xa5, 0x71, 0xec, 0x46, 0x5d, 0x59, 0x4b, 0x40, 0x56, 0x5f, 0x81, 0x7a, 0x74, 0x71, 0x09, 0xd5, - 0xa0, 0x7c, 0x6f, 0x6c, 0x59, 0xca, 0x25, 0x54, 0x87, 0x0a, 0x4d, 0xf4, 0x94, 0x02, 0xf9, 0x49, - 0x63, 0x3f, 0xa5, 0xb8, 0xfa, 0x2b, 0x50, 0x8f, 0x2e, 0x50, 0xa0, 0x06, 0xcc, 0x3d, 0x74, 0xde, - 0x77, 0xdc, 0x13, 0x47, 0xb9, 0x84, 0xe6, 0xa0, 0x74, 0xc7, 0xb2, 0x94, 0x02, 0x6a, 0x41, 0x7d, - 0x2f, 0xf0, 0xb0, 0x4e, 0xc4, 0xa7, 0x14, 0x51, 0x1b, 0xe0, 0x3d, 0xd3, 0x0f, 0x5c, 0xcf, 0x1c, - 0xe8, 0x96, 0x52, 0x5a, 0xfd, 0x14, 0xda, 0x62, 0x3d, 0x1d, 0x35, 0xa1, 0xb6, 0xe3, 0x06, 0xef, - 0x3e, 0x35, 0xfd, 0x40, 0xb9, 0x44, 0xfa, 0xef, 0xb8, 0xc1, 0xae, 0x87, 0x7d, 0xec, 0x04, 0x4a, - 0x01, 0x01, 0x54, 0x3f, 0x74, 0xb6, 0x4c, 0xff, 0xb1, 0x52, 0x44, 0x0b, 0xfc, 0xa8, 0x4c, 0xb7, - 0xb6, 0x79, 0x91, 0x5a, 0x29, 0x91, 0xe1, 0xd1, 0x57, 0x19, 0x29, 0xd0, 0x8c, 0xba, 0xdc, 0xdf, - 0x7d, 0xa8, 0x54, 0x18, 0xf5, 0xe4, 0x67, 0x75, 0xd5, 0x00, 0x25, 0x7d, 0xc4, 0x4b, 0xe6, 0x64, - 0x8b, 0x88, 0x40, 0xca, 0x25, 0xb2, 0x32, 0x7e, 0xc6, 0xae, 0x14, 0x50, 0x07, 0x1a, 0x89, 0x13, - 0x6b, 0xa5, 0x48, 0x00, 0xf7, 0xbd, 0xd1, 0x80, 0xeb, 0x16, 0x23, 0x81, 0x28, 0xea, 0x16, 0xe1, - 0x44, 0x79, 0xf5, 0x2e, 0xd4, 0xc2, 0xfc, 0x84, 0x74, 0xe5, 0x2c, 0x22, 0x9f, 0xca, 0x25, 0x34, - 0x0f, 0x2d, 0xe1, 0x89, 0x9e, 0x52, 0x40, 0x08, 0xda, 0xe2, 0x23, 0x5a, 0xa5, 0xb8, 0xba, 0x01, - 0x10, 0xc7, 0xf9, 0x84, 0x9c, 0x6d, 0xe7, 0x58, 0xb7, 0x4c, 0x83, 0xd1, 0x46, 0x9a, 0x08, 0x77, - 0x29, 0x77, 0x98, 0xcd, 0x2a, 0xc5, 0xd5, 0xb7, 0xa1, 0x16, 0xc6, 0xae, 0x04, 0xae, 0x61, 0xdb, - 0x3d, 0xc6, 0x4c, 0x32, 0x7b, 0x38, 0x60, 0x72, 0xbc, 0x63, 0x63, 0xc7, 0x50, 0x8a, 0x84, 0x8c, - 0x87, 0x23, 0x43, 0x0f, 0xc2, 0x6b, 0xa6, 0x4a, 0x69, 0xe3, 0xbf, 0x16, 0x00, 0xd8, 0x99, 0xad, - 0xeb, 0x7a, 0x06, 0xb2, 0xe8, 0xdd, 0x8d, 0x4d, 0xd7, 0x1e, 0xb9, 0x4e, 0x78, 0xa0, 0xe4, 0xa3, - 0xb5, 0x54, 0x89, 0x84, 0x7d, 0x64, 0x3b, 0x72, 0xde, 0xf4, 0x5e, 0x90, 0xf6, 0x4f, 0x75, 0x56, - 0x2f, 0x21, 0x9b, 0x62, 0xdb, 0x37, 0x6d, 0xbc, 0x6f, 0x0e, 0x1e, 0x47, 0x07, 0xbd, 0x93, 0x1f, - 0xb7, 0xa6, 0xba, 0x86, 0xf8, 0x6e, 0x48, 0xf1, 0xed, 0x05, 0x9e, 0xe9, 0x0c, 0xc3, 0xdd, 0x51, - 0xbd, 0x84, 0x9e, 0xa4, 0x9e, 0xd6, 0x86, 0x08, 0x37, 0xf2, 0xbc, 0xa6, 0xbd, 0x18, 0x4a, 0x0b, - 0x3a, 0xa9, 0xff, 0x30, 0x40, 0xab, 0xf2, 0x37, 0x4a, 0xb2, 0xff, 0x5b, 0xe8, 0xbd, 0x92, 0xab, - 0x6f, 0x84, 0xcd, 0x84, 0xb6, 0xf8, 0xf8, 0x1e, 0x7d, 0x63, 0xd2, 0x04, 0x99, 0x57, 0x92, 0xbd, - 0xd5, 0x3c, 0x5d, 0x23, 0x54, 0x1f, 0x31, 0xf5, 0x9d, 0x86, 0x4a, 0xfa, 0x30, 0xb5, 0x77, 0x56, - 0x60, 0xa2, 0x5e, 0x42, 0x9f, 0x90, 0x18, 0x22, 0xf5, 0x96, 0x13, 0xbd, 0x2a, 0xdf, 0xf7, 0xe4, - 0x4f, 0x3e, 0xa7, 0x61, 0xf8, 0x28, 0x6d, 0x7c, 0x93, 0xa9, 0xcf, 0x3c, 0x12, 0xcf, 0x4f, 0x7d, - 0x62, 0xfa, 0xb3, 0xa8, 0x3f, 0x37, 0x06, 0x8b, 0xa5, 0x53, 0x92, 0x57, 0x64, 0x69, 0x55, 0x8e, - 0xb3, 0x99, 0xc9, 0x4f, 0xce, 0xa6, 0x61, 0x1b, 0x53, 0x23, 0x4d, 0x5f, 0x56, 0x78, 0x6d, 0xc2, - 0x31, 0x88, 0xfc, 0xf9, 0x6a, 0x6f, 0x2d, 0x6f, 0xf7, 0xa4, 0x2e, 0x8b, 0x2f, 0x24, 0xe5, 0x22, - 0x92, 0xbe, 0xea, 0x94, 0xeb, 0xb2, 0xfc, 0xc1, 0xa5, 0x7a, 0x09, 0xed, 0x0b, 0xae, 0x1e, 0xbd, - 0x34, 0x49, 0x15, 0xc4, 0xdb, 0x4b, 0xd3, 0xf8, 0xf6, 0xeb, 0x80, 0x98, 0xa5, 0x3a, 0x87, 0xe6, - 0x70, 0xec, 0xe9, 0x4c, 0x8d, 0x27, 0x39, 0xb7, 0x6c, 0xd7, 0x10, 0xcd, 0x37, 0xcf, 0x31, 0x22, - 0x5a, 0x52, 0x1f, 0xe0, 0x3e, 0x0e, 0x3e, 0xa0, 0x4f, 0xe5, 0xfc, 0xf4, 0x8a, 0x62, 0xff, 0xcd, - 0x3b, 0x84, 0xa8, 0x5e, 0x9e, 0xda, 0x2f, 0x42, 0x70, 0x00, 0x8d, 0xfb, 0x24, 0xbf, 0xa2, 0x31, - 0xa3, 0x8f, 0x26, 0x8e, 0x0c, 0x7b, 0x84, 0x28, 0x56, 0xa6, 0x77, 0x4c, 0x3a, 0xcf, 0xd4, 0x6b, - 0x51, 0x34, 0x51, 0xb0, 0xd9, 0x37, 0xac, 0x72, 0xe7, 0x39, 0xe1, 0xf9, 0x29, 0x5b, 0x11, 0x3d, - 0x8a, 0x7b, 0x0f, 0xeb, 0x56, 0x70, 0x34, 0x61, 0x45, 0x89, 0x1e, 0x67, 0xaf, 0x48, 0xe8, 0x18, - 0xe1, 0xc0, 0xb0, 0xc0, 0xac, 0x50, 0x4c, 0x4c, 0xd7, 0xe5, 0x53, 0x64, 0x7b, 0xe6, 0x54, 0x3d, - 0x1d, 0xe6, 0xb7, 0x3c, 0x77, 0x24, 0x22, 0x79, 0x4d, 0x8a, 0x24, 0xd3, 0x2f, 0x27, 0x8a, 0x1f, - 0x40, 0x33, 0xcc, 0xff, 0x69, 0xc6, 0x22, 0xe7, 0x42, 0xb2, 0x4b, 0xce, 0x89, 0x3f, 0x86, 0x4e, - 0xaa, 0xb0, 0x20, 0x17, 0xba, 0xbc, 0xfa, 0x30, 0x6d, 0xf6, 0x13, 0x40, 0xf4, 0x09, 0xb0, 0xf8, - 0x2f, 0x06, 0xf2, 0xf8, 0x26, 0xdb, 0x31, 0x44, 0xb2, 0x9e, 0xbb, 0x7f, 0x24, 0xf9, 0xdf, 0x80, - 0x25, 0x69, 0xf2, 0x9e, 0x76, 0x08, 0xfc, 0xba, 0xf2, 0x19, 0x15, 0x86, 0xb4, 0x43, 0x38, 0x73, - 0x44, 0x88, 0x7f, 0xe3, 0xf7, 0xe7, 0xa1, 0x4e, 0xe3, 0x3c, 0x2a, 0xad, 0x5f, 0x86, 0x79, 0xcf, - 0x36, 0xcc, 0xfb, 0x18, 0x3a, 0xa9, 0xa7, 0xa9, 0x72, 0xa5, 0x95, 0xbf, 0x5f, 0xcd, 0x11, 0xad, - 0x88, 0xaf, 0x3a, 0xe5, 0x5b, 0xa1, 0xf4, 0xe5, 0xe7, 0xb4, 0xb9, 0x1f, 0xb1, 0x67, 0xdf, 0xd1, - 0x69, 0xee, 0xcb, 0x13, 0x0f, 0x1f, 0xc4, 0x6b, 0xc7, 0x5f, 0x7c, 0x14, 0xf4, 0xd5, 0x8e, 0x40, - 0x3f, 0x86, 0x4e, 0xea, 0x61, 0x8f, 0x5c, 0x63, 0xe4, 0xaf, 0x7f, 0xa6, 0xcd, 0xfe, 0x0b, 0x0c, - 0x9e, 0x0c, 0x58, 0x90, 0xbc, 0xa3, 0x40, 0x6b, 0x93, 0x02, 0x51, 0xf9, 0x83, 0x8b, 0xe9, 0x0b, - 0x6a, 0x09, 0x66, 0x9a, 0xde, 0x6f, 0x62, 0x22, 0xd3, 0x7f, 0x7f, 0xd4, 0x7b, 0x35, 0xdf, 0x7f, - 0x25, 0x45, 0x0b, 0xda, 0x83, 0x2a, 0x7b, 0xee, 0x83, 0x9e, 0x97, 0x1f, 0xc2, 0x24, 0x9e, 0x02, - 0xf5, 0xa6, 0x3d, 0x18, 0xf2, 0xc7, 0x56, 0x40, 0xe8, 0xff, 0x35, 0x68, 0x33, 0x50, 0xc4, 0xa0, - 0x67, 0x38, 0xf9, 0x1e, 0x54, 0xa8, 0x6b, 0x47, 0xd2, 0x03, 0x85, 0xe4, 0xa3, 0x9e, 0xde, 0xf4, - 0x77, 0x3c, 0x31, 0xc5, 0xad, 0xef, 0xb3, 0x7f, 0xad, 0xe3, 0x04, 0x3f, 0xcb, 0xc9, 0xff, 0x7f, - 0xc7, 0xc6, 0x4f, 0xe9, 0x93, 0x94, 0xf4, 0xa5, 0x2b, 0xb4, 0x76, 0xbe, 0x9b, 0x63, 0xbd, 0xf5, - 0xdc, 0xfd, 0x23, 0xcc, 0x3f, 0x02, 0x25, 0x7d, 0xd0, 0x86, 0x5e, 0x99, 0x64, 0x89, 0x32, 0x9c, - 0x53, 0xcc, 0xf0, 0x7b, 0x50, 0x65, 0x15, 0x56, 0xb9, 0xfa, 0x0a, 0xd5, 0xd7, 0x29, 0x73, 0xdd, - 0xfd, 0xd6, 0x47, 0x1b, 0x43, 0x33, 0x38, 0x1a, 0x1f, 0x90, 0x96, 0x75, 0xd6, 0xf5, 0x35, 0xd3, - 0xe5, 0xbf, 0xd6, 0x43, 0x59, 0xae, 0xd3, 0xd1, 0xeb, 0x14, 0xc1, 0xe8, 0xe0, 0xa0, 0x4a, 0x3f, - 0x6f, 0xfd, 0x5f, 0x00, 0x00, 0x00, 0xff, 0xff, 0x1e, 0x2c, 0x9a, 0x3c, 0x36, 0x53, 0x00, 0x00, + // 4842 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xec, 0x7c, 0xcb, 0x6f, 0x1c, 0x47, + 0x7a, 0xb8, 0xe6, 0xc9, 0x99, 0x6f, 0x1e, 0x6c, 0x16, 0x49, 0x69, 0x76, 0x56, 0x92, 0xe5, 0x96, + 0x1f, 0x5c, 0xd9, 0xa6, 0xb4, 0xd4, 0xda, 0xab, 0x5d, 0xdb, 0xf0, 0x4f, 0x22, 0x2d, 0x99, 0x6b, + 0x9b, 0xd6, 0xaf, 0x29, 0x79, 0x03, 0xaf, 0x77, 0xc7, 0xcd, 0xe9, 0xe2, 0xb0, 0xa1, 0x7e, 0x8c, + 0xba, 0x7b, 0x48, 0xd3, 0x01, 0x82, 0x1c, 0x72, 0xc9, 0x26, 0x1b, 0x04, 0xc9, 0x21, 0x39, 0x2c, + 0x72, 0x48, 0x10, 0x60, 0x13, 0x24, 0x97, 0x20, 0x01, 0x72, 0xc8, 0x21, 0xb7, 0xe4, 0x92, 0xd7, + 0xff, 0x90, 0xdc, 0x72, 0x5d, 0x04, 0x06, 0x72, 0x08, 0xea, 0xd5, 0xdd, 0xd5, 0x5d, 0xc3, 0x69, + 0x72, 0x24, 0x3f, 0x82, 0xdc, 0xa6, 0xbf, 0x7a, 0x7c, 0x5f, 0x7d, 0xaf, 0xfa, 0xbe, 0xaf, 0xaa, + 0x06, 0x96, 0x1e, 0x4f, 0x70, 0x70, 0x3c, 0x18, 0xfa, 0x7e, 0x60, 0xad, 0x8f, 0x03, 0x3f, 0xf2, + 0x11, 0x72, 0x6d, 0xe7, 0x70, 0x12, 0xb2, 0xaf, 0x75, 0xda, 0xde, 0x6f, 0x0f, 0x7d, 0xd7, 0xf5, + 0x3d, 0x06, 0xeb, 0xb7, 0xd3, 0x3d, 0xfa, 0x5d, 0xdb, 0x8b, 0x70, 0xe0, 0x99, 0x8e, 0x68, 0x0d, + 0x87, 0x07, 0xd8, 0x35, 0xf9, 0x57, 0xd3, 0x0d, 0x47, 0xfc, 0xa7, 0x66, 0x99, 0x91, 0x99, 0x46, + 0xd5, 0x5f, 0xb2, 0x3d, 0x0b, 0x7f, 0x9a, 0x06, 0xe9, 0xbf, 0x51, 0x82, 0xf3, 0xbb, 0x07, 0xfe, + 0xd1, 0xa6, 0xef, 0x38, 0x78, 0x18, 0xd9, 0xbe, 0x17, 0x1a, 0xf8, 0xf1, 0x04, 0x87, 0x11, 0xba, + 0x01, 0xd5, 0x3d, 0x33, 0xc4, 0xbd, 0xd2, 0x95, 0xd2, 0x5a, 0x6b, 0xe3, 0xe2, 0xba, 0x44, 0x27, + 0x27, 0xf0, 0xfd, 0x70, 0x74, 0xc7, 0x0c, 0xb1, 0x41, 0x7b, 0x22, 0x04, 0x55, 0x6b, 0x6f, 0x7b, + 0xab, 0x57, 0xbe, 0x52, 0x5a, 0xab, 0x18, 0xf4, 0x37, 0x7a, 0x0e, 0x3a, 0xc3, 0x78, 0xee, 0xed, + 0xad, 0xb0, 0x57, 0xb9, 0x52, 0x59, 0xab, 0x18, 0x32, 0x50, 0xff, 0x69, 0x19, 0x2e, 0xe4, 0xc8, + 0x08, 0xc7, 0xbe, 0x17, 0x62, 0x74, 0x13, 0xea, 0x61, 0x64, 0x46, 0x93, 0x90, 0x53, 0xf2, 0x4d, + 0x25, 0x25, 0xbb, 0xb4, 0x8b, 0xc1, 0xbb, 0xe6, 0xd1, 0x96, 0x15, 0x68, 0xd1, 0xb7, 0x61, 0xc5, + 0xf6, 0xde, 0xc7, 0xae, 0x1f, 0x1c, 0x0f, 0xc6, 0x38, 0x18, 0x62, 0x2f, 0x32, 0x47, 0x58, 0xd0, + 0xb8, 0x2c, 0xda, 0xee, 0x27, 0x4d, 0xe8, 0x35, 0xb8, 0xc0, 0x64, 0x18, 0xe2, 0xe0, 0xd0, 0x1e, + 0xe2, 0x81, 0x79, 0x68, 0xda, 0x8e, 0xb9, 0xe7, 0xe0, 0x5e, 0xf5, 0x4a, 0x65, 0xad, 0x61, 0xac, + 0xd2, 0xe6, 0x5d, 0xd6, 0x7a, 0x5b, 0x34, 0xa2, 0x6f, 0x81, 0x16, 0xe0, 0xfd, 0x00, 0x87, 0x07, + 0x83, 0x71, 0xe0, 0x8f, 0x02, 0x1c, 0x86, 0xbd, 0x1a, 0x45, 0xb3, 0xc8, 0xe1, 0xf7, 0x39, 0x58, + 0xff, 0xd3, 0x12, 0xac, 0x12, 0x66, 0xdc, 0x37, 0x83, 0xc8, 0x7e, 0x0a, 0x22, 0xd1, 0xa1, 0x9d, + 0x66, 0x43, 0xaf, 0x42, 0xdb, 0x24, 0x18, 0xe9, 0x33, 0x16, 0xe8, 0x09, 0xfb, 0xaa, 0x94, 0x54, + 0x09, 0xa6, 0xff, 0x0b, 0xd7, 0x9d, 0x34, 0x9d, 0xf3, 0xc8, 0x2c, 0x8b, 0xb3, 0x9c, 0xc7, 0x79, + 0x16, 0x89, 0xa9, 0x38, 0x5f, 0x55, 0x73, 0xfe, 0x9f, 0x2a, 0xb0, 0xfa, 0x9e, 0x6f, 0x5a, 0x89, + 0x1a, 0x7e, 0xf1, 0x9c, 0x7f, 0x13, 0xea, 0xcc, 0xa2, 0x7b, 0x55, 0x8a, 0xeb, 0x79, 0x19, 0x17, + 0xb7, 0xf6, 0x84, 0xc2, 0x5d, 0x0a, 0x30, 0xf8, 0x20, 0xf4, 0x3c, 0x74, 0x03, 0x3c, 0x76, 0xec, + 0xa1, 0x39, 0xf0, 0x26, 0xee, 0x1e, 0x0e, 0x7a, 0xb5, 0x2b, 0xa5, 0xb5, 0x9a, 0xd1, 0xe1, 0xd0, + 0x1d, 0x0a, 0x44, 0x9f, 0x40, 0x67, 0xdf, 0xc6, 0x8e, 0x35, 0xa0, 0x2e, 0x61, 0x7b, 0xab, 0x57, + 0xbf, 0x52, 0x59, 0x6b, 0x6d, 0xbc, 0xbe, 0x9e, 0xf7, 0x46, 0xeb, 0x4a, 0x8e, 0xac, 0xdf, 0x25, + 0xc3, 0xb7, 0xd9, 0xe8, 0xb7, 0xbd, 0x28, 0x38, 0x36, 0xda, 0xfb, 0x29, 0x10, 0xea, 0xc1, 0x02, + 0x67, 0x6f, 0x6f, 0xe1, 0x4a, 0x69, 0xad, 0x61, 0x88, 0x4f, 0xf4, 0x22, 0x2c, 0x06, 0x38, 0xf4, + 0x27, 0xc1, 0x10, 0x0f, 0x46, 0x81, 0x3f, 0x19, 0x87, 0xbd, 0xc6, 0x95, 0xca, 0x5a, 0xd3, 0xe8, + 0x0a, 0xf0, 0x3d, 0x0a, 0xed, 0xbf, 0x05, 0x4b, 0x39, 0x2c, 0x48, 0x83, 0xca, 0x23, 0x7c, 0x4c, + 0x05, 0x51, 0x31, 0xc8, 0x4f, 0xb4, 0x02, 0xb5, 0x43, 0xd3, 0x99, 0x60, 0xce, 0x6a, 0xf6, 0xf1, + 0xfd, 0xf2, 0xad, 0x92, 0xfe, 0xf3, 0x12, 0xf4, 0x0c, 0xec, 0x60, 0x33, 0xc4, 0x5f, 0xa6, 0x48, + 0xcf, 0x43, 0xdd, 0xf3, 0x2d, 0xbc, 0xbd, 0x45, 0x45, 0x5a, 0x31, 0xf8, 0x97, 0xfe, 0x79, 0x09, + 0x56, 0xee, 0xe1, 0x88, 0x98, 0x81, 0x1d, 0x46, 0xf6, 0x30, 0xb6, 0xf3, 0x37, 0xa1, 0x12, 0xe0, + 0xc7, 0x9c, 0xb2, 0x97, 0x64, 0xca, 0x62, 0xf7, 0xaf, 0x1a, 0x69, 0x90, 0x71, 0xe8, 0x59, 0x68, + 0x5b, 0xae, 0x33, 0x18, 0x1e, 0x98, 0x9e, 0x87, 0x1d, 0x66, 0x48, 0x4d, 0xa3, 0x65, 0xb9, 0xce, + 0x26, 0x07, 0xa1, 0xcb, 0x00, 0x21, 0x1e, 0xb9, 0xd8, 0x8b, 0x12, 0x9f, 0x9c, 0x82, 0xa0, 0x6b, + 0xb0, 0xb4, 0x1f, 0xf8, 0xee, 0x20, 0x3c, 0x30, 0x03, 0x6b, 0xe0, 0x60, 0xd3, 0xc2, 0x01, 0xa5, + 0xbe, 0x61, 0x2c, 0x92, 0x86, 0x5d, 0x02, 0x7f, 0x8f, 0x82, 0xd1, 0x4d, 0xa8, 0x85, 0x43, 0x7f, + 0x8c, 0xa9, 0xa6, 0x75, 0x37, 0x2e, 0xa9, 0x74, 0x68, 0xcb, 0x8c, 0xcc, 0x5d, 0xd2, 0xc9, 0x60, + 0x7d, 0xf5, 0xbf, 0xad, 0x32, 0x53, 0xfb, 0x8a, 0x3b, 0xb9, 0x94, 0x39, 0xd6, 0x9e, 0x8c, 0x39, + 0xd6, 0x0b, 0x99, 0xe3, 0xc2, 0xc9, 0xe6, 0x98, 0xe3, 0xda, 0x69, 0xcc, 0xb1, 0x31, 0xd3, 0x1c, + 0x9b, 0x2a, 0x73, 0x44, 0x6f, 0xc3, 0x22, 0x0b, 0x20, 0x6c, 0x6f, 0xdf, 0x1f, 0x38, 0x76, 0x18, + 0xf5, 0x80, 0x92, 0x79, 0x29, 0xab, 0xa1, 0x16, 0xfe, 0x74, 0x9d, 0x21, 0xf6, 0xf6, 0x7d, 0xa3, + 0x63, 0x8b, 0x9f, 0xef, 0xd9, 0x61, 0x34, 0xbf, 0x55, 0xff, 0x7d, 0x62, 0xd5, 0x5f, 0x75, 0xed, + 0x49, 0x2c, 0xbf, 0x26, 0x59, 0xfe, 0x9f, 0x95, 0xe0, 0x1b, 0xf7, 0x70, 0x14, 0x93, 0x4f, 0x0c, + 0x19, 0x7f, 0x45, 0xb7, 0xf9, 0xbf, 0x2c, 0x41, 0x5f, 0x45, 0xeb, 0x3c, 0x5b, 0xfd, 0x47, 0x70, + 0x3e, 0xc6, 0x31, 0xb0, 0x70, 0x38, 0x0c, 0xec, 0x31, 0x15, 0x23, 0xf5, 0x55, 0xad, 0x8d, 0xab, + 0x2a, 0xc5, 0xcf, 0x52, 0xb0, 0x1a, 0x4f, 0xb1, 0x95, 0x9a, 0x41, 0xff, 0x59, 0x09, 0x56, 0x89, + 0x6f, 0xe4, 0xce, 0x8c, 0x68, 0xe0, 0x99, 0xf9, 0x2a, 0xbb, 0xc9, 0x72, 0xce, 0x4d, 0x16, 0xe0, + 0x31, 0x0d, 0xb1, 0xb3, 0xf4, 0xcc, 0xc3, 0xbb, 0x57, 0xa1, 0x46, 0x0c, 0x50, 0xb0, 0xea, 0x19, + 0x15, 0xab, 0xd2, 0xc8, 0x58, 0x6f, 0xdd, 0x63, 0x54, 0x24, 0x7e, 0x7b, 0x0e, 0x75, 0xcb, 0x2e, + 0xbb, 0xac, 0x58, 0xf6, 0x6f, 0x97, 0xe0, 0x42, 0x0e, 0xe1, 0x3c, 0xeb, 0x7e, 0x03, 0xea, 0x74, + 0x37, 0x12, 0x0b, 0x7f, 0x4e, 0xb9, 0xf0, 0x14, 0x3a, 0xe2, 0x6d, 0x0c, 0x3e, 0x46, 0xf7, 0x41, + 0xcb, 0xb6, 0x91, 0x7d, 0x92, 0xef, 0x91, 0x03, 0xcf, 0x74, 0x19, 0x03, 0x9a, 0x46, 0x8b, 0xc3, + 0x76, 0x4c, 0x17, 0xa3, 0x6f, 0x40, 0x83, 0x98, 0xec, 0xc0, 0xb6, 0x84, 0xf8, 0x17, 0xa8, 0x09, + 0x5b, 0x21, 0xba, 0x04, 0x40, 0x9b, 0x4c, 0xcb, 0x0a, 0xd8, 0x16, 0xda, 0x34, 0x9a, 0x04, 0x72, + 0x9b, 0x00, 0xf4, 0x3f, 0x2c, 0xc1, 0xe5, 0xdd, 0x63, 0x6f, 0xb8, 0x83, 0x8f, 0x36, 0x03, 0x6c, + 0x46, 0x38, 0x71, 0xda, 0x4f, 0x95, 0xf1, 0xe8, 0x0a, 0xb4, 0x52, 0xf6, 0xcb, 0x55, 0x32, 0x0d, + 0xd2, 0xff, 0xaa, 0x04, 0x6d, 0xb2, 0x8b, 0xbc, 0x8f, 0x23, 0x93, 0xa8, 0x08, 0xfa, 0x1e, 0x34, + 0x1d, 0xdf, 0xb4, 0x06, 0xd1, 0xf1, 0x98, 0x51, 0xd3, 0xcd, 0x52, 0x93, 0x6c, 0x3d, 0x0f, 0x8e, + 0xc7, 0xd8, 0x68, 0x38, 0xfc, 0x57, 0x21, 0x8a, 0xb2, 0x5e, 0xa6, 0xa2, 0xf0, 0x94, 0xcf, 0x40, + 0xcb, 0xc5, 0x51, 0x60, 0x0f, 0x19, 0x11, 0x55, 0x2a, 0x0a, 0x60, 0x20, 0x82, 0x48, 0xff, 0x59, + 0x1d, 0xce, 0xff, 0xd0, 0x8c, 0x86, 0x07, 0x5b, 0xae, 0x88, 0x62, 0xce, 0xce, 0xc7, 0xc4, 0x2f, + 0x97, 0xd3, 0x7e, 0xf9, 0x89, 0xf9, 0xfd, 0xd8, 0x46, 0x6b, 0x2a, 0x1b, 0x25, 0x89, 0xf9, 0xfa, + 0x87, 0x5c, 0xcd, 0x52, 0x36, 0x9a, 0x0a, 0x36, 0xea, 0x67, 0x09, 0x36, 0x36, 0xa1, 0x83, 0x3f, + 0x1d, 0x3a, 0x13, 0xa2, 0xaf, 0x14, 0x3b, 0x8b, 0x22, 0x2e, 0x2b, 0xb0, 0xa7, 0x1d, 0x44, 0x9b, + 0x0f, 0xda, 0xe6, 0x34, 0x30, 0x5d, 0x70, 0x71, 0x64, 0xd2, 0x50, 0xa1, 0xb5, 0x71, 0x65, 0x9a, + 0x2e, 0x08, 0x05, 0x62, 0xfa, 0x40, 0xbe, 0xd0, 0x45, 0x68, 0xf2, 0xd0, 0x66, 0x7b, 0xab, 0xd7, + 0xa4, 0xec, 0x4b, 0x00, 0xc8, 0x84, 0x0e, 0xf7, 0x9e, 0x9c, 0x42, 0x16, 0x40, 0xbc, 0xa1, 0x42, + 0xa0, 0x16, 0x76, 0x9a, 0xf2, 0x90, 0x07, 0x3a, 0x61, 0x0a, 0x44, 0x32, 0x7f, 0x7f, 0x7f, 0xdf, + 0xb1, 0x3d, 0xbc, 0xc3, 0x24, 0xdc, 0xa2, 0x44, 0xc8, 0x40, 0x12, 0x0e, 0x1d, 0xe2, 0x20, 0xb4, + 0x7d, 0xaf, 0xd7, 0xa6, 0xed, 0xe2, 0x53, 0x15, 0xe5, 0x74, 0xce, 0x10, 0xe5, 0x0c, 0x60, 0x29, + 0x47, 0xa9, 0x22, 0xca, 0xf9, 0x4e, 0x3a, 0xca, 0x99, 0x2d, 0xaa, 0x54, 0x14, 0xf4, 0x8b, 0x12, + 0xac, 0x3e, 0xf4, 0xc2, 0xc9, 0x5e, 0xcc, 0xa2, 0x2f, 0xc7, 0x1c, 0xb2, 0x4e, 0xb4, 0x9a, 0x73, + 0xa2, 0xfa, 0xcf, 0xeb, 0xb0, 0xc8, 0x57, 0x41, 0xb4, 0x86, 0xba, 0x9c, 0x8b, 0xd0, 0x8c, 0xf7, + 0x51, 0xce, 0x90, 0x04, 0x90, 0xf5, 0x61, 0xe5, 0x9c, 0x0f, 0x2b, 0x44, 0x9a, 0x88, 0x8a, 0xaa, + 0xa9, 0xa8, 0xe8, 0x12, 0xc0, 0xbe, 0x33, 0x09, 0x0f, 0x06, 0x91, 0xed, 0x62, 0x1e, 0x95, 0x35, + 0x29, 0xe4, 0x81, 0xed, 0x62, 0x74, 0x1b, 0xda, 0x7b, 0xb6, 0xe7, 0xf8, 0xa3, 0xc1, 0xd8, 0x8c, + 0x0e, 0x42, 0x9e, 0x16, 0xab, 0xc4, 0x42, 0x63, 0xd8, 0x3b, 0xb4, 0xaf, 0xd1, 0x62, 0x63, 0xee, + 0x93, 0x21, 0xe8, 0x32, 0xb4, 0xbc, 0x89, 0x3b, 0xf0, 0xf7, 0x07, 0x81, 0x7f, 0x14, 0xd2, 0xe4, + 0xb7, 0x62, 0x34, 0xbd, 0x89, 0xfb, 0xc1, 0xbe, 0xe1, 0x1f, 0x91, 0x7d, 0xac, 0x49, 0x76, 0xb4, + 0xd0, 0xf1, 0x47, 0x2c, 0xf1, 0x9d, 0x3d, 0x7f, 0x32, 0x80, 0x8c, 0xb6, 0xb0, 0x13, 0x99, 0x74, + 0x74, 0xb3, 0xd8, 0xe8, 0x78, 0x00, 0x7a, 0x01, 0xba, 0x43, 0xdf, 0x1d, 0x9b, 0x94, 0x43, 0x77, + 0x03, 0xdf, 0xa5, 0x06, 0x58, 0x31, 0x32, 0x50, 0xb4, 0x09, 0xad, 0xc4, 0x08, 0xc2, 0x5e, 0x8b, + 0xe2, 0xd1, 0x55, 0x56, 0x9a, 0x0a, 0xe5, 0x89, 0x82, 0x42, 0x6c, 0x05, 0x21, 0xd1, 0x0c, 0x61, + 0xec, 0xa1, 0xfd, 0x19, 0xe6, 0x86, 0xd6, 0xe2, 0xb0, 0x5d, 0xfb, 0x33, 0x4c, 0xd2, 0x23, 0xdb, + 0x0b, 0x71, 0x10, 0x89, 0x64, 0xb5, 0xd7, 0xa1, 0xea, 0xd3, 0x61, 0x50, 0xae, 0xd8, 0x68, 0x0b, + 0xba, 0x61, 0x64, 0x06, 0xd1, 0x60, 0xec, 0x87, 0x54, 0x01, 0x7a, 0x5d, 0xaa, 0xdb, 0x19, 0x93, + 0x74, 0xc3, 0x11, 0x51, 0xec, 0xfb, 0xbc, 0x93, 0xd1, 0xa1, 0x83, 0xc4, 0x27, 0x99, 0x85, 0x72, + 0x22, 0x99, 0x65, 0xb1, 0xd0, 0x2c, 0x74, 0x50, 0x3c, 0xcb, 0x1a, 0x49, 0x97, 0x4c, 0xcb, 0xdc, + 0x73, 0xf0, 0x87, 0xdc, 0x83, 0x68, 0x74, 0x61, 0x59, 0x30, 0x61, 0x76, 0x18, 0xf9, 0x81, 0x39, + 0x8a, 0x3b, 0x2e, 0xd1, 0x8e, 0x19, 0xa8, 0xfe, 0x27, 0x15, 0xe8, 0xca, 0x6c, 0x24, 0xee, 0x89, + 0x65, 0x6f, 0xc2, 0x36, 0xc4, 0x27, 0x61, 0x2a, 0xf6, 0x08, 0x16, 0x96, 0x2a, 0x52, 0xd3, 0x68, + 0x18, 0x2d, 0x06, 0xa3, 0x13, 0x10, 0x15, 0x67, 0xc2, 0xa3, 0xf6, 0x58, 0xa1, 0x0c, 0x6d, 0x52, + 0x08, 0x0d, 0x69, 0x7a, 0xb0, 0x20, 0xb2, 0x4c, 0x66, 0x18, 0xe2, 0x93, 0xb4, 0xec, 0x4d, 0x6c, + 0x8a, 0x95, 0x19, 0x86, 0xf8, 0x44, 0x5b, 0xd0, 0x66, 0x53, 0x8e, 0xcd, 0xc0, 0x74, 0x85, 0x59, + 0x3c, 0xab, 0x74, 0x2d, 0xef, 0xe2, 0xe3, 0x0f, 0x89, 0x97, 0xba, 0x6f, 0xda, 0x81, 0xc1, 0xd4, + 0xe8, 0x3e, 0x1d, 0x85, 0xd6, 0x40, 0x63, 0xb3, 0xec, 0xdb, 0x0e, 0xe6, 0x06, 0xb6, 0xc0, 0x52, + 0x4d, 0x0a, 0xbf, 0x6b, 0x3b, 0x98, 0xd9, 0x50, 0xbc, 0x04, 0xaa, 0x38, 0x0d, 0x66, 0x42, 0x14, + 0x42, 0xd5, 0xe6, 0x2a, 0x30, 0x6f, 0x3b, 0x10, 0x3e, 0x9c, 0x6d, 0x34, 0x8c, 0x46, 0xc1, 0x7e, + 0x12, 0xba, 0x4d, 0x5c, 0x66, 0x84, 0xc0, 0x96, 0xe3, 0x4d, 0x5c, 0x6a, 0x82, 0x1b, 0xb0, 0x3a, + 0x9c, 0x04, 0x01, 0xdb, 0x86, 0xd2, 0xf3, 0xb4, 0x68, 0x72, 0xbe, 0xcc, 0x1b, 0xb7, 0x53, 0xd3, + 0xe9, 0xbf, 0x57, 0x83, 0x65, 0xe2, 0xbd, 0xb8, 0x23, 0x9b, 0x23, 0xf8, 0xb8, 0x04, 0x60, 0x85, + 0xd1, 0x40, 0xf2, 0xb8, 0x4d, 0x2b, 0x8c, 0xf8, 0xd6, 0xf4, 0x3d, 0x11, 0x3b, 0x54, 0xa6, 0xa7, + 0x42, 0x19, 0x6f, 0x9a, 0x8f, 0x1f, 0xce, 0x54, 0x3b, 0xbc, 0x0a, 0x1d, 0x5e, 0x07, 0x90, 0x92, + 0xd6, 0x36, 0x03, 0xee, 0xa8, 0xf7, 0x84, 0xba, 0xb2, 0x86, 0x99, 0x8a, 0x21, 0x16, 0xe6, 0x8b, + 0x21, 0x1a, 0xd9, 0x18, 0xe2, 0x2e, 0x2c, 0xca, 0x66, 0x2c, 0xfc, 0xe0, 0x0c, 0x3b, 0xee, 0x4a, + 0x76, 0x1c, 0xa6, 0x43, 0x00, 0x90, 0x43, 0x80, 0xab, 0xd0, 0xf1, 0x30, 0xb6, 0x06, 0x51, 0x60, + 0x7a, 0xe1, 0x3e, 0x0e, 0xa8, 0x5a, 0x34, 0x8c, 0x36, 0x01, 0x3e, 0xe0, 0x30, 0xf4, 0x06, 0x00, + 0x5d, 0x23, 0x2b, 0x7d, 0xb5, 0xa7, 0x97, 0xbe, 0xa8, 0xd2, 0xd0, 0xd2, 0x17, 0x65, 0x0a, 0xfd, + 0xf9, 0x84, 0xa2, 0x0c, 0xfd, 0x9f, 0xcb, 0x70, 0x9e, 0x97, 0x42, 0xe6, 0xd7, 0xcb, 0x69, 0x51, + 0x80, 0xd8, 0x46, 0x2b, 0x27, 0x14, 0x17, 0xaa, 0x05, 0x02, 0xe5, 0x9a, 0x22, 0x50, 0x96, 0x13, + 0xec, 0x7a, 0x2e, 0xc1, 0x8e, 0x6b, 0x8b, 0x0b, 0xc5, 0x6b, 0x8b, 0x68, 0x05, 0x6a, 0x34, 0xeb, + 0xa3, 0xba, 0xd3, 0x34, 0xd8, 0x47, 0x21, 0xa9, 0xea, 0x7f, 0x50, 0x86, 0xce, 0x2e, 0x36, 0x83, + 0xe1, 0x81, 0xe0, 0xe3, 0x6b, 0xe9, 0x5a, 0xec, 0x73, 0x53, 0x6a, 0xb1, 0xd2, 0x90, 0xaf, 0x4d, + 0x11, 0x96, 0x20, 0x88, 0xfc, 0xc8, 0x8c, 0xa9, 0x1c, 0x78, 0x13, 0x97, 0x17, 0x28, 0x17, 0x69, + 0x03, 0x27, 0x75, 0x67, 0xe2, 0xea, 0xff, 0x59, 0x82, 0xf6, 0xff, 0x27, 0xd3, 0x08, 0xc6, 0xdc, + 0x4a, 0x33, 0xe6, 0x85, 0x29, 0x8c, 0x31, 0x48, 0x02, 0x87, 0x0f, 0xf1, 0xd7, 0xae, 0x3e, 0xfd, + 0x0f, 0x25, 0xe8, 0x93, 0xf4, 0xdd, 0x60, 0x7e, 0x67, 0x7e, 0xeb, 0xba, 0x0a, 0x9d, 0x43, 0x29, + 0x50, 0x2e, 0x53, 0xe5, 0x6c, 0x1f, 0xa6, 0xcb, 0x0d, 0x06, 0x68, 0xa2, 0x5c, 0xcc, 0x17, 0x2b, + 0xb6, 0x81, 0x17, 0x55, 0x54, 0x67, 0x88, 0xa3, 0x1e, 0x62, 0x31, 0x90, 0x81, 0xfa, 0xef, 0x94, + 0x60, 0x59, 0xd1, 0x11, 0x5d, 0x80, 0x05, 0x5e, 0xda, 0xe0, 0x31, 0x06, 0xb3, 0x77, 0x8b, 0x88, + 0x27, 0x29, 0xce, 0xd9, 0x56, 0x3e, 0xfa, 0xb6, 0x48, 0xb6, 0x1e, 0xe7, 0x71, 0x56, 0x4e, 0x3e, + 0x56, 0x88, 0xfa, 0xd0, 0xe0, 0xde, 0x54, 0x24, 0xc8, 0xf1, 0xb7, 0xfe, 0x08, 0xd0, 0x3d, 0x9c, + 0xec, 0x5d, 0xf3, 0x70, 0x34, 0xf1, 0x37, 0x09, 0xa1, 0x69, 0x27, 0x64, 0xe9, 0xff, 0x5e, 0x82, + 0x65, 0x09, 0xdb, 0x3c, 0x25, 0xa8, 0x64, 0x7f, 0x2d, 0x9f, 0x65, 0x7f, 0x95, 0xca, 0x2c, 0x95, + 0x53, 0x95, 0x59, 0x2e, 0x03, 0xc4, 0xfc, 0x17, 0x1c, 0x4d, 0x41, 0xf4, 0xbf, 0x2b, 0xc1, 0xf9, + 0x77, 0x4c, 0xcf, 0xf2, 0xf7, 0xf7, 0xe7, 0x57, 0xd5, 0x4d, 0x90, 0x52, 0xea, 0xa2, 0x85, 0x46, + 0x39, 0x0f, 0x7f, 0x09, 0x96, 0x02, 0xb6, 0x33, 0x59, 0xb2, 0x2e, 0x57, 0x0c, 0x4d, 0x34, 0xc4, + 0x3a, 0xfa, 0x17, 0x65, 0x40, 0x64, 0xd5, 0x77, 0x4c, 0xc7, 0xf4, 0x86, 0xf8, 0xec, 0xa4, 0x3f, + 0x0f, 0x5d, 0x29, 0x84, 0x89, 0x0f, 0xfe, 0xd3, 0x31, 0x4c, 0x88, 0xde, 0x85, 0xee, 0x1e, 0x43, + 0x35, 0x08, 0xb0, 0x19, 0xfa, 0x1e, 0x17, 0x87, 0xb2, 0xa6, 0xf8, 0x20, 0xb0, 0x47, 0x23, 0x1c, + 0x6c, 0xfa, 0x9e, 0xc5, 0x33, 0x82, 0x3d, 0x41, 0x26, 0x19, 0x4a, 0x8c, 0x21, 0x89, 0xe7, 0x62, + 0xe1, 0xc4, 0x01, 0x1d, 0x65, 0x45, 0x88, 0x4d, 0x27, 0x61, 0x44, 0xb2, 0x1b, 0x6a, 0xac, 0x61, + 0x77, 0x7a, 0x49, 0x59, 0x11, 0x5f, 0xe9, 0x7f, 0x5d, 0x02, 0x14, 0xa7, 0xfd, 0xb4, 0x4e, 0x42, + 0x2d, 0x3a, 0x3b, 0xb4, 0xa4, 0xd8, 0x94, 0x2f, 0x42, 0xd3, 0x12, 0x23, 0xb9, 0x0b, 0x4a, 0x00, + 0x74, 0x8f, 0xa4, 0x44, 0x0f, 0x88, 0xe6, 0x61, 0x4b, 0xa4, 0xd5, 0x0c, 0xf8, 0x1e, 0x85, 0xc9, + 0xe1, 0x59, 0x35, 0x1b, 0x9e, 0xa5, 0x2b, 0xa6, 0x35, 0xa9, 0x62, 0xaa, 0xff, 0xa2, 0x0c, 0x1a, + 0xdd, 0x42, 0x36, 0x93, 0xd2, 0x57, 0x21, 0xa2, 0xaf, 0x42, 0x87, 0x5f, 0x9c, 0x91, 0x08, 0x6f, + 0x3f, 0x4e, 0x4d, 0x86, 0x6e, 0xc0, 0x0a, 0xeb, 0x14, 0xe0, 0x70, 0xe2, 0x24, 0x19, 0x25, 0x4b, + 0x80, 0xd0, 0x63, 0xb6, 0x77, 0x91, 0x26, 0x31, 0xe2, 0x21, 0x9c, 0x1f, 0x39, 0xfe, 0x9e, 0xe9, + 0x0c, 0x64, 0xf1, 0x30, 0x19, 0x16, 0xd0, 0xf8, 0x15, 0x36, 0x7c, 0x37, 0x2d, 0xc3, 0x10, 0xdd, + 0x81, 0x4e, 0x88, 0xf1, 0xa3, 0x24, 0xcd, 0xac, 0x15, 0x49, 0x33, 0xdb, 0x64, 0x8c, 0xf8, 0xd2, + 0xff, 0xa8, 0x04, 0x8b, 0x99, 0xf3, 0x8e, 0x6c, 0x51, 0xa4, 0x94, 0x2f, 0x8a, 0xdc, 0x82, 0x1a, + 0xf1, 0x54, 0x6c, 0x6f, 0xe9, 0xaa, 0x13, 0x76, 0x79, 0x56, 0x83, 0x0d, 0x40, 0xd7, 0x61, 0x59, + 0x71, 0xaf, 0x82, 0x8b, 0x1f, 0xe5, 0xaf, 0x55, 0xe8, 0xbf, 0xac, 0x42, 0x2b, 0xc5, 0x8a, 0x19, + 0xf5, 0x9c, 0x27, 0x52, 0xb7, 0x9e, 0x76, 0x8e, 0x4e, 0x54, 0xce, 0xc5, 0x2e, 0xcb, 0x15, 0x79, + 0xe2, 0xea, 0x62, 0x97, 0x66, 0x8a, 0xe9, 0x24, 0xb0, 0x2e, 0x27, 0x81, 0x72, 0x9a, 0xbc, 0x70, + 0x42, 0x9a, 0xdc, 0x90, 0xd3, 0x64, 0xc9, 0x84, 0x9a, 0x59, 0x13, 0x2a, 0x5a, 0x62, 0xb9, 0x01, + 0xcb, 0x43, 0x76, 0x2e, 0x70, 0xe7, 0x78, 0x33, 0x6e, 0xe2, 0x41, 0xa9, 0xaa, 0x09, 0xdd, 0x4d, + 0x8a, 0xa7, 0x4c, 0xca, 0x2c, 0xe9, 0x50, 0x67, 0xe1, 0x5c, 0x36, 0x4c, 0xc8, 0xc2, 0x33, 0xd3, + 0xaf, 0x6c, 0x71, 0xa7, 0x73, 0xa6, 0xe2, 0xce, 0x33, 0xd0, 0x12, 0x91, 0x0a, 0xb1, 0xf4, 0x2e, + 0x73, 0x7a, 0xc2, 0x0d, 0x58, 0xa1, 0xe4, 0x07, 0x16, 0xe5, 0x93, 0x93, 0x6c, 0x0d, 0x43, 0xcb, + 0xd7, 0x30, 0x2e, 0xc0, 0x82, 0x1d, 0x0e, 0xf6, 0xcd, 0x47, 0x98, 0x16, 0x4d, 0x1a, 0x46, 0xdd, + 0x0e, 0xef, 0x9a, 0x8f, 0xb0, 0xfe, 0xaf, 0x15, 0xe8, 0x26, 0x1b, 0x6c, 0x61, 0x0f, 0x52, 0xe4, + 0x6e, 0xd1, 0x0e, 0x68, 0x49, 0xdc, 0x43, 0x39, 0x7c, 0x62, 0x0e, 0x9e, 0x3d, 0x8e, 0x5c, 0x1c, + 0x67, 0xec, 0x55, 0xda, 0xee, 0xab, 0xa7, 0xda, 0xee, 0xe7, 0xbc, 0x75, 0x70, 0x13, 0x56, 0xe3, + 0xbd, 0x57, 0x5a, 0x36, 0x4b, 0xb0, 0x56, 0x44, 0xe3, 0xfd, 0xf4, 0xf2, 0xa7, 0xb8, 0x80, 0x85, + 0x69, 0x2e, 0x20, 0xab, 0x02, 0x8d, 0x9c, 0x0a, 0xe4, 0x2f, 0x3f, 0x34, 0x15, 0x97, 0x1f, 0xf4, + 0x87, 0xb0, 0x4c, 0x0b, 0xd9, 0xe1, 0x30, 0xb0, 0xf7, 0x70, 0x9c, 0x02, 0x14, 0x11, 0x6b, 0x1f, + 0x1a, 0x99, 0x2c, 0x22, 0xfe, 0xd6, 0x7f, 0x5a, 0x82, 0xf3, 0xf9, 0x79, 0xa9, 0xc6, 0x24, 0x8e, + 0xa4, 0x24, 0x39, 0x92, 0x5f, 0x81, 0xe5, 0x54, 0x44, 0x29, 0xcd, 0x3c, 0x25, 0x02, 0x57, 0x10, + 0x6e, 0xa0, 0x64, 0x0e, 0x01, 0xd3, 0x7f, 0x59, 0x8a, 0xcf, 0x03, 0x08, 0x6c, 0x44, 0x0f, 0x5b, + 0xc8, 0xbe, 0xe6, 0x7b, 0x8e, 0xed, 0xc5, 0x05, 0x17, 0xbe, 0x46, 0x06, 0xe4, 0x05, 0x97, 0x77, + 0x60, 0x91, 0x77, 0x8a, 0xb7, 0xa7, 0x82, 0x01, 0x59, 0x97, 0x8d, 0x8b, 0x37, 0xa6, 0xe7, 0xa1, + 0xcb, 0x4f, 0x41, 0x04, 0xbe, 0x8a, 0xea, 0x6c, 0xe4, 0x07, 0xa0, 0x89, 0x6e, 0xa7, 0xdd, 0x10, + 0x17, 0xf9, 0xc0, 0x38, 0xb0, 0xfb, 0xcd, 0x12, 0xf4, 0xe4, 0xed, 0x31, 0xb5, 0xfc, 0xd3, 0x87, + 0x77, 0xaf, 0xcb, 0x67, 0xdf, 0xcf, 0x9f, 0x40, 0x4f, 0x82, 0x47, 0x9c, 0x80, 0xff, 0x6e, 0x99, + 0x5e, 0x64, 0x20, 0xa9, 0xde, 0x96, 0x1d, 0x46, 0x81, 0xbd, 0x37, 0x99, 0xef, 0x34, 0xd6, 0x84, + 0xd6, 0xf0, 0x00, 0x0f, 0x1f, 0x8d, 0x7d, 0x3b, 0x91, 0xca, 0x5b, 0x2a, 0x9a, 0xa6, 0xa3, 0x5d, + 0xdf, 0x4c, 0x66, 0x60, 0xc7, 0x59, 0xe9, 0x39, 0xfb, 0x3f, 0x06, 0x2d, 0xdb, 0x21, 0x7d, 0x8a, + 0xd4, 0x64, 0xa7, 0x48, 0x37, 0xe5, 0x53, 0xa4, 0x19, 0x91, 0x46, 0xea, 0x10, 0xe9, 0x6f, 0xca, + 0xf0, 0x4d, 0x25, 0x6d, 0xf3, 0x64, 0x49, 0xd3, 0xea, 0x48, 0x77, 0xa0, 0x91, 0x49, 0x6a, 0x5f, + 0x38, 0x41, 0x7e, 0xbc, 0xee, 0xca, 0x4a, 0x83, 0x61, 0x12, 0x5b, 0x25, 0x06, 0x5f, 0x9d, 0x3e, + 0x07, 0xb7, 0x3b, 0x69, 0x0e, 0x31, 0x0e, 0xdd, 0x86, 0x36, 0x2b, 0x18, 0x0c, 0x0e, 0x6d, 0x7c, + 0x24, 0xce, 0x68, 0x2f, 0x2b, 0x5d, 0x33, 0xed, 0xf7, 0xa1, 0x8d, 0x8f, 0x8c, 0x96, 0x13, 0xff, + 0x0e, 0xf5, 0xdf, 0xaf, 0x02, 0x24, 0x6d, 0x24, 0x3b, 0x4b, 0x6c, 0x9e, 0x1b, 0x71, 0x0a, 0x42, + 0x62, 0x09, 0x39, 0x72, 0x15, 0x9f, 0xc8, 0x48, 0xce, 0x48, 0x2c, 0x3b, 0x8c, 0x38, 0x5f, 0xae, + 0x9f, 0x4c, 0x8b, 0x60, 0x11, 0x11, 0x19, 0xd7, 0x99, 0x30, 0x81, 0xa0, 0x57, 0x00, 0x8d, 0x02, + 0xff, 0xc8, 0xf6, 0x46, 0xe9, 0x7c, 0x83, 0xa5, 0x25, 0x4b, 0xbc, 0x25, 0x95, 0x70, 0xfc, 0x04, + 0xb4, 0x4c, 0x77, 0xc1, 0x92, 0x9b, 0x33, 0xc8, 0xb8, 0x27, 0xcd, 0xc5, 0xd5, 0x77, 0x51, 0xc6, + 0x40, 0x0f, 0x64, 0x1f, 0x98, 0xc1, 0x08, 0x0b, 0x89, 0xf2, 0x38, 0x4c, 0x06, 0xf6, 0x07, 0xa0, + 0x65, 0x57, 0xa5, 0x38, 0x2e, 0x7d, 0x55, 0x56, 0xf4, 0x93, 0xfc, 0x11, 0x99, 0x26, 0xa5, 0xea, + 0x7d, 0x13, 0x56, 0x54, 0xf4, 0x2a, 0x90, 0x9c, 0xd9, 0x9a, 0xde, 0x8a, 0x43, 0x62, 0x2a, 0x87, + 0x69, 0xbb, 0x4c, 0xaa, 0xf0, 0x5c, 0x96, 0x0a, 0xcf, 0xfa, 0xaf, 0x57, 0x00, 0xe5, 0xd5, 0x1f, + 0x75, 0xa1, 0x1c, 0x4f, 0x52, 0xde, 0xde, 0xca, 0xa8, 0x5b, 0x39, 0xa7, 0x6e, 0x17, 0xa1, 0x19, + 0xef, 0xfa, 0xdc, 0xc5, 0x27, 0x80, 0xb4, 0x32, 0x56, 0x65, 0x65, 0x4c, 0x11, 0x56, 0x93, 0x2b, + 0xe2, 0x37, 0x60, 0xc5, 0x31, 0xc3, 0x68, 0xc0, 0x0a, 0xef, 0x91, 0xed, 0xe2, 0x30, 0x32, 0xdd, + 0x31, 0x15, 0x65, 0xd5, 0x40, 0xa4, 0x6d, 0x8b, 0x34, 0x3d, 0x10, 0x2d, 0xe8, 0x81, 0x88, 0xae, + 0x89, 0xef, 0xe5, 0x17, 0x11, 0x5e, 0x2d, 0x66, 0xee, 0x49, 0xb9, 0x9b, 0x69, 0x54, 0x33, 0x0e, + 0x3b, 0xfb, 0x9f, 0x40, 0x57, 0x6e, 0x54, 0x88, 0xef, 0x96, 0x2c, 0xbe, 0x22, 0x81, 0x6d, 0x4a, + 0x86, 0x07, 0x80, 0xf2, 0xce, 0x23, 0xcd, 0xb3, 0x92, 0xcc, 0xb3, 0x59, 0xb2, 0x48, 0xf1, 0xb4, + 0x22, 0x0b, 0xfb, 0x1f, 0x2b, 0x80, 0x92, 0x08, 0x2e, 0x3e, 0x18, 0x2f, 0x12, 0xf6, 0x5c, 0x87, + 0xe5, 0x7c, 0x7c, 0x27, 0x82, 0x5a, 0x94, 0x8b, 0xee, 0x54, 0x91, 0x58, 0x45, 0x75, 0x0d, 0xf5, + 0xb5, 0xd8, 0xdd, 0xb3, 0x70, 0xf5, 0xf2, 0xd4, 0xf3, 0x0c, 0xd9, 0xe3, 0xff, 0x38, 0x7b, 0x7d, + 0x95, 0xf9, 0x8f, 0x5b, 0x4a, 0xd7, 0x9c, 0x5b, 0xf2, 0xcc, 0xbb, 0xab, 0x52, 0x20, 0x5d, 0x3f, + 0x55, 0x20, 0x7d, 0x15, 0x3a, 0x01, 0x1e, 0xfa, 0x87, 0x38, 0x60, 0x5a, 0x4b, 0xc3, 0xd9, 0x9a, + 0xd1, 0xe6, 0x40, 0xaa, 0xaf, 0xf3, 0xdf, 0x48, 0xfd, 0xef, 0x32, 0x2c, 0xc5, 0xdc, 0x3e, 0x95, + 0x24, 0x67, 0x5f, 0x74, 0x78, 0xca, 0xa2, 0xfb, 0x58, 0x2d, 0xba, 0xef, 0x9e, 0x98, 0xf1, 0x14, + 0x96, 0xdc, 0x17, 0xc3, 0xfe, 0xcf, 0x60, 0x81, 0x17, 0xb8, 0x73, 0xae, 0xb2, 0x48, 0xe1, 0x61, + 0x05, 0x6a, 0xc4, 0x33, 0x8b, 0xea, 0x24, 0xfb, 0x60, 0x7c, 0x4f, 0x5f, 0x8b, 0xe6, 0xde, 0xb2, + 0x23, 0xdd, 0x8a, 0xd6, 0x7f, 0xab, 0x02, 0xb0, 0x7b, 0xec, 0x0d, 0x6f, 0x33, 0x73, 0xbf, 0x01, + 0xd5, 0x59, 0x97, 0xe8, 0x48, 0x6f, 0xaa, 0xa5, 0xb4, 0x67, 0x01, 0x0d, 0x90, 0x4a, 0x2b, 0x95, + 0x6c, 0x69, 0x65, 0x5a, 0x51, 0x64, 0xba, 0x33, 0xff, 0x2e, 0x54, 0xa9, 0x53, 0x66, 0x77, 0xcc, + 0x0a, 0x9d, 0x2f, 0xd3, 0x01, 0x68, 0x0d, 0xc4, 0xe6, 0xbe, 0xed, 0xb1, 0xdd, 0x9b, 0x3a, 0xf6, + 0x8a, 0x91, 0x05, 0xd3, 0xab, 0x0f, 0x34, 0xea, 0x8f, 0x3b, 0xb2, 0xec, 0x30, 0x03, 0xcd, 0xc7, + 0x06, 0x4d, 0x45, 0x6c, 0x40, 0xf0, 0x5a, 0x81, 0x3f, 0x1e, 0xa7, 0xa6, 0x63, 0x35, 0x95, 0x2c, + 0x58, 0xff, 0xbc, 0x0c, 0x17, 0x08, 0x7f, 0x9f, 0x4c, 0x7c, 0x5f, 0x44, 0x79, 0x52, 0x3b, 0x43, + 0x45, 0xde, 0x19, 0x6e, 0xc1, 0x02, 0x2b, 0xdc, 0x88, 0x48, 0xf5, 0xf2, 0x34, 0x6d, 0x60, 0xba, + 0x63, 0x88, 0xee, 0xf3, 0x66, 0xff, 0xd2, 0xe9, 0x7b, 0x7d, 0xbe, 0xd3, 0xf7, 0x85, 0x6c, 0x79, + 0x37, 0xa5, 0x56, 0x0d, 0x79, 0x3f, 0x7b, 0x08, 0x1d, 0x23, 0x6d, 0x1a, 0x08, 0x41, 0x35, 0x75, + 0xad, 0x96, 0xfe, 0xa6, 0x09, 0xbb, 0x39, 0x36, 0x87, 0x76, 0x74, 0x4c, 0xd9, 0x59, 0x33, 0xe2, + 0x6f, 0xb5, 0x1d, 0xea, 0xff, 0x55, 0x82, 0xf3, 0xe2, 0x78, 0x96, 0x5b, 0xf9, 0xd9, 0x25, 0xba, + 0x01, 0xab, 0xdc, 0xa4, 0x33, 0xb6, 0xcd, 0xc2, 0xf2, 0x65, 0x06, 0x93, 0x97, 0xb1, 0x01, 0xab, + 0x11, 0xd5, 0xae, 0xec, 0x18, 0x26, 0xef, 0x65, 0xd6, 0x28, 0x8f, 0x29, 0x72, 0x3c, 0xfe, 0x0c, + 0xbb, 0x27, 0xc6, 0x59, 0xcb, 0x8d, 0x14, 0xbc, 0x89, 0xcb, 0x57, 0xa9, 0x1f, 0xc1, 0x45, 0x76, + 0xb1, 0x7d, 0x4f, 0xa6, 0x68, 0xae, 0xd3, 0x11, 0xe5, 0xba, 0x33, 0x3e, 0xed, 0x8f, 0x4b, 0x70, + 0x69, 0x0a, 0xe6, 0x79, 0xf2, 0xc2, 0xf7, 0x94, 0xd8, 0xa7, 0x64, 0xf1, 0x12, 0x5e, 0x76, 0xf5, + 0x41, 0x26, 0xf2, 0xf3, 0x2a, 0x2c, 0xe5, 0x3a, 0x9d, 0x5a, 0xe7, 0x5e, 0x06, 0x44, 0x84, 0x10, + 0x3f, 0xe2, 0xa4, 0x85, 0x11, 0xbe, 0xc3, 0x6a, 0xde, 0xc4, 0x8d, 0x1f, 0x70, 0xee, 0xf8, 0x16, + 0x46, 0x36, 0xeb, 0xcd, 0xce, 0x46, 0x62, 0xc9, 0x55, 0xa7, 0xbf, 0xd5, 0xc9, 0x11, 0xb8, 0xbe, + 0x33, 0x71, 0xd9, 0x31, 0x0a, 0x97, 0x32, 0xdb, 0x35, 0x09, 0x2a, 0x09, 0x8c, 0xf6, 0x61, 0x89, + 0xde, 0x22, 0x9c, 0x44, 0x23, 0x9f, 0xa4, 0x66, 0x94, 0x2e, 0xb6, 0x37, 0x7f, 0xbf, 0x30, 0xa6, + 0x0f, 0xf8, 0x68, 0x42, 0x3c, 0xcf, 0xce, 0x3c, 0x19, 0x2a, 0xf0, 0xd8, 0xde, 0xd0, 0x77, 0x63, + 0x3c, 0xf5, 0x53, 0xe2, 0xd9, 0xe6, 0xa3, 0x65, 0x3c, 0x69, 0x68, 0x7f, 0x13, 0x56, 0x95, 0x4b, + 0x9f, 0xb5, 0xd1, 0xd7, 0xd2, 0x39, 0xdc, 0x1d, 0x58, 0x51, 0xad, 0xea, 0x0c, 0x73, 0xe4, 0x28, + 0x3e, 0xcd, 0x1c, 0xfa, 0x9f, 0x97, 0xa1, 0xb3, 0x85, 0x1d, 0x1c, 0xe1, 0xa7, 0x7b, 0x7a, 0x9d, + 0x3b, 0x8a, 0xaf, 0xe4, 0x8f, 0xe2, 0x73, 0xf7, 0x0a, 0xaa, 0x8a, 0x7b, 0x05, 0x97, 0xe2, 0xeb, + 0x14, 0x64, 0x96, 0x9a, 0x1c, 0x43, 0x58, 0xe8, 0x75, 0x68, 0x8f, 0x03, 0xdb, 0x35, 0x83, 0xe3, + 0xc1, 0x23, 0x7c, 0x1c, 0xf2, 0x4d, 0xa3, 0xa7, 0xdc, 0x76, 0xb6, 0xb7, 0x42, 0xa3, 0xc5, 0x7b, + 0xbf, 0x8b, 0x8f, 0xe9, 0x55, 0x8d, 0x38, 0x21, 0x64, 0xf7, 0xf9, 0xaa, 0x46, 0x0a, 0x72, 0xed, + 0x25, 0x68, 0xc6, 0x57, 0xa0, 0x50, 0x03, 0xaa, 0x77, 0x27, 0x8e, 0xa3, 0x9d, 0x43, 0x4d, 0xa8, + 0xd1, 0x94, 0x51, 0x2b, 0x91, 0x9f, 0x34, 0xf6, 0xd3, 0xca, 0xd7, 0xfe, 0x1f, 0x34, 0xe3, 0xab, + 0x18, 0xa8, 0x05, 0x0b, 0x0f, 0xbd, 0x77, 0x3d, 0xff, 0xc8, 0xd3, 0xce, 0xa1, 0x05, 0xa8, 0xdc, + 0x76, 0x1c, 0xad, 0x84, 0x3a, 0xd0, 0xdc, 0x8d, 0x02, 0x6c, 0x12, 0xf1, 0x69, 0x65, 0xd4, 0x05, + 0x78, 0xc7, 0x0e, 0x23, 0x3f, 0xb0, 0x87, 0xa6, 0xa3, 0x55, 0xae, 0x7d, 0x06, 0x5d, 0xb9, 0x32, + 0x8f, 0xda, 0xd0, 0xd8, 0xf1, 0xa3, 0xb7, 0x3f, 0xb5, 0xc3, 0x48, 0x3b, 0x47, 0xfa, 0xef, 0xf8, + 0xd1, 0xfd, 0x00, 0x87, 0xd8, 0x8b, 0xb4, 0x12, 0x02, 0xa8, 0x7f, 0xe0, 0x6d, 0xd9, 0xe1, 0x23, + 0xad, 0x8c, 0x96, 0xf9, 0xa1, 0x9b, 0xe9, 0x6c, 0xf3, 0x72, 0xb7, 0x56, 0x21, 0xc3, 0xe3, 0xaf, + 0x2a, 0xd2, 0xa0, 0x1d, 0x77, 0xb9, 0x77, 0xff, 0xa1, 0x56, 0x63, 0xd4, 0x93, 0x9f, 0xf5, 0x6b, + 0x16, 0x68, 0xd9, 0xc3, 0x62, 0x32, 0x27, 0x5b, 0x44, 0x0c, 0xd2, 0xce, 0x91, 0x95, 0xf1, 0xd3, + 0x7a, 0xad, 0x84, 0x16, 0xa1, 0x95, 0x3a, 0xfb, 0xd6, 0xca, 0x04, 0x70, 0x2f, 0x18, 0x0f, 0xb9, + 0x6e, 0x31, 0x12, 0x88, 0xa2, 0x6e, 0x11, 0x4e, 0x54, 0xaf, 0xdd, 0x81, 0x86, 0xc8, 0x74, 0x48, + 0x57, 0xce, 0x22, 0xf2, 0xa9, 0x9d, 0x43, 0x4b, 0xd0, 0x91, 0x1e, 0x08, 0x6a, 0x25, 0x84, 0xa0, + 0x2b, 0x3f, 0xe1, 0xd5, 0xca, 0xd7, 0x36, 0x00, 0x92, 0x64, 0x80, 0x90, 0xb3, 0xed, 0x1d, 0x9a, + 0x8e, 0x6d, 0x31, 0xda, 0x48, 0x13, 0xe1, 0x2e, 0xe5, 0x0e, 0xb3, 0x59, 0xad, 0x7c, 0xed, 0x4d, + 0x68, 0x88, 0xd8, 0x95, 0xc0, 0x0d, 0xec, 0xfa, 0x87, 0x98, 0x49, 0x66, 0x17, 0x47, 0x4c, 0x8e, + 0xb7, 0x5d, 0xec, 0x59, 0x5a, 0x99, 0x90, 0xf1, 0x70, 0x6c, 0x99, 0x91, 0xb8, 0xe4, 0xaa, 0x55, + 0x36, 0xfe, 0x63, 0x19, 0x80, 0x9d, 0xfe, 0xfa, 0x7e, 0x60, 0x21, 0x87, 0xde, 0x02, 0xd9, 0xf4, + 0xdd, 0xb1, 0xef, 0x89, 0xa3, 0xa9, 0x10, 0xad, 0x67, 0x8a, 0x2d, 0xec, 0x23, 0xdf, 0x91, 0xf3, + 0xa6, 0xff, 0x9c, 0xb2, 0x7f, 0xa6, 0xb3, 0x7e, 0x0e, 0xb9, 0x14, 0x1b, 0xc9, 0x37, 0x1e, 0xd8, + 0xc3, 0x47, 0xf1, 0x91, 0xf1, 0xf4, 0xa7, 0xb5, 0x99, 0xae, 0x02, 0xdf, 0x55, 0x25, 0xbe, 0xdd, + 0x28, 0xb0, 0xbd, 0x91, 0xd8, 0x1d, 0xf5, 0x73, 0xe8, 0x71, 0xe6, 0x61, 0xaf, 0x40, 0xb8, 0x51, + 0xe4, 0x2d, 0xef, 0xd9, 0x50, 0x3a, 0xb0, 0x98, 0xf9, 0x07, 0x05, 0x74, 0x4d, 0xfd, 0x42, 0x4a, + 0xf5, 0x6f, 0x0f, 0xfd, 0x97, 0x0a, 0xf5, 0x8d, 0xb1, 0xd9, 0xd0, 0x95, 0x9f, 0xfe, 0xa3, 0x6f, + 0x4d, 0x9b, 0x20, 0xf7, 0x46, 0xb3, 0x7f, 0xad, 0x48, 0xd7, 0x18, 0xd5, 0x47, 0x4c, 0x7d, 0x67, + 0xa1, 0x52, 0x3e, 0x8b, 0xed, 0x9f, 0x14, 0x98, 0xe8, 0xe7, 0xd0, 0x27, 0x24, 0x86, 0xc8, 0xbc, + 0x24, 0x45, 0x2f, 0xab, 0xf7, 0x3d, 0xf5, 0x83, 0xd3, 0x59, 0x18, 0x3e, 0xca, 0x1a, 0xdf, 0x74, + 0xea, 0x73, 0x4f, 0xd4, 0x8b, 0x53, 0x9f, 0x9a, 0xfe, 0x24, 0xea, 0x4f, 0x8d, 0xc1, 0x61, 0xe9, + 0x94, 0xe2, 0x0d, 0x5b, 0x56, 0x95, 0x93, 0x6c, 0x66, 0xfa, 0x83, 0xb7, 0x59, 0xd8, 0x26, 0xd4, + 0x48, 0xb3, 0xd7, 0x1e, 0x5e, 0x99, 0x72, 0xa0, 0xa2, 0x7e, 0x3c, 0xdb, 0x5f, 0x2f, 0xda, 0x3d, + 0xad, 0xcb, 0xf2, 0xfb, 0x4c, 0xb5, 0x88, 0x94, 0x6f, 0x4a, 0xd5, 0xba, 0xac, 0x7e, 0xee, 0xa9, + 0x9f, 0x43, 0x0f, 0x24, 0x57, 0x8f, 0x5e, 0x98, 0xa6, 0x0a, 0xf2, 0x3d, 0xa8, 0x59, 0x7c, 0xfb, + 0x55, 0x40, 0xcc, 0x52, 0xbd, 0x7d, 0x7b, 0x34, 0x09, 0x4c, 0xa6, 0xc6, 0xd3, 0x9c, 0x5b, 0xbe, + 0xab, 0x40, 0xf3, 0xed, 0x53, 0x8c, 0x88, 0x97, 0x34, 0x00, 0xb8, 0x87, 0xa3, 0xf7, 0xe9, 0x43, + 0xbd, 0x30, 0xbb, 0xa2, 0xc4, 0x7f, 0xf3, 0x0e, 0x02, 0xd5, 0x8b, 0x33, 0xfb, 0xc5, 0x08, 0xf6, + 0xa0, 0x75, 0x8f, 0xe4, 0x57, 0x34, 0x66, 0x0c, 0xd1, 0xd4, 0x91, 0xa2, 0x87, 0x40, 0xb1, 0x36, + 0xbb, 0x63, 0xda, 0x79, 0x66, 0xde, 0xaa, 0xa2, 0xa9, 0x82, 0xcd, 0xbf, 0xa0, 0x55, 0x3b, 0xcf, + 0x29, 0x8f, 0x5f, 0xd9, 0x8a, 0xe8, 0xa1, 0xde, 0x3b, 0xd8, 0x74, 0xa2, 0x83, 0x29, 0x2b, 0x4a, + 0xf5, 0x38, 0x79, 0x45, 0x52, 0xc7, 0x18, 0x07, 0x86, 0x65, 0x66, 0x85, 0x72, 0x62, 0x7a, 0x5d, + 0x3d, 0x45, 0xbe, 0x67, 0x41, 0xd5, 0x33, 0x61, 0x69, 0x2b, 0xf0, 0xc7, 0x32, 0x92, 0x57, 0x94, + 0x48, 0x72, 0xfd, 0x0a, 0xa2, 0xf8, 0x21, 0xb4, 0x45, 0xfe, 0x4f, 0x33, 0x16, 0x35, 0x17, 0xd2, + 0x5d, 0x0a, 0x4e, 0xfc, 0x31, 0x2c, 0x66, 0x0a, 0x0b, 0x6a, 0xa1, 0xab, 0xab, 0x0f, 0xb3, 0x66, + 0x3f, 0x02, 0x44, 0x1f, 0x20, 0xcb, 0xff, 0xa1, 0xa0, 0x8e, 0x6f, 0xf2, 0x1d, 0x05, 0x92, 0xeb, + 0x85, 0xfb, 0xc7, 0x92, 0xff, 0x35, 0x58, 0x55, 0x26, 0xef, 0x59, 0x87, 0xc0, 0x2f, 0x3e, 0x9f, + 0x50, 0x61, 0xc8, 0x3a, 0x84, 0x13, 0x47, 0x08, 0xfc, 0x1b, 0xff, 0x86, 0xa0, 0x49, 0xe3, 0x3c, + 0x2a, 0xad, 0xff, 0x0b, 0xf3, 0x9e, 0x6c, 0x98, 0xf7, 0x31, 0x2c, 0x66, 0x1e, 0xc6, 0xaa, 0x95, + 0x56, 0xfd, 0x7a, 0xb6, 0x40, 0xb4, 0x22, 0xbf, 0x29, 0x55, 0x6f, 0x85, 0xca, 0x77, 0xa7, 0xb3, + 0xe6, 0xfe, 0x90, 0x3d, 0x3a, 0x8f, 0xcf, 0x85, 0x5f, 0x9c, 0x7a, 0x42, 0x21, 0x5f, 0x60, 0xfe, + 0xf2, 0xa3, 0xa0, 0xaf, 0x77, 0x04, 0xfa, 0x31, 0x2c, 0x66, 0x9e, 0x08, 0xa9, 0x35, 0x46, 0xfd, + 0x8e, 0x68, 0xd6, 0xec, 0x5f, 0x60, 0xf0, 0x64, 0xc1, 0xb2, 0xe2, 0x45, 0x06, 0x5a, 0x9f, 0x16, + 0x88, 0xaa, 0x9f, 0x6e, 0xcc, 0x5e, 0x50, 0x47, 0x32, 0xd3, 0xec, 0x7e, 0x93, 0x10, 0x99, 0xfd, + 0xf3, 0xa5, 0xfe, 0xcb, 0xc5, 0xfe, 0xa9, 0x29, 0x5e, 0xd0, 0x2e, 0xd4, 0xd9, 0xc3, 0x21, 0xf4, + 0xac, 0xfa, 0x10, 0x26, 0xf5, 0xa8, 0xa8, 0x3f, 0xeb, 0xe9, 0x51, 0x38, 0x71, 0x22, 0x42, 0xff, + 0x8f, 0xa0, 0xcb, 0x40, 0x31, 0x83, 0x9e, 0xe0, 0xe4, 0xbb, 0x50, 0xa3, 0xae, 0x1d, 0x29, 0x0f, + 0x14, 0xd2, 0xcf, 0x83, 0xfa, 0xb3, 0x5f, 0x04, 0x25, 0x14, 0xb7, 0xe8, 0x48, 0x56, 0xd5, 0x79, + 0x92, 0x53, 0xdf, 0x28, 0xa1, 0x1f, 0x41, 0x87, 0x4d, 0x2e, 0xb8, 0xf1, 0x24, 0x29, 0x1f, 0xc2, + 0x72, 0x8a, 0xf2, 0xa7, 0x81, 0xe2, 0x46, 0xe9, 0x7f, 0x79, 0x74, 0xff, 0x29, 0x7d, 0x9e, 0x93, + 0xbd, 0x80, 0x86, 0xd6, 0x4f, 0x77, 0x8b, 0xae, 0x7f, 0xbd, 0x70, 0xff, 0x18, 0xf3, 0x4f, 0x40, + 0xcb, 0x1e, 0x15, 0xa2, 0x97, 0xa6, 0xf9, 0x12, 0x15, 0xce, 0x19, 0x8e, 0xe4, 0x07, 0x50, 0x67, + 0x35, 0x62, 0xb5, 0x01, 0x4a, 0xf5, 0xe3, 0x19, 0x73, 0xdd, 0xf9, 0xce, 0x47, 0x1b, 0x23, 0x3b, + 0x3a, 0x98, 0xec, 0x91, 0x96, 0xeb, 0xac, 0xeb, 0x2b, 0xb6, 0xcf, 0x7f, 0x5d, 0x17, 0xb2, 0xbc, + 0x4e, 0x47, 0x5f, 0xa7, 0x08, 0xc6, 0x7b, 0x7b, 0x75, 0xfa, 0x79, 0xf3, 0x7f, 0x02, 0x00, 0x00, + 0xff, 0xff, 0x02, 0x5c, 0xf0, 0x6a, 0x76, 0x54, 0x00, 0x00, } // Reference imports to suppress errors if they are not otherwise used. @@ -5816,7 +5840,9 @@ type QueryNodeClient interface { Search(ctx context.Context, in *SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) SearchSegments(ctx context.Context, in *SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) Query(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (*internalpb.RetrieveResults, error) + QueryStream(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (QueryNode_QueryStreamClient, error) QuerySegments(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (*internalpb.RetrieveResults, error) + QueryStreamSegments(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (QueryNode_QueryStreamSegmentsClient, error) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) // https://wiki.lfaidata.foundation/display/MIL/MEP+8+--+Add+metrics+for+proxy GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) @@ -5977,6 +6003,38 @@ func (c *queryNodeClient) Query(ctx context.Context, in *QueryRequest, opts ...g return out, nil } +func (c *queryNodeClient) QueryStream(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (QueryNode_QueryStreamClient, error) { + stream, err := c.cc.NewStream(ctx, &_QueryNode_serviceDesc.Streams[0], "/milvus.proto.query.QueryNode/QueryStream", opts...) + if err != nil { + return nil, err + } + x := &queryNodeQueryStreamClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type QueryNode_QueryStreamClient interface { + Recv() (*internalpb.RetrieveResults, error) + grpc.ClientStream +} + +type queryNodeQueryStreamClient struct { + grpc.ClientStream +} + +func (x *queryNodeQueryStreamClient) Recv() (*internalpb.RetrieveResults, error) { + m := new(internalpb.RetrieveResults) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + func (c *queryNodeClient) QuerySegments(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (*internalpb.RetrieveResults, error) { out := new(internalpb.RetrieveResults) err := c.cc.Invoke(ctx, "/milvus.proto.query.QueryNode/QuerySegments", in, out, opts...) @@ -5986,6 +6044,38 @@ func (c *queryNodeClient) QuerySegments(ctx context.Context, in *QueryRequest, o return out, nil } +func (c *queryNodeClient) QueryStreamSegments(ctx context.Context, in *QueryRequest, opts ...grpc.CallOption) (QueryNode_QueryStreamSegmentsClient, error) { + stream, err := c.cc.NewStream(ctx, &_QueryNode_serviceDesc.Streams[1], "/milvus.proto.query.QueryNode/QueryStreamSegments", opts...) + if err != nil { + return nil, err + } + x := &queryNodeQueryStreamSegmentsClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type QueryNode_QueryStreamSegmentsClient interface { + Recv() (*internalpb.RetrieveResults, error) + grpc.ClientStream +} + +type queryNodeQueryStreamSegmentsClient struct { + grpc.ClientStream +} + +func (x *queryNodeQueryStreamSegmentsClient) Recv() (*internalpb.RetrieveResults, error) { + m := new(internalpb.RetrieveResults) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + func (c *queryNodeClient) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { out := new(internalpb.ShowConfigurationsResponse) err := c.cc.Invoke(ctx, "/milvus.proto.query.QueryNode/ShowConfigurations", in, out, opts...) @@ -6049,7 +6139,9 @@ type QueryNodeServer interface { Search(context.Context, *SearchRequest) (*internalpb.SearchResults, error) SearchSegments(context.Context, *SearchRequest) (*internalpb.SearchResults, error) Query(context.Context, *QueryRequest) (*internalpb.RetrieveResults, error) + QueryStream(*QueryRequest, QueryNode_QueryStreamServer) error QuerySegments(context.Context, *QueryRequest) (*internalpb.RetrieveResults, error) + QueryStreamSegments(*QueryRequest, QueryNode_QueryStreamSegmentsServer) error ShowConfigurations(context.Context, *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) // https://wiki.lfaidata.foundation/display/MIL/MEP+8+--+Add+metrics+for+proxy GetMetrics(context.Context, *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) @@ -6110,9 +6202,15 @@ func (*UnimplementedQueryNodeServer) SearchSegments(ctx context.Context, req *Se func (*UnimplementedQueryNodeServer) Query(ctx context.Context, req *QueryRequest) (*internalpb.RetrieveResults, error) { return nil, status.Errorf(codes.Unimplemented, "method Query not implemented") } +func (*UnimplementedQueryNodeServer) QueryStream(req *QueryRequest, srv QueryNode_QueryStreamServer) error { + return status.Errorf(codes.Unimplemented, "method QueryStream not implemented") +} func (*UnimplementedQueryNodeServer) QuerySegments(ctx context.Context, req *QueryRequest) (*internalpb.RetrieveResults, error) { return nil, status.Errorf(codes.Unimplemented, "method QuerySegments not implemented") } +func (*UnimplementedQueryNodeServer) QueryStreamSegments(req *QueryRequest, srv QueryNode_QueryStreamSegmentsServer) error { + return status.Errorf(codes.Unimplemented, "method QueryStreamSegments not implemented") +} func (*UnimplementedQueryNodeServer) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method ShowConfigurations not implemented") } @@ -6421,6 +6519,27 @@ func _QueryNode_Query_Handler(srv interface{}, ctx context.Context, dec func(int return interceptor(ctx, in, info, handler) } +func _QueryNode_QueryStream_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(QueryRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(QueryNodeServer).QueryStream(m, &queryNodeQueryStreamServer{stream}) +} + +type QueryNode_QueryStreamServer interface { + Send(*internalpb.RetrieveResults) error + grpc.ServerStream +} + +type queryNodeQueryStreamServer struct { + grpc.ServerStream +} + +func (x *queryNodeQueryStreamServer) Send(m *internalpb.RetrieveResults) error { + return x.ServerStream.SendMsg(m) +} + func _QueryNode_QuerySegments_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(QueryRequest) if err := dec(in); err != nil { @@ -6439,6 +6558,27 @@ func _QueryNode_QuerySegments_Handler(srv interface{}, ctx context.Context, dec return interceptor(ctx, in, info, handler) } +func _QueryNode_QueryStreamSegments_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(QueryRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(QueryNodeServer).QueryStreamSegments(m, &queryNodeQueryStreamSegmentsServer{stream}) +} + +type QueryNode_QueryStreamSegmentsServer interface { + Send(*internalpb.RetrieveResults) error + grpc.ServerStream +} + +type queryNodeQueryStreamSegmentsServer struct { + grpc.ServerStream +} + +func (x *queryNodeQueryStreamSegmentsServer) Send(m *internalpb.RetrieveResults) error { + return x.ServerStream.SendMsg(m) +} + func _QueryNode_ShowConfigurations_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(internalpb.ShowConfigurationsRequest) if err := dec(in); err != nil { @@ -6622,6 +6762,17 @@ var _QueryNode_serviceDesc = grpc.ServiceDesc{ Handler: _QueryNode_Delete_Handler, }, }, - Streams: []grpc.StreamDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "QueryStream", + Handler: _QueryNode_QueryStream_Handler, + ServerStreams: true, + }, + { + StreamName: "QueryStreamSegments", + Handler: _QueryNode_QueryStreamSegments_Handler, + ServerStreams: true, + }, + }, Metadata: "query_coord.proto", } diff --git a/internal/proxy/accesslog/access_log.go b/internal/proxy/accesslog/access_log.go index 8541a94d279ea..6cb7961b70122 100644 --- a/internal/proxy/accesslog/access_log.go +++ b/internal/proxy/accesslog/access_log.go @@ -23,24 +23,27 @@ import ( "sync" "sync/atomic" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/paramtable" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "google.golang.org/grpc" + + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) const ( clientRequestIDKey = "client_request_id" ) -var _globalL, _globalW atomic.Value -var once sync.Once +var ( + _globalL, _globalW atomic.Value + once sync.Once +) func A() *zap.Logger { return _globalL.Load().(*zap.Logger) } + func W() *RotateLogger { return _globalW.Load().(*RotateLogger) } @@ -109,11 +112,11 @@ func PrintAccessInfo(ctx context.Context, resp interface{}, err error, rpcInfo * } fields := []zap.Field{ - //format time cost of task + // format time cost of task zap.String("timeCost", fmt.Sprintf("%d ms", timeCost)), } - //get trace ID of task + // get trace ID of task traceID, ok := getTraceID(ctx) if !ok { log.Warn("access log print failed: could not get trace ID") @@ -121,7 +124,7 @@ func PrintAccessInfo(ctx context.Context, resp interface{}, err error, rpcInfo * } fields = append(fields, zap.String("traceId", traceID)) - //get response size of task + // get response size of task responseSize, ok := getResponseSize(resp) if !ok { log.Warn("access log print failed: could not get response size") @@ -129,7 +132,7 @@ func PrintAccessInfo(ctx context.Context, resp interface{}, err error, rpcInfo * } fields = append(fields, zap.Int("responseSize", responseSize)) - //get err code of task + // get err code of task errCode, ok := getErrCode(resp) if !ok { // unknown error code @@ -137,13 +140,13 @@ func PrintAccessInfo(ctx context.Context, resp interface{}, err error, rpcInfo * } fields = append(fields, zap.Int("errorCode", errCode)) - //get status of grpc + // get status of grpc Status := getGrpcStatus(err) if Status == "OK" && errCode > 0 { Status = "TaskFailed" } - //get method name of grpc + // get method name of grpc _, methodName := path.Split(rpcInfo.FullMethod) A().Info(fmt.Sprintf("%v: %s-%s", Status, getAccessAddr(ctx), methodName), fields...) diff --git a/internal/proxy/accesslog/access_log_test.go b/internal/proxy/accesslog/access_log_test.go index 755b9f20c5335..c16204f9e78d2 100644 --- a/internal/proxy/accesslog/access_log_test.go +++ b/internal/proxy/accesslog/access_log_test.go @@ -23,13 +23,14 @@ import ( "testing" "time" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/stretchr/testify/assert" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func TestAccessLogger_NotEnable(t *testing.T) { @@ -126,6 +127,7 @@ func TestAccessLogger_Stdout(t *testing.T) { ok := PrintAccessInfo(ctx, resp, nil, rpcInfo, 0) assert.True(t, ok) } + func TestAccessLogger_WithMinio(t *testing.T) { var Params paramtable.ComponentParam diff --git a/internal/proxy/accesslog/log_writer.go b/internal/proxy/accesslog/log_writer.go index f49042bc29441..4c7b8b57b6116 100644 --- a/internal/proxy/accesslog/log_writer.go +++ b/internal/proxy/accesslog/log_writer.go @@ -19,36 +19,38 @@ package accesslog import ( "context" "fmt" - "io/ioutil" "os" "path" "sync" "time" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/paramtable" - "go.uber.org/zap" ) const megabyte = 1024 * 1024 -var CheckBucketRetryAttempts uint = 20 -var timeFormat = ".2006-01-02T15-04-05.000" +var ( + CheckBucketRetryAttempts uint = 20 + timeFormat = ".2006-01-02T15-04-05.000" +) // a rotated file logger for zap.log and could upload sealed log file to minIO type RotateLogger struct { - //local path is the path to save log before update to minIO - //use os.TempDir()/accesslog if empty + // local path is the path to save log before update to minIO + // use os.TempDir()/accesslog if empty localPath string fileName string - //the interval time of update log to minIO + // the interval time of update log to minIO rotatedTime int64 - //the max size(Mb) of log file - //if local file large than maxSize will update immediately - //close if empty(zero) + // the max size(Mb) of log file + // if local file large than maxSize will update immediately + // close if empty(zero) maxSize int - //MaxBackups is the maximum number of old log files to retain - //close retention limit if empty(zero) + // MaxBackups is the maximum number of old log files to retain + // close retention limit if empty(zero) maxBackups int handler *minioHandler @@ -162,7 +164,7 @@ func (l *RotateLogger) openFileExistingOrNew() error { return fmt.Errorf("file to get log file info: %s", err) } - file, err := os.OpenFile(filename, os.O_APPEND|os.O_WRONLY, 0644) + file, err := os.OpenFile(filename, os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { return l.openNewFile() } @@ -173,13 +175,13 @@ func (l *RotateLogger) openFileExistingOrNew() error { } func (l *RotateLogger) openNewFile() error { - err := os.MkdirAll(l.dir(), 0744) + err := os.MkdirAll(l.dir(), 0o744) if err != nil { return fmt.Errorf("make directories for new log file filed: %s", err) } name := l.filename() - mode := os.FileMode(0644) + mode := os.FileMode(0o644) info, err := os.Stat(name) if err == nil { mode = info.Mode() @@ -269,7 +271,6 @@ func (l *RotateLogger) timeRotating() { case <-ticker.C: l.Rotate() } - } } @@ -318,7 +319,7 @@ func (l *RotateLogger) newBackupName() string { } func (l *RotateLogger) oldLogFiles() ([]logInfo, error) { - files, err := ioutil.ReadDir(l.dir()) + files, err := os.ReadDir(l.dir()) if err != nil { return nil, fmt.Errorf("can't read log file directory: %s", err) } diff --git a/internal/proxy/accesslog/log_writer_test.go b/internal/proxy/accesslog/log_writer_test.go index e5476c897f111..98eea96468e91 100644 --- a/internal/proxy/accesslog/log_writer_test.go +++ b/internal/proxy/accesslog/log_writer_test.go @@ -22,20 +22,22 @@ import ( "testing" "time" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/stretchr/testify/assert" "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func getText(size int) []byte { - var text = make([]byte, size) + text := make([]byte, size) for i := 0; i < size; i++ { text[i] = byte('-') } return text } + func TestRotateLogger_Basic(t *testing.T) { var Params paramtable.ComponentParam Params.Init(paramtable.NewBaseTable(paramtable.SkipRemote(true))) @@ -145,7 +147,6 @@ func TestRotateLogger_LocalRetention(t *testing.T) { logFiles, err := logger.oldLogFiles() assert.NoError(t, err) assert.Equal(t, 1, len(logFiles)) - } func TestRotateLogger_BasicError(t *testing.T) { @@ -161,7 +162,7 @@ func TestRotateLogger_BasicError(t *testing.T) { logger.openFileExistingOrNew() - os.Mkdir(path.Join(logger.dir(), "test"), 0744) + os.Mkdir(path.Join(logger.dir(), "test"), 0o744) logfile, err := logger.oldLogFiles() assert.NoError(t, err) assert.Equal(t, 0, len(logfile)) @@ -179,7 +180,7 @@ func TestRotateLogger_InitError(t *testing.T) { params.Save(params.ProxyCfg.AccessLog.LocalPath.Key, testPath) params.Save(params.ProxyCfg.AccessLog.MinioEnable.Key, "true") params.Save(params.MinioCfg.Address.Key, "") - //init err with invalid minio address + // init err with invalid minio address _, err := NewRotateLogger(¶ms.ProxyCfg.AccessLog, ¶ms.MinioCfg) assert.Error(t, err) } diff --git a/internal/proxy/accesslog/minio_handler.go b/internal/proxy/accesslog/minio_handler.go index 2a6372d443786..df1638741233d 100644 --- a/internal/proxy/accesslog/minio_handler.go +++ b/internal/proxy/accesslog/minio_handler.go @@ -24,12 +24,13 @@ import ( "sync" "time" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/milvus-io/milvus/pkg/util/retry" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/retry" ) type config struct { @@ -43,14 +44,16 @@ type config struct { iamEndpoint string } -//minIO client for upload access log -//TODO file retention on minio +// minIO client for upload access log +// TODO file retention on minio +type ( + RetentionFunc func(object minio.ObjectInfo) bool + task struct { + objectName string + filePath string + } +) -type RetentionFunc func(object minio.ObjectInfo) bool -type task struct { - objectName string - filePath string -} type minioHandler struct { bucketName string rootPath string @@ -152,7 +155,6 @@ func (c *minioHandler) scheduler() { log.Warn("close minio logger handler") return } - } } diff --git a/internal/proxy/accesslog/minio_handler_test.go b/internal/proxy/accesslog/minio_handler_test.go index a298f506c898f..d5c09333dfdc5 100644 --- a/internal/proxy/accesslog/minio_handler_test.go +++ b/internal/proxy/accesslog/minio_handler_test.go @@ -23,8 +23,9 @@ import ( "testing" "time" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func TestMinioHandler_ConnectError(t *testing.T) { @@ -54,14 +55,14 @@ func TestMinHandler_Basic(t *testing.T) { Params.Save(Params.ProxyCfg.AccessLog.MinioEnable.Key, "true") Params.Save(Params.ProxyCfg.AccessLog.RemotePath.Key, "accesslog") Params.Save(Params.ProxyCfg.AccessLog.MaxBackups.Key, "8") - //close retention + // close retention Params.Save(Params.ProxyCfg.AccessLog.RemoteMaxTime.Key, "0") - err := os.MkdirAll(testPath, 0744) + err := os.MkdirAll(testPath, 0o744) assert.NoError(t, err) defer os.RemoveAll(testPath) - //init MinioHandler + // init MinioHandler handler, err := NewMinioHandler( context.Background(), &Params.MinioCfg, @@ -72,22 +73,22 @@ func TestMinHandler_Basic(t *testing.T) { defer handler.Clean() prefix, ext := "accesslog", ".log" - //create a log file to upload + // create a log file to upload err = createAndUpdateFile(handler, time.Now(), testPath, prefix, ext) assert.NoError(t, err) time.Sleep(500 * time.Millisecond) - //check if upload success + // check if upload success lists, err := handler.listAll() assert.NoError(t, err) assert.Equal(t, 1, len(lists)) - //delete file from minio + // delete file from minio err = handler.removeWithPrefix(prefix) assert.NoError(t, err) time.Sleep(500 * time.Millisecond) - //check if delete success + // check if delete success lists, err = handler.listAll() assert.NoError(t, err) assert.Equal(t, 0, len(lists)) @@ -102,7 +103,7 @@ func TestMinioHandler_WithTimeRetention(t *testing.T) { Params.Save(Params.ProxyCfg.AccessLog.MaxBackups.Key, "8") Params.Save(Params.ProxyCfg.AccessLog.RemoteMaxTime.Key, "168") - err := os.MkdirAll(testPath, 0744) + err := os.MkdirAll(testPath, 0o744) assert.NoError(t, err) defer os.RemoveAll(testPath) @@ -118,16 +119,16 @@ func TestMinioHandler_WithTimeRetention(t *testing.T) { prefix, ext := "accesslog", ".log" handler.retentionPolicy = getTimeRetentionFunc(Params.ProxyCfg.AccessLog.RemoteMaxTime.GetAsInt(), prefix, ext) - //create a log file + // create a log file err = createAndUpdateFile(handler, time.Now(), testPath, prefix, ext) assert.NoError(t, err) - //mock a log file like time interval was large than RemoteMaxTime + // mock a log file like time interval was large than RemoteMaxTime oldTime := time.Now().Add(-1 * time.Duration(Params.ProxyCfg.AccessLog.RemoteMaxTime.GetAsInt()+1) * time.Hour) err = createAndUpdateFile(handler, oldTime, testPath, prefix, ext) assert.NoError(t, err) - //create a irrelevant file + // create a irrelevant file err = createAndUpdateFile(handler, time.Now(), testPath, "irrelevant", ext) assert.NoError(t, err) @@ -139,17 +140,16 @@ func TestMinioHandler_WithTimeRetention(t *testing.T) { handler.Retention() time.Sleep(500 * time.Millisecond) - //after retention the old file will be removed + // after retention the old file will be removed lists, err = handler.listAll() assert.NoError(t, err) assert.Equal(t, 2, len(lists)) } func createAndUpdateFile(handler *minioHandler, t time.Time, rootPath, prefix, ext string) error { - oldFileName := prefix + t.Format(timeFormat) + ext oldFilePath := path.Join(rootPath, oldFileName) - oldFileMode := os.FileMode(0644) + oldFileMode := os.FileMode(0o644) _, err := os.OpenFile(oldFilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, oldFileMode) if err != nil { return err diff --git a/internal/proxy/accesslog/util.go b/internal/proxy/accesslog/util.go index 5cb034823ccdc..9de75fc80da3a 100644 --- a/internal/proxy/accesslog/util.go +++ b/internal/proxy/accesslog/util.go @@ -23,15 +23,15 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "go.opentelemetry.io/otel/trace" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) type BaseResponse interface { diff --git a/internal/proxy/accesslog/util_test.go b/internal/proxy/accesslog/util_test.go index dfec23c6c9bd7..2d832255d813e 100644 --- a/internal/proxy/accesslog/util_test.go +++ b/internal/proxy/accesslog/util_test.go @@ -21,16 +21,17 @@ import ( "net" "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/tracer" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/stretchr/testify/assert" "go.opentelemetry.io/otel" "go.uber.org/zap" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func TestGetAccessAddr(t *testing.T) { diff --git a/internal/proxy/authentication_interceptor.go b/internal/proxy/authentication_interceptor.go index d1c1e2152a626..66d22fbc17b4b 100644 --- a/internal/proxy/authentication_interceptor.go +++ b/internal/proxy/authentication_interceptor.go @@ -15,19 +15,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" ) -func parseMD(authorization []string) (username, password string) { - if len(authorization) < 1 { - log.Warn("key not found in header") - return - } - // token format: base64 - //token := strings.TrimPrefix(authorization[0], "Bearer ") - token := authorization[0] - rawToken, err := crypto.Base64Decode(token) - if err != nil { - log.Warn("fail to decode the token", zap.Error(err)) - return - } +func parseMD(rawToken string) (username, password string) { secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2) if len(secrets) < 2 { log.Warn("invalid token format, length of secrets less than 2") @@ -40,7 +28,7 @@ func parseMD(authorization []string) (username, password string) { func validSourceID(ctx context.Context, authorization []string) bool { if len(authorization) < 1 { - //log.Warn("key not found in header", zap.String("key", util.HeaderSourceID)) + // log.Warn("key not found in header", zap.String("key", util.HeaderSourceID)) return false } // token format: base64 @@ -68,12 +56,41 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) { // 2. if rpc call from sdk if Params.CommonCfg.AuthorizationEnabled.GetAsBool() { if !validSourceID(ctx, md[strings.ToLower(util.HeaderSourceID)]) { - username, password := parseMD(md[strings.ToLower(util.HeaderAuthorize)]) - if !passwordVerify(ctx, username, password, globalMetaCache) { - msg := fmt.Sprintf("username: %s, password: %s", username, password) - return nil, merr.WrapErrParameterInvalid("vaild username and password", msg, "auth check failure, please check username and password are correct") + authStrArr := md[strings.ToLower(util.HeaderAuthorize)] + + if len(authStrArr) < 1 { + log.Warn("key not found in header") + return nil, merr.WrapErrParameterInvalidMsg("missing authorization in header") + } + + // token format: base64 + // token := strings.TrimPrefix(authorization[0], "Bearer ") + token := authStrArr[0] + rawToken, err := crypto.Base64Decode(token) + if err != nil { + log.Warn("fail to decode the token", zap.Error(err)) + return nil, merr.WrapErrParameterInvalidMsg("invalid token format") + } + + if !strings.Contains(rawToken, util.CredentialSeperator) { + user, err := VerifyAPIKey(rawToken) + if err != nil { + log.Warn("fail to verify apikey", zap.Error(err)) + return nil, err + } + metrics.UserRPCCounter.WithLabelValues(user).Inc() + userToken := fmt.Sprintf("%s%s%s", user, util.CredentialSeperator, "___") + md[strings.ToLower(util.HeaderAuthorize)] = []string{crypto.Base64Encode(userToken)} + ctx = metadata.NewIncomingContext(ctx, md) + } else { + // username+password authentication + username, password := parseMD(rawToken) + if !passwordVerify(ctx, username, password, globalMetaCache) { + msg := fmt.Sprintf("username: %s, password: %s", username, password) + return nil, merr.WrapErrParameterInvalid("vaild username and password", msg, "auth check failure, please check username and password are correct") + } + metrics.UserRPCCounter.WithLabelValues(username).Inc() } - metrics.UserRPCCounter.WithLabelValues(username).Inc() } } return ctx, nil diff --git a/internal/proxy/authentication_interceptor_test.go b/internal/proxy/authentication_interceptor_test.go index 37c9405935aaf..7eda478be7db9 100644 --- a/internal/proxy/authentication_interceptor_test.go +++ b/internal/proxy/authentication_interceptor_test.go @@ -2,8 +2,10 @@ package proxy import ( "context" + "strings" "testing" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "google.golang.org/grpc/metadata" @@ -16,7 +18,12 @@ import ( // validAuth validates the authentication func TestValidAuth(t *testing.T) { validAuth := func(ctx context.Context, authorization []string) bool { - username, password := parseMD(authorization) + if len(authorization) < 1 { + return false + } + token := authorization[0] + rawToken, _ := crypto.Base64Decode(token) + username, password := parseMD(rawToken) if username == "" || password == "" { return false } @@ -32,7 +39,7 @@ func TestValidAuth(t *testing.T) { assert.False(t, res) // normal metadata rootCoord := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -65,7 +72,7 @@ func TestAuthenticationInterceptor(t *testing.T) { assert.Error(t, err) // mock metacache rootCoord := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} mgr := newShardClientMgr() err = InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -84,4 +91,54 @@ func TestAuthenticationInterceptor(t *testing.T) { ctx = metadata.NewIncomingContext(ctx, md) _, err = AuthenticationInterceptor(ctx) assert.NoError(t, err) + + { + // wrong authorization style + md = metadata.Pairs(util.HeaderAuthorize, "123456") + ctx = metadata.NewIncomingContext(ctx, md) + _, err = AuthenticationInterceptor(ctx) + assert.Error(t, err) + } + + { + // invalid user + md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockUser2:mockPass")) + ctx = metadata.NewIncomingContext(ctx, md) + _, err = AuthenticationInterceptor(ctx) + assert.Error(t, err) + } + + { + // default hook + md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey")) + ctx = metadata.NewIncomingContext(ctx, md) + _, err = AuthenticationInterceptor(ctx) + assert.Error(t, err) + } + + { + // verify apikey error + SetMockAPIHook("", errors.New("err")) + md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey")) + ctx = metadata.NewIncomingContext(ctx, md) + _, err = AuthenticationInterceptor(ctx) + assert.Error(t, err) + } + + { + SetMockAPIHook("mockUser", nil) + md = metadata.Pairs(util.HeaderAuthorize, crypto.Base64Encode("mockapikey")) + ctx = metadata.NewIncomingContext(ctx, md) + authCtx, err := AuthenticationInterceptor(ctx) + assert.NoError(t, err) + md, ok := metadata.FromIncomingContext(authCtx) + assert.True(t, ok) + authStrArr := md[strings.ToLower(util.HeaderAuthorize)] + token := authStrArr[0] + rawToken, err := crypto.Base64Decode(token) + assert.NoError(t, err) + user, _ := parseMD(rawToken) + assert.Equal(t, "mockUser", user) + } + hoo = defaultHook{} } diff --git a/internal/proxy/channels_mgr.go b/internal/proxy/channels_mgr.go index ea8737cc19081..3ecdce564c02a 100644 --- a/internal/proxy/channels_mgr.go +++ b/internal/proxy/channels_mgr.go @@ -23,7 +23,7 @@ import ( "strconv" "sync" - "github.com/cockroachdb/errors" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -32,9 +32,8 @@ import ( "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" - - "go.uber.org/zap" ) // channelsMgr manages the pchans, vchans and related message stream of collections. @@ -88,7 +87,7 @@ type getChannelsFuncType = func(collectionID UniqueID) (channelInfos, error) type repackFuncType = func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) // getDmlChannelsFunc returns a function about how to get dml channels of a collection. -func getDmlChannelsFunc(ctx context.Context, rc types.RootCoord) getChannelsFuncType { +func getDmlChannelsFunc(ctx context.Context, rc types.RootCoordClient) getChannelsFuncType { return func(collectionID UniqueID) (channelInfos, error) { req := &milvuspb.DescribeCollectionRequest{ Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection)), @@ -105,7 +104,7 @@ func getDmlChannelsFunc(ctx context.Context, rc types.RootCoord) getChannelsFunc log.Error("failed to describe collection", zap.String("error_code", resp.GetStatus().GetErrorCode().String()), zap.String("reason", resp.GetStatus().GetReason())) - return channelInfos{}, errors.New(resp.GetStatus().GetReason()) + return channelInfos{}, merr.Error(resp.GetStatus()) } return newChannels(resp.GetVirtualChannelNames(), resp.GetPhysicalChannelNames()) @@ -257,7 +256,7 @@ func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstrea } // getOrCreateStream get message stream of specified collection. -// If stream don't exists, call createMsgStream to create for it. +// If stream doesn't exist, call createMsgStream to create for it. func (mgr *singleTypeChannelsMgr) getOrCreateStream(collectionID UniqueID) (msgstream.MsgStream, error) { if stream, err := mgr.lockGetStream(collectionID); err == nil { return stream, nil diff --git a/internal/proxy/channels_mgr_test.go b/internal/proxy/channels_mgr_test.go index 39c1bd14dbfa9..a35c4a3e4576f 100644 --- a/internal/proxy/channels_mgr_test.go +++ b/internal/proxy/channels_mgr_test.go @@ -21,14 +21,13 @@ import ( "testing" "github.com/cockroachdb/errors" - - "github.com/milvus-io/milvus/pkg/mq/msgstream" - "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - - "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func Test_removeDuplicate(t *testing.T) { @@ -56,7 +55,7 @@ func Test_getDmlChannelsFunc(t *testing.T) { t.Run("failed to describe collection", func(t *testing.T) { ctx := context.Background() rc := newMockRootCoord() - rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { + rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { return nil, errors.New("mock") } f := getDmlChannelsFunc(ctx, rc) @@ -67,7 +66,7 @@ func Test_getDmlChannelsFunc(t *testing.T) { t.Run("error code not success", func(t *testing.T) { ctx := context.Background() rc := newMockRootCoord() - rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { + rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { return &milvuspb.DescribeCollectionResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}}, nil } f := getDmlChannelsFunc(ctx, rc) @@ -78,11 +77,12 @@ func Test_getDmlChannelsFunc(t *testing.T) { t.Run("normal case", func(t *testing.T) { ctx := context.Background() rc := newMockRootCoord() - rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { + rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { return &milvuspb.DescribeCollectionResponse{ VirtualChannelNames: []string{"111", "222"}, PhysicalChannelNames: []string{"111", "111"}, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + }, nil } f := getDmlChannelsFunc(ctx, rc) got, err := f(100) diff --git a/internal/proxy/channels_time_ticker.go b/internal/proxy/channels_time_ticker.go index 4ce450603a5c7..c229010989627 100644 --- a/internal/proxy/channels_time_ticker.go +++ b/internal/proxy/channels_time_ticker.go @@ -219,7 +219,6 @@ func newChannelsTimeTicker( getStatisticsFunc getPChanStatisticsFuncType, tso tsoAllocator, ) *channelsTimeTickerImpl { - ctx1, cancel := context.WithCancel(ctx) ticker := &channelsTimeTickerImpl{ diff --git a/internal/proxy/channels_time_ticker_test.go b/internal/proxy/channels_time_ticker_test.go index 6c6d10d27495a..a5600e9a5faa9 100644 --- a/internal/proxy/channels_time_ticker_test.go +++ b/internal/proxy/channels_time_ticker_test.go @@ -23,12 +23,12 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/typeutil" - - "github.com/stretchr/testify/assert" - "go.uber.org/zap" ) func newGetStatisticsFunc(pchans []pChan) getPChanStatisticsFuncType { diff --git a/internal/proxy/client_info.go b/internal/proxy/client_info.go index a6bfcecffd279..5b60c92cce2df 100644 --- a/internal/proxy/client_info.go +++ b/internal/proxy/client_info.go @@ -4,9 +4,10 @@ import ( "context" "time" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/log" - "go.uber.org/zap" ) type clientInfo struct { diff --git a/internal/proxy/condition_test.go b/internal/proxy/condition_test.go index 90ceea30958f8..13dba5b651ffa 100644 --- a/internal/proxy/condition_test.go +++ b/internal/proxy/condition_test.go @@ -23,11 +23,10 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" + "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" - "go.uber.org/zap" ) func TestTaskCondition_Ctx(t *testing.T) { diff --git a/internal/proxy/connection_manager.go b/internal/proxy/connection_manager.go index 581d24e6a4d0d..de47517fb1b63 100644 --- a/internal/proxy/connection_manager.go +++ b/internal/proxy/connection_manager.go @@ -9,7 +9,6 @@ import ( "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/pkg/log" ) diff --git a/internal/proxy/connection_manager_test.go b/internal/proxy/connection_manager_test.go index 346542b2c041d..25f8c98b2793e 100644 --- a/internal/proxy/connection_manager_test.go +++ b/internal/proxy/connection_manager_test.go @@ -5,9 +5,9 @@ import ( "testing" "time" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) func Test_withDuration(t *testing.T) { diff --git a/internal/proxy/count_reducer.go b/internal/proxy/count_reducer.go index af0e5d6161183..41a476ab9950f 100644 --- a/internal/proxy/count_reducer.go +++ b/internal/proxy/count_reducer.go @@ -6,8 +6,7 @@ import ( "github.com/milvus-io/milvus/internal/util/funcutil" ) -type cntReducer struct { -} +type cntReducer struct{} func (r *cntReducer) Reduce(results []*internalpb.RetrieveResults) (*milvuspb.QueryResults, error) { cnt := int64(0) diff --git a/internal/proxy/count_reducer_test.go b/internal/proxy/count_reducer_test.go index d7a2f10d8823d..4e1fd436777a7 100644 --- a/internal/proxy/count_reducer_test.go +++ b/internal/proxy/count_reducer_test.go @@ -3,11 +3,11 @@ package proxy import ( "testing" - "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/util/funcutil" ) func Test_cntReducer_Reduce(t *testing.T) { diff --git a/internal/proxy/data_coord_mock_test.go b/internal/proxy/data_coord_mock_test.go index 622d308b37954..5815437933c2d 100644 --- a/internal/proxy/data_coord_mock_test.go +++ b/internal/proxy/data_coord_mock_test.go @@ -19,6 +19,9 @@ package proxy import ( "context" + "go.uber.org/atomic" + "google.golang.org/grpc" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -26,13 +29,13 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/uniquegenerator" - "go.uber.org/atomic" ) type DataCoordMock struct { - types.DataCoord + types.DataCoordClient nodeID typeutil.UniqueID address string @@ -43,9 +46,9 @@ type DataCoordMock struct { showConfigurationsFunc showConfigurationsFuncType statisticsChannel string timeTickChannel string - checkHealthFunc func(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) - GetIndexStateFunc func(ctx context.Context, request *indexpb.GetIndexStateRequest) (*indexpb.GetIndexStateResponse, error) - DescribeIndexFunc func(ctx context.Context, request *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error) + checkHealthFunc func(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) + GetIndexStateFunc func(ctx context.Context, request *indexpb.GetIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetIndexStateResponse, error) + DescribeIndexFunc func(ctx context.Context, request *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) } func (coord *DataCoordMock) updateState(state commonpb.StateCode) { @@ -60,24 +63,7 @@ func (coord *DataCoordMock) healthy() bool { return coord.getState() == commonpb.StateCode_Healthy } -func (coord *DataCoordMock) Init() error { - coord.updateState(commonpb.StateCode_Initializing) - return nil -} - -func (coord *DataCoordMock) Start() error { - defer coord.updateState(commonpb.StateCode_Healthy) - - return nil -} - -func (coord *DataCoordMock) Stop() error { - defer coord.updateState(commonpb.StateCode_Abnormal) - - return nil -} - -func (coord *DataCoordMock) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (coord *DataCoordMock) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ NodeID: coord.nodeID, @@ -86,20 +72,14 @@ func (coord *DataCoordMock) GetComponentStates(ctx context.Context) (*milvuspb.C ExtraInfo: nil, }, SubcomponentStates: nil, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), }, nil } -func (coord *DataCoordMock) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (coord *DataCoordMock) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Value: coord.statisticsChannel, + Status: merr.Success(), + Value: coord.statisticsChannel, }, nil } @@ -107,88 +87,85 @@ func (coord *DataCoordMock) Register() error { return nil } -func (coord *DataCoordMock) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (coord *DataCoordMock) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Value: coord.timeTickChannel, + Status: merr.Success(), + Value: coord.timeTickChannel, }, nil } -func (coord *DataCoordMock) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.FlushResponse, error) { +func (coord *DataCoordMock) Flush(ctx context.Context, req *datapb.FlushRequest, opts ...grpc.CallOption) (*datapb.FlushResponse, error) { panic("implement me") } -func (coord *DataCoordMock) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { +func (coord *DataCoordMock) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func (coord *DataCoordMock) UnsetIsImportingState(context.Context, *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { +func (coord *DataCoordMock) UnsetIsImportingState(ctx context.Context, in *datapb.UnsetIsImportingStateRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func (coord *DataCoordMock) MarkSegmentsDropped(ctx context.Context, req *datapb.MarkSegmentsDroppedRequest) (*commonpb.Status, error) { +func (coord *DataCoordMock) MarkSegmentsDropped(ctx context.Context, req *datapb.MarkSegmentsDroppedRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func (coord *DataCoordMock) BroadcastAlteredCollection(ctx context.Context, req *datapb.AlterCollectionRequest) (*commonpb.Status, error) { +func (coord *DataCoordMock) BroadcastAlteredCollection(ctx context.Context, req *datapb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func (coord *DataCoordMock) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { +func (coord *DataCoordMock) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { if coord.checkHealthFunc != nil { return coord.checkHealthFunc(ctx, req) } return &milvuspb.CheckHealthResponse{IsHealthy: true}, nil } -func (coord *DataCoordMock) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { +func (coord *DataCoordMock) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) { panic("implement me") } -func (coord *DataCoordMock) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { +func (coord *DataCoordMock) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentStatesRequest, opts ...grpc.CallOption) (*datapb.GetSegmentStatesResponse, error) { panic("implement me") } -func (coord *DataCoordMock) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsertBinlogPathsRequest) (*datapb.GetInsertBinlogPathsResponse, error) { +func (coord *DataCoordMock) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsertBinlogPathsRequest, opts ...grpc.CallOption) (*datapb.GetInsertBinlogPathsResponse, error) { panic("implement me") } -func (coord *DataCoordMock) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (coord *DataCoordMock) GetSegmentInfoChannel(ctx context.Context, in *datapb.GetSegmentInfoChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { panic("implement me") } -func (coord *DataCoordMock) GetCollectionStatistics(ctx context.Context, req *datapb.GetCollectionStatisticsRequest) (*datapb.GetCollectionStatisticsResponse, error) { +func (coord *DataCoordMock) GetCollectionStatistics(ctx context.Context, req *datapb.GetCollectionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetCollectionStatisticsResponse, error) { panic("implement me") } -func (coord *DataCoordMock) GetPartitionStatistics(ctx context.Context, req *datapb.GetPartitionStatisticsRequest) (*datapb.GetPartitionStatisticsResponse, error) { +func (coord *DataCoordMock) GetPartitionStatistics(ctx context.Context, req *datapb.GetPartitionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetPartitionStatisticsResponse, error) { panic("implement me") } -func (coord *DataCoordMock) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoRequest) (*datapb.GetSegmentInfoResponse, error) { +func (coord *DataCoordMock) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*datapb.GetSegmentInfoResponse, error) { panic("implement me") } -func (coord *DataCoordMock) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInfoRequest) (*datapb.GetRecoveryInfoResponse, error) { +func (coord *DataCoordMock) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInfoRequest, opts ...grpc.CallOption) (*datapb.GetRecoveryInfoResponse, error) { panic("implement me") } -func (coord *DataCoordMock) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) (*commonpb.Status, error) { +func (coord *DataCoordMock) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { panic("implement me") } -func (coord *DataCoordMock) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) { +func (coord *DataCoordMock) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest, opts ...grpc.CallOption) (*datapb.GetFlushedSegmentsResponse, error) { panic("implement me") } -func (coord *DataCoordMock) GetSegmentsByStates(ctx context.Context, req *datapb.GetSegmentsByStatesRequest) (*datapb.GetSegmentsByStatesResponse, error) { +func (coord *DataCoordMock) GetSegmentsByStates(ctx context.Context, req *datapb.GetSegmentsByStatesRequest, opts ...grpc.CallOption) (*datapb.GetSegmentsByStatesResponse, error) { panic("implement me") } -func (coord *DataCoordMock) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { +func (coord *DataCoordMock) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { if !coord.healthy() { return &internalpb.ShowConfigurationsResponse{ Status: &commonpb.Status{ @@ -210,7 +187,7 @@ func (coord *DataCoordMock) ShowConfigurations(ctx context.Context, req *interna }, nil } -func (coord *DataCoordMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (coord *DataCoordMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { if !coord.healthy() { return &milvuspb.GetMetricsResponse{ Status: &commonpb.Status{ @@ -234,136 +211,113 @@ func (coord *DataCoordMock) GetMetrics(ctx context.Context, req *milvuspb.GetMet }, nil } -func (coord *DataCoordMock) CompleteCompaction(ctx context.Context, req *datapb.CompactionResult) (*commonpb.Status, error) { +func (coord *DataCoordMock) CompleteCompaction(ctx context.Context, req *datapb.CompactionResult, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, nil } -func (coord *DataCoordMock) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) { +func (coord *DataCoordMock) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompactionRequest, opts ...grpc.CallOption) (*milvuspb.ManualCompactionResponse, error) { return &milvuspb.ManualCompactionResponse{}, nil } -func (coord *DataCoordMock) GetCompactionState(ctx context.Context, req *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) { +func (coord *DataCoordMock) GetCompactionState(ctx context.Context, req *milvuspb.GetCompactionStateRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionStateResponse, error) { return &milvuspb.GetCompactionStateResponse{}, nil } -func (coord *DataCoordMock) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) { +func (coord *DataCoordMock) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionPlansResponse, error) { return &milvuspb.GetCompactionPlansResponse{}, nil } -func (coord *DataCoordMock) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) { +func (coord *DataCoordMock) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest, opts ...grpc.CallOption) (*datapb.WatchChannelsResponse, error) { return &datapb.WatchChannelsResponse{}, nil } -func (coord *DataCoordMock) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { +func (coord *DataCoordMock) GetFlushState(ctx context.Context, req *datapb.GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) { return &milvuspb.GetFlushStateResponse{}, nil } -func (coord *DataCoordMock) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAllStateRequest) (*milvuspb.GetFlushAllStateResponse, error) { +func (coord *DataCoordMock) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAllStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushAllStateResponse, error) { return &milvuspb.GetFlushAllStateResponse{}, nil } -func (coord *DataCoordMock) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) (*datapb.DropVirtualChannelResponse, error) { +func (coord *DataCoordMock) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest, opts ...grpc.CallOption) (*datapb.DropVirtualChannelResponse, error) { return &datapb.DropVirtualChannelResponse{}, nil } -func (coord *DataCoordMock) SetSegmentState(ctx context.Context, req *datapb.SetSegmentStateRequest) (*datapb.SetSegmentStateResponse, error) { +func (coord *DataCoordMock) SetSegmentState(ctx context.Context, req *datapb.SetSegmentStateRequest, opts ...grpc.CallOption) (*datapb.SetSegmentStateResponse, error) { return &datapb.SetSegmentStateResponse{}, nil } -func (coord *DataCoordMock) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { +func (coord *DataCoordMock) Import(ctx context.Context, req *datapb.ImportTaskRequest, opts ...grpc.CallOption) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{}, nil } -func (coord *DataCoordMock) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil +func (coord *DataCoordMock) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return merr.Success(), nil } -func (coord *DataCoordMock) UpdateChannelCheckpoint(ctx context.Context, req *datapb.UpdateChannelCheckpointRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil +func (coord *DataCoordMock) UpdateChannelCheckpoint(ctx context.Context, req *datapb.UpdateChannelCheckpointRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return merr.Success(), nil } -func (coord *DataCoordMock) CreateIndex(ctx context.Context, req *indexpb.CreateIndexRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil +func (coord *DataCoordMock) CreateIndex(ctx context.Context, req *indexpb.CreateIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return merr.Success(), nil } -func (coord *DataCoordMock) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil +func (coord *DataCoordMock) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return merr.Success(), nil } -func (coord *DataCoordMock) GetIndexState(ctx context.Context, req *indexpb.GetIndexStateRequest) (*indexpb.GetIndexStateResponse, error) { +func (coord *DataCoordMock) GetIndexState(ctx context.Context, req *indexpb.GetIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetIndexStateResponse, error) { return &indexpb.GetIndexStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), State: commonpb.IndexState_Finished, FailReason: "", }, nil } // GetSegmentIndexState gets the index state of the segments in the request from RootCoord. -func (coord *DataCoordMock) GetSegmentIndexState(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest) (*indexpb.GetSegmentIndexStateResponse, error) { +func (coord *DataCoordMock) GetSegmentIndexState(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetSegmentIndexStateResponse, error) { return &indexpb.GetSegmentIndexStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), }, nil } // GetIndexInfos gets the index files of the IndexBuildIDs in the request from RootCoordinator. -func (coord *DataCoordMock) GetIndexInfos(ctx context.Context, req *indexpb.GetIndexInfoRequest) (*indexpb.GetIndexInfoResponse, error) { +func (coord *DataCoordMock) GetIndexInfos(ctx context.Context, req *indexpb.GetIndexInfoRequest, opts ...grpc.CallOption) (*indexpb.GetIndexInfoResponse, error) { return &indexpb.GetIndexInfoResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), }, nil } // DescribeIndex describe the index info of the collection. -func (coord *DataCoordMock) DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error) { +func (coord *DataCoordMock) DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) { return &indexpb.DescribeIndexResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), IndexInfos: nil, }, nil } // GetIndexStatistics get the statistics of the index. -func (coord *DataCoordMock) GetIndexStatistics(ctx context.Context, req *indexpb.GetIndexStatisticsRequest) (*indexpb.GetIndexStatisticsResponse, error) { +func (coord *DataCoordMock) GetIndexStatistics(ctx context.Context, req *indexpb.GetIndexStatisticsRequest, opts ...grpc.CallOption) (*indexpb.GetIndexStatisticsResponse, error) { return &indexpb.GetIndexStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), IndexInfos: nil, }, nil } // GetIndexBuildProgress get the index building progress by num rows. -func (coord *DataCoordMock) GetIndexBuildProgress(ctx context.Context, req *indexpb.GetIndexBuildProgressRequest) (*indexpb.GetIndexBuildProgressResponse, error) { +func (coord *DataCoordMock) GetIndexBuildProgress(ctx context.Context, req *indexpb.GetIndexBuildProgressRequest, opts ...grpc.CallOption) (*indexpb.GetIndexBuildProgressResponse, error) { return &indexpb.GetIndexBuildProgressResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } +func (coord *DataCoordMock) Close() error { + return nil +} + func NewDataCoordMock() *DataCoordMock { - return &DataCoordMock{ + dc := &DataCoordMock{ nodeID: typeutil.UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()), address: funcutil.GenRandomStr(), // random address state: atomic.Value{}, @@ -371,4 +325,6 @@ func NewDataCoordMock() *DataCoordMock { statisticsChannel: funcutil.GenRandomStr(), timeTickChannel: funcutil.GenRandomStr(), } + dc.updateState(commonpb.StateCode_Healthy) + return dc } diff --git a/internal/proxy/database_interceptor.go b/internal/proxy/database_interceptor.go index 52eb30ff84e39..fa97af0f4709d 100644 --- a/internal/proxy/database_interceptor.go +++ b/internal/proxy/database_interceptor.go @@ -238,6 +238,11 @@ func fillDatabase(ctx context.Context, req interface{}) (context.Context, interf r.DbName = GetCurDBNameFromContextOrDefault(ctx) } return ctx, r + case *milvuspb.GetFlushStateRequest: + if r.DbName == "" { + r.DbName = GetCurDBNameFromContextOrDefault(ctx) + } + return ctx, r default: return ctx, req } diff --git a/internal/proxy/database_interceptor_test.go b/internal/proxy/database_interceptor_test.go index 8a901e820ba4e..77f62c68431ca 100644 --- a/internal/proxy/database_interceptor_test.go +++ b/internal/proxy/database_interceptor_test.go @@ -74,6 +74,7 @@ func TestDatabaseInterceptor(t *testing.T) { &milvuspb.DeleteRequest{}, &milvuspb.SearchRequest{}, &milvuspb.FlushRequest{}, + &milvuspb.GetFlushStateRequest{}, &milvuspb.QueryRequest{}, &milvuspb.CreateAliasRequest{}, &milvuspb.DropAliasRequest{}, @@ -113,7 +114,6 @@ func TestDatabaseInterceptor(t *testing.T) { &milvuspb.GetCompactionStateRequest{}, &milvuspb.ManualCompactionRequest{}, &milvuspb.GetCompactionPlansRequest{}, - &milvuspb.GetFlushStateRequest{}, &milvuspb.GetFlushAllStateRequest{}, &milvuspb.GetImportStateRequest{}, } @@ -133,5 +133,4 @@ func TestDatabaseInterceptor(t *testing.T) { } } }) - } diff --git a/internal/proxy/default_limit_reducer.go b/internal/proxy/default_limit_reducer.go index d74440567bfc3..0f70f49bd1ccb 100644 --- a/internal/proxy/default_limit_reducer.go +++ b/internal/proxy/default_limit_reducer.go @@ -3,7 +3,6 @@ package proxy import ( "context" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -39,42 +38,38 @@ func (r *defaultLimitReducer) afterReduce(result *milvuspb.QueryResults) error { outputFieldsID := r.req.GetOutputFieldsId() result.CollectionName = collectionName + var err error - if len(result.FieldsData) > 0 { - result.Status = merr.Status(nil) - } else { - result.Status = &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_EmptyCollection, - Reason: "empty collection", // TODO - } - return nil - } - - for i := 0; i < len(result.FieldsData); i++ { + for i := 0; i < len(result.GetFieldsData()); i++ { + // drop ts column if outputFieldsID[i] == common.TimeStampField { result.FieldsData = append(result.FieldsData[:i], result.FieldsData[(i+1):]...) + outputFieldsID = append(outputFieldsID[:i], outputFieldsID[i+1:]...) i-- continue } - for _, field := range schema.Fields { - if field.FieldID == outputFieldsID[i] { - // deal with the situation that offset equal to or greater than the number of entities - if result.FieldsData[i] == nil { - var err error - result.FieldsData[i], err = typeutil.GenEmptyFieldData(field) - if err != nil { - return err - } - } - result.FieldsData[i].FieldName = field.Name - result.FieldsData[i].FieldId = field.FieldID - result.FieldsData[i].Type = field.DataType - result.FieldsData[i].IsDynamic = field.IsDynamic + field := typeutil.GetField(schema, outputFieldsID[i]) + if field == nil { + err = merr.WrapErrFieldNotFound(outputFieldsID[i]) + break + } + + if result.FieldsData[i] == nil { + result.FieldsData[i], err = typeutil.GenEmptyFieldData(field) + if err != nil { + break } + continue } + + result.FieldsData[i].FieldName = field.GetName() + result.FieldsData[i].FieldId = field.GetFieldID() + result.FieldsData[i].Type = field.GetDataType() + result.FieldsData[i].IsDynamic = field.GetIsDynamic() } - return nil + result.Status = merr.Status(err) + return err } func newDefaultLimitReducer(ctx context.Context, params *queryParams, req *internalpb.RetrieveRequest, schema *schemapb.CollectionSchema, collectionName string) *defaultLimitReducer { diff --git a/internal/proxy/dummyreq_test.go b/internal/proxy/dummyreq_test.go index ee1d1fa552032..a86ec6ad83d33 100644 --- a/internal/proxy/dummyreq_test.go +++ b/internal/proxy/dummyreq_test.go @@ -20,10 +20,10 @@ import ( "encoding/json" "testing" - "github.com/milvus-io/milvus/pkg/log" + "github.com/stretchr/testify/assert" "go.uber.org/zap" - "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/pkg/log" ) func Test_parseDummyRequestType(t *testing.T) { diff --git a/internal/proxy/error.go b/internal/proxy/error.go deleted file mode 100644 index 88558c3ad00a3..0000000000000 --- a/internal/proxy/error.go +++ /dev/null @@ -1,42 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxy - -import ( - "fmt" - - "github.com/cockroachdb/errors" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" -) - -// Keep this error temporarily -// this error belongs to ErrServiceMemoryLimitExceeded -// but in the error returned by querycoord,the collection id is given -// which can not be thrown out -// the error will be deleted after reaching an agreement on collection name and id in qn - -// ErrInsufficientMemory returns insufficient memory error. -var ErrInsufficientMemory = errors.New("InsufficientMemoryToLoad") - -// InSufficientMemoryStatus returns insufficient memory status. -func InSufficientMemoryStatus(collectionName string) *commonpb.Status { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad, - Reason: fmt.Sprintf("deny to load, insufficient memory, please allocate more resources, collectionName: %s", collectionName), - } -} diff --git a/internal/proxy/error_test.go b/internal/proxy/error_test.go deleted file mode 100644 index d5641f4e29457..0000000000000 --- a/internal/proxy/error_test.go +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxy - -import ( - "fmt" - "testing" - - "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" -) - -func Test_ErrInsufficientMemory(t *testing.T) { - err := fmt.Errorf("%w, mock insufficient memory error", ErrInsufficientMemory) - assert.True(t, errors.Is(err, ErrInsufficientMemory)) - - status := InSufficientMemoryStatus("collection1") - assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, status.GetErrorCode()) -} diff --git a/internal/proxy/hook_interceptor.go b/internal/proxy/hook_interceptor.go index a86e640cf65cc..fa4a314e44c36 100644 --- a/internal/proxy/hook_interceptor.go +++ b/internal/proxy/hook_interceptor.go @@ -7,16 +7,21 @@ import ( "strconv" "strings" + "github.com/cockroachdb/errors" + "go.uber.org/zap" + "google.golang.org/grpc" + "github.com/milvus-io/milvus-proto/go-api/v2/hook" "github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/paramtable" - "go.uber.org/zap" - "google.golang.org/grpc" ) -type defaultHook struct { +type defaultHook struct{} + +func (d defaultHook) VerifyAPIKey(key string) (string, error) { + return "", errors.New("default hook, can't verify api key") } func (d defaultHook) Init(params map[string]string) error { @@ -123,10 +128,11 @@ func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor { } func updateProxyFunctionCallMetric(fullMethod string) { - if fullMethod == "" { + strs := strings.Split(fullMethod, "/") + method := strs[len(strs)-1] + if method == "" { return } - method := strings.Split(fullMethod, "/")[0] metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc() metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() } @@ -138,3 +144,25 @@ func getCurrentUser(ctx context.Context) string { } return username } + +// MockAPIHook is a mock hook for api key verification, ONLY FOR TEST +type MockAPIHook struct { + defaultHook + mockErr error + apiUser string +} + +func (m MockAPIHook) VerifyAPIKey(apiKey string) (string, error) { + return m.apiUser, m.mockErr +} + +func SetMockAPIHook(apiUser string, mockErr error) { + if apiUser == "" && mockErr == nil { + hoo = defaultHook{} + return + } + hoo = MockAPIHook{ + mockErr: mockErr, + apiUser: apiUser, + } +} diff --git a/internal/proxy/hook_interceptor_test.go b/internal/proxy/hook_interceptor_test.go index b0341fdceeecc..a387053b8eeb4 100644 --- a/internal/proxy/hook_interceptor_test.go +++ b/internal/proxy/hook_interceptor_test.go @@ -5,11 +5,10 @@ import ( "testing" "github.com/cockroachdb/errors" - + "github.com/stretchr/testify/assert" "google.golang.org/grpc" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" ) func TestInitHook(t *testing.T) { @@ -143,7 +142,19 @@ func TestHookInterceptor(t *testing.T) { func TestDefaultHook(t *testing.T) { d := defaultHook{} assert.NoError(t, d.Init(nil)) + { + _, err := d.VerifyAPIKey("key") + assert.Error(t, err) + } assert.NotPanics(t, func() { d.Release() }) } + +func TestUpdateProxyFunctionCallMetric(t *testing.T) { + assert.NotPanics(t, func() { + updateProxyFunctionCallMetric("/milvus.proto.milvus.MilvusService/Flush") + updateProxyFunctionCallMetric("Flush") + updateProxyFunctionCallMetric("") + }) +} diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index cf53c925697fe..ff9e578d96c72 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -18,6 +18,7 @@ package proxy import ( "context" + "encoding/base64" "fmt" "os" "strconv" @@ -48,7 +49,6 @@ import ( "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/crypto" - "github.com/milvus-io/milvus/pkg/util/errorutil" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -63,25 +63,12 @@ const moduleName = "Proxy" const SlowReadSpan = time.Second * 5 -// UpdateStateCode updates the state code of Proxy. -func (node *Proxy) UpdateStateCode(code commonpb.StateCode) { - node.stateCode.Store(code) -} - -// GetComponentStates get state of Proxy. -func (node *Proxy) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +// GetComponentStates gets the state of Proxy. +func (node *Proxy) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { stats := &milvuspb.ComponentStates{ - Status: merr.Status(nil), - } - code, ok := node.stateCode.Load().(commonpb.StateCode) - if !ok { - errMsg := "unexpected error in type assertion" - stats.Status = &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: errMsg, - } - return stats, nil + Status: merr.Success(), } + code := node.GetStateCode() nodeID := common.NotRegisteredID if node.session != nil && node.session.Registered() { nodeID = node.session.ServerID @@ -97,17 +84,17 @@ func (node *Proxy) GetComponentStates(ctx context.Context) (*milvuspb.ComponentS } // GetStatisticsChannel gets statistics channel of Proxy. -func (node *Proxy) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (node *Proxy) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Value: "", }, nil } // InvalidateCollectionMetaCache invalidate the meta cache of specific collection. func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx = logutil.WithModule(ctx, moduleName) @@ -117,7 +104,8 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p zap.String("role", typeutil.ProxyRole), zap.String("db", request.DbName), zap.String("collectionName", request.CollectionName), - zap.Int64("collectionID", request.CollectionID)) + zap.Int64("collectionID", request.CollectionID), + ) log.Info("received request to invalidate collection meta cache") @@ -144,12 +132,12 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p } log.Info("complete to invalidate collection meta cache") - return merr.Status(nil), nil + return merr.Success(), nil } func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-CreateDatabase") @@ -158,18 +146,25 @@ func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateD method := "CreateDatabase" tr := timerecord.NewTimeRecorder(method) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.TotalLabel, + ).Inc() cct := &createDatabaseTask{ ctx: ctx, Condition: NewTaskCondition(ctx), CreateDatabaseRequest: request, rootCoord: node.rootCoord, + replicateMsgStream: node.replicateMsgStream, } - log := log.With(zap.String("traceID", sp.SpanContext().TraceID().String()), + log := log.With( + zap.String("traceID", sp.SpanContext().TraceID().String()), zap.String("role", typeutil.ProxyRole), - zap.String("dbName", request.DbName)) + zap.String("dbName", request.DbName), + ) log.Info(rpcReceived(method)) if err := node.sched.ddQueue.Enqueue(cct); err != nil { @@ -188,14 +183,23 @@ func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateD } log.Info(rpcDone(method)) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc() - metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.SuccessLabel, + ).Inc() + + metrics.ProxyReqLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + ).Observe(float64(tr.ElapseSpan().Milliseconds())) + return cct.result, nil } func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-DropDatabase") @@ -203,18 +207,25 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab method := "DropDatabase" tr := timerecord.NewTimeRecorder(method) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.TotalLabel, + ).Inc() dct := &dropDatabaseTask{ ctx: ctx, Condition: NewTaskCondition(ctx), DropDatabaseRequest: request, rootCoord: node.rootCoord, + replicateMsgStream: node.replicateMsgStream, } - log := log.With(zap.String("traceID", sp.SpanContext().TraceID().String()), + log := log.With( + zap.String("traceID", sp.SpanContext().TraceID().String()), zap.String("role", typeutil.ProxyRole), - zap.String("dbName", request.DbName)) + zap.String("dbName", request.DbName), + ) log.Info(rpcReceived(method)) if err := node.sched.ddQueue.Enqueue(dct); err != nil { @@ -231,15 +242,24 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab } log.Info(rpcDone(method)) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc() - metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.SuccessLabel, + ).Inc() + + metrics.ProxyReqLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + ).Observe(float64(tr.ElapseSpan().Milliseconds())) + return dct.result, nil } func (node *Proxy) ListDatabases(ctx context.Context, request *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) { resp := &milvuspb.ListDatabasesResponse{} - if !node.checkHealthy() { - resp.Status = unhealthyStatus() + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp.Status = merr.Status(err) return resp, nil } @@ -248,7 +268,11 @@ func (node *Proxy) ListDatabases(ctx context.Context, request *milvuspb.ListData method := "ListDatabases" tr := timerecord.NewTimeRecorder(method) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.TotalLabel, + ).Inc() dct := &listDatabaseTask{ ctx: ctx, @@ -257,8 +281,10 @@ func (node *Proxy) ListDatabases(ctx context.Context, request *milvuspb.ListData rootCoord: node.rootCoord, } - log := log.With(zap.String("traceID", sp.SpanContext().TraceID().String()), - zap.String("role", typeutil.ProxyRole)) + log := log.With( + zap.String("traceID", sp.SpanContext().TraceID().String()), + zap.String("role", typeutil.ProxyRole), + ) log.Info(rpcReceived(method)) @@ -278,16 +304,25 @@ func (node *Proxy) ListDatabases(ctx context.Context, request *milvuspb.ListData } log.Info(rpcDone(method), zap.Int("num of db", len(dct.result.DbNames))) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc() - metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.SuccessLabel, + ).Inc() + + metrics.ProxyReqLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + ).Observe(float64(tr.ElapseSpan().Milliseconds())) + return dct.result, nil } // CreateCollection create a collection by the schema. // TODO(dragondriver): add more detailed ut for ConsistencyLevel, should we support multiple consistency level in Proxy? func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-CreateCollection") @@ -295,7 +330,11 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat method := "CreateCollection" tr := timerecord.NewTimeRecorder(method) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.TotalLabel, + ).Inc() cct := &createCollectionTask{ ctx: ctx, @@ -313,7 +352,8 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat zap.String("collection", request.CollectionName), zap.Int("len(schema)", lenOfSchema), zap.Int32("shards_num", request.ShardsNum), - zap.String("consistency_level", request.ConsistencyLevel.String())) + zap.String("consistency_level", request.ConsistencyLevel.String()), + ) log.Debug(rpcReceived(method)) @@ -330,7 +370,8 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat rpcEnqueued(method), zap.Uint64("BeginTs", cct.BeginTs()), zap.Uint64("EndTs", cct.EndTs()), - zap.Uint64("timestamp", request.Base.Timestamp)) + zap.Uint64("timestamp", request.Base.Timestamp), + ) if err := cct.WaitToFinish(); err != nil { log.Warn( @@ -346,24 +387,37 @@ func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.Creat log.Debug( rpcDone(method), zap.Uint64("BeginTs", cct.BeginTs()), - zap.Uint64("EndTs", cct.EndTs())) + zap.Uint64("EndTs", cct.EndTs()), + ) + + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.SuccessLabel, + ).Inc() + metrics.ProxyReqLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + ).Observe(float64(tr.ElapseSpan().Milliseconds())) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc() - metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) return cct.result, nil } // DropCollection drop a collection. func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-DropCollection") defer sp.End() method := "DropCollection" tr := timerecord.NewTimeRecorder(method) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.TotalLabel, + ).Inc() dct := &dropCollectionTask{ ctx: ctx, @@ -377,7 +431,8 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol log := log.Ctx(ctx).With( zap.String("role", typeutil.ProxyRole), zap.String("db", request.DbName), - zap.String("collection", request.CollectionName)) + zap.String("collection", request.CollectionName), + ) log.Debug("DropCollection received") @@ -389,9 +444,11 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol return merr.Status(err), nil } - log.Debug("DropCollection enqueued", + log.Debug( + "DropCollection enqueued", zap.Uint64("BeginTs", dct.BeginTs()), - zap.Uint64("EndTs", dct.EndTs())) + zap.Uint64("EndTs", dct.EndTs()), + ) if err := dct.WaitToFinish(); err != nil { log.Warn("DropCollection failed to WaitToFinish", @@ -403,20 +460,30 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol return merr.Status(err), nil } - log.Debug("DropCollection done", + log.Debug( + "DropCollection done", zap.Uint64("BeginTs", dct.BeginTs()), - zap.Uint64("EndTs", dct.EndTs())) + zap.Uint64("EndTs", dct.EndTs()), + ) + + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.SuccessLabel, + ).Inc() + metrics.ProxyReqLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + ).Observe(float64(tr.ElapseSpan().Milliseconds())) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc() - metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) return dct.result, nil } // HasCollection check if the specific collection exists in Milvus. func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.BoolResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -424,13 +491,17 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle defer sp.End() method := "HasCollection" tr := timerecord.NewTimeRecorder(method) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, - metrics.TotalLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.TotalLabel, + ).Inc() log := log.Ctx(ctx).With( zap.String("role", typeutil.ProxyRole), zap.String("db", request.DbName), - zap.String("collection", request.CollectionName)) + zap.String("collection", request.CollectionName), + ) log.Debug("HasCollection received") @@ -452,9 +523,11 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle }, nil } - log.Debug("HasCollection enqueued", + log.Debug( + "HasCollection enqueued", zap.Uint64("BeginTS", hct.BeginTs()), - zap.Uint64("EndTS", hct.EndTs())) + zap.Uint64("EndTS", hct.EndTs()), + ) if err := hct.WaitToFinish(); err != nil { log.Warn("HasCollection failed to WaitToFinish", @@ -469,41 +542,56 @@ func (node *Proxy) HasCollection(ctx context.Context, request *milvuspb.HasColle }, nil } - log.Debug("HasCollection done", + log.Debug( + "HasCollection done", zap.Uint64("BeginTS", hct.BeginTs()), - zap.Uint64("EndTS", hct.EndTs())) + zap.Uint64("EndTS", hct.EndTs()), + ) + + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.SuccessLabel, + ).Inc() + metrics.ProxyReqLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + ).Observe(float64(tr.ElapseSpan().Milliseconds())) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, - metrics.SuccessLabel).Inc() - metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) return hct.result, nil } // LoadCollection load a collection into query nodes. func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-LoadCollection") defer sp.End() method := "LoadCollection" tr := timerecord.NewTimeRecorder(method) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, - metrics.TotalLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.TotalLabel, + ).Inc() + lct := &loadCollectionTask{ ctx: ctx, Condition: NewTaskCondition(ctx), LoadCollectionRequest: request, queryCoord: node.queryCoord, datacoord: node.dataCoord, + replicateMsgStream: node.replicateMsgStream, } log := log.Ctx(ctx).With( zap.String("role", typeutil.ProxyRole), zap.String("db", request.DbName), zap.String("collection", request.CollectionName), - zap.Bool("refreshMode", request.Refresh)) + zap.Bool("refreshMode", request.Refresh), + ) log.Debug("LoadCollection received") @@ -516,9 +604,11 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol return merr.Status(err), nil } - log.Debug("LoadCollection enqueued", + log.Debug( + "LoadCollection enqueued", zap.Uint64("BeginTS", lct.BeginTs()), - zap.Uint64("EndTS", lct.EndTs())) + zap.Uint64("EndTS", lct.EndTs()), + ) if err := lct.WaitToFinish(); err != nil { log.Warn("LoadCollection failed to WaitToFinish", @@ -530,20 +620,29 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol return merr.Status(err), nil } - log.Debug("LoadCollection done", + log.Debug( + "LoadCollection done", zap.Uint64("BeginTS", lct.BeginTs()), - zap.Uint64("EndTS", lct.EndTs())) + zap.Uint64("EndTS", lct.EndTs()), + ) + + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.SuccessLabel, + ).Inc() + metrics.ProxyReqLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + ).Observe(float64(tr.ElapseSpan().Milliseconds())) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, - metrics.SuccessLabel).Inc() - metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) return lct.result, nil } // ReleaseCollection remove the loaded collection from query nodes. func (node *Proxy) ReleaseCollection(ctx context.Context, request *milvuspb.ReleaseCollectionRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-ReleaseCollection") @@ -557,7 +656,7 @@ func (node *Proxy) ReleaseCollection(ctx context.Context, request *milvuspb.Rele Condition: NewTaskCondition(ctx), ReleaseCollectionRequest: request, queryCoord: node.queryCoord, - chMgr: node.chMgr, + replicateMsgStream: node.replicateMsgStream, } log := log.Ctx(ctx).With( @@ -607,9 +706,9 @@ func (node *Proxy) ReleaseCollection(ctx context.Context, request *milvuspb.Rele // DescribeCollection get the meta information of specific collection, such as schema, created timestamp and etc. func (node *Proxy) DescribeCollection(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.DescribeCollectionResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -679,9 +778,9 @@ func (node *Proxy) DescribeCollection(ctx context.Context, request *milvuspb.Des // GetStatistics get the statistics, such as `num_rows`. // WARNING: It is an experimental API func (node *Proxy) GetStatistics(ctx context.Context, request *milvuspb.GetStatisticsRequest) (*milvuspb.GetStatisticsResponse, error) { - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.GetStatisticsResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -759,9 +858,9 @@ func (node *Proxy) GetStatistics(ctx context.Context, request *milvuspb.GetStati // GetCollectionStatistics get the collection statistics, such as `num_rows`. func (node *Proxy) GetCollectionStatistics(ctx context.Context, request *milvuspb.GetCollectionStatisticsRequest) (*milvuspb.GetCollectionStatisticsResponse, error) { - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.GetCollectionStatisticsResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -831,9 +930,9 @@ func (node *Proxy) GetCollectionStatistics(ctx context.Context, request *milvusp // ShowCollections list all collections in Milvus. func (node *Proxy) ShowCollections(ctx context.Context, request *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.ShowCollectionsResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-ShowCollections") @@ -897,8 +996,8 @@ func (node *Proxy) ShowCollections(ctx context.Context, request *milvuspb.ShowCo } func (node *Proxy) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-AlterCollection") @@ -961,8 +1060,8 @@ func (node *Proxy) AlterCollection(ctx context.Context, request *milvuspb.AlterC // CreatePartition create a partition in specific collection. func (node *Proxy) CreatePartition(ctx context.Context, request *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-CreatePartition") @@ -1026,8 +1125,8 @@ func (node *Proxy) CreatePartition(ctx context.Context, request *milvuspb.Create // DropPartition drop a partition in specific collection. func (node *Proxy) DropPartition(ctx context.Context, request *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-DropPartition") @@ -1092,9 +1191,9 @@ func (node *Proxy) DropPartition(ctx context.Context, request *milvuspb.DropPart // HasPartition check if partition exist. func (node *Proxy) HasPartition(ctx context.Context, request *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.BoolResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -1102,7 +1201,7 @@ func (node *Proxy) HasPartition(ctx context.Context, request *milvuspb.HasPartit defer sp.End() method := "HasPartition" tr := timerecord.NewTimeRecorder(method) - //TODO: use collectionID instead of collectionName + // TODO: use collectionID instead of collectionName metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc() @@ -1170,8 +1269,8 @@ func (node *Proxy) HasPartition(ctx context.Context, request *milvuspb.HasPartit // LoadPartitions load specific partitions into query nodes. func (node *Proxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadPartitionsRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-LoadPartitions") @@ -1239,8 +1338,8 @@ func (node *Proxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadPar // ReleasePartitions release specific partitions from query nodes. func (node *Proxy) ReleasePartitions(ctx context.Context, request *milvuspb.ReleasePartitionsRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-ReleasePartitions") @@ -1308,9 +1407,9 @@ func (node *Proxy) ReleasePartitions(ctx context.Context, request *milvuspb.Rele // GetPartitionStatistics get the statistics of partition, such as num_rows. func (node *Proxy) GetPartitionStatistics(ctx context.Context, request *milvuspb.GetPartitionStatisticsRequest) (*milvuspb.GetPartitionStatisticsResponse, error) { - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.GetPartitionStatisticsResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -1382,9 +1481,9 @@ func (node *Proxy) GetPartitionStatistics(ctx context.Context, request *milvuspb // ShowPartitions list all partitions in the specific collection. func (node *Proxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.ShowPartitionsResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -1402,7 +1501,7 @@ func (node *Proxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowPar method := "ShowPartitions" tr := timerecord.NewTimeRecorder(method) - //TODO: use collectionID instead of collectionName + // TODO: use collectionID instead of collectionName metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc() @@ -1467,8 +1566,8 @@ func (node *Proxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowPar } func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) { - if !node.checkHealthy() { - return &milvuspb.GetLoadingProgressResponse{Status: unhealthyStatus()}, nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return &milvuspb.GetLoadingProgressResponse{Status: merr.Status(err)}, nil } method := "GetLoadingProgress" tr := timerecord.NewTimeRecorder(method) @@ -1483,13 +1582,13 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get getErrResponse := func(err error) *milvuspb.GetLoadingProgressResponse { log.Warn("fail to get loading progress", - zap.String("collection_name", request.CollectionName), - zap.Strings("partition_name", request.PartitionNames), + zap.String("collectionName", request.CollectionName), + zap.Strings("partitionName", request.PartitionNames), zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - if errors.Is(err, ErrInsufficientMemory) { + if errors.Is(err, merr.ErrServiceMemoryLimitExceeded) { return &milvuspb.GetLoadingProgressResponse{ - Status: InSufficientMemoryStatus(request.GetCollectionName()), + Status: merr.Status(err), } } return &milvuspb.GetLoadingProgressResponse{ @@ -1538,15 +1637,15 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc() metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) return &milvuspb.GetLoadingProgressResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Progress: loadProgress, RefreshProgress: refreshProgress, }, nil } func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadStateRequest) (*milvuspb.GetLoadStateResponse, error) { - if !node.checkHealthy() { - return &milvuspb.GetLoadStateResponse{Status: unhealthyStatus()}, nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return &milvuspb.GetLoadStateResponse{Status: merr.Status(err)}, nil } method := "GetLoadState" tr := timerecord.NewTimeRecorder(method) @@ -1574,16 +1673,8 @@ func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadSt return getErrResponse(err), nil } - // TODO(longjiquan): https://github.com/milvus-io/milvus/issues/21485, Remove `GetComponentStates` after error code - // is ready to distinguish case whether the querycoord is not healthy or the collection is not even loaded. - if statesResp, err := node.queryCoord.GetComponentStates(ctx); err != nil { - return getErrResponse(err), nil - } else if statesResp.State == nil || statesResp.State.StateCode != commonpb.StateCode_Healthy { - return getErrResponse(fmt.Errorf("the querycoord server isn't healthy, state: %v", statesResp.State)), nil - } - successResponse := &milvuspb.GetLoadStateResponse{ - Status: merr.Status(nil), + Status: merr.Success(), } defer func() { log.Debug( @@ -1615,24 +1706,30 @@ func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadSt var progress int64 if len(request.GetPartitionNames()) == 0 { if progress, _, err = getCollectionProgress(ctx, node.queryCoord, request.GetBase(), collectionID); err != nil { - if errors.Is(err, ErrInsufficientMemory) { + if err != nil { + if errors.Is(err, merr.ErrCollectionNotLoaded) { + successResponse.State = commonpb.LoadState_LoadStateNotLoad + return successResponse, nil + } return &milvuspb.GetLoadStateResponse{ - Status: InSufficientMemoryStatus(request.GetCollectionName()), + Status: merr.Status(err), }, nil } - successResponse.State = commonpb.LoadState_LoadStateNotLoad - return successResponse, nil } } else { if progress, _, err = getPartitionProgress(ctx, node.queryCoord, request.GetBase(), request.GetPartitionNames(), request.GetCollectionName(), collectionID, request.GetDbName()); err != nil { - if errors.Is(err, ErrInsufficientMemory) { + if err != nil { + if errors.IsAny(err, + merr.ErrCollectionNotLoaded, + merr.ErrPartitionNotLoaded) { + successResponse.State = commonpb.LoadState_LoadStateNotLoad + return successResponse, nil + } return &milvuspb.GetLoadStateResponse{ - Status: InSufficientMemoryStatus(request.GetCollectionName()), + Status: merr.Status(err), }, nil } - successResponse.State = commonpb.LoadState_LoadStateNotLoad - return successResponse, nil } } if progress >= 100 { @@ -1645,19 +1742,20 @@ func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadSt // CreateIndex create index for collection. func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateIndexRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-CreateIndex") defer sp.End() cit := &createIndexTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - req: request, - rootCoord: node.rootCoord, - datacoord: node.dataCoord, + ctx: ctx, + Condition: NewTaskCondition(ctx), + req: request, + rootCoord: node.rootCoord, + datacoord: node.dataCoord, + replicateMsgStream: node.replicateMsgStream, } method := "CreateIndex" @@ -1716,9 +1814,9 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde // DescribeIndex get the meta information of index, such as index state, index id and etc. func (node *Proxy) DescribeIndex(ctx context.Context, request *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) { - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.DescribeIndexResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -1772,18 +1870,11 @@ func (node *Proxy) DescribeIndex(ctx context.Context, request *milvuspb.Describe zap.Uint64("BeginTs", dit.BeginTs()), zap.Uint64("EndTs", dit.EndTs())) - errCode := commonpb.ErrorCode_UnexpectedError - if dit.result != nil { - errCode = dit.result.Status.GetErrorCode() - } metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() return &milvuspb.DescribeIndexResponse{ - Status: &commonpb.Status{ - ErrorCode: errCode, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -1800,8 +1891,7 @@ func (node *Proxy) DescribeIndex(ctx context.Context, request *milvuspb.Describe // GetIndexStatistics get the information of index. func (node *Proxy) GetIndexStatistics(ctx context.Context, request *milvuspb.GetIndexStatisticsRequest) (*milvuspb.GetIndexStatisticsResponse, error) { - if !node.checkHealthy() { - err := merr.WrapErrServiceNotReady(fmt.Sprintf("proxy %d is unhealthy", paramtable.GetNodeID())) + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.GetIndexStatisticsResponse{ Status: merr.Status(err), }, nil @@ -1851,16 +1941,9 @@ func (node *Proxy) GetIndexStatistics(ctx context.Context, request *milvuspb.Get if err := dit.WaitToFinish(); err != nil { log.Warn(rpcFailedToWaitToFinish(method), zap.Error(err), zap.Uint64("BeginTs", dit.BeginTs()), zap.Uint64("EndTs", dit.EndTs())) - errCode := commonpb.ErrorCode_UnexpectedError - if dit.result != nil { - errCode = dit.result.Status.GetErrorCode() - } metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(node.session.ServerID, 10), method, metrics.FailLabel).Inc() return &milvuspb.GetIndexStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: errCode, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -1878,19 +1961,20 @@ func (node *Proxy) GetIndexStatistics(ctx context.Context, request *milvuspb.Get // DropIndex drop the index of collection. func (node *Proxy) DropIndex(ctx context.Context, request *milvuspb.DropIndexRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-DropIndex") defer sp.End() dit := &dropIndexTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - DropIndexRequest: request, - dataCoord: node.dataCoord, - queryCoord: node.queryCoord, + ctx: ctx, + Condition: NewTaskCondition(ctx), + DropIndexRequest: request, + dataCoord: node.dataCoord, + queryCoord: node.queryCoord, + replicateMsgStream: node.replicateMsgStream, } method := "DropIndex" @@ -1950,9 +2034,9 @@ func (node *Proxy) DropIndex(ctx context.Context, request *milvuspb.DropIndexReq // IndexRows is the num of indexed rows. And TotalRows is the total number of segment rows. // Deprecated: use DescribeIndex instead func (node *Proxy) GetIndexBuildProgress(ctx context.Context, request *milvuspb.GetIndexBuildProgressRequest) (*milvuspb.GetIndexBuildProgressResponse, error) { - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.GetIndexBuildProgressResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -2026,9 +2110,9 @@ func (node *Proxy) GetIndexBuildProgress(ctx context.Context, request *milvuspb. // GetIndexState get the build-state of index. // Deprecated: use DescribeIndex instead func (node *Proxy) GetIndexState(ctx context.Context, request *milvuspb.GetIndexStateRequest) (*milvuspb.GetIndexStateResponse, error) { - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.GetIndexStateResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -2105,9 +2189,9 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert") defer sp.End() - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.MutationResult{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } log := log.Ctx(ctx).With( @@ -2184,7 +2268,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) return constructFailedResponse(err), nil } - if it.result.Status.ErrorCode != commonpb.ErrorCode_Success { + if it.result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { setErrorIndex := func() { numRows := request.NumRows errIndex := make([]uint32, numRows) @@ -2201,8 +2285,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) // InsertCnt always equals to the number of entities in the request it.result.InsertCnt = int64(request.NumRows) - receiveSize := proto.Size(it.insertMsg) - rateCol.Add(internalpb.RateType_DMLInsert.String(), float64(receiveSize)) + rateCol.Add(internalpb.RateType_DMLInsert.String(), float64(it.insertMsg.Size())) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc() @@ -2231,9 +2314,9 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.DeleteLabel, request.GetCollectionName()).Add(float64(proto.Size(request))) - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.MutationResult{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -2243,27 +2326,13 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc() dt := &deleteTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - deleteExpr: request.Expr, - deleteMsg: &BaseDeleteTask{ - BaseMsg: msgstream.BaseMsg{ - HashValues: request.HashKeys, - }, - DeleteRequest: msgpb.DeleteRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_Delete), - commonpbutil.WithMsgID(0), - ), - DbName: request.DbName, - CollectionName: request.CollectionName, - PartitionName: request.PartitionName, - // RowData: transfer column based request to this - }, - }, + ctx: ctx, + Condition: NewTaskCondition(ctx), + req: request, idAllocator: node.rowIDAllocator, chMgr: node.chMgr, chTicker: node.chTicker, + lb: node.lbPolicy, } log.Debug("Enqueue delete request in Proxy") @@ -2290,7 +2359,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) }, nil } - receiveSize := proto.Size(dt.deleteMsg) + receiveSize := proto.Size(dt.req) rateCol.Add(internalpb.RateType_DMLDelete.String(), float64(receiveSize)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, @@ -2314,9 +2383,9 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) ) log.Debug("Start processing upsert request in Proxy") - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.MutationResult{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } method := "Upsert" @@ -2340,7 +2409,7 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) Condition: NewTaskCondition(ctx), req: request, result: &milvuspb.MutationResult{ - Status: merr.Status(nil), + Status: merr.Success(), IDs: &schemapb.IDs{ IdField: nil, }, @@ -2352,22 +2421,6 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) chTicker: node.chTicker, } - constructFailedResponse := func(err error, errCode commonpb.ErrorCode) *milvuspb.MutationResult { - numRows := request.NumRows - errIndex := make([]uint32, numRows) - for i := uint32(0); i < numRows; i++ { - errIndex[i] = i - } - - return &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: errCode, - Reason: err.Error(), - }, - ErrIndex: errIndex, - } - } - log.Debug("Enqueue upsert request in Proxy", zap.Int("len(FieldsData)", len(request.FieldsData)), zap.Int("len(HashKeys)", len(request.HashKeys))) @@ -2393,13 +2446,23 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) metrics.FailLabel).Inc() // Not every error case changes the status internally // change status there to handle it - if it.result.Status.ErrorCode == commonpb.ErrorCode_Success { - it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError + if it.result.GetStatus().GetErrorCode() == commonpb.ErrorCode_Success { + it.result.Status = merr.Status(err) + } + + numRows := request.NumRows + errIndex := make([]uint32, numRows) + for i := uint32(0); i < numRows; i++ { + errIndex[i] = i } - return constructFailedResponse(err, it.result.Status.ErrorCode), nil + + return &milvuspb.MutationResult{ + Status: merr.Status(err), + ErrIndex: errIndex, + }, nil } - if it.result.Status.ErrorCode != commonpb.ErrorCode_Success { + if it.result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { setErrorIndex := func() { numRows := request.NumRows errIndex := make([]uint32, numRows) @@ -2411,10 +2474,7 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) setErrorIndex() } - insertReceiveSize := proto.Size(it.upsertMsg.InsertMsg) - deleteReceiveSize := proto.Size(it.upsertMsg.DeleteMsg) - - rateCol.Add(internalpb.RateType_DMLUpsert.String(), float64(insertReceiveSize+deleteReceiveSize)) + rateCol.Add(internalpb.RateType_DMLUpsert.String(), float64(it.upsertMsg.DeleteMsg.Size()+it.upsertMsg.DeleteMsg.Size())) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc() @@ -2430,27 +2490,46 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) receiveSize := proto.Size(request) metrics.ProxyReceiveBytes.WithLabelValues( strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.SearchLabel, request.GetCollectionName()).Add(float64(receiveSize)) + metrics.SearchLabel, + request.GetCollectionName(), + ).Add(float64(receiveSize)) metrics.ProxyReceivedNQ.WithLabelValues( strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.SearchLabel, request.GetCollectionName()).Add(float64(request.GetNq())) + metrics.SearchLabel, + request.GetCollectionName(), + ).Add(float64(request.GetNq())) rateCol.Add(internalpb.RateType_DQLSearch.String(), float64(request.GetNq())) - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.SearchResults{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } + method := "Search" tr := timerecord.NewTimeRecorder(method) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, - metrics.TotalLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.TotalLabel, + ).Inc() ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search") defer sp.End() + if request.SearchByPrimaryKeys { + placeholderGroupBytes, err := node.getVectorPlaceholderGroupForSearchByPks(ctx, request) + if err != nil { + return &milvuspb.SearchResults{ + Status: merr.Status(err), + }, nil + } + + request.PlaceholderGroup = placeholderGroupBytes + } + qt := &searchTask{ ctx: ctx, Condition: NewTaskCondition(ctx), @@ -2479,7 +2558,8 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) zap.Any("len(PlaceholderGroup)", len(request.PlaceholderGroup)), zap.Any("OutputFields", request.OutputFields), zap.Any("search_params", request.SearchParams), - zap.Uint64("guarantee_timestamp", guaranteeTs)) + zap.Uint64("guarantee_timestamp", guaranteeTs), + ) defer func() { span := tr.ElapseSpan() @@ -2493,10 +2573,14 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) if err := node.sched.dqQueue.Enqueue(qt); err != nil { log.Warn( rpcFailedToEnqueue(method), - zap.Error(err)) + zap.Error(err), + ) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, - metrics.AbandonLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.AbandonLabel, + ).Inc() return &milvuspb.SearchResults{ Status: merr.Status(err), @@ -2506,15 +2590,20 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) log.Debug( rpcEnqueued(method), - zap.Uint64("timestamp", qt.Base.Timestamp)) + zap.Uint64("timestamp", qt.Base.Timestamp), + ) if err := qt.WaitToFinish(); err != nil { log.Warn( rpcFailedToWaitToFinish(method), - zap.Error(err)) + zap.Error(err), + ) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, - metrics.FailLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.FailLabel, + ).Inc() return &milvuspb.SearchResults{ Status: merr.Status(err), @@ -2522,19 +2611,34 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) } span := tr.CtxRecord(ctx, "wait search result") - metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.SearchLabel).Observe(float64(span.Milliseconds())) + metrics.ProxyWaitForSearchResultLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + metrics.SearchLabel, + ).Observe(float64(span.Milliseconds())) + tr.CtxRecord(ctx, "wait search result") log.Debug(rpcDone(method)) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, - metrics.SuccessLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.SuccessLabel, + ).Inc() + metrics.ProxySearchVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(qt.result.GetResults().GetNumQueries())) + searchDur := tr.ElapseSpan().Milliseconds() - metrics.ProxySQLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.SearchLabel).Observe(float64(searchDur)) - metrics.ProxyCollectionSQLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.SearchLabel, request.CollectionName).Observe(float64(searchDur)) + metrics.ProxySQLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + metrics.SearchLabel, + ).Observe(float64(searchDur)) + + metrics.ProxyCollectionSQLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + metrics.SearchLabel, + request.CollectionName, + ).Observe(float64(searchDur)) + if qt.result != nil { sentSize := proto.Size(qt.result) metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize)) @@ -2543,16 +2647,68 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) return qt.result, nil } +func (node *Proxy) getVectorPlaceholderGroupForSearchByPks(ctx context.Context, request *milvuspb.SearchRequest) ([]byte, error) { + placeholderGroup := &commonpb.PlaceholderGroup{} + err := proto.Unmarshal(request.PlaceholderGroup, placeholderGroup) + if err != nil { + return nil, err + } + + if len(placeholderGroup.Placeholders) != 1 || len(placeholderGroup.Placeholders[0].Values) != 1 { + return nil, merr.WrapErrParameterInvalidMsg("please provide primary key") + } + queryExpr := string(placeholderGroup.Placeholders[0].Values[0]) + + annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, request.SearchParams) + if err != nil { + return nil, err + } + + queryRequest := &milvuspb.QueryRequest{ + Base: request.Base, + DbName: request.DbName, + CollectionName: request.CollectionName, + Expr: queryExpr, + OutputFields: []string{annsField}, + PartitionNames: request.PartitionNames, + TravelTimestamp: request.TravelTimestamp, + GuaranteeTimestamp: request.GuaranteeTimestamp, + QueryParams: nil, + NotReturnAllMeta: request.NotReturnAllMeta, + ConsistencyLevel: request.ConsistencyLevel, + UseDefaultConsistency: request.UseDefaultConsistency, + } + + queryResults, _ := node.Query(ctx, queryRequest) + + err = merr.Error(queryResults.GetStatus()) + if err != nil { + return nil, err + } + + var vectorFieldsData *schemapb.FieldData + for _, fieldsData := range queryResults.GetFieldsData() { + if fieldsData.GetFieldName() == annsField { + vectorFieldsData = fieldsData + break + } + } + + placeholderGroupBytes, err := funcutil.FieldDataToPlaceholderGroupBytes(vectorFieldsData) + if err != nil { + return nil, err + } + + return placeholderGroupBytes, nil +} + // Flush notify data nodes to persist the data of collection. func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*milvuspb.FlushResponse, error) { resp := &milvuspb.FlushResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "", - }, + Status: merr.Success(), } - if !node.checkHealthy() { - resp.Status.Reason = "proxy is not healthy" + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp.Status = merr.Status(err) return resp, nil } @@ -2560,10 +2716,11 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (* defer sp.End() ft := &flushTask{ - ctx: ctx, - Condition: NewTaskCondition(ctx), - FlushRequest: request, - dataCoord: node.dataCoord, + ctx: ctx, + Condition: NewTaskCondition(ctx), + FlushRequest: request, + dataCoord: node.dataCoord, + replicateMsgStream: node.replicateMsgStream, } method := "Flush" @@ -2577,14 +2734,14 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (* log.Debug(rpcReceived(method)) - if err := node.sched.ddQueue.Enqueue(ft); err != nil { + if err := node.sched.dcQueue.Enqueue(ft); err != nil { log.Warn( rpcFailedToEnqueue(method), zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel).Inc() - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } @@ -2602,8 +2759,7 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (* metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } @@ -2622,17 +2778,21 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* receiveSize := proto.Size(request) metrics.ProxyReceiveBytes.WithLabelValues( strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.QueryLabel, request.GetCollectionName()).Add(float64(receiveSize)) + metrics.QueryLabel, + request.GetCollectionName(), + ).Add(float64(receiveSize)) metrics.ProxyReceivedNQ.WithLabelValues( strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.SearchLabel, request.GetCollectionName()).Add(float64(1)) + metrics.SearchLabel, + request.GetCollectionName(), + ).Add(float64(1)) rateCol.Add(internalpb.RateType_DQLQuery.String(), 1) - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.QueryResults{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -2657,14 +2817,18 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* method := "Query" - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, - metrics.TotalLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.TotalLabel, + ).Inc() log := log.Ctx(ctx).With( zap.String("role", typeutil.ProxyRole), zap.String("db", request.DbName), zap.String("collection", request.CollectionName), - zap.Strings("partitions", request.PartitionNames)) + zap.Strings("partitions", request.PartitionNames), + ) defer func() { span := tr.ElapseSpan() @@ -2684,15 +2848,20 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* zap.String("expr", request.Expr), zap.Strings("OutputFields", request.OutputFields), zap.Uint64("travel_timestamp", request.TravelTimestamp), - zap.Uint64("guarantee_timestamp", request.GuaranteeTimestamp)) + zap.Uint64("guarantee_timestamp", request.GuaranteeTimestamp), + ) if err := node.sched.dqQueue.Enqueue(qt); err != nil { log.Warn( rpcFailedToEnqueue(method), - zap.Error(err)) + zap.Error(err), + ) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, - metrics.AbandonLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.AbandonLabel, + ).Inc() return &milvuspb.QueryResults{ Status: merr.Status(err), @@ -2715,28 +2884,41 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* }, nil } span := tr.CtxRecord(ctx, "wait query result") - metrics.ProxyWaitForSearchResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.QueryLabel).Observe(float64(span.Milliseconds())) + metrics.ProxyWaitForSearchResultLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + metrics.QueryLabel, + ).Observe(float64(span.Milliseconds())) log.Debug(rpcDone(method)) - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, - metrics.SuccessLabel).Inc() + metrics.ProxyFunctionCall.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + method, + metrics.SuccessLabel, + ).Inc() + + metrics.ProxySQLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + metrics.QueryLabel, + ).Observe(float64(tr.ElapseSpan().Milliseconds())) + + metrics.ProxyCollectionSQLatency.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + metrics.QueryLabel, + request.CollectionName, + ).Observe(float64(tr.ElapseSpan().Milliseconds())) - metrics.ProxySQLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) - metrics.ProxyCollectionSQLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.QueryLabel, request.CollectionName).Observe(float64(tr.ElapseSpan().Milliseconds())) sentSize := proto.Size(qt.result) rateCol.Add(metricsinfo.ReadResultThroughput, float64(sentSize)) metrics.ProxyReadReqSendBytes.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(sentSize)) + return qt.result, nil } // CreateAlias create alias for collection, then you can search the collection with alias. func (node *Proxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-CreateAlias") @@ -2799,26 +2981,20 @@ func (node *Proxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAlia func (node *Proxy) DescribeAlias(ctx context.Context, request *milvuspb.DescribeAliasRequest) (*milvuspb.DescribeAliasResponse, error) { return &milvuspb.DescribeAliasResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "TODO: implement me", - }, + Status: merr.Status(merr.WrapErrServiceUnavailable("DescribeAlias unimplemented")), }, nil } func (node *Proxy) ListAliases(ctx context.Context, request *milvuspb.ListAliasesRequest) (*milvuspb.ListAliasesResponse, error) { return &milvuspb.ListAliasesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "TODO: implement me", - }, + Status: merr.Status(merr.WrapErrServiceUnavailable("ListAliases unimplemented")), }, nil } // DropAlias alter the alias of collection. func (node *Proxy) DropAlias(ctx context.Context, request *milvuspb.DropAliasRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-DropAlias") @@ -2880,8 +3056,8 @@ func (node *Proxy) DropAlias(ctx context.Context, request *milvuspb.DropAliasReq // AlterAlias alter alias of collection. func (node *Proxy) AlterAlias(ctx context.Context, request *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-AlterAlias") @@ -2945,10 +3121,7 @@ func (node *Proxy) AlterAlias(ctx context.Context, request *milvuspb.AlterAliasR // CalcDistance calculates the distances between vectors. func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDistanceRequest) (*milvuspb.CalcDistanceResults, error) { return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "interface obsolete", - }, + Status: merr.Status(merr.WrapErrServiceUnavailable("CalcDistance deprecated")), }, nil } @@ -2959,10 +3132,10 @@ func (node *Proxy) FlushAll(ctx context.Context, req *milvuspb.FlushAllRequest) log := log.With(zap.String("db", req.GetDbName())) resp := &milvuspb.FlushAllResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, + Status: merr.Success(), } - if !node.checkHealthy() { - resp.Status.Reason = "proxy is not healthy" + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp.Status = merr.Status(err) return resp, nil } log.Info(rpcReceived("FlushAll")) @@ -2970,10 +3143,10 @@ func (node *Proxy) FlushAll(ctx context.Context, req *milvuspb.FlushAllRequest) hasError := func(status *commonpb.Status, err error) bool { if err != nil { resp.Status = merr.Status(err) - log.Warn("FlushAll failed", zap.String("err", err.Error())) + log.Warn("FlushAll failed", zap.Error(err)) return true } - if status != nil && status.ErrorCode != commonpb.ErrorCode_Success { + if status.GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("FlushAll failed", zap.String("err", status.GetReason())) resp.Status = status return true @@ -2993,7 +3166,7 @@ func (node *Proxy) FlushAll(ctx context.Context, req *milvuspb.FlushAllRequest) return dbName == req.GetDbName() }) if len(dbNames) == 0 { - resp.Status.Reason = fmt.Sprintf("failed to get db %s", req.GetDbName()) + resp.Status = merr.Status(merr.WrapErrDatabaseNotFound(req.GetDbName())) return resp, nil } } @@ -3036,12 +3209,11 @@ func (node *Proxy) FlushAll(ctx context.Context, req *milvuspb.FlushAllRequest) ts, err := node.tsoAllocator.AllocOne(ctx) if err != nil { log.Warn("FlushAll failed", zap.Error(err)) - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } resp.FlushAllTs = ts - resp.Status.ErrorCode = commonpb.ErrorCode_Success log.Info(rpcDone("FlushAll"), zap.Uint64("FlushAllTs", ts), zap.Time("FlushAllTime", tsoutil.PhysicalTime(ts))) @@ -3051,10 +3223,7 @@ func (node *Proxy) FlushAll(ctx context.Context, req *milvuspb.FlushAllRequest) // GetDdChannel returns the used channel for dd operations. func (node *Proxy) GetDdChannel(ctx context.Context, request *internalpb.GetDdChannelRequest) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "TODO: implement me", - }, + Status: merr.Status(merr.WrapErrServiceUnavailable("unimp")), }, nil } @@ -3071,12 +3240,10 @@ func (node *Proxy) GetPersistentSegmentInfo(ctx context.Context, req *milvuspb.G zap.Any("collection", req.CollectionName)) resp := &milvuspb.GetPersistentSegmentInfoResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Success(), } - if !node.checkHealthy() { - resp.Status = unhealthyStatus() + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp.Status = merr.Status(err) return resp, nil } method := "GetPersistentSegmentInfo" @@ -3088,7 +3255,7 @@ func (node *Proxy) GetPersistentSegmentInfo(ctx context.Context, req *milvuspb.G collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) if err != nil { metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - resp.Status.Reason = fmt.Errorf("getCollectionID failed, err:%w", err).Error() + resp.Status = merr.Status(err) return resp, nil } @@ -3100,7 +3267,7 @@ func (node *Proxy) GetPersistentSegmentInfo(ctx context.Context, req *milvuspb.G }) if err != nil { metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - resp.Status.Reason = fmt.Errorf("getSegmentsOfCollection, err:%w", err).Error() + resp.Status = merr.Status(err) return resp, nil } @@ -3118,18 +3285,19 @@ func (node *Proxy) GetPersistentSegmentInfo(ctx context.Context, req *milvuspb.G metrics.FailLabel).Inc() log.Warn("GetPersistentSegmentInfo fail", zap.Error(err)) - resp.Status.Reason = fmt.Errorf("dataCoord:GetSegmentInfo, err:%w", err).Error() + resp.Status = merr.Status(err) return resp, nil } - log.Debug("GetPersistentSegmentInfo", - zap.Int("len(infos)", len(infoResp.Infos)), - zap.Any("status", infoResp.Status)) - if infoResp.Status.ErrorCode != commonpb.ErrorCode_Success { + err = merr.Error(infoResp.GetStatus()) + if err != nil { metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - resp.Status.Reason = infoResp.Status.Reason + resp.Status = merr.Status(err) return resp, nil } + log.Debug("GetPersistentSegmentInfo", + zap.Int("len(infos)", len(infoResp.Infos)), + zap.Any("status", infoResp.Status)) persistentInfos := make([]*milvuspb.PersistentSegmentInfo, len(infoResp.Infos)) for i, info := range infoResp.Infos { persistentInfos[i] = &milvuspb.PersistentSegmentInfo{ @@ -3143,7 +3311,6 @@ func (node *Proxy) GetPersistentSegmentInfo(ctx context.Context, req *milvuspb.G metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc() metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) - resp.Status.ErrorCode = commonpb.ErrorCode_Success resp.Infos = persistentInfos return resp, nil } @@ -3161,12 +3328,10 @@ func (node *Proxy) GetQuerySegmentInfo(ctx context.Context, req *milvuspb.GetQue zap.Any("collection", req.CollectionName)) resp := &milvuspb.GetQuerySegmentInfoResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, + Status: merr.Success(), } - if !node.checkHealthy() { - resp.Status = unhealthyStatus() + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp.Status = merr.Status(err) return resp, nil } @@ -3178,7 +3343,7 @@ func (node *Proxy) GetQuerySegmentInfo(ctx context.Context, req *milvuspb.GetQue collID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.CollectionName) if err != nil { metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } infoResp, err := node.queryCoord.GetSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{ @@ -3189,23 +3354,19 @@ func (node *Proxy) GetQuerySegmentInfo(ctx context.Context, req *milvuspb.GetQue ), CollectionID: collID, }) + if err == nil { + err = merr.Error(infoResp.GetStatus()) + } if err != nil { metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() log.Error("Failed to get segment info from QueryCoord", zap.Error(err)) - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } log.Debug("GetQuerySegmentInfo", zap.Any("infos", infoResp.Infos), zap.Any("status", infoResp.Status)) - if infoResp.Status.ErrorCode != commonpb.ErrorCode_Success { - metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - log.Error("Failed to get segment info from QueryCoord", - zap.String("errMsg", infoResp.Status.Reason)) - resp.Status.Reason = infoResp.Status.Reason - return resp, nil - } queryInfos := make([]*milvuspb.QuerySegmentInfo, len(infoResp.Infos)) for i, info := range infoResp.Infos { queryInfos[i] = &milvuspb.QuerySegmentInfo{ @@ -3223,7 +3384,6 @@ func (node *Proxy) GetQuerySegmentInfo(ctx context.Context, req *milvuspb.GetQue metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc() metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) - resp.Status.ErrorCode = commonpb.ErrorCode_Success resp.Infos = queryInfos return resp, nil } @@ -3281,33 +3441,25 @@ func (node *Proxy) Dummy(ctx context.Context, req *milvuspb.DummyRequest) (*milv // RegisterLink registers a link func (node *Proxy) RegisterLink(ctx context.Context, req *milvuspb.RegisterLinkRequest) (*milvuspb.RegisterLinkResponse, error) { - code := node.stateCode.Load().(commonpb.StateCode) + code := node.GetStateCode() ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-RegisterLink") defer sp.End() log := log.Ctx(ctx).With( zap.String("role", typeutil.ProxyRole), - zap.Any("state code of proxy", code)) + zap.String("state", code.String())) log.Debug("RegisterLink") - if code != commonpb.StateCode_Healthy { + if err := merr.CheckHealthy(code); err != nil { return &milvuspb.RegisterLinkResponse{ - Address: nil, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "proxy not healthy", - }, + Status: merr.Status(err), }, nil } - //metrics.ProxyLinkedSDKs.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Inc() + // metrics.ProxyLinkedSDKs.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Inc() return &milvuspb.RegisterLinkResponse{ - Address: nil, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: os.Getenv(metricsinfo.DeployModeEnvKey), - }, + Status: merr.Success(os.Getenv(metricsinfo.DeployModeEnvKey)), }, nil } @@ -3323,16 +3475,14 @@ func (node *Proxy) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsReque zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("req", req.Request)) - if !node.checkHealthy() { - err := merr.WrapErrServiceNotReady(fmt.Sprintf("proxy %d is unhealthy", paramtable.GetNodeID())) + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { log.Warn("Proxy.GetMetrics failed", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("req", req.Request), zap.Error(err)) return &milvuspb.GetMetricsResponse{ - Status: merr.Status(err), - Response: "", + Status: merr.Status(err), }, nil } @@ -3344,8 +3494,7 @@ func (node *Proxy) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsReque zap.Error(err)) return &milvuspb.GetMetricsResponse{ - Status: merr.Status(err), - Response: "", + Status: merr.Status(err), }, nil } @@ -3378,11 +3527,7 @@ func (node *Proxy) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsReque zap.String("metricType", metricType)) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: metricsinfo.MsgUnimplementedMetric, - }, - Response: "", + Status: merr.Status(merr.WrapErrMetricNotFound(metricType)), }, nil } @@ -3396,8 +3541,7 @@ func (node *Proxy) GetProxyMetrics(ctx context.Context, req *milvuspb.GetMetrics zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("req", req.Request)) - if !node.checkHealthy() { - err := merr.WrapErrServiceNotReady(fmt.Sprintf("proxy %d is unhealthy", paramtable.GetNodeID())) + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { log.Warn("Proxy.GetProxyMetrics failed", zap.Error(err)) @@ -3429,14 +3573,11 @@ func (node *Proxy) GetProxyMetrics(ctx context.Context, req *milvuspb.GetMetrics zap.Error(err)) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } - //log.Debug("Proxy.GetProxyMetrics", + // log.Debug("Proxy.GetProxyMetrics", // zap.String("metricType", metricType)) return proxyMetrics, nil @@ -3446,10 +3587,7 @@ func (node *Proxy) GetProxyMetrics(ctx context.Context, req *milvuspb.GetMetrics zap.String("metricType", metricType)) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: metricsinfo.MsgUnimplementedMetric, - }, + Status: merr.Status(merr.WrapErrMetricNotFound(metricType)), }, nil } @@ -3464,20 +3602,18 @@ func (node *Proxy) LoadBalance(ctx context.Context, req *milvuspb.LoadBalanceReq zap.Int64("proxy_id", paramtable.GetNodeID()), zap.Any("req", req)) - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - } + status := merr.Success() collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) if err != nil { log.Warn("failed to get collection id", zap.String("collectionName", req.GetCollectionName()), zap.Error(err)) - status.Reason = err.Error() + status = merr.Status(err) return status, nil } infoResp, err := node.queryCoord.LoadBalance(ctx, &querypb.LoadBalanceRequest{ @@ -3496,19 +3632,18 @@ func (node *Proxy) LoadBalance(ctx context.Context, req *milvuspb.LoadBalanceReq log.Warn("Failed to LoadBalance from Query Coordinator", zap.Any("req", req), zap.Error(err)) - status.Reason = err.Error() + status = merr.Status(err) return status, nil } if infoResp.ErrorCode != commonpb.ErrorCode_Success { log.Warn("Failed to LoadBalance from Query Coordinator", zap.String("errMsg", infoResp.Reason)) - status.Reason = infoResp.Reason + status = infoResp return status, nil } log.Debug("LoadBalance Done", zap.Any("req", req), zap.Any("status", infoResp)) - status.ErrorCode = commonpb.ErrorCode_Success return status, nil } @@ -3523,8 +3658,8 @@ func (node *Proxy) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasReq zap.Int64("collection", req.GetCollectionID()), zap.Bool("with shard nodes", req.GetWithShardNodes())) resp := &milvuspb.GetReplicasResponse{} - if !node.checkHealthy() { - resp.Status = unhealthyStatus() + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp.Status = merr.Status(err) return resp, nil } @@ -3534,7 +3669,12 @@ func (node *Proxy) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasReq ) if req.GetCollectionName() != "" { - req.CollectionID, _ = globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) + var err error + req.CollectionID, err = globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) + if err != nil { + resp.Status = merr.Status(err) + return resp, nil + } } r, err := node.queryCoord.GetReplicas(ctx, req) @@ -3558,8 +3698,8 @@ func (node *Proxy) GetCompactionState(ctx context.Context, req *milvuspb.GetComp log.Debug("received GetCompactionState request") resp := &milvuspb.GetCompactionStateResponse{} - if !node.checkHealthy() { - resp.Status = unhealthyStatus() + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp.Status = merr.Status(err) return resp, nil } @@ -3580,8 +3720,8 @@ func (node *Proxy) ManualCompaction(ctx context.Context, req *milvuspb.ManualCom log.Info("received ManualCompaction request") resp := &milvuspb.ManualCompactionResponse{} - if !node.checkHealthy() { - resp.Status = unhealthyStatus() + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp.Status = merr.Status(err) return resp, nil } @@ -3602,8 +3742,8 @@ func (node *Proxy) GetCompactionStateWithPlans(ctx context.Context, req *milvusp log.Debug("received GetCompactionStateWithPlans request") resp := &milvuspb.GetCompactionPlansResponse{} - if !node.checkHealthy() { - resp.Status = unhealthyStatus() + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp.Status = merr.Status(err) return resp, nil } @@ -3614,7 +3754,7 @@ func (node *Proxy) GetCompactionStateWithPlans(ctx context.Context, req *milvusp return resp, err } -// GetFlushState gets the flush state of multiple segments +// GetFlushState gets the flush state of the collection based on the provided flush ts and segment IDs. func (node *Proxy) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-GetFlushState") defer sp.End() @@ -3624,18 +3764,37 @@ func (node *Proxy) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStat log.Debug("received get flush state request", zap.Any("request", req)) var err error - resp := &milvuspb.GetFlushStateResponse{} - if !node.checkHealthy() { - resp.Status = unhealthyStatus() + failResp := &milvuspb.GetFlushStateResponse{} + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + failResp.Status = merr.Status(err) log.Warn("unable to get flush state because of closed server") - return resp, nil + return failResp, nil + } + + stateReq := &datapb.GetFlushStateRequest{ + SegmentIDs: req.GetSegmentIDs(), + FlushTs: req.GetFlushTs(), + } + + if len(req.GetCollectionName()) > 0 { // For compatibility with old client + if err = validateCollectionName(req.GetCollectionName()); err != nil { + failResp.Status = merr.Status(err) + return failResp, nil + } + collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetDbName(), req.GetCollectionName()) + if err != nil { + failResp.Status = merr.Status(err) + return failResp, nil + } + stateReq.CollectionID = collectionID } - resp, err = node.dataCoord.GetFlushState(ctx, req) + resp, err := node.dataCoord.GetFlushState(ctx, stateReq) if err != nil { log.Warn("failed to get flush state response", zap.Error(err)) - return nil, err + failResp.Status = merr.Status(err) + return failResp, nil } log.Debug("received get flush state response", zap.Any("response", resp)) @@ -3653,8 +3812,8 @@ func (node *Proxy) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushA var err error resp := &milvuspb.GetFlushAllStateResponse{} - if !node.checkHealthy() { - resp.Status = unhealthyStatus() + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp.Status = merr.Status(err) log.Warn("GetFlushAllState failed, closed server") return resp, nil } @@ -3662,7 +3821,7 @@ func (node *Proxy) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushA resp, err = node.dataCoord.GetFlushAllState(ctx, req) if err != nil { resp.Status = merr.Status(err) - log.Warn("GetFlushAllState failed", zap.String("err", err.Error())) + log.Warn("GetFlushAllState failed", zap.Error(err)) return resp, nil } log.Debug("GetFlushAllState done", zap.Bool("flushed", resp.GetFlushed())) @@ -3671,23 +3830,10 @@ func (node *Proxy) GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushA // checkHealthy checks proxy state is Healthy func (node *Proxy) checkHealthy() bool { - code := node.stateCode.Load().(commonpb.StateCode) + code := node.GetStateCode() return code == commonpb.StateCode_Healthy } -func (node *Proxy) checkHealthyAndReturnCode() (commonpb.StateCode, bool) { - code := node.stateCode.Load().(commonpb.StateCode) - return code, code == commonpb.StateCode_Healthy -} - -// unhealthyStatus returns the proxy not healthy status -func unhealthyStatus() *commonpb.Status { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "proxy not healthy", - } -} - // Import data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Import") @@ -3700,10 +3846,10 @@ func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*mi zap.String("partition name", req.GetPartitionName()), zap.Strings("files", req.GetFiles())) resp := &milvuspb.ImportResponse{ - Status: merr.Status(nil), + Status: merr.Success(), } - if !node.checkHealthy() { - resp.Status = unhealthyStatus() + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp.Status = merr.Status(err) return resp, nil } @@ -3711,8 +3857,7 @@ func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*mi if err != nil { log.Error("failed to execute import request", zap.Error(err)) - resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - resp.Status.Reason = "request options is not illegal \n" + err.Error() + " \nIllegal option format \n" + importutil.OptionFormat + resp.Status = merr.Status(err) return resp, nil } @@ -3727,8 +3872,7 @@ func (node *Proxy) Import(ctx context.Context, req *milvuspb.ImportRequest) (*mi metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() log.Error("failed to execute bulk insert request", zap.Error(err)) - resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } @@ -3746,9 +3890,11 @@ func (node *Proxy) GetImportState(ctx context.Context, req *milvuspb.GetImportSt log.Debug("received get import state request", zap.Int64("taskID", req.GetTask())) - resp := &milvuspb.GetImportStateResponse{} - if !node.checkHealthy() { - resp.Status = unhealthyStatus() + resp := &milvuspb.GetImportStateResponse{ + Status: merr.Success(), + } + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp.Status = merr.Status(err) return resp, nil } method := "GetImportState" @@ -3761,8 +3907,7 @@ func (node *Proxy) GetImportState(ctx context.Context, req *milvuspb.GetImportSt metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() log.Error("failed to execute get import state", zap.Error(err)) - resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } @@ -3782,9 +3927,11 @@ func (node *Proxy) ListImportTasks(ctx context.Context, req *milvuspb.ListImport log := log.Ctx(ctx) log.Debug("received list import tasks request") - resp := &milvuspb.ListImportTasksResponse{} - if !node.checkHealthy() { - resp.Status = unhealthyStatus() + resp := &milvuspb.ListImportTasksResponse{ + Status: merr.Success(), + } + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp.Status = merr.Status(err) return resp, nil } method := "ListImportTasks" @@ -3796,8 +3943,7 @@ func (node *Proxy) ListImportTasks(ctx context.Context, req *milvuspb.ListImport metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() log.Error("failed to execute list import tasks", zap.Error(err)) - resp.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - resp.Status.Reason = err.Error() + resp.Status = merr.Status(err) return resp, nil } @@ -3819,8 +3965,8 @@ func (node *Proxy) InvalidateCredentialCache(ctx context.Context, request *proxy zap.String("username", request.Username)) log.Debug("received request to invalidate credential cache") - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } username := request.Username @@ -3829,7 +3975,7 @@ func (node *Proxy) InvalidateCredentialCache(ctx context.Context, request *proxy } log.Debug("complete to invalidate credential cache") - return merr.Status(nil), nil + return merr.Success(), nil } // UpdateCredentialCache update the credential cache of specified username. @@ -3842,8 +3988,8 @@ func (node *Proxy) UpdateCredentialCache(ctx context.Context, request *proxypb.U zap.String("username", request.Username)) log.Debug("received request to update credential cache") - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } credInfo := &internalpb.CredentialInfo{ @@ -3855,7 +4001,7 @@ func (node *Proxy) UpdateCredentialCache(ctx context.Context, request *proxypb.U } log.Debug("complete to update credential cache") - return merr.Status(nil), nil + return merr.Success(), nil } func (node *Proxy) CreateCredential(ctx context.Context, req *milvuspb.CreateCredentialRequest) (*commonpb.Status, error) { @@ -3867,42 +4013,32 @@ func (node *Proxy) CreateCredential(ctx context.Context, req *milvuspb.CreateCre log.Debug("CreateCredential", zap.String("role", typeutil.ProxyRole)) - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } // validate params username := req.Username if err := ValidateUsername(username); err != nil { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } rawPassword, err := crypto.Base64Decode(req.Password) if err != nil { log.Error("decode password fail", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_CreateCredentialFailure, - Reason: "decode password fail key:" + req.Username, - }, nil + err = errors.Wrap(err, "decode password fail") + return merr.Status(err), nil } if err = ValidatePassword(rawPassword); err != nil { log.Error("illegal password", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } encryptedPassword, err := crypto.PasswordEncrypt(rawPassword) if err != nil { log.Error("encrypt password fail", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_CreateCredentialFailure, - Reason: "encrypt password fail key:" + req.Username, - }, nil + err = errors.Wrap(err, "encrypt password failed") + return merr.Status(err), nil } credInfo := &internalpb.CredentialInfo{ @@ -3928,35 +4064,28 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre log.Debug("UpdateCredential", zap.String("role", typeutil.ProxyRole)) - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } rawOldPassword, err := crypto.Base64Decode(req.OldPassword) if err != nil { log.Error("decode old password fail", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UpdateCredentialFailure, - Reason: "decode old password fail when updating:" + req.Username, - }, nil + err = errors.Wrap(err, "decode old password failed") + return merr.Status(err), nil } rawNewPassword, err := crypto.Base64Decode(req.NewPassword) if err != nil { log.Error("decode password fail", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UpdateCredentialFailure, - Reason: "decode password fail when updating:" + req.Username, - }, nil + err = errors.Wrap(err, "decode password failed") + return merr.Status(err), nil } // valid new password if err = ValidatePassword(rawNewPassword); err != nil { log.Error("illegal password", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } skipPasswordVerify := false @@ -3969,20 +4098,16 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre } if !skipPasswordVerify && !passwordVerify(ctx, req.Username, rawOldPassword, globalMetaCache) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UpdateCredentialFailure, - Reason: "old password is not correct:" + req.Username, - }, nil + err := merr.WrapErrPrivilegeNotAuthenticated("old password not correct for %s", req.GetUsername()) + return merr.Status(err), nil } // update meta data encryptedPassword, err := crypto.PasswordEncrypt(rawNewPassword) if err != nil { log.Error("encrypt password fail", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UpdateCredentialFailure, - Reason: "encrypt password fail when updating:" + req.Username, - }, nil + err = errors.Wrap(err, "encrypt password failed") + return merr.Status(err), nil } updateCredReq := &internalpb.CredentialInfo{ Username: req.Username, @@ -4007,15 +4132,13 @@ func (node *Proxy) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCre log.Debug("DeleteCredential", zap.String("role", typeutil.ProxyRole)) - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } if req.Username == util.UserRoot { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_DeleteCredentialFailure, - Reason: "user root cannot be deleted", - }, nil + err := merr.WrapErrPrivilegeNotPermitted("root user cannot be deleted") + return merr.Status(err), nil } result, err := node.rootCoord.DeleteCredential(ctx, req) if err != nil { // for error like conntext timeout etc. @@ -4034,8 +4157,8 @@ func (node *Proxy) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUser zap.String("role", typeutil.ProxyRole)) log.Debug("ListCredUsers") - if !node.checkHealthy() { - return &milvuspb.ListCredUsersResponse{Status: unhealthyStatus()}, nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return &milvuspb.ListCredUsersResponse{Status: merr.Status(err)}, nil } rootCoordReq := &milvuspb.ListCredUsersRequest{ Base: commonpbutil.NewMsgBase( @@ -4049,7 +4172,7 @@ func (node *Proxy) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUser }, nil } return &milvuspb.ListCredUsersResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Usernames: resp.Usernames, }, nil } @@ -4061,8 +4184,8 @@ func (node *Proxy) CreateRole(ctx context.Context, req *milvuspb.CreateRoleReque log := log.Ctx(ctx) log.Debug("CreateRole", zap.Any("req", req)) - if code, ok := node.checkHealthyAndReturnCode(); !ok { - return errorutil.UnhealthyStatus(code), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } var roleName string @@ -4070,10 +4193,7 @@ func (node *Proxy) CreateRole(ctx context.Context, req *milvuspb.CreateRoleReque roleName = req.Entity.Name } if err := ValidateRoleName(roleName); err != nil { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } result, err := node.rootCoord.CreateRole(ctx, req) @@ -4092,21 +4212,15 @@ func (node *Proxy) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) log.Debug("DropRole", zap.Any("req", req)) - if code, ok := node.checkHealthyAndReturnCode(); !ok { - return errorutil.UnhealthyStatus(code), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } if err := ValidateRoleName(req.RoleName); err != nil { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } if IsDefaultRole(req.RoleName) { - errMsg := fmt.Sprintf("the role[%s] is a default role, which can't be droped", req.RoleName) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: errMsg, - }, nil + err := merr.WrapErrPrivilegeNotPermitted("the role[%s] is a default role, which can't be droped", req.GetRoleName()) + return merr.Status(err), nil } result, err := node.rootCoord.DropRole(ctx, req) if err != nil { @@ -4125,20 +4239,14 @@ func (node *Proxy) OperateUserRole(ctx context.Context, req *milvuspb.OperateUse log := log.Ctx(ctx) log.Debug("OperateUserRole", zap.Any("req", req)) - if code, ok := node.checkHealthyAndReturnCode(); !ok { - return errorutil.UnhealthyStatus(code), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } if err := ValidateUsername(req.Username); err != nil { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } if err := ValidateRoleName(req.RoleName); err != nil { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } result, err := node.rootCoord.OperateUserRole(ctx, req) @@ -4156,8 +4264,8 @@ func (node *Proxy) SelectRole(ctx context.Context, req *milvuspb.SelectRoleReque log := log.Ctx(ctx) log.Debug("SelectRole", zap.Any("req", req)) - if code, ok := node.checkHealthyAndReturnCode(); !ok { - return &milvuspb.SelectRoleResponse{Status: errorutil.UnhealthyStatus(code)}, nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return &milvuspb.SelectRoleResponse{Status: merr.Status(err)}, nil } if req.Role != nil { @@ -4185,8 +4293,8 @@ func (node *Proxy) SelectUser(ctx context.Context, req *milvuspb.SelectUserReque log := log.Ctx(ctx) log.Debug("SelectUser", zap.Any("req", req)) - if code, ok := node.checkHealthyAndReturnCode(); !ok { - return &milvuspb.SelectUserResponse{Status: errorutil.UnhealthyStatus(code)}, nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return &milvuspb.SelectUserResponse{Status: merr.Status(err)}, nil } if req.User != nil { @@ -4248,22 +4356,16 @@ func (node *Proxy) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePr log.Debug("OperatePrivilege", zap.Any("req", req)) - if code, ok := node.checkHealthyAndReturnCode(); !ok { - return errorutil.UnhealthyStatus(code), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } if err := node.validPrivilegeParams(req); err != nil { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } curUser, err := GetCurUserFromContext(ctx) if err != nil { log.Warn("fail to get current user", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "fail to get current user, please make sure the authorizationEnabled setting in the milvus.yaml is true", - }, nil + return merr.Status(err), nil } req.Entity.Grantor.User = &milvuspb.UserEntity{Name: curUser} result, err := node.rootCoord.OperatePrivilege(ctx, req) @@ -4276,7 +4378,7 @@ func (node *Proxy) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePr func (node *Proxy) validGrantParams(req *milvuspb.SelectGrantRequest) error { if req.Entity == nil { - return fmt.Errorf("the grant entity in the request is nil") + return merr.WrapErrParameterInvalidMsg("the grant entity in the request is nil") } if req.Entity.Object != nil { @@ -4290,7 +4392,7 @@ func (node *Proxy) validGrantParams(req *milvuspb.SelectGrantRequest) error { } if req.Entity.Role == nil { - return fmt.Errorf("the role entity in the grant entity is nil") + return merr.WrapErrParameterInvalidMsg("the role entity in the grant entity is nil") } if err := ValidateRoleName(req.Entity.Role.Name); err != nil { @@ -4308,16 +4410,13 @@ func (node *Proxy) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantReq log.Debug("SelectGrant", zap.Any("req", req)) - if code, ok := node.checkHealthyAndReturnCode(); !ok { - return &milvuspb.SelectGrantResponse{Status: errorutil.UnhealthyStatus(code)}, nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return &milvuspb.SelectGrantResponse{Status: merr.Status(err)}, nil } if err := node.validGrantParams(req); err != nil { return &milvuspb.SelectGrantResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -4339,8 +4438,8 @@ func (node *Proxy) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.Refr log.Debug("RefreshPrivilegeInfoCache", zap.Any("req", req)) - if code, ok := node.checkHealthyAndReturnCode(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } if globalMetaCache != nil { @@ -4349,46 +4448,41 @@ func (node *Proxy) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.Refr OpKey: req.OpKey, }) if err != nil { - log.Error("fail to refresh policy info", + log.Warn("fail to refresh policy info", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_RefreshPolicyInfoCacheFailure, - Reason: err.Error(), - }, err + return merr.Status(err), nil } } log.Debug("RefreshPrivilegeInfoCache success") - return merr.Status(nil), nil + return merr.Success(), nil } // SetRates limits the rates of requests. func (node *Proxy) SetRates(ctx context.Context, request *proxypb.SetRatesRequest) (*commonpb.Status, error) { - resp := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - } - if !node.checkHealthy() { - resp = unhealthyStatus() + resp := merr.Success() + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + resp = merr.Status(err) return resp, nil } err := node.multiRateLimiter.SetRates(request.GetRates()) // TODO: set multiple rate limiter rates if err != nil { - resp.Reason = err.Error() + resp = merr.Status(err) return resp, nil } - resp.ErrorCode = commonpb.ErrorCode_Success + return resp, nil } func (node *Proxy) CheckHealth(ctx context.Context, request *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { - if !node.checkHealthy() { - reason := errorutil.UnHealthReason("proxy", node.session.ServerID, "proxy is unhealthy") + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.CheckHealthResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), IsHealthy: false, - Reasons: []string{reason}}, nil + Reasons: []string{err.Error()}, + }, nil } group, ctx := errgroup.WithContext(ctx) @@ -4436,7 +4530,7 @@ func (node *Proxy) CheckHealth(ctx context.Context, request *milvuspb.CheckHealt err := group.Wait() if err != nil || len(errReasons) != 0 { return &milvuspb.CheckHealthResponse{ - Status: merr.Status(nil), + Status: merr.Success(), IsHealthy: false, Reasons: errReasons, }, nil @@ -4444,7 +4538,7 @@ func (node *Proxy) CheckHealth(ctx context.Context, request *milvuspb.CheckHealt states, reasons := node.multiRateLimiter.GetQuotaStates() return &milvuspb.CheckHealthResponse{ - Status: merr.Status(nil), + Status: merr.Success(), QuotaStates: states, Reasons: reasons, IsHealthy: true, @@ -4463,16 +4557,13 @@ func (node *Proxy) RenameCollection(ctx context.Context, req *milvuspb.RenameCol log.Info("received rename collection request") var err error - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } if err := validateCollectionName(req.GetNewName()); err != nil { log.Warn("validate new collection name fail", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalCollectionName, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } req.Base = commonpbutil.NewMsgBase( @@ -4483,18 +4574,15 @@ func (node *Proxy) RenameCollection(ctx context.Context, req *milvuspb.RenameCol resp, err := node.rootCoord.RenameCollection(ctx, req) if err != nil { log.Warn("failed to rename collection", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, err + return merr.Status(err), err } return resp, nil } func (node *Proxy) CreateResourceGroup(ctx context.Context, request *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } method := "CreateResourceGroup" @@ -4554,15 +4642,12 @@ func (node *Proxy) CreateResourceGroup(ctx context.Context, request *milvuspb.Cr func getErrResponse(err error, method string) *commonpb.Status { metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_IllegalArgument, - Reason: err.Error(), - } + return merr.Status(err) } func (node *Proxy) DropResourceGroup(ctx context.Context, request *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } method := "DropResourceGroup" @@ -4614,8 +4699,8 @@ func (node *Proxy) DropResourceGroup(ctx context.Context, request *milvuspb.Drop } func (node *Proxy) TransferNode(ctx context.Context, request *milvuspb.TransferNodeRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } method := "TransferNode" @@ -4681,8 +4766,8 @@ func (node *Proxy) TransferNode(ctx context.Context, request *milvuspb.TransferN } func (node *Proxy) TransferReplica(ctx context.Context, request *milvuspb.TransferReplicaRequest) (*commonpb.Status, error) { - if !node.checkHealthy() { - return unhealthyStatus(), nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return merr.Status(err), nil } method := "TransferReplica" @@ -4748,9 +4833,9 @@ func (node *Proxy) TransferReplica(ctx context.Context, request *milvuspb.Transf } func (node *Proxy) ListResourceGroups(ctx context.Context, request *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) { - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.ListResourceGroupsResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -4811,9 +4896,9 @@ func (node *Proxy) ListResourceGroups(ctx context.Context, request *milvuspb.Lis } func (node *Proxy) DescribeResourceGroup(ctx context.Context, request *milvuspb.DescribeResourceGroupRequest) (*milvuspb.DescribeResourceGroupResponse, error) { - if !node.checkHealthy() { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.DescribeResourceGroupResponse{ - Status: unhealthyStatus(), + Status: merr.Status(err), }, nil } @@ -4875,25 +4960,19 @@ func (node *Proxy) DescribeResourceGroup(ctx context.Context, request *milvuspb. func (node *Proxy) ListIndexedSegment(ctx context.Context, request *federpb.ListIndexedSegmentRequest) (*federpb.ListIndexedSegmentResponse, error) { return &federpb.ListIndexedSegmentResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "TODO: implement me", - }, + Status: merr.Status(merr.WrapErrServiceUnavailable("unimp")), }, nil } func (node *Proxy) DescribeSegmentIndexData(ctx context.Context, request *federpb.DescribeSegmentIndexDataRequest) (*federpb.DescribeSegmentIndexDataResponse, error) { return &federpb.DescribeSegmentIndexDataResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "TODO: implement me", - }, + Status: merr.Status(merr.WrapErrServiceUnavailable("unimp")), }, nil } func (node *Proxy) Connect(ctx context.Context, request *milvuspb.ConnectRequest) (*milvuspb.ConnectResponse, error) { - if !node.checkHealthy() { - return &milvuspb.ConnectResponse{Status: unhealthyStatus()}, nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return &milvuspb.ConnectResponse{Status: merr.Status(err)}, nil } db := GetCurDBNameFromContextOrDefault(ctx) @@ -4907,6 +4986,9 @@ func (node *Proxy) Connect(ctx context.Context, request *milvuspb.ConnectRequest commonpbutil.WithMsgType(commonpb.MsgType_ListDatabases), ), }) + if err == nil { + err = merr.Error(resp.GetStatus()) + } if err != nil { log.Info("connect failed, failed to list databases", zap.Error(err)) @@ -4915,22 +4997,10 @@ func (node *Proxy) Connect(ctx context.Context, request *milvuspb.ConnectRequest }, nil } - if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - log.Info("connect failed, failed to list databases", - zap.String("code", resp.GetStatus().GetErrorCode().String()), - zap.String("reason", resp.GetStatus().GetReason())) - return &milvuspb.ConnectResponse{ - Status: proto.Clone(resp.GetStatus()).(*commonpb.Status), - }, nil - } - if !funcutil.SliceContain(resp.GetDbNames(), db) { log.Info("connect failed, target database not exist") return &milvuspb.ConnectResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, // DatabaseNotExist? - Reason: fmt.Sprintf("database not found: %s", db), - }, + Status: merr.Status(merr.WrapErrDatabaseNotFound(db)), }, nil } @@ -4954,28 +5024,126 @@ func (node *Proxy) Connect(ctx context.Context, request *milvuspb.ConnectRequest GetConnectionManager().register(ctx, int64(ts), request.GetClientInfo()) return &milvuspb.ConnectResponse{ - Status: merr.Status(nil), + Status: merr.Success(), ServerInfo: serverInfo, Identifier: int64(ts), }, nil } +func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) { + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil + } + + if paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool() { + return &milvuspb.ReplicateMessageResponse{ + Status: merr.Status(merr.ErrDenyReplicateMessage), + }, nil + } + var err error + ctxLog := log.Ctx(ctx) + + if req.GetChannelName() == "" { + ctxLog.Warn("channel name is empty") + return &milvuspb.ReplicateMessageResponse{ + Status: merr.Status(merr.WrapErrParameterInvalidMsg("invalid channel name for the replicate message request")), + }, nil + } + + // get the latest position of the replicate msg channel + replicateMsgChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() + if req.GetChannelName() == replicateMsgChannel { + msgID, err := msgstream.GetChannelLatestMsgID(ctx, node.factory, replicateMsgChannel) + if err != nil { + ctxLog.Warn("failed to get the latest message id of the replicate msg channel", zap.Error(err)) + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil + } + position := base64.StdEncoding.EncodeToString(msgID) + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(nil), Position: position}, nil + } + + msgPack := &msgstream.MsgPack{ + BeginTs: req.BeginTs, + EndTs: req.EndTs, + Msgs: make([]msgstream.TsMsg, 0), + StartPositions: req.StartPositions, + EndPositions: req.EndPositions, + } + // getTsMsgFromConsumerMsg + for i, msgBytes := range req.Msgs { + header := commonpb.MsgHeader{} + err = proto.Unmarshal(msgBytes, &header) + if err != nil { + ctxLog.Warn("failed to unmarshal msg header", zap.Int("index", i), zap.Error(err)) + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil + } + if header.GetBase() == nil { + ctxLog.Warn("msg header base is nil", zap.Int("index", i)) + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil + } + tsMsg, err := node.replicateStreamManager.GetMsgDispatcher().Unmarshal(msgBytes, header.GetBase().GetMsgType()) + if err != nil { + ctxLog.Warn("failed to unmarshal msg", zap.Int("index", i), zap.Error(err)) + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil + } + switch realMsg := tsMsg.(type) { + case *msgstream.InsertMsg: + assignedSegmentInfos, err := node.segAssigner.GetSegmentID(realMsg.GetCollectionID(), realMsg.GetPartitionID(), + realMsg.GetShardName(), uint32(realMsg.NumRows), req.EndTs) + if err != nil { + ctxLog.Warn("failed to get segment id", zap.Error(err)) + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil + } + if len(assignedSegmentInfos) == 0 { + ctxLog.Warn("no segment id assigned") + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrNoAssignSegmentID)}, nil + } + for assignSegmentID := range assignedSegmentInfos { + realMsg.SegmentID = assignSegmentID + break + } + } + msgPack.Msgs = append(msgPack.Msgs, tsMsg) + } + + msgStream, err := node.replicateStreamManager.GetReplicateMsgStream(ctx, req.ChannelName) + if err != nil { + ctxLog.Warn("failed to get msg stream from the replicate stream manager", zap.Error(err)) + return &milvuspb.ReplicateMessageResponse{ + Status: merr.Status(err), + }, nil + } + messageIDsMap, err := msgStream.Broadcast(msgPack) + if err != nil { + ctxLog.Warn("failed to produce msg", zap.Error(err)) + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil + } + var position string + if len(messageIDsMap[req.GetChannelName()]) == 0 { + ctxLog.Warn("no message id returned") + } else { + messageIDs := messageIDsMap[req.GetChannelName()] + position = base64.StdEncoding.EncodeToString(messageIDs[len(messageIDs)-1].Serialize()) + } + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(nil), Position: position}, nil +} + func (node *Proxy) ListClientInfos(ctx context.Context, req *proxypb.ListClientInfosRequest) (*proxypb.ListClientInfosResponse, error) { - if !node.checkHealthy() { - return &proxypb.ListClientInfosResponse{Status: unhealthyStatus()}, nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return &proxypb.ListClientInfosResponse{Status: merr.Status(err)}, nil } clients := GetConnectionManager().list() return &proxypb.ListClientInfosResponse{ - Status: merr.Status(nil), + Status: merr.Success(), ClientInfos: clients, }, nil } func (node *Proxy) AllocTimestamp(ctx context.Context, req *milvuspb.AllocTimestampRequest) (*milvuspb.AllocTimestampResponse, error) { - if !node.checkHealthy() { - return &milvuspb.AllocTimestampResponse{Status: unhealthyStatus()}, nil + if err := merr.CheckHealthy(node.GetStateCode()); err != nil { + return &milvuspb.AllocTimestampResponse{Status: merr.Status(err)}, nil } log.Info("AllocTimestamp request receive") @@ -4990,7 +5158,14 @@ func (node *Proxy) AllocTimestamp(ctx context.Context, req *milvuspb.AllocTimest log.Info("AllocTimestamp request success", zap.Uint64("timestamp", ts)) return &milvuspb.AllocTimestampResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Timestamp: ts, }, nil } + +func (node *Proxy) GetVersion(ctx context.Context, request *milvuspb.GetVersionRequest) (*milvuspb.GetVersionResponse, error) { + // TODO implement me + return &milvuspb.GetVersionResponse{ + Status: merr.Success(), + }, nil +} diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index ca02c7dc91e8f..ffab642dbb4f5 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -18,15 +18,20 @@ package proxy import ( "context" + "encoding/base64" "testing" + "time" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "go.uber.org/zap" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/proxypb" @@ -35,7 +40,12 @@ import ( "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/resource" ) func TestProxy_InvalidateCollectionMetaCache_remove_stream(t *testing.T) { @@ -44,14 +54,11 @@ func TestProxy_InvalidateCollectionMetaCache_remove_stream(t *testing.T) { globalMetaCache = nil defer func() { globalMetaCache = cache }() - chMgr := newMockChannelsMgr() - chMgr.removeDMLStreamFuncType = func(collectionID UniqueID) error { - log.Debug("TestProxy_InvalidateCollectionMetaCache_remove_stream, remove dml stream") - return nil - } + chMgr := NewMockChannelsMgr(t) + chMgr.EXPECT().removeDMLStream(mock.Anything).Return() node := &Proxy{chMgr: chMgr} - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) ctx := context.Background() req := &proxypb.InvalidateCollMetaCacheRequest{ @@ -65,9 +72,9 @@ func TestProxy_InvalidateCollectionMetaCache_remove_stream(t *testing.T) { func TestProxy_CheckHealth(t *testing.T) { t.Run("not healthy", func(t *testing.T) { - node := &Proxy{session: &sessionutil.Session{ServerID: 1}} + node := &Proxy{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} node.multiRateLimiter = NewMultiRateLimiter() - node.stateCode.Store(commonpb.StateCode_Abnormal) + node.UpdateStateCode(commonpb.StateCode_Abnormal) ctx := context.Background() resp, err := node.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) assert.NoError(t, err) @@ -76,16 +83,16 @@ func TestProxy_CheckHealth(t *testing.T) { }) t.Run("proxy health check is ok", func(t *testing.T) { - qc := &mocks.MockQueryCoord{} + qc := &mocks.MockQueryCoordClient{} qc.EXPECT().CheckHealth(mock.Anything, mock.Anything).Return(&milvuspb.CheckHealthResponse{IsHealthy: true}, nil) node := &Proxy{ rootCoord: NewRootCoordMock(), queryCoord: qc, dataCoord: NewDataCoordMock(), - session: &sessionutil.Session{ServerID: 1}, + session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}, } node.multiRateLimiter = NewMultiRateLimiter() - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) ctx := context.Background() resp, err := node.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) assert.NoError(t, err) @@ -95,7 +102,9 @@ func TestProxy_CheckHealth(t *testing.T) { t.Run("proxy health check is fail", func(t *testing.T) { checkHealthFunc1 := func(ctx context.Context, - req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { + req *milvuspb.CheckHealthRequest, + opts ...grpc.CallOption, + ) (*milvuspb.CheckHealthResponse, error) { return &milvuspb.CheckHealthResponse{ IsHealthy: false, Reasons: []string{"unHealth"}, @@ -105,17 +114,18 @@ func TestProxy_CheckHealth(t *testing.T) { dataCoordMock := NewDataCoordMock() dataCoordMock.checkHealthFunc = checkHealthFunc1 - qc := &mocks.MockQueryCoord{} + qc := &mocks.MockQueryCoordClient{} qc.EXPECT().CheckHealth(mock.Anything, mock.Anything).Return(nil, errors.New("test")) node := &Proxy{ - session: &sessionutil.Session{ServerID: 1}, + session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}, rootCoord: NewRootCoordMock(func(mock *RootCoordMock) { mock.checkHealthFunc = checkHealthFunc1 }), queryCoord: qc, - dataCoord: dataCoordMock} + dataCoord: dataCoordMock, + } node.multiRateLimiter = NewMultiRateLimiter() - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) ctx := context.Background() resp, err := node.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) assert.NoError(t, err) @@ -124,7 +134,7 @@ func TestProxy_CheckHealth(t *testing.T) { }) t.Run("check quota state", func(t *testing.T) { - qc := &mocks.MockQueryCoord{} + qc := &mocks.MockQueryCoordClient{} qc.EXPECT().CheckHealth(mock.Anything, mock.Anything).Return(&milvuspb.CheckHealthResponse{IsHealthy: true}, nil) node := &Proxy{ rootCoord: NewRootCoordMock(), @@ -132,7 +142,7 @@ func TestProxy_CheckHealth(t *testing.T) { queryCoord: qc, } node.multiRateLimiter = NewMultiRateLimiter() - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) resp, err := node.CheckHealth(context.Background(), &milvuspb.CheckHealthRequest{}) assert.NoError(t, err) assert.Equal(t, true, resp.IsHealthy) @@ -158,32 +168,32 @@ func TestProxy_CheckHealth(t *testing.T) { func TestProxyRenameCollection(t *testing.T) { t.Run("not healthy", func(t *testing.T) { - node := &Proxy{session: &sessionutil.Session{ServerID: 1}} - node.stateCode.Store(commonpb.StateCode_Abnormal) + node := &Proxy{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} + node.UpdateStateCode(commonpb.StateCode_Abnormal) ctx := context.Background() resp, err := node.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp), merr.ErrServiceNotReady) }) t.Run("rename with illegal new collection name", func(t *testing.T) { - node := &Proxy{session: &sessionutil.Session{ServerID: 1}} - node.stateCode.Store(commonpb.StateCode_Healthy) + node := &Proxy{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} + node.UpdateStateCode(commonpb.StateCode_Healthy) ctx := context.Background() resp, err := node.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{NewName: "$#^%#&#$*!)#@!"}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_IllegalCollectionName, resp.GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp), merr.ErrParameterInvalid) }) t.Run("rename fail", func(t *testing.T) { - rc := mocks.NewRootCoord(t) + rc := mocks.NewMockRootCoordClient(t) rc.On("RenameCollection", mock.Anything, mock.Anything). Return(nil, errors.New("fail")) node := &Proxy{ - session: &sessionutil.Session{ServerID: 1}, + session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}, rootCoord: rc, } - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) ctx := context.Background() resp, err := node.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{NewName: "new"}) @@ -192,16 +202,14 @@ func TestProxyRenameCollection(t *testing.T) { }) t.Run("rename ok", func(t *testing.T) { - rc := mocks.NewRootCoord(t) + rc := mocks.NewMockRootCoordClient(t) rc.On("RenameCollection", mock.Anything, mock.Anything). - Return(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil) + Return(merr.Success(), nil) node := &Proxy{ - session: &sessionutil.Session{ServerID: 1}, + session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}, rootCoord: rc, } - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) ctx := context.Background() resp, err := node.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{NewName: "new"}) @@ -217,9 +225,9 @@ func TestProxy_ResourceGroup(t *testing.T) { node, err := NewProxy(ctx, factory) assert.NoError(t, err) node.multiRateLimiter = NewMultiRateLimiter() - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) node.SetQueryCoordClient(qc) tsoAllocatorIns := newMockTsoAllocator() @@ -279,7 +287,7 @@ func TestProxy_ResourceGroup(t *testing.T) { qc.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(&milvuspb.ListResourceGroupsResponse{Status: successStatus}, nil) resp, err := node.ListResourceGroups(ctx, &milvuspb.ListResourceGroupsRequest{}) assert.NoError(t, err) - assert.Equal(t, resp.Status.ErrorCode, commonpb.ErrorCode_Success) + assert.True(t, merr.Ok(resp.GetStatus())) }) t.Run("describe resource group", func(t *testing.T) { @@ -298,7 +306,7 @@ func TestProxy_ResourceGroup(t *testing.T) { ResourceGroup: "rg", }) assert.NoError(t, err) - assert.Equal(t, resp.Status.ErrorCode, commonpb.ErrorCode_Success) + assert.True(t, merr.Ok(resp.GetStatus())) }) } @@ -309,9 +317,9 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) { node, err := NewProxy(ctx, factory) assert.NoError(t, err) node.multiRateLimiter = NewMultiRateLimiter() - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) node.SetQueryCoordClient(qc) qc.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil) @@ -330,7 +338,7 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) { ResourceGroup: "...", }) assert.NoError(t, err) - assert.Equal(t, resp.ErrorCode, commonpb.ErrorCode_IllegalArgument) + assert.ErrorIs(t, merr.Error(resp), merr.ErrParameterInvalid) }) t.Run("drop resource group", func(t *testing.T) { @@ -348,7 +356,7 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) { NumNode: 1, }) assert.NoError(t, err) - assert.Equal(t, resp.ErrorCode, commonpb.ErrorCode_IllegalArgument) + assert.ErrorIs(t, merr.Error(resp), merr.ErrParameterInvalid) }) t.Run("transfer replica", func(t *testing.T) { @@ -359,7 +367,7 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) { CollectionName: "collection1", }) assert.NoError(t, err) - assert.Equal(t, resp.ErrorCode, commonpb.ErrorCode_IllegalArgument) + assert.ErrorIs(t, merr.Error(resp), merr.ErrParameterInvalid) }) } @@ -373,55 +381,61 @@ func TestProxy_FlushAll_DbCollection(t *testing.T) { {"flushAll set db", &milvuspb.FlushAllRequest{DbName: "default"}, true}, {"flushAll set db, db not exist", &milvuspb.FlushAllRequest{DbName: "default2"}, false}, } - for _, test := range tests { - factory := dependency.NewDefaultFactory(true) - ctx := context.Background() - paramtable.Init() - node, err := NewProxy(ctx, factory) - assert.NoError(t, err) - node.stateCode.Store(commonpb.StateCode_Healthy) - node.tsoAllocator = ×tampAllocator{ - tso: newMockTimestampAllocatorInterface(), - } + cacheBak := globalMetaCache + defer func() { globalMetaCache = cacheBak }() + // set expectations + cache := NewMockCache(t) + cache.On("GetCollectionID", + mock.Anything, // context.Context + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + ).Return(UniqueID(0), nil).Maybe() - Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000") - node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) - assert.NoError(t, err) - err = node.sched.Start() - assert.NoError(t, err) - defer node.sched.Close() - node.dataCoord = mocks.NewMockDataCoord(t) - node.rootCoord = mocks.NewRootCoord(t) - successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} + cache.On("RemoveDatabase", + mock.Anything, // context.Context + mock.AnythingOfType("string"), + ).Maybe() - // set expectations - cache := NewMockCache(t) - cache.On("GetCollectionID", - mock.Anything, // context.Context - mock.AnythingOfType("string"), - mock.AnythingOfType("string"), - ).Return(UniqueID(0), nil).Maybe() + globalMetaCache = cache - cache.On("RemoveDatabase", - mock.Anything, // context.Context - mock.AnythingOfType("string"), - ).Maybe() + for _, test := range tests { + t.Run(test.testName, func(t *testing.T) { + factory := dependency.NewDefaultFactory(true) + ctx := context.Background() + paramtable.Init() - globalMetaCache = cache + node, err := NewProxy(ctx, factory) + assert.NoError(t, err) + node.UpdateStateCode(commonpb.StateCode_Healthy) + node.tsoAllocator = ×tampAllocator{ + tso: newMockTimestampAllocatorInterface(), + } + rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() + node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx) + assert.NoError(t, err) + node.replicateMsgStream.AsProducer([]string{rpcRequestChannel}) - node.dataCoord.(*mocks.MockDataCoord).EXPECT().Flush(mock.Anything, mock.Anything). - Return(&datapb.FlushResponse{Status: successStatus}, nil).Maybe() - node.rootCoord.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything). - Return(&milvuspb.ShowCollectionsResponse{Status: successStatus, CollectionNames: []string{"col-0"}}, nil).Maybe() - node.rootCoord.(*mocks.RootCoord).EXPECT().ListDatabases(mock.Anything, mock.Anything). - Return(&milvuspb.ListDatabasesResponse{Status: successStatus, DbNames: []string{"default"}}, nil).Maybe() + Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000") + node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) + assert.NoError(t, err) + err = node.sched.Start() + assert.NoError(t, err) + defer node.sched.Close() + node.dataCoord = mocks.NewMockDataCoordClient(t) + node.rootCoord = mocks.NewMockRootCoordClient(t) + successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} + node.dataCoord.(*mocks.MockDataCoordClient).EXPECT().Flush(mock.Anything, mock.Anything). + Return(&datapb.FlushResponse{Status: successStatus}, nil).Maybe() + node.rootCoord.(*mocks.MockRootCoordClient).EXPECT().ShowCollections(mock.Anything, mock.Anything). + Return(&milvuspb.ShowCollectionsResponse{Status: successStatus, CollectionNames: []string{"col-0"}}, nil).Maybe() + node.rootCoord.(*mocks.MockRootCoordClient).EXPECT().ListDatabases(mock.Anything, mock.Anything). + Return(&milvuspb.ListDatabasesResponse{Status: successStatus, DbNames: []string{"default"}}, nil).Maybe() - t.Run(test.testName, func(t *testing.T) { resp, err := node.FlushAll(ctx, test.FlushRequest) assert.NoError(t, err) if test.ExpectedSuccess { - assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + assert.True(t, merr.Ok(resp.GetStatus())) } else { assert.NotEqual(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) } @@ -436,10 +450,14 @@ func TestProxy_FlushAll(t *testing.T) { node, err := NewProxy(ctx, factory) assert.NoError(t, err) - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) node.tsoAllocator = ×tampAllocator{ tso: newMockTimestampAllocatorInterface(), } + rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() + node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx) + assert.NoError(t, err) + node.replicateMsgStream.AsProducer([]string{rpcRequestChannel}) Params.Save(Params.ProxyCfg.MaxTaskNum.Key, "1000") node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) @@ -447,8 +465,11 @@ func TestProxy_FlushAll(t *testing.T) { err = node.sched.Start() assert.NoError(t, err) defer node.sched.Close() - node.dataCoord = mocks.NewMockDataCoord(t) - node.rootCoord = mocks.NewRootCoord(t) + node.dataCoord = mocks.NewMockDataCoordClient(t) + node.rootCoord = mocks.NewMockRootCoordClient(t) + + cacheBak := globalMetaCache + defer func() { globalMetaCache = cacheBak }() // set expectations cache := NewMockCache(t) @@ -465,25 +486,25 @@ func TestProxy_FlushAll(t *testing.T) { globalMetaCache = cache successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} - node.dataCoord.(*mocks.MockDataCoord).EXPECT().Flush(mock.Anything, mock.Anything). + node.dataCoord.(*mocks.MockDataCoordClient).EXPECT().Flush(mock.Anything, mock.Anything). Return(&datapb.FlushResponse{Status: successStatus}, nil).Maybe() - node.rootCoord.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything). + node.rootCoord.(*mocks.MockRootCoordClient).EXPECT().ShowCollections(mock.Anything, mock.Anything). Return(&milvuspb.ShowCollectionsResponse{Status: successStatus, CollectionNames: []string{"col-0"}}, nil).Maybe() - node.rootCoord.(*mocks.RootCoord).EXPECT().ListDatabases(mock.Anything, mock.Anything). + node.rootCoord.(*mocks.MockRootCoordClient).EXPECT().ListDatabases(mock.Anything, mock.Anything). Return(&milvuspb.ListDatabasesResponse{Status: successStatus, DbNames: []string{"default"}}, nil).Maybe() t.Run("FlushAll", func(t *testing.T) { resp, err := node.FlushAll(ctx, &milvuspb.FlushAllRequest{}) assert.NoError(t, err) - assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + assert.True(t, merr.Ok(resp.GetStatus())) }) t.Run("FlushAll failed, server is abnormal", func(t *testing.T) { - node.stateCode.Store(commonpb.StateCode_Abnormal) + node.UpdateStateCode(commonpb.StateCode_Abnormal) resp, err := node.FlushAll(ctx, &milvuspb.FlushAllRequest{}) assert.NoError(t, err) - assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError) - node.stateCode.Store(commonpb.StateCode_Healthy) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) + node.UpdateStateCode(commonpb.StateCode_Healthy) }) t.Run("FlushAll failed, get id failed", func(t *testing.T) { @@ -503,8 +524,8 @@ func TestProxy_FlushAll(t *testing.T) { }) t.Run("FlushAll failed, DataCoord flush failed", func(t *testing.T) { - node.dataCoord.(*mocks.MockDataCoord).ExpectedCalls = nil - node.dataCoord.(*mocks.MockDataCoord).EXPECT().Flush(mock.Anything, mock.Anything). + node.dataCoord.(*mocks.MockDataCoordClient).ExpectedCalls = nil + node.dataCoord.(*mocks.MockDataCoordClient).EXPECT().Flush(mock.Anything, mock.Anything). Return(&datapb.FlushResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -517,10 +538,10 @@ func TestProxy_FlushAll(t *testing.T) { }) t.Run("FlushAll failed, RootCoord showCollections failed", func(t *testing.T) { - node.rootCoord.(*mocks.RootCoord).ExpectedCalls = nil - node.rootCoord.(*mocks.RootCoord).EXPECT().ListDatabases(mock.Anything, mock.Anything). + node.rootCoord.(*mocks.MockRootCoordClient).ExpectedCalls = nil + node.rootCoord.(*mocks.MockRootCoordClient).EXPECT().ListDatabases(mock.Anything, mock.Anything). Return(&milvuspb.ListDatabasesResponse{Status: successStatus, DbNames: []string{"default"}}, nil).Maybe() - node.rootCoord.(*mocks.RootCoord).EXPECT().ShowCollections(mock.Anything, mock.Anything). + node.rootCoord.(*mocks.MockRootCoordClient).EXPECT().ShowCollections(mock.Anything, mock.Anything). Return(&milvuspb.ShowCollectionsResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -533,8 +554,8 @@ func TestProxy_FlushAll(t *testing.T) { }) t.Run("FlushAll failed, RootCoord showCollections failed", func(t *testing.T) { - node.rootCoord.(*mocks.RootCoord).ExpectedCalls = nil - node.rootCoord.(*mocks.RootCoord).EXPECT().ListDatabases(mock.Anything, mock.Anything). + node.rootCoord.(*mocks.MockRootCoordClient).ExpectedCalls = nil + node.rootCoord.(*mocks.MockRootCoordClient).EXPECT().ListDatabases(mock.Anything, mock.Anything). Return(&milvuspb.ListDatabasesResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -553,35 +574,35 @@ func TestProxy_GetFlushAllState(t *testing.T) { node, err := NewProxy(ctx, factory) assert.NoError(t, err) - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) node.tsoAllocator = ×tampAllocator{ tso: newMockTimestampAllocatorInterface(), } - node.dataCoord = mocks.NewMockDataCoord(t) - node.rootCoord = mocks.NewRootCoord(t) + node.dataCoord = mocks.NewMockDataCoordClient(t) + node.rootCoord = mocks.NewMockRootCoordClient(t) // set expectations successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} - node.dataCoord.(*mocks.MockDataCoord).EXPECT().GetFlushAllState(mock.Anything, mock.Anything). + node.dataCoord.(*mocks.MockDataCoordClient).EXPECT().GetFlushAllState(mock.Anything, mock.Anything). Return(&milvuspb.GetFlushAllStateResponse{Status: successStatus}, nil).Maybe() t.Run("GetFlushAllState success", func(t *testing.T) { resp, err := node.GetFlushAllState(ctx, &milvuspb.GetFlushAllStateRequest{}) assert.NoError(t, err) - assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + assert.True(t, merr.Ok(resp.GetStatus())) }) t.Run("GetFlushAllState failed, server is abnormal", func(t *testing.T) { - node.stateCode.Store(commonpb.StateCode_Abnormal) + node.UpdateStateCode(commonpb.StateCode_Abnormal) resp, err := node.GetFlushAllState(ctx, &milvuspb.GetFlushAllStateRequest{}) assert.NoError(t, err) - assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError) - node.stateCode.Store(commonpb.StateCode_Healthy) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) + node.UpdateStateCode(commonpb.StateCode_Healthy) }) t.Run("DataCoord GetFlushAllState failed", func(t *testing.T) { - node.dataCoord.(*mocks.MockDataCoord).ExpectedCalls = nil - node.dataCoord.(*mocks.MockDataCoord).EXPECT().GetFlushAllState(mock.Anything, mock.Anything). + node.dataCoord.(*mocks.MockDataCoordClient).ExpectedCalls = nil + node.dataCoord.(*mocks.MockDataCoordClient).EXPECT().GetFlushAllState(mock.Anything, mock.Anything). Return(&milvuspb.GetFlushAllStateResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -594,18 +615,98 @@ func TestProxy_GetFlushAllState(t *testing.T) { }) } +func TestProxy_GetFlushState(t *testing.T) { + factory := dependency.NewDefaultFactory(true) + ctx := context.Background() + + node, err := NewProxy(ctx, factory) + assert.NoError(t, err) + node.UpdateStateCode(commonpb.StateCode_Healthy) + node.tsoAllocator = ×tampAllocator{ + tso: newMockTimestampAllocatorInterface(), + } + node.dataCoord = mocks.NewMockDataCoordClient(t) + node.rootCoord = mocks.NewMockRootCoordClient(t) + + // set expectations + successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} + node.dataCoord.(*mocks.MockDataCoordClient).EXPECT().GetFlushState(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.GetFlushStateResponse{Status: successStatus}, nil).Maybe() + + t.Run("GetFlushState success", func(t *testing.T) { + resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{}) + assert.NoError(t, err) + assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + }) + + t.Run("GetFlushState failed, server is abnormal", func(t *testing.T) { + node.UpdateStateCode(commonpb.StateCode_Abnormal) + resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{}) + assert.NoError(t, err) + assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_NotReadyServe) + node.UpdateStateCode(commonpb.StateCode_Healthy) + }) + + t.Run("GetFlushState with collection name", func(t *testing.T) { + resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ + CollectionName: "*", + }) + assert.NoError(t, err) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrParameterInvalid) + + cacheBak := globalMetaCache + defer func() { globalMetaCache = cacheBak }() + cache := NewMockCache(t) + cache.On("GetCollectionID", + mock.Anything, // context.Context + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + ).Return(UniqueID(0), nil).Maybe() + globalMetaCache = cache + + resp, err = node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ + CollectionName: "collection1", + }) + assert.NoError(t, err) + assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + }) + + t.Run("DataCoord GetFlushState failed", func(t *testing.T) { + node.dataCoord.(*mocks.MockDataCoordClient).ExpectedCalls = nil + node.dataCoord.(*mocks.MockDataCoordClient).EXPECT().GetFlushState(mock.Anything, mock.Anything, mock.Anything). + Return(&milvuspb.GetFlushStateResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "mock err", + }, + }, nil) + resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{}) + assert.NoError(t, err) + assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError) + }) + + t.Run("GetFlushState return error", func(t *testing.T) { + node.dataCoord.(*mocks.MockDataCoordClient).ExpectedCalls = nil + node.dataCoord.(*mocks.MockDataCoordClient).EXPECT().GetFlushState(mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("fake error")) + resp, err := node.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{}) + assert.NoError(t, err) + assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError) + }) +} + func TestProxy_GetReplicas(t *testing.T) { factory := dependency.NewDefaultFactory(true) ctx := context.Background() node, err := NewProxy(ctx, factory) assert.NoError(t, err) - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) node.tsoAllocator = ×tampAllocator{ tso: newMockTimestampAllocatorInterface(), } - mockQC := mocks.NewMockQueryCoord(t) - mockRC := mocks.NewRootCoord(t) + mockQC := mocks.NewMockQueryCoordClient(t) + mockRC := mocks.NewMockRootCoordClient(t) node.queryCoord = mockQC node.rootCoord = mockRC @@ -617,17 +718,17 @@ func TestProxy_GetReplicas(t *testing.T) { CollectionID: 1000, }) assert.NoError(t, err) - assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) + assert.True(t, merr.Ok(resp.GetStatus())) }) t.Run("proxy_not_healthy", func(t *testing.T) { - node.stateCode.Store(commonpb.StateCode_Abnormal) + node.UpdateStateCode(commonpb.StateCode_Abnormal) resp, err := node.GetReplicas(ctx, &milvuspb.GetReplicasRequest{ CollectionID: 1000, }) assert.NoError(t, err) - assert.Equal(t, resp.GetStatus().GetErrorCode(), commonpb.ErrorCode_UnexpectedError) - node.stateCode.Store(commonpb.StateCode_Healthy) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) + node.UpdateStateCode(commonpb.StateCode_Healthy) }) t.Run("QueryCoordClient_returnsError", func(t *testing.T) { @@ -653,7 +754,7 @@ func TestProxy_Connect(t *testing.T) { }) t.Run("failed to list database", func(t *testing.T) { - r := mocks.NewRootCoord(t) + r := mocks.NewMockRootCoordClient(t) r.On("ListDatabases", mock.Anything, mock.Anything, @@ -668,12 +769,12 @@ func TestProxy_Connect(t *testing.T) { }) t.Run("list database error", func(t *testing.T) { - r := mocks.NewRootCoord(t) + r := mocks.NewMockRootCoordClient(t) r.On("ListDatabases", mock.Anything, mock.Anything, ).Return(&milvuspb.ListDatabasesResponse{ - Status: unhealthyStatus(), + Status: merr.Status(merr.WrapErrServiceNotReady(paramtable.GetRole(), paramtable.GetNodeID(), "initialization")), }, nil) node := &Proxy{rootCoord: r} @@ -690,14 +791,12 @@ func TestProxy_Connect(t *testing.T) { }) ctx := metadata.NewIncomingContext(context.TODO(), md) - r := mocks.NewRootCoord(t) + r := mocks.NewMockRootCoordClient(t) r.On("ListDatabases", mock.Anything, mock.Anything, ).Return(&milvuspb.ListDatabasesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), DbNames: []string{}, }, nil) @@ -715,14 +814,12 @@ func TestProxy_Connect(t *testing.T) { }) ctx := metadata.NewIncomingContext(context.TODO(), md) - r := mocks.NewRootCoord(t) + r := mocks.NewMockRootCoordClient(t) r.On("ListDatabases", mock.Anything, mock.Anything, ).Return(&milvuspb.ListDatabasesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), DbNames: []string{"20230525"}, }, nil) @@ -748,14 +845,12 @@ func TestProxy_Connect(t *testing.T) { }) ctx := metadata.NewIncomingContext(context.TODO(), md) - r := mocks.NewRootCoord(t) + r := mocks.NewMockRootCoordClient(t) r.On("ListDatabases", mock.Anything, mock.Anything, ).Return(&milvuspb.ListDatabasesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), DbNames: []string{"20230525"}, }, nil) @@ -764,9 +859,7 @@ func TestProxy_Connect(t *testing.T) { mock.Anything, mock.Anything, ).Return(&rootcoordpb.AllocTimestampResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Timestamp: 20230518, Count: 1, }, nil) @@ -801,7 +894,6 @@ func TestProxy_ListClientInfos(t *testing.T) { resp, err := node.ListClientInfos(context.TODO(), nil) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) - }) } @@ -809,12 +901,12 @@ func TestProxyCreateDatabase(t *testing.T) { paramtable.Init() t.Run("not healthy", func(t *testing.T) { - node := &Proxy{session: &sessionutil.Session{ServerID: 1}} - node.stateCode.Store(commonpb.StateCode_Abnormal) + node := &Proxy{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} + node.UpdateStateCode(commonpb.StateCode_Abnormal) ctx := context.Background() resp, err := node.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp), merr.ErrServiceNotReady) }) factory := dependency.NewDefaultFactory(true) @@ -826,7 +918,7 @@ func TestProxyCreateDatabase(t *testing.T) { tso: newMockTimestampAllocatorInterface(), } node.multiRateLimiter = NewMultiRateLimiter() - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched.ddQueue.setMaxTaskNum(10) assert.NoError(t, err) @@ -834,8 +926,13 @@ func TestProxyCreateDatabase(t *testing.T) { assert.NoError(t, err) defer node.sched.Close() + rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() + node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx) + assert.NoError(t, err) + node.replicateMsgStream.AsProducer([]string{rpcRequestChannel}) + t.Run("create database fail", func(t *testing.T) { - rc := mocks.NewRootCoord(t) + rc := mocks.NewMockRootCoordClient(t) rc.On("CreateDatabase", mock.Anything, mock.Anything). Return(nil, errors.New("fail")) node.rootCoord = rc @@ -846,13 +943,11 @@ func TestProxyCreateDatabase(t *testing.T) { }) t.Run("create database ok", func(t *testing.T) { - rc := mocks.NewRootCoord(t) + rc := mocks.NewMockRootCoordClient(t) rc.On("CreateDatabase", mock.Anything, mock.Anything). - Return(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil) + Return(merr.Success(), nil) node.rootCoord = rc - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) ctx := context.Background() resp, err := node.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{DbName: "db"}) @@ -865,12 +960,12 @@ func TestProxyDropDatabase(t *testing.T) { paramtable.Init() t.Run("not healthy", func(t *testing.T) { - node := &Proxy{session: &sessionutil.Session{ServerID: 1}} - node.stateCode.Store(commonpb.StateCode_Abnormal) + node := &Proxy{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} + node.UpdateStateCode(commonpb.StateCode_Abnormal) ctx := context.Background() resp, err := node.DropDatabase(ctx, &milvuspb.DropDatabaseRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp), merr.ErrServiceNotReady) }) factory := dependency.NewDefaultFactory(true) @@ -882,7 +977,7 @@ func TestProxyDropDatabase(t *testing.T) { tso: newMockTimestampAllocatorInterface(), } node.multiRateLimiter = NewMultiRateLimiter() - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched.ddQueue.setMaxTaskNum(10) assert.NoError(t, err) @@ -890,8 +985,13 @@ func TestProxyDropDatabase(t *testing.T) { assert.NoError(t, err) defer node.sched.Close() + rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() + node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx) + assert.NoError(t, err) + node.replicateMsgStream.AsProducer([]string{rpcRequestChannel}) + t.Run("drop database fail", func(t *testing.T) { - rc := mocks.NewRootCoord(t) + rc := mocks.NewMockRootCoordClient(t) rc.On("DropDatabase", mock.Anything, mock.Anything). Return(nil, errors.New("fail")) node.rootCoord = rc @@ -902,13 +1002,11 @@ func TestProxyDropDatabase(t *testing.T) { }) t.Run("drop database ok", func(t *testing.T) { - rc := mocks.NewRootCoord(t) + rc := mocks.NewMockRootCoordClient(t) rc.On("DropDatabase", mock.Anything, mock.Anything). - Return(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil) + Return(merr.Success(), nil) node.rootCoord = rc - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) ctx := context.Background() resp, err := node.DropDatabase(ctx, &milvuspb.DropDatabaseRequest{DbName: "db"}) @@ -921,12 +1019,12 @@ func TestProxyListDatabase(t *testing.T) { paramtable.Init() t.Run("not healthy", func(t *testing.T) { - node := &Proxy{session: &sessionutil.Session{ServerID: 1}} - node.stateCode.Store(commonpb.StateCode_Abnormal) + node := &Proxy{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}} + node.UpdateStateCode(commonpb.StateCode_Abnormal) ctx := context.Background() resp, err := node.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) factory := dependency.NewDefaultFactory(true) @@ -938,7 +1036,7 @@ func TestProxyListDatabase(t *testing.T) { tso: newMockTimestampAllocatorInterface(), } node.multiRateLimiter = NewMultiRateLimiter() - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched.ddQueue.setMaxTaskNum(10) assert.NoError(t, err) @@ -947,7 +1045,7 @@ func TestProxyListDatabase(t *testing.T) { defer node.sched.Close() t.Run("list database fail", func(t *testing.T) { - rc := mocks.NewRootCoord(t) + rc := mocks.NewMockRootCoordClient(t) rc.On("ListDatabases", mock.Anything, mock.Anything). Return(nil, errors.New("fail")) node.rootCoord = rc @@ -958,14 +1056,13 @@ func TestProxyListDatabase(t *testing.T) { }) t.Run("list database ok", func(t *testing.T) { - rc := mocks.NewRootCoord(t) + rc := mocks.NewMockRootCoordClient(t) rc.On("ListDatabases", mock.Anything, mock.Anything). Return(&milvuspb.ListDatabasesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }}, nil) + Status: merr.Success(), + }, nil) node.rootCoord = rc - node.stateCode.Store(commonpb.StateCode_Healthy) + node.UpdateStateCode(commonpb.StateCode_Healthy) ctx := context.Background() resp, err := node.ListDatabases(ctx, &milvuspb.ListDatabasesRequest{}) @@ -1022,3 +1119,261 @@ func TestProxy_AllocTimestamp(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) }) } + +func TestProxy_ReplicateMessage(t *testing.T) { + paramtable.Init() + defer paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true") + t.Run("proxy unhealthy", func(t *testing.T) { + node := &Proxy{} + node.UpdateStateCode(commonpb.StateCode_Abnormal) + + resp, err := node.ReplicateMessage(context.TODO(), nil) + assert.NoError(t, err) + assert.NotEqual(t, 0, resp.GetStatus().GetCode()) + }) + + t.Run("not backup instance", func(t *testing.T) { + node := &Proxy{} + node.UpdateStateCode(commonpb.StateCode_Healthy) + + resp, err := node.ReplicateMessage(context.TODO(), nil) + assert.NoError(t, err) + assert.NotEqual(t, 0, resp.GetStatus().GetCode()) + }) + + t.Run("empty channel name", func(t *testing.T) { + node := &Proxy{} + node.UpdateStateCode(commonpb.StateCode_Healthy) + paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false") + + resp, err := node.ReplicateMessage(context.TODO(), nil) + assert.NoError(t, err) + assert.NotEqual(t, 0, resp.GetStatus().GetCode()) + }) + + t.Run("fail to get msg stream", func(t *testing.T) { + factory := newMockMsgStreamFactory() + factory.f = func(ctx context.Context) (msgstream.MsgStream, error) { + return nil, errors.New("mock error: get msg stream") + } + resourceManager := resource.NewManager(time.Second, 2*time.Second, nil) + manager := NewReplicateStreamManager(context.Background(), factory, resourceManager) + + node := &Proxy{ + replicateStreamManager: manager, + } + node.UpdateStateCode(commonpb.StateCode_Healthy) + paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false") + + resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{ChannelName: "unit_test_replicate_message"}) + assert.NoError(t, err) + assert.NotEqual(t, 0, resp.GetStatus().GetCode()) + }) + + t.Run("get latest position", func(t *testing.T) { + paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false") + defer paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true") + + factory := dependency.NewMockFactory(t) + stream := msgstream.NewMockMsgStream(t) + mockMsgID := mqwrapper.NewMockMessageID(t) + + factory.EXPECT().NewMsgStream(mock.Anything).Return(stream, nil).Once() + mockMsgID.EXPECT().Serialize().Return([]byte("mock")).Once() + stream.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + stream.EXPECT().GetLatestMsgID(mock.Anything).Return(mockMsgID, nil).Once() + stream.EXPECT().Close().Return() + node := &Proxy{ + factory: factory, + } + node.UpdateStateCode(commonpb.StateCode_Healthy) + resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{ + ChannelName: Params.CommonCfg.ReplicateMsgChannel.GetValue(), + }) + assert.NoError(t, err) + assert.EqualValues(t, 0, resp.GetStatus().GetCode()) + assert.Equal(t, base64.StdEncoding.EncodeToString([]byte("mock")), resp.GetPosition()) + + factory.EXPECT().NewMsgStream(mock.Anything).Return(nil, errors.New("mock")).Once() + resp, err = node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{ + ChannelName: Params.CommonCfg.ReplicateMsgChannel.GetValue(), + }) + assert.NoError(t, err) + assert.NotEqualValues(t, 0, resp.GetStatus().GetCode()) + }) + + t.Run("invalid msg pack", func(t *testing.T) { + node := &Proxy{ + replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil), + } + node.UpdateStateCode(commonpb.StateCode_Healthy) + paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false") + { + resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{ + ChannelName: "unit_test_replicate_message", + Msgs: [][]byte{{1, 2, 3}}, + }) + assert.NoError(t, err) + assert.NotEqual(t, 0, resp.GetStatus().GetCode()) + } + + { + timeTickResult := msgpb.TimeTickMsg{} + timeTickMsg := &msgstream.TimeTickMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 1, + EndTimestamp: 10, + HashValues: []uint32{0}, + }, + TimeTickMsg: timeTickResult, + } + msgBytes, _ := timeTickMsg.Marshal(timeTickMsg) + resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{ + ChannelName: "unit_test_replicate_message", + Msgs: [][]byte{msgBytes.([]byte)}, + }) + assert.NoError(t, err) + log.Info("resp", zap.Any("resp", resp)) + assert.NotEqual(t, 0, resp.GetStatus().GetCode()) + } + + { + timeTickResult := msgpb.TimeTickMsg{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType(-1)), + commonpbutil.WithMsgID(0), + commonpbutil.WithTimeStamp(10), + commonpbutil.WithSourceID(-1), + ), + } + timeTickMsg := &msgstream.TimeTickMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 1, + EndTimestamp: 10, + HashValues: []uint32{0}, + }, + TimeTickMsg: timeTickResult, + } + msgBytes, _ := timeTickMsg.Marshal(timeTickMsg) + resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{ + ChannelName: "unit_test_replicate_message", + Msgs: [][]byte{msgBytes.([]byte)}, + }) + assert.NoError(t, err) + log.Info("resp", zap.Any("resp", resp)) + assert.NotEqual(t, 0, resp.GetStatus().GetCode()) + } + }) + + t.Run("success", func(t *testing.T) { + paramtable.Init() + factory := newMockMsgStreamFactory() + msgStreamObj := msgstream.NewMockMsgStream(t) + msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return() + msgStreamObj.EXPECT().AsProducer(mock.Anything).Return() + msgStreamObj.EXPECT().EnableProduce(mock.Anything).Return() + msgStreamObj.EXPECT().Close().Return() + mockMsgID1 := mqwrapper.NewMockMessageID(t) + mockMsgID2 := mqwrapper.NewMockMessageID(t) + mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2")) + broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqwrapper.MessageID{ + "unit_test_replicate_message": {mockMsgID1, mockMsgID2}, + }, nil) + + factory.f = func(ctx context.Context) (msgstream.MsgStream, error) { + return msgStreamObj, nil + } + resourceManager := resource.NewManager(time.Second, 2*time.Second, nil) + manager := NewReplicateStreamManager(context.Background(), factory, resourceManager) + + ctx := context.Background() + dataCoord := &mockDataCoord{} + dataCoord.expireTime = Timestamp(1000) + segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick1) + assert.NoError(t, err) + segAllocator.Start() + + node := &Proxy{ + replicateStreamManager: manager, + segAssigner: segAllocator, + } + node.UpdateStateCode(commonpb.StateCode_Healthy) + paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false") + + insertMsg := &msgstream.InsertMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 4, + EndTimestamp: 10, + HashValues: []uint32{0}, + MsgPosition: &msgstream.MsgPosition{ + ChannelName: "unit_test_replicate_message", + MsgID: []byte("mock message id 2"), + }, + }, + InsertRequest: msgpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + MsgID: 10001, + Timestamp: 10, + SourceID: -1, + }, + ShardName: "unit_test_replicate_message_v1", + DbName: "default", + CollectionName: "foo_collection", + PartitionName: "_default", + DbID: 1, + CollectionID: 11, + PartitionID: 22, + SegmentID: 33, + Timestamps: []uint64{10}, + RowIDs: []int64{66}, + NumRows: 1, + }, + } + msgBytes, _ := insertMsg.Marshal(insertMsg) + + replicateRequest := &milvuspb.ReplicateMessageRequest{ + ChannelName: "unit_test_replicate_message", + BeginTs: 1, + EndTs: 10, + Msgs: [][]byte{msgBytes.([]byte)}, + StartPositions: []*msgpb.MsgPosition{ + {ChannelName: "unit_test_replicate_message", MsgID: []byte("mock message id 1")}, + }, + EndPositions: []*msgpb.MsgPosition{ + {ChannelName: "unit_test_replicate_message", MsgID: []byte("mock message id 2")}, + }, + } + resp, err := node.ReplicateMessage(context.TODO(), replicateRequest) + assert.NoError(t, err) + assert.EqualValues(t, 0, resp.GetStatus().GetCode()) + assert.Equal(t, base64.StdEncoding.EncodeToString([]byte("mock message id 2")), resp.GetPosition()) + + res := resourceManager.Delete(ReplicateMsgStreamTyp, replicateRequest.GetChannelName()) + assert.NotNil(t, res) + time.Sleep(2 * time.Second) + + { + broadcastMock.Unset() + broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(nil, errors.New("mock error: broadcast")) + resp, err := node.ReplicateMessage(context.TODO(), replicateRequest) + assert.NoError(t, err) + assert.NotEqualValues(t, 0, resp.GetStatus().GetCode()) + resourceManager.Delete(ReplicateMsgStreamTyp, replicateRequest.GetChannelName()) + time.Sleep(2 * time.Second) + } + { + broadcastMock.Unset() + broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything).Return(map[string][]mqwrapper.MessageID{ + "unit_test_replicate_message": {}, + }, nil) + resp, err := node.ReplicateMessage(context.TODO(), replicateRequest) + assert.NoError(t, err) + assert.EqualValues(t, 0, resp.GetStatus().GetCode()) + assert.Empty(t, resp.GetPosition()) + resourceManager.Delete(ReplicateMsgStreamTyp, replicateRequest.GetChannelName()) + time.Sleep(2 * time.Second) + broadcastMock.Unset() + } + }) +} diff --git a/internal/proxy/interface_def.go b/internal/proxy/interface_def.go index 3a908e3a62b06..e9551129bb60c 100644 --- a/internal/proxy/interface_def.go +++ b/internal/proxy/interface_def.go @@ -19,6 +19,8 @@ package proxy import ( "context" + "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" ) @@ -32,5 +34,5 @@ type tsoAllocator interface { // //go:generate mockery --name=timestampAllocatorInterface --filename=mock_tso_test.go --outpkg=proxy --output=. --inpackage --structname=mockTimestampAllocator --with-expecter type timestampAllocatorInterface interface { - AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) + AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) } diff --git a/internal/proxy/keep_active_interceptor.go b/internal/proxy/keep_active_interceptor.go index 0359279517c36..8e536f8cbbdbc 100644 --- a/internal/proxy/keep_active_interceptor.go +++ b/internal/proxy/keep_active_interceptor.go @@ -5,13 +5,11 @@ import ( "fmt" "strconv" - "github.com/milvus-io/milvus/pkg/util" - - "github.com/milvus-io/milvus/pkg/util/funcutil" - + "google.golang.org/grpc" "google.golang.org/grpc/metadata" - "google.golang.org/grpc" + "github.com/milvus-io/milvus/pkg/util" + "github.com/milvus-io/milvus/pkg/util/funcutil" ) func getIdentifierFromContext(ctx context.Context) (int64, error) { diff --git a/internal/proxy/keep_active_interceptor_test.go b/internal/proxy/keep_active_interceptor_test.go index 1fc5139f3f6cf..3de5e19f0bf36 100644 --- a/internal/proxy/keep_active_interceptor_test.go +++ b/internal/proxy/keep_active_interceptor_test.go @@ -4,9 +4,8 @@ import ( "context" "testing" - "google.golang.org/grpc" - "github.com/stretchr/testify/assert" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index 060664d6392ed..f81cab8c4a3f2 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -32,7 +32,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -type executeFunc func(context.Context, UniqueID, types.QueryNode, ...string) error +type executeFunc func(context.Context, UniqueID, types.QueryNodeClient, ...string) error type ChannelWorkload struct { db string @@ -91,7 +91,7 @@ func (lb *LBPolicyImpl) Start(ctx context.Context) { // try to select the best node from the available nodes func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) { - log := log.With( + log := log.Ctx(ctx).With( zap.Int64("collectionID", workload.collectionID), zap.String("collectionName", workload.collectionName), zap.String("channelName", workload.channel), diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index 956969cc4974b..19da3d6b0dbbd 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -41,9 +41,9 @@ import ( type LBPolicySuite struct { suite.Suite - rc types.RootCoord - qc *mocks.MockQueryCoord - qn *mocks.MockQueryNode + rc types.RootCoordClient + qc *mocks.MockQueryCoordClient + qn *mocks.MockQueryNodeClient mgr *MockShardClientManager lbBalancer *MockLBBalancer @@ -65,7 +65,7 @@ func (s *LBPolicySuite) SetupTest() { s.nodes = []int64{1, 2, 3, 4, 5} s.channels = []string{"channel1", "channel2"} successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} - qc := mocks.NewMockQueryCoord(s.T()) + qc := mocks.NewMockQueryCoordClient(s.T()) qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&successStatus, nil) qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ @@ -84,19 +84,15 @@ func (s *LBPolicySuite) SetupTest() { }, }, nil).Maybe() qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), PartitionIDs: []int64{1, 2, 3}, }, nil).Maybe() s.qc = qc s.rc = NewRootCoordMock() - s.rc.Start() - s.qn = mocks.NewMockQueryNode(s.T()) - s.qn.EXPECT().GetAddress().Return("localhost").Maybe() - s.qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + s.qn = mocks.NewMockQueryNodeClient(s.T()) + s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() s.mgr = NewMockShardClientManager(s.T()) s.mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() @@ -252,7 +248,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { return nil }, retryTimes: 1, @@ -269,7 +265,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { return nil }, retryTimes: 1, @@ -289,7 +285,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { return nil }, retryTimes: 1, @@ -307,7 +303,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { return nil }, retryTimes: 2, @@ -328,7 +324,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { counter++ if counter == 1 { return errors.New("fake error") @@ -343,8 +339,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { s.mgr.ExpectedCalls = nil s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) - s.qn.EXPECT().GetAddress().Return("localhost").Maybe() - s.qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.Canceled).Times(1) s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.DeadlineExceeded) err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ @@ -354,7 +349,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { channel: s.channels[0], shardLeaders: s.nodes, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { _, err := qn.Search(ctx, nil) return err }, @@ -375,7 +370,7 @@ func (s *LBPolicySuite) TestExecute() { collectionName: s.collectionName, collectionID: s.collectionID, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { return nil }, }) @@ -388,7 +383,7 @@ func (s *LBPolicySuite) TestExecute() { collectionName: s.collectionName, collectionID: s.collectionID, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { // succeed in first execute if counter.Add(1) == 1 { return nil @@ -409,7 +404,7 @@ func (s *LBPolicySuite) TestExecute() { collectionName: s.collectionName, collectionID: s.collectionID, nq: 1, - exec: func(ctx context.Context, ui UniqueID, qn types.QueryNode, s ...string) error { + exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, s ...string) error { return nil }, }) diff --git a/internal/proxy/look_aside_balancer.go b/internal/proxy/look_aside_balancer.go index 99a47dd1afc85..51d4a4b9a8c86 100644 --- a/internal/proxy/look_aside_balancer.go +++ b/internal/proxy/look_aside_balancer.go @@ -24,7 +24,11 @@ import ( "sync" "time" + "go.uber.org/atomic" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -32,8 +36,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/atomic" - "go.uber.org/zap" ) type LookAsideBalancer struct { @@ -225,7 +227,7 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) { return struct{}{}, nil } - resp, err := qn.GetComponentStates(ctx) + resp, err := qn.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) if err != nil { if b.trySetQueryNodeUnReachable(node, err) { log.Warn("get component status failed, set node unreachable", zap.Int64("node", node), zap.Error(err)) diff --git a/internal/proxy/look_aside_balancer_test.go b/internal/proxy/look_aside_balancer_test.go index 84360de537531..cfb7b6ec195a0 100644 --- a/internal/proxy/look_aside_balancer_test.go +++ b/internal/proxy/look_aside_balancer_test.go @@ -23,14 +23,15 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "go.uber.org/atomic" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" - "go.uber.org/atomic" ) type LookAsideBalancerSuite struct { @@ -45,9 +46,9 @@ func (suite *LookAsideBalancerSuite) SetupTest() { suite.balancer = NewLookAsideBalancer(suite.clientMgr) suite.balancer.Start(context.Background()) - qn := mocks.NewMockQueryNode(suite.T()) + qn := mocks.NewMockQueryNodeClient(suite.T()) suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(1)).Return(qn, nil).Maybe() - qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, errors.New("fake error")).Maybe() + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Maybe() } func (suite *LookAsideBalancerSuite) TearDownTest() { @@ -305,9 +306,9 @@ func (suite *LookAsideBalancerSuite) TestCancelWorkload() { } func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() { - qn2 := mocks.NewMockQueryNode(suite.T()) + qn2 := mocks.NewMockQueryNodeClient(suite.T()) suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(2)).Return(qn2, nil) - qn2.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ + qn2.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Healthy, }, @@ -335,15 +336,15 @@ func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() { func (suite *LookAsideBalancerSuite) TestNodeRecover() { // mock qn down for a while and then recover - qn3 := mocks.NewMockQueryNode(suite.T()) + qn3 := mocks.NewMockQueryNodeClient(suite.T()) suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(3)).Return(qn3, nil) - qn3.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ + qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Abnormal, }, }, nil).Times(3) - qn3.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ + qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Healthy, }, @@ -363,9 +364,9 @@ func (suite *LookAsideBalancerSuite) TestNodeOffline() { Params.Save(Params.CommonCfg.SessionTTL.Key, "10") Params.Save(Params.ProxyCfg.HealthCheckTimeout.Key, "1000") // mock qn down for a while and then recover - qn3 := mocks.NewMockQueryNode(suite.T()) + qn3 := mocks.NewMockQueryNodeClient(suite.T()) suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(3)).Return(qn3, nil) - qn3.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ + qn3.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Abnormal, }, diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 6e387b7e87cff..a2fe911e256f9 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -192,8 +192,8 @@ var _ Cache = (*MetaCache)(nil) // MetaCache implements Cache, provides collection meta cache based on internal RootCoord type MetaCache struct { - rootCoord types.RootCoord - queryCoord types.QueryCoord + rootCoord types.RootCoordClient + queryCoord types.QueryCoordClient collInfo map[string]map[string]*collectionInfo // database -> collection -> collection_info credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load @@ -209,7 +209,7 @@ type MetaCache struct { var globalMetaCache Cache // InitMetaCache initializes globalMetaCache -func InitMetaCache(ctx context.Context, rootCoord types.RootCoord, queryCoord types.QueryCoord, shardMgr shardClientMgr) error { +func InitMetaCache(ctx context.Context, rootCoord types.RootCoordClient, queryCoord types.QueryCoordClient, shardMgr shardClientMgr) error { var err error globalMetaCache, err = NewMetaCache(rootCoord, queryCoord, shardMgr) if err != nil { @@ -229,7 +229,7 @@ func InitMetaCache(ctx context.Context, rootCoord types.RootCoord, queryCoord ty } // NewMetaCache creates a MetaCache with provided RootCoord and QueryNode -func NewMetaCache(rootCoord types.RootCoord, queryCoord types.QueryCoord, shardMgr shardClientMgr) (*MetaCache, error) { +func NewMetaCache(rootCoord types.RootCoordClient, queryCoord types.QueryCoordClient, shardMgr shardClientMgr) (*MetaCache, error) { return &MetaCache{ rootCoord: rootCoord, queryCoord: queryCoord, @@ -497,7 +497,6 @@ func (m *MetaCache) GetPartitions(ctx context.Context, database, collectionName ret[k] = v.partitionID } return ret, nil - } defer m.mu.RUnlock() @@ -621,8 +620,8 @@ func (m *MetaCache) showPartitions(ctx context.Context, dbName string, collectio if err != nil { return nil, err } - if partitions.Status.ErrorCode != commonpb.ErrorCode_Success { - return nil, fmt.Errorf("%s", partitions.Status.Reason) + if partitions.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return nil, fmt.Errorf("%s", partitions.GetStatus().GetReason()) } if len(partitions.PartitionIDs) != len(partitions.PartitionNames) { @@ -808,27 +807,28 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col if err != nil { return retry.Unrecoverable(err) } - if resp.Status.ErrorCode == commonpb.ErrorCode_Success { + if resp.GetStatus().GetErrorCode() == commonpb.ErrorCode_Success { return nil } // do not retry unless got NoReplicaAvailable from querycoord - if resp.Status.ErrorCode != commonpb.ErrorCode_NoReplicaAvailable { - return retry.Unrecoverable(fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason)) + err2 := merr.Error(resp.GetStatus()) + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_NoReplicaAvailable { + return retry.Unrecoverable(err2) } - return fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason) + return err2 }) if err != nil { return nil, err } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return nil, fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason) + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return nil, merr.Error(resp.GetStatus()) } shards := parseShardLeaderList2QueryNode(resp.GetShards()) info, err = m.getFullCollectionInfo(ctx, database, collectionName, collectionID) if err != nil { - return nil, fmt.Errorf("failed to get shards, collectionName %s, colectionID %d not found", collectionName, collectionID) + return nil, err } // lock leader info.leaderMutex.Lock() diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index d37c5b38be713..b922f7e9742bd 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -29,6 +29,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" uatomic "go.uber.org/atomic" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -48,7 +49,7 @@ import ( var dbName = GetCurDBNameFromContextOrDefault(context.Background()) type MockRootCoordClientInterface struct { - types.RootCoord + types.RootCoordClient Error bool AccessCount int32 @@ -64,15 +65,13 @@ func (m *MockRootCoordClientInterface) GetAccessCount() int { return int(ret) } -func (m *MockRootCoordClientInterface) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { +func (m *MockRootCoordClientInterface) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { if m.Error { return nil, errors.New("mocked error") } if in.CollectionName == "collection1" { return &milvuspb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), PartitionIDs: []typeutil.UniqueID{1, 2}, CreatedTimestamps: []uint64{100, 200}, CreatedUtcTimestamps: []uint64{100, 200}, @@ -81,9 +80,7 @@ func (m *MockRootCoordClientInterface) ShowPartitions(ctx context.Context, in *m } if in.CollectionName == "collection2" { return &milvuspb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), PartitionIDs: []typeutil.UniqueID{3, 4}, CreatedTimestamps: []uint64{201, 202}, CreatedUtcTimestamps: []uint64{201, 202}, @@ -92,9 +89,7 @@ func (m *MockRootCoordClientInterface) ShowPartitions(ctx context.Context, in *m } if in.CollectionName == "errorCollection" { return &milvuspb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), PartitionIDs: []typeutil.UniqueID{5, 6}, CreatedTimestamps: []uint64{201}, CreatedUtcTimestamps: []uint64{201}, @@ -112,16 +107,14 @@ func (m *MockRootCoordClientInterface) ShowPartitions(ctx context.Context, in *m }, nil } -func (m *MockRootCoordClientInterface) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (m *MockRootCoordClientInterface) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { if m.Error { return nil, errors.New("mocked error") } m.IncAccessCount() if in.CollectionName == "collection1" || in.CollectionID == 1 { return &milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), CollectionID: typeutil.UniqueID(1), Schema: &schemapb.CollectionSchema{ AutoID: true, @@ -132,9 +125,7 @@ func (m *MockRootCoordClientInterface) DescribeCollection(ctx context.Context, i } if in.CollectionName == "collection2" || in.CollectionID == 2 { return &milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), CollectionID: typeutil.UniqueID(2), Schema: &schemapb.CollectionSchema{ AutoID: true, @@ -145,9 +136,7 @@ func (m *MockRootCoordClientInterface) DescribeCollection(ctx context.Context, i } if in.CollectionName == "errorCollection" { return &milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), CollectionID: typeutil.UniqueID(3), Schema: &schemapb.CollectionSchema{ AutoID: true, @@ -163,7 +152,7 @@ func (m *MockRootCoordClientInterface) DescribeCollection(ctx context.Context, i }, nil } -func (m *MockRootCoordClientInterface) GetCredential(ctx context.Context, req *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) { +func (m *MockRootCoordClientInterface) GetCredential(ctx context.Context, req *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error) { if m.Error { return nil, errors.New("mocked error") } @@ -171,9 +160,7 @@ func (m *MockRootCoordClientInterface) GetCredential(ctx context.Context, req *r if req.Username == "mockUser" { encryptedPassword, _ := crypto.PasswordEncrypt("mockPass") return &rootcoordpb.GetCredentialResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Username: "mockUser", Password: encryptedPassword, }, nil @@ -183,27 +170,23 @@ func (m *MockRootCoordClientInterface) GetCredential(ctx context.Context, req *r return nil, err } -func (m *MockRootCoordClientInterface) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) { +func (m *MockRootCoordClientInterface) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest, opts ...grpc.CallOption) (*milvuspb.ListCredUsersResponse, error) { if m.Error { return nil, errors.New("mocked error") } return &milvuspb.ListCredUsersResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Usernames: []string{"mockUser"}, }, nil } -func (m *MockRootCoordClientInterface) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { +func (m *MockRootCoordClientInterface) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest, opts ...grpc.CallOption) (*internalpb.ListPolicyResponse, error) { if m.listPolicy != nil { return m.listPolicy(ctx, in) } return &internalpb.ListPolicyResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } @@ -211,7 +194,7 @@ func (m *MockRootCoordClientInterface) ListPolicy(ctx context.Context, in *inter func TestMetaCache_GetCollection(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -261,7 +244,7 @@ func TestMetaCache_GetCollection(t *testing.T) { func TestMetaCache_GetBasicCollectionInfo(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -295,7 +278,7 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) { func TestMetaCache_GetCollectionName(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -345,7 +328,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) { func TestMetaCache_GetCollectionFailure(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -378,7 +361,7 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) { func TestMetaCache_GetNonExistCollection(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -394,7 +377,7 @@ func TestMetaCache_GetNonExistCollection(t *testing.T) { func TestMetaCache_GetPartitionID(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -416,7 +399,7 @@ func TestMetaCache_GetPartitionID(t *testing.T) { func TestMetaCache_ConcurrentTest1(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -470,7 +453,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) { func TestMetaCache_GetPartitionError(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -503,16 +486,11 @@ func TestMetaCache_GetShards(t *testing.T) { ) rootCoord := &MockRootCoordClientInterface{} - qc := getQueryCoord() - qc.EXPECT().Init().Return(nil) + qc := getQueryCoordClient() shardMgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, qc, shardMgr) require.Nil(t, err) - qc.Init() - qc.Start() - defer qc.Stop() - t.Run("No collection in meta cache", func(t *testing.T) { shards, err := globalMetaCache.GetShards(ctx, true, dbName, "non-exists", 0) assert.Error(t, err) @@ -536,9 +514,7 @@ func TestMetaCache_GetShards(t *testing.T) { t.Run("without shardLeaders in collection info", func(t *testing.T) { qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Shards: []*querypb.ShardLeadersList{ { ChannelName: "channel-1", @@ -580,16 +556,11 @@ func TestMetaCache_ClearShards(t *testing.T) { ) rootCoord := &MockRootCoordClientInterface{} - qc := getQueryCoord() - qc.EXPECT().Init().Return(nil) + qc := getQueryCoordClient() mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, qc, mgr) require.Nil(t, err) - qc.Init() - qc.Start() - defer qc.Stop() - t.Run("Clear with no collection info", func(t *testing.T) { globalMetaCache.DeprecateShardCache(dbName, "collection_not_exist") }) @@ -600,9 +571,7 @@ func TestMetaCache_ClearShards(t *testing.T) { t.Run("Clear valid collection valid cache", func(t *testing.T) { qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Shards: []*querypb.ShardLeadersList{ { ChannelName: "channel-1", @@ -636,7 +605,7 @@ func TestMetaCache_ClearShards(t *testing.T) { func TestMetaCache_PolicyInfo(t *testing.T) { client := &MockRootCoordClientInterface{} - qc := &mocks.MockQueryCoord{} + qc := &mocks.MockQueryCoordClient{} mgr := newShardClientMgr() t.Run("InitMetaCache", func(t *testing.T) { @@ -648,9 +617,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) { client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { return &internalpb.ListPolicyResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), PolicyInfos: []string{"policy1", "policy2", "policy3"}, }, nil } @@ -661,9 +628,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) { t.Run("GetPrivilegeInfo", func(t *testing.T) { client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { return &internalpb.ListPolicyResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), PolicyInfos: []string{"policy1", "policy2", "policy3"}, UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2")}, }, nil @@ -679,9 +644,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) { t.Run("GetPrivilegeInfo", func(t *testing.T) { client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { return &internalpb.ListPolicyResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), PolicyInfos: []string{"policy1", "policy2", "policy3"}, UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2")}, }, nil @@ -718,9 +681,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) { t.Run("Delete user or drop role", func(t *testing.T) { client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { return &internalpb.ListPolicyResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), PolicyInfos: []string{"policy1", "policy2", "policy3"}, UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2"), funcutil.EncodeUserRoleCache("foo2", "role3")}, }, nil @@ -745,9 +706,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) { client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { return &internalpb.ListPolicyResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), PolicyInfos: []string{"policy1", "policy2", "policy3"}, UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2"), funcutil.EncodeUserRoleCache("foo2", "role3")}, }, nil @@ -762,15 +721,13 @@ func TestMetaCache_PolicyInfo(t *testing.T) { func TestMetaCache_RemoveCollection(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} shardMgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, shardMgr) assert.NoError(t, err) queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), CollectionIDs: []UniqueID{1, 2}, InMemoryPercentages: []int64{100, 50}, }, nil) @@ -813,22 +770,18 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} shardMgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, shardMgr) assert.NoError(t, err) queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), CollectionIDs: []UniqueID{1}, InMemoryPercentages: []int64{100}, }, nil) queryCoord.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Shards: []*querypb.ShardLeadersList{ { ChannelName: "channel-1", @@ -843,9 +796,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { queryCoord.ExpectedCalls = nil queryCoord.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Shards: []*querypb.ShardLeadersList{ { ChannelName: "channel-1", @@ -863,9 +814,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { queryCoord.ExpectedCalls = nil queryCoord.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Shards: []*querypb.ShardLeadersList{ { ChannelName: "channel-1", @@ -883,9 +832,7 @@ func TestMetaCache_ExpireShardLeaderCache(t *testing.T) { queryCoord.ExpectedCalls = nil queryCoord.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Shards: []*querypb.ShardLeadersList{ { ChannelName: "channel-1", diff --git a/internal/proxy/metrics_info.go b/internal/proxy/metrics_info.go index f4e7021559907..c02fae5aa22aa 100644 --- a/internal/proxy/metrics_info.go +++ b/internal/proxy/metrics_info.go @@ -20,7 +20,6 @@ import ( "context" "sync" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/util/hardware" @@ -31,8 +30,10 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -type getMetricsFuncType func(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) -type showConfigurationsFuncType func(ctx context.Context, request *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) +type ( + getMetricsFuncType func(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) + showConfigurationsFuncType func(ctx context.Context, request *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) +) // getQuotaMetrics returns ProxyQuotaMetrics. func getQuotaMetrics() (*metricsinfo.ProxyQuotaMetrics, error) { @@ -108,7 +109,7 @@ func getProxyMetrics(ctx context.Context, request *milvuspb.GetMetricsRequest, n } return &milvuspb.GetMetricsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.ProxyRole, paramtable.GetNodeID()), }, nil @@ -418,17 +419,14 @@ func getSystemInfoMetrics( resp, err := metricsinfo.MarshalTopology(systemTopology) if err != nil { return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), Response: "", ComponentName: metricsinfo.ConstructComponentName(typeutil.ProxyRole, paramtable.GetNodeID()), }, nil } return &milvuspb.GetMetricsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.ProxyRole, paramtable.GetNodeID()), }, nil diff --git a/internal/proxy/metrics_info_test.go b/internal/proxy/metrics_info_test.go index a7f4e468dfffc..92a9db1203055 100644 --- a/internal/proxy/metrics_info_test.go +++ b/internal/proxy/metrics_info_test.go @@ -20,18 +20,16 @@ import ( "context" "testing" - "github.com/milvus-io/milvus/internal/util/sessionutil" - "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/uniquegenerator" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/milvus-io/milvus/pkg/util/uniquegenerator" ) func TestProxy_metrics(t *testing.T) { @@ -40,22 +38,17 @@ func TestProxy_metrics(t *testing.T) { ctx := context.Background() rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() - - qc := getQueryCoord() - qc.Start() - defer qc.Stop() + defer rc.Close() + qc := getQueryCoordClient() dc := NewDataCoordMock() - dc.Start() - defer dc.Stop() + defer dc.Close() proxy := &Proxy{ rootCoord: rc, queryCoord: qc, dataCoord: dc, - session: &sessionutil.Session{Address: funcutil.GenRandomStr()}, + session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{Address: funcutil.GenRandomStr()}}, } rc.getMetricsFunc = func(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { @@ -90,10 +83,7 @@ func TestProxy_metrics(t *testing.T) { resp, _ := metricsinfo.MarshalTopology(rootCoordTopology) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.RootCoordRole, id), }, nil @@ -142,10 +132,7 @@ func TestProxy_metrics(t *testing.T) { resp, _ := metricsinfo.MarshalTopology(coordTopology) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryCoordRole, id), }, nil @@ -202,14 +189,10 @@ func TestProxy_metrics(t *testing.T) { resp, _ := metricsinfo.MarshalTopology(coordTopology) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.DataCoordRole, id), }, nil - } req, _ := metricsinfo.ConstructRequestByMetricType(metricsinfo.SystemInfoMetrics) diff --git a/internal/proxy/mock_channels_manager.go b/internal/proxy/mock_channels_manager.go new file mode 100644 index 0000000000000..79b8c6015de5c --- /dev/null +++ b/internal/proxy/mock_channels_manager.go @@ -0,0 +1,262 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package proxy + +import ( + msgstream "github.com/milvus-io/milvus/pkg/mq/msgstream" + mock "github.com/stretchr/testify/mock" +) + +// MockChannelsMgr is an autogenerated mock type for the channelsMgr type +type MockChannelsMgr struct { + mock.Mock +} + +type MockChannelsMgr_Expecter struct { + mock *mock.Mock +} + +func (_m *MockChannelsMgr) EXPECT() *MockChannelsMgr_Expecter { + return &MockChannelsMgr_Expecter{mock: &_m.Mock} +} + +// getChannels provides a mock function with given fields: collectionID +func (_m *MockChannelsMgr) getChannels(collectionID int64) ([]string, error) { + ret := _m.Called(collectionID) + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(int64) ([]string, error)); ok { + return rf(collectionID) + } + if rf, ok := ret.Get(0).(func(int64) []string); ok { + r0 = rf(collectionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(int64) error); ok { + r1 = rf(collectionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockChannelsMgr_getChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'getChannels' +type MockChannelsMgr_getChannels_Call struct { + *mock.Call +} + +// getChannels is a helper method to define mock.On call +// - collectionID int64 +func (_e *MockChannelsMgr_Expecter) getChannels(collectionID interface{}) *MockChannelsMgr_getChannels_Call { + return &MockChannelsMgr_getChannels_Call{Call: _e.mock.On("getChannels", collectionID)} +} + +func (_c *MockChannelsMgr_getChannels_Call) Run(run func(collectionID int64)) *MockChannelsMgr_getChannels_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockChannelsMgr_getChannels_Call) Return(_a0 []string, _a1 error) *MockChannelsMgr_getChannels_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockChannelsMgr_getChannels_Call) RunAndReturn(run func(int64) ([]string, error)) *MockChannelsMgr_getChannels_Call { + _c.Call.Return(run) + return _c +} + +// getOrCreateDmlStream provides a mock function with given fields: collectionID +func (_m *MockChannelsMgr) getOrCreateDmlStream(collectionID int64) (msgstream.MsgStream, error) { + ret := _m.Called(collectionID) + + var r0 msgstream.MsgStream + var r1 error + if rf, ok := ret.Get(0).(func(int64) (msgstream.MsgStream, error)); ok { + return rf(collectionID) + } + if rf, ok := ret.Get(0).(func(int64) msgstream.MsgStream); ok { + r0 = rf(collectionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(msgstream.MsgStream) + } + } + + if rf, ok := ret.Get(1).(func(int64) error); ok { + r1 = rf(collectionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockChannelsMgr_getOrCreateDmlStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'getOrCreateDmlStream' +type MockChannelsMgr_getOrCreateDmlStream_Call struct { + *mock.Call +} + +// getOrCreateDmlStream is a helper method to define mock.On call +// - collectionID int64 +func (_e *MockChannelsMgr_Expecter) getOrCreateDmlStream(collectionID interface{}) *MockChannelsMgr_getOrCreateDmlStream_Call { + return &MockChannelsMgr_getOrCreateDmlStream_Call{Call: _e.mock.On("getOrCreateDmlStream", collectionID)} +} + +func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Run(run func(collectionID int64)) *MockChannelsMgr_getOrCreateDmlStream_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Return(_a0 msgstream.MsgStream, _a1 error) *MockChannelsMgr_getOrCreateDmlStream_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) RunAndReturn(run func(int64) (msgstream.MsgStream, error)) *MockChannelsMgr_getOrCreateDmlStream_Call { + _c.Call.Return(run) + return _c +} + +// getVChannels provides a mock function with given fields: collectionID +func (_m *MockChannelsMgr) getVChannels(collectionID int64) ([]string, error) { + ret := _m.Called(collectionID) + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(int64) ([]string, error)); ok { + return rf(collectionID) + } + if rf, ok := ret.Get(0).(func(int64) []string); ok { + r0 = rf(collectionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(int64) error); ok { + r1 = rf(collectionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockChannelsMgr_getVChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'getVChannels' +type MockChannelsMgr_getVChannels_Call struct { + *mock.Call +} + +// getVChannels is a helper method to define mock.On call +// - collectionID int64 +func (_e *MockChannelsMgr_Expecter) getVChannels(collectionID interface{}) *MockChannelsMgr_getVChannels_Call { + return &MockChannelsMgr_getVChannels_Call{Call: _e.mock.On("getVChannels", collectionID)} +} + +func (_c *MockChannelsMgr_getVChannels_Call) Run(run func(collectionID int64)) *MockChannelsMgr_getVChannels_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockChannelsMgr_getVChannels_Call) Return(_a0 []string, _a1 error) *MockChannelsMgr_getVChannels_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockChannelsMgr_getVChannels_Call) RunAndReturn(run func(int64) ([]string, error)) *MockChannelsMgr_getVChannels_Call { + _c.Call.Return(run) + return _c +} + +// removeAllDMLStream provides a mock function with given fields: +func (_m *MockChannelsMgr) removeAllDMLStream() { + _m.Called() +} + +// MockChannelsMgr_removeAllDMLStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'removeAllDMLStream' +type MockChannelsMgr_removeAllDMLStream_Call struct { + *mock.Call +} + +// removeAllDMLStream is a helper method to define mock.On call +func (_e *MockChannelsMgr_Expecter) removeAllDMLStream() *MockChannelsMgr_removeAllDMLStream_Call { + return &MockChannelsMgr_removeAllDMLStream_Call{Call: _e.mock.On("removeAllDMLStream")} +} + +func (_c *MockChannelsMgr_removeAllDMLStream_Call) Run(run func()) *MockChannelsMgr_removeAllDMLStream_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockChannelsMgr_removeAllDMLStream_Call) Return() *MockChannelsMgr_removeAllDMLStream_Call { + _c.Call.Return() + return _c +} + +func (_c *MockChannelsMgr_removeAllDMLStream_Call) RunAndReturn(run func()) *MockChannelsMgr_removeAllDMLStream_Call { + _c.Call.Return(run) + return _c +} + +// removeDMLStream provides a mock function with given fields: collectionID +func (_m *MockChannelsMgr) removeDMLStream(collectionID int64) { + _m.Called(collectionID) +} + +// MockChannelsMgr_removeDMLStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'removeDMLStream' +type MockChannelsMgr_removeDMLStream_Call struct { + *mock.Call +} + +// removeDMLStream is a helper method to define mock.On call +// - collectionID int64 +func (_e *MockChannelsMgr_Expecter) removeDMLStream(collectionID interface{}) *MockChannelsMgr_removeDMLStream_Call { + return &MockChannelsMgr_removeDMLStream_Call{Call: _e.mock.On("removeDMLStream", collectionID)} +} + +func (_c *MockChannelsMgr_removeDMLStream_Call) Run(run func(collectionID int64)) *MockChannelsMgr_removeDMLStream_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockChannelsMgr_removeDMLStream_Call) Return() *MockChannelsMgr_removeDMLStream_Call { + _c.Call.Return() + return _c +} + +func (_c *MockChannelsMgr_removeDMLStream_Call) RunAndReturn(run func(int64)) *MockChannelsMgr_removeDMLStream_Call { + _c.Call.Return(run) + return _c +} + +// NewMockChannelsMgr creates a new instance of MockChannelsMgr. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockChannelsMgr(t interface { + mock.TestingT + Cleanup(func()) +}) *MockChannelsMgr { + mock := &MockChannelsMgr{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/proxy/mock_channels_mgr_test.go b/internal/proxy/mock_channels_mgr_test.go deleted file mode 100644 index 062b62ed42c89..0000000000000 --- a/internal/proxy/mock_channels_mgr_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package proxy - -type getVChannelsFuncType = func(collectionID UniqueID) ([]vChan, error) -type removeDMLStreamFuncType = func(collectionID UniqueID) error - -type mockChannelsMgr struct { - channelsMgr - getChannelsFunc func(collectionID UniqueID) ([]pChan, error) - getVChannelsFuncType - removeDMLStreamFuncType -} - -func (m *mockChannelsMgr) getChannels(collectionID UniqueID) ([]pChan, error) { - if m.getChannelsFunc != nil { - return m.getChannelsFunc(collectionID) - } - return nil, nil -} - -func (m *mockChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan, error) { - if m.getVChannelsFuncType != nil { - return m.getVChannelsFuncType(collectionID) - } - return nil, nil -} - -func (m *mockChannelsMgr) removeDMLStream(collectionID UniqueID) { - if m.removeDMLStreamFuncType != nil { - m.removeDMLStreamFuncType(collectionID) - } -} - -func newMockChannelsMgr() *mockChannelsMgr { - return &mockChannelsMgr{} -} diff --git a/internal/proxy/mock_msgstream_test.go b/internal/proxy/mock_msgstream_test.go index 95f39590a9d73..613dd97b94057 100644 --- a/internal/proxy/mock_msgstream_test.go +++ b/internal/proxy/mock_msgstream_test.go @@ -10,9 +10,10 @@ import ( type mockMsgStream struct { msgstream.MsgStream - asProducer func([]string) - setRepack func(repackFunc msgstream.RepackFunc) - close func() + asProducer func([]string) + setRepack func(repackFunc msgstream.RepackFunc) + close func() + enableProduce func(bool) } func (m *mockMsgStream) AsProducer(producers []string) { @@ -33,6 +34,12 @@ func (m *mockMsgStream) Close() { } } +func (m *mockMsgStream) EnableProduce(enabled bool) { + if m.enableProduce != nil { + m.enableProduce(enabled) + } +} + func newMockMsgStream() *mockMsgStream { return &mockMsgStream{} } diff --git a/internal/proxy/mock_shardclient_manager.go b/internal/proxy/mock_shardclient_manager.go index 51ed2c7c13d95..33d886a18dc02 100644 --- a/internal/proxy/mock_shardclient_manager.go +++ b/internal/proxy/mock_shardclient_manager.go @@ -55,19 +55,19 @@ func (_c *MockShardClientManager_Close_Call) RunAndReturn(run func()) *MockShard } // GetClient provides a mock function with given fields: ctx, nodeID -func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeID int64) (types.QueryNode, error) { +func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeID int64) (types.QueryNodeClient, error) { ret := _m.Called(ctx, nodeID) - var r0 types.QueryNode + var r0 types.QueryNodeClient var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int64) (types.QueryNode, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, int64) (types.QueryNodeClient, error)); ok { return rf(ctx, nodeID) } - if rf, ok := ret.Get(0).(func(context.Context, int64) types.QueryNode); ok { + if rf, ok := ret.Get(0).(func(context.Context, int64) types.QueryNodeClient); ok { r0 = rf(ctx, nodeID) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(types.QueryNode) + r0 = ret.Get(0).(types.QueryNodeClient) } } @@ -99,12 +99,12 @@ func (_c *MockShardClientManager_GetClient_Call) Run(run func(ctx context.Contex return _c } -func (_c *MockShardClientManager_GetClient_Call) Return(_a0 types.QueryNode, _a1 error) *MockShardClientManager_GetClient_Call { +func (_c *MockShardClientManager_GetClient_Call) Return(_a0 types.QueryNodeClient, _a1 error) *MockShardClientManager_GetClient_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.Context, int64) (types.QueryNode, error)) *MockShardClientManager_GetClient_Call { +func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.Context, int64) (types.QueryNodeClient, error)) *MockShardClientManager_GetClient_Call { _c.Call.Return(run) return _c } diff --git a/internal/proxy/mock_test.go b/internal/proxy/mock_test.go index 598752ecde5df..836ad42cf40c0 100644 --- a/internal/proxy/mock_test.go +++ b/internal/proxy/mock_test.go @@ -22,6 +22,8 @@ import ( "sync" "time" + "google.golang.org/grpc" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" @@ -29,6 +31,7 @@ import ( "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/uniquegenerator" ) @@ -38,7 +41,7 @@ type mockTimestampAllocatorInterface struct { mtx sync.Mutex } -func (tso *mockTimestampAllocatorInterface) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { +func (tso *mockTimestampAllocatorInterface) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) { tso.mtx.Lock() defer tso.mtx.Unlock() @@ -49,10 +52,7 @@ func (tso *mockTimestampAllocatorInterface) AllocTimestamp(ctx context.Context, tso.lastTs = ts return &rootcoordpb.AllocTimestampResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), Timestamp: ts, Count: req.Count, }, nil @@ -81,8 +81,7 @@ func newMockTsoAllocator() tsoAllocator { return &mockTsoAllocator{} } -type mockIDAllocatorInterface struct { -} +type mockIDAllocatorInterface struct{} func (m *mockIDAllocatorInterface) AllocOne() (UniqueID, error) { return UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()), nil @@ -207,6 +206,8 @@ func newMockDmlTask(ctx context.Context) *mockDmlTask { return &mockDmlTask{ mockTask: newMockTask(ctx), + vchans: vchans, + pchans: pchans, } } @@ -252,7 +253,8 @@ func (ms *simpleMockMsgStream) Chan() <-chan *msgstream.MsgPack { func (ms *simpleMockMsgStream) AsProducer(channels []string) { } -func (ms *simpleMockMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) { +func (ms *simpleMockMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error { + return nil } func (ms *simpleMockMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) { @@ -292,7 +294,7 @@ func (ms *simpleMockMsgStream) GetProduceChannels() []string { return nil } -func (ms *simpleMockMsgStream) Seek(offset []*msgstream.MsgPosition) error { +func (ms *simpleMockMsgStream) Seek(ctx context.Context, offset []*msgstream.MsgPosition) error { return nil } @@ -304,6 +306,9 @@ func (ms *simpleMockMsgStream) CheckTopicValid(topic string) error { return nil } +func (ms *simpleMockMsgStream) EnableProduce(enabled bool) { +} + func newSimpleMockMsgStream() *simpleMockMsgStream { return &simpleMockMsgStream{ msgChan: make(chan *msgstream.MsgPack, 1024), @@ -311,8 +316,7 @@ func newSimpleMockMsgStream() *simpleMockMsgStream { } } -type simpleMockMsgStreamFactory struct { -} +type simpleMockMsgStreamFactory struct{} func (factory *simpleMockMsgStreamFactory) Init(param *paramtable.ComponentParam) error { return nil @@ -429,7 +433,7 @@ func generateFieldData(dataType schemapb.DataType, fieldName string, numRows int }, } default: - //TODO:: + // TODO:: } return fieldData diff --git a/internal/proxy/mock_tso_test.go b/internal/proxy/mock_tso_test.go index d78aa0c68f7ea..10b0b9655ca1a 100644 --- a/internal/proxy/mock_tso_test.go +++ b/internal/proxy/mock_tso_test.go @@ -5,8 +5,11 @@ package proxy import ( context "context" - rootcoordpb "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + grpc "google.golang.org/grpc" + mock "github.com/stretchr/testify/mock" + + rootcoordpb "github.com/milvus-io/milvus/internal/proto/rootcoordpb" ) // mockTimestampAllocator is an autogenerated mock type for the timestampAllocatorInterface type @@ -22,25 +25,32 @@ func (_m *mockTimestampAllocator) EXPECT() *mockTimestampAllocator_Expecter { return &mockTimestampAllocator_Expecter{mock: &_m.Mock} } -// AllocTimestamp provides a mock function with given fields: ctx, req -func (_m *mockTimestampAllocator) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { - ret := _m.Called(ctx, req) +// AllocTimestamp provides a mock function with given fields: ctx, req, opts +func (_m *mockTimestampAllocator) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, req) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 *rootcoordpb.AllocTimestampResponse var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error)); ok { - return rf(ctx, req) + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.AllocTimestampRequest, ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error)); ok { + return rf(ctx, req, opts...) } - if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.AllocTimestampRequest) *rootcoordpb.AllocTimestampResponse); ok { - r0 = rf(ctx, req) + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.AllocTimestampRequest, ...grpc.CallOption) *rootcoordpb.AllocTimestampResponse); ok { + r0 = rf(ctx, req, opts...) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*rootcoordpb.AllocTimestampResponse) } } - if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.AllocTimestampRequest) error); ok { - r1 = rf(ctx, req) + if rf, ok := ret.Get(1).(func(context.Context, *rootcoordpb.AllocTimestampRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, req, opts...) } else { r1 = ret.Error(1) } @@ -56,13 +66,21 @@ type mockTimestampAllocator_AllocTimestamp_Call struct { // AllocTimestamp is a helper method to define mock.On call // - ctx context.Context // - req *rootcoordpb.AllocTimestampRequest -func (_e *mockTimestampAllocator_Expecter) AllocTimestamp(ctx interface{}, req interface{}) *mockTimestampAllocator_AllocTimestamp_Call { - return &mockTimestampAllocator_AllocTimestamp_Call{Call: _e.mock.On("AllocTimestamp", ctx, req)} +// - opts ...grpc.CallOption +func (_e *mockTimestampAllocator_Expecter) AllocTimestamp(ctx interface{}, req interface{}, opts ...interface{}) *mockTimestampAllocator_AllocTimestamp_Call { + return &mockTimestampAllocator_AllocTimestamp_Call{Call: _e.mock.On("AllocTimestamp", + append([]interface{}{ctx, req}, opts...)...)} } -func (_c *mockTimestampAllocator_AllocTimestamp_Call) Run(run func(ctx context.Context, req *rootcoordpb.AllocTimestampRequest)) *mockTimestampAllocator_AllocTimestamp_Call { +func (_c *mockTimestampAllocator_AllocTimestamp_Call) Run(run func(ctx context.Context, req *rootcoordpb.AllocTimestampRequest, opts ...grpc.CallOption)) *mockTimestampAllocator_AllocTimestamp_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*rootcoordpb.AllocTimestampRequest)) + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*rootcoordpb.AllocTimestampRequest), variadicArgs...) }) return _c } @@ -72,7 +90,7 @@ func (_c *mockTimestampAllocator_AllocTimestamp_Call) Return(_a0 *rootcoordpb.Al return _c } -func (_c *mockTimestampAllocator_AllocTimestamp_Call) RunAndReturn(run func(context.Context, *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error)) *mockTimestampAllocator_AllocTimestamp_Call { +func (_c *mockTimestampAllocator_AllocTimestamp_Call) RunAndReturn(run func(context.Context, *rootcoordpb.AllocTimestampRequest, ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error)) *mockTimestampAllocator_AllocTimestamp_Call { _c.Call.Return(run) return _c } diff --git a/internal/proxy/msg_pack.go b/internal/proxy/msg_pack.go index 543127095a4c2..7d1d58b213698 100644 --- a/internal/proxy/msg_pack.go +++ b/internal/proxy/msg_pack.go @@ -44,7 +44,8 @@ func genInsertMsgsByPartition(ctx context.Context, partitionName string, rowOffsets []int, channelName string, - insertMsg *msgstream.InsertMsg) ([]msgstream.TsMsg, error) { + insertMsg *msgstream.InsertMsg, +) ([]msgstream.TsMsg, error) { threshold := Params.PulsarCfg.MaxMessageSize.GetAsInt() // create empty insert message @@ -108,7 +109,8 @@ func repackInsertDataByPartition(ctx context.Context, rowOffsets []int, channelName string, insertMsg *msgstream.InsertMsg, - segIDAssigner *segIDAssigner) ([]msgstream.TsMsg, error) { + segIDAssigner *segIDAssigner, +) ([]msgstream.TsMsg, error) { res := make([]msgstream.TsMsg, 0) maxTs := Timestamp(0) @@ -155,7 +157,8 @@ func repackInsertDataByPartition(ctx context.Context, func setMsgID(ctx context.Context, msgs []msgstream.TsMsg, - idAllocator *allocator.IDAllocator) error { + idAllocator *allocator.IDAllocator, +) error { var idBegin int64 var err error @@ -180,7 +183,8 @@ func repackInsertData(ctx context.Context, insertMsg *msgstream.InsertMsg, result *milvuspb.MutationResult, idAllocator *allocator.IDAllocator, - segIDAssigner *segIDAssigner) (*msgstream.MsgPack, error) { + segIDAssigner *segIDAssigner, +) (*msgstream.MsgPack, error) { msgPack := &msgstream.MsgPack{ BeginTs: insertMsg.BeginTs(), EndTs: insertMsg.EndTs(), @@ -219,7 +223,8 @@ func repackInsertDataWithPartitionKey(ctx context.Context, insertMsg *msgstream.InsertMsg, result *milvuspb.MutationResult, idAllocator *allocator.IDAllocator, - segIDAssigner *segIDAssigner) (*msgstream.MsgPack, error) { + segIDAssigner *segIDAssigner, +) (*msgstream.MsgPack, error) { msgPack := &msgstream.MsgPack{ BeginTs: insertMsg.BeginTs(), EndTs: insertMsg.EndTs(), diff --git a/internal/proxy/msg_pack_test.go b/internal/proxy/msg_pack_test.go index 194777a1a9e97..4114660666e93 100644 --- a/internal/proxy/msg_pack_test.go +++ b/internal/proxy/msg_pack_test.go @@ -45,8 +45,7 @@ func TestRepackInsertData(t *testing.T) { ctx := context.Background() rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() + defer rc.Close() cache := NewMockCache(t) cache.On("GetPartitionID", @@ -152,8 +151,7 @@ func TestRepackInsertDataWithPartitionKey(t *testing.T) { dbName := GetCurDBNameFromContextOrDefault(ctx) rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() + defer rc.Close() err := InitMetaCache(ctx, rc, nil, nil) assert.NoError(t, err) @@ -171,7 +169,8 @@ func TestRepackInsertDataWithPartitionKey(t *testing.T) { fieldName2Types := map[string]schemapb.DataType{ testInt64Field: schemapb.DataType_Int64, testVarCharField: schemapb.DataType_VarChar, - testFloatVecField: schemapb.DataType_FloatVector} + testFloatVecField: schemapb.DataType_FloatVector, + } t.Run("create collection with partition key", func(t *testing.T) { schema := ConstructCollectionSchemaWithPartitionKey(collectionName, fieldName2Types, testInt64Field, testVarCharField, false) diff --git a/internal/proxy/multi_rate_limiter.go b/internal/proxy/multi_rate_limiter.go index b2e08b4cf66b2..4e4ba74a5b16a 100644 --- a/internal/proxy/multi_rate_limiter.go +++ b/internal/proxy/multi_rate_limiter.go @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/ratelimitutil" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -68,37 +69,37 @@ func NewMultiRateLimiter() *MultiRateLimiter { } // Check checks if request would be limited or denied. -func (m *MultiRateLimiter) Check(collectionID int64, rt internalpb.RateType, n int) commonpb.ErrorCode { +func (m *MultiRateLimiter) Check(collectionID int64, rt internalpb.RateType, n int) error { if !Params.QuotaConfig.QuotaAndLimitsEnabled.GetAsBool() { - return commonpb.ErrorCode_Success + return nil } m.quotaStatesMu.RLock() defer m.quotaStatesMu.RUnlock() - checkFunc := func(limiter *rateLimiter) commonpb.ErrorCode { + checkFunc := func(limiter *rateLimiter) error { if limiter == nil { - return commonpb.ErrorCode_Success + return nil } limit, rate := limiter.limit(rt, n) if rate == 0 { - return limiter.getErrorCode(rt) + return limiter.getError(rt) } if limit { - return commonpb.ErrorCode_RateLimit + return merr.WrapErrServiceRateLimit(rate) } - return commonpb.ErrorCode_Success + return nil } // first, check global level rate limits ret := checkFunc(m.globalDDLLimiter) // second check collection level rate limits - if ret == commonpb.ErrorCode_Success && !IsDDLRequest(rt) { + if ret == nil && !IsDDLRequest(rt) { // only dml and dql have collection level rate limits ret = checkFunc(m.collectionLimiters[collectionID]) - if ret != commonpb.ErrorCode_Success { + if ret != nil { m.globalDDLLimiter.cancel(rt, n) } } @@ -237,18 +238,18 @@ func (rl *rateLimiter) setRates(collectionRate *proxypb.CollectionRate) error { return nil } -func (rl *rateLimiter) getErrorCode(rt internalpb.RateType) commonpb.ErrorCode { +func (rl *rateLimiter) getError(rt internalpb.RateType) error { switch rt { case internalpb.RateType_DMLInsert, internalpb.RateType_DMLUpsert, internalpb.RateType_DMLDelete, internalpb.RateType_DMLBulkLoad: if errCode, ok := rl.quotaStates.Get(milvuspb.QuotaState_DenyToWrite); ok { - return errCode + return merr.OldCodeToMerr(errCode) } case internalpb.RateType_DQLSearch, internalpb.RateType_DQLQuery: if errCode, ok := rl.quotaStates.Get(milvuspb.QuotaState_DenyToRead); ok { - return errCode + return merr.OldCodeToMerr(errCode) } } - return commonpb.ErrorCode_Success + return nil } // setRateGaugeByRateType sets ProxyLimiterRate metrics. diff --git a/internal/proxy/multi_rate_limiter_test.go b/internal/proxy/multi_rate_limiter_test.go index 6cb974b33f359..db80be17cacbd 100644 --- a/internal/proxy/multi_rate_limiter_test.go +++ b/internal/proxy/multi_rate_limiter_test.go @@ -24,14 +24,16 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/ratelimitutil" - "github.com/stretchr/testify/assert" ) func TestMultiRateLimiter(t *testing.T) { @@ -50,21 +52,20 @@ func TestMultiRateLimiter(t *testing.T) { } for _, rt := range internalpb.RateType_value { if IsDDLRequest(internalpb.RateType(rt)) { - errCode := multiLimiter.Check(collectionID, internalpb.RateType(rt), 1) - assert.Equal(t, commonpb.ErrorCode_Success, errCode) - errCode = multiLimiter.Check(collectionID, internalpb.RateType(rt), 5) - assert.Equal(t, commonpb.ErrorCode_Success, errCode) - errCode = multiLimiter.Check(collectionID, internalpb.RateType(rt), 5) - assert.Equal(t, commonpb.ErrorCode_RateLimit, errCode) + err := multiLimiter.Check(collectionID, internalpb.RateType(rt), 1) + assert.NoError(t, err) + err = multiLimiter.Check(collectionID, internalpb.RateType(rt), 5) + assert.NoError(t, err) + err = multiLimiter.Check(collectionID, internalpb.RateType(rt), 5) + assert.ErrorIs(t, err, merr.ErrServiceRateLimit) } else { - errCode := multiLimiter.Check(collectionID, internalpb.RateType(rt), 1) - assert.Equal(t, commonpb.ErrorCode_Success, errCode) - errCode = multiLimiter.Check(collectionID, internalpb.RateType(rt), math.MaxInt) - assert.Equal(t, commonpb.ErrorCode_Success, errCode) - errCode = multiLimiter.Check(collectionID, internalpb.RateType(rt), math.MaxInt) - assert.Equal(t, commonpb.ErrorCode_RateLimit, errCode) + err := multiLimiter.Check(collectionID, internalpb.RateType(rt), 1) + assert.NoError(t, err) + err = multiLimiter.Check(collectionID, internalpb.RateType(rt), math.MaxInt) + assert.NoError(t, err) + err = multiLimiter.Check(collectionID, internalpb.RateType(rt), math.MaxInt) + assert.ErrorIs(t, err, merr.ErrServiceRateLimit) } - } Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak) }) @@ -88,19 +89,19 @@ func TestMultiRateLimiter(t *testing.T) { } for _, rt := range internalpb.RateType_value { if IsDDLRequest(internalpb.RateType(rt)) { - errCode := multiLimiter.Check(1, internalpb.RateType(rt), 1) - assert.Equal(t, commonpb.ErrorCode_Success, errCode) - errCode = multiLimiter.Check(1, internalpb.RateType(rt), 5) - assert.Equal(t, commonpb.ErrorCode_Success, errCode) - errCode = multiLimiter.Check(1, internalpb.RateType(rt), 5) - assert.Equal(t, commonpb.ErrorCode_RateLimit, errCode) + err := multiLimiter.Check(1, internalpb.RateType(rt), 1) + assert.NoError(t, err) + err = multiLimiter.Check(1, internalpb.RateType(rt), 5) + assert.NoError(t, err) + err = multiLimiter.Check(1, internalpb.RateType(rt), 5) + assert.ErrorIs(t, err, merr.ErrServiceRateLimit) } else { - errCode := multiLimiter.Check(1, internalpb.RateType(rt), 1) - assert.Equal(t, commonpb.ErrorCode_Success, errCode) - errCode = multiLimiter.Check(2, internalpb.RateType(rt), 1) - assert.Equal(t, commonpb.ErrorCode_Success, errCode) - errCode = multiLimiter.Check(3, internalpb.RateType(rt), 1) - assert.Equal(t, commonpb.ErrorCode_RateLimit, errCode) + err := multiLimiter.Check(1, internalpb.RateType(rt), 1) + assert.NoError(t, err) + err = multiLimiter.Check(2, internalpb.RateType(rt), 1) + assert.NoError(t, err) + err = multiLimiter.Check(3, internalpb.RateType(rt), 1) + assert.ErrorIs(t, err, merr.ErrServiceRateLimit) } } Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak) @@ -112,8 +113,8 @@ func TestMultiRateLimiter(t *testing.T) { bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue() paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "false") for _, rt := range internalpb.RateType_value { - errCode := multiLimiter.Check(collectionID, internalpb.RateType(rt), 1) - assert.Equal(t, commonpb.ErrorCode_Success, errCode) + err := multiLimiter.Check(collectionID, internalpb.RateType(rt), 1) + assert.NoError(t, err) } Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak) }) @@ -125,8 +126,8 @@ func TestMultiRateLimiter(t *testing.T) { multiLimiter := NewMultiRateLimiter() bak := Params.QuotaConfig.QuotaAndLimitsEnabled.GetValue() paramtable.Get().Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, "true") - errCode := multiLimiter.Check(collectionID, internalpb.RateType_DMLInsert, 1*1024*1024) - assert.Equal(t, commonpb.ErrorCode_Success, errCode) + err := multiLimiter.Check(collectionID, internalpb.RateType_DMLInsert, 1*1024*1024) + assert.NoError(t, err) Params.Save(Params.QuotaConfig.QuotaAndLimitsEnabled.Key, bak) Params.Save(Params.QuotaConfig.DMLMaxInsertRate.Key, bakInsertRate) } @@ -282,8 +283,8 @@ func TestRateLimiter(t *testing.T) { }, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_ForceDeny, limiter.getErrorCode(internalpb.RateType_DQLQuery)) - assert.Equal(t, commonpb.ErrorCode_DiskQuotaExhausted, limiter.getErrorCode(internalpb.RateType_DMLInsert)) + assert.ErrorIs(t, limiter.getError(internalpb.RateType_DQLQuery), merr.ErrServiceForceDeny) + assert.Equal(t, limiter.getError(internalpb.RateType_DMLInsert), merr.ErrServiceDiskLimitExceeded) }) t.Run("tests refresh rate by config", func(t *testing.T) { @@ -301,7 +302,7 @@ func TestRateLimiter(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // avoid production precision issues when comparing 0-terminated numbers - newRate := fmt.Sprintf("%.3f1", rand.Float64()) + newRate := fmt.Sprintf("%.2f1", rand.Float64()) etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/collectionRate", newRate) etcdCli.KV.Put(ctx, "by-dev/config/quotaAndLimits/ddl/partitionRate", "invalid") diff --git a/internal/proxy/plan_parser.go b/internal/proxy/plan_parser.go deleted file mode 100644 index cb1314383a452..0000000000000 --- a/internal/proxy/plan_parser.go +++ /dev/null @@ -1,918 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxy - -import ( - "fmt" - "math" - "strings" - - ant_ast "github.com/antonmedv/expr/ast" - ant_parser "github.com/antonmedv/expr/parser" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/proto/planpb" - "github.com/milvus-io/milvus/pkg/util/typeutil" -) - -type parserContext struct { - schema *typeutil.SchemaHelper -} - -type optimizer struct { - err error -} - -func (*optimizer) Enter(*ant_ast.Node) {} - -func (optimizer *optimizer) Exit(node *ant_ast.Node) { - patch := func(newNode ant_ast.Node) { - ant_ast.Patch(node, newNode) - } - - switch node := (*node).(type) { - case *ant_ast.UnaryNode: - switch node.Operator { - case "-": - if i, ok := node.Node.(*ant_ast.IntegerNode); ok { - patch(&ant_ast.IntegerNode{Value: -i.Value}) - } else if i, ok := node.Node.(*ant_ast.FloatNode); ok { - patch(&ant_ast.FloatNode{Value: -i.Value}) - } else { - optimizer.err = fmt.Errorf("invalid data type") - return - } - case "+": - if i, ok := node.Node.(*ant_ast.IntegerNode); ok { - patch(&ant_ast.IntegerNode{Value: i.Value}) - } else if i, ok := node.Node.(*ant_ast.FloatNode); ok { - patch(&ant_ast.FloatNode{Value: i.Value}) - } else { - optimizer.err = fmt.Errorf("invalid data type") - return - } - } - - case *ant_ast.BinaryNode: - floatNodeLeft, leftFloat := node.Left.(*ant_ast.FloatNode) - integerNodeLeft, leftInteger := node.Left.(*ant_ast.IntegerNode) - floatNodeRight, rightFloat := node.Right.(*ant_ast.FloatNode) - integerNodeRight, rightInteger := node.Right.(*ant_ast.IntegerNode) - - // Check IdentifierNodes - identifierNodeLeft, leftIdentifier := node.Left.(*ant_ast.IdentifierNode) - identifierNodeRight, rightIdentifier := node.Right.(*ant_ast.IdentifierNode) - - switch node.Operator { - case "+": - funcName, err := getFuncNameByNodeOp(node.Operator) - if err != nil { - optimizer.err = err - return - } - if leftFloat && rightFloat { - patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value + floatNodeRight.Value}) - } else if leftFloat && rightInteger { - patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value + float64(integerNodeRight.Value)}) - } else if leftInteger && rightFloat { - patch(&ant_ast.FloatNode{Value: float64(integerNodeLeft.Value) + floatNodeRight.Value}) - } else if leftInteger && rightInteger { - patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value + integerNodeRight.Value}) - } else if leftIdentifier && rightFloat { - patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, floatNodeRight}}) - } else if leftIdentifier && rightInteger { - patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, integerNodeRight}}) - } else if leftFloat && rightIdentifier { - patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeRight, floatNodeLeft}}) - } else if leftInteger && rightIdentifier { - patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeRight, integerNodeLeft}}) - } else { - optimizer.err = fmt.Errorf("invalid data type") - return - } - case "-": - funcName, err := getFuncNameByNodeOp(node.Operator) - if err != nil { - optimizer.err = err - return - } - if leftFloat && rightFloat { - patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value - floatNodeRight.Value}) - } else if leftFloat && rightInteger { - patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value - float64(integerNodeRight.Value)}) - } else if leftInteger && rightFloat { - patch(&ant_ast.FloatNode{Value: float64(integerNodeLeft.Value) - floatNodeRight.Value}) - } else if leftInteger && rightInteger { - patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value - integerNodeRight.Value}) - } else if leftIdentifier && rightFloat { - patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, floatNodeRight}}) - } else if leftIdentifier && rightInteger { - patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, integerNodeRight}}) - } else if leftFloat && rightIdentifier { - optimizer.err = fmt.Errorf("field as right operand is not yet supported for (%s) operator", node.Operator) - return - } else if leftInteger && rightIdentifier { - optimizer.err = fmt.Errorf("field as right operand is not yet supported for (%s) operator", node.Operator) - return - } else { - optimizer.err = fmt.Errorf("invalid data type") - return - } - case "*": - funcName, err := getFuncNameByNodeOp(node.Operator) - if err != nil { - optimizer.err = err - return - } - if leftFloat && rightFloat { - patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value * floatNodeRight.Value}) - } else if leftFloat && rightInteger { - patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value * float64(integerNodeRight.Value)}) - } else if leftInteger && rightFloat { - patch(&ant_ast.FloatNode{Value: float64(integerNodeLeft.Value) * floatNodeRight.Value}) - } else if leftInteger && rightInteger { - patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value * integerNodeRight.Value}) - } else if leftIdentifier && rightFloat { - patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, floatNodeRight}}) - } else if leftIdentifier && rightInteger { - patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, integerNodeRight}}) - } else if leftFloat && rightIdentifier { - patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeRight, floatNodeLeft}}) - } else if leftInteger && rightIdentifier { - patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeRight, integerNodeLeft}}) - } else { - optimizer.err = fmt.Errorf("invalid data type") - return - } - case "/": - funcName, err := getFuncNameByNodeOp(node.Operator) - if err != nil { - optimizer.err = err - return - } - if leftFloat && rightFloat { - if floatNodeRight.Value == 0 { - optimizer.err = fmt.Errorf("divide by zero") - return - } - patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value / floatNodeRight.Value}) - } else if leftFloat && rightInteger { - if integerNodeRight.Value == 0 { - optimizer.err = fmt.Errorf("divide by zero") - return - } - patch(&ant_ast.FloatNode{Value: floatNodeLeft.Value / float64(integerNodeRight.Value)}) - } else if leftInteger && rightFloat { - if floatNodeRight.Value == 0 { - optimizer.err = fmt.Errorf("divide by zero") - return - } - patch(&ant_ast.FloatNode{Value: float64(integerNodeLeft.Value) / floatNodeRight.Value}) - } else if leftInteger && rightInteger { - if integerNodeRight.Value == 0 { - optimizer.err = fmt.Errorf("divide by zero") - return - } - patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value / integerNodeRight.Value}) - } else if leftIdentifier && rightFloat { - if floatNodeRight.Value == 0 { - optimizer.err = fmt.Errorf("divide by zero") - return - } - patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, floatNodeRight}}) - } else if leftIdentifier && rightInteger { - if integerNodeRight.Value == 0 { - optimizer.err = fmt.Errorf("divide by zero") - return - } - patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, integerNodeRight}}) - } else if leftFloat && rightIdentifier { - optimizer.err = fmt.Errorf("field as right operand is not yet supported for (%s) operator", node.Operator) - return - } else if leftInteger && rightIdentifier { - optimizer.err = fmt.Errorf("field as right operand is not yet supported for (%s) operator", node.Operator) - return - } else { - optimizer.err = fmt.Errorf("invalid data type") - return - } - case "%": - funcName, err := getFuncNameByNodeOp(node.Operator) - if err != nil { - optimizer.err = err - return - } - if leftInteger && rightInteger { - if integerNodeRight.Value == 0 { - optimizer.err = fmt.Errorf("modulo by zero") - return - } - patch(&ant_ast.IntegerNode{Value: integerNodeLeft.Value % integerNodeRight.Value}) - } else if leftIdentifier && rightInteger { - if integerNodeRight.Value == 0 { - optimizer.err = fmt.Errorf("modulo by zero") - return - } - patch(&ant_ast.FunctionNode{Name: funcName, Arguments: []ant_ast.Node{identifierNodeLeft, integerNodeRight}}) - } else if leftInteger && rightIdentifier { - optimizer.err = fmt.Errorf("field as right operand is not yet supported for (%s) operator", node.Operator) - return - } else { - optimizer.err = fmt.Errorf("invalid data type") - return - } - case "**": - if leftFloat && rightFloat { - patch(&ant_ast.FloatNode{Value: math.Pow(floatNodeLeft.Value, floatNodeRight.Value)}) - } else if leftFloat && rightInteger { - patch(&ant_ast.FloatNode{Value: math.Pow(floatNodeLeft.Value, float64(integerNodeRight.Value))}) - } else if leftInteger && rightFloat { - patch(&ant_ast.FloatNode{Value: math.Pow(float64(integerNodeLeft.Value), floatNodeRight.Value)}) - } else if leftInteger && rightInteger { - patch(&ant_ast.IntegerNode{Value: int(math.Pow(float64(integerNodeLeft.Value), float64(integerNodeRight.Value)))}) - } else { - optimizer.err = fmt.Errorf("invalid data type") - return - } - } - } -} - -func parseExpr(schema *typeutil.SchemaHelper, exprStr string) (*planpb.Expr, error) { - if exprStr == "" { - return nil, nil - } - ast, err := ant_parser.Parse(exprStr) - if err != nil { - return nil, err - } - - optimizer := &optimizer{} - ant_ast.Walk(&ast.Node, optimizer) - if optimizer.err != nil { - return nil, optimizer.err - } - - pc := parserContext{schema} - expr, err := pc.handleExpr(&ast.Node) - if err != nil { - return nil, err - } - - return expr, nil -} - -func createColumnInfo(field *schemapb.FieldSchema) *planpb.ColumnInfo { - return &planpb.ColumnInfo{ - FieldId: field.FieldID, - DataType: field.DataType, - IsPrimaryKey: field.IsPrimaryKey, - IsPartitionKey: field.IsPartitionKey, - } -} - -func isSameOrder(opStr1, opStr2 string) bool { - isLess1 := (opStr1 == "<") || (opStr1 == "<=") - isLess2 := (opStr2 == "<") || (opStr2 == "<=") - return isLess1 == isLess2 -} - -var opMap = map[planpb.OpType]string{ - planpb.OpType_Invalid: "invalid", - planpb.OpType_GreaterThan: ">", - planpb.OpType_GreaterEqual: ">=", - planpb.OpType_LessThan: "<", - planpb.OpType_LessEqual: "<=", - planpb.OpType_Equal: "==", - planpb.OpType_NotEqual: "!=", -} - -func getCompareOpType(opStr string, reverse bool) (op planpb.OpType) { - switch opStr { - case ">": - if reverse { - op = planpb.OpType_LessThan - } else { - op = planpb.OpType_GreaterThan - } - case "<": - if reverse { - op = planpb.OpType_GreaterThan - } else { - op = planpb.OpType_LessThan - } - case ">=": - if reverse { - op = planpb.OpType_LessEqual - } else { - op = planpb.OpType_GreaterEqual - } - case "<=": - if reverse { - op = planpb.OpType_GreaterEqual - } else { - op = planpb.OpType_LessEqual - } - case "==": - op = planpb.OpType_Equal - case "!=": - op = planpb.OpType_NotEqual - case "startsWith": - op = planpb.OpType_PrefixMatch - case "endsWith": - op = planpb.OpType_PostfixMatch - default: - op = planpb.OpType_Invalid - } - return op -} - -func getLogicalOpType(opStr string) planpb.BinaryExpr_BinaryOp { - switch opStr { - case "&&", "and": - return planpb.BinaryExpr_LogicalAnd - case "||", "or": - return planpb.BinaryExpr_LogicalOr - default: - return planpb.BinaryExpr_Invalid - } -} - -func getArithOpType(funcName string) (planpb.ArithOpType, error) { - var op planpb.ArithOpType - - switch funcName { - case "add": - op = planpb.ArithOpType_Add - case "sub": - op = planpb.ArithOpType_Sub - case "mul": - op = planpb.ArithOpType_Mul - case "div": - op = planpb.ArithOpType_Div - case "mod": - op = planpb.ArithOpType_Mod - default: - return op, fmt.Errorf("unsupported or invalid arith op type: %s", funcName) - } - return op, nil -} - -func getFuncNameByNodeOp(nodeOp string) (string, error) { - var funcName string - - switch nodeOp { - case "+": - funcName = "add" - case "-": - funcName = "sub" - case "*": - funcName = "mul" - case "/": - funcName = "div" - case "%": - funcName = "mod" - default: - return funcName, fmt.Errorf("no defined funcName assigned to nodeOp: %s", nodeOp) - } - return funcName, nil -} - -func parseBoolNode(nodeRaw *ant_ast.Node) *ant_ast.BoolNode { - switch node := (*nodeRaw).(type) { - case *ant_ast.IdentifierNode: - // bool node only accept value 'true' or 'false' - val := strings.ToLower(node.Value) - if val == "true" { - return &ant_ast.BoolNode{ - Value: true, - } - } else if val == "false" { - return &ant_ast.BoolNode{ - Value: false, - } - } else { - return nil - } - default: - return nil - } -} - -func (pc *parserContext) createCmpExpr(left, right ant_ast.Node, operator string) (*planpb.Expr, error) { - if boolNode := parseBoolNode(&left); boolNode != nil { - left = boolNode - } - if boolNode := parseBoolNode(&right); boolNode != nil { - right = boolNode - } - idNodeLeft, okLeft := left.(*ant_ast.IdentifierNode) - idNodeRight, okRight := right.(*ant_ast.IdentifierNode) - - if okLeft && okRight { - leftField, err := pc.handleIdentifier(idNodeLeft) - if err != nil { - return nil, err - } - rightField, err := pc.handleIdentifier(idNodeRight) - if err != nil { - return nil, err - } - op := getCompareOpType(operator, false) - if op == planpb.OpType_Invalid { - return nil, fmt.Errorf("invalid binary operator(%s)", operator) - } - expr := &planpb.Expr{ - Expr: &planpb.Expr_CompareExpr{ - CompareExpr: &planpb.CompareExpr{ - LeftColumnInfo: createColumnInfo(leftField), - RightColumnInfo: createColumnInfo(rightField), - Op: op, - }, - }, - } - return expr, nil - } - - var idNode *ant_ast.IdentifierNode - var reverse bool - var valueNode *ant_ast.Node - if okLeft { - idNode = idNodeLeft - reverse = false - valueNode = &right - } else if okRight { - idNode = idNodeRight - reverse = true - valueNode = &left - } else { - return nil, fmt.Errorf("compare expr has no identifier") - } - - field, err := pc.handleIdentifier(idNode) - if err != nil { - return nil, err - } - - val, err := pc.handleLeafValue(valueNode, field.DataType) - if err != nil { - return nil, err - } - - op := getCompareOpType(operator, reverse) - if op == planpb.OpType_Invalid { - return nil, fmt.Errorf("invalid binary operator(%s)", operator) - } - - expr := &planpb.Expr{ - Expr: &planpb.Expr_UnaryRangeExpr{ - UnaryRangeExpr: &planpb.UnaryRangeExpr{ - ColumnInfo: createColumnInfo(field), - Op: op, - Value: val, - }, - }, - } - return expr, nil -} - -func (pc *parserContext) createBinaryArithOpEvalExpr(left *ant_ast.FunctionNode, right *ant_ast.Node, operator string) (*planpb.Expr, error) { - switch operator { - case "==", "!=": - binArithOp, err := pc.handleFunction(left) - if err != nil { - return nil, fmt.Errorf("createBinaryArithOpEvalExpr: %v", err) - } - op := getCompareOpType(operator, false) - val, err := pc.handleLeafValue(right, binArithOp.ColumnInfo.DataType) - if err != nil { - return nil, err - } - - expr := &planpb.Expr{ - Expr: &planpb.Expr_BinaryArithOpEvalRangeExpr{ - BinaryArithOpEvalRangeExpr: &planpb.BinaryArithOpEvalRangeExpr{ - ColumnInfo: binArithOp.ColumnInfo, - ArithOp: binArithOp.ArithOp, - RightOperand: binArithOp.RightOperand, - Op: op, - Value: val, - }, - }, - } - return expr, nil - } - return nil, fmt.Errorf("operator(%s) not yet supported for function nodes", operator) -} - -func (pc *parserContext) handleCmpExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) { - return pc.createCmpExpr(node.Left, node.Right, node.Operator) -} - -func (pc *parserContext) handleBinaryArithCmpExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) { - leftNode, funcNodeLeft := node.Left.(*ant_ast.FunctionNode) - rightNode, funcNodeRight := node.Right.(*ant_ast.FunctionNode) - - if funcNodeLeft && funcNodeRight { - return nil, fmt.Errorf("left and right are both expression are not supported") - } else if funcNodeRight { - // Only the right node is a function node - op := getCompareOpType(node.Operator, true) - if op == planpb.OpType_Invalid { - return nil, fmt.Errorf("invalid right expression") - } - return pc.createBinaryArithOpEvalExpr(rightNode, &node.Left, opMap[op]) - } else if funcNodeLeft { - // Only the left node is a function node - return pc.createBinaryArithOpEvalExpr(leftNode, &node.Right, node.Operator) - } else { - // Both left and right are not function nodes, pass to createCmpExpr - return pc.createCmpExpr(node.Left, node.Right, node.Operator) - } -} - -func (pc *parserContext) handleLogicalExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) { - op := getLogicalOpType(node.Operator) - if op == planpb.BinaryExpr_Invalid { - return nil, fmt.Errorf("invalid logical operator(%s)", node.Operator) - } - - leftExpr, err := pc.handleExpr(&node.Left) - if err != nil { - return nil, err - } - - rightExpr, err := pc.handleExpr(&node.Right) - if err != nil { - return nil, err - } - - expr := &planpb.Expr{ - Expr: &planpb.Expr_BinaryExpr{ - BinaryExpr: &planpb.BinaryExpr{ - Op: op, - Left: leftExpr, - Right: rightExpr, - }, - }, - } - return expr, nil -} - -func (pc *parserContext) handleArrayExpr(node *ant_ast.Node, dataType schemapb.DataType) ([]*planpb.GenericValue, error) { - arrayNode, ok2 := (*node).(*ant_ast.ArrayNode) - if !ok2 { - return nil, fmt.Errorf("right operand of the InExpr must be array") - } - var arr []*planpb.GenericValue - for _, element := range arrayNode.Nodes { - // use value inside - // #nosec G601 - val, err := pc.handleLeafValue(&element, dataType) - if err != nil { - return nil, err - } - arr = append(arr, val) - } - return arr, nil -} - -func (pc *parserContext) handleInExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) { - if node.Operator != "in" && node.Operator != "not in" { - return nil, fmt.Errorf("invalid operator(%s)", node.Operator) - } - idNode, ok := node.Left.(*ant_ast.IdentifierNode) - if !ok { - return nil, fmt.Errorf("left operand of the InExpr must be identifier") - } - field, err := pc.handleIdentifier(idNode) - if err != nil { - return nil, err - } - arrayData, err := pc.handleArrayExpr(&node.Right, field.DataType) - if err != nil { - return nil, err - } - - expr := &planpb.Expr{ - Expr: &planpb.Expr_TermExpr{ - TermExpr: &planpb.TermExpr{ - ColumnInfo: createColumnInfo(field), - Values: arrayData, - }, - }, - } - - if node.Operator == "not in" { - return pc.createNotExpr(expr) - } - return expr, nil -} - -func (pc *parserContext) combineUnaryRangeExpr(a, b *planpb.UnaryRangeExpr) *planpb.Expr { - if a.Op == planpb.OpType_LessEqual || a.Op == planpb.OpType_LessThan { - a, b = b, a - } - - lowerInclusive := (a.Op == planpb.OpType_GreaterEqual) - upperInclusive := (b.Op == planpb.OpType_LessEqual) - - expr := &planpb.Expr{ - Expr: &planpb.Expr_BinaryRangeExpr{ - BinaryRangeExpr: &planpb.BinaryRangeExpr{ - ColumnInfo: a.ColumnInfo, - LowerInclusive: lowerInclusive, - UpperInclusive: upperInclusive, - LowerValue: a.Value, - UpperValue: b.Value, - }, - }, - } - return expr -} - -func (pc *parserContext) handleMultiCmpExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) { - exprs := []*planpb.Expr{} - curNode := node - - // handle multiple relational operators - for { - binNodeLeft, LeftOk := curNode.Left.(*ant_ast.BinaryNode) - if !LeftOk { - expr, err := pc.handleCmpExpr(curNode) - if err != nil { - return nil, err - } - exprs = append(exprs, expr) - break - } - if isSameOrder(node.Operator, binNodeLeft.Operator) { - expr, err := pc.createCmpExpr(binNodeLeft.Right, curNode.Right, curNode.Operator) - if err != nil { - return nil, err - } - exprs = append(exprs, expr) - curNode = binNodeLeft - } else { - return nil, fmt.Errorf("illegal multi-range expr") - } - } - - // combine UnaryRangeExpr to BinaryRangeExpr - var lastExpr *planpb.UnaryRangeExpr - for i := len(exprs) - 1; i >= 0; i-- { - if expr, ok := exprs[i].Expr.(*planpb.Expr_UnaryRangeExpr); ok { - if lastExpr != nil && expr.UnaryRangeExpr.ColumnInfo.FieldId == lastExpr.ColumnInfo.FieldId { - binaryRangeExpr := pc.combineUnaryRangeExpr(expr.UnaryRangeExpr, lastExpr) - exprs = append(exprs[0:i], append([]*planpb.Expr{binaryRangeExpr}, exprs[i+2:]...)...) - lastExpr = nil - } else { - lastExpr = expr.UnaryRangeExpr - } - } else { - lastExpr = nil - } - } - - // use `&&` to connect exprs - combinedExpr := exprs[len(exprs)-1] - for i := len(exprs) - 2; i >= 0; i-- { - expr := exprs[i] - combinedExpr = &planpb.Expr{ - Expr: &planpb.Expr_BinaryExpr{ - BinaryExpr: &planpb.BinaryExpr{ - Op: planpb.BinaryExpr_LogicalAnd, - Left: combinedExpr, - Right: expr, - }, - }, - } - } - return combinedExpr, nil -} - -func (pc *parserContext) handleBinaryExpr(node *ant_ast.BinaryNode) (*planpb.Expr, error) { - _, leftArithExpr := node.Left.(*ant_ast.FunctionNode) - _, rightArithExpr := node.Right.(*ant_ast.FunctionNode) - - if leftArithExpr || rightArithExpr { - return pc.handleBinaryArithCmpExpr(node) - } - - switch node.Operator { - case "<", "<=", ">", ">=": - return pc.handleMultiCmpExpr(node) - case "==", "!=", "startsWith", "endsWith": - return pc.handleCmpExpr(node) - case "and", "or", "&&", "||": - return pc.handleLogicalExpr(node) - case "in", "not in": - return pc.handleInExpr(node) - } - return nil, fmt.Errorf("unsupported binary operator %s", node.Operator) -} - -func (pc *parserContext) createNotExpr(childExpr *planpb.Expr) (*planpb.Expr, error) { - expr := &planpb.Expr{ - Expr: &planpb.Expr_UnaryExpr{ - UnaryExpr: &planpb.UnaryExpr{ - Op: planpb.UnaryExpr_Not, - Child: childExpr, - }, - }, - } - return expr, nil -} - -func (pc *parserContext) handleLeafValue(nodeRaw *ant_ast.Node, dataType schemapb.DataType) (gv *planpb.GenericValue, err error) { - switch node := (*nodeRaw).(type) { - case *ant_ast.FloatNode: - if typeutil.IsFloatingType(dataType) { - gv = &planpb.GenericValue{ - Val: &planpb.GenericValue_FloatVal{ - FloatVal: node.Value, - }, - } - } else { - return nil, fmt.Errorf("type mismatch") - } - case *ant_ast.IntegerNode: - if typeutil.IsFloatingType(dataType) { - gv = &planpb.GenericValue{ - Val: &planpb.GenericValue_FloatVal{ - FloatVal: float64(node.Value), - }, - } - } else if typeutil.IsIntegerType(dataType) { - gv = &planpb.GenericValue{ - Val: &planpb.GenericValue_Int64Val{ - Int64Val: int64(node.Value), - }, - } - } else { - return nil, fmt.Errorf("type mismatch") - } - case *ant_ast.BoolNode: - if typeutil.IsBoolType(dataType) { - gv = &planpb.GenericValue{ - Val: &planpb.GenericValue_BoolVal{ - BoolVal: node.Value, - }, - } - } else { - return nil, fmt.Errorf("type mismatch") - } - case *ant_ast.StringNode: - if typeutil.IsStringType(dataType) { - gv = &planpb.GenericValue{ - Val: &planpb.GenericValue_StringVal{ - StringVal: node.Value, - }, - } - } else { - return nil, fmt.Errorf("type mismatch") - } - default: - return nil, fmt.Errorf("unsupported leaf node") - } - - return gv, nil -} - -func (pc *parserContext) handleFunction(node *ant_ast.FunctionNode) (*planpb.BinaryArithOp, error) { - funcArithOp, err := getArithOpType(node.Name) - if err != nil { - return nil, err - } - - idNode, ok := node.Arguments[0].(*ant_ast.IdentifierNode) - if !ok { - return nil, fmt.Errorf("left operand of the function must be an identifier") - } - - field, err := pc.handleIdentifier(idNode) - if err != nil { - return nil, err - } - - valueNode := node.Arguments[1] - val, err := pc.handleLeafValue(&valueNode, field.DataType) - if err != nil { - return nil, err - } - - arithOp := &planpb.BinaryArithOp{ - ColumnInfo: createColumnInfo(field), - ArithOp: funcArithOp, - RightOperand: val, - } - - return arithOp, nil -} - -func (pc *parserContext) handleIdentifier(node *ant_ast.IdentifierNode) (*schemapb.FieldSchema, error) { - fieldName := node.Value - field, err := pc.schema.GetFieldFromName(fieldName) - return field, err -} - -func (pc *parserContext) handleUnaryExpr(node *ant_ast.UnaryNode) (*planpb.Expr, error) { - switch node.Operator { - case "!", "not": - subExpr, err := pc.handleExpr(&node.Node) - if err != nil { - return nil, err - } - return pc.createNotExpr(subExpr) - default: - return nil, fmt.Errorf("invalid unary operator(%s)", node.Operator) - } -} - -func (pc *parserContext) handleExpr(nodeRaw *ant_ast.Node) (*planpb.Expr, error) { - switch node := (*nodeRaw).(type) { - case *ant_ast.IdentifierNode, - *ant_ast.FloatNode, - *ant_ast.IntegerNode, - *ant_ast.BoolNode, - *ant_ast.StringNode: - return nil, fmt.Errorf("scalar expr is not supported yet") - case *ant_ast.UnaryNode: - expr, err := pc.handleUnaryExpr(node) - if err != nil { - return nil, err - } - return expr, nil - case *ant_ast.BinaryNode: - return pc.handleBinaryExpr(node) - default: - return nil, fmt.Errorf("unsupported node") - } -} - -func createQueryPlan(schemaPb *schemapb.CollectionSchema, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo) (*planpb.PlanNode, error) { - schema, err := typeutil.CreateSchemaHelper(schemaPb) - if err != nil { - return nil, err - } - - expr, err := parseExpr(schema, exprStr) - if err != nil { - return nil, err - } - vectorField, err := schema.GetFieldFromName(vectorFieldName) - if err != nil { - return nil, err - } - fieldID := vectorField.FieldID - dataType := vectorField.DataType - - if !typeutil.IsVectorType(dataType) { - return nil, fmt.Errorf("field (%s) to search is not of vector data type", vectorFieldName) - } - - planNode := &planpb.PlanNode{ - Node: &planpb.PlanNode_VectorAnns{ - VectorAnns: &planpb.VectorANNS{ - IsBinary: dataType == schemapb.DataType_BinaryVector, - Predicates: expr, - QueryInfo: queryInfo, - PlaceholderTag: "$0", - FieldId: fieldID, - }, - }, - } - return planNode, nil -} - -func createExprPlan(schemaPb *schemapb.CollectionSchema, exprStr string) (*planpb.PlanNode, error) { - schema, err := typeutil.CreateSchemaHelper(schemaPb) - if err != nil { - return nil, err - } - - expr, err := parseExpr(schema, exprStr) - if err != nil { - return nil, err - } - - planNode := &planpb.PlanNode{ - Node: &planpb.PlanNode_Predicates{ - Predicates: expr, - }, - } - return planNode, nil -} diff --git a/internal/proxy/plan_parser_test.go b/internal/proxy/plan_parser_test.go deleted file mode 100644 index 43e64f4355df6..0000000000000 --- a/internal/proxy/plan_parser_test.go +++ /dev/null @@ -1,640 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package proxy - -import ( - "fmt" - "testing" - - "github.com/milvus-io/milvus/internal/parser/planparserv2" - - ant_ast "github.com/antonmedv/expr/ast" - ant_parser "github.com/antonmedv/expr/parser" - - "github.com/golang/protobuf/proto" - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/proto/planpb" - "github.com/milvus-io/milvus/pkg/util/typeutil" -) - -func newTestSchema() *schemapb.CollectionSchema { - fields := []*schemapb.FieldSchema{ - {FieldID: 0, Name: "FieldID", IsPrimaryKey: false, Description: "field no.1", DataType: schemapb.DataType_Int64}, - } - - for name, value := range schemapb.DataType_value { - dataType := schemapb.DataType(value) - if !typeutil.IsIntegerType(dataType) && !typeutil.IsFloatingType(dataType) && !typeutil.IsVectorType(dataType) && !typeutil.IsStringType(dataType) { - continue - } - newField := &schemapb.FieldSchema{ - FieldID: int64(100 + value), Name: name + "Field", IsPrimaryKey: false, Description: "", DataType: dataType, - } - fields = append(fields, newField) - } - - return &schemapb.CollectionSchema{ - Name: "test", - Description: "schema for test used", - AutoID: true, - Fields: fields, - EnableDynamicField: true, - } -} - -func assertValidExpr(t *testing.T, schema *typeutil.SchemaHelper, exprStr string) { - // t.Log("expr: ", exprStr) - - _, err := parseExpr(schema, exprStr) - assert.Nil(t, err, exprStr) - - // t.Log("AST1:") - // planparserv2.ShowExpr(expr1) -} - -func assertValidExprV2(t *testing.T, schema *typeutil.SchemaHelper, exprStr string) { - expr1, err := parseExpr(schema, exprStr) - assert.NoError(t, err) - - expr2, err := planparserv2.ParseExpr(schema, exprStr) - assert.NoError(t, err) - - if !planparserv2.CheckPredicatesIdentical(expr1, expr2) { - t.Log("expr: ", exprStr) - - t.Log("AST1:") - planparserv2.ShowExpr(expr1) - - t.Log("AST2:") - planparserv2.ShowExpr(expr2) - - t.Errorf("parsed asts are not identical") - } -} - -func assertInvalidExpr(t *testing.T, schema *typeutil.SchemaHelper, exprStr string) { - _, err := parseExpr(schema, exprStr) - assert.Error(t, err, exprStr) - - _, err = planparserv2.ParseExpr(schema, exprStr) - assert.Error(t, err, exprStr) -} - -func assertValidSearchPlan(t *testing.T, schema *schemapb.CollectionSchema, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo) { - _, err := createQueryPlan(schema, exprStr, vectorFieldName, queryInfo) - assert.NoError(t, err) -} - -func assertValidSearchPlanV2(t *testing.T, schema *schemapb.CollectionSchema, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo) { - planProto1, err := createQueryPlan(schema, exprStr, vectorFieldName, queryInfo) - assert.NoError(t, err) - - planProto2, err := planparserv2.CreateSearchPlan(schema, exprStr, vectorFieldName, queryInfo) - assert.NoError(t, err) - - expr1 := planProto1.GetVectorAnns().GetPredicates() - assert.NotNil(t, expr1) - - expr2 := planProto2.GetVectorAnns().GetPredicates() - assert.NotNil(t, expr2) - - if !planparserv2.CheckPredicatesIdentical(expr1, expr2) { - t.Log("expr: ", exprStr) - - t.Log("AST1:") - planparserv2.ShowExpr(expr1) - - t.Log("AST2:") - planparserv2.ShowExpr(expr2) - - t.Errorf("parsed asts are not identical") - } -} - -func assertInvalidSearchPlan(t *testing.T, schema *schemapb.CollectionSchema, exprStr string, vectorFieldName string, queryInfo *planpb.QueryInfo) { - _, err := createQueryPlan(schema, exprStr, vectorFieldName, queryInfo) - assert.Error(t, err, exprStr) - - _, err = planparserv2.CreateSearchPlan(schema, exprStr, vectorFieldName, queryInfo) - assert.Error(t, err, exprStr) -} - -func TestParseExpr_Naive(t *testing.T) { - schemaPb := newTestSchema() - schema, err := typeutil.CreateSchemaHelper(schemaPb) - assert.NoError(t, err) - - t.Run("test UnaryNode", func(t *testing.T) { - exprStrs := []string{ - "Int64Field > +1", - "Int64Field > -1", - "FloatField > +1.0", - "FloatField > -1.0", - `VarCharField > "str"`, - } - for _, exprStr := range exprStrs { - assertValidExprV2(t, schema, exprStr) - } - }) - - t.Run("test string unary", func(t *testing.T) { - exprStrs := []string{ - `VarCharField startsWith "str"`, - `VarCharField endsWith "str"`, - } - for _, exprStr := range exprStrs { - assertValidExpr(t, schema, exprStr) - } - }) - - t.Run("test UnaryNode invalid", func(t *testing.T) { - exprStrs := []string{ - "Int64Field > +aa", - "FloatField > -aa", - `VarCharField > -aa`, - } - for _, exprStr := range exprStrs { - assertInvalidExpr(t, schema, exprStr) - } - }) - - t.Run("test BinaryNode", func(t *testing.T) { - exprStrs := []string{ - // "+" - "FloatField > 1 + 2", - "FloatField > 1 + 2.0", - "FloatField > 1.0 + 2", - "FloatField > 1.0 + 2.0", - // "-" - "FloatField > 1 - 2", - "FloatField > 1 - 2.0", - "FloatField > 1.0 - 2", - "FloatField > 1.0 - 2.0", - // "*" - "FloatField > 1 * 2", - "FloatField > 1 * 2.0", - "FloatField > 1.0 * 2", - "FloatField > 1.0 * 2.0", - // "/" - "FloatField > 1 / 2", - "FloatField > 1 / 2.0", - "FloatField > 1.0 / 2", - "FloatField > 1.0 / 2.0", - // "%" - "FloatField > 1 % 2", - // "**" - "FloatField > 1 ** 2", - "FloatField > 1 ** 2.0", - "FloatField > 1.0 ** 2", - "FloatField > 1.0 ** 2.0", - } - for _, exprStr := range exprStrs { - assertValidExprV2(t, schema, exprStr) - } - }) - - t.Run("test BinaryNode invalid", func(t *testing.T) { - exprStrs := []string{ - // "+" - "FloatField > 1 + aa", - "FloatField > aa + 2.0", - // "-" - "FloatField > 1 - aa", - "FloatField > aa - 2.0", - // "*" - "FloatField > 1 * aa", - "FloatField > aa * 2.0", - // "/" - "FloatField > 1 / 0", - "FloatField > 1 / 0.0", - "FloatField > 1.0 / 0", - "FloatField > 1.0 / 0.0", - "FloatField > 1 / aa", - "FloatField > aa / 2.0", - // "%" - "FloatField > 1 % aa", - "FloatField > 1 % 0", - "FloatField > 1 % 0.0", - // "**" - "FloatField > 1 ** aa", - "FloatField > aa ** 2.0", - } - for _, exprStr := range exprStrs { - assertInvalidExpr(t, schema, exprStr) - } - }) - - t.Run("test BinaryArithOpNode", func(t *testing.T) { - exprStrs := []string{ - // "+" - "FloatField + 1.2 == 3", - "Int64Field + 3 == 5", - "1.2 + FloatField != 3", - "3 + Int64Field != 5", - // "-" - "FloatField - 1.2 == 3", - "Int64Field - 3 != 5", - // "*" - "FloatField * 1.2 == 3", - "Int64Field * 3 == 5", - "1.2 * FloatField != 3", - "3 * Int64Field != 5", - // "/" - "FloatField / 1.2 == 3", - "Int64Field / 3 != 5", - // "%" - "Int64Field % 7 == 5", - } - for _, exprStr := range exprStrs { - assertValidExprV2(t, schema, exprStr) - } - }) - - t.Run("test BinaryArithOpNode invalid", func(t *testing.T) { - exprStrs := []string{ - // "+" - "FloatField + FloatField == 20", - "Int64Field + Int64Field != 10", - // "-" - "FloatField - FloatField == 20.0", - "Int64Field - Int64Field != 10", - "10 - FloatField == 20", - "20 - Int64Field != 10", - // "*" - "FloatField * FloatField == 20", - "Int64Field * Int64Field != 10", - // "/" - "FloatField / FloatField == 20", - "Int64Field / Int64Field != 10", - "FloatField / 0 == 20", - "Int64Field / 0 != 10", - // "%" - "Int64Field % Int64Field != 10", - "FloatField % 0 == 20", - "Int64Field % 0 != 10", - "FloatField % 2.3 == 20", - } - for _, exprStr := range exprStrs { - exprProto, err := parseExpr(schema, exprStr) - assert.Error(t, err) - assert.Nil(t, exprProto) - } - }) -} - -func TestParsePlanNode_Naive(t *testing.T) { - exprStrs := []string{ - "not (Int64Field > 3)", - "not (3 > Int64Field)", - "Int64Field in [1, 2, 3]", - "Int64Field < 3 and (Int64Field > 2 || Int64Field == 1)", - "DoubleField in [1.0, 2, 3]", - "DoubleField in [1.0, 2, 3] && Int64Field < 3 or Int64Field > 2", - `not (VarCharField > "str")`, - `not ("str" > VarCharField)`, - `VarCharField in ["term0", "term1", "term2"]`, - `VarCharField < "str3" and (VarCharField > "str2" || VarCharField == "str1")`, - `DoubleField in [1.0, 2, 3] && VarCharField < "str3" or Int64Field > 2`, - } - - schema := newTestSchema() - queryInfo := &planpb.QueryInfo{ - Topk: 10, - MetricType: "L2", - SearchParams: "{\"nprobe\": 10}", - } - - for _, exprStr := range exprStrs { - assertValidSearchPlanV2(t, schema, exprStr, "FloatVectorField", queryInfo) - } - - stringFuncs := []string{ - `not (VarCharField startsWith "str")`, - `not (VarCharField endsWith "str")`, - `VarCharField < "str3" and (VarCharField startsWith "str2" || VarCharField endsWith "str1")`, - } - for _, exprStr := range stringFuncs { - assertValidSearchPlan(t, schema, exprStr, "FloatVectorField", queryInfo) - } -} - -func TestExternalParser(t *testing.T) { - ast, err := ant_parser.Parse(`!(1 < a < 2 or b in [1, 2, 3]) or (c < 3 and b > 5) and (d > "str1" or d < "str2")`) - // NOTE: probe ast here via IDE - assert.NoError(t, err) - - println(ast.Node.Location().Column) -} - -func TestExprPlan_Str(t *testing.T) { - fields := []*schemapb.FieldSchema{ - {FieldID: 100, Name: "fakevec", DataType: schemapb.DataType_FloatVector}, - {FieldID: 101, Name: "age", DataType: schemapb.DataType_Int64}, - } - - schema := &schemapb.CollectionSchema{ - Name: "default-collection", - Description: "", - AutoID: true, - Fields: fields, - } - - queryInfo := &planpb.QueryInfo{ - Topk: 10, - MetricType: "L2", - SearchParams: "{\"nprobe\": 10}", - } - - // without filter - planProto, err := createQueryPlan(schema, "", "fakevec", queryInfo) - assert.NoError(t, err) - dbgStr := proto.MarshalTextString(planProto) - println(dbgStr) - - exprStrs := []string{ - "age >= 420000 && age < 420010", // range - "age == 420000 || age == 420001 || age == 420002 || age == 420003 || age == 420004", // term - "age not in [1, 2, 3]", - } - - for _, exprStr := range exprStrs { - assertValidSearchPlanV2(t, schema, exprStr, "fakevec", queryInfo) - } -} - -func TestExprMultiRange_Str(t *testing.T) { - exprStrs := []string{ - "3 < FloatN < 4.0", - // "3 < age1 < 5 < age2 < 7 < FloatN < 9.0 < FloatN2", // no need to support this, ambiguous. - "1 + 1 < age1 < 2 * 2", - "1 - 1 < age1 < 3 / 2", - "1.0 - 1 < FloatN < 3 / 2", - "2 ** 10 > FloatN >= 7 % 4", - "0.1 ** 2 < FloatN < 2 ** 0.1", - "0.1 ** 1.1 < FloatN < 3.1 / 4", - "4.1 / 3 < FloatN < 0.0 / 5.0", - "BoolN1 == True", - "True == BoolN1", - "BoolN1 == False", - } - invalidExprs := []string{ - "BoolN1 == 1", - "BoolN1 == 0", - "BoolN1 > 0", - } - - fields := []*schemapb.FieldSchema{ - {FieldID: 100, Name: "fakevec", DataType: schemapb.DataType_FloatVector}, - {FieldID: 101, Name: "age1", DataType: schemapb.DataType_Int64}, - {FieldID: 102, Name: "age2", DataType: schemapb.DataType_Int64}, - {FieldID: 103, Name: "FloatN", DataType: schemapb.DataType_Float}, - {FieldID: 104, Name: "FloatN2", DataType: schemapb.DataType_Float}, - {FieldID: 105, Name: "BoolN1", DataType: schemapb.DataType_Bool}, - } - - schema := &schemapb.CollectionSchema{ - Name: "default-collection", - Description: "", - AutoID: true, - Fields: fields, - } - - queryInfo := &planpb.QueryInfo{ - Topk: 10, - MetricType: "L2", - SearchParams: "{\"nprobe\": 10}", - } - - for _, exprStr := range exprStrs { - assertValidSearchPlanV2(t, schema, exprStr, "fakevec", queryInfo) - } - for _, exprStr := range invalidExprs { - assertInvalidSearchPlan(t, schema, exprStr, "fakevec", queryInfo) - } -} - -func TestExprFieldCompare_Str(t *testing.T) { - exprStrs := []string{ - "age1 < age2", - // "3 < age1 <= age2 < 4", // no need to support this, ambiguous. - } - - fields := []*schemapb.FieldSchema{ - {FieldID: 100, Name: "fakevec", DataType: schemapb.DataType_FloatVector}, - {FieldID: 101, Name: "age1", DataType: schemapb.DataType_Int64}, - {FieldID: 102, Name: "age2", DataType: schemapb.DataType_Int64}, - {FieldID: 103, Name: "FloatN", DataType: schemapb.DataType_Float}, - } - - schema := &schemapb.CollectionSchema{ - Name: "default-collection", - Description: "", - AutoID: true, - Fields: fields, - } - - queryInfo := &planpb.QueryInfo{ - Topk: 10, - MetricType: "L2", - SearchParams: "{\"nprobe\": 10}", - } - - for _, exprStr := range exprStrs { - assertValidSearchPlanV2(t, schema, exprStr, "fakevec", queryInfo) - } -} - -func TestExprBinaryArithOp_Str(t *testing.T) { - exprStrs := []string{ - // Basic arithmetic - "(age1 + 5) == 2", - // Float data type - "(FloatN - 5.2) == 0", - // Other operators - "(age1 - 5) == 1", - "(age1 * 5) == 6", - "(age1 / 5) == 1", - "(age1 % 5) == 0", - // Allow for commutative property for + and * - "(6 + age1) != 2", - "(age1 * 4) != 9", - "(5 * FloatN) != 0", - "(9 * FloatN) != 0", - // Functional nodes at the right can be reversed - "0 == (age1 + 3)", - } - - unsupportedExprStrs := []string{ - // Comparison operators except for "==" and "!=" are unsupported - "(age1 + 2) > 4", - "(age1 + 2) >= 4", - "(age1 + 2) < 4", - "(age1 + 2) <= 4", - // Field as the right operand for -, /, and % operators are not supported - "(10 - age1) == 0", - "(20 / age1) == 0", - "(30 % age1) == 0", - // Modulo is not supported in the parser but the engine can handle it since fmod is used - "(FloatN % 2.1) == 0", - // Different data types are not supported - "(age1 + 20.16) == 35.16", - // Left operand of the function must be an identifier - "(10.5 / floatN) == 5.75", - } - - fields := []*schemapb.FieldSchema{ - {FieldID: 100, Name: "fakevec", DataType: schemapb.DataType_FloatVector}, - {FieldID: 101, Name: "age1", DataType: schemapb.DataType_Int64}, - {FieldID: 102, Name: "FloatN", DataType: schemapb.DataType_Float}, - } - - schema := &schemapb.CollectionSchema{ - Name: "default-collection", - Description: "", - AutoID: true, - Fields: fields, - } - - queryInfo := &planpb.QueryInfo{ - Topk: 10, - MetricType: "L2", - SearchParams: "{\"nprobe\": 10}", - } - - for _, exprStr := range exprStrs { - assertValidSearchPlanV2(t, schema, exprStr, "fakevec", queryInfo) - } - - for _, exprStr := range unsupportedExprStrs { - assertInvalidSearchPlan(t, schema, exprStr, "fakevec", queryInfo) - } -} - -func TestPlanParseAPIs(t *testing.T) { - t.Run("get compare op type", func(t *testing.T) { - var op planpb.OpType - var reverse bool - - reverse = false - op = getCompareOpType(">", reverse) - assert.Equal(t, planpb.OpType_GreaterThan, op) - op = getCompareOpType(">=", reverse) - assert.Equal(t, planpb.OpType_GreaterEqual, op) - op = getCompareOpType("<", reverse) - assert.Equal(t, planpb.OpType_LessThan, op) - op = getCompareOpType("<=", reverse) - assert.Equal(t, planpb.OpType_LessEqual, op) - op = getCompareOpType("==", reverse) - assert.Equal(t, planpb.OpType_Equal, op) - op = getCompareOpType("!=", reverse) - assert.Equal(t, planpb.OpType_NotEqual, op) - op = getCompareOpType("*", reverse) - assert.Equal(t, planpb.OpType_Invalid, op) - op = getCompareOpType("startsWith", reverse) - assert.Equal(t, planpb.OpType_PrefixMatch, op) - op = getCompareOpType("endsWith", reverse) - assert.Equal(t, planpb.OpType_PostfixMatch, op) - - reverse = true - op = getCompareOpType(">", reverse) - assert.Equal(t, planpb.OpType_LessThan, op) - op = getCompareOpType(">=", reverse) - assert.Equal(t, planpb.OpType_LessEqual, op) - op = getCompareOpType("<", reverse) - assert.Equal(t, planpb.OpType_GreaterThan, op) - op = getCompareOpType("<=", reverse) - assert.Equal(t, planpb.OpType_GreaterEqual, op) - op = getCompareOpType("==", reverse) - assert.Equal(t, planpb.OpType_Equal, op) - op = getCompareOpType("!=", reverse) - assert.Equal(t, planpb.OpType_NotEqual, op) - op = getCompareOpType("*", reverse) - assert.Equal(t, planpb.OpType_Invalid, op) - op = getCompareOpType("startsWith", reverse) - assert.Equal(t, planpb.OpType_PrefixMatch, op) - op = getCompareOpType("endsWith", reverse) - assert.Equal(t, planpb.OpType_PostfixMatch, op) - }) - - t.Run("parse bool node", func(t *testing.T) { - var nodeRaw1, nodeRaw2, nodeRaw3, nodeRaw4 ant_ast.Node - nodeRaw1 = &ant_ast.IdentifierNode{ - Value: "True", - } - boolNode1 := parseBoolNode(&nodeRaw1) - assert.Equal(t, boolNode1.Value, true) - - nodeRaw2 = &ant_ast.IdentifierNode{ - Value: "False", - } - boolNode2 := parseBoolNode(&nodeRaw2) - assert.Equal(t, boolNode2.Value, false) - - nodeRaw3 = &ant_ast.IdentifierNode{ - Value: "abcd", - } - assert.Nil(t, parseBoolNode(&nodeRaw3)) - - nodeRaw4 = &ant_ast.BoolNode{ - Value: true, - } - assert.Nil(t, parseBoolNode(&nodeRaw4)) - }) -} - -func Test_CheckIdentical(t *testing.T) { - schema := newTestSchema() - helper, err := typeutil.CreateSchemaHelper(schema) - assert.NoError(t, err) - - n := 5000 - int64s := generateInt64Array(n) - largeIntTermExpr := `Int64Field in [` - largeFloatTermExpr := `FloatField in [` - for _, i := range int64s[:n-1] { - largeIntTermExpr += fmt.Sprintf("%d, ", i) - largeFloatTermExpr += fmt.Sprintf("%d, ", i) - } - largeIntTermExpr += fmt.Sprintf("%d]", int64s[n-1]) - largeFloatTermExpr += fmt.Sprintf("%d]", int64s[n-1]) - - // cases in regression. - inputs := []string{ - "Int64Field > 0", - "(Int64Field > 0 && Int64Field < 400) or (Int64Field > 500 && Int64Field < 1000)", - "Int64Field not in [1, 2, 3]", - "Int64Field in [1, 2, 3] and FloatField != 2", - "Int64Field == 0 || Int64Field == 1 || Int64Field == 2", - "0 < Int64Field < 400", - "500 <= Int64Field < 1000", - "200+300 < Int64Field <= 500+500", - "Int32Field != Int64Field", - "Int64Field not in []", - `Int64Field >= 0 && VarCharField >= "0"`, - largeIntTermExpr, - largeFloatTermExpr, - } - for _, input := range inputs { - expr1, err := parseExpr(helper, input) - assert.NoError(t, err) - expr2, err := planparserv2.ParseExpr(helper, input) - assert.NoError(t, err) - assert.True(t, planparserv2.CheckPredicatesIdentical(expr1, expr2)) - } -} diff --git a/internal/proxy/privilege_interceptor.go b/internal/proxy/privilege_interceptor.go index 5bff2fb2007d1..b928ef95e9135 100644 --- a/internal/proxy/privilege_interceptor.go +++ b/internal/proxy/privilege_interceptor.go @@ -77,22 +77,6 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context log.Warn("GetCurUserFromContext fail", zap.Error(err)) return ctx, err } - return privilegeInterceptor(ctx, privilegeExt, username, req) -} - -func PrivilegeInterceptorWithUsername(ctx context.Context, username string, req interface{}) (context.Context, error) { - if !Params.CommonCfg.AuthorizationEnabled.GetAsBool() { - return ctx, nil - } - log.Debug("PrivilegeInterceptor", zap.String("type", reflect.TypeOf(req).String())) - privilegeExt, err := funcutil.GetPrivilegeExtObj(req) - if err != nil { - log.Info("GetPrivilegeExtObj err", zap.Error(err)) - return ctx, nil - } - return privilegeInterceptor(ctx, privilegeExt, username, req) -} -func privilegeInterceptor(ctx context.Context, privilegeExt commonpb.PrivilegeExt, username string, req interface{}) (context.Context, error) { if username == util.UserRoot { return ctx, nil } diff --git a/internal/proxy/privilege_interceptor_test.go b/internal/proxy/privilege_interceptor_test.go index fc4aa4dde909b..e42c4df78b0f7 100644 --- a/internal/proxy/privilege_interceptor_test.go +++ b/internal/proxy/privilege_interceptor_test.go @@ -5,13 +5,15 @@ import ( "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" ) func TestUnaryServerInterceptor(t *testing.T) { @@ -45,14 +47,12 @@ func TestPrivilegeInterceptor(t *testing.T) { ctx = GetContext(context.Background(), "alice:123456") client := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} mgr := newShardClientMgr() client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { return &internalpb.ListPolicyResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), PolicyInfos: []string{ funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Collection.String(), "col1", commonpb.ObjectPrivilege_PrivilegeLoad.String(), "default"), funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Collection.String(), "col1", commonpb.ObjectPrivilege_PrivilegeGetLoadState.String(), "default"), @@ -162,7 +162,6 @@ func TestPrivilegeInterceptor(t *testing.T) { getPolicyModel("foo") }) }) - } func TestResourceGroupPrivilege(t *testing.T) { @@ -176,14 +175,12 @@ func TestResourceGroupPrivilege(t *testing.T) { ctx = GetContext(context.Background(), "fooo:123456") client := &MockRootCoordClientInterface{} - queryCoord := &mocks.MockQueryCoord{} + queryCoord := &mocks.MockQueryCoordClient{} mgr := newShardClientMgr() client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { return &internalpb.ListPolicyResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), PolicyInfos: []string{ funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeCreateResourceGroup.String(), "default"), funcutil.PolicyForPrivilege("role1", commonpb.ObjectType_Global.String(), "*", commonpb.ObjectPrivilege_PrivilegeDropResourceGroup.String(), "default"), @@ -223,5 +220,4 @@ func TestResourceGroupPrivilege(t *testing.T) { _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.TransferReplicaRequest{}) assert.NoError(t, err) }) - } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 8d171fc78d1b5..ef7f45698e5cc 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -23,30 +23,31 @@ import ( "os" "strconv" "sync" - "sync/atomic" "syscall" "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/allocator" - "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/internal/util/sessionutil" - "github.com/milvus-io/milvus/pkg/util/tsoutil" - clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/atomic" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proxy/accesslog" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/ratelimitutil" + "github.com/milvus-io/milvus/pkg/util/resource" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -77,18 +78,20 @@ type Proxy struct { ip string port int - stateCode atomic.Value + stateCode atomic.Int32 etcdCli *clientv3.Client address string - rootCoord types.RootCoord - dataCoord types.DataCoord - queryCoord types.QueryCoord + rootCoord types.RootCoordClient + dataCoord types.DataCoordClient + queryCoord types.QueryCoordClient multiRateLimiter *MultiRateLimiter chMgr channelsMgr + replicateMsgStream msgstream.MsgStream + sched *taskScheduler chTicker channelsTimeTicker @@ -112,6 +115,10 @@ type Proxy struct { // for load balance in replicas lbPolicy LBPolicy + + // resource manager + resourceManager resource.Manager + replicateStreamManager *ReplicateStreamManager } // NewProxy returns a Proxy struct. @@ -122,20 +129,33 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) { mgr := newShardClientMgr() lbPolicy := NewLBPolicyImpl(mgr) lbPolicy.Start(ctx) + resourceManager := resource.NewManager(10*time.Second, 20*time.Second, make(map[string]time.Duration)) + replicateStreamManager := NewReplicateStreamManager(ctx, factory, resourceManager) node := &Proxy{ - ctx: ctx1, - cancel: cancel, - factory: factory, - searchResultCh: make(chan *internalpb.SearchResults, n), - shardMgr: mgr, - multiRateLimiter: NewMultiRateLimiter(), - lbPolicy: lbPolicy, + ctx: ctx1, + cancel: cancel, + factory: factory, + searchResultCh: make(chan *internalpb.SearchResults, n), + shardMgr: mgr, + multiRateLimiter: NewMultiRateLimiter(), + lbPolicy: lbPolicy, + resourceManager: resourceManager, + replicateStreamManager: replicateStreamManager, } node.UpdateStateCode(commonpb.StateCode_Abnormal) logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load())) return node, nil } +// UpdateStateCode updates the state code of Proxy. +func (node *Proxy) UpdateStateCode(code commonpb.StateCode) { + node.stateCode.Store(int32(code)) +} + +func (node *Proxy) GetStateCode() commonpb.StateCode { + return commonpb.StateCode(node.stateCode.Load()) +} + // Register registers proxy at etcd func (node *Proxy) Register() error { node.session.Register() @@ -154,7 +174,7 @@ func (node *Proxy) Register() error { } }) // TODO Reset the logger - //Params.initLogCfg() + // Params.initLogCfg() return nil } @@ -165,6 +185,7 @@ func (node *Proxy) initSession() error { return errors.New("new session failed, maybe etcd cannot be connected") } node.session.Init(typeutil.ProxyRole, node.address, false, true) + sessionutil.SaveServerInfo(typeutil.ProxyRole, node.session.ServerID) return nil } @@ -241,6 +262,17 @@ func (node *Proxy) Init() error { node.chMgr = chMgr log.Debug("create channels manager done", zap.String("role", typeutil.ProxyRole)) + replicateMsgChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() + node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx) + if err != nil { + log.Warn("failed to create replicate msg stream", + zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID()), + zap.Error(err)) + return err + } + node.replicateMsgStream.EnableProduce(true) + node.replicateMsgStream.AsProducer([]string{replicateMsgChannel}) + node.sched, err = newTaskScheduler(node.ctx, node.tsoAllocator, node.factory) if err != nil { log.Warn("failed to create task scheduler", zap.String("role", typeutil.ProxyRole), zap.Error(err)) @@ -278,6 +310,9 @@ func (node *Proxy) sendChannelsTimeTickLoop() { log.Info("send channels time tick loop exit") return case <-ticker.C: + if !Params.CommonCfg.TTMsgEnabled.GetAsBool() { + continue + } stats, ts, err := node.chTicker.getMinTsStatistics() if err != nil { log.Warn("sendChannelsTimeTickLoop.getMinTsStatistics", zap.Error(err)) @@ -334,7 +369,7 @@ func (node *Proxy) sendChannelsTimeTickLoop() { log.Warn("sendChannelsTimeTickLoop.UpdateChannelTimeTick", zap.Error(err)) continue } - if status.ErrorCode != 0 { + if status.GetErrorCode() != 0 { log.Warn("sendChannelsTimeTickLoop.UpdateChannelTimeTick", zap.Any("ErrorCode", status.ErrorCode), zap.Any("Reason", status.Reason)) @@ -433,6 +468,10 @@ func (node *Proxy) Stop() error { node.lbPolicy.Close() } + if node.resourceManager != nil { + node.resourceManager.Close() + } + // https://github.com/milvus-io/milvus/issues/12282 node.UpdateStateCode(commonpb.StateCode_Abnormal) @@ -470,21 +509,21 @@ func (node *Proxy) SetEtcdClient(client *clientv3.Client) { } // SetRootCoordClient sets RootCoord client for proxy. -func (node *Proxy) SetRootCoordClient(cli types.RootCoord) { +func (node *Proxy) SetRootCoordClient(cli types.RootCoordClient) { node.rootCoord = cli } // SetDataCoordClient sets DataCoord client for proxy. -func (node *Proxy) SetDataCoordClient(cli types.DataCoord) { +func (node *Proxy) SetDataCoordClient(cli types.DataCoordClient) { node.dataCoord = cli } // SetQueryCoordClient sets QueryCoord client for proxy. -func (node *Proxy) SetQueryCoordClient(cli types.QueryCoord) { +func (node *Proxy) SetQueryCoordClient(cli types.QueryCoordClient) { node.queryCoord = cli } -func (node *Proxy) SetQueryNodeCreator(f func(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error)) { +func (node *Proxy) SetQueryNodeCreator(f func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error)) { node.shardMgr.SetClientCreatorFunc(f) } diff --git a/internal/proxy/proxy_rpc_test.go b/internal/proxy/proxy_rpc_test.go index 2566a35e00480..ae49378025535 100644 --- a/internal/proxy/proxy_rpc_test.go +++ b/internal/proxy/proxy_rpc_test.go @@ -7,6 +7,8 @@ import ( "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" grpcproxyclient "github.com/milvus-io/milvus/internal/distributed/proxy/client" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -17,7 +19,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/stretchr/testify/assert" ) func TestProxyRpcLimit(t *testing.T) { @@ -55,7 +56,7 @@ func TestProxyRpcLimit(t *testing.T) { defer testServer.grpcServer.Stop() client, err := grpcproxyclient.NewClient(ctx, "localhost:"+p.Port.GetValue(), 1) assert.NoError(t, err) - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) rates := make([]*internalpb.Rate, 0) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 895162350d928..18f7ba41c80f5 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -17,7 +17,10 @@ package proxy import ( + "bytes" "context" + "encoding/binary" + "encoding/json" "fmt" "math/rand" "net" @@ -66,6 +69,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/crypto" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -262,22 +266,22 @@ func (s *proxyTestServer) RenameCollection(ctx context.Context, request *milvusp } func (s *proxyTestServer) GetComponentStates(ctx context.Context, request *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { - return s.Proxy.GetComponentStates(ctx) + return s.Proxy.GetComponentStates(ctx, request) } func (s *proxyTestServer) GetStatisticsChannel(ctx context.Context, request *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { - return s.Proxy.GetStatisticsChannel(ctx) + return s.Proxy.GetStatisticsChannel(ctx, request) } func (s *proxyTestServer) startGrpc(ctx context.Context, wg *sync.WaitGroup, p *paramtable.GrpcServerConfig) { defer wg.Done() - var kaep = keepalive.EnforcementPolicy{ + kaep := keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, // If a client pings more than once every 5 seconds, terminate the connection PermitWithoutStream: true, // Allow pings even when there are no active streams } - var kasp = keepalive.ServerParameters{ + kasp := keepalive.ServerParameters{ Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } @@ -452,8 +456,6 @@ func TestProxy(t *testing.T) { rootCoordClient, err := rcc.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli) assert.NoError(t, err) - err = rootCoordClient.Init() - assert.NoError(t, err) err = componentutil.WaitForComponentHealthy(ctx, rootCoordClient, typeutil.RootCoordRole, attempts, sleepDuration) assert.NoError(t, err) proxy.SetRootCoordClient(rootCoordClient) @@ -461,8 +463,6 @@ func TestProxy(t *testing.T) { dataCoordClient, err := grpcdatacoordclient2.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli) assert.NoError(t, err) - err = dataCoordClient.Init() - assert.NoError(t, err) err = componentutil.WaitForComponentHealthy(ctx, dataCoordClient, typeutil.DataCoordRole, attempts, sleepDuration) assert.NoError(t, err) proxy.SetDataCoordClient(dataCoordClient) @@ -470,8 +470,6 @@ func TestProxy(t *testing.T) { queryCoordClient, err := grpcquerycoordclient.NewClient(ctx, Params.EtcdCfg.MetaRootPath.GetValue(), etcdcli) assert.NoError(t, err) - err = queryCoordClient.Init() - assert.NoError(t, err) err = componentutil.WaitForComponentHealthy(ctx, queryCoordClient, typeutil.QueryCoordRole, attempts, sleepDuration) assert.NoError(t, err) proxy.SetQueryCoordClient(queryCoordClient) @@ -484,7 +482,7 @@ func TestProxy(t *testing.T) { err = proxy.Start() assert.NoError(t, err) - assert.Equal(t, commonpb.StateCode_Healthy, proxy.stateCode.Load().(commonpb.StateCode)) + assert.Equal(t, commonpb.StateCode_Healthy, proxy.GetStateCode()) // register proxy err = proxy.Register() @@ -496,18 +494,18 @@ func TestProxy(t *testing.T) { }() t.Run("get component states", func(t *testing.T) { - states, err := proxy.GetComponentStates(ctx) + states, err := proxy.GetComponentStates(ctx, nil) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, states.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, states.GetStatus().GetErrorCode()) assert.Equal(t, paramtable.GetNodeID(), states.State.NodeID) assert.Equal(t, typeutil.ProxyRole, states.State.Role) - assert.Equal(t, proxy.stateCode.Load().(commonpb.StateCode), states.State.StateCode) + assert.Equal(t, proxy.GetStateCode(), states.State.StateCode) }) t.Run("get statistics channel", func(t *testing.T) { - resp, err := proxy.GetStatisticsChannel(ctx) + resp, err := proxy.GetStatisticsChannel(ctx, nil) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, "", resp.Value) }) @@ -677,7 +675,6 @@ func TestProxy(t *testing.T) { resp, err = proxy.CreateCollection(ctx, reqInvalidField) assert.NoError(t, err) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode) - }) wg.Add(1) @@ -764,7 +761,6 @@ func TestProxy(t *testing.T) { DbName: dbName, CollectionName: collectionName, }) - }) wg.Add(1) @@ -777,7 +773,7 @@ func TestProxy(t *testing.T) { TimeStamp: 0, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.True(t, resp.Value) // has other collection: false @@ -788,7 +784,7 @@ func TestProxy(t *testing.T) { TimeStamp: 0, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.False(t, resp.Value) }) @@ -806,7 +802,7 @@ func TestProxy(t *testing.T) { TimeStamp: 0, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, collectionID, resp.CollectionID) // TODO(dragondriver): shards num assert.Equal(t, len(schema.Fields), len(resp.Schema.Fields)) @@ -821,7 +817,7 @@ func TestProxy(t *testing.T) { TimeStamp: 0, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -833,7 +829,7 @@ func TestProxy(t *testing.T) { CollectionName: collectionName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // TODO(dragondriver): check num rows // get statistics of other collection -> fail @@ -843,7 +839,7 @@ func TestProxy(t *testing.T) { CollectionName: otherCollectionName, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -857,7 +853,7 @@ func TestProxy(t *testing.T) { CollectionNames: nil, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 1, len(resp.CollectionNames), resp.CollectionNames) }) @@ -912,7 +908,7 @@ func TestProxy(t *testing.T) { PartitionName: partitionName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.True(t, resp.Value) resp, err = proxy.HasPartition(ctx, &milvuspb.HasPartitionRequest{ @@ -922,7 +918,7 @@ func TestProxy(t *testing.T) { PartitionName: otherPartitionName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.False(t, resp.Value) // non-exist collection -> fail @@ -933,7 +929,7 @@ func TestProxy(t *testing.T) { PartitionName: partitionName, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -946,7 +942,7 @@ func TestProxy(t *testing.T) { PartitionName: partitionName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // non-exist partition -> fail resp, err = proxy.GetPartitionStatistics(ctx, &milvuspb.GetPartitionStatisticsRequest{ @@ -956,7 +952,7 @@ func TestProxy(t *testing.T) { PartitionName: otherPartitionName, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // non-exist collection -> fail resp, err = proxy.GetPartitionStatistics(ctx, &milvuspb.GetPartitionStatisticsRequest{ @@ -966,7 +962,7 @@ func TestProxy(t *testing.T) { PartitionName: partitionName, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -984,7 +980,7 @@ func TestProxy(t *testing.T) { Type: milvuspb.ShowType_All, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // default partition assert.Equal(t, 2, len(resp.PartitionNames)) @@ -995,7 +991,7 @@ func TestProxy(t *testing.T) { PartitionNames: resp.PartitionNames, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State) } @@ -1009,9 +1005,10 @@ func TestProxy(t *testing.T) { Type: milvuspb.ShowType_All, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) + var insertedIds []int64 wg.Add(1) t.Run("insert", func(t *testing.T) { defer wg.Done() @@ -1019,13 +1016,20 @@ func TestProxy(t *testing.T) { resp, err := proxy.Insert(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, rowNum, len(resp.SuccIndex)) assert.Equal(t, 0, len(resp.ErrIndex)) assert.Equal(t, int64(rowNum), resp.InsertCnt) + + switch field := resp.GetIDs().GetIdField().(type) { + case *schemapb.IDs_IntId: + insertedIds = field.IntId.GetData() + default: + t.Fatalf("Unexpected ID type") + } }) - //TODO(dragondriver): proxy.Delete() + // TODO(dragondriver): proxy.Delete() flushed := true wg.Add(1) @@ -1037,7 +1041,7 @@ func TestProxy(t *testing.T) { CollectionNames: []string{collectionName}, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) segmentIDs = resp.CollSegIDs[collectionName].Data log.Info("flush collection", zap.Int64s("segments to be flushed", segmentIDs)) @@ -1079,7 +1083,7 @@ func TestProxy(t *testing.T) { CollectionName: collectionName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) rowNumStr := funcutil.KeyValuePair2Map(resp.Stats)["row_count"] assert.Equal(t, strconv.Itoa(rowNum), rowNumStr) @@ -1090,7 +1094,7 @@ func TestProxy(t *testing.T) { CollectionName: otherCollectionName, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -1114,7 +1118,7 @@ func TestProxy(t *testing.T) { IndexName: "", }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) indexName = resp.IndexDescriptions[0].IndexName }) @@ -1128,7 +1132,7 @@ func TestProxy(t *testing.T) { IndexName: "", }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) indexName = resp.IndexDescriptions[0].IndexName }) @@ -1143,7 +1147,7 @@ func TestProxy(t *testing.T) { IndexName: indexName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -1157,7 +1161,7 @@ func TestProxy(t *testing.T) { IndexName: indexName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) loaded := true @@ -1170,7 +1174,7 @@ func TestProxy(t *testing.T) { CollectionName: collectionName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State) } @@ -1200,7 +1204,7 @@ func TestProxy(t *testing.T) { CollectionNames: []string{collectionName}, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) for idx, name := range resp.CollectionNames { if name == collectionName && resp.InMemoryPercentages[idx] == 100 { @@ -1236,7 +1240,7 @@ func TestProxy(t *testing.T) { CollectionNames: nil, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 1, len(resp.CollectionNames)) // get in-memory percentage @@ -1248,7 +1252,7 @@ func TestProxy(t *testing.T) { CollectionNames: []string{collectionName}, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 1, len(resp.CollectionNames)) assert.Equal(t, 1, len(resp.InMemoryPercentages)) @@ -1261,7 +1265,7 @@ func TestProxy(t *testing.T) { CollectionNames: []string{otherCollectionName}, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) { progressResp, err := proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{ @@ -1269,7 +1273,7 @@ func TestProxy(t *testing.T) { CollectionName: collectionName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, progressResp.GetStatus().GetErrorCode()) assert.NotEqual(t, int64(0), progressResp.Progress) } @@ -1279,7 +1283,7 @@ func TestProxy(t *testing.T) { CollectionName: otherCollectionName, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, progressResp.GetStatus().GetErrorCode()) assert.Equal(t, int64(0), progressResp.Progress) } @@ -1289,7 +1293,7 @@ func TestProxy(t *testing.T) { CollectionName: otherCollectionName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.LoadState_LoadStateNotExist, stateResp.State) } }) @@ -1321,7 +1325,7 @@ func TestProxy(t *testing.T) { CollectionName: collectionName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) rowNumStr := funcutil.KeyValuePair2Map(resp.Stats)["row_count"] assert.Equal(t, strconv.Itoa(rowNum), rowNumStr) @@ -1332,7 +1336,136 @@ func TestProxy(t *testing.T) { CollectionName: otherCollectionName, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + + nprobe := 10 + topk := 10 + roundDecimal := 6 + expr := fmt.Sprintf("%s > 0", int64Field) + constructVectorsPlaceholderGroup := func() *commonpb.PlaceholderGroup { + values := make([][]byte, 0, nq) + for i := 0; i < nq; i++ { + bs := make([]byte, 0, dim*4) + for j := 0; j < dim; j++ { + var buffer bytes.Buffer + f := rand.Float32() + err := binary.Write(&buffer, common.Endian, f) + assert.NoError(t, err) + bs = append(bs, buffer.Bytes()...) + } + values = append(values, bs) + } + + return &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{ + { + Tag: "$0", + Type: commonpb.PlaceholderType_FloatVector, + Values: values, + }, + }, + } + } + + constructSearchRequest := func() *milvuspb.SearchRequest { + plg := constructVectorsPlaceholderGroup() + plgBs, err := proto.Marshal(plg) + assert.NoError(t, err) + + params := make(map[string]string) + params["nprobe"] = strconv.Itoa(nprobe) + b, err := json.Marshal(params) + assert.NoError(t, err) + searchParams := []*commonpb.KeyValuePair{ + {Key: MetricTypeKey, Value: metric.L2}, + {Key: SearchParamsKey, Value: string(b)}, + {Key: AnnsFieldKey, Value: floatVecField}, + {Key: TopKKey, Value: strconv.Itoa(topk)}, + {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, + } + + return &milvuspb.SearchRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionNames: nil, + Dsl: expr, + PlaceholderGroup: plgBs, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: nil, + SearchParams: searchParams, + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + SearchByPrimaryKeys: false, + } + } + + wg.Add(1) + t.Run("search", func(t *testing.T) { + defer wg.Done() + req := constructSearchRequest() + + resp, err := proxy.Search(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + }) + + constructPrimaryKeysPlaceholderGroup := func() *commonpb.PlaceholderGroup { + expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIds[0]) + exprBytes := []byte(expr) + + return &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{ + { + Tag: "$0", + Type: commonpb.PlaceholderType_None, + Values: [][]byte{exprBytes}, + }, + }, + } + } + + constructSearchByPksRequest := func() *milvuspb.SearchRequest { + plg := constructPrimaryKeysPlaceholderGroup() + plgBs, err := proto.Marshal(plg) + assert.NoError(t, err) + + params := make(map[string]string) + params["nprobe"] = strconv.Itoa(nprobe) + b, err := json.Marshal(params) + assert.NoError(t, err) + searchParams := []*commonpb.KeyValuePair{ + {Key: MetricTypeKey, Value: metric.L2}, + {Key: SearchParamsKey, Value: string(b)}, + {Key: AnnsFieldKey, Value: floatVecField}, + {Key: TopKKey, Value: strconv.Itoa(topk)}, + {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, + } + + return &milvuspb.SearchRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionNames: nil, + Dsl: "", + PlaceholderGroup: plgBs, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: nil, + SearchParams: searchParams, + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + SearchByPrimaryKeys: true, + } + } + + wg.Add(1) + t.Run("search by primary keys", func(t *testing.T) { + defer wg.Done() + req := constructSearchByPksRequest() + resp, err := proxy.Search(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) }) // nprobe := 10 @@ -1418,7 +1551,7 @@ func TestProxy(t *testing.T) { // // resp, err := proxy.Search(ctx, req) // assert.NoError(t, err) - // assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + // assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // }) // // wg.Add(1) @@ -1437,7 +1570,7 @@ func TestProxy(t *testing.T) { // }) // assert.NoError(t, err) // // FIXME(dragondriver) - // // assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + // // assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // // TODO(dragondriver): compare query result // }) @@ -1470,7 +1603,7 @@ func TestProxy(t *testing.T) { }, } - //resp, err := proxy.CalcDistance(ctx, &milvuspb.CalcDistanceRequest{ + // resp, err := proxy.CalcDistance(ctx, &milvuspb.CalcDistanceRequest{ _, err := proxy.CalcDistance(ctx, &milvuspb.CalcDistanceRequest{ Base: nil, OpLeft: opLeft, @@ -1483,7 +1616,7 @@ func TestProxy(t *testing.T) { }, }) assert.NoError(t, err) - // assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + // assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // TODO(dragondriver): compare distance // TODO(dragondriver): use primary key to calculate distance @@ -1491,7 +1624,7 @@ func TestProxy(t *testing.T) { t.Run("get dd channel", func(t *testing.T) { resp, _ := proxy.GetDdChannel(ctx, &internalpb.GetDdChannelRequest{}) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -1503,7 +1636,7 @@ func TestProxy(t *testing.T) { CollectionName: collectionName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -1515,7 +1648,7 @@ func TestProxy(t *testing.T) { CollectionName: collectionName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -1525,7 +1658,7 @@ func TestProxy(t *testing.T) { Base: nil, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.ErrorCode) + assert.ErrorIs(t, merr.Error(resp), merr.ErrCollectionNotFound) }) // TODO(dragondriver): dummy @@ -1535,7 +1668,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.RegisterLink(ctx, &milvuspb.RegisterLinkRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -1545,12 +1678,12 @@ func TestProxy(t *testing.T) { assert.NoError(t, err) resp, err := proxy.GetMetrics(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // get from cache resp, err = proxy.GetMetrics(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // failed to parse metric type resp, err = proxy.GetMetrics(ctx, &milvuspb.GetMetricsRequest{ @@ -1558,14 +1691,14 @@ func TestProxy(t *testing.T) { Request: "not in json format", }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // not implemented metric notImplemented, err := metricsinfo.ConstructRequestByMetricType("not implemented") assert.NoError(t, err) resp, err = proxy.GetMetrics(ctx, notImplemented) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -1575,7 +1708,7 @@ func TestProxy(t *testing.T) { assert.NoError(t, err) resp, err := proxy.GetProxyMetrics(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // failed to parse metric type resp, err = proxy.GetProxyMetrics(ctx, &milvuspb.GetMetricsRequest{ @@ -1583,27 +1716,27 @@ func TestProxy(t *testing.T) { Request: "not in json format", }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // not implemented metric notImplemented, err := metricsinfo.ConstructRequestByMetricType("not implemented") assert.NoError(t, err) resp, err = proxy.GetProxyMetrics(ctx, notImplemented) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // unhealthy proxy.UpdateStateCode(commonpb.StateCode_Abnormal) resp, err = proxy.GetProxyMetrics(ctx, req) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) proxy.UpdateStateCode(commonpb.StateCode_Healthy) // getProxyMetric failed rateCol.Deregister(internalpb.RateType_DMLInsert.String()) resp, err = proxy.GetProxyMetrics(ctx, req) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) rateCol.Register(internalpb.RateType_DMLInsert.String()) }) @@ -1615,9 +1748,9 @@ func TestProxy(t *testing.T) { CollectionName: collectionName, Files: []string{"f1.json"}, } - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) resp, err := proxy.Import(context.TODO(), req) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.NoError(t, err) // Wait a bit for complete import to start. time.Sleep(2 * time.Second) @@ -1630,10 +1763,10 @@ func TestProxy(t *testing.T) { CollectionName: "bad_collection_name", Files: []string{"f1.json"}, } - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) resp, err := proxy.Import(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -1643,10 +1776,10 @@ func TestProxy(t *testing.T) { CollectionName: "bad_collection_name", Files: []string{"f1.json"}, } - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) resp, err := proxy.Import(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -1685,7 +1818,7 @@ func TestProxy(t *testing.T) { CollectionNames: nil, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(resp.CollectionNames)) }) @@ -1738,7 +1871,7 @@ func TestProxy(t *testing.T) { Type: milvuspb.ShowType_InMemory, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) for idx, name := range resp.PartitionNames { if name == partitionName && resp.InMemoryPercentages[idx] == 100 { @@ -1778,7 +1911,7 @@ func TestProxy(t *testing.T) { Type: milvuspb.ShowType_InMemory, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // default partition? assert.Equal(t, 1, len(resp.PartitionNames)) @@ -1792,7 +1925,7 @@ func TestProxy(t *testing.T) { Type: milvuspb.ShowType_InMemory, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // non-exist collection -> fail resp, err = proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{ @@ -1804,7 +1937,7 @@ func TestProxy(t *testing.T) { Type: milvuspb.ShowType_InMemory, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) { resp, err := proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{ @@ -1813,7 +1946,7 @@ func TestProxy(t *testing.T) { PartitionNames: []string{partitionName}, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.NotEqual(t, int64(0), resp.Progress) } @@ -1824,7 +1957,7 @@ func TestProxy(t *testing.T) { PartitionNames: []string{otherPartitionName}, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, int64(0), resp.Progress) } }) @@ -1836,7 +1969,7 @@ func TestProxy(t *testing.T) { resp, err := proxy.Insert(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, rowNum, len(resp.SuccIndex)) assert.Equal(t, 0, len(resp.ErrIndex)) assert.Equal(t, int64(rowNum), resp.InsertCnt) @@ -1855,7 +1988,7 @@ func TestProxy(t *testing.T) { PartitionNames: []string{partitionName}, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) rowNumStr := funcutil.KeyValuePair2Map(resp.Stats)["row_count"] assert.Equal(t, strconv.Itoa(rowNum), rowNumStr) @@ -1867,7 +2000,7 @@ func TestProxy(t *testing.T) { PartitionNames: []string{otherPartitionName}, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // non-exist collection -> fail resp, err = proxy.GetStatistics(ctx, &milvuspb.GetStatisticsRequest{ @@ -1877,7 +2010,7 @@ func TestProxy(t *testing.T) { PartitionNames: []string{partitionName}, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -1895,7 +2028,7 @@ func TestProxy(t *testing.T) { CollectionName: collectionName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) rowNumStr := funcutil.KeyValuePair2Map(resp.Stats)["row_count"] assert.Equal(t, strconv.Itoa(rowNum*2), rowNumStr) @@ -1906,7 +2039,7 @@ func TestProxy(t *testing.T) { CollectionName: otherCollectionName, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -1937,7 +2070,7 @@ func TestProxy(t *testing.T) { Type: milvuspb.ShowType_InMemory, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // default partition assert.Equal(t, 0, len(resp.PartitionNames)) @@ -1950,7 +2083,7 @@ func TestProxy(t *testing.T) { Type: milvuspb.ShowType_InMemory, }) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -1996,7 +2129,7 @@ func TestProxy(t *testing.T) { PartitionName: partitionName, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.False(t, resp.Value) }) @@ -2015,7 +2148,7 @@ func TestProxy(t *testing.T) { Type: milvuspb.ShowType_All, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // default partition assert.Equal(t, 1, len(resp.PartitionNames)) }) @@ -2054,7 +2187,7 @@ func TestProxy(t *testing.T) { resp, err := proxy.Upsert(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UpsertAutoIDTrue, resp.Status.ErrorCode) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrParameterInvalid) assert.Equal(t, 0, len(resp.SuccIndex)) assert.Equal(t, rowNum, len(resp.ErrIndex)) assert.Equal(t, int64(0), resp.UpsertCnt) @@ -2102,7 +2235,7 @@ func TestProxy(t *testing.T) { TimeStamp: 0, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.False(t, resp.Value) }) @@ -2117,7 +2250,7 @@ func TestProxy(t *testing.T) { CollectionNames: nil, }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(resp.CollectionNames)) }) @@ -2252,13 +2385,13 @@ func TestProxy(t *testing.T) { getCredentialReq := constructGetCredentialRequest() getResp, err := rootCoordClient.GetCredential(ctx, getCredentialReq) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, getResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, getResp.GetStatus().GetErrorCode()) assert.True(t, passwordVerify(ctx, username, newPassword, globalMetaCache)) getCredentialReq.Username = "(" getResp, err = rootCoordClient.GetCredential(ctx, getCredentialReq) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, getResp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, getResp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2302,7 +2435,7 @@ func TestProxy(t *testing.T) { // proxy unhealthy // //notStateCode := "not state code" - //proxy.stateCode.Store(notStateCode) + //proxy.UpdateStateCode(notStateCode) // //t.Run("GetComponentStates fail", func(t *testing.T) { // _, err := proxy.GetComponentStates(ctx) @@ -2332,7 +2465,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.HasCollection(ctx, &milvuspb.HasCollectionRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2356,7 +2489,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2364,7 +2497,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetCollectionStatistics(ctx, &milvuspb.GetCollectionStatisticsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2372,7 +2505,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2408,7 +2541,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.HasPartition(ctx, &milvuspb.HasPartitionRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2432,7 +2565,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetPartitionStatistics(ctx, &milvuspb.GetPartitionStatisticsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2440,7 +2573,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2448,7 +2581,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetLoadingProgress(ctx, &milvuspb.GetLoadingProgressRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2456,7 +2589,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2472,7 +2605,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2480,7 +2613,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetIndexStatistics(ctx, &milvuspb.GetIndexStatisticsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2496,7 +2629,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetIndexBuildProgress(ctx, &milvuspb.GetIndexBuildProgressRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2504,7 +2637,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetIndexState(ctx, &milvuspb.GetIndexStateRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2512,7 +2645,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Insert(ctx, &milvuspb.InsertRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2520,7 +2653,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Delete(ctx, &milvuspb.DeleteRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2528,7 +2661,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Upsert(ctx, &milvuspb.UpsertRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2536,7 +2669,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Search(ctx, &milvuspb.SearchRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2544,7 +2677,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Flush(ctx, &milvuspb.FlushRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2552,7 +2685,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Query(ctx, &milvuspb.QueryRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2584,7 +2717,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetPersistentSegmentInfo(ctx, &milvuspb.GetPersistentSegmentInfoRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2592,7 +2725,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetQuerySegmentInfo(ctx, &milvuspb.GetQuerySegmentInfoRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2608,7 +2741,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.RegisterLink(ctx, &milvuspb.RegisterLinkRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2616,7 +2749,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetMetrics(ctx, &milvuspb.GetMetricsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2664,7 +2797,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.ListCredUsers(ctx, &milvuspb.ListCredUsersRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("InvalidateCollectionMetaCache failed", func(t *testing.T) { @@ -2705,7 +2838,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.HasCollection(ctx, &milvuspb.HasCollectionRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2729,7 +2862,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2737,7 +2870,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetCollectionStatistics(ctx, &milvuspb.GetCollectionStatisticsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2745,7 +2878,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2781,7 +2914,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.HasPartition(ctx, &milvuspb.HasPartitionRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2805,7 +2938,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetPartitionStatistics(ctx, &milvuspb.GetPartitionStatisticsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2813,7 +2946,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2829,7 +2962,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2837,7 +2970,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetIndexStatistics(ctx, &milvuspb.GetIndexStatisticsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2853,7 +2986,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetIndexBuildProgress(ctx, &milvuspb.GetIndexBuildProgressRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2861,15 +2994,16 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetIndexState(ctx, &milvuspb.GetIndexStateRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) - t.Run("Flush fail, dd queue full", func(t *testing.T) { + t.Run("Flush fail, dc queue full", func(t *testing.T) { defer wg.Done() + proxy.sched.dcQueue.setMaxTaskNum(0) resp, err := proxy.Flush(ctx, &milvuspb.FlushRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2906,7 +3040,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Insert(ctx, &milvuspb.InsertRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2914,7 +3048,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Delete(ctx, &milvuspb.DeleteRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2922,7 +3056,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Upsert(ctx, &milvuspb.UpsertRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) proxy.sched.dmQueue.setMaxTaskNum(dmParallelism) @@ -2935,7 +3069,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Search(ctx, &milvuspb.SearchRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -2943,7 +3077,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Query(ctx, &milvuspb.QueryRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) proxy.sched.dqQueue.setMaxTaskNum(dqParallelism) @@ -2976,7 +3110,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.HasCollection(shortCtx, &milvuspb.HasCollectionRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3000,7 +3134,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.DescribeCollection(shortCtx, &milvuspb.DescribeCollectionRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3008,7 +3142,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetCollectionStatistics(shortCtx, &milvuspb.GetCollectionStatisticsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3016,7 +3150,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.ShowCollections(shortCtx, &milvuspb.ShowCollectionsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3052,7 +3186,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.HasPartition(shortCtx, &milvuspb.HasPartitionRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3076,7 +3210,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetPartitionStatistics(shortCtx, &milvuspb.GetPartitionStatisticsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3084,7 +3218,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.ShowPartitions(shortCtx, &milvuspb.ShowPartitionsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3092,7 +3226,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetLoadingProgress(shortCtx, &milvuspb.GetLoadingProgressRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3108,7 +3242,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.DescribeIndex(shortCtx, &milvuspb.DescribeIndexRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3116,7 +3250,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetIndexStatistics(shortCtx, &milvuspb.GetIndexStatisticsRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3132,7 +3266,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetIndexBuildProgress(shortCtx, &milvuspb.GetIndexBuildProgressRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3140,7 +3274,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.GetIndexState(shortCtx, &milvuspb.GetIndexStateRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3149,7 +3283,7 @@ func TestProxy(t *testing.T) { _, err := proxy.Flush(shortCtx, &milvuspb.FlushRequest{}) assert.NoError(t, err) // FIXME(dragondriver) - // assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + // assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3157,7 +3291,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Insert(shortCtx, &milvuspb.InsertRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3165,7 +3299,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Delete(shortCtx, &milvuspb.DeleteRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3173,7 +3307,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Upsert(shortCtx, &milvuspb.UpsertRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3181,7 +3315,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Search(shortCtx, &milvuspb.SearchRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3189,7 +3323,7 @@ func TestProxy(t *testing.T) { defer wg.Done() resp, err := proxy.Query(shortCtx, &milvuspb.QueryRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) @@ -3369,7 +3503,6 @@ func TestProxy(t *testing.T) { resp, err = proxy.CreateCollection(ctx, reqInvalidField) assert.NoError(t, err) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode) - }) wg.Add(1) @@ -3402,7 +3535,7 @@ func TestProxy(t *testing.T) { resp, err := proxy.Upsert(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, rowNum, len(resp.SuccIndex)) assert.Equal(t, 0, len(resp.ErrIndex)) assert.Equal(t, int64(rowNum), resp.UpsertCnt) @@ -3415,7 +3548,7 @@ func TestProxy(t *testing.T) { resp, err := proxy.Upsert(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(resp.SuccIndex)) assert.Equal(t, rowNum, len(resp.ErrIndex)) assert.Equal(t, int64(0), resp.UpsertCnt) @@ -3428,7 +3561,7 @@ func TestProxy(t *testing.T) { resp, err := proxy.Upsert(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, rowNum, len(resp.SuccIndex)) assert.Equal(t, 0, len(resp.ErrIndex)) assert.Equal(t, int64(rowNum), resp.UpsertCnt) @@ -3491,7 +3624,7 @@ func testProxyRole(ctx context.Context, t *testing.T, proxy *Proxy) { assert.Equal(t, commonpb.ErrorCode_Success, privilegeResp.ErrorCode) userResp, _ := proxy.SelectUser(ctx, &milvuspb.SelectUserRequest{User: &milvuspb.UserEntity{Name: username}, IncludeRoleInfo: true}) - assert.Equal(t, commonpb.ErrorCode_Success, userResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, userResp.GetStatus().GetErrorCode()) roleNumOfUser := len(userResp.Results[0].Roles) roleResp, _ = proxy.OperateUserRole(ctx, &milvuspb.OperateUserRoleRequest{ @@ -3512,7 +3645,7 @@ func testProxyRole(ctx context.Context, t *testing.T, proxy *Proxy) { assert.Equal(t, commonpb.ErrorCode_Success, roleResp.ErrorCode) userResp, _ = proxy.SelectUser(ctx, &milvuspb.SelectUserRequest{User: &milvuspb.UserEntity{Name: username}, IncludeRoleInfo: true}) - assert.Equal(t, commonpb.ErrorCode_Success, userResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, userResp.GetStatus().GetErrorCode()) assert.Equal(t, roleNumOfUser, len(userResp.Results[0].Roles)) }) @@ -3550,14 +3683,14 @@ func testProxyRole(ctx context.Context, t *testing.T, proxy *Proxy) { defer wg.Done() resp, _ := proxy.SelectRole(ctx, &milvuspb.SelectRoleRequest{Role: &milvuspb.RoleEntity{Name: " "}}) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) resp, _ = proxy.SelectRole(ctx, &milvuspb.SelectRoleRequest{}) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) roleNum := len(resp.Results) resp, _ = proxy.SelectRole(ctx, &milvuspb.SelectRoleRequest{Role: &milvuspb.RoleEntity{Name: "not_existed"}}) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(resp.Results)) roleName := "unit_test" @@ -3565,7 +3698,7 @@ func testProxyRole(ctx context.Context, t *testing.T, proxy *Proxy) { assert.Equal(t, commonpb.ErrorCode_Success, roleResp.ErrorCode) resp, _ = proxy.SelectRole(ctx, &milvuspb.SelectRoleRequest{}) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, roleNum+1, len(resp.Results)) roleResp, _ = proxy.DropRole(ctx, &milvuspb.DropRoleRequest{RoleName: roleName}) @@ -3575,7 +3708,7 @@ func testProxyRole(ctx context.Context, t *testing.T, proxy *Proxy) { assert.Equal(t, commonpb.ErrorCode_Success, opResp.ErrorCode) resp, _ = proxy.SelectRole(ctx, &milvuspb.SelectRoleRequest{Role: &milvuspb.RoleEntity{Name: "admin"}, IncludeUserInfo: true}) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.NotEqual(t, 0, len(resp.Results)) assert.NotEqual(t, 0, len(resp.Results[0].Users)) @@ -3589,23 +3722,23 @@ func testProxyRole(ctx context.Context, t *testing.T, proxy *Proxy) { entity := &milvuspb.UserEntity{Name: " "} resp, _ := proxy.SelectUser(ctx, &milvuspb.SelectUserRequest{User: entity}) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) entity.Name = "not_existed" resp, _ = proxy.SelectUser(ctx, &milvuspb.SelectUserRequest{User: entity}) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(resp.Results)) entity.Name = "root" resp, _ = proxy.SelectUser(ctx, &milvuspb.SelectUserRequest{User: entity}) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.NotEqual(t, 0, len(resp.Results)) opResp, _ := proxy.OperateUserRole(ctx, &milvuspb.OperateUserRoleRequest{Username: "root", RoleName: "admin"}) assert.Equal(t, commonpb.ErrorCode_Success, opResp.ErrorCode) resp, _ = proxy.SelectUser(ctx, &milvuspb.SelectUserRequest{User: entity, IncludeRoleInfo: true}) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.NotEqual(t, 0, len(resp.Results)) assert.NotEqual(t, 0, len(resp.Results[0].Roles)) @@ -3632,7 +3765,6 @@ func testProxyRole(ctx context.Context, t *testing.T, proxy *Proxy) { resp, err := proxy.OperateUserRole(ctx, &milvuspb.OperateUserRoleRequest{Username: username, RoleName: roleName}) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) - } { resp, err := proxy.OperateUserRole(ctx, &milvuspb.OperateUserRoleRequest{Username: username, RoleName: "admin"}) @@ -3642,12 +3774,12 @@ func testProxyRole(ctx context.Context, t *testing.T, proxy *Proxy) { { selectUserResp, err := proxy.SelectUser(ctx, &milvuspb.SelectUserRequest{User: &milvuspb.UserEntity{Name: username}, IncludeRoleInfo: true}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, selectUserResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, selectUserResp.GetStatus().GetErrorCode()) assert.Equal(t, 2, len(selectUserResp.Results[0].Roles)) selectRoleResp, err := proxy.SelectRole(ctx, &milvuspb.SelectRoleRequest{Role: &milvuspb.RoleEntity{Name: roleName}, IncludeUserInfo: true}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, selectRoleResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, selectRoleResp.GetStatus().GetErrorCode()) assert.Equal(t, 1, len(selectRoleResp.Results[0].Users)) } { @@ -3658,12 +3790,12 @@ func testProxyRole(ctx context.Context, t *testing.T, proxy *Proxy) { { selectUserResp, err := proxy.SelectUser(ctx, &milvuspb.SelectUserRequest{User: &milvuspb.UserEntity{Name: username}, IncludeRoleInfo: true}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, selectUserResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, selectUserResp.GetStatus().GetErrorCode()) assert.Equal(t, 1, len(selectUserResp.Results[0].Roles)) selectRoleResp, err := proxy.SelectRole(ctx, &milvuspb.SelectRoleRequest{Role: &milvuspb.RoleEntity{Name: roleName}, IncludeUserInfo: true}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, selectRoleResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, selectRoleResp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(selectRoleResp.Results)) } { @@ -3674,7 +3806,7 @@ func testProxyRole(ctx context.Context, t *testing.T, proxy *Proxy) { { selectUserResp, err := proxy.SelectUser(ctx, &milvuspb.SelectUserRequest{User: &milvuspb.UserEntity{Name: username}, IncludeRoleInfo: true}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, selectUserResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, selectUserResp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(selectUserResp.Results)) } }) @@ -3719,14 +3851,14 @@ func testProxyRoleFail(ctx context.Context, t *testing.T, proxy *Proxy, reason s t.Run(fmt.Sprintf("SelectRole fail, %s", reason), func(t *testing.T) { defer wg.Done() resp, _ := proxy.SelectRole(ctx, &milvuspb.SelectRoleRequest{}) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Add(1) t.Run(fmt.Sprintf("SelectUser fail, %s", reason), func(t *testing.T) { defer wg.Done() resp, _ := proxy.SelectUser(ctx, &milvuspb.SelectUserRequest{}) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Wait() @@ -3812,49 +3944,49 @@ func testProxyPrivilege(ctx context.Context, t *testing.T, proxy *Proxy) { // select grant selectReq := &milvuspb.SelectGrantRequest{} results, _ := proxy.SelectGrant(ctx, selectReq) - assert.NotEqual(t, commonpb.ErrorCode_Success, results.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, results.GetStatus().GetErrorCode()) selectReq.Entity = &milvuspb.GrantEntity{} results, _ = proxy.SelectGrant(ctx, selectReq) - assert.NotEqual(t, commonpb.ErrorCode_Success, results.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, results.GetStatus().GetErrorCode()) selectReq.Entity.Object = &milvuspb.ObjectEntity{} results, _ = proxy.SelectGrant(ctx, selectReq) - assert.NotEqual(t, commonpb.ErrorCode_Success, results.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, results.GetStatus().GetErrorCode()) selectReq.Entity.Object.Name = commonpb.ObjectType_Collection.String() results, _ = proxy.SelectGrant(ctx, selectReq) - assert.NotEqual(t, commonpb.ErrorCode_Success, results.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, results.GetStatus().GetErrorCode()) selectReq.Entity.ObjectName = "col1" results, _ = proxy.SelectGrant(ctx, selectReq) - assert.NotEqual(t, commonpb.ErrorCode_Success, results.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, results.GetStatus().GetErrorCode()) selectReq.Entity.Role = &milvuspb.RoleEntity{} results, _ = proxy.SelectGrant(ctx, selectReq) - assert.NotEqual(t, commonpb.ErrorCode_Success, results.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, results.GetStatus().GetErrorCode()) selectReq.Entity.Role = &milvuspb.RoleEntity{Name: "public"} results, _ = proxy.SelectGrant(ctx, selectReq) - assert.Equal(t, commonpb.ErrorCode_Success, results.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, results.GetStatus().GetErrorCode()) assert.NotEqual(t, 0, len(results.Entities)) selectReq.Entity.Object.Name = "not existed" results, _ = proxy.SelectGrant(ctx, selectReq) - assert.NotEqual(t, commonpb.ErrorCode_Success, results.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, results.GetStatus().GetErrorCode()) selectReq.Entity.Object.Name = commonpb.ObjectType_Collection.String() selectReq.Entity.Role = &milvuspb.RoleEntity{Name: "not existed"} results, _ = proxy.SelectGrant(ctx, selectReq) - assert.NotEqual(t, commonpb.ErrorCode_Success, results.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, results.GetStatus().GetErrorCode()) results, _ = proxy.SelectGrant(ctx, &milvuspb.SelectGrantRequest{ Entity: &milvuspb.GrantEntity{ Role: &milvuspb.RoleEntity{Name: "public"}, }, }) - assert.Equal(t, commonpb.ErrorCode_Success, results.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, results.GetStatus().GetErrorCode()) assert.NotEqual(t, 0, len(results.Entities)) req.Type = milvuspb.OperatePrivilegeType_Revoke @@ -3934,7 +4066,7 @@ func testProxyPrivilegeFail(ctx context.Context, t *testing.T, proxy *Proxy, rea Role: &milvuspb.RoleEntity{Name: "admin"}, }, }) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) wg.Wait() } @@ -3951,14 +4083,16 @@ func testProxyRefreshPolicyInfoCache(ctx context.Context, t *testing.T, proxy *P }) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) - _, err = proxy.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) - assert.Error(t, err) + resp, err = proxy.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) + assert.NoError(t, err) + assert.Error(t, merr.Error(resp)) - _, err = proxy.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ + resp, err = proxy.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ OpType: 100, OpKey: funcutil.EncodeUserRoleCache("foo", "public"), }) - assert.Error(t, err) + assert.NoError(t, err) + assert.Error(t, merr.Error(resp)) }) wg.Wait() } @@ -3985,7 +4119,7 @@ func Test_GetCompactionState(t *testing.T) { t.Run("get compaction state", func(t *testing.T) { datacoord := &DataCoordMock{} proxy := &Proxy{dataCoord: datacoord} - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) resp, err := proxy.GetCompactionState(context.TODO(), nil) assert.EqualValues(t, &milvuspb.GetCompactionStateResponse{}, resp) assert.NoError(t, err) @@ -3994,9 +4128,9 @@ func Test_GetCompactionState(t *testing.T) { t.Run("get compaction state with unhealthy proxy", func(t *testing.T) { datacoord := &DataCoordMock{} proxy := &Proxy{dataCoord: datacoord} - proxy.stateCode.Store(commonpb.StateCode_Abnormal) + proxy.UpdateStateCode(commonpb.StateCode_Abnormal) resp, err := proxy.GetCompactionState(context.TODO(), nil) - assert.EqualValues(t, unhealthyStatus(), resp.Status) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) assert.NoError(t, err) }) } @@ -4005,7 +4139,7 @@ func Test_ManualCompaction(t *testing.T) { t.Run("test manual compaction", func(t *testing.T) { datacoord := &DataCoordMock{} proxy := &Proxy{dataCoord: datacoord} - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) resp, err := proxy.ManualCompaction(context.TODO(), nil) assert.EqualValues(t, &milvuspb.ManualCompactionResponse{}, resp) assert.NoError(t, err) @@ -4013,9 +4147,9 @@ func Test_ManualCompaction(t *testing.T) { t.Run("test manual compaction with unhealthy", func(t *testing.T) { datacoord := &DataCoordMock{} proxy := &Proxy{dataCoord: datacoord} - proxy.stateCode.Store(commonpb.StateCode_Abnormal) + proxy.UpdateStateCode(commonpb.StateCode_Abnormal) resp, err := proxy.ManualCompaction(context.TODO(), nil) - assert.EqualValues(t, unhealthyStatus(), resp.Status) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) assert.NoError(t, err) }) } @@ -4024,7 +4158,7 @@ func Test_GetCompactionStateWithPlans(t *testing.T) { t.Run("test get compaction state with plans", func(t *testing.T) { datacoord := &DataCoordMock{} proxy := &Proxy{dataCoord: datacoord} - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) resp, err := proxy.GetCompactionStateWithPlans(context.TODO(), nil) assert.EqualValues(t, &milvuspb.GetCompactionPlansResponse{}, resp) assert.NoError(t, err) @@ -4032,19 +4166,33 @@ func Test_GetCompactionStateWithPlans(t *testing.T) { t.Run("test get compaction state with plans with unhealthy proxy", func(t *testing.T) { datacoord := &DataCoordMock{} proxy := &Proxy{dataCoord: datacoord} - proxy.stateCode.Store(commonpb.StateCode_Abnormal) + proxy.UpdateStateCode(commonpb.StateCode_Abnormal) resp, err := proxy.GetCompactionStateWithPlans(context.TODO(), nil) - assert.EqualValues(t, unhealthyStatus(), resp.Status) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) assert.NoError(t, err) }) } func Test_GetFlushState(t *testing.T) { t.Run("normal test", func(t *testing.T) { + originCache := globalMetaCache + m := NewMockCache(t) + m.On("GetCollectionID", + mock.Anything, + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + ).Return(UniqueID(1), nil) + globalMetaCache = m + defer func() { + globalMetaCache = originCache + }() + datacoord := &DataCoordMock{} proxy := &Proxy{dataCoord: datacoord} - proxy.stateCode.Store(commonpb.StateCode_Healthy) - resp, err := proxy.GetFlushState(context.TODO(), nil) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) + resp, err := proxy.GetFlushState(context.TODO(), &milvuspb.GetFlushStateRequest{ + CollectionName: "coll", + }) assert.EqualValues(t, &milvuspb.GetFlushStateResponse{}, resp) assert.NoError(t, err) }) @@ -4052,33 +4200,25 @@ func Test_GetFlushState(t *testing.T) { t.Run("test get flush state with unhealthy proxy", func(t *testing.T) { datacoord := &DataCoordMock{} proxy := &Proxy{dataCoord: datacoord} - proxy.stateCode.Store(commonpb.StateCode_Abnormal) + proxy.UpdateStateCode(commonpb.StateCode_Abnormal) resp, err := proxy.GetFlushState(context.TODO(), nil) - assert.EqualValues(t, unhealthyStatus(), resp.Status) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) assert.NoError(t, err) }) } func TestProxy_GetComponentStates(t *testing.T) { n := &Proxy{} - n.stateCode.Store(commonpb.StateCode_Healthy) - resp, err := n.GetComponentStates(context.Background()) + n.UpdateStateCode(commonpb.StateCode_Healthy) + resp, err := n.GetComponentStates(context.Background(), nil) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, common.NotRegisteredID, resp.State.NodeID) n.session = &sessionutil.Session{} n.session.UpdateRegistered(true) - resp, err = n.GetComponentStates(context.Background()) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) -} - -func TestProxy_GetComponentStates_state_code(t *testing.T) { - p := &Proxy{} - p.stateCode.Store("not commonpb.StateCode") - states, err := p.GetComponentStates(context.Background()) + resp, err = n.GetComponentStates(context.Background(), nil) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, states.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } func TestProxy_Import(t *testing.T) { @@ -4094,7 +4234,7 @@ func TestProxy_Import(t *testing.T) { proxy.UpdateStateCode(commonpb.StateCode_Abnormal) resp, err := proxy.Import(context.TODO(), req) assert.NoError(t, err) - assert.EqualValues(t, unhealthyStatus(), resp.GetStatus()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) }) wg.Add(1) @@ -4102,10 +4242,10 @@ func TestProxy_Import(t *testing.T) { defer wg.Done() proxy := &Proxy{} proxy.UpdateStateCode(commonpb.StateCode_Healthy) - chMgr := newMockChannelsMgr() + chMgr := NewMockChannelsMgr(t) proxy.chMgr = chMgr rc := newMockRootCoord() - rc.ImportFunc = func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { + rc.ImportFunc = func(ctx context.Context, req *milvuspb.ImportRequest, opts ...grpc.CallOption) (*milvuspb.ImportResponse, error) { return nil, errors.New("mock") } proxy.rootCoord = rc @@ -4122,10 +4262,10 @@ func TestProxy_Import(t *testing.T) { defer wg.Done() proxy := &Proxy{} proxy.UpdateStateCode(commonpb.StateCode_Healthy) - chMgr := newMockChannelsMgr() + chMgr := NewMockChannelsMgr(t) proxy.chMgr = chMgr rc := newMockRootCoord() - rc.ImportFunc = func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { + rc.ImportFunc = func(ctx context.Context, req *milvuspb.ImportRequest, opts ...grpc.CallOption) (*milvuspb.ImportResponse, error) { return &milvuspb.ImportResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil } proxy.rootCoord = rc @@ -4142,10 +4282,10 @@ func TestProxy_Import(t *testing.T) { defer wg.Done() proxy := &Proxy{} proxy.UpdateStateCode(commonpb.StateCode_Healthy) - chMgr := newMockChannelsMgr() + chMgr := NewMockChannelsMgr(t) proxy.chMgr = chMgr rc := newMockRootCoord() - rc.ImportFunc = func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { + rc.ImportFunc = func(ctx context.Context, req *milvuspb.ImportRequest, opts ...grpc.CallOption) (*milvuspb.ImportResponse, error) { return &milvuspb.ImportResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}}, nil } proxy.rootCoord = rc @@ -4177,17 +4317,17 @@ func TestProxy_GetImportState(t *testing.T) { rootCoord.state.Store(commonpb.StateCode_Healthy) t.Run("test get import state", func(t *testing.T) { proxy := &Proxy{rootCoord: rootCoord} - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) resp, err := proxy.GetImportState(context.TODO(), req) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.NoError(t, err) }) t.Run("test get import state with unhealthy", func(t *testing.T) { proxy := &Proxy{rootCoord: rootCoord} - proxy.stateCode.Store(commonpb.StateCode_Abnormal) + proxy.UpdateStateCode(commonpb.StateCode_Abnormal) resp, err := proxy.GetImportState(context.TODO(), req) - assert.EqualValues(t, unhealthyStatus(), resp.Status) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) assert.NoError(t, err) }) } @@ -4198,23 +4338,22 @@ func TestProxy_ListImportTasks(t *testing.T) { rootCoord.state.Store(commonpb.StateCode_Healthy) t.Run("test list import tasks", func(t *testing.T) { proxy := &Proxy{rootCoord: rootCoord} - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) resp, err := proxy.ListImportTasks(context.TODO(), req) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.NoError(t, err) }) t.Run("test list import tasks with unhealthy", func(t *testing.T) { proxy := &Proxy{rootCoord: rootCoord} - proxy.stateCode.Store(commonpb.StateCode_Abnormal) + proxy.UpdateStateCode(commonpb.StateCode_Abnormal) resp, err := proxy.ListImportTasks(context.TODO(), req) - assert.EqualValues(t, unhealthyStatus(), resp.Status) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrServiceNotReady) assert.NoError(t, err) }) } func TestProxy_GetStatistics(t *testing.T) { - } func TestProxy_GetLoadState(t *testing.T) { @@ -4237,114 +4376,54 @@ func TestProxy_GetLoadState(t *testing.T) { }() { - qc := getQueryCoord() - qc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ - State: &milvuspb.ComponentInfo{ - NodeID: 0, - Role: typeutil.QueryCoordRole, - StateCode: commonpb.StateCode_Abnormal, - ExtraInfo: nil, - }, - SubcomponentStates: nil, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - }, nil) + qc := getQueryCoordClient() qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Status: merr.Status(merr.WrapErrServiceNotReady(paramtable.GetRole(), paramtable.GetNodeID(), "initialization")), CollectionIDs: nil, InMemoryPercentages: []int64{}, }, nil) proxy := &Proxy{queryCoord: qc} - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stateResp.Status.ErrorCode) + assert.ErrorIs(t, merr.Error(stateResp.GetStatus()), merr.ErrServiceNotReady) progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode) + assert.ErrorIs(t, merr.Error(progressResp.GetStatus()), merr.ErrServiceNotReady) } { - qc := getQueryCoord() - qc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ - State: &milvuspb.ComponentInfo{ - NodeID: 0, - Role: typeutil.QueryCoordRole, - StateCode: commonpb.StateCode_Healthy, - ExtraInfo: nil, - }, - SubcomponentStates: nil, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - }, nil) - qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(nil, errors.New("test")) - qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(nil, errors.New("test")) + qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(nil, merr.WrapErrCollectionNotLoaded("foo")) + qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(nil, merr.WrapErrPartitionNotLoaded("p1")) proxy := &Proxy{queryCoord: qc} - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State) stateResp, err = proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo", PartitionNames: []string{"p1"}}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State) progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.GetStatus().GetErrorCode()) assert.Equal(t, int64(0), progressResp.Progress) progressResp, err = proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo", PartitionNames: []string{"p1"}}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.GetStatus().GetErrorCode()) assert.Equal(t, int64(0), progressResp.Progress) } { - qc := getQueryCoord() - qc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ - State: &milvuspb.ComponentInfo{ - NodeID: 0, - Role: typeutil.QueryCoordRole, - StateCode: commonpb.StateCode_Healthy, - ExtraInfo: nil, - }, - SubcomponentStates: nil, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - }, nil) - qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, - CollectionIDs: nil, - InMemoryPercentages: []int64{}, - }, nil) - qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(nil, errors.New("test")) - proxy := &Proxy{queryCoord: qc} - proxy.stateCode.Store(commonpb.StateCode_Healthy) - - stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"}) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) - assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State) - - progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"}) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode) - } - - { - qc := getQueryCoord() - qc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ + qc := getQueryCoordClient() + qc.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ NodeID: 0, Role: typeutil.QueryCoordRole, @@ -4352,10 +4431,7 @@ func TestProxy_GetLoadState(t *testing.T) { ExtraInfo: nil, }, SubcomponentStates: nil, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), }, nil) qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, @@ -4363,26 +4439,26 @@ func TestProxy_GetLoadState(t *testing.T) { InMemoryPercentages: []int64{100}, }, nil) proxy := &Proxy{queryCoord: qc} - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo", Base: &commonpb.MsgBase{}}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.LoadState_LoadStateLoaded, stateResp.State) stateResp, err = proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: ""}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stateResp.Status.ErrorCode) + assert.ErrorIs(t, merr.Error(stateResp.GetStatus()), merr.ErrParameterInvalid) progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, progressResp.GetStatus().GetErrorCode()) assert.Equal(t, int64(100), progressResp.Progress) } { - qc := getQueryCoord() - qc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ + qc := getQueryCoordClient() + qc.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ NodeID: 0, Role: typeutil.QueryCoordRole, @@ -4390,10 +4466,7 @@ func TestProxy_GetLoadState(t *testing.T) { ExtraInfo: nil, }, SubcomponentStates: nil, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), }, nil) qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, @@ -4401,22 +4474,22 @@ func TestProxy_GetLoadState(t *testing.T) { InMemoryPercentages: []int64{50}, }, nil) proxy := &Proxy{queryCoord: qc} - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.LoadState_LoadStateLoading, stateResp.State) progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, progressResp.GetStatus().GetErrorCode()) assert.Equal(t, int64(50), progressResp.Progress) } t.Run("test insufficient memory", func(t *testing.T) { - qc := getQueryCoord() - qc.EXPECT().GetComponentStates(mock.Anything).Return(&milvuspb.ComponentStates{ + qc := getQueryCoordClient() + qc.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ NodeID: 0, Role: typeutil.QueryCoordRole, @@ -4424,31 +4497,30 @@ func TestProxy_GetLoadState(t *testing.T) { ExtraInfo: nil, }, SubcomponentStates: nil, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), }, nil) + + mockErr := merr.WrapErrServiceMemoryLimitExceeded(110, 100) qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad}, + Status: merr.Status(mockErr), }, nil) qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_InsufficientMemoryToLoad}, + Status: merr.Status(mockErr), }, nil) proxy := &Proxy{queryCoord: qc} - proxy.stateCode.Store(commonpb.StateCode_Healthy) + proxy.UpdateStateCode(commonpb.StateCode_Healthy) stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, stateResp.GetStatus().GetErrorCode()) progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, progressResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, progressResp.GetStatus().GetErrorCode()) progressResp, err = proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo", PartitionNames: []string{"p1"}}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, progressResp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_InsufficientMemoryToLoad, progressResp.GetStatus().GetErrorCode()) }) } @@ -4467,6 +4539,6 @@ func TestUnhealthProxy_GetIndexStatistics(t *testing.T) { IndexName: "", }) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_NotReadyServe, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_NotReadyServe, resp.GetStatus().GetErrorCode()) }) } diff --git a/internal/proxy/rate_limit_interceptor.go b/internal/proxy/rate_limit_interceptor.go index 54f3d1f7487d2..53a6b6331b08a 100644 --- a/internal/proxy/rate_limit_interceptor.go +++ b/internal/proxy/rate_limit_interceptor.go @@ -21,13 +21,14 @@ import ( "fmt" "reflect" + "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" "google.golang.org/grpc" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/merr" ) // RateLimitInterceptor returns a new unary server interceptors that performs request rate limiting. @@ -38,9 +39,9 @@ func RateLimitInterceptor(limiter types.Limiter) grpc.UnaryServerInterceptor { return handler(ctx, req) } - code := limiter.Check(collectionID, rt, n) - if code != commonpb.ErrorCode_Success { - rsp := getFailedResponse(req, rt, code, info.FullMethod) + err = limiter.Check(collectionID, rt, n) + if err != nil { + rsp := getFailedResponse(req, rt, err, info.FullMethod) if rsp != nil { return rsp, nil } @@ -113,31 +114,16 @@ func getRequestInfo(req interface{}) (int64, internalpb.RateType, int, error) { } } -// failedStatus returns failed status. -func failedStatus(code commonpb.ErrorCode, reason string) *commonpb.Status { - return &commonpb.Status{ - ErrorCode: code, - Reason: reason, - } -} - // failedMutationResult returns failed mutation result. -func failedMutationResult(code commonpb.ErrorCode, reason string) *milvuspb.MutationResult { +func failedMutationResult(err error) *milvuspb.MutationResult { return &milvuspb.MutationResult{ - Status: failedStatus(code, reason), + Status: merr.Status(err), } } -// failedBoolResponse returns failed boolean response. -func failedBoolResponse(code commonpb.ErrorCode, reason string) *milvuspb.BoolResponse { - return &milvuspb.BoolResponse{ - Status: failedStatus(code, reason), - } -} - -func wrapQuotaError(rt internalpb.RateType, errCode commonpb.ErrorCode, fullMethod string) error { - if errCode == commonpb.ErrorCode_RateLimit { - return fmt.Errorf("request is rejected by grpc RateLimiter middleware, please retry later, req: %s", fullMethod) +func wrapQuotaError(rt internalpb.RateType, err error, fullMethod string) error { + if errors.Is(err, merr.ErrServiceRateLimit) { + return errors.Wrapf(err, "request %s is rejected by grpc RateLimiter middleware, please retry later", fullMethod) } // deny to write/read @@ -148,40 +134,41 @@ func wrapQuotaError(rt internalpb.RateType, errCode commonpb.ErrorCode, fullMeth case internalpb.RateType_DQLSearch, internalpb.RateType_DQLQuery: op = "read" } - return fmt.Errorf("deny to %s, reason: %s, req: %s", op, GetQuotaErrorString(errCode), fullMethod) + + return merr.WrapErrServiceForceDeny(op, err, fullMethod) } // getFailedResponse returns failed response. -func getFailedResponse(req interface{}, rt internalpb.RateType, errCode commonpb.ErrorCode, fullMethod string) interface{} { - err := wrapQuotaError(rt, errCode, fullMethod) +func getFailedResponse(req any, rt internalpb.RateType, err error, fullMethod string) any { + err = wrapQuotaError(rt, err, fullMethod) switch req.(type) { case *milvuspb.InsertRequest, *milvuspb.DeleteRequest, *milvuspb.UpsertRequest: - return failedMutationResult(errCode, err.Error()) + return failedMutationResult(err) case *milvuspb.ImportRequest: return &milvuspb.ImportResponse{ - Status: failedStatus(errCode, err.Error()), + Status: merr.Status(err), } case *milvuspb.SearchRequest: return &milvuspb.SearchResults{ - Status: failedStatus(errCode, err.Error()), + Status: merr.Status(err), } case *milvuspb.QueryRequest: return &milvuspb.QueryResults{ - Status: failedStatus(errCode, err.Error()), + Status: merr.Status(err), } case *milvuspb.CreateCollectionRequest, *milvuspb.DropCollectionRequest, *milvuspb.LoadCollectionRequest, *milvuspb.ReleaseCollectionRequest, *milvuspb.CreatePartitionRequest, *milvuspb.DropPartitionRequest, *milvuspb.LoadPartitionsRequest, *milvuspb.ReleasePartitionsRequest, *milvuspb.CreateIndexRequest, *milvuspb.DropIndexRequest: - return failedStatus(errCode, err.Error()) + return merr.Status(err) case *milvuspb.FlushRequest: return &milvuspb.FlushResponse{ - Status: failedStatus(errCode, err.Error()), + Status: merr.Status(err), } case *milvuspb.ManualCompactionRequest: return &milvuspb.ManualCompactionResponse{ - Status: failedStatus(errCode, err.Error()), + Status: merr.Status(err), } } return nil diff --git a/internal/proxy/rate_limit_interceptor_test.go b/internal/proxy/rate_limit_interceptor_test.go index e15eea87fd71b..4300b9c0b6f61 100644 --- a/internal/proxy/rate_limit_interceptor_test.go +++ b/internal/proxy/rate_limit_interceptor_test.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/pkg/util/merr" ) type limiterMock struct { @@ -37,14 +38,14 @@ type limiterMock struct { quotaStateReasons []commonpb.ErrorCode } -func (l *limiterMock) Check(collection int64, rt internalpb.RateType, n int) commonpb.ErrorCode { +func (l *limiterMock) Check(collection int64, rt internalpb.RateType, n int) error { if l.rate == 0 { - return commonpb.ErrorCode_ForceDeny + return merr.ErrServiceForceDeny } if l.limit { - return commonpb.ErrorCode_RateLimit + return merr.ErrServiceRateLimit } - return commonpb.ErrorCode_Success + return nil } func TestRateLimitInterceptor(t *testing.T) { @@ -165,24 +166,24 @@ func TestRateLimitInterceptor(t *testing.T) { }) t.Run("test getFailedResponse", func(t *testing.T) { - testGetFailedResponse := func(req interface{}, rt internalpb.RateType, errCode commonpb.ErrorCode, fullMethod string) { - rsp := getFailedResponse(req, rt, errCode, fullMethod) + testGetFailedResponse := func(req interface{}, rt internalpb.RateType, err error, fullMethod string) { + rsp := getFailedResponse(req, rt, err, fullMethod) assert.NotNil(t, rsp) } - testGetFailedResponse(&milvuspb.DeleteRequest{}, internalpb.RateType_DMLDelete, commonpb.ErrorCode_ForceDeny, "delete") - testGetFailedResponse(&milvuspb.UpsertRequest{}, internalpb.RateType_DMLUpsert, commonpb.ErrorCode_ForceDeny, "upsert") - testGetFailedResponse(&milvuspb.ImportRequest{}, internalpb.RateType_DMLBulkLoad, commonpb.ErrorCode_MemoryQuotaExhausted, "import") - testGetFailedResponse(&milvuspb.SearchRequest{}, internalpb.RateType_DQLSearch, commonpb.ErrorCode_DiskQuotaExhausted, "search") - testGetFailedResponse(&milvuspb.QueryRequest{}, internalpb.RateType_DQLQuery, commonpb.ErrorCode_ForceDeny, "query") - testGetFailedResponse(&milvuspb.CreateCollectionRequest{}, internalpb.RateType_DDLCollection, commonpb.ErrorCode_RateLimit, "createCollection") - testGetFailedResponse(&milvuspb.FlushRequest{}, internalpb.RateType_DDLFlush, commonpb.ErrorCode_RateLimit, "flush") - testGetFailedResponse(&milvuspb.ManualCompactionRequest{}, internalpb.RateType_DDLCompaction, commonpb.ErrorCode_RateLimit, "compaction") + testGetFailedResponse(&milvuspb.DeleteRequest{}, internalpb.RateType_DMLDelete, merr.ErrServiceForceDeny, "delete") + testGetFailedResponse(&milvuspb.UpsertRequest{}, internalpb.RateType_DMLUpsert, merr.ErrServiceForceDeny, "upsert") + testGetFailedResponse(&milvuspb.ImportRequest{}, internalpb.RateType_DMLBulkLoad, merr.ErrServiceMemoryLimitExceeded, "import") + testGetFailedResponse(&milvuspb.SearchRequest{}, internalpb.RateType_DQLSearch, merr.ErrServiceDiskLimitExceeded, "search") + testGetFailedResponse(&milvuspb.QueryRequest{}, internalpb.RateType_DQLQuery, merr.ErrServiceForceDeny, "query") + testGetFailedResponse(&milvuspb.CreateCollectionRequest{}, internalpb.RateType_DDLCollection, merr.ErrServiceRateLimit, "createCollection") + testGetFailedResponse(&milvuspb.FlushRequest{}, internalpb.RateType_DDLFlush, merr.ErrServiceRateLimit, "flush") + testGetFailedResponse(&milvuspb.ManualCompactionRequest{}, internalpb.RateType_DDLCompaction, merr.ErrServiceRateLimit, "compaction") // test illegal - rsp := getFailedResponse(&milvuspb.SearchResults{}, internalpb.RateType_DQLSearch, commonpb.ErrorCode_UnexpectedError, "method") + rsp := getFailedResponse(&milvuspb.SearchResults{}, internalpb.RateType_DQLSearch, merr.OldCodeToMerr(commonpb.ErrorCode_UnexpectedError), "method") assert.Nil(t, rsp) - rsp = getFailedResponse(nil, internalpb.RateType_DQLSearch, commonpb.ErrorCode_UnexpectedError, "method") + rsp = getFailedResponse(nil, internalpb.RateType_DQLSearch, merr.OldCodeToMerr(commonpb.ErrorCode_UnexpectedError), "method") assert.Nil(t, rsp) }) @@ -198,9 +199,7 @@ func TestRateLimitInterceptor(t *testing.T) { limiter := limiterMock{rate: 100} handler := func(ctx context.Context, req interface{}) (interface{}, error) { return &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } serverInfo := &grpc.UnaryServerInfo{FullMethod: "MockFullMethod"} diff --git a/internal/proxy/reducer.go b/internal/proxy/reducer.go index 2e59fdf4d2f6a..e1a9697d26c22 100644 --- a/internal/proxy/reducer.go +++ b/internal/proxy/reducer.go @@ -16,11 +16,6 @@ type milvusReducer interface { func createMilvusReducer(ctx context.Context, params *queryParams, req *internalpb.RetrieveRequest, schema *schemapb.CollectionSchema, plan *planpb.PlanNode, collectionName string) milvusReducer { if plan.GetQuery().GetIsCount() { return &cntReducer{} - } else if req.GetIterationExtensionReduceRate() > 0 { - params.limit = params.limit * req.GetIterationExtensionReduceRate() - if params.limit > Params.QuotaConfig.TopKLimit.GetAsInt64() { - params.limit = Params.QuotaConfig.TopKLimit.GetAsInt64() - } } return newDefaultLimitReducer(ctx, params, req, schema, collectionName) } diff --git a/internal/proxy/reducer_test.go b/internal/proxy/reducer_test.go index 29693f50a0d21..668a94ce97e3e 100644 --- a/internal/proxy/reducer_test.go +++ b/internal/proxy/reducer_test.go @@ -1,11 +1,12 @@ package proxy import ( + "context" "testing" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/proto/planpb" ) func Test_createMilvusReducer(t *testing.T) { @@ -17,35 +18,14 @@ func Test_createMilvusReducer(t *testing.T) { }, } var r milvusReducer + ctx := context.Background() - r = createMilvusReducer(nil, nil, nil, nil, n, "") + r = createMilvusReducer(ctx, nil, nil, nil, n, "") _, ok := r.(*defaultLimitReducer) assert.True(t, ok) n.Node.(*planpb.PlanNode_Query).Query.IsCount = true - r = createMilvusReducer(nil, nil, nil, nil, n, "") + r = createMilvusReducer(ctx, nil, nil, nil, n, "") _, ok = r.(*cntReducer) assert.True(t, ok) - - req := &internalpb.RetrieveRequest{ - IterationExtensionReduceRate: 100, - } - params := &queryParams{ - limit: 10, - } - r = createMilvusReducer(nil, params, req, nil, nil, "") - defaultReducer, typeOk := r.(*defaultLimitReducer) - assert.True(t, typeOk) - assert.Equal(t, int64(10*100), defaultReducer.params.limit) - - req = &internalpb.RetrieveRequest{ - IterationExtensionReduceRate: 1000, - } - params = &queryParams{ - limit: 100, - } - r = createMilvusReducer(nil, params, req, nil, nil, "") - defaultReducer, typeOk = r.(*defaultLimitReducer) - assert.True(t, typeOk) - assert.Equal(t, int64(16384), defaultReducer.params.limit) } diff --git a/internal/proxy/repack_func.go b/internal/proxy/repack_func.go index 618f6aa0bac22..5b5b9e5ab20d7 100644 --- a/internal/proxy/repack_func.go +++ b/internal/proxy/repack_func.go @@ -27,7 +27,6 @@ func insertRepackFunc( tsMsgs []msgstream.TsMsg, hashKeys [][]int32, ) (map[int32]*msgstream.MsgPack, error) { - if len(hashKeys) < len(tsMsgs) { return nil, fmt.Errorf( "the length of hash keys (%d) is less than the length of messages (%d)", @@ -59,7 +58,6 @@ func defaultInsertRepackFunc( tsMsgs []msgstream.TsMsg, hashKeys [][]int32, ) (map[int32]*msgstream.MsgPack, error) { - if len(hashKeys) < len(tsMsgs) { return nil, fmt.Errorf( "the length of hash keys (%d) is less than the length of messages (%d)", @@ -83,3 +81,14 @@ func defaultInsertRepackFunc( } return pack, nil } + +func replicatePackFunc( + tsMsgs []msgstream.TsMsg, + hashKeys [][]int32, +) (map[int32]*msgstream.MsgPack, error) { + return map[int32]*msgstream.MsgPack{ + 0: { + Msgs: tsMsgs, + }, + }, nil +} diff --git a/internal/proxy/repack_func_test.go b/internal/proxy/repack_func_test.go index e5cf96d64db75..ffc01e4b79047 100644 --- a/internal/proxy/repack_func_test.go +++ b/internal/proxy/repack_func_test.go @@ -20,9 +20,9 @@ import ( "math/rand" "testing" - "github.com/milvus-io/milvus/pkg/mq/msgstream" - "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/mq/msgstream" ) func Test_insertRepackFunc(t *testing.T) { diff --git a/internal/proxy/replicate_stream_manager.go b/internal/proxy/replicate_stream_manager.go new file mode 100644 index 0000000000000..5bf01d1f6e244 --- /dev/null +++ b/internal/proxy/replicate_stream_manager.go @@ -0,0 +1,72 @@ +package proxy + +import ( + "context" + "time" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/resource" +) + +const ( + ReplicateMsgStreamTyp = "replicate_msg_stream" + ReplicateMsgStreamExpireTime = 30 * time.Second +) + +type ReplicateStreamManager struct { + ctx context.Context + factory msgstream.Factory + dispatcher msgstream.UnmarshalDispatcher + resourceManager resource.Manager +} + +func NewReplicateStreamManager(ctx context.Context, factory msgstream.Factory, resourceManager resource.Manager) *ReplicateStreamManager { + manager := &ReplicateStreamManager{ + ctx: ctx, + factory: factory, + dispatcher: (&msgstream.ProtoUDFactory{}).NewUnmarshalDispatcher(), + resourceManager: resourceManager, + } + return manager +} + +func (m *ReplicateStreamManager) newMsgStreamResource(channel string) resource.NewResourceFunc { + return func() (resource.Resource, error) { + msgStream, err := m.factory.NewMsgStream(m.ctx) + if err != nil { + log.Ctx(m.ctx).Warn("failed to create msg stream", zap.String("channel", channel), zap.Error(err)) + return nil, err + } + msgStream.SetRepackFunc(replicatePackFunc) + msgStream.AsProducer([]string{channel}) + msgStream.EnableProduce(true) + + res := resource.NewSimpleResource(msgStream, ReplicateMsgStreamTyp, channel, ReplicateMsgStreamExpireTime, func() { + msgStream.Close() + }) + + return res, nil + } +} + +func (m *ReplicateStreamManager) GetReplicateMsgStream(ctx context.Context, channel string) (msgstream.MsgStream, error) { + ctxLog := log.Ctx(ctx).With(zap.String("proxy_channel", channel)) + res, err := m.resourceManager.Get(ReplicateMsgStreamTyp, channel, m.newMsgStreamResource(channel)) + if err != nil { + ctxLog.Warn("failed to get replicate msg stream", zap.String("channel", channel), zap.Error(err)) + return nil, err + } + if obj, ok := res.Get().(msgstream.MsgStream); ok && obj != nil { + return obj, nil + } + ctxLog.Warn("invalid resource object", zap.Any("obj", res.Get())) + return nil, merr.ErrInvalidStreamObj +} + +func (m *ReplicateStreamManager) GetMsgDispatcher() msgstream.UnmarshalDispatcher { + return m.dispatcher +} diff --git a/internal/proxy/replicate_stream_manager_test.go b/internal/proxy/replicate_stream_manager_test.go new file mode 100644 index 0000000000000..f367750c55a10 --- /dev/null +++ b/internal/proxy/replicate_stream_manager_test.go @@ -0,0 +1,79 @@ +package proxy + +import ( + "context" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/resource" +) + +func TestReplicateManager(t *testing.T) { + factory := newMockMsgStreamFactory() + resourceManager := resource.NewManager(time.Second, 2*time.Second, nil) + manager := NewReplicateStreamManager(context.Background(), factory, resourceManager) + + { + factory.f = func(ctx context.Context) (msgstream.MsgStream, error) { + return nil, errors.New("mock msgstream fail") + } + _, err := manager.GetReplicateMsgStream(context.Background(), "test") + assert.Error(t, err) + } + { + mockMsgStream := newMockMsgStream() + i := 0 + mockMsgStream.setRepack = func(repackFunc msgstream.RepackFunc) { + i++ + } + mockMsgStream.asProducer = func(producers []string) { + i++ + } + mockMsgStream.enableProduce = func(b bool) { + i++ + } + mockMsgStream.close = func() { + i++ + } + factory.f = func(ctx context.Context) (msgstream.MsgStream, error) { + return mockMsgStream, nil + } + _, err := manager.GetReplicateMsgStream(context.Background(), "test") + assert.NoError(t, err) + assert.Equal(t, 3, i) + time.Sleep(time.Second) + _, err = manager.GetReplicateMsgStream(context.Background(), "test") + assert.NoError(t, err) + assert.Equal(t, 3, i) + res := resourceManager.Delete(ReplicateMsgStreamTyp, "test") + assert.NotNil(t, res) + time.Sleep(2 * time.Second) + + _, err = manager.GetReplicateMsgStream(context.Background(), "test") + assert.NoError(t, err) + assert.Equal(t, 7, i) + } + { + res := resourceManager.Delete(ReplicateMsgStreamTyp, "test") + assert.NotNil(t, res) + time.Sleep(2 * time.Second) + + res, err := resourceManager.Get(ReplicateMsgStreamTyp, "test", func() (resource.Resource, error) { + return resource.NewResource(resource.WithObj("str")), nil + }) + assert.NoError(t, err) + assert.Equal(t, "str", res.Get()) + + _, err = manager.GetReplicateMsgStream(context.Background(), "test") + assert.ErrorIs(t, err, merr.ErrInvalidStreamObj) + } + + { + assert.NotNil(t, manager.GetMsgDispatcher()) + } +} diff --git a/internal/proxy/rootcoord_mock_test.go b/internal/proxy/rootcoord_mock_test.go index 44f02449df970..c2a37542ced94 100644 --- a/internal/proxy/rootcoord_mock_test.go +++ b/internal/proxy/rootcoord_mock_test.go @@ -25,6 +25,7 @@ import ( "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -36,6 +37,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/milvuserrors" "github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/uniquegenerator" @@ -103,10 +105,10 @@ type RootCoordMock struct { lastTs typeutil.Timestamp lastTsMtx sync.Mutex - checkHealthFunc func(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) + checkHealthFunc func(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) } -func (coord *RootCoordMock) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &commonpb.Status{ @@ -134,13 +136,10 @@ func (coord *RootCoordMock) CreateAlias(ctx context.Context, req *milvuspb.Creat } coord.collAlias2ID[req.Alias] = collID - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Success(), nil } -func (coord *RootCoordMock) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &commonpb.Status{ @@ -160,13 +159,10 @@ func (coord *RootCoordMock) DropAlias(ctx context.Context, req *milvuspb.DropAli } delete(coord.collAlias2ID, req.Alias) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Success(), nil } -func (coord *RootCoordMock) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &commonpb.Status{ @@ -192,10 +188,7 @@ func (coord *RootCoordMock) AlterAlias(ctx context.Context, req *milvuspb.AlterA }, nil } coord.collAlias2ID[req.Alias] = collID - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Success(), nil } func (coord *RootCoordMock) updateState(state commonpb.StateCode) { @@ -210,24 +203,12 @@ func (coord *RootCoordMock) healthy() bool { return coord.getState() == commonpb.StateCode_Healthy } -func (coord *RootCoordMock) Init() error { - coord.updateState(commonpb.StateCode_Initializing) - return nil -} - -func (coord *RootCoordMock) Start() error { - defer coord.updateState(commonpb.StateCode_Healthy) - - return nil -} - -func (coord *RootCoordMock) Stop() error { +func (coord *RootCoordMock) Close() error { defer coord.updateState(commonpb.StateCode_Abnormal) - return nil } -func (coord *RootCoordMock) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (coord *RootCoordMock) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ NodeID: coord.nodeID, @@ -236,14 +217,11 @@ func (coord *RootCoordMock) GetComponentStates(ctx context.Context) (*milvuspb.C ExtraInfo: nil, }, SubcomponentStates: nil, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), }, nil } -func (coord *RootCoordMock) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (coord *RootCoordMock) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &milvuspb.StringResponse{ @@ -254,11 +232,8 @@ func (coord *RootCoordMock) GetStatisticsChannel(ctx context.Context) (*milvuspb }, nil } return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Value: coord.statisticsChannel, + Status: merr.Success(), + Value: coord.statisticsChannel, }, nil } @@ -266,7 +241,7 @@ func (coord *RootCoordMock) Register() error { return nil } -func (coord *RootCoordMock) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (coord *RootCoordMock) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &milvuspb.StringResponse{ @@ -277,15 +252,12 @@ func (coord *RootCoordMock) GetTimeTickChannel(ctx context.Context) (*milvuspb.S }, nil } return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Value: coord.timeTickChannel, + Status: merr.Success(), + Value: coord.timeTickChannel, }, nil } -func (coord *RootCoordMock) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &commonpb.Status{ @@ -384,13 +356,10 @@ func (coord *RootCoordMock) CreateCollection(ctx context.Context, req *milvuspb. coord.collID2Partitions[collID].partitionID2Meta[id] = partitionMeta{} } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Success(), nil } -func (coord *RootCoordMock) DropCollection(ctx context.Context, req *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) DropCollection(ctx context.Context, req *milvuspb.DropCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &commonpb.Status{ @@ -418,13 +387,10 @@ func (coord *RootCoordMock) DropCollection(ctx context.Context, req *milvuspb.Dr delete(coord.collID2Partitions, collID) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Success(), nil } -func (coord *RootCoordMock) HasCollection(ctx context.Context, req *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { +func (coord *RootCoordMock) HasCollection(ctx context.Context, req *milvuspb.HasCollectionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &milvuspb.BoolResponse{ @@ -441,11 +407,8 @@ func (coord *RootCoordMock) HasCollection(ctx context.Context, req *milvuspb.Has _, exist := coord.collName2ID[req.CollectionName] return &milvuspb.BoolResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Value: exist, + Status: merr.Success(), + Value: exist, }, nil } @@ -457,7 +420,7 @@ func (coord *RootCoordMock) ResetDescribeCollectionFunc() { coord.describeCollectionFunc = nil } -func (coord *RootCoordMock) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (coord *RootCoordMock) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &milvuspb.DescribeCollectionResponse{ @@ -503,10 +466,7 @@ func (coord *RootCoordMock) DescribeCollection(ctx context.Context, req *milvusp } return &milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), Schema: meta.schema, CollectionID: collID, ShardsNum: meta.shardsNum, @@ -517,11 +477,11 @@ func (coord *RootCoordMock) DescribeCollection(ctx context.Context, req *milvusp }, nil } -func (coord *RootCoordMock) DescribeCollectionInternal(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (coord *RootCoordMock) DescribeCollectionInternal(ctx context.Context, req *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { return coord.DescribeCollection(ctx, req) } -func (coord *RootCoordMock) ShowCollections(ctx context.Context, req *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { +func (coord *RootCoordMock) ShowCollections(ctx context.Context, req *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &milvuspb.ShowCollectionsResponse{ @@ -550,10 +510,7 @@ func (coord *RootCoordMock) ShowCollections(ctx context.Context, req *milvuspb.S } return &milvuspb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), CollectionNames: names, CollectionIds: ids, CreatedTimestamps: createdTimestamps, @@ -562,7 +519,7 @@ func (coord *RootCoordMock) ShowCollections(ctx context.Context, req *milvuspb.S }, nil } -func (coord *RootCoordMock) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &commonpb.Status{ @@ -602,13 +559,10 @@ func (coord *RootCoordMock) CreatePartition(ctx context.Context, req *milvuspb.C createdUtcTimestamp: ts, } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Success(), nil } -func (coord *RootCoordMock) DropPartition(ctx context.Context, req *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) DropPartition(ctx context.Context, req *milvuspb.DropPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &commonpb.Status{ @@ -641,13 +595,10 @@ func (coord *RootCoordMock) DropPartition(ctx context.Context, req *milvuspb.Dro delete(coord.collID2Partitions[collID].partitionName2ID, req.PartitionName) delete(coord.collID2Partitions[collID].partitionID2Name, partitionID) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Success(), nil } -func (coord *RootCoordMock) HasPartition(ctx context.Context, req *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { +func (coord *RootCoordMock) HasPartition(ctx context.Context, req *milvuspb.HasPartitionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &milvuspb.BoolResponse{ @@ -677,15 +628,12 @@ func (coord *RootCoordMock) HasPartition(ctx context.Context, req *milvuspb.HasP _, partitionExist := coord.collID2Partitions[collID].partitionName2ID[req.PartitionName] return &milvuspb.BoolResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Value: partitionExist, + Status: merr.Success(), + Value: partitionExist, }, nil } -func (coord *RootCoordMock) ShowPartitions(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { +func (coord *RootCoordMock) ShowPartitions(ctx context.Context, req *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &milvuspb.ShowPartitionsResponse{ @@ -734,10 +682,7 @@ func (coord *RootCoordMock) ShowPartitions(ctx context.Context, req *milvuspb.Sh } return &milvuspb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), PartitionNames: names, PartitionIDs: ids, CreatedTimestamps: createdTimestamps, @@ -746,7 +691,7 @@ func (coord *RootCoordMock) ShowPartitions(ctx context.Context, req *milvuspb.Sh }, nil } -func (coord *RootCoordMock) ShowPartitionsInternal(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { +func (coord *RootCoordMock) ShowPartitionsInternal(ctx context.Context, req *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { return coord.ShowPartitions(ctx, req) } @@ -816,7 +761,7 @@ func (coord *RootCoordMock) ShowPartitionsInternal(ctx context.Context, req *mil // }, nil //} -func (coord *RootCoordMock) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { +func (coord *RootCoordMock) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &rootcoordpb.AllocTimestampResponse{ @@ -838,16 +783,13 @@ func (coord *RootCoordMock) AllocTimestamp(ctx context.Context, req *rootcoordpb coord.lastTs = ts return &rootcoordpb.AllocTimestampResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), Timestamp: ts, Count: req.Count, }, nil } -func (coord *RootCoordMock) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { +func (coord *RootCoordMock) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &rootcoordpb.AllocIDResponse{ @@ -861,16 +803,13 @@ func (coord *RootCoordMock) AllocID(ctx context.Context, req *rootcoordpb.AllocI } begin, _ := uniquegenerator.GetUniqueIntGeneratorIns().GetInts(int(req.Count)) return &rootcoordpb.AllocIDResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - ID: int64(begin), - Count: req.Count, + Status: merr.Success(), + ID: int64(begin), + Count: req.Count, }, nil } -func (coord *RootCoordMock) UpdateChannelTimeTick(ctx context.Context, req *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) { +func (coord *RootCoordMock) UpdateChannelTimeTick(ctx context.Context, req *internalpb.ChannelTimeTickMsg, opts ...grpc.CallOption) (*commonpb.Status, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &commonpb.Status{ @@ -878,13 +817,10 @@ func (coord *RootCoordMock) UpdateChannelTimeTick(ctx context.Context, req *inte Reason: fmt.Sprintf("state code = %s", commonpb.StateCode_name[int32(code)]), }, nil } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Success(), nil } -func (coord *RootCoordMock) DescribeSegment(ctx context.Context, req *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) { +func (coord *RootCoordMock) DescribeSegment(ctx context.Context, req *milvuspb.DescribeSegmentRequest, opts ...grpc.CallOption) (*milvuspb.DescribeSegmentResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &milvuspb.DescribeSegmentResponse{ @@ -896,17 +832,14 @@ func (coord *RootCoordMock) DescribeSegment(ctx context.Context, req *milvuspb.D }, nil } return &milvuspb.DescribeSegmentResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), IndexID: 0, BuildID: 0, EnableIndex: false, }, nil } -func (coord *RootCoordMock) ShowSegments(ctx context.Context, req *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { +func (coord *RootCoordMock) ShowSegments(ctx context.Context, req *milvuspb.ShowSegmentsRequest, opts ...grpc.CallOption) (*milvuspb.ShowSegmentsResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &milvuspb.ShowSegmentsResponse{ @@ -918,19 +851,16 @@ func (coord *RootCoordMock) ShowSegments(ctx context.Context, req *milvuspb.Show }, nil } return &milvuspb.ShowSegmentsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), SegmentIDs: nil, }, nil } -func (coord *RootCoordMock) DescribeSegments(ctx context.Context, req *rootcoordpb.DescribeSegmentsRequest) (*rootcoordpb.DescribeSegmentsResponse, error) { +func (coord *RootCoordMock) DescribeSegments(ctx context.Context, req *rootcoordpb.DescribeSegmentsRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeSegmentsResponse, error) { panic("implement me") } -func (coord *RootCoordMock) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &commonpb.Status{ @@ -938,10 +868,7 @@ func (coord *RootCoordMock) InvalidateCollectionMetaCache(ctx context.Context, i Reason: fmt.Sprintf("state code = %s", commonpb.StateCode_name[int32(code)]), }, nil } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Success(), nil } func (coord *RootCoordMock) SegmentFlushCompleted(ctx context.Context, in *datapb.SegmentFlushCompletedMsg) (*commonpb.Status, error) { @@ -952,13 +879,10 @@ func (coord *RootCoordMock) SegmentFlushCompleted(ctx context.Context, in *datap Reason: fmt.Sprintf("state code = %s", commonpb.StateCode_name[int32(code)]), }, nil } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Success(), nil } -func (coord *RootCoordMock) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { +func (coord *RootCoordMock) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { if !coord.healthy() { return &internalpb.ShowConfigurationsResponse{ Status: &commonpb.Status{ @@ -980,7 +904,7 @@ func (coord *RootCoordMock) ShowConfigurations(ctx context.Context, req *interna }, nil } -func (coord *RootCoordMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (coord *RootCoordMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { if !coord.healthy() { return &milvuspb.GetMetricsResponse{ Status: &commonpb.Status{ @@ -1004,7 +928,7 @@ func (coord *RootCoordMock) GetMetrics(ctx context.Context, req *milvuspb.GetMet }, nil } -func (coord *RootCoordMock) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { +func (coord *RootCoordMock) Import(ctx context.Context, req *milvuspb.ImportRequest, opts ...grpc.CallOption) (*milvuspb.ImportResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &milvuspb.ImportResponse{ @@ -1016,15 +940,12 @@ func (coord *RootCoordMock) Import(ctx context.Context, req *milvuspb.ImportRequ }, nil } return &milvuspb.ImportResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Tasks: make([]int64, 3), + Status: merr.Success(), + Tasks: make([]int64, 3), }, nil } -func (coord *RootCoordMock) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) { +func (coord *RootCoordMock) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest, opts ...grpc.CallOption) (*milvuspb.GetImportStateResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &milvuspb.GetImportStateResponse{ @@ -1037,16 +958,13 @@ func (coord *RootCoordMock) GetImportState(ctx context.Context, req *milvuspb.Ge }, nil } return &milvuspb.GetImportStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), RowCount: 10, IdList: make([]int64, 3), }, nil } -func (coord *RootCoordMock) ListImportTasks(ctx context.Context, in *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) { +func (coord *RootCoordMock) ListImportTasks(ctx context.Context, in *milvuspb.ListImportTasksRequest, opts ...grpc.CallOption) (*milvuspb.ListImportTasksResponse, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &milvuspb.ListImportTasksResponse{ @@ -1058,15 +976,12 @@ func (coord *RootCoordMock) ListImportTasks(ctx context.Context, in *milvuspb.Li }, nil } return &milvuspb.ListImportTasksResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Tasks: make([]*milvuspb.GetImportStateResponse, 3), + Status: merr.Success(), + Tasks: make([]*milvuspb.GetImportStateResponse, 3), }, nil } -func (coord *RootCoordMock) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) (*commonpb.Status, error) { +func (coord *RootCoordMock) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult, opts ...grpc.CallOption) (*commonpb.Status, error) { code := coord.state.Load().(commonpb.StateCode) if code != commonpb.StateCode_Healthy { return &commonpb.Status{ @@ -1074,10 +989,7 @@ func (coord *RootCoordMock) ReportImport(ctx context.Context, req *rootcoordpb.I Reason: fmt.Sprintf("state code = %s", commonpb.StateCode_name[int32(code)]), }, nil } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Success(), nil } func NewRootCoordMock(opts ...RootCoordMockOption) *RootCoordMock { @@ -1096,104 +1008,105 @@ func NewRootCoordMock(opts ...RootCoordMockOption) *RootCoordMock { opt(rc) } + rc.updateState(commonpb.StateCode_Healthy) return rc } -func (coord *RootCoordMock) CreateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) { +func (coord *RootCoordMock) CreateCredential(ctx context.Context, req *internalpb.CredentialInfo, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, nil } -func (coord *RootCoordMock) UpdateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) { +func (coord *RootCoordMock) UpdateCredential(ctx context.Context, req *internalpb.CredentialInfo, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, nil } -func (coord *RootCoordMock) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCredentialRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, nil } -func (coord *RootCoordMock) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) { +func (coord *RootCoordMock) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest, opts ...grpc.CallOption) (*milvuspb.ListCredUsersResponse, error) { return &milvuspb.ListCredUsersResponse{}, nil } -func (coord *RootCoordMock) GetCredential(ctx context.Context, req *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) { +func (coord *RootCoordMock) GetCredential(ctx context.Context, req *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error) { return &rootcoordpb.GetCredentialResponse{}, nil } -func (coord *RootCoordMock) CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, nil } -func (coord *RootCoordMock) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, nil } -func (coord *RootCoordMock) OperateUserRole(ctx context.Context, req *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) OperateUserRole(ctx context.Context, req *milvuspb.OperateUserRoleRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, nil } -func (coord *RootCoordMock) SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) { +func (coord *RootCoordMock) SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest, opts ...grpc.CallOption) (*milvuspb.SelectRoleResponse, error) { return &milvuspb.SelectRoleResponse{}, nil } -func (coord *RootCoordMock) SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) { +func (coord *RootCoordMock) SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest, opts ...grpc.CallOption) (*milvuspb.SelectUserResponse, error) { return &milvuspb.SelectUserResponse{}, nil } -func (coord *RootCoordMock) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePrivilegeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, nil } -func (coord *RootCoordMock) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) { +func (coord *RootCoordMock) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantRequest, opts ...grpc.CallOption) (*milvuspb.SelectGrantResponse, error) { return &milvuspb.SelectGrantResponse{}, nil } -func (coord *RootCoordMock) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { +func (coord *RootCoordMock) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest, opts ...grpc.CallOption) (*internalpb.ListPolicyResponse, error) { return &internalpb.ListPolicyResponse{}, nil } -func (coord *RootCoordMock) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, nil } -func (coord *RootCoordMock) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, nil } -func (coord *RootCoordMock) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil } -func (coord *RootCoordMock) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) { +func (coord *RootCoordMock) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error) { return &milvuspb.ListDatabasesResponse{}, nil } -func (coord *RootCoordMock) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { +func (coord *RootCoordMock) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { if coord.checkHealthFunc != nil { return coord.checkHealthFunc(ctx, req) } return &milvuspb.CheckHealthResponse{IsHealthy: true}, nil } -func (coord *RootCoordMock) RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) { +func (coord *RootCoordMock) RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, nil } -type DescribeCollectionFunc func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) +type DescribeCollectionFunc func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) -type ShowPartitionsFunc func(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) +type ShowPartitionsFunc func(ctx context.Context, request *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) -type ShowSegmentsFunc func(ctx context.Context, request *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) +type ShowSegmentsFunc func(ctx context.Context, request *milvuspb.ShowSegmentsRequest, opts ...grpc.CallOption) (*milvuspb.ShowSegmentsResponse, error) -type DescribeSegmentsFunc func(ctx context.Context, request *rootcoordpb.DescribeSegmentsRequest) (*rootcoordpb.DescribeSegmentsResponse, error) +type DescribeSegmentsFunc func(ctx context.Context, request *rootcoordpb.DescribeSegmentsRequest, opts ...grpc.CallOption) (*rootcoordpb.DescribeSegmentsResponse, error) -type ImportFunc func(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) +type ImportFunc func(ctx context.Context, req *milvuspb.ImportRequest, opts ...grpc.CallOption) (*milvuspb.ImportResponse, error) -type DropCollectionFunc func(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) +type DropCollectionFunc func(ctx context.Context, request *milvuspb.DropCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) -type GetGetCredentialFunc func(ctx context.Context, req *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) +type GetGetCredentialFunc func(ctx context.Context, req *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error) type mockRootCoord struct { - types.RootCoord + types.RootCoordClient DescribeCollectionFunc ShowPartitionsFunc ShowSegmentsFunc @@ -1203,76 +1116,75 @@ type mockRootCoord struct { GetGetCredentialFunc } -func (m *mockRootCoord) GetCredential(ctx context.Context, request *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) { +func (m *mockRootCoord) GetCredential(ctx context.Context, request *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error) { if m.GetGetCredentialFunc != nil { return m.GetGetCredentialFunc(ctx, request) } return nil, errors.New("mock") - } -func (m *mockRootCoord) DescribeCollection(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (m *mockRootCoord) DescribeCollection(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { if m.DescribeCollectionFunc != nil { return m.DescribeCollectionFunc(ctx, request) } return nil, errors.New("mock") } -func (m *mockRootCoord) DescribeCollectionInternal(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { +func (m *mockRootCoord) DescribeCollectionInternal(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { return m.DescribeCollection(ctx, request) } -func (m *mockRootCoord) ShowPartitions(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { +func (m *mockRootCoord) ShowPartitions(ctx context.Context, request *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { if m.ShowPartitionsFunc != nil { return m.ShowPartitionsFunc(ctx, request) } return nil, errors.New("mock") } -func (m *mockRootCoord) ShowPartitionsInternal(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { +func (m *mockRootCoord) ShowPartitionsInternal(ctx context.Context, request *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { return m.ShowPartitions(ctx, request) } -func (m *mockRootCoord) ShowSegments(ctx context.Context, request *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { +func (m *mockRootCoord) ShowSegments(ctx context.Context, request *milvuspb.ShowSegmentsRequest, opts ...grpc.CallOption) (*milvuspb.ShowSegmentsResponse, error) { if m.ShowSegmentsFunc != nil { return m.ShowSegmentsFunc(ctx, request) } return nil, errors.New("mock") } -func (m *mockRootCoord) Import(ctx context.Context, request *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { +func (m *mockRootCoord) Import(ctx context.Context, request *milvuspb.ImportRequest, opts ...grpc.CallOption) (*milvuspb.ImportResponse, error) { if m.ImportFunc != nil { return m.ImportFunc(ctx, request) } return nil, errors.New("mock") } -func (m *mockRootCoord) DropCollection(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { +func (m *mockRootCoord) DropCollection(ctx context.Context, request *milvuspb.DropCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { if m.DropCollectionFunc != nil { return m.DropCollectionFunc(ctx, request) } return nil, errors.New("mock") } -func (m *mockRootCoord) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { +func (m *mockRootCoord) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest, opts ...grpc.CallOption) (*internalpb.ListPolicyResponse, error) { return &internalpb.ListPolicyResponse{}, nil } -func (m *mockRootCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { +func (m *mockRootCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { return &milvuspb.CheckHealthResponse{ IsHealthy: true, }, nil } -func (m *mockRootCoord) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { +func (m *mockRootCoord) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, nil } -func (m *mockRootCoord) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) { +func (m *mockRootCoord) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, nil } -func (m *mockRootCoord) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) { +func (m *mockRootCoord) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error) { return &milvuspb.ListDatabasesResponse{}, nil } diff --git a/internal/proxy/roundrobin_balancer.go b/internal/proxy/roundrobin_balancer.go index cd0f49cbcf093..bd54f0f82ae95 100644 --- a/internal/proxy/roundrobin_balancer.go +++ b/internal/proxy/roundrobin_balancer.go @@ -18,10 +18,11 @@ package proxy import ( "context" + "go.uber.org/atomic" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/atomic" ) type RoundRobinBalancer struct { diff --git a/internal/proxy/segment.go b/internal/proxy/segment.go index cc2458f6c9ed2..4ee1bc97cfd80 100644 --- a/internal/proxy/segment.go +++ b/internal/proxy/segment.go @@ -25,6 +25,7 @@ import ( "github.com/cockroachdb/errors" "go.uber.org/zap" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/allocator" @@ -41,7 +42,7 @@ const ( // DataCoord is a narrowed interface of DataCoordinator which only provide AssignSegmentID method type DataCoord interface { - AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) + AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) } type segRequest struct { @@ -319,13 +320,12 @@ func (sa *segIDAssigner) syncSegments() (bool, error) { log.Debug("syncSegments call dataCoord.AssignSegmentID", zap.String("request", req.String())) resp, err := sa.dataCoord.AssignSegmentID(context.Background(), req) - if err != nil { return false, fmt.Errorf("syncSegmentID Failed:%w", err) } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return false, fmt.Errorf("syncSegmentID Failed:%s", resp.Status.Reason) + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return false, fmt.Errorf("syncSegmentID Failed:%s", resp.GetStatus().GetReason()) } var errMsg string @@ -333,8 +333,8 @@ func (sa *segIDAssigner) syncSegments() (bool, error) { success := true for _, segAssign := range resp.SegIDAssignments { if segAssign.Status.GetErrorCode() != commonpb.ErrorCode_Success { - log.Warn("proxy", zap.String("SyncSegment Error", segAssign.Status.Reason)) - errMsg += segAssign.Status.Reason + log.Warn("proxy", zap.String("SyncSegment Error", segAssign.GetStatus().GetReason())) + errMsg += segAssign.GetStatus().GetReason() errMsg += "\n" success = false continue diff --git a/internal/proxy/segment_test.go b/internal/proxy/segment_test.go index eb9a1516892b8..0d7d8c9485818 100644 --- a/internal/proxy/segment_test.go +++ b/internal/proxy/segment_test.go @@ -25,16 +25,18 @@ import ( "time" "github.com/stretchr/testify/assert" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/util/merr" ) type mockDataCoord struct { expireTime Timestamp } -func (mockD *mockDataCoord) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { +func (mockD *mockDataCoord) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) { assigns := make([]*datapb.SegmentIDAssignment, 0, len(req.SegmentIDRequests)) maxPerCnt := 100 for _, r := range req.SegmentIDRequests { @@ -53,19 +55,14 @@ func (mockD *mockDataCoord) AssignSegmentID(ctx context.Context, req *datapb.Ass PartitionID: r.PartitionID, ExpireTime: mockD.expireTime, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), } assigns = append(assigns, result) } } return &datapb.AssignSegmentIDResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), SegIDAssignments: assigns, }, nil } @@ -74,8 +71,7 @@ type mockDataCoord2 struct { expireTime Timestamp } -func (mockD *mockDataCoord2) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { - +func (mockD *mockDataCoord2) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) { return &datapb.AssignSegmentIDResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -120,7 +116,6 @@ func TestSegmentAllocator1(t *testing.T) { _, err = segAllocator.GetSegmentID(1, 1, "abc", 10, 1001) assert.Error(t, err) wg.Wait() - } var curLastTick2 = Timestamp(200) @@ -160,7 +155,6 @@ func TestSegmentAllocator2(t *testing.T) { _, err = segAllocator.GetSegmentID(1, 1, "abc", segCountPerRPC-10, getLastTick2()) assert.Error(t, err) wg.Wait() - } func TestSegmentAllocator3(t *testing.T) { @@ -188,7 +182,7 @@ type mockDataCoord3 struct { expireTime Timestamp } -func (mockD *mockDataCoord3) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { +func (mockD *mockDataCoord3) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) { assigns := make([]*datapb.SegmentIDAssignment, 0, len(req.SegmentIDRequests)) for i, r := range req.SegmentIDRequests { errCode := commonpb.ErrorCode_Success @@ -214,9 +208,7 @@ func (mockD *mockDataCoord3) AssignSegmentID(ctx context.Context, req *datapb.As } return &datapb.AssignSegmentIDResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), SegIDAssignments: assigns, }, nil } @@ -246,8 +238,7 @@ type mockDataCoord5 struct { expireTime Timestamp } -func (mockD *mockDataCoord5) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { - +func (mockD *mockDataCoord5) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) { return &datapb.AssignSegmentIDResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -320,5 +311,4 @@ func TestSegmentAllocator6(t *testing.T) { } wg.Wait() assert.True(t, success) - } diff --git a/internal/proxy/shard_client.go b/internal/proxy/shard_client.go index 8ad71bab77672..c250de1d6aab4 100644 --- a/internal/proxy/shard_client.go +++ b/internal/proxy/shard_client.go @@ -6,12 +6,14 @@ import ( "sync" "github.com/cockroachdb/errors" + "go.uber.org/zap" - qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client" + "github.com/milvus-io/milvus/internal/registry" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/log" ) -type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error) +type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) type nodeInfo struct { nodeID UniqueID @@ -27,12 +29,12 @@ var errClosed = errors.New("client is closed") type shardClient struct { sync.RWMutex info nodeInfo - client types.QueryNode + client types.QueryNodeClient isClosed bool refCnt int } -func (n *shardClient) getClient(ctx context.Context) (types.QueryNode, error) { +func (n *shardClient) getClient(ctx context.Context) (types.QueryNodeClient, error) { n.RLock() defer n.RUnlock() if n.isClosed { @@ -54,7 +56,9 @@ func (n *shardClient) close() { n.isClosed = true n.refCnt = 0 if n.client != nil { - n.client.Stop() + if err := n.client.Close(); err != nil { + log.Warn("close grpc client failed", zap.Error(err)) + } n.client = nil } } @@ -80,7 +84,7 @@ func (n *shardClient) Close() { n.close() } -func newShardClient(info *nodeInfo, client types.QueryNode) *shardClient { +func newShardClient(info *nodeInfo, client types.QueryNodeClient) *shardClient { ret := &shardClient{ info: nodeInfo{ nodeID: info.nodeID, @@ -93,7 +97,7 @@ func newShardClient(info *nodeInfo, client types.QueryNode) *shardClient { } type shardClientMgr interface { - GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNode, error) + GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNodeClient, error) UpdateShardLeaders(oldLeaders map[string][]nodeInfo, newLeaders map[string][]nodeInfo) error Close() SetClientCreatorFunc(creator queryNodeCreatorFunc) @@ -114,8 +118,8 @@ func withShardClientCreator(creator queryNodeCreatorFunc) shardClientMgrOpt { return func(s shardClientMgr) { s.SetClientCreatorFunc(creator) } } -func defaultQueryNodeClientCreator(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error) { - return qnClient.NewClient(ctx, addr, nodeID) +func defaultQueryNodeClientCreator(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { + return registry.GetInMemoryResolver().ResolveQueryNode(ctx, addr, nodeID) } // NewShardClientMgr creates a new shardClientMgr @@ -195,7 +199,7 @@ func (c *shardClientMgrImpl) UpdateShardLeaders(oldLeaders map[string][]nodeInfo return nil } -func (c *shardClientMgrImpl) GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNode, error) { +func (c *shardClientMgrImpl) GetClient(ctx context.Context, nodeID UniqueID) (types.QueryNodeClient, error) { c.clients.RLock() client, ok := c.clients.data[nodeID] c.clients.RUnlock() diff --git a/internal/proxy/shard_client_test.go b/internal/proxy/shard_client_test.go index da21ed512c626..0ef6f516caf2d 100644 --- a/internal/proxy/shard_client_test.go +++ b/internal/proxy/shard_client_test.go @@ -4,9 +4,10 @@ import ( "context" "testing" - "github.com/milvus-io/milvus/internal/types" - "github.com/milvus-io/milvus/internal/util/mock" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/types" ) func genShardLeaderInfo(channel string, leaderIDs []UniqueID) map[string][]nodeInfo { @@ -31,8 +32,8 @@ func TestShardClientMgr_UpdateShardLeaders_CreatorNil(t *testing.T) { } func TestShardClientMgr_UpdateShardLeaders_Empty(t *testing.T) { - mockCreator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error) { - return &mock.QueryNodeClient{}, nil + mockCreator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { + return &mocks.MockQueryNodeClient{}, nil } mgr := newShardClientMgr(withShardClientCreator(mockCreator)) diff --git a/internal/proxy/task.go b/internal/proxy/task.go index de243d236d483..54934cd89074d 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -42,16 +42,16 @@ import ( ) const ( - IgnoreGrowingKey = "ignore_growing" - IterationExtensionReduceRateKey = "iteration_extension_reduce_rate" - AnnsFieldKey = "anns_field" - TopKKey = "topk" - NQKey = "nq" - MetricTypeKey = common.MetricTypeKey - SearchParamsKey = "params" - RoundDecimalKey = "round_decimal" - OffsetKey = "offset" - LimitKey = "limit" + IgnoreGrowingKey = "ignore_growing" + ReduceStopForBestKey = "reduce_stop_for_best" + AnnsFieldKey = "anns_field" + TopKKey = "topk" + NQKey = "nq" + MetricTypeKey = common.MetricTypeKey + SearchParamsKey = "params" + RoundDecimalKey = "round_decimal" + OffsetKey = "offset" + LimitKey = "limit" InsertTaskName = "InsertTask" CreateCollectionTaskName = "CreateCollectionTask" @@ -118,56 +118,58 @@ type createCollectionTask struct { Condition *milvuspb.CreateCollectionRequest ctx context.Context - rootCoord types.RootCoord + rootCoord types.RootCoordClient result *commonpb.Status schema *schemapb.CollectionSchema } -func (cct *createCollectionTask) TraceCtx() context.Context { - return cct.ctx +func (t *createCollectionTask) TraceCtx() context.Context { + return t.ctx } -func (cct *createCollectionTask) ID() UniqueID { - return cct.Base.MsgID +func (t *createCollectionTask) ID() UniqueID { + return t.Base.MsgID } -func (cct *createCollectionTask) SetID(uid UniqueID) { - cct.Base.MsgID = uid +func (t *createCollectionTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (cct *createCollectionTask) Name() string { +func (t *createCollectionTask) Name() string { return CreateCollectionTaskName } -func (cct *createCollectionTask) Type() commonpb.MsgType { - return cct.Base.MsgType +func (t *createCollectionTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (cct *createCollectionTask) BeginTs() Timestamp { - return cct.Base.Timestamp +func (t *createCollectionTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (cct *createCollectionTask) EndTs() Timestamp { - return cct.Base.Timestamp +func (t *createCollectionTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (cct *createCollectionTask) SetTs(ts Timestamp) { - cct.Base.Timestamp = ts +func (t *createCollectionTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (cct *createCollectionTask) OnEnqueue() error { - cct.Base = commonpbutil.NewMsgBase() - cct.Base.MsgType = commonpb.MsgType_CreateCollection - cct.Base.SourceID = paramtable.GetNodeID() +func (t *createCollectionTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } + t.Base.MsgType = commonpb.MsgType_CreateCollection + t.Base.SourceID = paramtable.GetNodeID() return nil } -func (cct *createCollectionTask) validatePartitionKey() error { +func (t *createCollectionTask) validatePartitionKey() error { idx := -1 - for i, field := range cct.schema.Fields { + for i, field := range t.schema.Fields { if field.GetIsPartitionKey() { if idx != -1 { - return fmt.Errorf("there are more than one partition key, field name = %s, %s", cct.schema.Fields[idx].Name, field.Name) + return fmt.Errorf("there are more than one partition key, field name = %s, %s", t.schema.Fields[idx].Name, field.Name) } if field.GetIsPrimaryKey() { @@ -179,13 +181,13 @@ func (cct *createCollectionTask) validatePartitionKey() error { return errors.New("the data type of partition key should be Int64 or VarChar") } - if cct.GetNumPartitions() < 0 { + if t.GetNumPartitions() < 0 { return errors.New("the specified partitions should be greater than 0 if partition key is used") } // set default physical partitions num if enable partition key mode - if cct.GetNumPartitions() == 0 { - cct.NumPartitions = common.DefaultPartitionsWithPartitionKey + if t.GetNumPartitions() == 0 { + t.NumPartitions = common.DefaultPartitionsWithPartitionKey } idx = i @@ -193,79 +195,79 @@ func (cct *createCollectionTask) validatePartitionKey() error { } if idx == -1 { - if cct.GetNumPartitions() != 0 { + if t.GetNumPartitions() != 0 { return fmt.Errorf("num_partitions should only be specified with partition key field enabled") } } else { log.Info("create collection with partition key mode", - zap.String("collectionName", cct.CollectionName), - zap.Int64("numDefaultPartitions", cct.GetNumPartitions())) + zap.String("collectionName", t.CollectionName), + zap.Int64("numDefaultPartitions", t.GetNumPartitions())) } return nil } -func (cct *createCollectionTask) PreExecute(ctx context.Context) error { - cct.Base.MsgType = commonpb.MsgType_CreateCollection - cct.Base.SourceID = paramtable.GetNodeID() +func (t *createCollectionTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_CreateCollection + t.Base.SourceID = paramtable.GetNodeID() - cct.schema = &schemapb.CollectionSchema{} - err := proto.Unmarshal(cct.Schema, cct.schema) + t.schema = &schemapb.CollectionSchema{} + err := proto.Unmarshal(t.Schema, t.schema) if err != nil { return err } - cct.schema.AutoID = false + t.schema.AutoID = false - if cct.ShardsNum > Params.ProxyCfg.MaxShardNum.GetAsInt32() { + if t.ShardsNum > Params.ProxyCfg.MaxShardNum.GetAsInt32() { return fmt.Errorf("maximum shards's number should be limited to %d", Params.ProxyCfg.MaxShardNum.GetAsInt()) } - if len(cct.schema.Fields) > Params.ProxyCfg.MaxFieldNum.GetAsInt() { + if len(t.schema.Fields) > Params.ProxyCfg.MaxFieldNum.GetAsInt() { return fmt.Errorf("maximum field's number should be limited to %d", Params.ProxyCfg.MaxFieldNum.GetAsInt()) } // validate collection name - if err := validateCollectionName(cct.schema.Name); err != nil { + if err := validateCollectionName(t.schema.Name); err != nil { return err } // validate whether field names duplicates - if err := validateDuplicatedFieldName(cct.schema.Fields); err != nil { + if err := validateDuplicatedFieldName(t.schema.Fields); err != nil { return err } // validate primary key definition - if err := validatePrimaryKey(cct.schema); err != nil { + if err := validatePrimaryKey(t.schema); err != nil { return err } // validate dynamic field - if err := validateDynamicField(cct.schema); err != nil { + if err := validateDynamicField(t.schema); err != nil { return err } // validate auto id definition - if err := ValidateFieldAutoID(cct.schema); err != nil { + if err := ValidateFieldAutoID(t.schema); err != nil { return err } // validate field type definition - if err := validateFieldType(cct.schema); err != nil { + if err := validateFieldType(t.schema); err != nil { return err } // validate partition key mode - if err := cct.validatePartitionKey(); err != nil { + if err := t.validatePartitionKey(); err != nil { return err } - for _, field := range cct.schema.Fields { + for _, field := range t.schema.Fields { // validate field name if err := validateFieldName(field.Name); err != nil { return err } // validate vector field type parameters - if field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_BinaryVector { + if isVectorType(field.DataType) { err = validateDimension(field) if err != nil { return err @@ -273,19 +275,27 @@ func (cct *createCollectionTask) PreExecute(ctx context.Context) error { } // valid max length per row parameters // if max_length not specified, return error - if field.DataType == schemapb.DataType_VarChar { - err = validateMaxLengthPerRow(cct.schema.Name, field) + if field.DataType == schemapb.DataType_VarChar || + (field.GetDataType() == schemapb.DataType_Array && field.GetElementType() == schemapb.DataType_VarChar) { + err = validateMaxLengthPerRow(t.schema.Name, field) if err != nil { return err } } + // valid max capacity for array per row parameters + // if max_capacity not specified, return error + if field.DataType == schemapb.DataType_Array { + if err = validateMaxCapacityPerRow(t.schema.Name, field); err != nil { + return err + } + } } - if err := validateMultipleVectorFields(cct.schema); err != nil { + if err := validateMultipleVectorFields(t.schema); err != nil { return err } - cct.CreateCollectionRequest.Schema, err = proto.Marshal(cct.schema) + t.CreateCollectionRequest.Schema, err = proto.Marshal(t.schema) if err != nil { return err } @@ -293,13 +303,13 @@ func (cct *createCollectionTask) PreExecute(ctx context.Context) error { return nil } -func (cct *createCollectionTask) Execute(ctx context.Context) error { +func (t *createCollectionTask) Execute(ctx context.Context) error { var err error - cct.result, err = cct.rootCoord.CreateCollection(ctx, cct.CreateCollectionRequest) + t.result, err = t.rootCoord.CreateCollection(ctx, t.CreateCollectionRequest) return err } -func (cct *createCollectionTask) PostExecute(ctx context.Context) error { +func (t *createCollectionTask) PostExecute(ctx context.Context) error { return nil } @@ -307,66 +317,68 @@ type dropCollectionTask struct { Condition *milvuspb.DropCollectionRequest ctx context.Context - rootCoord types.RootCoord + rootCoord types.RootCoordClient result *commonpb.Status chMgr channelsMgr chTicker channelsTimeTicker } -func (dct *dropCollectionTask) TraceCtx() context.Context { - return dct.ctx +func (t *dropCollectionTask) TraceCtx() context.Context { + return t.ctx } -func (dct *dropCollectionTask) ID() UniqueID { - return dct.Base.MsgID +func (t *dropCollectionTask) ID() UniqueID { + return t.Base.MsgID } -func (dct *dropCollectionTask) SetID(uid UniqueID) { - dct.Base.MsgID = uid +func (t *dropCollectionTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (dct *dropCollectionTask) Name() string { +func (t *dropCollectionTask) Name() string { return DropCollectionTaskName } -func (dct *dropCollectionTask) Type() commonpb.MsgType { - return dct.Base.MsgType +func (t *dropCollectionTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (dct *dropCollectionTask) BeginTs() Timestamp { - return dct.Base.Timestamp +func (t *dropCollectionTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (dct *dropCollectionTask) EndTs() Timestamp { - return dct.Base.Timestamp +func (t *dropCollectionTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (dct *dropCollectionTask) SetTs(ts Timestamp) { - dct.Base.Timestamp = ts +func (t *dropCollectionTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (dct *dropCollectionTask) OnEnqueue() error { - dct.Base = commonpbutil.NewMsgBase() +func (t *dropCollectionTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } -func (dct *dropCollectionTask) PreExecute(ctx context.Context) error { - dct.Base.MsgType = commonpb.MsgType_DropCollection - dct.Base.SourceID = paramtable.GetNodeID() +func (t *dropCollectionTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_DropCollection + t.Base.SourceID = paramtable.GetNodeID() - if err := validateCollectionName(dct.CollectionName); err != nil { + if err := validateCollectionName(t.CollectionName); err != nil { return err } return nil } -func (dct *dropCollectionTask) Execute(ctx context.Context) error { +func (t *dropCollectionTask) Execute(ctx context.Context) error { var err error - dct.result, err = dct.rootCoord.DropCollection(ctx, dct.DropCollectionRequest) + t.result, err = t.rootCoord.DropCollection(ctx, t.DropCollectionRequest) return err } -func (dct *dropCollectionTask) PostExecute(ctx context.Context) error { +func (t *dropCollectionTask) PostExecute(ctx context.Context) error { return nil } @@ -374,73 +386,73 @@ type hasCollectionTask struct { Condition *milvuspb.HasCollectionRequest ctx context.Context - rootCoord types.RootCoord + rootCoord types.RootCoordClient result *milvuspb.BoolResponse } -func (hct *hasCollectionTask) TraceCtx() context.Context { - return hct.ctx +func (t *hasCollectionTask) TraceCtx() context.Context { + return t.ctx } -func (hct *hasCollectionTask) ID() UniqueID { - return hct.Base.MsgID +func (t *hasCollectionTask) ID() UniqueID { + return t.Base.MsgID } -func (hct *hasCollectionTask) SetID(uid UniqueID) { - hct.Base.MsgID = uid +func (t *hasCollectionTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (hct *hasCollectionTask) Name() string { +func (t *hasCollectionTask) Name() string { return HasCollectionTaskName } -func (hct *hasCollectionTask) Type() commonpb.MsgType { - return hct.Base.MsgType +func (t *hasCollectionTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (hct *hasCollectionTask) BeginTs() Timestamp { - return hct.Base.Timestamp +func (t *hasCollectionTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (hct *hasCollectionTask) EndTs() Timestamp { - return hct.Base.Timestamp +func (t *hasCollectionTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (hct *hasCollectionTask) SetTs(ts Timestamp) { - hct.Base.Timestamp = ts +func (t *hasCollectionTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (hct *hasCollectionTask) OnEnqueue() error { - hct.Base = commonpbutil.NewMsgBase() +func (t *hasCollectionTask) OnEnqueue() error { + t.Base = commonpbutil.NewMsgBase() return nil } -func (hct *hasCollectionTask) PreExecute(ctx context.Context) error { - hct.Base.MsgType = commonpb.MsgType_HasCollection - hct.Base.SourceID = paramtable.GetNodeID() +func (t *hasCollectionTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_HasCollection + t.Base.SourceID = paramtable.GetNodeID() - if err := validateCollectionName(hct.CollectionName); err != nil { + if err := validateCollectionName(t.CollectionName); err != nil { return err } return nil } -func (hct *hasCollectionTask) Execute(ctx context.Context) error { +func (t *hasCollectionTask) Execute(ctx context.Context) error { var err error - hct.result, err = hct.rootCoord.HasCollection(ctx, hct.HasCollectionRequest) + t.result, err = t.rootCoord.HasCollection(ctx, t.HasCollectionRequest) if err != nil { return err } - if hct.result == nil { + if t.result == nil { return errors.New("has collection resp is nil") } - if hct.result.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(hct.result.Status.Reason) + if t.result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return merr.Error(t.result.GetStatus()) } return nil } -func (hct *hasCollectionTask) PostExecute(ctx context.Context) error { +func (t *hasCollectionTask) PostExecute(ctx context.Context) error { return nil } @@ -448,62 +460,62 @@ type describeCollectionTask struct { Condition *milvuspb.DescribeCollectionRequest ctx context.Context - rootCoord types.RootCoord + rootCoord types.RootCoordClient result *milvuspb.DescribeCollectionResponse } -func (dct *describeCollectionTask) TraceCtx() context.Context { - return dct.ctx +func (t *describeCollectionTask) TraceCtx() context.Context { + return t.ctx } -func (dct *describeCollectionTask) ID() UniqueID { - return dct.Base.MsgID +func (t *describeCollectionTask) ID() UniqueID { + return t.Base.MsgID } -func (dct *describeCollectionTask) SetID(uid UniqueID) { - dct.Base.MsgID = uid +func (t *describeCollectionTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (dct *describeCollectionTask) Name() string { +func (t *describeCollectionTask) Name() string { return DescribeCollectionTaskName } -func (dct *describeCollectionTask) Type() commonpb.MsgType { - return dct.Base.MsgType +func (t *describeCollectionTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (dct *describeCollectionTask) BeginTs() Timestamp { - return dct.Base.Timestamp +func (t *describeCollectionTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (dct *describeCollectionTask) EndTs() Timestamp { - return dct.Base.Timestamp +func (t *describeCollectionTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (dct *describeCollectionTask) SetTs(ts Timestamp) { - dct.Base.Timestamp = ts +func (t *describeCollectionTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (dct *describeCollectionTask) OnEnqueue() error { - dct.Base = commonpbutil.NewMsgBase() +func (t *describeCollectionTask) OnEnqueue() error { + t.Base = commonpbutil.NewMsgBase() return nil } -func (dct *describeCollectionTask) PreExecute(ctx context.Context) error { - dct.Base.MsgType = commonpb.MsgType_DescribeCollection - dct.Base.SourceID = paramtable.GetNodeID() +func (t *describeCollectionTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_DescribeCollection + t.Base.SourceID = paramtable.GetNodeID() - if dct.CollectionID != 0 && len(dct.CollectionName) == 0 { + if t.CollectionID != 0 && len(t.CollectionName) == 0 { return nil } - return validateCollectionName(dct.CollectionName) + return validateCollectionName(t.CollectionName) } -func (dct *describeCollectionTask) Execute(ctx context.Context) error { +func (t *describeCollectionTask) Execute(ctx context.Context) error { var err error - dct.result = &milvuspb.DescribeCollectionResponse{ - Status: merr.Status(nil), + t.result = &milvuspb.DescribeCollectionResponse{ + Status: merr.Success(), Schema: &schemapb.CollectionSchema{ Name: "", Description: "", @@ -513,46 +525,48 @@ func (dct *describeCollectionTask) Execute(ctx context.Context) error { CollectionID: 0, VirtualChannelNames: nil, PhysicalChannelNames: nil, - CollectionName: dct.GetCollectionName(), - DbName: dct.GetDbName(), + CollectionName: t.GetCollectionName(), + DbName: t.GetDbName(), } - result, err := dct.rootCoord.DescribeCollection(ctx, dct.DescribeCollectionRequest) + result, err := t.rootCoord.DescribeCollection(ctx, t.DescribeCollectionRequest) if err != nil { return err } - if result.Status.ErrorCode != commonpb.ErrorCode_Success { - dct.result.Status = result.Status + if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + t.result.Status = result.Status // compatibility with PyMilvus existing implementation - err := merr.Error(dct.result.GetStatus()) + err := merr.Error(t.result.GetStatus()) if errors.Is(err, merr.ErrCollectionNotFound) { - dct.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - dct.result.Status.Reason = "can't find collection " + dct.result.Status.Reason + // nolint + t.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError + // nolint + t.result.Status.Reason = "can't find collection " + t.result.GetStatus().GetReason() } } else { - dct.result.Schema.Name = result.Schema.Name - dct.result.Schema.Description = result.Schema.Description - dct.result.Schema.AutoID = result.Schema.AutoID - dct.result.Schema.EnableDynamicField = result.Schema.EnableDynamicField - dct.result.CollectionID = result.CollectionID - dct.result.VirtualChannelNames = result.VirtualChannelNames - dct.result.PhysicalChannelNames = result.PhysicalChannelNames - dct.result.CreatedTimestamp = result.CreatedTimestamp - dct.result.CreatedUtcTimestamp = result.CreatedUtcTimestamp - dct.result.ShardsNum = result.ShardsNum - dct.result.ConsistencyLevel = result.ConsistencyLevel - dct.result.Aliases = result.Aliases - dct.result.Properties = result.Properties - dct.result.DbName = result.GetDbName() - dct.result.NumPartitions = result.NumPartitions + t.result.Schema.Name = result.Schema.Name + t.result.Schema.Description = result.Schema.Description + t.result.Schema.AutoID = result.Schema.AutoID + t.result.Schema.EnableDynamicField = result.Schema.EnableDynamicField + t.result.CollectionID = result.CollectionID + t.result.VirtualChannelNames = result.VirtualChannelNames + t.result.PhysicalChannelNames = result.PhysicalChannelNames + t.result.CreatedTimestamp = result.CreatedTimestamp + t.result.CreatedUtcTimestamp = result.CreatedUtcTimestamp + t.result.ShardsNum = result.ShardsNum + t.result.ConsistencyLevel = result.ConsistencyLevel + t.result.Aliases = result.Aliases + t.result.Properties = result.Properties + t.result.DbName = result.GetDbName() + t.result.NumPartitions = result.NumPartitions for _, field := range result.Schema.Fields { if field.IsDynamic { continue } if field.FieldID >= common.StartOfUserFieldID { - dct.result.Schema.Fields = append(dct.result.Schema.Fields, &schemapb.FieldSchema{ + t.result.Schema.Fields = append(t.result.Schema.Fields, &schemapb.FieldSchema{ FieldID: field.FieldID, Name: field.Name, IsPrimaryKey: field.IsPrimaryKey, @@ -564,6 +578,7 @@ func (dct *describeCollectionTask) Execute(ctx context.Context) error { IsDynamic: field.IsDynamic, IsPartitionKey: field.IsPartitionKey, DefaultValue: field.DefaultValue, + ElementType: field.ElementType, }) } } @@ -571,7 +586,7 @@ func (dct *describeCollectionTask) Execute(ctx context.Context) error { return nil } -func (dct *describeCollectionTask) PostExecute(ctx context.Context) error { +func (t *describeCollectionTask) PostExecute(ctx context.Context) error { return nil } @@ -579,53 +594,53 @@ type showCollectionsTask struct { Condition *milvuspb.ShowCollectionsRequest ctx context.Context - rootCoord types.RootCoord - queryCoord types.QueryCoord + rootCoord types.RootCoordClient + queryCoord types.QueryCoordClient result *milvuspb.ShowCollectionsResponse } -func (sct *showCollectionsTask) TraceCtx() context.Context { - return sct.ctx +func (t *showCollectionsTask) TraceCtx() context.Context { + return t.ctx } -func (sct *showCollectionsTask) ID() UniqueID { - return sct.Base.MsgID +func (t *showCollectionsTask) ID() UniqueID { + return t.Base.MsgID } -func (sct *showCollectionsTask) SetID(uid UniqueID) { - sct.Base.MsgID = uid +func (t *showCollectionsTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (sct *showCollectionsTask) Name() string { +func (t *showCollectionsTask) Name() string { return ShowCollectionTaskName } -func (sct *showCollectionsTask) Type() commonpb.MsgType { - return sct.Base.MsgType +func (t *showCollectionsTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (sct *showCollectionsTask) BeginTs() Timestamp { - return sct.Base.Timestamp +func (t *showCollectionsTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (sct *showCollectionsTask) EndTs() Timestamp { - return sct.Base.Timestamp +func (t *showCollectionsTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (sct *showCollectionsTask) SetTs(ts Timestamp) { - sct.Base.Timestamp = ts +func (t *showCollectionsTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (sct *showCollectionsTask) OnEnqueue() error { - sct.Base = commonpbutil.NewMsgBase() +func (t *showCollectionsTask) OnEnqueue() error { + t.Base = commonpbutil.NewMsgBase() return nil } -func (sct *showCollectionsTask) PreExecute(ctx context.Context) error { - sct.Base.MsgType = commonpb.MsgType_ShowCollections - sct.Base.SourceID = paramtable.GetNodeID() - if sct.GetType() == milvuspb.ShowType_InMemory { - for _, collectionName := range sct.CollectionNames { +func (t *showCollectionsTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_ShowCollections + t.Base.SourceID = paramtable.GetNodeID() + if t.GetType() == milvuspb.ShowType_InMemory { + for _, collectionName := range t.CollectionNames { if err := validateCollectionName(collectionName); err != nil { return err } @@ -635,8 +650,8 @@ func (sct *showCollectionsTask) PreExecute(ctx context.Context) error { return nil } -func (sct *showCollectionsTask) Execute(ctx context.Context) error { - respFromRootCoord, err := sct.rootCoord.ShowCollections(ctx, sct.ShowCollectionsRequest) +func (t *showCollectionsTask) Execute(ctx context.Context) error { + respFromRootCoord, err := t.rootCoord.ShowCollections(ctx, t.ShowCollectionsRequest) if err != nil { return err } @@ -645,34 +660,34 @@ func (sct *showCollectionsTask) Execute(ctx context.Context) error { return errors.New("failed to show collections") } - if respFromRootCoord.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(respFromRootCoord.Status.Reason) + if respFromRootCoord.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return merr.Error(respFromRootCoord.GetStatus()) } - if sct.GetType() == milvuspb.ShowType_InMemory { + if t.GetType() == milvuspb.ShowType_InMemory { IDs2Names := make(map[UniqueID]string) for offset, collectionName := range respFromRootCoord.CollectionNames { collectionID := respFromRootCoord.CollectionIds[offset] IDs2Names[collectionID] = collectionName } collectionIDs := make([]UniqueID, 0) - for _, collectionName := range sct.CollectionNames { - collectionID, err := globalMetaCache.GetCollectionID(ctx, sct.GetDbName(), collectionName) + for _, collectionName := range t.CollectionNames { + collectionID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), collectionName) if err != nil { log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName), - zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections")) + zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "showCollections")) return err } collectionIDs = append(collectionIDs, collectionID) IDs2Names[collectionID] = collectionName } - resp, err := sct.queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{ + resp, err := t.queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{ Base: commonpbutil.UpdateMsgBase( - sct.Base, + t.Base, commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections), ), - // DbID: sct.ShowCollectionsRequest.DbName, + // DbID: t.ShowCollectionsRequest.DbName, CollectionIDs: collectionIDs, }) if err != nil { @@ -683,16 +698,16 @@ func (sct *showCollectionsTask) Execute(ctx context.Context) error { return errors.New("failed to show collections") } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { // update collectionID to collection name, and return new error info to sdk - newErrorReason := resp.Status.Reason + newErrorReason := resp.GetStatus().GetReason() for _, collectionID := range collectionIDs { newErrorReason = ReplaceID2Name(newErrorReason, collectionID, IDs2Names[collectionID]) } return errors.New(newErrorReason) } - sct.result = &milvuspb.ShowCollectionsResponse{ + t.result = &milvuspb.ShowCollectionsResponse{ Status: resp.Status, CollectionNames: make([]string, 0, len(resp.CollectionIDs)), CollectionIds: make([]int64, 0, len(resp.CollectionIDs)), @@ -707,30 +722,30 @@ func (sct *showCollectionsTask) Execute(ctx context.Context) error { if !ok { log.Debug("Failed to get collection info. This collection may be not released", zap.Any("collectionID", id), - zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections")) + zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "showCollections")) continue } - collectionInfo, err := globalMetaCache.GetCollectionInfo(ctx, sct.GetDbName(), collectionName, id) + collectionInfo, err := globalMetaCache.GetCollectionInfo(ctx, t.GetDbName(), collectionName, id) if err != nil { log.Debug("Failed to get collection info.", zap.Any("collectionName", collectionName), - zap.Any("requestID", sct.Base.MsgID), zap.Any("requestType", "showCollections")) + zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "showCollections")) return err } - sct.result.CollectionIds = append(sct.result.CollectionIds, id) - sct.result.CollectionNames = append(sct.result.CollectionNames, collectionName) - sct.result.CreatedTimestamps = append(sct.result.CreatedTimestamps, collectionInfo.createdTimestamp) - sct.result.CreatedUtcTimestamps = append(sct.result.CreatedUtcTimestamps, collectionInfo.createdUtcTimestamp) - sct.result.InMemoryPercentages = append(sct.result.InMemoryPercentages, resp.InMemoryPercentages[offset]) - sct.result.QueryServiceAvailable = append(sct.result.QueryServiceAvailable, resp.QueryServiceAvailable[offset]) + t.result.CollectionIds = append(t.result.CollectionIds, id) + t.result.CollectionNames = append(t.result.CollectionNames, collectionName) + t.result.CreatedTimestamps = append(t.result.CreatedTimestamps, collectionInfo.createdTimestamp) + t.result.CreatedUtcTimestamps = append(t.result.CreatedUtcTimestamps, collectionInfo.createdUtcTimestamp) + t.result.InMemoryPercentages = append(t.result.InMemoryPercentages, resp.InMemoryPercentages[offset]) + t.result.QueryServiceAvailable = append(t.result.QueryServiceAvailable, resp.QueryServiceAvailable[offset]) } } else { - sct.result = respFromRootCoord + t.result = respFromRootCoord } return nil } -func (sct *showCollectionsTask) PostExecute(ctx context.Context) error { +func (t *showCollectionsTask) PostExecute(ctx context.Context) error { return nil } @@ -738,61 +753,63 @@ type alterCollectionTask struct { Condition *milvuspb.AlterCollectionRequest ctx context.Context - rootCoord types.RootCoord + rootCoord types.RootCoordClient result *commonpb.Status } -func (act *alterCollectionTask) TraceCtx() context.Context { - return act.ctx +func (t *alterCollectionTask) TraceCtx() context.Context { + return t.ctx } -func (act *alterCollectionTask) ID() UniqueID { - return act.Base.MsgID +func (t *alterCollectionTask) ID() UniqueID { + return t.Base.MsgID } -func (act *alterCollectionTask) SetID(uid UniqueID) { - act.Base.MsgID = uid +func (t *alterCollectionTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (act *alterCollectionTask) Name() string { +func (t *alterCollectionTask) Name() string { return AlterCollectionTaskName } -func (act *alterCollectionTask) Type() commonpb.MsgType { - return act.Base.MsgType +func (t *alterCollectionTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (act *alterCollectionTask) BeginTs() Timestamp { - return act.Base.Timestamp +func (t *alterCollectionTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (act *alterCollectionTask) EndTs() Timestamp { - return act.Base.Timestamp +func (t *alterCollectionTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (act *alterCollectionTask) SetTs(ts Timestamp) { - act.Base.Timestamp = ts +func (t *alterCollectionTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (act *alterCollectionTask) OnEnqueue() error { - act.Base = commonpbutil.NewMsgBase() +func (t *alterCollectionTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } -func (act *alterCollectionTask) PreExecute(ctx context.Context) error { - act.Base.MsgType = commonpb.MsgType_AlterCollection - act.Base.SourceID = paramtable.GetNodeID() +func (t *alterCollectionTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_AlterCollection + t.Base.SourceID = paramtable.GetNodeID() return nil } -func (act *alterCollectionTask) Execute(ctx context.Context) error { +func (t *alterCollectionTask) Execute(ctx context.Context) error { var err error - act.result, err = act.rootCoord.AlterCollection(ctx, act.AlterCollectionRequest) + t.result, err = t.rootCoord.AlterCollection(ctx, t.AlterCollectionRequest) return err } -func (act *alterCollectionTask) PostExecute(ctx context.Context) error { +func (t *alterCollectionTask) PostExecute(ctx context.Context) error { return nil } @@ -800,58 +817,60 @@ type createPartitionTask struct { Condition *milvuspb.CreatePartitionRequest ctx context.Context - rootCoord types.RootCoord + rootCoord types.RootCoordClient result *commonpb.Status } -func (cpt *createPartitionTask) TraceCtx() context.Context { - return cpt.ctx +func (t *createPartitionTask) TraceCtx() context.Context { + return t.ctx } -func (cpt *createPartitionTask) ID() UniqueID { - return cpt.Base.MsgID +func (t *createPartitionTask) ID() UniqueID { + return t.Base.MsgID } -func (cpt *createPartitionTask) SetID(uid UniqueID) { - cpt.Base.MsgID = uid +func (t *createPartitionTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (cpt *createPartitionTask) Name() string { +func (t *createPartitionTask) Name() string { return CreatePartitionTaskName } -func (cpt *createPartitionTask) Type() commonpb.MsgType { - return cpt.Base.MsgType +func (t *createPartitionTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (cpt *createPartitionTask) BeginTs() Timestamp { - return cpt.Base.Timestamp +func (t *createPartitionTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (cpt *createPartitionTask) EndTs() Timestamp { - return cpt.Base.Timestamp +func (t *createPartitionTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (cpt *createPartitionTask) SetTs(ts Timestamp) { - cpt.Base.Timestamp = ts +func (t *createPartitionTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (cpt *createPartitionTask) OnEnqueue() error { - cpt.Base = commonpbutil.NewMsgBase() +func (t *createPartitionTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } -func (cpt *createPartitionTask) PreExecute(ctx context.Context) error { - cpt.Base.MsgType = commonpb.MsgType_CreatePartition - cpt.Base.SourceID = paramtable.GetNodeID() +func (t *createPartitionTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_CreatePartition + t.Base.SourceID = paramtable.GetNodeID() - collName, partitionTag := cpt.CollectionName, cpt.PartitionName + collName, partitionTag := t.CollectionName, t.PartitionName if err := validateCollectionName(collName); err != nil { return err } - partitionKeyMode, err := isPartitionKeyMode(ctx, cpt.GetDbName(), collName) + partitionKeyMode, err := isPartitionKeyMode(ctx, t.GetDbName(), collName) if err != nil { return err } @@ -866,18 +885,18 @@ func (cpt *createPartitionTask) PreExecute(ctx context.Context) error { return nil } -func (cpt *createPartitionTask) Execute(ctx context.Context) (err error) { - cpt.result, err = cpt.rootCoord.CreatePartition(ctx, cpt.CreatePartitionRequest) +func (t *createPartitionTask) Execute(ctx context.Context) (err error) { + t.result, err = t.rootCoord.CreatePartition(ctx, t.CreatePartitionRequest) if err != nil { return err } - if cpt.result.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(cpt.result.Reason) + if t.result.ErrorCode != commonpb.ErrorCode_Success { + return errors.New(t.result.Reason) } return err } -func (cpt *createPartitionTask) PostExecute(ctx context.Context) error { +func (t *createPartitionTask) PostExecute(ctx context.Context) error { return nil } @@ -885,59 +904,61 @@ type dropPartitionTask struct { Condition *milvuspb.DropPartitionRequest ctx context.Context - rootCoord types.RootCoord - queryCoord types.QueryCoord + rootCoord types.RootCoordClient + queryCoord types.QueryCoordClient result *commonpb.Status } -func (dpt *dropPartitionTask) TraceCtx() context.Context { - return dpt.ctx +func (t *dropPartitionTask) TraceCtx() context.Context { + return t.ctx } -func (dpt *dropPartitionTask) ID() UniqueID { - return dpt.Base.MsgID +func (t *dropPartitionTask) ID() UniqueID { + return t.Base.MsgID } -func (dpt *dropPartitionTask) SetID(uid UniqueID) { - dpt.Base.MsgID = uid +func (t *dropPartitionTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (dpt *dropPartitionTask) Name() string { +func (t *dropPartitionTask) Name() string { return DropPartitionTaskName } -func (dpt *dropPartitionTask) Type() commonpb.MsgType { - return dpt.Base.MsgType +func (t *dropPartitionTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (dpt *dropPartitionTask) BeginTs() Timestamp { - return dpt.Base.Timestamp +func (t *dropPartitionTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (dpt *dropPartitionTask) EndTs() Timestamp { - return dpt.Base.Timestamp +func (t *dropPartitionTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (dpt *dropPartitionTask) SetTs(ts Timestamp) { - dpt.Base.Timestamp = ts +func (t *dropPartitionTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (dpt *dropPartitionTask) OnEnqueue() error { - dpt.Base = commonpbutil.NewMsgBase() +func (t *dropPartitionTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } -func (dpt *dropPartitionTask) PreExecute(ctx context.Context) error { - dpt.Base.MsgType = commonpb.MsgType_DropPartition - dpt.Base.SourceID = paramtable.GetNodeID() +func (t *dropPartitionTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_DropPartition + t.Base.SourceID = paramtable.GetNodeID() - collName, partitionTag := dpt.CollectionName, dpt.PartitionName + collName, partitionTag := t.CollectionName, t.PartitionName if err := validateCollectionName(collName); err != nil { return err } - partitionKeyMode, err := isPartitionKeyMode(ctx, dpt.GetDbName(), collName) + partitionKeyMode, err := isPartitionKeyMode(ctx, t.GetDbName(), collName) if err != nil { return err } @@ -949,11 +970,11 @@ func (dpt *dropPartitionTask) PreExecute(ctx context.Context) error { return err } - collID, err := globalMetaCache.GetCollectionID(ctx, dpt.GetDbName(), dpt.GetCollectionName()) + collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), t.GetCollectionName()) if err != nil { return err } - partID, err := globalMetaCache.GetPartitionID(ctx, dpt.GetDbName(), dpt.GetCollectionName(), dpt.GetPartitionName()) + partID, err := globalMetaCache.GetPartitionID(ctx, t.GetDbName(), t.GetCollectionName(), t.GetPartitionName()) if err != nil { if errors.Is(merr.ErrPartitionNotFound, err) { return nil @@ -961,12 +982,12 @@ func (dpt *dropPartitionTask) PreExecute(ctx context.Context) error { return err } - collLoaded, err := isCollectionLoaded(ctx, dpt.queryCoord, collID) + collLoaded, err := isCollectionLoaded(ctx, t.queryCoord, collID) if err != nil { return err } if collLoaded { - loaded, err := isPartitionLoaded(ctx, dpt.queryCoord, collID, []int64{partID}) + loaded, err := isPartitionLoaded(ctx, t.queryCoord, collID, []int64{partID}) if err != nil { return err } @@ -978,18 +999,18 @@ func (dpt *dropPartitionTask) PreExecute(ctx context.Context) error { return nil } -func (dpt *dropPartitionTask) Execute(ctx context.Context) (err error) { - dpt.result, err = dpt.rootCoord.DropPartition(ctx, dpt.DropPartitionRequest) +func (t *dropPartitionTask) Execute(ctx context.Context) (err error) { + t.result, err = t.rootCoord.DropPartition(ctx, t.DropPartitionRequest) if err != nil { return err } - if dpt.result.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(dpt.result.Reason) + if t.result.ErrorCode != commonpb.ErrorCode_Success { + return errors.New(t.result.Reason) } return err } -func (dpt *dropPartitionTask) PostExecute(ctx context.Context) error { +func (t *dropPartitionTask) PostExecute(ctx context.Context) error { return nil } @@ -997,52 +1018,52 @@ type hasPartitionTask struct { Condition *milvuspb.HasPartitionRequest ctx context.Context - rootCoord types.RootCoord + rootCoord types.RootCoordClient result *milvuspb.BoolResponse } -func (hpt *hasPartitionTask) TraceCtx() context.Context { - return hpt.ctx +func (t *hasPartitionTask) TraceCtx() context.Context { + return t.ctx } -func (hpt *hasPartitionTask) ID() UniqueID { - return hpt.Base.MsgID +func (t *hasPartitionTask) ID() UniqueID { + return t.Base.MsgID } -func (hpt *hasPartitionTask) SetID(uid UniqueID) { - hpt.Base.MsgID = uid +func (t *hasPartitionTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (hpt *hasPartitionTask) Name() string { +func (t *hasPartitionTask) Name() string { return HasPartitionTaskName } -func (hpt *hasPartitionTask) Type() commonpb.MsgType { - return hpt.Base.MsgType +func (t *hasPartitionTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (hpt *hasPartitionTask) BeginTs() Timestamp { - return hpt.Base.Timestamp +func (t *hasPartitionTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (hpt *hasPartitionTask) EndTs() Timestamp { - return hpt.Base.Timestamp +func (t *hasPartitionTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (hpt *hasPartitionTask) SetTs(ts Timestamp) { - hpt.Base.Timestamp = ts +func (t *hasPartitionTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (hpt *hasPartitionTask) OnEnqueue() error { - hpt.Base = commonpbutil.NewMsgBase() +func (t *hasPartitionTask) OnEnqueue() error { + t.Base = commonpbutil.NewMsgBase() return nil } -func (hpt *hasPartitionTask) PreExecute(ctx context.Context) error { - hpt.Base.MsgType = commonpb.MsgType_HasPartition - hpt.Base.SourceID = paramtable.GetNodeID() +func (t *hasPartitionTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_HasPartition + t.Base.SourceID = paramtable.GetNodeID() - collName, partitionTag := hpt.CollectionName, hpt.PartitionName + collName, partitionTag := t.CollectionName, t.PartitionName if err := validateCollectionName(collName); err != nil { return err @@ -1054,18 +1075,18 @@ func (hpt *hasPartitionTask) PreExecute(ctx context.Context) error { return nil } -func (hpt *hasPartitionTask) Execute(ctx context.Context) (err error) { - hpt.result, err = hpt.rootCoord.HasPartition(ctx, hpt.HasPartitionRequest) +func (t *hasPartitionTask) Execute(ctx context.Context) (err error) { + t.result, err = t.rootCoord.HasPartition(ctx, t.HasPartitionRequest) if err != nil { return err } - if hpt.result.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(hpt.result.Status.Reason) + if t.result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return merr.Error(t.result.GetStatus()) } return err } -func (hpt *hasPartitionTask) PostExecute(ctx context.Context) error { +func (t *hasPartitionTask) PostExecute(ctx context.Context) error { return nil } @@ -1073,58 +1094,58 @@ type showPartitionsTask struct { Condition *milvuspb.ShowPartitionsRequest ctx context.Context - rootCoord types.RootCoord - queryCoord types.QueryCoord + rootCoord types.RootCoordClient + queryCoord types.QueryCoordClient result *milvuspb.ShowPartitionsResponse } -func (spt *showPartitionsTask) TraceCtx() context.Context { - return spt.ctx +func (t *showPartitionsTask) TraceCtx() context.Context { + return t.ctx } -func (spt *showPartitionsTask) ID() UniqueID { - return spt.Base.MsgID +func (t *showPartitionsTask) ID() UniqueID { + return t.Base.MsgID } -func (spt *showPartitionsTask) SetID(uid UniqueID) { - spt.Base.MsgID = uid +func (t *showPartitionsTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (spt *showPartitionsTask) Name() string { +func (t *showPartitionsTask) Name() string { return ShowPartitionTaskName } -func (spt *showPartitionsTask) Type() commonpb.MsgType { - return spt.Base.MsgType +func (t *showPartitionsTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (spt *showPartitionsTask) BeginTs() Timestamp { - return spt.Base.Timestamp +func (t *showPartitionsTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (spt *showPartitionsTask) EndTs() Timestamp { - return spt.Base.Timestamp +func (t *showPartitionsTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (spt *showPartitionsTask) SetTs(ts Timestamp) { - spt.Base.Timestamp = ts +func (t *showPartitionsTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (spt *showPartitionsTask) OnEnqueue() error { - spt.Base = commonpbutil.NewMsgBase() +func (t *showPartitionsTask) OnEnqueue() error { + t.Base = commonpbutil.NewMsgBase() return nil } -func (spt *showPartitionsTask) PreExecute(ctx context.Context) error { - spt.Base.MsgType = commonpb.MsgType_ShowPartitions - spt.Base.SourceID = paramtable.GetNodeID() +func (t *showPartitionsTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_ShowPartitions + t.Base.SourceID = paramtable.GetNodeID() - if err := validateCollectionName(spt.CollectionName); err != nil { + if err := validateCollectionName(t.CollectionName); err != nil { return err } - if spt.GetType() == milvuspb.ShowType_InMemory { - for _, partitionName := range spt.PartitionNames { + if t.GetType() == milvuspb.ShowType_InMemory { + for _, partitionName := range t.PartitionNames { if err := validatePartitionTag(partitionName, true); err != nil { return err } @@ -1134,8 +1155,8 @@ func (spt *showPartitionsTask) PreExecute(ctx context.Context) error { return nil } -func (spt *showPartitionsTask) Execute(ctx context.Context) error { - respFromRootCoord, err := spt.rootCoord.ShowPartitions(ctx, spt.ShowPartitionsRequest) +func (t *showPartitionsTask) Execute(ctx context.Context) error { + respFromRootCoord, err := t.rootCoord.ShowPartitions(ctx, t.ShowPartitionsRequest) if err != nil { return err } @@ -1144,16 +1165,16 @@ func (spt *showPartitionsTask) Execute(ctx context.Context) error { return errors.New("failed to show partitions") } - if respFromRootCoord.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(respFromRootCoord.Status.Reason) + if respFromRootCoord.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return merr.Error(respFromRootCoord.GetStatus()) } - if spt.GetType() == milvuspb.ShowType_InMemory { - collectionName := spt.CollectionName - collectionID, err := globalMetaCache.GetCollectionID(ctx, spt.GetDbName(), collectionName) + if t.GetType() == milvuspb.ShowType_InMemory { + collectionName := t.CollectionName + collectionID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), collectionName) if err != nil { log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName), - zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions")) + zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "showPartitions")) return err } IDs2Names := make(map[UniqueID]string) @@ -1162,19 +1183,19 @@ func (spt *showPartitionsTask) Execute(ctx context.Context) error { IDs2Names[partitionID] = partitionName } partitionIDs := make([]UniqueID, 0) - for _, partitionName := range spt.PartitionNames { - partitionID, err := globalMetaCache.GetPartitionID(ctx, spt.GetDbName(), collectionName, partitionName) + for _, partitionName := range t.PartitionNames { + partitionID, err := globalMetaCache.GetPartitionID(ctx, t.GetDbName(), collectionName, partitionName) if err != nil { log.Debug("Failed to get partition id.", zap.Any("partitionName", partitionName), - zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions")) + zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "showPartitions")) return err } partitionIDs = append(partitionIDs, partitionID) IDs2Names[partitionID] = partitionName } - resp, err := spt.queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{ + resp, err := t.queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{ Base: commonpbutil.UpdateMsgBase( - spt.Base, + t.Base, commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections), ), CollectionID: collectionID, @@ -1188,11 +1209,11 @@ func (spt *showPartitionsTask) Execute(ctx context.Context) error { return errors.New("failed to show partitions") } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(resp.Status.Reason) + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return merr.Error(resp.GetStatus()) } - spt.result = &milvuspb.ShowPartitionsResponse{ + t.result = &milvuspb.ShowPartitionsResponse{ Status: resp.Status, PartitionNames: make([]string, 0, len(resp.PartitionIDs)), PartitionIDs: make([]int64, 0, len(resp.PartitionIDs)), @@ -1205,29 +1226,29 @@ func (spt *showPartitionsTask) Execute(ctx context.Context) error { partitionName, ok := IDs2Names[id] if !ok { log.Debug("Failed to get partition id.", zap.Any("partitionName", partitionName), - zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions")) + zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "showPartitions")) return errors.New("failed to show partitions") } - partitionInfo, err := globalMetaCache.GetPartitionInfo(ctx, spt.GetDbName(), collectionName, partitionName) + partitionInfo, err := globalMetaCache.GetPartitionInfo(ctx, t.GetDbName(), collectionName, partitionName) if err != nil { log.Debug("Failed to get partition id.", zap.Any("partitionName", partitionName), - zap.Any("requestID", spt.Base.MsgID), zap.Any("requestType", "showPartitions")) + zap.Any("requestID", t.Base.MsgID), zap.Any("requestType", "showPartitions")) return err } - spt.result.PartitionIDs = append(spt.result.PartitionIDs, id) - spt.result.PartitionNames = append(spt.result.PartitionNames, partitionName) - spt.result.CreatedTimestamps = append(spt.result.CreatedTimestamps, partitionInfo.createdTimestamp) - spt.result.CreatedUtcTimestamps = append(spt.result.CreatedUtcTimestamps, partitionInfo.createdUtcTimestamp) - spt.result.InMemoryPercentages = append(spt.result.InMemoryPercentages, resp.InMemoryPercentages[offset]) + t.result.PartitionIDs = append(t.result.PartitionIDs, id) + t.result.PartitionNames = append(t.result.PartitionNames, partitionName) + t.result.CreatedTimestamps = append(t.result.CreatedTimestamps, partitionInfo.createdTimestamp) + t.result.CreatedUtcTimestamps = append(t.result.CreatedUtcTimestamps, partitionInfo.createdUtcTimestamp) + t.result.InMemoryPercentages = append(t.result.InMemoryPercentages, resp.InMemoryPercentages[offset]) } } else { - spt.result = respFromRootCoord + t.result = respFromRootCoord } return nil } -func (spt *showPartitionsTask) PostExecute(ctx context.Context) error { +func (t *showPartitionsTask) PostExecute(ctx context.Context) error { return nil } @@ -1235,92 +1256,100 @@ type flushTask struct { Condition *milvuspb.FlushRequest ctx context.Context - dataCoord types.DataCoord + dataCoord types.DataCoordClient result *milvuspb.FlushResponse + + replicateMsgStream msgstream.MsgStream } -func (ft *flushTask) TraceCtx() context.Context { - return ft.ctx +func (t *flushTask) TraceCtx() context.Context { + return t.ctx } -func (ft *flushTask) ID() UniqueID { - return ft.Base.MsgID +func (t *flushTask) ID() UniqueID { + return t.Base.MsgID } -func (ft *flushTask) SetID(uid UniqueID) { - ft.Base.MsgID = uid +func (t *flushTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (ft *flushTask) Name() string { +func (t *flushTask) Name() string { return FlushTaskName } -func (ft *flushTask) Type() commonpb.MsgType { - return ft.Base.MsgType +func (t *flushTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (ft *flushTask) BeginTs() Timestamp { - return ft.Base.Timestamp +func (t *flushTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (ft *flushTask) EndTs() Timestamp { - return ft.Base.Timestamp +func (t *flushTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (ft *flushTask) SetTs(ts Timestamp) { - ft.Base.Timestamp = ts +func (t *flushTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (ft *flushTask) OnEnqueue() error { - ft.Base = commonpbutil.NewMsgBase() +func (t *flushTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } -func (ft *flushTask) PreExecute(ctx context.Context) error { - ft.Base.MsgType = commonpb.MsgType_Flush - ft.Base.SourceID = paramtable.GetNodeID() +func (t *flushTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_Flush + t.Base.SourceID = paramtable.GetNodeID() return nil } -func (ft *flushTask) Execute(ctx context.Context) error { +func (t *flushTask) Execute(ctx context.Context) error { coll2Segments := make(map[string]*schemapb.LongArray) flushColl2Segments := make(map[string]*schemapb.LongArray) coll2SealTimes := make(map[string]int64) - for _, collName := range ft.CollectionNames { - collID, err := globalMetaCache.GetCollectionID(ctx, ft.GetDbName(), collName) + coll2FlushTs := make(map[string]Timestamp) + for _, collName := range t.CollectionNames { + collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), collName) if err != nil { return err } flushReq := &datapb.FlushRequest{ Base: commonpbutil.UpdateMsgBase( - ft.Base, + t.Base, commonpbutil.WithMsgType(commonpb.MsgType_Flush), ), CollectionID: collID, IsImport: false, } - resp, err := ft.dataCoord.Flush(ctx, flushReq) + resp, err := t.dataCoord.Flush(ctx, flushReq) if err != nil { return fmt.Errorf("failed to call flush to data coordinator: %s", err.Error()) } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(resp.Status.Reason) + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return merr.Error(resp.GetStatus()) } coll2Segments[collName] = &schemapb.LongArray{Data: resp.GetSegmentIDs()} flushColl2Segments[collName] = &schemapb.LongArray{Data: resp.GetFlushSegmentIDs()} coll2SealTimes[collName] = resp.GetTimeOfSeal() + coll2FlushTs[collName] = resp.GetFlushTs() } - ft.result = &milvuspb.FlushResponse{ - Status: merr.Status(nil), - DbName: ft.GetDbName(), + SendReplicateMessagePack(ctx, t.replicateMsgStream, t.FlushRequest) + t.result = &milvuspb.FlushResponse{ + Status: merr.Success(), + DbName: t.GetDbName(), CollSegIDs: coll2Segments, FlushCollSegIDs: flushColl2Segments, CollSealTimes: coll2SealTimes, + CollFlushTs: coll2FlushTs, } return nil } -func (ft *flushTask) PostExecute(ctx context.Context) error { +func (t *flushTask) PostExecute(ctx context.Context) error { return nil } @@ -1328,72 +1357,75 @@ type loadCollectionTask struct { Condition *milvuspb.LoadCollectionRequest ctx context.Context - queryCoord types.QueryCoord - datacoord types.DataCoord + queryCoord types.QueryCoordClient + datacoord types.DataCoordClient result *commonpb.Status - collectionID UniqueID + collectionID UniqueID + replicateMsgStream msgstream.MsgStream } -func (lct *loadCollectionTask) TraceCtx() context.Context { - return lct.ctx +func (t *loadCollectionTask) TraceCtx() context.Context { + return t.ctx } -func (lct *loadCollectionTask) ID() UniqueID { - return lct.Base.MsgID +func (t *loadCollectionTask) ID() UniqueID { + return t.Base.MsgID } -func (lct *loadCollectionTask) SetID(uid UniqueID) { - lct.Base.MsgID = uid +func (t *loadCollectionTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (lct *loadCollectionTask) Name() string { +func (t *loadCollectionTask) Name() string { return LoadCollectionTaskName } -func (lct *loadCollectionTask) Type() commonpb.MsgType { - return lct.Base.MsgType +func (t *loadCollectionTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (lct *loadCollectionTask) BeginTs() Timestamp { - return lct.Base.Timestamp +func (t *loadCollectionTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (lct *loadCollectionTask) EndTs() Timestamp { - return lct.Base.Timestamp +func (t *loadCollectionTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (lct *loadCollectionTask) SetTs(ts Timestamp) { - lct.Base.Timestamp = ts +func (t *loadCollectionTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (lct *loadCollectionTask) OnEnqueue() error { - lct.Base = commonpbutil.NewMsgBase() +func (t *loadCollectionTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } -func (lct *loadCollectionTask) PreExecute(ctx context.Context) error { +func (t *loadCollectionTask) PreExecute(ctx context.Context) error { log.Ctx(ctx).Debug("loadCollectionTask PreExecute", zap.String("role", typeutil.ProxyRole)) - lct.Base.MsgType = commonpb.MsgType_LoadCollection - lct.Base.SourceID = paramtable.GetNodeID() + t.Base.MsgType = commonpb.MsgType_LoadCollection + t.Base.SourceID = paramtable.GetNodeID() - collName := lct.CollectionName + collName := t.CollectionName if err := validateCollectionName(collName); err != nil { return err } // To compat with LoadCollcetion before Milvus@2.1 - if lct.ReplicaNumber == 0 { - lct.ReplicaNumber = 1 + if t.ReplicaNumber == 0 { + t.ReplicaNumber = 1 } return nil } -func (lct *loadCollectionTask) Execute(ctx context.Context) (err error) { - collID, err := globalMetaCache.GetCollectionID(ctx, lct.GetDbName(), lct.CollectionName) +func (t *loadCollectionTask) Execute(ctx context.Context) (err error) { + collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), t.CollectionName) log := log.Ctx(ctx).With( zap.String("role", typeutil.ProxyRole), @@ -1404,62 +1436,66 @@ func (lct *loadCollectionTask) Execute(ctx context.Context) (err error) { return err } - lct.collectionID = collID - collSchema, err := globalMetaCache.GetCollectionSchema(ctx, lct.GetDbName(), lct.CollectionName) + t.collectionID = collID + collSchema, err := globalMetaCache.GetCollectionSchema(ctx, t.GetDbName(), t.CollectionName) if err != nil { return err } // check index - indexResponse, err := lct.datacoord.DescribeIndex(ctx, &indexpb.DescribeIndexRequest{ + indexResponse, err := t.datacoord.DescribeIndex(ctx, &indexpb.DescribeIndexRequest{ CollectionID: collID, IndexName: "", }) + if err == nil { + err = merr.Error(indexResponse.GetStatus()) + } if err != nil { + if errors.Is(err, merr.ErrIndexNotFound) { + err = merr.WrapErrIndexNotFoundForCollection(t.GetCollectionName()) + } return err } - if indexResponse.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(indexResponse.Status.Reason) - } hasVecIndex := false fieldIndexIDs := make(map[int64]int64) for _, index := range indexResponse.IndexInfos { fieldIndexIDs[index.FieldID] = index.IndexID for _, field := range collSchema.Fields { - if index.FieldID == field.FieldID && (field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_BinaryVector) { + if index.FieldID == field.FieldID && (field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_Float16Vector) { hasVecIndex = true } } } if !hasVecIndex { - errMsg := fmt.Sprintf("there is no vector index on collection: %s, please create index firstly", lct.LoadCollectionRequest.CollectionName) + errMsg := fmt.Sprintf("there is no vector index on collection: %s, please create index firstly", t.LoadCollectionRequest.CollectionName) log.Error(errMsg) return errors.New(errMsg) } request := &querypb.LoadCollectionRequest{ Base: commonpbutil.UpdateMsgBase( - lct.Base, + t.Base, commonpbutil.WithMsgType(commonpb.MsgType_LoadCollection), ), DbID: 0, CollectionID: collID, Schema: collSchema, - ReplicaNumber: lct.ReplicaNumber, + ReplicaNumber: t.ReplicaNumber, FieldIndexID: fieldIndexIDs, - Refresh: lct.Refresh, - ResourceGroups: lct.ResourceGroups, + Refresh: t.Refresh, + ResourceGroups: t.ResourceGroups, } log.Debug("send LoadCollectionRequest to query coordinator", zap.Any("schema", request.Schema)) - lct.result, err = lct.queryCoord.LoadCollection(ctx, request) + t.result, err = t.queryCoord.LoadCollection(ctx, request) if err != nil { return fmt.Errorf("call query coordinator LoadCollection: %s", err) } + SendReplicateMessagePack(ctx, t.replicateMsgStream, t.LoadCollectionRequest) return nil } -func (lct *loadCollectionTask) PostExecute(ctx context.Context) error { - collID, err := globalMetaCache.GetCollectionID(ctx, lct.GetDbName(), lct.CollectionName) +func (t *loadCollectionTask) PostExecute(ctx context.Context) error { + collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), t.CollectionName) log.Ctx(ctx).Debug("loadCollectionTask PostExecute", zap.String("role", typeutil.ProxyRole), zap.Int64("collectionID", collID)) @@ -1473,55 +1509,57 @@ type releaseCollectionTask struct { Condition *milvuspb.ReleaseCollectionRequest ctx context.Context - queryCoord types.QueryCoord + queryCoord types.QueryCoordClient result *commonpb.Status - chMgr channelsMgr - collectionID UniqueID + collectionID UniqueID + replicateMsgStream msgstream.MsgStream } -func (rct *releaseCollectionTask) TraceCtx() context.Context { - return rct.ctx +func (t *releaseCollectionTask) TraceCtx() context.Context { + return t.ctx } -func (rct *releaseCollectionTask) ID() UniqueID { - return rct.Base.MsgID +func (t *releaseCollectionTask) ID() UniqueID { + return t.Base.MsgID } -func (rct *releaseCollectionTask) SetID(uid UniqueID) { - rct.Base.MsgID = uid +func (t *releaseCollectionTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (rct *releaseCollectionTask) Name() string { +func (t *releaseCollectionTask) Name() string { return ReleaseCollectionTaskName } -func (rct *releaseCollectionTask) Type() commonpb.MsgType { - return rct.Base.MsgType +func (t *releaseCollectionTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (rct *releaseCollectionTask) BeginTs() Timestamp { - return rct.Base.Timestamp +func (t *releaseCollectionTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (rct *releaseCollectionTask) EndTs() Timestamp { - return rct.Base.Timestamp +func (t *releaseCollectionTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (rct *releaseCollectionTask) SetTs(ts Timestamp) { - rct.Base.Timestamp = ts +func (t *releaseCollectionTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (rct *releaseCollectionTask) OnEnqueue() error { - rct.Base = commonpbutil.NewMsgBase() +func (t *releaseCollectionTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } -func (rct *releaseCollectionTask) PreExecute(ctx context.Context) error { - rct.Base.MsgType = commonpb.MsgType_ReleaseCollection - rct.Base.SourceID = paramtable.GetNodeID() +func (t *releaseCollectionTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_ReleaseCollection + t.Base.SourceID = paramtable.GetNodeID() - collName := rct.CollectionName + collName := t.CollectionName if err := validateCollectionName(collName); err != nil { return err @@ -1530,30 +1568,33 @@ func (rct *releaseCollectionTask) PreExecute(ctx context.Context) error { return nil } -func (rct *releaseCollectionTask) Execute(ctx context.Context) (err error) { - collID, err := globalMetaCache.GetCollectionID(ctx, rct.GetDbName(), rct.CollectionName) +func (t *releaseCollectionTask) Execute(ctx context.Context) (err error) { + collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), t.CollectionName) if err != nil { return err } - rct.collectionID = collID + t.collectionID = collID request := &querypb.ReleaseCollectionRequest{ Base: commonpbutil.UpdateMsgBase( - rct.Base, + t.Base, commonpbutil.WithMsgType(commonpb.MsgType_ReleaseCollection), ), DbID: 0, CollectionID: collID, } - rct.result, err = rct.queryCoord.ReleaseCollection(ctx, request) - - globalMetaCache.RemoveCollection(ctx, rct.GetDbName(), rct.CollectionName) + t.result, err = t.queryCoord.ReleaseCollection(ctx, request) + globalMetaCache.RemoveCollection(ctx, t.GetDbName(), t.CollectionName) + if err != nil { + return err + } + SendReplicateMessagePack(ctx, t.replicateMsgStream, t.ReleaseCollectionRequest) return err } -func (rct *releaseCollectionTask) PostExecute(ctx context.Context) error { - globalMetaCache.DeprecateShardCache(rct.GetDbName(), rct.CollectionName) +func (t *releaseCollectionTask) PostExecute(ctx context.Context) error { + globalMetaCache.DeprecateShardCache(t.GetDbName(), t.CollectionName) return nil } @@ -1561,61 +1602,63 @@ type loadPartitionsTask struct { Condition *milvuspb.LoadPartitionsRequest ctx context.Context - queryCoord types.QueryCoord - datacoord types.DataCoord + queryCoord types.QueryCoordClient + datacoord types.DataCoordClient result *commonpb.Status collectionID UniqueID } -func (lpt *loadPartitionsTask) TraceCtx() context.Context { - return lpt.ctx +func (t *loadPartitionsTask) TraceCtx() context.Context { + return t.ctx } -func (lpt *loadPartitionsTask) ID() UniqueID { - return lpt.Base.MsgID +func (t *loadPartitionsTask) ID() UniqueID { + return t.Base.MsgID } -func (lpt *loadPartitionsTask) SetID(uid UniqueID) { - lpt.Base.MsgID = uid +func (t *loadPartitionsTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (lpt *loadPartitionsTask) Name() string { +func (t *loadPartitionsTask) Name() string { return LoadPartitionTaskName } -func (lpt *loadPartitionsTask) Type() commonpb.MsgType { - return lpt.Base.MsgType +func (t *loadPartitionsTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (lpt *loadPartitionsTask) BeginTs() Timestamp { - return lpt.Base.Timestamp +func (t *loadPartitionsTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (lpt *loadPartitionsTask) EndTs() Timestamp { - return lpt.Base.Timestamp +func (t *loadPartitionsTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (lpt *loadPartitionsTask) SetTs(ts Timestamp) { - lpt.Base.Timestamp = ts +func (t *loadPartitionsTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (lpt *loadPartitionsTask) OnEnqueue() error { - lpt.Base = commonpbutil.NewMsgBase() +func (t *loadPartitionsTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } -func (lpt *loadPartitionsTask) PreExecute(ctx context.Context) error { - lpt.Base.MsgType = commonpb.MsgType_LoadPartitions - lpt.Base.SourceID = paramtable.GetNodeID() +func (t *loadPartitionsTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_LoadPartitions + t.Base.SourceID = paramtable.GetNodeID() - collName := lpt.CollectionName + collName := t.CollectionName if err := validateCollectionName(collName); err != nil { return err } - partitionKeyMode, err := isPartitionKeyMode(ctx, lpt.GetDbName(), collName) + partitionKeyMode, err := isPartitionKeyMode(ctx, t.GetDbName(), collName) if err != nil { return err } @@ -1626,46 +1669,49 @@ func (lpt *loadPartitionsTask) PreExecute(ctx context.Context) error { return nil } -func (lpt *loadPartitionsTask) Execute(ctx context.Context) error { +func (t *loadPartitionsTask) Execute(ctx context.Context) error { var partitionIDs []int64 - collID, err := globalMetaCache.GetCollectionID(ctx, lpt.GetDbName(), lpt.CollectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), t.CollectionName) if err != nil { return err } - lpt.collectionID = collID - collSchema, err := globalMetaCache.GetCollectionSchema(ctx, lpt.GetDbName(), lpt.CollectionName) + t.collectionID = collID + collSchema, err := globalMetaCache.GetCollectionSchema(ctx, t.GetDbName(), t.CollectionName) if err != nil { return err } // check index - indexResponse, err := lpt.datacoord.DescribeIndex(ctx, &indexpb.DescribeIndexRequest{ + indexResponse, err := t.datacoord.DescribeIndex(ctx, &indexpb.DescribeIndexRequest{ CollectionID: collID, IndexName: "", }) + if err == nil { + err = merr.Error(indexResponse.GetStatus()) + } if err != nil { + if errors.Is(err, merr.ErrIndexNotFound) { + err = merr.WrapErrIndexNotFoundForCollection(t.GetCollectionName()) + } return err } - if indexResponse.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(indexResponse.Status.Reason) - } hasVecIndex := false fieldIndexIDs := make(map[int64]int64) for _, index := range indexResponse.IndexInfos { fieldIndexIDs[index.FieldID] = index.IndexID for _, field := range collSchema.Fields { - if index.FieldID == field.FieldID && (field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_BinaryVector) { + if index.FieldID == field.FieldID && (field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_Float16Vector) { hasVecIndex = true } } } if !hasVecIndex { - errMsg := fmt.Sprintf("there is no vector index on collection: %s, please create index firstly", lpt.LoadPartitionsRequest.CollectionName) + errMsg := fmt.Sprintf("there is no vector index on collection: %s, please create index firstly", t.LoadPartitionsRequest.CollectionName) log.Ctx(ctx).Error(errMsg) return errors.New(errMsg) } - for _, partitionName := range lpt.PartitionNames { - partitionID, err := globalMetaCache.GetPartitionID(ctx, lpt.GetDbName(), lpt.CollectionName, partitionName) + for _, partitionName := range t.PartitionNames { + partitionID, err := globalMetaCache.GetPartitionID(ctx, t.GetDbName(), t.CollectionName, partitionName) if err != nil { return err } @@ -1676,23 +1722,23 @@ func (lpt *loadPartitionsTask) Execute(ctx context.Context) error { } request := &querypb.LoadPartitionsRequest{ Base: commonpbutil.UpdateMsgBase( - lpt.Base, + t.Base, commonpbutil.WithMsgType(commonpb.MsgType_LoadPartitions), ), DbID: 0, CollectionID: collID, PartitionIDs: partitionIDs, Schema: collSchema, - ReplicaNumber: lpt.ReplicaNumber, + ReplicaNumber: t.ReplicaNumber, FieldIndexID: fieldIndexIDs, - Refresh: lpt.Refresh, - ResourceGroups: lpt.ResourceGroups, + Refresh: t.Refresh, + ResourceGroups: t.ResourceGroups, } - lpt.result, err = lpt.queryCoord.LoadPartitions(ctx, request) + t.result, err = t.queryCoord.LoadPartitions(ctx, request) return err } -func (lpt *loadPartitionsTask) PostExecute(ctx context.Context) error { +func (t *loadPartitionsTask) PostExecute(ctx context.Context) error { return nil } @@ -1700,60 +1746,62 @@ type releasePartitionsTask struct { Condition *milvuspb.ReleasePartitionsRequest ctx context.Context - queryCoord types.QueryCoord + queryCoord types.QueryCoordClient result *commonpb.Status collectionID UniqueID } -func (rpt *releasePartitionsTask) TraceCtx() context.Context { - return rpt.ctx +func (t *releasePartitionsTask) TraceCtx() context.Context { + return t.ctx } -func (rpt *releasePartitionsTask) ID() UniqueID { - return rpt.Base.MsgID +func (t *releasePartitionsTask) ID() UniqueID { + return t.Base.MsgID } -func (rpt *releasePartitionsTask) SetID(uid UniqueID) { - rpt.Base.MsgID = uid +func (t *releasePartitionsTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (rpt *releasePartitionsTask) Type() commonpb.MsgType { - return rpt.Base.MsgType +func (t *releasePartitionsTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (rpt *releasePartitionsTask) Name() string { +func (t *releasePartitionsTask) Name() string { return ReleasePartitionTaskName } -func (rpt *releasePartitionsTask) BeginTs() Timestamp { - return rpt.Base.Timestamp +func (t *releasePartitionsTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (rpt *releasePartitionsTask) EndTs() Timestamp { - return rpt.Base.Timestamp +func (t *releasePartitionsTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (rpt *releasePartitionsTask) SetTs(ts Timestamp) { - rpt.Base.Timestamp = ts +func (t *releasePartitionsTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (rpt *releasePartitionsTask) OnEnqueue() error { - rpt.Base = commonpbutil.NewMsgBase() +func (t *releasePartitionsTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } -func (rpt *releasePartitionsTask) PreExecute(ctx context.Context) error { - rpt.Base.MsgType = commonpb.MsgType_ReleasePartitions - rpt.Base.SourceID = paramtable.GetNodeID() +func (t *releasePartitionsTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_ReleasePartitions + t.Base.SourceID = paramtable.GetNodeID() - collName := rpt.CollectionName + collName := t.CollectionName if err := validateCollectionName(collName); err != nil { return err } - partitionKeyMode, err := isPartitionKeyMode(ctx, rpt.GetDbName(), collName) + partitionKeyMode, err := isPartitionKeyMode(ctx, t.GetDbName(), collName) if err != nil { return err } @@ -1764,15 +1812,15 @@ func (rpt *releasePartitionsTask) PreExecute(ctx context.Context) error { return nil } -func (rpt *releasePartitionsTask) Execute(ctx context.Context) (err error) { +func (t *releasePartitionsTask) Execute(ctx context.Context) (err error) { var partitionIDs []int64 - collID, err := globalMetaCache.GetCollectionID(ctx, rpt.GetDbName(), rpt.CollectionName) + collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), t.CollectionName) if err != nil { return err } - rpt.collectionID = collID - for _, partitionName := range rpt.PartitionNames { - partitionID, err := globalMetaCache.GetPartitionID(ctx, rpt.GetDbName(), rpt.CollectionName, partitionName) + t.collectionID = collID + for _, partitionName := range t.PartitionNames { + partitionID, err := globalMetaCache.GetPartitionID(ctx, t.GetDbName(), t.CollectionName, partitionName) if err != nil { return err } @@ -1780,19 +1828,19 @@ func (rpt *releasePartitionsTask) Execute(ctx context.Context) (err error) { } request := &querypb.ReleasePartitionsRequest{ Base: commonpbutil.UpdateMsgBase( - rpt.Base, + t.Base, commonpbutil.WithMsgType(commonpb.MsgType_ReleasePartitions), ), DbID: 0, CollectionID: collID, PartitionIDs: partitionIDs, } - rpt.result, err = rpt.queryCoord.ReleasePartitions(ctx, request) + t.result, err = t.queryCoord.ReleasePartitions(ctx, request) return err } -func (rpt *releasePartitionsTask) PostExecute(ctx context.Context) error { - globalMetaCache.DeprecateShardCache(rpt.GetDbName(), rpt.CollectionName) +func (t *releasePartitionsTask) PostExecute(ctx context.Context) error { + globalMetaCache.DeprecateShardCache(t.GetDbName(), t.CollectionName) return nil } @@ -1801,83 +1849,85 @@ type CreateAliasTask struct { Condition *milvuspb.CreateAliasRequest ctx context.Context - rootCoord types.RootCoord + rootCoord types.RootCoordClient result *commonpb.Status } // TraceCtx returns the trace context of the task. -func (c *CreateAliasTask) TraceCtx() context.Context { - return c.ctx +func (t *CreateAliasTask) TraceCtx() context.Context { + return t.ctx } // ID return the id of the task -func (c *CreateAliasTask) ID() UniqueID { - return c.Base.MsgID +func (t *CreateAliasTask) ID() UniqueID { + return t.Base.MsgID } // SetID sets the id of the task -func (c *CreateAliasTask) SetID(uid UniqueID) { - c.Base.MsgID = uid +func (t *CreateAliasTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } // Name returns the name of the task -func (c *CreateAliasTask) Name() string { +func (t *CreateAliasTask) Name() string { return CreateAliasTaskName } // Type returns the type of the task -func (c *CreateAliasTask) Type() commonpb.MsgType { - return c.Base.MsgType +func (t *CreateAliasTask) Type() commonpb.MsgType { + return t.Base.MsgType } // BeginTs returns the ts -func (c *CreateAliasTask) BeginTs() Timestamp { - return c.Base.Timestamp +func (t *CreateAliasTask) BeginTs() Timestamp { + return t.Base.Timestamp } // EndTs returns the ts -func (c *CreateAliasTask) EndTs() Timestamp { - return c.Base.Timestamp +func (t *CreateAliasTask) EndTs() Timestamp { + return t.Base.Timestamp } // SetTs sets the ts -func (c *CreateAliasTask) SetTs(ts Timestamp) { - c.Base.Timestamp = ts +func (t *CreateAliasTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } // OnEnqueue defines the behavior task enqueued -func (c *CreateAliasTask) OnEnqueue() error { - c.Base = commonpbutil.NewMsgBase() +func (t *CreateAliasTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } -// PreExecute defines the action before task execution -func (c *CreateAliasTask) PreExecute(ctx context.Context) error { - c.Base.MsgType = commonpb.MsgType_CreateAlias - c.Base.SourceID = paramtable.GetNodeID() +// PreExecute defines the tion before task execution +func (t *CreateAliasTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_CreateAlias + t.Base.SourceID = paramtable.GetNodeID() - collAlias := c.Alias + collAlias := t.Alias // collection alias uses the same format as collection name if err := ValidateCollectionAlias(collAlias); err != nil { return err } - collName := c.CollectionName + collName := t.CollectionName if err := validateCollectionName(collName); err != nil { return err } return nil } -// Execute defines the actual execution of create alias -func (c *CreateAliasTask) Execute(ctx context.Context) error { +// Execute defines the tual execution of create alias +func (t *CreateAliasTask) Execute(ctx context.Context) error { var err error - c.result, err = c.rootCoord.CreateAlias(ctx, c.CreateAliasRequest) + t.result, err = t.rootCoord.CreateAlias(ctx, t.CreateAliasRequest) return err } // PostExecute defines the post execution, do nothing for create alias -func (c *CreateAliasTask) PostExecute(ctx context.Context) error { +func (t *CreateAliasTask) PostExecute(ctx context.Context) error { return nil } @@ -1886,68 +1936,70 @@ type DropAliasTask struct { Condition *milvuspb.DropAliasRequest ctx context.Context - rootCoord types.RootCoord + rootCoord types.RootCoordClient result *commonpb.Status } // TraceCtx returns the context for trace -func (d *DropAliasTask) TraceCtx() context.Context { - return d.ctx +func (t *DropAliasTask) TraceCtx() context.Context { + return t.ctx } // ID returns the MsgID -func (d *DropAliasTask) ID() UniqueID { - return d.Base.MsgID +func (t *DropAliasTask) ID() UniqueID { + return t.Base.MsgID } // SetID sets the MsgID -func (d *DropAliasTask) SetID(uid UniqueID) { - d.Base.MsgID = uid +func (t *DropAliasTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } // Name returns the name of the task -func (d *DropAliasTask) Name() string { +func (t *DropAliasTask) Name() string { return DropAliasTaskName } -func (d *DropAliasTask) Type() commonpb.MsgType { - return d.Base.MsgType +func (t *DropAliasTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (d *DropAliasTask) BeginTs() Timestamp { - return d.Base.Timestamp +func (t *DropAliasTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (d *DropAliasTask) EndTs() Timestamp { - return d.Base.Timestamp +func (t *DropAliasTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (d *DropAliasTask) SetTs(ts Timestamp) { - d.Base.Timestamp = ts +func (t *DropAliasTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (d *DropAliasTask) OnEnqueue() error { - d.Base = commonpbutil.NewMsgBase() +func (t *DropAliasTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } -func (d *DropAliasTask) PreExecute(ctx context.Context) error { - d.Base.MsgType = commonpb.MsgType_DropAlias - d.Base.SourceID = paramtable.GetNodeID() - collAlias := d.Alias +func (t *DropAliasTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_DropAlias + t.Base.SourceID = paramtable.GetNodeID() + collAlias := t.Alias if err := ValidateCollectionAlias(collAlias); err != nil { return err } return nil } -func (d *DropAliasTask) Execute(ctx context.Context) error { +func (t *DropAliasTask) Execute(ctx context.Context) error { var err error - d.result, err = d.rootCoord.DropAlias(ctx, d.DropAliasRequest) + t.result, err = t.rootCoord.DropAlias(ctx, t.DropAliasRequest) return err } -func (d *DropAliasTask) PostExecute(ctx context.Context) error { +func (t *DropAliasTask) PostExecute(ctx context.Context) error { return nil } @@ -1956,58 +2008,60 @@ type AlterAliasTask struct { Condition *milvuspb.AlterAliasRequest ctx context.Context - rootCoord types.RootCoord + rootCoord types.RootCoordClient result *commonpb.Status } -func (a *AlterAliasTask) TraceCtx() context.Context { - return a.ctx +func (t *AlterAliasTask) TraceCtx() context.Context { + return t.ctx } -func (a *AlterAliasTask) ID() UniqueID { - return a.Base.MsgID +func (t *AlterAliasTask) ID() UniqueID { + return t.Base.MsgID } -func (a *AlterAliasTask) SetID(uid UniqueID) { - a.Base.MsgID = uid +func (t *AlterAliasTask) SetID(uid UniqueID) { + t.Base.MsgID = uid } -func (a *AlterAliasTask) Name() string { +func (t *AlterAliasTask) Name() string { return AlterAliasTaskName } -func (a *AlterAliasTask) Type() commonpb.MsgType { - return a.Base.MsgType +func (t *AlterAliasTask) Type() commonpb.MsgType { + return t.Base.MsgType } -func (a *AlterAliasTask) BeginTs() Timestamp { - return a.Base.Timestamp +func (t *AlterAliasTask) BeginTs() Timestamp { + return t.Base.Timestamp } -func (a *AlterAliasTask) EndTs() Timestamp { - return a.Base.Timestamp +func (t *AlterAliasTask) EndTs() Timestamp { + return t.Base.Timestamp } -func (a *AlterAliasTask) SetTs(ts Timestamp) { - a.Base.Timestamp = ts +func (t *AlterAliasTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts } -func (a *AlterAliasTask) OnEnqueue() error { - a.Base = commonpbutil.NewMsgBase() +func (t *AlterAliasTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } -func (a *AlterAliasTask) PreExecute(ctx context.Context) error { - a.Base.MsgType = commonpb.MsgType_AlterAlias - a.Base.SourceID = paramtable.GetNodeID() +func (t *AlterAliasTask) PreExecute(ctx context.Context) error { + t.Base.MsgType = commonpb.MsgType_AlterAlias + t.Base.SourceID = paramtable.GetNodeID() - collAlias := a.Alias + collAlias := t.Alias // collection alias uses the same format as collection name if err := ValidateCollectionAlias(collAlias); err != nil { return err } - collName := a.CollectionName + collName := t.CollectionName if err := validateCollectionName(collName); err != nil { return err } @@ -2015,13 +2069,13 @@ func (a *AlterAliasTask) PreExecute(ctx context.Context) error { return nil } -func (a *AlterAliasTask) Execute(ctx context.Context) error { +func (t *AlterAliasTask) Execute(ctx context.Context) error { var err error - a.result, err = a.rootCoord.AlterAlias(ctx, a.AlterAliasRequest) + t.result, err = t.rootCoord.AlterAlias(ctx, t.AlterAliasRequest) return err } -func (a *AlterAliasTask) PostExecute(ctx context.Context) error { +func (t *AlterAliasTask) PostExecute(ctx context.Context) error { return nil } @@ -2029,7 +2083,7 @@ type CreateResourceGroupTask struct { Condition *milvuspb.CreateResourceGroupRequest ctx context.Context - queryCoord types.QueryCoord + queryCoord types.QueryCoordClient result *commonpb.Status } @@ -2066,7 +2120,9 @@ func (t *CreateResourceGroupTask) SetTs(ts Timestamp) { } func (t *CreateResourceGroupTask) OnEnqueue() error { - t.Base = commonpbutil.NewMsgBase() + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } @@ -2091,7 +2147,7 @@ type DropResourceGroupTask struct { Condition *milvuspb.DropResourceGroupRequest ctx context.Context - queryCoord types.QueryCoord + queryCoord types.QueryCoordClient result *commonpb.Status } @@ -2128,7 +2184,9 @@ func (t *DropResourceGroupTask) SetTs(ts Timestamp) { } func (t *DropResourceGroupTask) OnEnqueue() error { - t.Base = commonpbutil.NewMsgBase() + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } @@ -2153,7 +2211,7 @@ type DescribeResourceGroupTask struct { Condition *milvuspb.DescribeResourceGroupRequest ctx context.Context - queryCoord types.QueryCoord + queryCoord types.QueryCoordClient result *milvuspb.DescribeResourceGroupResponse } @@ -2230,7 +2288,7 @@ func (t *DescribeResourceGroupTask) Execute(ctx context.Context) error { return ret, nil } - if resp.Status.ErrorCode == commonpb.ErrorCode_Success { + if resp.GetStatus().GetErrorCode() == commonpb.ErrorCode_Success { rgInfo := resp.GetResourceGroup() numLoadedReplica, err := getCollectionName(rgInfo.NumLoadedReplica) @@ -2274,7 +2332,7 @@ type TransferNodeTask struct { Condition *milvuspb.TransferNodeRequest ctx context.Context - queryCoord types.QueryCoord + queryCoord types.QueryCoordClient result *commonpb.Status } @@ -2311,7 +2369,9 @@ func (t *TransferNodeTask) SetTs(ts Timestamp) { } func (t *TransferNodeTask) OnEnqueue() error { - t.Base = commonpbutil.NewMsgBase() + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } @@ -2336,7 +2396,7 @@ type TransferReplicaTask struct { Condition *milvuspb.TransferReplicaRequest ctx context.Context - queryCoord types.QueryCoord + queryCoord types.QueryCoordClient result *commonpb.Status } @@ -2373,7 +2433,9 @@ func (t *TransferReplicaTask) SetTs(ts Timestamp) { } func (t *TransferReplicaTask) OnEnqueue() error { - t.Base = commonpbutil.NewMsgBase() + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } return nil } @@ -2407,7 +2469,7 @@ type ListResourceGroupsTask struct { Condition *milvuspb.ListResourceGroupsRequest ctx context.Context - queryCoord types.QueryCoord + queryCoord types.QueryCoordClient result *milvuspb.ListResourceGroupsResponse } diff --git a/internal/proxy/task_database.go b/internal/proxy/task_database.go index b72b7da9e79dd..fc8bb33711ff7 100644 --- a/internal/proxy/task_database.go +++ b/internal/proxy/task_database.go @@ -6,6 +6,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -14,8 +15,10 @@ type createDatabaseTask struct { Condition *milvuspb.CreateDatabaseRequest ctx context.Context - rootCoord types.RootCoord + rootCoord types.RootCoordClient result *commonpb.Status + + replicateMsgStream msgstream.MsgStream } func (cdt *createDatabaseTask) TraceCtx() context.Context { @@ -51,7 +54,9 @@ func (cdt *createDatabaseTask) SetTs(ts Timestamp) { } func (cdt *createDatabaseTask) OnEnqueue() error { - cdt.Base = commonpbutil.NewMsgBase() + if cdt.Base == nil { + cdt.Base = commonpbutil.NewMsgBase() + } cdt.Base.MsgType = commonpb.MsgType_CreateDatabase cdt.Base.SourceID = paramtable.GetNodeID() return nil @@ -63,8 +68,10 @@ func (cdt *createDatabaseTask) PreExecute(ctx context.Context) error { func (cdt *createDatabaseTask) Execute(ctx context.Context) error { var err error - cdt.result = &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError} cdt.result, err = cdt.rootCoord.CreateDatabase(ctx, cdt.CreateDatabaseRequest) + if cdt.result != nil && cdt.result.ErrorCode == commonpb.ErrorCode_Success { + SendReplicateMessagePack(ctx, cdt.replicateMsgStream, cdt.CreateDatabaseRequest) + } return err } @@ -76,8 +83,10 @@ type dropDatabaseTask struct { Condition *milvuspb.DropDatabaseRequest ctx context.Context - rootCoord types.RootCoord + rootCoord types.RootCoordClient result *commonpb.Status + + replicateMsgStream msgstream.MsgStream } func (ddt *dropDatabaseTask) TraceCtx() context.Context { @@ -113,7 +122,9 @@ func (ddt *dropDatabaseTask) SetTs(ts Timestamp) { } func (ddt *dropDatabaseTask) OnEnqueue() error { - ddt.Base = commonpbutil.NewMsgBase() + if ddt.Base == nil { + ddt.Base = commonpbutil.NewMsgBase() + } ddt.Base.MsgType = commonpb.MsgType_DropDatabase ddt.Base.SourceID = paramtable.GetNodeID() return nil @@ -125,11 +136,11 @@ func (ddt *dropDatabaseTask) PreExecute(ctx context.Context) error { func (ddt *dropDatabaseTask) Execute(ctx context.Context) error { var err error - ddt.result = &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError} ddt.result, err = ddt.rootCoord.DropDatabase(ctx, ddt.DropDatabaseRequest) if ddt.result != nil && ddt.result.ErrorCode == commonpb.ErrorCode_Success { globalMetaCache.RemoveDatabase(ctx, ddt.DbName) + SendReplicateMessagePack(ctx, ddt.replicateMsgStream, ddt.DropDatabaseRequest) } return err } @@ -142,7 +153,7 @@ type listDatabaseTask struct { Condition *milvuspb.ListDatabasesRequest ctx context.Context - rootCoord types.RootCoord + rootCoord types.RootCoordClient result *milvuspb.ListDatabasesResponse } @@ -191,11 +202,6 @@ func (ldt *listDatabaseTask) PreExecute(ctx context.Context) error { func (ldt *listDatabaseTask) Execute(ctx context.Context) error { var err error - ldt.result = &milvuspb.ListDatabasesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, - } ldt.result, err = ldt.rootCoord.ListDatabases(ctx, ldt.ListDatabasesRequest) return err } diff --git a/internal/proxy/task_database_test.go b/internal/proxy/task_database_test.go index fd211ff518fa2..c65393bab2262 100644 --- a/internal/proxy/task_database_test.go +++ b/internal/proxy/task_database_test.go @@ -4,19 +4,18 @@ import ( "context" "testing" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func TestCreateDatabaseTask(t *testing.T) { paramtable.Init() rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() + defer rc.Close() ctx := context.Background() task := &createDatabaseTask{ @@ -46,6 +45,7 @@ func TestCreateDatabaseTask(t *testing.T) { err = task.Execute(ctx) assert.NoError(t, err) + task.Base = nil err = task.OnEnqueue() assert.NoError(t, err) assert.Equal(t, paramtable.GetNodeID(), task.GetBase().GetSourceID()) @@ -62,8 +62,7 @@ func TestCreateDatabaseTask(t *testing.T) { func TestDropDatabaseTask(t *testing.T) { paramtable.Init() rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() + defer rc.Close() ctx := context.Background() task := &dropDatabaseTask{ @@ -100,6 +99,7 @@ func TestDropDatabaseTask(t *testing.T) { err = task.Execute(ctx) assert.NoError(t, err) + task.Base = nil err = task.OnEnqueue() assert.NoError(t, err) assert.Equal(t, paramtable.GetNodeID(), task.GetBase().GetSourceID()) @@ -116,8 +116,7 @@ func TestDropDatabaseTask(t *testing.T) { func TestListDatabaseTask(t *testing.T) { paramtable.Init() rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() + defer rc.Close() ctx := context.Background() task := &listDatabaseTask{ diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index ba804a79dab41..75485c1ffde5a 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -3,9 +3,10 @@ package proxy import ( "context" "fmt" - "strconv" + "io" "github.com/cockroachdb/errors" + "github.com/golang/protobuf/proto" "go.opentelemetry.io/otel" "go.uber.org/zap" @@ -14,10 +15,13 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -30,19 +34,28 @@ type BaseDeleteTask = msgstream.DeleteMsg type deleteTask struct { Condition - deleteMsg *BaseDeleteTask - ctx context.Context - deleteExpr string - //req *milvuspb.DeleteRequest - result *milvuspb.MutationResult + ctx context.Context + tr *timerecord.TimeRecorder + + req *milvuspb.DeleteRequest + result *milvuspb.MutationResult + + // channel chMgr channelsMgr chTicker channelsTimeTicker - vChannels []vChan pChannels []pChan + vChannels []vChan - idAllocator *allocator.IDAllocator - collectionID UniqueID + idAllocator *allocator.IDAllocator + lb LBPolicy + + // delete info schema *schemapb.CollectionSchema + ts Timestamp + msgID UniqueID + collectionID UniqueID + partitionID UniqueID + count int } func (dt *deleteTask) TraceCtx() context.Context { @@ -50,15 +63,15 @@ func (dt *deleteTask) TraceCtx() context.Context { } func (dt *deleteTask) ID() UniqueID { - return dt.deleteMsg.Base.MsgID + return dt.msgID } func (dt *deleteTask) SetID(uid UniqueID) { - dt.deleteMsg.Base.MsgID = uid + dt.msgID = uid } func (dt *deleteTask) Type() commonpb.MsgType { - return dt.deleteMsg.Base.MsgType + return commonpb.MsgType_Delete } func (dt *deleteTask) Name() string { @@ -66,24 +79,23 @@ func (dt *deleteTask) Name() string { } func (dt *deleteTask) BeginTs() Timestamp { - return dt.deleteMsg.Base.Timestamp + return dt.ts } func (dt *deleteTask) EndTs() Timestamp { - return dt.deleteMsg.Base.Timestamp + return dt.ts } func (dt *deleteTask) SetTs(ts Timestamp) { - dt.deleteMsg.Base.Timestamp = ts + dt.ts = ts } func (dt *deleteTask) OnEnqueue() error { - dt.deleteMsg.Base = commonpbutil.NewMsgBase() return nil } func (dt *deleteTask) setChannels() error { - collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.deleteMsg.GetDbName(), dt.deleteMsg.CollectionName) + collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.req.GetDbName(), dt.req.GetCollectionName()) if err != nil { return err } @@ -99,27 +111,20 @@ func (dt *deleteTask) getChannels() []pChan { return dt.pChannels } -func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, expr string) (res *schemapb.IDs, rowNum int64, err error) { - if len(expr) == 0 { - log.Warn("empty expr") - return - } - - plan, err := createExprPlan(schema, expr) - if err != nil { - return res, 0, fmt.Errorf("failed to create expr plan, expr = %s", expr) - } - - // delete request only support expr "id in [a, b]" - termExpr, ok := plan.Node.(*planpb.PlanNode_Predicates).Predicates.Expr.(*planpb.Expr_TermExpr) +func getExpr(plan *planpb.PlanNode) (bool, *planpb.Expr_TermExpr) { + // simple delete request need expr with "pk in [a, b]" + termExpr, ok := plan.Node.(*planpb.PlanNode_Query).Query.Predicates.Expr.(*planpb.Expr_TermExpr) if !ok { - return res, 0, fmt.Errorf("invalid plan node type, only pk in [1, 2] supported") + return false, nil } if !termExpr.TermExpr.GetColumnInfo().GetIsPrimaryKey() { - return res, 0, fmt.Errorf("invalid expression, we only support to delete by pk, expr: %s", expr) + return false, nil } + return true, termExpr +} +func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, termExpr *planpb.Expr_TermExpr) (res *schemapb.IDs, rowNum int64, err error) { res = &schemapb.IDs{} rowNum = int64(len(termExpr.TermExpr.Values)) switch termExpr.TermExpr.ColumnInfo.GetDataType() { @@ -151,11 +156,8 @@ func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, expr string) (res } func (dt *deleteTask) PreExecute(ctx context.Context) error { - dt.deleteMsg.Base.MsgType = commonpb.MsgType_Delete - dt.deleteMsg.Base.SourceID = paramtable.GetNodeID() - dt.result = &milvuspb.MutationResult{ - Status: merr.Status(nil), + Status: merr.Success(), IDs: &schemapb.IDs{ IdField: nil, }, @@ -163,65 +165,51 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error { } log := log.Ctx(ctx) - - collName := dt.deleteMsg.CollectionName + collName := dt.req.GetCollectionName() if err := validateCollectionName(collName); err != nil { return ErrWithLog(log, "Invalid collection name", err) } - collID, err := globalMetaCache.GetCollectionID(ctx, dt.deleteMsg.GetDbName(), collName) + collID, err := globalMetaCache.GetCollectionID(ctx, dt.req.GetDbName(), collName) if err != nil { return ErrWithLog(log, "Failed to get collection id", err) } - dt.deleteMsg.CollectionID = collID dt.collectionID = collID - partitionKeyMode, err := isPartitionKeyMode(ctx, dt.deleteMsg.GetDbName(), dt.deleteMsg.CollectionName) + partitionKeyMode, err := isPartitionKeyMode(ctx, dt.req.GetDbName(), dt.req.GetCollectionName()) if err != nil { return ErrWithLog(log, "Failed to get partition key mode", err) } - if partitionKeyMode && len(dt.deleteMsg.PartitionName) != 0 { - return ErrWithLog(log, "", errors.New("not support manually specifying the partition names if partition key mode is used")) + if partitionKeyMode && len(dt.req.PartitionName) != 0 { + return errors.New("not support manually specifying the partition names if partition key mode is used") } // If partitionName is not empty, partitionID will be set. - if len(dt.deleteMsg.PartitionName) > 0 { - partName := dt.deleteMsg.PartitionName + if len(dt.req.PartitionName) > 0 { + partName := dt.req.GetPartitionName() if err := validatePartitionTag(partName, true); err != nil { return ErrWithLog(log, "Invalid partition name", err) } - partID, err := globalMetaCache.GetPartitionID(ctx, dt.deleteMsg.GetDbName(), collName, partName) + partID, err := globalMetaCache.GetPartitionID(ctx, dt.req.GetDbName(), collName, partName) if err != nil { return ErrWithLog(log, "Failed to get partition id", err) } - dt.deleteMsg.PartitionID = partID + dt.partitionID = partID } else { - dt.deleteMsg.PartitionID = common.InvalidPartitionID + dt.partitionID = common.InvalidPartitionID } - schema, err := globalMetaCache.GetCollectionSchema(ctx, dt.deleteMsg.GetDbName(), collName) + schema, err := globalMetaCache.GetCollectionSchema(ctx, dt.req.GetDbName(), collName) if err != nil { return ErrWithLog(log, "Failed to get collection schema", err) } dt.schema = schema - // get delete.primaryKeys from delete expr - primaryKeys, numRow, err := getPrimaryKeysFromExpr(schema, dt.deleteExpr) + // hash primary keys to channels + channelNames, err := dt.chMgr.getVChannels(dt.collectionID) if err != nil { return ErrWithLog(log, "Failed to get primary keys from expr", err) } - - dt.deleteMsg.NumRows = numRow - dt.deleteMsg.PrimaryKeys = primaryKeys - log.Debug("get primary keys from expr", zap.Int64("len of primary keys", dt.deleteMsg.NumRows)) - - // set result - dt.result.IDs = primaryKeys - dt.result.DeleteCnt = dt.deleteMsg.NumRows - - dt.deleteMsg.Timestamps = make([]uint64, numRow) - for index := range dt.deleteMsg.Timestamps { - dt.deleteMsg.Timestamps[index] = dt.BeginTs() - } + dt.vChannels = channelNames log.Debug("pre delete done", zap.Int64("collection_id", dt.collectionID)) @@ -233,70 +221,166 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { defer sp.End() log := log.Ctx(ctx) - tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID())) + if len(dt.req.GetExpr()) == 0 { + return merr.WrapErrParameterInvalid("valid expr", "empty expr", "invalid expression") + } + + dt.tr = timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID())) + stream, err := dt.chMgr.getOrCreateDmlStream(dt.collectionID) + if err != nil { + return err + } - collID := dt.deleteMsg.CollectionID - stream, err := dt.chMgr.getOrCreateDmlStream(collID) + plan, err := planparserv2.CreateRetrievePlan(dt.schema, dt.req.Expr) + if err != nil { + return fmt.Errorf("failed to create expr plan, expr = %s", dt.req.GetExpr()) + } + + isSimple, termExp := getExpr(plan) + if isSimple { + // if could get delete.primaryKeys from delete expr + err := dt.simpleDelete(ctx, termExp, stream) + if err != nil { + return err + } + } else { + // if get complex delete expr + // need query from querynode before delete + err = dt.complexDelete(ctx, plan, stream) + if err != nil { + log.Warn("complex delete failed,but delete some data", zap.Int("count", dt.count), zap.String("expr", dt.req.GetExpr())) + return err + } + } + + return nil +} + +func (dt *deleteTask) PostExecute(ctx context.Context) error { + return nil +} + +func (dt *deleteTask) getStreamingQueryAndDelteFunc(stream msgstream.MsgStream, plan *planpb.PlanNode) executeFunc { + return func(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channelIDs ...string) error { + // outputField := translateOutputFields(, dt.schema, true) + + partationIDs := []int64{} + if dt.partitionID != common.InvalidFieldID { + partationIDs = append(partationIDs, dt.partitionID) + } + + // set plan + _, outputFieldIDs := translatePkOutputFields(dt.schema) + outputFieldIDs = append(outputFieldIDs, common.TimeStampField) + plan.OutputFieldIds = outputFieldIDs + + serializedPlan, err := proto.Marshal(plan) + if err != nil { + return err + } + + queryReq := &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_Retrieve), + commonpbutil.WithMsgID(dt.msgID), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + commonpbutil.WithTargetID(nodeID), + ), + MvccTimestamp: dt.ts, + ReqID: paramtable.GetNodeID(), + DbID: 0, // TODO + CollectionID: dt.collectionID, + PartitionIDs: partationIDs, + SerializedExprPlan: serializedPlan, + OutputFieldsId: outputFieldIDs, + GuaranteeTimestamp: parseGuaranteeTsFromConsistency(dt.ts, dt.ts, commonpb.ConsistencyLevel_Bounded), + }, + DmlChannels: channelIDs, + Scope: querypb.DataScope_All, + } + + client, err := qn.QueryStream(ctx, queryReq) + if err != nil { + log.Warn("query for delete return error", zap.Error(err)) + return err + } + + for { + result, err := client.Recv() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + + err = merr.Error(result.GetStatus()) + if err != nil { + return err + } + + err = dt.produce(ctx, stream, result.GetIds()) + if err != nil { + return err + } + } + } +} + +func (dt *deleteTask) complexDelete(ctx context.Context, plan *planpb.PlanNode, stream msgstream.MsgStream) error { + err := dt.lb.Execute(ctx, CollectionWorkLoad{ + db: dt.req.GetDbName(), + collectionName: dt.req.GetCollectionName(), + collectionID: dt.collectionID, + nq: 1, + exec: dt.getStreamingQueryAndDelteFunc(stream, plan), + }) if err != nil { log.Warn("fail to get or create dml stream", zap.Error(err)) return err } - // hash primary keys to channels - channelNames, err := dt.chMgr.getVChannels(collID) + return nil +} + +func (dt *deleteTask) simpleDelete(ctx context.Context, termExp *planpb.Expr_TermExpr, stream msgstream.MsgStream) error { + primaryKeys, numRow, err := getPrimaryKeysFromExpr(dt.schema, termExp) if err != nil { - log.Warn("get vChannels failed", zap.Int64("collectionID", collID), zap.Error(err)) - dt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - dt.result.Status.Reason = err.Error() + log.Info("Failed to get primary keys from expr", zap.Error(err)) return err } - dt.deleteMsg.HashValues = typeutil.HashPK2Channels(dt.result.IDs, channelNames) + log.Debug("get primary keys from expr", zap.Int64("len of primary keys", numRow)) + err = dt.produce(ctx, stream, primaryKeys) + if err != nil { + return err + } + return nil +} +func (dt *deleteTask) produce(ctx context.Context, stream msgstream.MsgStream, primaryKeys *schemapb.IDs) error { + hashValues := typeutil.HashPK2Channels(primaryKeys, dt.vChannels) // repack delete msg by dmChannel result := make(map[uint32]msgstream.TsMsg) - collectionName := dt.deleteMsg.CollectionName - collectionID := dt.deleteMsg.CollectionID - partitionID := dt.deleteMsg.PartitionID - partitionName := dt.deleteMsg.PartitionName - proxyID := dt.deleteMsg.Base.SourceID - for index, key := range dt.deleteMsg.HashValues { - vchannel := channelNames[key] - ts := dt.deleteMsg.Timestamps[index] + numRows := int64(0) + for index, key := range hashValues { + vchannel := dt.vChannels[key] _, ok := result[key] if !ok { - msgid, err := dt.idAllocator.AllocOne() + deleteMsg, err := dt.newDeleteMsg(ctx) if err != nil { - return errors.Wrap(err, "failed to allocate MsgID of delete") - } - sliceRequest := msgpb.DeleteRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_Delete), - // msgid of delete msg must be set - // or it will be seen as duplicated msg in mq - commonpbutil.WithMsgID(msgid), - commonpbutil.WithTimeStamp(ts), - commonpbutil.WithSourceID(proxyID), - ), - CollectionID: collectionID, - PartitionID: partitionID, - CollectionName: collectionName, - PartitionName: partitionName, - PrimaryKeys: &schemapb.IDs{}, - } - deleteMsg := &msgstream.DeleteMsg{ - BaseMsg: msgstream.BaseMsg{ - Ctx: ctx, - }, - DeleteRequest: sliceRequest, + return err } + deleteMsg.ShardName = vchannel result[key] = deleteMsg } curMsg := result[key].(*msgstream.DeleteMsg) - curMsg.HashValues = append(curMsg.HashValues, dt.deleteMsg.HashValues[index]) - curMsg.Timestamps = append(curMsg.Timestamps, dt.deleteMsg.Timestamps[index]) - typeutil.AppendIDs(curMsg.PrimaryKeys, dt.deleteMsg.PrimaryKeys, index) + curMsg.HashValues = append(curMsg.HashValues, hashValues[index]) + curMsg.Timestamps = append(curMsg.Timestamps, dt.ts) + + typeutil.AppendIDs(curMsg.PrimaryKeys, primaryKeys, index) curMsg.NumRows++ - curMsg.ShardName = vchannel + numRows++ } // send delete request to log broker @@ -304,6 +388,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { BeginTs: dt.BeginTs(), EndTs: dt.EndTs(), } + for _, msg := range result { if msg != nil { msgPack.Msgs = append(msgPack.Msgs, msg) @@ -311,24 +396,44 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { } log.Debug("send delete request to virtual channels", - zap.String("collectionName", dt.deleteMsg.GetCollectionName()), - zap.Int64("collectionID", collID), - zap.Strings("virtual_channels", channelNames), + zap.String("collectionName", dt.req.GetCollectionName()), + zap.Int64("collectionID", dt.collectionID), + zap.Strings("virtual_channels", dt.vChannels), zap.Int64("taskID", dt.ID()), - zap.Duration("prepare duration", tr.RecordSpan())) + zap.Duration("prepare duration", dt.tr.RecordSpan())) - err = stream.Produce(msgPack) + err := stream.Produce(msgPack) if err != nil { - dt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - dt.result.Status.Reason = err.Error() return err } - sendMsgDur := tr.ElapseSpan() - metrics.ProxySendMutationReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.DeleteLabel).Observe(float64(sendMsgDur.Milliseconds())) - + dt.result.DeleteCnt += numRows return nil } -func (dt *deleteTask) PostExecute(ctx context.Context) error { - return nil +func (dt *deleteTask) newDeleteMsg(ctx context.Context) (*msgstream.DeleteMsg, error) { + msgid, err := dt.idAllocator.AllocOne() + if err != nil { + return nil, errors.Wrap(err, "failed to allocate MsgID of delete") + } + sliceRequest := msgpb.DeleteRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_Delete), + // msgid of delete msg must be set + // or it will be seen as duplicated msg in mq + commonpbutil.WithMsgID(msgid), + commonpbutil.WithTimeStamp(dt.ts), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + CollectionID: dt.collectionID, + PartitionID: dt.partitionID, + CollectionName: dt.req.GetCollectionName(), + PartitionName: dt.req.GetPartitionName(), + PrimaryKeys: &schemapb.IDs{}, + } + return &msgstream.DeleteMsg{ + BaseMsg: msgstream.BaseMsg{ + Ctx: ctx, + }, + DeleteRequest: sliceRequest, + }, nil } diff --git a/internal/proxy/task_delete_test.go b/internal/proxy/task_delete_test.go index c9ec7b4304fbe..e09b89b8fef37 100644 --- a/internal/proxy/task_delete_test.go +++ b/internal/proxy/task_delete_test.go @@ -2,108 +2,133 @@ package proxy import ( "context" - "fmt" "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "google.golang.org/grpc" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) -func Test_getPrimaryKeysFromExpr(t *testing.T) { - t.Run("delete on non-pk field", func(t *testing.T) { - schema := &schemapb.CollectionSchema{ - Name: "test_delete", - Description: "", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - { - FieldID: common.StartOfUserFieldID, - Name: "pk", - IsPrimaryKey: true, - DataType: schemapb.DataType_Int64, - }, - { - FieldID: common.StartOfUserFieldID + 1, - Name: "non_pk", - IsPrimaryKey: false, - DataType: schemapb.DataType_Int64, - }, +func Test_GetExpr(t *testing.T) { + schema := &schemapb.CollectionSchema{ + Name: "test_delete", + Description: "", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + FieldID: common.StartOfUserFieldID, + Name: "pk", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, }, - } + { + FieldID: common.StartOfUserFieldID + 1, + Name: "non_pk", + IsPrimaryKey: false, + DataType: schemapb.DataType_Int64, + }, + }, + } + t.Run("delelte with complex pk expr", func(t *testing.T) { + expr := "pk < 4" + plan, err := planparserv2.CreateRetrievePlan(schema, expr) + assert.NoError(t, err) + isSimple, _ := getExpr(plan) + assert.False(t, isSimple) + }) + t.Run("delete with no-pk field expr", func(t *testing.T) { expr := "non_pk in [1, 2, 3]" + plan, err := planparserv2.CreateRetrievePlan(schema, expr) + assert.NoError(t, err) + isSimple, _ := getExpr(plan) + assert.False(t, isSimple) + }) - _, _, err := getPrimaryKeysFromExpr(schema, expr) - assert.Error(t, err) + t.Run("delete with simple expr", func(t *testing.T) { + expr := "pk in [1, 2, 3]" + plan, err := planparserv2.CreateRetrievePlan(schema, expr) + assert.NoError(t, err) + isSimple, _ := getExpr(plan) + assert.True(t, isSimple) }) } -func TestDeleteTask(t *testing.T) { - t.Run("test getChannels", func(t *testing.T) { - collectionID := UniqueID(0) - collectionName := "col-0" - channels := []pChan{"mock-chan-0", "mock-chan-1"} - cache := NewMockCache(t) - cache.On("GetCollectionID", - mock.Anything, // context.Context - mock.AnythingOfType("string"), - mock.AnythingOfType("string"), - ).Return(collectionID, nil) - globalMetaCache = cache - chMgr := newMockChannelsMgr() - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return channels, nil - } - dt := deleteTask{ - ctx: context.Background(), - deleteMsg: &msgstream.DeleteMsg{ - DeleteRequest: msgpb.DeleteRequest{ - CollectionName: collectionName, - }, - }, - chMgr: chMgr, - } - err := dt.setChannels() - assert.NoError(t, err) - resChannels := dt.getChannels() - assert.ElementsMatch(t, channels, resChannels) - assert.ElementsMatch(t, channels, dt.pChannels) +func TestDeleteTask_GetChannels(t *testing.T) { + collectionID := UniqueID(0) + collectionName := "col-0" + channels := []pChan{"mock-chan-0", "mock-chan-1"} + cache := NewMockCache(t) + cache.On("GetCollectionID", + mock.Anything, // context.Context + mock.AnythingOfType("string"), + mock.AnythingOfType("string"), + ).Return(collectionID, nil) + globalMetaCache = cache + chMgr := NewMockChannelsMgr(t) + chMgr.EXPECT().getChannels(mock.Anything).Return(channels, nil) + dt := deleteTask{ + ctx: context.Background(), + req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + }, + chMgr: chMgr, + } + err := dt.setChannels() + assert.NoError(t, err) + resChannels := dt.getChannels() + assert.ElementsMatch(t, channels, resChannels) + assert.ElementsMatch(t, channels, dt.pChannels) +} - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return nil, fmt.Errorf("mock err") - } - // get channels again, should return task's pChannels, so getChannelsFunc should not invoke again - resChannels = dt.getChannels() - assert.ElementsMatch(t, channels, resChannels) - }) +func TestDeleteTask_PreExecute(t *testing.T) { + schema := &schemapb.CollectionSchema{ + Name: "test_delete", + Description: "", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + FieldID: common.StartOfUserFieldID, + Name: "pk", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: common.StartOfUserFieldID + 1, + Name: "non_pk", + IsPrimaryKey: false, + DataType: schemapb.DataType_Int64, + }, + }, + } t.Run("empty collection name", func(t *testing.T) { - dt := deleteTask{ - deleteMsg: &BaseDeleteTask{ - DeleteRequest: msgpb.DeleteRequest{ - Base: &commonpb.MsgBase{}, - }, - }, - } + dt := deleteTask{} assert.Error(t, dt.PreExecute(context.Background())) }) t.Run("fail to get collection id", func(t *testing.T) { - dt := deleteTask{deleteMsg: &BaseDeleteTask{ - DeleteRequest: msgpb.DeleteRequest{ - Base: &commonpb.MsgBase{}, + dt := deleteTask{ + req: &milvuspb.DeleteRequest{ CollectionName: "foo", }, - }} + } cache := NewMockCache(t) cache.On("GetCollectionID", mock.Anything, // context.Context @@ -115,12 +140,9 @@ func TestDeleteTask(t *testing.T) { }) t.Run("fail partition key mode", func(t *testing.T) { - dt := deleteTask{deleteMsg: &BaseDeleteTask{ - DeleteRequest: msgpb.DeleteRequest{ - Base: &commonpb.MsgBase{}, - CollectionName: "foo", - DbName: "db_1", - }, + dt := deleteTask{req: &milvuspb.DeleteRequest{ + CollectionName: "foo", + DbName: "db_1", }} cache := NewMockCache(t) cache.On("GetCollectionID", @@ -139,13 +161,10 @@ func TestDeleteTask(t *testing.T) { }) t.Run("invalid partition name", func(t *testing.T) { - dt := deleteTask{deleteMsg: &BaseDeleteTask{ - DeleteRequest: msgpb.DeleteRequest{ - Base: &commonpb.MsgBase{}, - CollectionName: "foo", - DbName: "db_1", - PartitionName: "aaa", - }, + dt := deleteTask{req: &milvuspb.DeleteRequest{ + CollectionName: "foo", + DbName: "db_1", + PartitionName: "aaa", }} cache := NewMockCache(t) cache.On("GetCollectionID", @@ -176,37 +195,14 @@ func TestDeleteTask(t *testing.T) { assert.Error(t, dt.PreExecute(context.Background())) }) - schema := &schemapb.CollectionSchema{ - Name: "test_delete", - Description: "", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - { - FieldID: common.StartOfUserFieldID, - Name: "pk", - IsPrimaryKey: true, - DataType: schemapb.DataType_Int64, - }, - { - FieldID: common.StartOfUserFieldID + 1, - Name: "non_pk", - IsPrimaryKey: false, - DataType: schemapb.DataType_Int64, - }, - }, - } - t.Run("invalie partition", func(t *testing.T) { dt := deleteTask{ - deleteMsg: &BaseDeleteTask{ - DeleteRequest: msgpb.DeleteRequest{ - Base: &commonpb.MsgBase{}, - CollectionName: "foo", - DbName: "db_1", - PartitionName: "_aaa", - }, + req: &milvuspb.DeleteRequest{ + CollectionName: "foo", + DbName: "db_1", + PartitionName: "aaa", + Expr: "non_pk in [1, 2, 3]", }, - deleteExpr: "non_pk in [1, 2, 3]", } cache := NewMockCache(t) cache.On("GetCollectionID", @@ -229,7 +225,7 @@ func TestDeleteTask(t *testing.T) { globalMetaCache = cache assert.Error(t, dt.PreExecute(context.Background())) - dt.deleteMsg.PartitionName = "aaa" + dt.req.PartitionName = "aaa" assert.Error(t, dt.PreExecute(context.Background())) cache.On("GetPartitionID", @@ -241,3 +237,430 @@ func TestDeleteTask(t *testing.T) { assert.Error(t, dt.PreExecute(context.Background())) }) } + +func TestDeleteTask_Execute(t *testing.T) { + collectionName := "test_delete" + collectionID := int64(111) + partitionName := "default" + partitionID := int64(222) + channels := []string{"test_channel"} + dbName := "test_1" + + schema := &schemapb.CollectionSchema{ + Name: collectionName, + Description: "", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + FieldID: common.StartOfUserFieldID, + Name: "pk", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: common.StartOfUserFieldID + 1, + Name: "non_pk", + IsPrimaryKey: false, + DataType: schemapb.DataType_Int64, + }, + }, + } + t.Run("empty expr", func(t *testing.T) { + dt := deleteTask{} + assert.Error(t, dt.Execute(context.Background())) + }) + + t.Run("get channel failed", func(t *testing.T) { + mockMgr := NewMockChannelsMgr(t) + dt := deleteTask{ + chMgr: mockMgr, + req: &milvuspb.DeleteRequest{ + Expr: "pk in [1,2]", + }, + } + + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(nil, errors.New("mock error")) + assert.Error(t, dt.Execute(context.Background())) + }) + + t.Run("create plan failed", func(t *testing.T) { + mockMgr := NewMockChannelsMgr(t) + dt := deleteTask{ + chMgr: mockMgr, + schema: schema, + req: &milvuspb.DeleteRequest{ + Expr: "????", + }, + } + stream := msgstream.NewMockMsgStream(t) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) + assert.Error(t, dt.Execute(context.Background())) + }) + + t.Run("alloc failed", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockMgr := NewMockChannelsMgr(t) + rc := mocks.NewMockRootCoordClient(t) + allocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID()) + assert.NoError(t, err) + allocator.Close() + + dt := deleteTask{ + chMgr: mockMgr, + schema: schema, + collectionID: collectionID, + partitionID: partitionID, + vChannels: channels, + idAllocator: allocator, + req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + PartitionName: partitionName, + DbName: dbName, + Expr: "pk in [1,2]", + }, + } + stream := msgstream.NewMockMsgStream(t) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) + + assert.Error(t, dt.Execute(context.Background())) + }) + + t.Run("simple delete failed", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockMgr := NewMockChannelsMgr(t) + rc := mocks.NewMockRootCoordClient(t) + rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return( + &rootcoordpb.AllocIDResponse{ + Status: merr.Success(), + ID: 0, + Count: 1, + }, nil) + allocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID()) + allocator.Start() + assert.NoError(t, err) + + dt := deleteTask{ + chMgr: mockMgr, + schema: schema, + collectionID: collectionID, + partitionID: partitionID, + vChannels: channels, + idAllocator: allocator, + req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + PartitionName: partitionName, + DbName: dbName, + Expr: "pk in [1,2]", + }, + } + stream := msgstream.NewMockMsgStream(t) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) + stream.EXPECT().Produce(mock.Anything).Return(errors.New("mock error")) + assert.Error(t, dt.Execute(context.Background())) + }) + + t.Run("complex delete query rpc failed", func(t *testing.T) { + mockMgr := NewMockChannelsMgr(t) + qn := mocks.NewMockQueryNodeClient(t) + lb := NewMockLBPolicy(t) + + dt := deleteTask{ + chMgr: mockMgr, + schema: schema, + collectionID: collectionID, + partitionID: partitionID, + vChannels: channels, + lb: lb, + result: &milvuspb.MutationResult{ + Status: merr.Success(), + IDs: &schemapb.IDs{ + IdField: nil, + }, + }, + req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + PartitionName: partitionName, + DbName: dbName, + Expr: "pk < 3", + }, + } + stream := msgstream.NewMockMsgStream(t) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) + lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { + return workload.exec(ctx, 1, qn) + }) + + qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) + assert.Error(t, dt.Execute(context.Background())) + assert.Equal(t, int64(0), dt.result.DeleteCnt) + }) + + t.Run("complex delete query failed", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockMgr := NewMockChannelsMgr(t) + rc := mocks.NewMockRootCoordClient(t) + qn := mocks.NewMockQueryNodeClient(t) + lb := NewMockLBPolicy(t) + rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return( + &rootcoordpb.AllocIDResponse{ + Status: merr.Success(), + ID: 0, + Count: 1, + }, nil) + allocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID()) + allocator.Start() + assert.NoError(t, err) + + dt := deleteTask{ + chMgr: mockMgr, + schema: schema, + collectionID: collectionID, + partitionID: partitionID, + vChannels: channels, + idAllocator: allocator, + lb: lb, + result: &milvuspb.MutationResult{ + Status: merr.Success(), + IDs: &schemapb.IDs{ + IdField: nil, + }, + }, + req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + PartitionName: partitionName, + DbName: dbName, + Expr: "pk < 3", + }, + } + stream := msgstream.NewMockMsgStream(t) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) + lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { + return workload.exec(ctx, 1, qn) + }) + + qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return( + func(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) querypb.QueryNode_QueryStreamClient { + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + server.Send(&internalpb.RetrieveResults{ + Status: merr.Success(), + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{0, 1, 2}, + }, + }, + }, + }) + + server.Send(&internalpb.RetrieveResults{ + Status: merr.Status(errors.New("mock error")), + }) + return client + }, nil) + stream.EXPECT().Produce(mock.Anything).Return(nil) + + assert.Error(t, dt.Execute(context.Background())) + // query failed but still delete some data before failed. + assert.Equal(t, int64(3), dt.result.DeleteCnt) + }) + + t.Run("complex delete produce failed", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockMgr := NewMockChannelsMgr(t) + rc := mocks.NewMockRootCoordClient(t) + qn := mocks.NewMockQueryNodeClient(t) + lb := NewMockLBPolicy(t) + rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return( + &rootcoordpb.AllocIDResponse{ + Status: merr.Success(), + ID: 0, + Count: 1, + }, nil) + allocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID()) + allocator.Start() + assert.NoError(t, err) + + dt := deleteTask{ + chMgr: mockMgr, + schema: schema, + collectionID: collectionID, + partitionID: partitionID, + vChannels: channels, + idAllocator: allocator, + lb: lb, + result: &milvuspb.MutationResult{ + Status: merr.Success(), + IDs: &schemapb.IDs{ + IdField: nil, + }, + }, + req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + PartitionName: partitionName, + DbName: dbName, + Expr: "pk < 3", + }, + } + stream := msgstream.NewMockMsgStream(t) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) + lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { + return workload.exec(ctx, 1, qn) + }) + + qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return( + func(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) querypb.QueryNode_QueryStreamClient { + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + server.Send(&internalpb.RetrieveResults{ + Status: merr.Success(), + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{0, 1, 2}, + }, + }, + }, + }) + server.FinishSend(nil) + return client + }, nil) + stream.EXPECT().Produce(mock.Anything).Return(errors.New("mock error")) + + assert.Error(t, dt.Execute(context.Background())) + assert.Equal(t, int64(0), dt.result.DeleteCnt) + }) + + t.Run("complex delete success", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockMgr := NewMockChannelsMgr(t) + rc := mocks.NewMockRootCoordClient(t) + qn := mocks.NewMockQueryNodeClient(t) + lb := NewMockLBPolicy(t) + rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return( + &rootcoordpb.AllocIDResponse{ + Status: merr.Success(), + ID: 0, + Count: 1, + }, nil) + allocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID()) + allocator.Start() + assert.NoError(t, err) + + dt := deleteTask{ + chMgr: mockMgr, + schema: schema, + collectionID: collectionID, + partitionID: partitionID, + vChannels: channels, + idAllocator: allocator, + lb: lb, + result: &milvuspb.MutationResult{ + Status: merr.Success(), + IDs: &schemapb.IDs{ + IdField: nil, + }, + }, + req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + PartitionName: partitionName, + DbName: dbName, + Expr: "pk < 3", + }, + } + stream := msgstream.NewMockMsgStream(t) + mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything).Return(stream, nil) + lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { + return workload.exec(ctx, 1, qn) + }) + + qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return( + func(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) querypb.QueryNode_QueryStreamClient { + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + server.Send(&internalpb.RetrieveResults{ + Status: merr.Success(), + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{0, 1, 2}, + }, + }, + }, + }) + server.FinishSend(nil) + return client + }, nil) + stream.EXPECT().Produce(mock.Anything).Return(nil) + + assert.NoError(t, dt.Execute(context.Background())) + assert.Equal(t, int64(3), dt.result.DeleteCnt) + }) +} + +func TestDeleteTask_SimpleDelete(t *testing.T) { + collectionName := "test_delete" + collectionID := int64(111) + partitionName := "default" + partitionID := int64(222) + dbName := "test_1" + + schema := &schemapb.CollectionSchema{ + Name: collectionName, + Description: "", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + FieldID: common.StartOfUserFieldID, + Name: "pk", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: common.StartOfUserFieldID + 1, + Name: "non_pk", + IsPrimaryKey: false, + DataType: schemapb.DataType_Int64, + }, + }, + } + + task := deleteTask{ + schema: schema, + collectionID: collectionID, + partitionID: partitionID, + req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + PartitionName: partitionName, + DbName: dbName, + }, + } + t.Run("get PK failed", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + expr := &planpb.Expr_TermExpr{ + TermExpr: &planpb.TermExpr{ + ColumnInfo: &planpb.ColumnInfo{ + DataType: schemapb.DataType_BinaryVector, + }, + }, + } + stream := msgstream.NewMockMsgStream(t) + err := task.simpleDelete(ctx, expr, stream) + assert.Error(t, err) + }) +} diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 3512c79c831d5..28ff1d7fa85d3 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -31,6 +31,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/indexparamcheck" @@ -55,10 +56,12 @@ type createIndexTask struct { Condition req *milvuspb.CreateIndexRequest ctx context.Context - rootCoord types.RootCoord - datacoord types.DataCoord + rootCoord types.RootCoordClient + datacoord types.DataCoordClient result *commonpb.Status + replicateMsgStream msgstream.MsgStream + isAutoIndex bool newIndexParams []*commonpb.KeyValuePair newTypeParams []*commonpb.KeyValuePair @@ -101,7 +104,9 @@ func (cit *createIndexTask) SetTs(ts Timestamp) { } func (cit *createIndexTask) OnEnqueue() error { - cit.req.Base = commonpbutil.NewMsgBase() + if cit.req.Base == nil { + cit.req.Base = commonpbutil.NewMsgBase() + } return nil } @@ -140,18 +145,25 @@ func (cit *createIndexTask) parseIndexParams() error { if !isVecIndex { specifyIndexType, exist := indexParamsMap[common.IndexTypeKey] if cit.fieldSchema.DataType == schemapb.DataType_VarChar { - if exist && specifyIndexType != DefaultStringIndexType { + if !exist { + indexParamsMap[common.IndexTypeKey] = DefaultStringIndexType + } + + if exist && !validateStringIndexType(specifyIndexType) { return merr.WrapErrParameterInvalid(DefaultStringIndexType, specifyIndexType, "index type not match") } - indexParamsMap[common.IndexTypeKey] = DefaultStringIndexType - } else { - if cit.fieldSchema.DataType == schemapb.DataType_JSON { - return merr.WrapErrParameterInvalid("not json field", "create index on json field", "create index on json field is not supported") + } else if typeutil.IsArithmetic(cit.fieldSchema.DataType) { + if !exist { + indexParamsMap[common.IndexTypeKey] = DefaultArithmeticIndexType } - if exist && specifyIndexType != DefaultIndexType { - return merr.WrapErrParameterInvalid(DefaultStringIndexType, specifyIndexType, "index type not match") + + if exist && !validateArithmeticIndexType(specifyIndexType) { + return merr.WrapErrParameterInvalid(DefaultArithmeticIndexType, specifyIndexType, "index type not match") } - indexParamsMap[common.IndexTypeKey] = DefaultIndexType + } else { + return merr.WrapErrParameterInvalid("supported field", + fmt.Sprintf("create index on %s field", cit.fieldSchema.DataType.String()), + "create index on json field is not supported") } } @@ -250,7 +262,7 @@ func (cit *createIndexTask) parseIndexParams() error { } for k, v := range indexParamsMap { - //Currently, it is required that type_params and index_params do not have same keys. + // Currently, it is required that type_params and index_params do not have same keys. if k == DimKey || k == common.MaxLengthKey { delete(indexParamsMap, k) continue @@ -291,6 +303,7 @@ func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) e vecDataTypes := []schemapb.DataType{ schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, + schemapb.DataType_Float16Vector, } if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) { return nil @@ -319,6 +332,7 @@ func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) erro vecDataTypes := []schemapb.DataType{ schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, + schemapb.DataType_Float16Vector, } if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) { return indexparamcheck.CheckIndexValid(field.GetDataType(), indexType, indexParams) @@ -405,7 +419,8 @@ func (cit *createIndexTask) Execute(ctx context.Context) error { if cit.result.ErrorCode != commonpb.ErrorCode_Success { return errors.New(cit.result.Reason) } - return err + SendReplicateMessagePack(ctx, cit.replicateMsgStream, cit.req) + return nil } func (cit *createIndexTask) PostExecute(ctx context.Context) error { @@ -416,7 +431,7 @@ type describeIndexTask struct { Condition *milvuspb.DescribeIndexRequest ctx context.Context - datacoord types.DataCoord + datacoord types.DataCoordClient result *milvuspb.DescribeIndexResponse collectionID UniqueID @@ -488,13 +503,19 @@ func (dit *describeIndexTask) Execute(ctx context.Context) error { } resp, err := dit.datacoord.DescribeIndex(ctx, &indexpb.DescribeIndexRequest{CollectionID: dit.collectionID, IndexName: dit.IndexName, Timestamp: dit.Timestamp}) - if err != nil || resp == nil { + if err != nil { return err } + dit.result = &milvuspb.DescribeIndexResponse{} dit.result.Status = resp.GetStatus() - if dit.result.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(dit.result.Status.Reason) + err = merr.Error(resp.GetStatus()) + if err != nil { + if errors.Is(err, merr.ErrIndexNotFound) && len(dit.GetIndexName()) == 0 { + err = merr.WrapErrIndexNotFoundForCollection(dit.GetCollectionName()) + dit.result.Status = merr.Status(err) + } + return err } for _, indexInfo := range resp.IndexInfos { field, err := schemaHelper.GetFieldFromID(indexInfo.FieldID) @@ -533,7 +554,7 @@ type getIndexStatisticsTask struct { Condition *milvuspb.GetIndexStatisticsRequest ctx context.Context - datacoord types.DataCoord + datacoord types.DataCoordClient result *milvuspb.GetIndexStatisticsResponse nodeID int64 @@ -606,14 +627,15 @@ func (dit *getIndexStatisticsTask) Execute(ctx context.Context) error { } resp, err := dit.datacoord.GetIndexStatistics(ctx, &indexpb.GetIndexStatisticsRequest{ - CollectionID: dit.collectionID, IndexName: dit.IndexName}) + CollectionID: dit.collectionID, IndexName: dit.IndexName, + }) if err != nil || resp == nil { return err } dit.result = &milvuspb.GetIndexStatisticsResponse{} dit.result.Status = resp.GetStatus() - if dit.result.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(dit.result.Status.Reason) + if dit.result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return merr.Error(dit.result.GetStatus()) } for _, indexInfo := range resp.IndexInfos { field, err := schemaHelper.GetFieldFromID(indexInfo.FieldID) @@ -648,11 +670,13 @@ type dropIndexTask struct { Condition ctx context.Context *milvuspb.DropIndexRequest - dataCoord types.DataCoord - queryCoord types.QueryCoord + dataCoord types.DataCoordClient + queryCoord types.QueryCoordClient result *commonpb.Status collectionID UniqueID + + replicateMsgStream msgstream.MsgStream } func (dit *dropIndexTask) TraceCtx() context.Context { @@ -688,7 +712,9 @@ func (dit *dropIndexTask) SetTs(ts Timestamp) { } func (dit *dropIndexTask) OnEnqueue() error { - dit.Base = commonpbutil.NewMsgBase() + if dit.Base == nil { + dit.Base = commonpbutil.NewMsgBase() + } return nil } @@ -727,6 +753,13 @@ func (dit *dropIndexTask) PreExecute(ctx context.Context) error { } func (dit *dropIndexTask) Execute(ctx context.Context) error { + ctxLog := log.Ctx(ctx) + ctxLog.Info("proxy drop index", zap.Int64("collID", dit.collectionID), + zap.String("field_name", dit.FieldName), + zap.String("index_name", dit.IndexName), + zap.String("db_name", dit.DbName), + ) + var err error dit.result, err = dit.dataCoord.DropIndex(ctx, &indexpb.DropIndexRequest{ CollectionID: dit.collectionID, @@ -734,13 +767,18 @@ func (dit *dropIndexTask) Execute(ctx context.Context) error { IndexName: dit.IndexName, DropAll: false, }) + if err != nil { + ctxLog.Warn("drop index failed", zap.Error(err)) + return err + } if dit.result == nil { return errors.New("drop index resp is nil") } if dit.result.ErrorCode != commonpb.ErrorCode_Success { return errors.New(dit.result.Reason) } - return err + SendReplicateMessagePack(ctx, dit.replicateMsgStream, dit.DropIndexRequest) + return nil } func (dit *dropIndexTask) PostExecute(ctx context.Context) error { @@ -752,8 +790,8 @@ type getIndexBuildProgressTask struct { Condition *milvuspb.GetIndexBuildProgressRequest ctx context.Context - rootCoord types.RootCoord - dataCoord types.DataCoord + rootCoord types.RootCoordClient + dataCoord types.DataCoordClient result *milvuspb.GetIndexBuildProgressResponse collectionID UniqueID @@ -845,8 +883,8 @@ type getIndexStateTask struct { Condition *milvuspb.GetIndexStateRequest ctx context.Context - dataCoord types.DataCoord - rootCoord types.RootCoord + dataCoord types.DataCoordClient + rootCoord types.RootCoordClient result *milvuspb.GetIndexStateResponse collectionID UniqueID @@ -915,7 +953,7 @@ func (gist *getIndexStateTask) Execute(ctx context.Context) error { } gist.result = &milvuspb.GetIndexStateResponse{ - Status: merr.Status(nil), + Status: merr.Success(), State: state.GetState(), FailReason: state.GetFailReason(), } diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 8d7aa736e00b3..38c8e507940db 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -25,6 +25,7 @@ import ( "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -37,6 +38,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) func TestMain(m *testing.M) { @@ -56,9 +58,7 @@ func TestGetIndexStateTask_Execute(t *testing.T) { rootCoord := newMockRootCoord() queryCoord := getMockQueryCoord() queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), CollectionIDs: []int64{}, }, nil) datacoord := NewDataCoordMock() @@ -86,22 +86,18 @@ func TestGetIndexStateTask_Execute(t *testing.T) { _ = InitMetaCache(ctx, rootCoord, queryCoord, shardMgr) assert.Error(t, gist.Execute(ctx)) - rootCoord.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { + rootCoord.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { return &milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Schema: newTestSchema(), CollectionID: collectionID, CollectionName: request.CollectionName, }, nil } - datacoord.GetIndexStateFunc = func(ctx context.Context, request *indexpb.GetIndexStateRequest) (*indexpb.GetIndexStateResponse, error) { + datacoord.GetIndexStateFunc = func(ctx context.Context, request *indexpb.GetIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetIndexStateResponse, error) { return &indexpb.GetIndexStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), State: commonpb.IndexState_Finished, FailReason: "", }, nil @@ -120,9 +116,7 @@ func TestDropIndexTask_PreExecute(t *testing.T) { paramtable.Init() qc := getMockQueryCoord() qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), CollectionIDs: []int64{}, }, nil) dc := NewDataCoordMock() @@ -183,9 +177,7 @@ func TestDropIndexTask_PreExecute(t *testing.T) { t.Run("coll has been loaded", func(t *testing.T) { qc := getMockQueryCoord() qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), CollectionIDs: []int64{collectionID}, }, nil) dit.queryCoord = qc @@ -218,8 +210,8 @@ func TestDropIndexTask_PreExecute(t *testing.T) { }) } -func getMockQueryCoord() *mocks.MockQueryCoord { - qc := &mocks.MockQueryCoord{} +func getMockQueryCoord() *mocks.MockQueryCoordClient { + qc := &mocks.MockQueryCoordClient{} successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil) qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ @@ -327,7 +319,8 @@ func Test_parseIndexParams(t *testing.T) { Key: MetricTypeKey, Value: "L2", }, - }}, + }, + }, } t.Run("parse index params", func(t *testing.T) { @@ -411,7 +404,8 @@ func Test_parseIndexParams(t *testing.T) { Key: MetricTypeKey, Value: "L2", }, - }}, + }, + }, } t.Run("parse index params 2", func(t *testing.T) { Params.Save(Params.AutoIndexConfig.Enable.Key, "true") @@ -504,6 +498,154 @@ func Test_parseIndexParams(t *testing.T) { assert.Error(t, err) }) + t.Run("create index on VarChar field", func(t *testing.T) { + cit := &createIndexTask{ + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: DefaultStringIndexType, + }, + }, + IndexName: "", + }, + fieldSchema: &schemapb.FieldSchema{ + FieldID: 101, + Name: "FieldID", + IsPrimaryKey: false, + DataType: schemapb.DataType_VarChar, + }, + } + err := cit.parseIndexParams() + assert.NoError(t, err) + }) + + t.Run("create index on Arithmetic field", func(t *testing.T) { + cit := &createIndexTask{ + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: DefaultArithmeticIndexType, + }, + }, + IndexName: "", + }, + fieldSchema: &schemapb.FieldSchema{ + FieldID: 101, + Name: "FieldID", + IsPrimaryKey: false, + DataType: schemapb.DataType_Int64, + }, + } + err := cit.parseIndexParams() + assert.NoError(t, err) + }) + + // Compatible with the old version <= 2.3.0 + t.Run("create marisa-trie index on VarChar field", func(t *testing.T) { + cit := &createIndexTask{ + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "marisa-trie", + }, + }, + IndexName: "", + }, + fieldSchema: &schemapb.FieldSchema{ + FieldID: 101, + Name: "FieldID", + IsPrimaryKey: false, + DataType: schemapb.DataType_VarChar, + }, + } + err := cit.parseIndexParams() + assert.NoError(t, err) + }) + + // Compatible with the old version <= 2.3.0 + t.Run("create Asceneding index on Arithmetic field", func(t *testing.T) { + cit := &createIndexTask{ + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "Asceneding", + }, + }, + IndexName: "", + }, + fieldSchema: &schemapb.FieldSchema{ + FieldID: 101, + Name: "FieldID", + IsPrimaryKey: false, + DataType: schemapb.DataType_Int64, + }, + } + err := cit.parseIndexParams() + assert.NoError(t, err) + }) + + t.Run("create unsupported index on Arithmetic field", func(t *testing.T) { + cit := &createIndexTask{ + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "invalid_type", + }, + }, + IndexName: "", + }, + fieldSchema: &schemapb.FieldSchema{ + FieldID: 101, + Name: "FieldID", + IsPrimaryKey: false, + DataType: schemapb.DataType_Int64, + }, + } + err := cit.parseIndexParams() + assert.Error(t, err) + }) + + t.Run("create index on array field", func(t *testing.T) { + cit3 := &createIndexTask{ + Condition: nil, + req: &milvuspb.CreateIndexRequest{ + Base: nil, + DbName: "", + CollectionName: "", + FieldName: "", + ExtraParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "STL_SORT", + }, + }, + IndexName: "", + }, + ctx: nil, + rootCoord: nil, + result: nil, + isAutoIndex: false, + newIndexParams: nil, + newTypeParams: nil, + collectionID: 0, + fieldSchema: &schemapb.FieldSchema{ + FieldID: 101, + Name: "FieldID", + IsPrimaryKey: false, + Description: "field no.1", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + }, + } + err := cit3.parseIndexParams() + assert.Error(t, err) + }) + t.Run("pass vector index type on scalar field", func(t *testing.T) { cit4 := &createIndexTask{ Condition: nil, @@ -713,3 +855,28 @@ func Test_parseIndexParams_AutoIndex(t *testing.T) { assert.Error(t, err) }) } + +func newTestSchema() *schemapb.CollectionSchema { + fields := []*schemapb.FieldSchema{ + {FieldID: 0, Name: "FieldID", IsPrimaryKey: false, Description: "field no.1", DataType: schemapb.DataType_Int64}, + } + + for name, value := range schemapb.DataType_value { + dataType := schemapb.DataType(value) + if !typeutil.IsIntegerType(dataType) && !typeutil.IsFloatingType(dataType) && !typeutil.IsVectorType(dataType) && !typeutil.IsStringType(dataType) { + continue + } + newField := &schemapb.FieldSchema{ + FieldID: int64(100 + value), Name: name + "Field", IsPrimaryKey: false, Description: "", DataType: dataType, + } + fields = append(fields, newField) + } + + return &schemapb.CollectionSchema{ + Name: "test", + Description: "schema for test used", + AutoID: true, + Fields: fields, + EnableDynamicField: true, + } +} diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index b3ce3ba6b920c..aa710e3d6575c 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -98,7 +98,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { defer sp.End() it.result = &milvuspb.MutationResult{ - Status: merr.Status(nil), + Status: merr.Success(), IDs: &schemapb.IDs{ IdField: nil, }, @@ -198,7 +198,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { } } - if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck()). + if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck(), withMaxCapCheck()). Validate(it.insertMsg.GetFieldsData(), schema, it.insertMsg.NRows()); err != nil { return err } @@ -233,8 +233,7 @@ func (it *insertTask) Execute(ctx context.Context) error { channelNames, err := it.chMgr.getVChannels(collID) if err != nil { log.Warn("get vChannels failed", zap.Int64("collectionID", collID), zap.Error(err)) - it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - it.result.Status.Reason = err.Error() + it.result.Status = merr.Status(err) return err } @@ -255,8 +254,7 @@ func (it *insertTask) Execute(ctx context.Context) error { } if err != nil { log.Warn("assign segmentID and repack insert data failed", zap.Error(err)) - it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - it.result.Status.Reason = err.Error() + it.result.Status = merr.Status(err) return err } assignSegmentIDDur := tr.RecordSpan() @@ -266,8 +264,7 @@ func (it *insertTask) Execute(ctx context.Context) error { err = stream.Produce(msgPack) if err != nil { log.Warn("fail to produce insert msg", zap.Error(err)) - it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - it.result.Status.Reason = err.Error() + it.result.Status = merr.Status(err) return err } sendMsgDur := tr.RecordSpan() diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index a5e95a71b4d3a..ddc9390ea515d 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -2,7 +2,6 @@ package proxy import ( "context" - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -236,10 +235,8 @@ func TestInsertTask(t *testing.T) { mock.AnythingOfType("string"), ).Return(collectionID, nil) globalMetaCache = cache - chMgr := newMockChannelsMgr() - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return channels, nil - } + chMgr := NewMockChannelsMgr(t) + chMgr.EXPECT().getChannels(mock.Anything).Return(channels, nil) it := insertTask{ ctx: context.Background(), insertMsg: &msgstream.InsertMsg{ @@ -254,12 +251,5 @@ func TestInsertTask(t *testing.T) { resChannels := it.getChannels() assert.ElementsMatch(t, channels, resChannels) assert.ElementsMatch(t, channels, it.pChannels) - - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return nil, fmt.Errorf("mock err") - } - // get channels again, should return task's pChannels, so getChannelsFunc should not invoke again - resChannels = it.getChannels() - assert.ElementsMatch(t, channels, resChannels) }) } diff --git a/internal/proxy/task_policies.go b/internal/proxy/task_policies.go index c910eb51c59fb..56c662e34b1b7 100644 --- a/internal/proxy/task_policies.go +++ b/internal/proxy/task_policies.go @@ -14,13 +14,11 @@ import ( // type pickShardPolicy func(ctx context.Context, mgr shardClientMgr, query func(UniqueID, types.QueryNode) error, leaders []nodeInfo) error -type queryFunc func(context.Context, UniqueID, types.QueryNode, ...string) error +type queryFunc func(context.Context, UniqueID, types.QueryNodeClient, ...string) error type pickShardPolicy func(context.Context, shardClientMgr, queryFunc, map[string][]nodeInfo) error -var ( - errInvalidShardLeaders = errors.New("Invalid shard leader") -) +var errInvalidShardLeaders = errors.New("Invalid shard leader") // RoundRobinPolicy do the query with multiple dml channels // if request failed, it finds shard leader for failed dml channels @@ -28,8 +26,8 @@ func RoundRobinPolicy( ctx context.Context, mgr shardClientMgr, query queryFunc, - dml2leaders map[string][]nodeInfo) error { - + dml2leaders map[string][]nodeInfo, +) error { queryChannel := func(ctx context.Context, channel string) error { var combineErr error leaders := dml2leaders[channel] diff --git a/internal/proxy/task_policies_test.go b/internal/proxy/task_policies_test.go index ec0ba5fa90c2c..5c4b1732824ac 100644 --- a/internal/proxy/task_policies_test.go +++ b/internal/proxy/task_policies_test.go @@ -16,9 +16,7 @@ import ( func TestRoundRobinPolicy(t *testing.T) { var err error - var ( - ctx = context.TODO() - ) + ctx := context.TODO() mgr := newShardClientMgr() @@ -59,7 +57,7 @@ type mockQuery struct { failset map[UniqueID]error } -func (m *mockQuery) query(_ context.Context, nodeID UniqueID, qn types.QueryNode, chs ...string) error { +func (m *mockQuery) query(_ context.Context, nodeID UniqueID, qn types.QueryNodeClient, chs ...string) error { m.mu.Lock() defer m.mu.Unlock() if err, ok := m.failset[nodeID]; ok { diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index e1c642e8a48c9..ce13ee9885cb3 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -24,6 +24,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/tsoutil" @@ -47,7 +48,7 @@ type queryTask struct { ctx context.Context result *milvuspb.QueryResults request *milvuspb.QueryRequest - qc types.QueryCoord + qc types.QueryCoordClient ids *schemapb.IDs collectionName string queryParams *queryParams @@ -63,8 +64,9 @@ type queryTask struct { } type queryParams struct { - limit int64 - offset int64 + limit int64 + offset int64 + reduceStopForBest bool } // translateToOutputFieldIDs translates output fields name to output fields id. @@ -72,7 +74,7 @@ func translateToOutputFieldIDs(outputFields []string, schema *schemapb.Collectio outputFieldIDs := make([]UniqueID, 0, len(outputFields)+1) if len(outputFields) == 0 { for _, field := range schema.Fields { - if field.FieldID >= common.StartOfUserFieldID && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector { + if field.FieldID >= common.StartOfUserFieldID && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector && field.DataType != schemapb.DataType_Float16Vector { outputFieldIDs = append(outputFieldIDs, field.FieldID) } } @@ -109,7 +111,6 @@ func translateToOutputFieldIDs(outputFields []string, schema *schemapb.Collectio if !pkFound { outputFieldIDs = append(outputFieldIDs, pkFieldID) } - } return outputFieldIDs, nil } @@ -127,15 +128,25 @@ func filterSystemFields(outputFieldIDs []UniqueID) []UniqueID { // parseQueryParams get limit and offset from queryParamsPair, both are optional. func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, error) { var ( - limit int64 - offset int64 - err error + limit int64 + offset int64 + reduceStopForBest bool + err error ) + reduceStopForBestStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ReduceStopForBestKey, queryParamsPair) + // if reduce_stop_for_best is provided + if err == nil { + reduceStopForBest, err = strconv.ParseBool(reduceStopForBestStr) + if err != nil { + return nil, merr.WrapErrParameterInvalid("true or false", reduceStopForBestStr, + "value for reduce_stop_for_best is invalid") + } + } limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair) // if limit is not provided if err != nil { - return &queryParams{limit: typeutil.Unlimited}, nil + return &queryParams{limit: typeutil.Unlimited, reduceStopForBest: reduceStopForBest}, nil } limit, err = strconv.ParseInt(limitStr, 0, 64) if err != nil { @@ -157,8 +168,9 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e } return &queryParams{ - limit: limit, - offset: offset, + limit: limit, + offset: offset, + reduceStopForBest: reduceStopForBest, }, nil } @@ -279,24 +291,12 @@ func (t *queryTask) PreExecute(ctx context.Context) error { } t.RetrieveRequest.IgnoreGrowing = ignoreGrowing - // fetch iteration_extension_reduce_rate from query param - var iterationExtensionReduceRate int64 - for i, kv := range t.request.GetQueryParams() { - if kv.GetKey() == IterationExtensionReduceRateKey { - iterationExtensionReduceRate, err = strconv.ParseInt(kv.Value, 0, 64) - if err != nil { - return errors.New("parse query iteration_extension_reduce_rate failed") - } - t.request.QueryParams = append(t.request.GetQueryParams()[:i], t.request.GetQueryParams()[i+1:]...) - break - } - } - t.RetrieveRequest.IterationExtensionReduceRate = iterationExtensionReduceRate - queryParams, err := parseQueryParams(t.request.GetQueryParams()) if err != nil { return err } + t.RetrieveRequest.ReduceStopForBest = queryParams.reduceStopForBest + t.queryParams = queryParams t.RetrieveRequest.Limit = queryParams.limit + queryParams.offset @@ -464,7 +464,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error { return nil } -func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs ...string) error { +func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channelIDs ...string) error { retrieveReq := typeutil.Clone(t.RetrieveRequest) retrieveReq.GetBase().TargetID = nodeID req := &querypb.QueryRequest{ @@ -517,7 +517,7 @@ func IDs2Expr(fieldName string, ids *schemapb.IDs) string { } func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, queryParams *queryParams) (*milvuspb.QueryResults, error) { - log.Ctx(ctx).Debug("reduceInternelRetrieveResults", zap.Int("len(retrieveResults)", len(retrieveResults))) + log.Ctx(ctx).Debug("reduceInternalRetrieveResults", zap.Int("len(retrieveResults)", len(retrieveResults))) var ( ret = &milvuspb.QueryResults{} @@ -543,12 +543,15 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re idSet := make(map[interface{}]struct{}) cursors := make([]int64, len(validRetrieveResults)) + realLimit := typeutil.Unlimited if queryParams != nil && queryParams.limit != typeutil.Unlimited { - loopEnd = int(queryParams.limit) - + realLimit = queryParams.limit + if !queryParams.reduceStopForBest { + loopEnd = int(queryParams.limit) + } if queryParams.offset > 0 { for i := int64(0); i < queryParams.offset; i++ { - sel := typeutil.SelectMinPK(validRetrieveResults, cursors) + sel := typeutil.SelectMinPK(validRetrieveResults, cursors, queryParams.reduceStopForBest, realLimit) if sel == -1 { return ret, nil } @@ -556,16 +559,22 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re } } } + reduceStopForBest := false + if queryParams != nil { + reduceStopForBest = queryParams.reduceStopForBest + } + var retSize int64 + maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd; j++ { - sel := typeutil.SelectMinPK(validRetrieveResults, cursors) + sel := typeutil.SelectMinPK(validRetrieveResults, cursors, reduceStopForBest, realLimit) if sel == -1 { break } pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel]) if _, ok := idSet[pk]; !ok { - typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) + retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) idSet[pk] = struct{}{} } else { // primary keys duplicate @@ -573,8 +582,8 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re } // limit retrieve result to avoid oom - if int64(proto.Size(ret)) > paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() { - return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()) + if retSize > maxOutputSize { + return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", maxOutputSize) } cursors[sel]++ diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 9625e20d8f763..7cef1d2a4e112 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -38,6 +38,7 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -48,8 +49,8 @@ func TestQueryTask_all(t *testing.T) { ctx = context.TODO() rc = NewRootCoordMock() - qc = mocks.NewMockQueryCoord(t) - qn = getQueryNode() + qc = mocks.NewMockQueryCoordClient(t) + qn = getQueryNodeClient() shardsNum = common.DefaultShardsNum collectionName = t.Name() + funcutil.GenRandomStr() @@ -58,11 +59,9 @@ func TestQueryTask_all(t *testing.T) { hitNum = 10 ) - qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe() successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} - qc.EXPECT().Start().Return(nil) - qc.EXPECT().Stop().Return(nil) qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&successStatus, nil) qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ Status: &successStatus, @@ -80,11 +79,7 @@ func TestQueryTask_all(t *testing.T) { mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() lb := NewLBPolicyImpl(mgr) - rc.Start() - defer rc.Stop() - qc.Start() - defer qc.Stop() - + defer rc.Close() err = InitMetaCache(ctx, rc, qc, mgr) assert.NoError(t, err) @@ -146,9 +141,7 @@ func TestQueryTask_all(t *testing.T) { }, ctx: ctx, result: &milvuspb.QueryResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), FieldsData: []*schemapb.FieldData{}, }, request: &milvuspb.QueryRequest{ @@ -163,10 +156,6 @@ func TestQueryTask_all(t *testing.T) { Key: IgnoreGrowingKey, Value: "false", }, - { - Key: IterationExtensionReduceRateKey, - Value: "10", - }, }, }, qc: qc, @@ -185,19 +174,17 @@ func TestQueryTask_all(t *testing.T) { // after preExecute assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp) - // check iteration extension reduce - assert.Equal(t, int64(10), task.RetrieveRequest.IterationExtensionReduceRate) + // check reduce_stop_for_best + assert.Equal(t, false, task.RetrieveRequest.GetReduceStopForBest()) task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{ - Key: IterationExtensionReduceRateKey, - Value: "10XXX", + Key: ReduceStopForBestKey, + Value: "trxxxx", }) assert.Error(t, task.PreExecute(ctx)) result1 := &internalpb.RetrieveResults{ - Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult}, - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult}, + Status: merr.Success(), Ids: &schemapb.IDs{ IdField: &schemapb.IDs_IntId{ IntId: &schemapb.LongArray{Data: generateInt64Array(hitNum)}, @@ -217,12 +204,12 @@ func TestQueryTask_all(t *testing.T) { task.RetrieveRequest.OutputFieldsId = append(task.RetrieveRequest.OutputFieldsId, common.TimeStampField) task.ctx = ctx qn.ExpectedCalls = nil - qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) assert.Error(t, task.Execute(ctx)) qn.ExpectedCalls = nil - qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_NotShardLeader, @@ -232,7 +219,7 @@ func TestQueryTask_all(t *testing.T) { assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error())) qn.ExpectedCalls = nil - qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -241,7 +228,7 @@ func TestQueryTask_all(t *testing.T) { assert.Error(t, task.Execute(ctx)) qn.ExpectedCalls = nil - qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Query(mock.Anything, mock.Anything).Return(result1, nil) assert.NoError(t, task.Execute(ctx)) @@ -431,7 +418,6 @@ func TestTaskQuery_functions(t *testing.T) { Key: test.inKey[i], Value: test.inValue[i], }) - } ret, err := parseQueryParams(inParams) if test.expectErr { @@ -526,7 +512,8 @@ func TestTaskQuery_functions(t *testing.T) { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0, - 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} + 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0, + } t.Run("test limited", func(t *testing.T) { tests := []struct { @@ -548,7 +535,6 @@ func TestTaskQuery_functions(t *testing.T) { assert.NoError(t, err) }) } - }) t.Run("test unLimited and maxOutputSize", func(t *testing.T) { @@ -599,6 +585,28 @@ func TestTaskQuery_functions(t *testing.T) { }) } }) + + t.Run("test stop reduce for best for limit", func(t *testing.T) { + result, err := reduceRetrieveResults(context.Background(), + []*internalpb.RetrieveResults{r1, r2}, + &queryParams{limit: 2, reduceStopForBest: true}) + assert.NoError(t, err) + assert.Equal(t, 2, len(result.GetFieldsData())) + assert.Equal(t, []int64{11, 11, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + len := len(result.GetFieldsData()[0].GetScalars().GetLongData().Data) + assert.InDeltaSlice(t, resultFloat[0:(len)*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + }) + + t.Run("test stop reduce for best for unlimited set", func(t *testing.T) { + result, err := reduceRetrieveResults(context.Background(), + []*internalpb.RetrieveResults{r1, r2}, + &queryParams{limit: typeutil.Unlimited, reduceStopForBest: true}) + assert.NoError(t, err) + assert.Equal(t, 2, len(result.GetFieldsData())) + assert.Equal(t, []int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + len := len(result.GetFieldsData()[0].GetScalars().GetLongData().Data) + assert.InDeltaSlice(t, resultFloat[0:(len)*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + }) }) }) } @@ -857,7 +865,6 @@ func Test_queryTask_createPlan(t *testing.T) { }) t.Run("query without expression", func(t *testing.T) { - tsk := &queryTask{ request: &milvuspb.QueryRequest{ OutputFields: []string{"a"}, @@ -868,7 +875,6 @@ func Test_queryTask_createPlan(t *testing.T) { }) t.Run("invalid expression", func(t *testing.T) { - schema := &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { @@ -892,7 +898,6 @@ func Test_queryTask_createPlan(t *testing.T) { }) t.Run("invalid output fields", func(t *testing.T) { - schema := &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { diff --git a/internal/proxy/task_scheduler.go b/internal/proxy/task_scheduler.go index 0467dfd4e978b..f28a3a7e81e38 100644 --- a/internal/proxy/task_scheduler.go +++ b/internal/proxy/task_scheduler.go @@ -20,13 +20,16 @@ import ( "container/list" "context" "sync" + "time" - "github.com/cockroachdb/errors" "go.opentelemetry.io/otel" "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/conc" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -83,7 +86,7 @@ func (queue *baseTaskQueue) addUnissuedTask(t task) error { defer queue.utLock.Unlock() if queue.utFull() { - return errors.New("task queue is full") + return merr.WrapErrServiceRequestLimitExceeded(int32(queue.getMaxTaskNum())) } queue.unissuedTasks.PushBack(t) queue.utBufChan <- 1 @@ -227,7 +230,7 @@ func (queue *dmTaskQueue) Enqueue(t task) error { // 1) Protect member pChanStatisticsInfos // 2) Serialize the timestamp allocation for dml tasks - //1. set the current pChannels for this dmTask + // 1. set the current pChannels for this dmTask dmt := t.(dmlTask) err := dmt.setChannels() if err != nil { @@ -235,19 +238,19 @@ func (queue *dmTaskQueue) Enqueue(t task) error { return err } - //2. enqueue dml task + // 2. enqueue dml task queue.statsLock.Lock() defer queue.statsLock.Unlock() err = queue.baseTaskQueue.Enqueue(t) if err != nil { return err } - //3. commit will use pChannels got previously when preAdding and will definitely succeed + // 3. commit will use pChannels got previously when preAdding and will definitely succeed pChannels := dmt.getChannels() queue.commitPChanStats(dmt, pChannels) - //there's indeed a possibility that the collection info cache was expired after preAddPChanStats - //but considering root coord knows everything about meta modification, invalid stats appended after the meta changed - //will be discarded by root coord and will not lead to inconsistent state + // there's indeed a possibility that the collection info cache was expired after preAddPChanStats + // but considering root coord knows everything about meta modification, invalid stats appended after the meta changed + // will be discarded by root coord and will not lead to inconsistent state return nil } @@ -269,7 +272,7 @@ func (queue *dmTaskQueue) PopActiveTask(taskID UniqueID) task { } func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) { - //1. prepare new stat for all pChannels + // 1. prepare new stat for all pChannels newStats := make(map[pChan]pChanStatistics) beginTs := dmt.BeginTs() endTs := dmt.EndTs() @@ -279,7 +282,7 @@ func (queue *dmTaskQueue) commitPChanStats(dmt dmlTask, pChannels []pChan) { maxTs: endTs, } } - //2. update stats for all pChannels + // 2. update stats for all pChannels for cName, newStat := range newStats { currentStat, ok := queue.pChanStatisticsInfos[cName] if !ok { @@ -325,7 +328,6 @@ func (queue *dmTaskQueue) popPChanStats(t task) { } func (queue *dmTaskQueue) getPChanStatsInfo() (map[pChan]*pChanStatistics, error) { - ret := make(map[pChan]*pChanStatistics) queue.statsLock.RLock() defer queue.statsLock.RUnlock() @@ -373,6 +375,9 @@ type taskScheduler struct { dmQueue *dmTaskQueue dqQueue *dqTaskQueue + // data control queue, use for such as flush operation, which control the data status + dcQueue *ddTaskQueue + wg sync.WaitGroup ctx context.Context cancel context.CancelFunc @@ -397,6 +402,8 @@ func newTaskScheduler(ctx context.Context, s.dmQueue = newDmTaskQueue(tsoAllocatorIns) s.dqQueue = newDqTaskQueue(tsoAllocatorIns) + s.dcQueue = newDdTaskQueue(tsoAllocatorIns) + for _, opt := range opts { opt(s) } @@ -408,6 +415,10 @@ func (sched *taskScheduler) scheduleDdTask() task { return sched.ddQueue.PopUnissuedTask() } +func (sched *taskScheduler) scheduleDcTask() task { + return sched.dcQueue.PopUnissuedTask() +} + func (sched *taskScheduler) scheduleDmTask() task { return sched.dmQueue.PopUnissuedTask() } @@ -416,19 +427,6 @@ func (sched *taskScheduler) scheduleDqTask() task { return sched.dqQueue.PopUnissuedTask() } -func (sched *taskScheduler) getTaskByReqID(reqID UniqueID) task { - if t := sched.ddQueue.getTaskByReqID(reqID); t != nil { - return t - } - if t := sched.dmQueue.getTaskByReqID(reqID); t != nil { - return t - } - if t := sched.dqQueue.getTaskByReqID(reqID); t != nil { - return t - } - return nil -} - func (sched *taskScheduler) processTask(t task, q taskQueue) { ctx, span := otel.Tracer(typeutil.ProxyRole).Start(t.TraceCtx(), t.Name()) defer span.End() @@ -487,8 +485,25 @@ func (sched *taskScheduler) definitionLoop() { } } +// controlLoop schedule the data control operation, such as flush +func (sched *taskScheduler) controlLoop() { + defer sched.wg.Done() + for { + select { + case <-sched.ctx.Done(): + return + case <-sched.dcQueue.utChan(): + if !sched.dcQueue.utEmpty() { + t := sched.scheduleDcTask() + sched.processTask(t, sched.dcQueue) + } + } + } +} + func (sched *taskScheduler) manipulationLoop() { defer sched.wg.Done() + for { select { case <-sched.ctx.Done(): @@ -505,6 +520,7 @@ func (sched *taskScheduler) manipulationLoop() { func (sched *taskScheduler) queryLoop() { defer sched.wg.Done() + pool := conc.NewPool[struct{}](paramtable.Get().ProxyCfg.MaxTaskNum.GetAsInt(), conc.WithExpiryDuration(time.Minute)) for { select { case <-sched.ctx.Done(): @@ -512,7 +528,10 @@ func (sched *taskScheduler) queryLoop() { case <-sched.dqQueue.utChan(): if !sched.dqQueue.utEmpty() { t := sched.scheduleDqTask() - go sched.processTask(t, sched.dqQueue) + pool.Submit(func() (struct{}, error) { + sched.processTask(t, sched.dqQueue) + return struct{}{}, nil + }) } else { log.Debug("query queue is empty ...") } @@ -524,6 +543,9 @@ func (sched *taskScheduler) Start() error { sched.wg.Add(1) go sched.definitionLoop() + sched.wg.Add(1) + go sched.controlLoop() + sched.wg.Add(1) go sched.manipulationLoop() diff --git a/internal/proxy/task_scheduler_test.go b/internal/proxy/task_scheduler_test.go index 4bba7033dfdcf..2a04ea31994c1 100644 --- a/internal/proxy/task_scheduler_test.go +++ b/internal/proxy/task_scheduler_test.go @@ -34,7 +34,6 @@ import ( ) func TestBaseTaskQueue(t *testing.T) { - var err error var unissuedTask task var activeTask task @@ -111,7 +110,6 @@ func TestBaseTaskQueue(t *testing.T) { } func TestDdTaskQueue(t *testing.T) { - var err error var unissuedTask task var activeTask task @@ -189,7 +187,6 @@ func TestDdTaskQueue(t *testing.T) { // test the logic of queue func TestDmTaskQueue_Basic(t *testing.T) { - var err error var unissuedTask task var activeTask task @@ -266,7 +263,6 @@ func TestDmTaskQueue_Basic(t *testing.T) { // test the timestamp statistics func TestDmTaskQueue_TimestampStatistics(t *testing.T) { - var err error var unissuedTask task @@ -394,7 +390,7 @@ func TestDmTaskQueue_TimestampStatistics2(t *testing.T) { }() } wg.Wait() - //time.Sleep(time.Millisecond*100) + // time.Sleep(time.Millisecond*100) needLoop := true for needLoop { processCountMut.RLock() @@ -413,7 +409,6 @@ func TestDmTaskQueue_TimestampStatistics2(t *testing.T) { } func TestDqTaskQueue(t *testing.T) { - var err error var unissuedTask task var activeTask task @@ -490,7 +485,6 @@ func TestDqTaskQueue(t *testing.T) { } func TestTaskScheduler(t *testing.T) { - var err error ctx := context.Background() @@ -581,10 +575,8 @@ func TestTaskScheduler_concurrentPushAndPop(t *testing.T) { run := func(wg *sync.WaitGroup) { defer wg.Done() - chMgr := newMockChannelsMgr() - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return channels, nil - } + chMgr := NewMockChannelsMgr(t) + chMgr.EXPECT().getChannels(mock.Anything).Return(channels, nil) it := &insertTask{ ctx: context.Background(), insertMsg: &msgstream.InsertMsg{ @@ -599,9 +591,7 @@ func TestTaskScheduler_concurrentPushAndPop(t *testing.T) { assert.NoError(t, err) task := scheduler.scheduleDmTask() scheduler.dmQueue.AddActiveTask(task) - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return nil, fmt.Errorf("mock err") - } + chMgr.EXPECT().getChannels(mock.Anything).Return(nil, fmt.Errorf("mock err")) scheduler.dmQueue.PopActiveTask(task.ID()) // assert no panic } diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 1265f0fd81b76..b5f7493914f59 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -63,7 +63,7 @@ type searchTask struct { offset int64 resultBuf *typeutil.ConcurrentSet[*internalpb.SearchResults] - qc types.QueryCoord + qc types.QueryCoordClient node types.ProxyComponent lb LBPolicy } @@ -507,7 +507,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error { return nil } -func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs ...string) error { +func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channelIDs ...string) error { searchReq := typeutil.Clone(t.SearchRequest) searchReq.GetBase().TargetID = nodeID req := &querypb.SearchRequest{ @@ -639,10 +639,7 @@ func (t *searchTask) Requery() error { func (t *searchTask) fillInEmptyResult(numQueries int64) { t.result = &milvuspb.SearchResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "search result is empty", - }, + Status: merr.Success("search result is empty"), CollectionName: t.collectionName, Results: &schemapb.SearchResultData{ NumQueries: numQueries, @@ -770,7 +767,7 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb zap.String("metricType", metricType)) ret := &milvuspb.SearchResults{ - Status: merr.Status(nil), + Status: merr.Success(), Results: &schemapb.SearchResultData{ NumQueries: nq, TopK: topk, @@ -799,10 +796,12 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb } for i, sData := range subSearchResultData { + pkLength := typeutil.GetSizeOfIDs(sData.GetIds()) log.Ctx(ctx).Debug("subSearchResultData", zap.Int("result No.", i), zap.Int64("nq", sData.NumQueries), zap.Int64("topk", sData.TopK), + zap.Int("length of pks", pkLength), zap.Any("length of FieldsData", len(sData.FieldsData))) if err := checkSearchResultData(sData, nq, topk); err != nil { log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) @@ -828,9 +827,10 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb realTopK int64 = -1 ) + var retSize int64 + maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() // reducing nq * topk results for i := int64(0); i < nq; i++ { - var ( // cursor of current data of each subSearch for merging the j-th data of TopK. // sum(cursors) == j @@ -865,7 +865,7 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb // remove duplicates if _, ok := idSet[id]; !ok { - typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx) + retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx) typeutil.AppendPKs(ret.Results.Ids, id) ret.Results.Scores = append(ret.Results.Scores, score) idSet[id] = struct{}{} @@ -884,8 +884,8 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb ret.Results.Topks = append(ret.Results.Topks, realTopK) // limit search result to avoid oom - if int64(proto.Size(ret)) > paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() { - return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()) + if retSize > maxOutputSize { + return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize) } } log.Ctx(ctx).Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 226225d998a87..d4edc62269f06 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -38,6 +38,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" @@ -69,11 +70,11 @@ func TestSearchTask_PostExecute(t *testing.T) { err := qt.PostExecute(context.TODO()) assert.NoError(t, err) - assert.Equal(t, qt.result.Status.ErrorCode, commonpb.ErrorCode_Success) + assert.Equal(t, qt.result.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) }) } -func createColl(t *testing.T, name string, rc types.RootCoord) { +func createColl(t *testing.T, name string, rc types.RootCoordClient) { schema := constructCollectionSchema(testInt64Field, testFloatVecField, testVecDim, name) marshaledSchema, err := proto.Marshal(schema) require.NoError(t, err) @@ -134,7 +135,8 @@ func getValidSearchParams() []*commonpb.KeyValuePair { { Key: IgnoreGrowingKey, Value: "false", - }} + }, + } } func getInvalidSearchParams(invalidName string) []*commonpb.KeyValuePair { @@ -152,12 +154,11 @@ func TestSearchTask_PreExecute(t *testing.T) { var ( rc = NewRootCoordMock() - qc = mocks.NewMockQueryCoord(t) + qc = mocks.NewMockQueryCoordClient(t) ctx = context.TODO() ) - err = rc.Start() - defer rc.Stop() + defer rc.Close() require.NoError(t, err) mgr := newShardClientMgr() err = InitMetaCache(ctx, rc, qc, mgr) @@ -265,34 +266,41 @@ func getQueryCoord() *mocks.MockQueryCoord { return qc } +func getQueryCoordClient() *mocks.MockQueryCoordClient { + qc := &mocks.MockQueryCoordClient{} + qc.EXPECT().Close().Return(nil) + return qc +} + func getQueryNode() *mocks.MockQueryNode { qn := &mocks.MockQueryNode{} return qn } -func TestSearchTaskV2_Execute(t *testing.T) { +func getQueryNodeClient() *mocks.MockQueryNodeClient { + qn := &mocks.MockQueryNodeClient{} + return qn +} + +func TestSearchTaskV2_Execute(t *testing.T) { var ( err error rc = NewRootCoordMock() - qc = getQueryCoord() + qc = getQueryCoordClient() ctx = context.TODO() collectionName = t.Name() + funcutil.GenRandomStr() ) - err = rc.Start() - require.NoError(t, err) - defer rc.Stop() + defer rc.Close() mgr := newShardClientMgr() err = InitMetaCache(ctx, rc, qc, mgr) require.NoError(t, err) - err = qc.Start() - require.NoError(t, err) - defer qc.Stop() + defer qc.Close() task := &searchTask{ ctx: ctx, @@ -1119,62 +1127,79 @@ func Test_checkSearchResultData(t *testing.T) { args args }{ - {"data.NumQueries != nq", true, + { + "data.NumQueries != nq", true, args{ data: &schemapb.SearchResultData{NumQueries: 100}, nq: 10, - }}, - {"data.TopK != topk", true, + }, + }, + { + "data.TopK != topk", true, args{ data: &schemapb.SearchResultData{NumQueries: 1, TopK: 1}, nq: 1, topk: 10, - }}, - {"size of IntId != NumQueries * TopK", true, + }, + }, + { + "size of IntId != NumQueries * TopK", true, args{ data: &schemapb.SearchResultData{ NumQueries: 1, TopK: 1, Ids: &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1, 2}}}}, + IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1, 2}}}, + }, }, nq: 1, topk: 1, - }}, - {"size of StrID != NumQueries * TopK", true, + }, + }, + { + "size of StrID != NumQueries * TopK", true, args{ data: &schemapb.SearchResultData{ NumQueries: 1, TopK: 1, Ids: &schemapb.IDs{ - IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: []string{"1", "2"}}}}, + IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: []string{"1", "2"}}}, + }, }, nq: 1, topk: 1, - }}, - {"size of score != nq * topK", true, + }, + }, + { + "size of score != nq * topK", true, args{ data: &schemapb.SearchResultData{ NumQueries: 1, TopK: 1, Ids: &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1}}}}, + IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1}}}, + }, Scores: []float32{0.99, 0.98}, }, nq: 1, topk: 1, - }}, - {"correct params", false, + }, + }, + { + "correct params", false, args{ data: &schemapb.SearchResultData{ NumQueries: 1, TopK: 1, Ids: &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1}}}}, - Scores: []float32{0.99}}, + IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1}}}, + }, + Scores: []float32{0.99}, + }, nq: 1, topk: 1, - }}, + }, + }, } for _, test := range tests { @@ -1411,21 +1436,31 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) { outScore []float32 outData []int64 }{ - {"offset 0, limit 5", 0, 5, + { + "offset 0, limit 5", 0, 5, []float32{-50, -49, -48, -47, -46, -45, -44, -43, -42, -41}, - []int64{50, 49, 48, 47, 46, 45, 44, 43, 42, 41}}, - {"offset 1, limit 4", 1, 4, + []int64{50, 49, 48, 47, 46, 45, 44, 43, 42, 41}, + }, + { + "offset 1, limit 4", 1, 4, []float32{-49, -48, -47, -46, -44, -43, -42, -41}, - []int64{49, 48, 47, 46, 44, 43, 42, 41}}, - {"offset 2, limit 3", 2, 3, + []int64{49, 48, 47, 46, 44, 43, 42, 41}, + }, + { + "offset 2, limit 3", 2, 3, []float32{-48, -47, -46, -43, -42, -41}, - []int64{48, 47, 46, 43, 42, 41}}, - {"offset 3, limit 2", 3, 2, + []int64{48, 47, 46, 43, 42, 41}, + }, + { + "offset 3, limit 2", 3, 2, []float32{-47, -46, -42, -41}, - []int64{47, 46, 42, 41}}, - {"offset 4, limit 1", 4, 1, + []int64{47, 46, 42, 41}, + }, + { + "offset 4, limit 1", 4, 1, []float32{-46, -41}, - []int64{46, 41}}, + []int64{46, 41}, + }, } var results []*schemapb.SearchResultData @@ -1459,24 +1494,36 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) { outScore []float32 outData []int64 }{ - {"offset 0, limit 6", 0, 6, 5, + { + "offset 0, limit 6", 0, 6, 5, []float32{-50, -49, -48, -47, -46, -45, -44, -43, -42, -41}, - []int64{50, 49, 48, 47, 46, 45, 44, 43, 42, 41}}, - {"offset 1, limit 5", 1, 5, 4, + []int64{50, 49, 48, 47, 46, 45, 44, 43, 42, 41}, + }, + { + "offset 1, limit 5", 1, 5, 4, []float32{-49, -48, -47, -46, -44, -43, -42, -41}, - []int64{49, 48, 47, 46, 44, 43, 42, 41}}, - {"offset 2, limit 4", 2, 4, 3, + []int64{49, 48, 47, 46, 44, 43, 42, 41}, + }, + { + "offset 2, limit 4", 2, 4, 3, []float32{-48, -47, -46, -43, -42, -41}, - []int64{48, 47, 46, 43, 42, 41}}, - {"offset 3, limit 3", 3, 3, 2, + []int64{48, 47, 46, 43, 42, 41}, + }, + { + "offset 3, limit 3", 3, 3, 2, []float32{-47, -46, -42, -41}, - []int64{47, 46, 42, 41}}, - {"offset 4, limit 2", 4, 2, 1, + []int64{47, 46, 42, 41}, + }, + { + "offset 4, limit 2", 4, 2, 1, []float32{-46, -41}, - []int64{46, 41}}, - {"offset 5, limit 1", 5, 1, 0, + []int64{46, 41}, + }, + { + "offset 5, limit 1", 5, 1, 0, []float32{}, - []int64{}}, + []int64{}, + }, } for _, test := range lessThanLimitTests { @@ -1543,30 +1590,26 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) { } func TestSearchTask_ErrExecute(t *testing.T) { - var ( err error ctx = context.TODO() rc = NewRootCoordMock() - qc = getQueryCoord() - qn = getQueryNode() + qc = getQueryCoordClient() + qn = getQueryNodeClient() shardsNum = int32(2) collectionName = t.Name() + funcutil.GenRandomStr() ) - qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() mgr := NewMockShardClientManager(t) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() lb := NewLBPolicyImpl(mgr) - rc.Start() - defer rc.Stop() - qc.Start() - defer qc.Stop() + defer qc.Close() err = InitMetaCache(ctx, rc, qc, mgr) assert.NoError(t, err) @@ -1646,9 +1689,7 @@ func TestSearchTask_ErrExecute(t *testing.T) { }, ctx: ctx, result: &milvuspb.SearchResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, request: &milvuspb.SearchRequest{ Base: &commonpb.MsgBase{ @@ -1674,7 +1715,7 @@ func TestSearchTask_ErrExecute(t *testing.T) { assert.Error(t, task.Execute(ctx)) qn.ExpectedCalls = nil - qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_NotShardLeader, @@ -1684,7 +1725,7 @@ func TestSearchTask_ErrExecute(t *testing.T) { assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error())) qn.ExpectedCalls = nil - qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -1693,11 +1734,9 @@ func TestSearchTask_ErrExecute(t *testing.T) { assert.Error(t, task.Execute(ctx)) qn.ExpectedCalls = nil - qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil) assert.NoError(t, task.Execute(ctx)) } @@ -1751,7 +1790,8 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { t.Run("parseSearchInfo error", func(t *testing.T) { spNoTopk := []*commonpb.KeyValuePair{{ Key: AnnsFieldKey, - Value: testFloatVecField}} + Value: testFloatVecField, + }} spInvalidTopk := append(spNoTopk, &commonpb.KeyValuePair{ Key: TopKKey, @@ -1871,19 +1911,20 @@ func TestSearchTask_Requery(t *testing.T) { node := mocks.NewMockProxy(t) node.EXPECT().Query(mock.Anything, mock.Anything). Return(&milvuspb.QueryResults{ - FieldsData: []*schemapb.FieldData{{ - Type: schemapb.DataType_Int64, - FieldName: pkField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{ - Data: ids, + FieldsData: []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + FieldName: pkField, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: ids, + }, }, }, }, }, - }, newFloatVectorFieldData(vecField, rows, dim), }, }, nil) @@ -2034,19 +2075,20 @@ func TestSearchTask_Requery(t *testing.T) { node := mocks.NewMockProxy(t) node.EXPECT().Query(mock.Anything, mock.Anything). Return(&milvuspb.QueryResults{ - FieldsData: []*schemapb.FieldData{{ - Type: schemapb.DataType_Int64, - FieldName: pkField, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{ - Data: ids[:len(ids)-1], + FieldsData: []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + FieldName: pkField, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: ids[:len(ids)-1], + }, }, }, }, }, - }, newFloatVectorFieldData(vecField, rows, dim), }, }, nil) diff --git a/internal/proxy/task_statistic.go b/internal/proxy/task_statistic.go index efe7a42a5a28e..9a676f692e4ba 100644 --- a/internal/proxy/task_statistic.go +++ b/internal/proxy/task_statistic.go @@ -43,7 +43,7 @@ type getStatisticsTask struct { unloadedPartitionIDs []UniqueID ctx context.Context - dc types.DataCoord + dc types.DataCoordClient tr *timerecord.TimeRecorder fromDataCoord bool @@ -51,7 +51,7 @@ type getStatisticsTask struct { // if query from shard *internalpb.GetStatisticsRequest - qc types.QueryCoord + qc types.QueryCoordClient resultBuf *typeutil.ConcurrentSet[*internalpb.GetStatisticsResponse] lb LBPolicy @@ -216,7 +216,7 @@ func (g *getStatisticsTask) PostExecute(ctx context.Context) error { return err } g.result = &milvuspb.GetStatisticsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Stats: result, } @@ -241,14 +241,14 @@ func (g *getStatisticsTask) getStatisticsFromDataCoord(ctx context.Context) erro if err != nil { return err } - if result.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(result.Status.Reason) + if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return merr.Error(result.GetStatus()) } if g.resultBuf == nil { g.resultBuf = typeutil.NewConcurrentSet[*internalpb.GetStatisticsResponse]() } g.resultBuf.Insert(&internalpb.GetStatisticsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Stats: result.Stats, }) return nil @@ -266,7 +266,6 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro nq: 1, exec: g.getStatisticsShard, }) - if err != nil { return errors.Wrap(err, "failed to statistic") } @@ -274,7 +273,7 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro return nil } -func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs ...string) error { +func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channelIDs ...string) error { nodeReq := proto.Clone(g.GetStatisticsRequest).(*internalpb.GetStatisticsRequest) nodeReq.Base.TargetID = nodeID req := &querypb.GetStatisticsRequest{ @@ -312,7 +311,7 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64 // checkFullLoaded check if collection / partition was fully loaded into QueryNode // return loaded partitions, unloaded partitions and error -func checkFullLoaded(ctx context.Context, qc types.QueryCoord, dbName string, collectionName string, collectionID int64, searchPartitionIDs []UniqueID) ([]UniqueID, []UniqueID, error) { +func checkFullLoaded(ctx context.Context, qc types.QueryCoordClient, dbName string, collectionName string, collectionID int64, searchPartitionIDs []UniqueID) ([]UniqueID, []UniqueID, error) { var loadedPartitionIDs []UniqueID var unloadPartitionIDs []UniqueID @@ -335,7 +334,7 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoord, dbName string, co if err != nil { return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, err = %s", collectionID, searchPartitionIDs, err) } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, reason = %s", collectionID, searchPartitionIDs, resp.GetStatus().GetReason()) } @@ -360,7 +359,7 @@ func checkFullLoaded(ctx context.Context, qc types.QueryCoord, dbName string, co if err != nil { return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, err = %s", collectionID, searchPartitionIDs, err) } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { return nil, nil, fmt.Errorf("showPartitions failed, collection = %d, partitionIDs = %v, reason = %s", collectionID, searchPartitionIDs, resp.GetStatus().GetReason()) } @@ -462,11 +461,11 @@ func reduceStatisticResponse(results []map[string]string) ([]*commonpb.KeyValueP // if err != nil { // return err // } -// if result.Status.ErrorCode != commonpb.ErrorCode_Success { -// return errors.New(result.Status.Reason) +// if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { +// return merr.Error(result.GetStatus()) // } // g.toReduceResults = append(g.toReduceResults, &internalpb.GetStatisticsResponse{ -// Status: merr.Status(nil), +// Status: merr.Success(), // Stats: result.Stats, // }) // log.Debug("get partition statistics from DataCoord execute done", zap.Int64("msgID", g.ID())) @@ -481,7 +480,7 @@ func reduceStatisticResponse(results []map[string]string) ([]*commonpb.KeyValueP // return err // } // g.result = &milvuspb.GetPartitionStatisticsResponse{ -// Status: merr.Status(nil), +// Status: merr.Success(), // Stats: g.innerResult, // } // return nil @@ -534,11 +533,11 @@ func reduceStatisticResponse(results []map[string]string) ([]*commonpb.KeyValueP // if err != nil { // return err // } -// if result.Status.ErrorCode != commonpb.ErrorCode_Success { -// return errors.New(result.Status.Reason) +// if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { +// return merr.Error(result.GetStatus()) // } // g.toReduceResults = append(g.toReduceResults, &internalpb.GetStatisticsResponse{ -// Status: merr.Status(nil), +// Status: merr.Success(), // Stats: result.Stats, // }) // } else { // some partitions have been loaded, get some partition statistics from datacoord @@ -557,11 +556,11 @@ func reduceStatisticResponse(results []map[string]string) ([]*commonpb.KeyValueP // if err != nil { // return err // } -// if result.Status.ErrorCode != commonpb.ErrorCode_Success { -// return errors.New(result.Status.Reason) +// if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { +// return merr.Error(result.GetStatus()) // } // g.toReduceResults = append(g.toReduceResults, &internalpb.GetStatisticsResponse{ -// Status: merr.Status(nil), +// Status: merr.Success(), // Stats: result.Stats, // }) // } @@ -577,7 +576,7 @@ func reduceStatisticResponse(results []map[string]string) ([]*commonpb.KeyValueP // return err // } // g.result = &milvuspb.GetCollectionStatisticsResponse{ -// Status: merr.Status(nil), +// Status: merr.Success(), // Stats: g.innerResult, // } // return nil @@ -589,7 +588,7 @@ type getCollectionStatisticsTask struct { Condition *milvuspb.GetCollectionStatisticsRequest ctx context.Context - dataCoord types.DataCoord + dataCoord types.DataCoordClient result *milvuspb.GetCollectionStatisticsResponse collectionID UniqueID @@ -656,11 +655,11 @@ func (g *getCollectionStatisticsTask) Execute(ctx context.Context) error { if err != nil { return err } - if result.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(result.Status.Reason) + if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return merr.Error(result.GetStatus()) } g.result = &milvuspb.GetCollectionStatisticsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Stats: result.Stats, } return nil @@ -674,7 +673,7 @@ type getPartitionStatisticsTask struct { Condition *milvuspb.GetPartitionStatisticsRequest ctx context.Context - dataCoord types.DataCoord + dataCoord types.DataCoordClient result *milvuspb.GetPartitionStatisticsResponse collectionID UniqueID @@ -746,11 +745,11 @@ func (g *getPartitionStatisticsTask) Execute(ctx context.Context) error { if result == nil { return errors.New("get partition statistics resp is nil") } - if result.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(result.Status.Reason) + if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return merr.Error(result.GetStatus()) } g.result = &milvuspb.GetPartitionStatisticsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Stats: result.Stats, } return nil diff --git a/internal/proxy/task_statistic_test.go b/internal/proxy/task_statistic_test.go index f50d8c4b146cb..2288d4c23ab79 100644 --- a/internal/proxy/task_statistic_test.go +++ b/internal/proxy/task_statistic_test.go @@ -32,15 +32,16 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) type StatisticTaskSuite struct { suite.Suite - rc types.RootCoord - qc types.QueryCoord - qn *mocks.MockQueryNode + rc types.RootCoordClient + qc types.QueryCoordClient + qn *mocks.MockQueryNodeClient lb LBPolicy @@ -54,7 +55,7 @@ func (s *StatisticTaskSuite) SetupSuite() { func (s *StatisticTaskSuite) SetupTest() { successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} - qc := mocks.NewMockQueryCoord(s.T()) + qc := mocks.NewMockQueryCoordClient(s.T()) qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&successStatus, nil) qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ @@ -68,18 +69,15 @@ func (s *StatisticTaskSuite) SetupTest() { }, }, nil).Maybe() qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), PartitionIDs: []int64{1, 2, 3}, }, nil).Maybe() s.qc = qc s.rc = NewRootCoordMock() - s.rc.Start() - s.qn = mocks.NewMockQueryNode(s.T()) + s.qn = mocks.NewMockQueryNodeClient(s.T()) - s.qn.EXPECT().GetComponentStates(mock.Anything).Return(nil, nil).Maybe() + s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() mgr := NewMockShardClientManager(s.T()) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil).Maybe() mgr.EXPECT().UpdateShardLeaders(mock.Anything, mock.Anything).Return(nil).Maybe() @@ -142,7 +140,7 @@ func (s *StatisticTaskSuite) loadCollection() { } func (s *StatisticTaskSuite) TearDownSuite() { - s.rc.Stop() + s.rc.Close() } func (s *StatisticTaskSuite) TestStatisticTask_Timeout() { @@ -168,9 +166,7 @@ func (s *StatisticTaskSuite) getStatisticsTask(ctx context.Context) *getStatisti ctx: ctx, collectionName: s.collectionName, result: &milvuspb.GetStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, request: &milvuspb.GetStatisticsRequest{ Base: &commonpb.MsgBase{ diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 355b8ef79fc48..6713dd625ebc7 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -30,6 +30,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -65,6 +66,7 @@ const ( testVarCharField = "varChar" testFloatVecField = "fvec" testBinaryVecField = "bvec" + testFloat16VecField = "f16vec" testVecDim = 128 testMaxVarCharLength = 100 ) @@ -74,7 +76,6 @@ func constructCollectionSchema( dim int, collectionName string, ) *schemapb.CollectionSchema { - pk := &schemapb.FieldSchema{ FieldID: 0, Name: int64Field, @@ -116,7 +117,6 @@ func constructCollectionSchemaEnableDynamicSchema( dim int, collectionName string, ) *schemapb.CollectionSchema { - pk := &schemapb.FieldSchema{ FieldID: 0, Name: int64Field, @@ -173,7 +173,7 @@ func constructCollectionSchemaByDataType(collectionName string, fieldName2DataTy Name: fieldName, DataType: dataType, } - if dataType == schemapb.DataType_FloatVector || dataType == schemapb.DataType_BinaryVector { + if dataType == schemapb.DataType_FloatVector || dataType == schemapb.DataType_BinaryVector || dataType == schemapb.DataType_Float16Vector { fieldSchema.TypeParams = []*commonpb.KeyValuePair{ { Key: common.DimKey, @@ -205,11 +205,10 @@ func constructCollectionSchemaByDataType(collectionName string, fieldName2DataTy func constructCollectionSchemaWithAllType( boolField, int32Field, int64Field, floatField, doubleField string, - floatVecField, binaryVecField string, + floatVecField, binaryVecField, float16VecField string, dim int, collectionName string, ) *schemapb.CollectionSchema { - b := &schemapb.FieldSchema{ FieldID: 0, Name: boolField, @@ -290,6 +289,21 @@ func constructCollectionSchemaWithAllType( IndexParams: nil, AutoID: false, } + f16Vec := &schemapb.FieldSchema{ + FieldID: 0, + Name: float16VecField, + IsPrimaryKey: false, + Description: "", + DataType: schemapb.DataType_Float16Vector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(dim), + }, + }, + IndexParams: nil, + AutoID: false, + } if enableMultipleVectorFields { return &schemapb.CollectionSchema{ @@ -304,6 +318,7 @@ func constructCollectionSchemaWithAllType( d, fVec, bVec, + f16Vec, }, } } @@ -409,10 +424,11 @@ func constructSearchRequest( func TestTranslateOutputFields(t *testing.T) { const ( - idFieldName = "id" - tsFieldName = "timestamp" - floatVectorFieldName = "float_vector" - binaryVectorFieldName = "binary_vector" + idFieldName = "id" + tsFieldName = "timestamp" + floatVectorFieldName = "float_vector" + binaryVectorFieldName = "binary_vector" + float16VectorFieldName = "float16_vector" ) var outputFields []string var userOutputFields []string @@ -427,6 +443,7 @@ func TestTranslateOutputFields(t *testing.T) { {Name: tsFieldName, FieldID: 1, DataType: schemapb.DataType_Int64}, {Name: floatVectorFieldName, FieldID: 100, DataType: schemapb.DataType_FloatVector}, {Name: binaryVectorFieldName, FieldID: 101, DataType: schemapb.DataType_BinaryVector}, + {Name: float16VectorFieldName, FieldID: 102, DataType: schemapb.DataType_Float16Vector}, }, } @@ -452,23 +469,23 @@ func TestTranslateOutputFields(t *testing.T) { outputFields, userOutputFields, err = translateOutputFields([]string{"*"}, schema, false) assert.Equal(t, nil, err) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, userOutputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, outputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, userOutputFields) outputFields, userOutputFields, err = translateOutputFields([]string{" * "}, schema, false) assert.Equal(t, nil, err) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, userOutputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, outputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, userOutputFields) outputFields, userOutputFields, err = translateOutputFields([]string{"*", tsFieldName}, schema, false) assert.Equal(t, nil, err) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, userOutputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, outputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, userOutputFields) outputFields, userOutputFields, err = translateOutputFields([]string{"*", floatVectorFieldName}, schema, false) assert.Equal(t, nil, err) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, userOutputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, outputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, userOutputFields) //========================================================================= outputFields, userOutputFields, err = translateOutputFields([]string{}, schema, true) @@ -493,18 +510,18 @@ func TestTranslateOutputFields(t *testing.T) { outputFields, userOutputFields, err = translateOutputFields([]string{"*"}, schema, true) assert.Equal(t, nil, err) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, userOutputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, outputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, userOutputFields) outputFields, userOutputFields, err = translateOutputFields([]string{"*", tsFieldName}, schema, true) assert.Equal(t, nil, err) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, userOutputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, outputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, userOutputFields) outputFields, userOutputFields, err = translateOutputFields([]string{"*", floatVectorFieldName}, schema, true) assert.Equal(t, nil, err) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, outputFields) - assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName}, userOutputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, outputFields) + assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName}, userOutputFields) outputFields, userOutputFields, err = translateOutputFields([]string{"A"}, schema, true) assert.Error(t, err) @@ -553,10 +570,7 @@ func TestTranslateOutputFields(t *testing.T) { } func TestCreateCollectionTask(t *testing.T) { - rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() ctx := context.Background() shardsNum := common.DefaultShardsNum prefix := "TestCreateCollectionTask" @@ -869,11 +883,10 @@ func TestCreateCollectionTask(t *testing.T) { func TestHasCollectionTask(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() - qc := getQueryCoord() - qc.Start() - defer qc.Stop() + + defer rc.Close() + qc := getQueryCoordClient() + ctx := context.Background() mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) @@ -902,7 +915,7 @@ func TestHasCollectionTask(t *testing.T) { ShardsNum: shardsNum, } - //CreateCollection + // CreateCollection task := &hasCollectionTask{ Condition: NewTaskCondition(ctx), HasCollectionRequest: &milvuspb.HasCollectionRequest{ @@ -949,16 +962,14 @@ func TestHasCollectionTask(t *testing.T) { assert.NoError(t, err) err = task.Execute(ctx) assert.Error(t, err) - } func TestDescribeCollectionTask(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() - qc := getQueryCoord() - qc.Start() - defer qc.Stop() + + defer rc.Close() + qc := getQueryCoordClient() + ctx := context.Background() mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) @@ -966,7 +977,7 @@ func TestDescribeCollectionTask(t *testing.T) { dbName := "" collectionName := prefix + funcutil.GenRandomStr() - //CreateCollection + // CreateCollection task := &describeCollectionTask{ Condition: NewTaskCondition(ctx), DescribeCollectionRequest: &milvuspb.DescribeCollectionRequest{ @@ -1004,23 +1015,22 @@ func TestDescribeCollectionTask(t *testing.T) { err = task.PreExecute(ctx) assert.NoError(t, err) - rc.Stop() + rc.Close() task.CollectionID = 0 task.CollectionName = collectionName err = task.PreExecute(ctx) assert.NoError(t, err) err = task.Execute(ctx) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, task.result.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, task.result.GetStatus().GetErrorCode()) } func TestDescribeCollectionTask_ShardsNum1(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() - qc := getQueryCoord() - qc.Start() - defer qc.Stop() + + defer rc.Close() + qc := getQueryCoordClient() + ctx := context.Background() mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) @@ -1052,7 +1062,7 @@ func TestDescribeCollectionTask_ShardsNum1(t *testing.T) { rc.CreateCollection(ctx, createColReq) globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName) - //CreateCollection + // CreateCollection task := &describeCollectionTask{ Condition: NewTaskCondition(ctx), DescribeCollectionRequest: &milvuspb.DescribeCollectionRequest{ @@ -1073,18 +1083,15 @@ func TestDescribeCollectionTask_ShardsNum1(t *testing.T) { err = task.Execute(ctx) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, task.result.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, task.result.GetStatus().GetErrorCode()) assert.Equal(t, shardsNum, task.result.ShardsNum) assert.Equal(t, collectionName, task.result.GetCollectionName()) } func TestDescribeCollectionTask_EnableDynamicSchema(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() - qc := getQueryCoord() - qc.Start() - defer qc.Stop() + defer rc.Close() + qc := getQueryCoordClient() ctx := context.Background() mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) @@ -1116,7 +1123,7 @@ func TestDescribeCollectionTask_EnableDynamicSchema(t *testing.T) { rc.CreateCollection(ctx, createColReq) globalMetaCache.GetCollectionID(ctx, dbName, collectionName) - //CreateCollection + // CreateCollection task := &describeCollectionTask{ Condition: NewTaskCondition(ctx), DescribeCollectionRequest: &milvuspb.DescribeCollectionRequest{ @@ -1137,7 +1144,7 @@ func TestDescribeCollectionTask_EnableDynamicSchema(t *testing.T) { err = task.Execute(ctx) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, task.result.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, task.result.GetStatus().GetErrorCode()) assert.Equal(t, shardsNum, task.result.ShardsNum) assert.Equal(t, collectionName, task.result.GetCollectionName()) assert.Equal(t, 2, len(task.result.Schema.Fields)) @@ -1145,11 +1152,10 @@ func TestDescribeCollectionTask_EnableDynamicSchema(t *testing.T) { func TestDescribeCollectionTask_ShardsNum2(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() - qc := getQueryCoord() - qc.Start() - defer qc.Stop() + + defer rc.Close() + qc := getQueryCoordClient() + ctx := context.Background() mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) @@ -1179,7 +1185,7 @@ func TestDescribeCollectionTask_ShardsNum2(t *testing.T) { rc.CreateCollection(ctx, createColReq) globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName) - //CreateCollection + // CreateCollection task := &describeCollectionTask{ Condition: NewTaskCondition(ctx), DescribeCollectionRequest: &milvuspb.DescribeCollectionRequest{ @@ -1203,16 +1209,16 @@ func TestDescribeCollectionTask_ShardsNum2(t *testing.T) { err = task.Execute(ctx) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, task.result.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, task.result.GetStatus().GetErrorCode()) assert.Equal(t, common.DefaultShardsNum, task.result.ShardsNum) assert.Equal(t, collectionName, task.result.GetCollectionName()) - rc.Stop() + rc.Close() } func TestCreatePartitionTask(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() + + defer rc.Close() ctx := context.Background() prefix := "TestCreatePartitionTask" dbName := "" @@ -1257,23 +1263,20 @@ func TestCreatePartitionTask(t *testing.T) { func TestDropPartitionTask(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() + + defer rc.Close() ctx := context.Background() prefix := "TestDropPartitionTask" dbName := "" collectionName := prefix + funcutil.GenRandomStr() partitionName := prefix + funcutil.GenRandomStr() - qc := getQueryCoord() + qc := getQueryCoordClient() qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), PartitionIDs: []int64{}, }, nil) qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil) mockCache := NewMockCache(t) @@ -1406,8 +1409,8 @@ func TestDropPartitionTask(t *testing.T) { func TestHasPartitionTask(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() + + defer rc.Close() ctx := context.Background() prefix := "TestHasPartitionTask" dbName := "" @@ -1452,8 +1455,8 @@ func TestHasPartitionTask(t *testing.T) { func TestShowPartitionsTask(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() + + defer rc.Close() ctx := context.Background() prefix := "TestShowPartitionsTask" dbName := "" @@ -1502,18 +1505,15 @@ func TestShowPartitionsTask(t *testing.T) { task.ShowPartitionsRequest.Type = milvuspb.ShowType_InMemory err = task.Execute(ctx) assert.Error(t, err) - } func TestTask_Int64PrimaryKey(t *testing.T) { var err error rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() - qc := getQueryCoord() - qc.Start() - defer qc.Stop() + + defer rc.Close() + qc := getQueryCoordClient() ctx := context.Background() @@ -1533,7 +1533,8 @@ func TestTask_Int64PrimaryKey(t *testing.T) { testInt64Field: schemapb.DataType_Int64, testFloatField: schemapb.DataType_Float, testDoubleField: schemapb.DataType_Double, - testFloatVecField: schemapb.DataType_FloatVector} + testFloatVecField: schemapb.DataType_FloatVector, + } if enableMultipleVectorFields { fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector } @@ -1631,10 +1632,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) { Condition: NewTaskCondition(ctx), ctx: ctx, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), IDs: nil, SuccIndex: nil, ErrIndex: nil, @@ -1663,30 +1661,18 @@ func TestTask_Int64PrimaryKey(t *testing.T) { assert.NoError(t, task.PostExecute(ctx)) }) - t.Run("delete", func(t *testing.T) { + t.Run("simple delete", func(t *testing.T) { task := &deleteTask{ Condition: NewTaskCondition(ctx), - deleteMsg: &msgstream.DeleteMsg{ - BaseMsg: msgstream.BaseMsg{}, - DeleteRequest: msgpb.DeleteRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Delete, - MsgID: 0, - Timestamp: 0, - SourceID: paramtable.GetNodeID(), - }, - CollectionName: collectionName, - PartitionName: partitionName, - }, + req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + PartitionName: partitionName, + Expr: "int64 in [0, 1]", }, idAllocator: idAllocator, - deleteExpr: "int64 in [0, 1]", ctx: ctx, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), IDs: nil, SuccIndex: nil, ErrIndex: nil, @@ -1706,8 +1692,6 @@ func TestTask_Int64PrimaryKey(t *testing.T) { id := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) task.SetID(id) assert.Equal(t, id, task.ID()) - - task.deleteMsg.Base.MsgType = commonpb.MsgType_Delete assert.Equal(t, commonpb.MsgType_Delete, task.Type()) ts := Timestamp(time.Now().UnixNano()) @@ -1718,30 +1702,22 @@ func TestTask_Int64PrimaryKey(t *testing.T) { assert.NoError(t, task.PreExecute(ctx)) assert.NoError(t, task.Execute(ctx)) assert.NoError(t, task.PostExecute(ctx)) + }) - task2 := &deleteTask{ + t.Run("complex delete", func(t *testing.T) { + lb := NewMockLBPolicy(t) + task := &deleteTask{ Condition: NewTaskCondition(ctx), - deleteMsg: &msgstream.DeleteMsg{ - BaseMsg: msgstream.BaseMsg{}, - DeleteRequest: msgpb.DeleteRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Delete, - MsgID: 0, - Timestamp: 0, - SourceID: paramtable.GetNodeID(), - }, - CollectionName: collectionName, - PartitionName: partitionName, - }, + lb: lb, + req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + PartitionName: partitionName, + Expr: "int64 < 2", }, idAllocator: idAllocator, - deleteExpr: "int64 not in [0, 1]", ctx: ctx, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), IDs: nil, SuccIndex: nil, ErrIndex: nil, @@ -1754,7 +1730,23 @@ func TestTask_Int64PrimaryKey(t *testing.T) { chMgr: chMgr, chTicker: ticker, } - assert.Error(t, task2.PreExecute(ctx)) + lb.EXPECT().Execute(mock.Anything, mock.Anything).Return(nil) + assert.NoError(t, task.OnEnqueue()) + assert.NotNil(t, task.TraceCtx()) + + id := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) + task.SetID(id) + assert.Equal(t, id, task.ID()) + assert.Equal(t, commonpb.MsgType_Delete, task.Type()) + + ts := Timestamp(time.Now().UnixNano()) + task.SetTs(ts) + assert.Equal(t, ts, task.BeginTs()) + assert.Equal(t, ts, task.EndTs()) + + assert.NoError(t, task.PreExecute(ctx)) + assert.NoError(t, task.Execute(ctx)) + assert.NoError(t, task.PostExecute(ctx)) }) } @@ -1762,11 +1754,9 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { var err error rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() - qc := getQueryCoord() - qc.Start() - defer qc.Stop() + + defer rc.Close() + qc := getQueryCoordClient() ctx := context.Background() @@ -1787,7 +1777,8 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { testFloatField: schemapb.DataType_Float, testDoubleField: schemapb.DataType_Double, testVarCharField: schemapb.DataType_VarChar, - testFloatVecField: schemapb.DataType_FloatVector} + testFloatVecField: schemapb.DataType_FloatVector, + } if enableMultipleVectorFields { fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector } @@ -1886,10 +1877,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { Condition: NewTaskCondition(ctx), ctx: ctx, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), IDs: nil, SuccIndex: nil, ErrIndex: nil, @@ -1974,10 +1962,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { }, ctx: ctx, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), IDs: nil, SuccIndex: nil, ErrIndex: nil, @@ -2008,30 +1993,18 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { assert.NoError(t, task.PostExecute(ctx)) }) - t.Run("delete", func(t *testing.T) { + t.Run("simple delete", func(t *testing.T) { task := &deleteTask{ Condition: NewTaskCondition(ctx), - deleteMsg: &msgstream.DeleteMsg{ - BaseMsg: msgstream.BaseMsg{}, - DeleteRequest: msgpb.DeleteRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Delete, - MsgID: 0, - Timestamp: 0, - SourceID: paramtable.GetNodeID(), - }, - CollectionName: collectionName, - PartitionName: partitionName, - }, + req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + PartitionName: partitionName, + Expr: "varChar in [\"milvus\", \"test\"]", }, idAllocator: idAllocator, - deleteExpr: "varChar in [\"milvus\", \"test\"]", ctx: ctx, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), IDs: nil, SuccIndex: nil, ErrIndex: nil, @@ -2051,8 +2024,6 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { id := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) task.SetID(id) assert.Equal(t, id, task.ID()) - - task.deleteMsg.Base.MsgType = commonpb.MsgType_Delete assert.Equal(t, commonpb.MsgType_Delete, task.Type()) ts := Timestamp(time.Now().UnixNano()) @@ -2063,50 +2034,13 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { assert.NoError(t, task.PreExecute(ctx)) assert.NoError(t, task.Execute(ctx)) assert.NoError(t, task.PostExecute(ctx)) - - task2 := &deleteTask{ - Condition: NewTaskCondition(ctx), - deleteMsg: &msgstream.DeleteMsg{ - BaseMsg: msgstream.BaseMsg{}, - DeleteRequest: msgpb.DeleteRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Delete, - MsgID: 0, - Timestamp: 0, - SourceID: paramtable.GetNodeID(), - }, - CollectionName: collectionName, - PartitionName: partitionName, - }, - }, - idAllocator: idAllocator, - deleteExpr: "varChar not in [\"milvus\", \"test\"]", - ctx: ctx, - result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - IDs: nil, - SuccIndex: nil, - ErrIndex: nil, - Acknowledged: false, - InsertCnt: 0, - DeleteCnt: 0, - UpsertCnt: 0, - Timestamp: 0, - }, - chMgr: chMgr, - chTicker: ticker, - } - assert.Error(t, task2.PreExecute(ctx)) }) } func TestCreateAlias_all(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() + + defer rc.Close() ctx := context.Background() prefix := "TestCreateAlias_all" collectionName := prefix + funcutil.GenRandomStr() @@ -2117,10 +2051,8 @@ func TestCreateAlias_all(t *testing.T) { CollectionName: collectionName, Alias: "alias1", }, - ctx: ctx, - result: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + ctx: ctx, + result: merr.Success(), rootCoord: rc, } @@ -2146,8 +2078,8 @@ func TestCreateAlias_all(t *testing.T) { func TestDropAlias_all(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() + + defer rc.Close() ctx := context.Background() task := &DropAliasTask{ Condition: NewTaskCondition(ctx), @@ -2155,10 +2087,8 @@ func TestDropAlias_all(t *testing.T) { Base: nil, Alias: "alias1", }, - ctx: ctx, - result: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + ctx: ctx, + result: merr.Success(), rootCoord: rc, } @@ -2179,13 +2109,12 @@ func TestDropAlias_all(t *testing.T) { assert.NoError(t, task.PreExecute(ctx)) assert.NoError(t, task.Execute(ctx)) assert.NoError(t, task.PostExecute(ctx)) - } func TestAlterAlias_all(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() + + defer rc.Close() ctx := context.Background() prefix := "TestAlterAlias_all" collectionName := prefix + funcutil.GenRandomStr() @@ -2196,10 +2125,8 @@ func TestAlterAlias_all(t *testing.T) { CollectionName: collectionName, Alias: "alias1", }, - ctx: ctx, - result: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + ctx: ctx, + result: merr.Success(), rootCoord: rc, } @@ -2545,11 +2472,12 @@ func Test_dropCollectionTask_PreExecute(t *testing.T) { } func Test_dropCollectionTask_Execute(t *testing.T) { - mockRC := mocks.NewRootCoord(t) + mockRC := mocks.NewMockRootCoordClient(t) mockRC.On("DropCollection", mock.Anything, // context.Context mock.Anything, // *milvuspb.DropCollectionRequest - ).Return(&commonpb.Status{}, func(ctx context.Context, request *milvuspb.DropCollectionRequest) error { + mock.Anything, + ).Return(&commonpb.Status{}, func(ctx context.Context, request *milvuspb.DropCollectionRequest, opts ...grpc.CallOption) error { switch request.GetCollectionName() { case "c1": return errors.New("error mock DropCollection") @@ -2585,22 +2513,19 @@ func Test_loadCollectionTask_Execute(t *testing.T) { rc := newMockRootCoord() dc := NewDataCoordMock() - qc := getQueryCoord() - qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + qc := getQueryCoordClient() + qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{ + Status: merr.Success(), PartitionIDs: []int64{}, }, nil) - qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: merr.Status(nil), + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ + Status: merr.Success(), }, nil) dbName := funcutil.GenRandomStr() collectionName := funcutil.GenRandomStr() collectionID := UniqueID(1) - //fieldName := funcutil.GenRandomStr() + // fieldName := funcutil.GenRandomStr() indexName := funcutil.GenRandomStr() ctx := context.Background() indexID := int64(1000) @@ -2609,11 +2534,9 @@ func Test_loadCollectionTask_Execute(t *testing.T) { // failed to get collection id. _ = InitMetaCache(ctx, rc, qc, shardMgr) - rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { + rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { return &milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Schema: newTestSchema(), CollectionID: collectionID, CollectionName: request.CollectionName, @@ -2646,7 +2569,7 @@ func Test_loadCollectionTask_Execute(t *testing.T) { }) t.Run("indexcoord describe index not success", func(t *testing.T) { - dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error) { + dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) { return &indexpb.DescribeIndexResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -2660,11 +2583,9 @@ func Test_loadCollectionTask_Execute(t *testing.T) { }) t.Run("no vector index", func(t *testing.T) { - dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error) { + dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) { return &indexpb.DescribeIndexResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), IndexInfos: []*indexpb.IndexInfo{ { CollectionID: collectionID, @@ -2693,22 +2614,19 @@ func Test_loadPartitionTask_Execute(t *testing.T) { rc := newMockRootCoord() dc := NewDataCoordMock() - qc := getQueryCoord() + qc := getQueryCoordClient() qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), PartitionIDs: []int64{}, }, nil) qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil) dbName := funcutil.GenRandomStr() collectionName := funcutil.GenRandomStr() collectionID := UniqueID(1) - //fieldName := funcutil.GenRandomStr() + // fieldName := funcutil.GenRandomStr() indexName := funcutil.GenRandomStr() ctx := context.Background() indexID := int64(1000) @@ -2717,11 +2635,9 @@ func Test_loadPartitionTask_Execute(t *testing.T) { // failed to get collection id. _ = InitMetaCache(ctx, rc, qc, shardMgr) - rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { + rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { return &milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Schema: newTestSchema(), CollectionID: collectionID, CollectionName: request.CollectionName, @@ -2754,7 +2670,7 @@ func Test_loadPartitionTask_Execute(t *testing.T) { }) t.Run("indexcoord describe index not success", func(t *testing.T) { - dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error) { + dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) { return &indexpb.DescribeIndexResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -2768,11 +2684,9 @@ func Test_loadPartitionTask_Execute(t *testing.T) { }) t.Run("no vector index", func(t *testing.T) { - dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error) { + dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) { return &indexpb.DescribeIndexResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), IndexInfos: []*indexpb.IndexInfo{ { CollectionID: collectionID, @@ -2799,12 +2713,11 @@ func Test_loadPartitionTask_Execute(t *testing.T) { func TestCreateResourceGroupTask(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() - qc := getQueryCoord() - qc.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).Return(merr.Status(nil), nil) - qc.Start() - defer qc.Stop() + + defer rc.Close() + qc := getQueryCoordClient() + qc.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything, mock.Anything).Return(merr.Success(), nil) + ctx := context.Background() mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) @@ -2839,12 +2752,11 @@ func TestCreateResourceGroupTask(t *testing.T) { func TestDropResourceGroupTask(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() - qc := getQueryCoord() - qc.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(merr.Status(nil), nil) - qc.Start() - defer qc.Stop() + + defer rc.Close() + qc := getQueryCoordClient() + qc.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(merr.Success(), nil) + ctx := context.Background() mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) @@ -2879,12 +2791,11 @@ func TestDropResourceGroupTask(t *testing.T) { func TestTransferNodeTask(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() - qc := getQueryCoord() - qc.EXPECT().TransferNode(mock.Anything, mock.Anything).Return(merr.Status(nil), nil) - qc.Start() - defer qc.Stop() + + defer rc.Close() + qc := getQueryCoordClient() + qc.EXPECT().TransferNode(mock.Anything, mock.Anything).Return(merr.Success(), nil) + ctx := context.Background() mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) @@ -2921,10 +2832,9 @@ func TestTransferNodeTask(t *testing.T) { func TestTransferReplicaTask(t *testing.T) { rc := &MockRootCoordClientInterface{} - qc := getQueryCoord() - qc.EXPECT().TransferReplica(mock.Anything, mock.Anything).Return(merr.Status(nil), nil) - qc.Start() - defer qc.Stop() + qc := getQueryCoordClient() + qc.EXPECT().TransferReplica(mock.Anything, mock.Anything).Return(merr.Success(), nil) + ctx := context.Background() mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) @@ -2964,13 +2874,12 @@ func TestTransferReplicaTask(t *testing.T) { func TestListResourceGroupsTask(t *testing.T) { rc := &MockRootCoordClientInterface{} - qc := getQueryCoord() + qc := getQueryCoordClient() qc.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(&milvuspb.ListResourceGroupsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), ResourceGroups: []string{meta.DefaultResourceGroupName, "rg"}, }, nil) - qc.Start() - defer qc.Stop() + ctx := context.Background() mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) @@ -2999,7 +2908,7 @@ func TestListResourceGroupsTask(t *testing.T) { err := task.Execute(ctx) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, task.result.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, task.result.GetStatus().GetErrorCode()) groups := task.result.GetResourceGroups() assert.Contains(t, groups, meta.DefaultResourceGroupName) assert.Contains(t, groups, "rg") @@ -3007,9 +2916,9 @@ func TestListResourceGroupsTask(t *testing.T) { func TestDescribeResourceGroupTask(t *testing.T) { rc := &MockRootCoordClientInterface{} - qc := getQueryCoord() + qc := getQueryCoordClient() qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{ - Status: merr.Status(nil), + Status: merr.Success(), ResourceGroup: &querypb.ResourceGroupInfo{ Name: "rg", Capacity: 2, @@ -3018,8 +2927,7 @@ func TestDescribeResourceGroupTask(t *testing.T) { NumIncomingNode: map[int64]int32{2: 2}, }, }, nil) - qc.Start() - defer qc.Stop() + ctx := context.Background() mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) @@ -3052,7 +2960,7 @@ func TestDescribeResourceGroupTask(t *testing.T) { err := task.Execute(ctx) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, task.result.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, task.result.GetStatus().GetErrorCode()) groupInfo := task.result.GetResourceGroup() outgoingNodeNum := groupInfo.GetNumOutgoingNode() incomingNodeNum := groupInfo.GetNumIncomingNode() @@ -3062,12 +2970,10 @@ func TestDescribeResourceGroupTask(t *testing.T) { func TestDescribeResourceGroupTaskFailed(t *testing.T) { rc := &MockRootCoordClientInterface{} - qc := getQueryCoord() + qc := getQueryCoordClient() qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, }, nil) - qc.Start() - defer qc.Stop() ctx := context.Background() mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) @@ -3100,12 +3006,11 @@ func TestDescribeResourceGroupTaskFailed(t *testing.T) { err := task.Execute(ctx) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, task.result.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, task.result.GetStatus().GetErrorCode()) qc.ExpectedCalls = nil - qc.EXPECT().Stop().Return(nil) qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{ - Status: merr.Status(nil), + Status: merr.Success(), ResourceGroup: &querypb.ResourceGroupInfo{ Name: "rg", Capacity: 2, @@ -3122,8 +3027,8 @@ func TestDescribeResourceGroupTaskFailed(t *testing.T) { func TestCreateCollectionTaskWithPartitionKey(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() + + defer rc.Close() ctx := context.Background() shardsNum := common.DefaultShardsNum prefix := "TestCreateCollectionTaskWithPartitionKey" @@ -3329,11 +3234,9 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) { func TestPartitionKey(t *testing.T) { rc := NewRootCoordMock() - rc.Start() - defer rc.Stop() - qc := getQueryCoord() - qc.Start() - defer qc.Stop() + + defer rc.Close() + qc := getQueryCoordClient() ctx := context.Background() @@ -3450,10 +3353,7 @@ func TestPartitionKey(t *testing.T) { Condition: NewTaskCondition(ctx), ctx: ctx, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), IDs: nil, SuccIndex: nil, ErrIndex: nil, @@ -3502,9 +3402,7 @@ func TestPartitionKey(t *testing.T) { }, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), IDs: &schemapb.IDs{ IdField: nil, }, @@ -3529,25 +3427,13 @@ func TestPartitionKey(t *testing.T) { t.Run("delete", func(t *testing.T) { dt := &deleteTask{ Condition: NewTaskCondition(ctx), - deleteMsg: &BaseDeleteTask{ - BaseMsg: msgstream.BaseMsg{}, - DeleteRequest: msgpb.DeleteRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Delete, - MsgID: 0, - Timestamp: 0, - SourceID: paramtable.GetNodeID(), - }, - CollectionName: collectionName, - }, + req: &milvuspb.DeleteRequest{ + CollectionName: collectionName, + Expr: "int64_field in [0, 1]", }, - deleteExpr: "int64_field in [0, 1]", - ctx: ctx, + ctx: ctx, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), IDs: nil, SuccIndex: nil, ErrIndex: nil, @@ -3562,10 +3448,10 @@ func TestPartitionKey(t *testing.T) { chTicker: ticker, } // don't support specify partition name if use partition key - dt.deleteMsg.PartitionName = partitionNames[0] + dt.req.PartitionName = partitionNames[0] assert.Error(t, dt.PreExecute(ctx)) - dt.deleteMsg.PartitionName = "" + dt.req.PartitionName = "" assert.NoError(t, dt.PreExecute(ctx)) assert.NoError(t, dt.Execute(ctx)) assert.NoError(t, dt.PostExecute(ctx)) diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 64aec5660b258..f08188116e960 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -276,7 +276,7 @@ func (it *upsertTask) PreExecute(ctx context.Context) error { log := log.Ctx(ctx).With(zap.String("collectionName", collectionName)) it.result = &milvuspb.MutationResult{ - Status: merr.Status(nil), + Status: merr.Success(), IDs: &schemapb.IDs{ IdField: nil, }, @@ -389,8 +389,7 @@ func (it *upsertTask) insertExecute(ctx context.Context, msgPack *msgstream.MsgP if err != nil { log.Warn("get vChannels failed when insertExecute", zap.Error(err)) - it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - it.result.Status.Reason = err.Error() + it.result.Status = merr.Status(err) return err } @@ -413,8 +412,7 @@ func (it *upsertTask) insertExecute(ctx context.Context, msgPack *msgstream.MsgP if err != nil { log.Warn("assign segmentID and repack insert data failed when insertExecute", zap.Error(err)) - it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - it.result.Status.Reason = err.Error() + it.result.Status = merr.Status(err) return err } assignSegmentIDDur := tr.RecordSpan() @@ -438,8 +436,7 @@ func (it *upsertTask) deleteExecute(ctx context.Context, msgPack *msgstream.MsgP channelNames, err := it.chMgr.getVChannels(collID) if err != nil { log.Warn("get vChannels failed when deleteExecute", zap.Error(err)) - it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - it.result.Status.Reason = err.Error() + it.result.Status = merr.Status(err) return err } it.upsertMsg.DeleteMsg.PrimaryKeys = it.result.IDs @@ -539,8 +536,7 @@ func (it *upsertTask) Execute(ctx context.Context) (err error) { tr.RecordSpan() err = stream.Produce(msgPack) if err != nil { - it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - it.result.Status.Reason = err.Error() + it.result.Status = merr.Status(err) return err } sendMsgDur := tr.RecordSpan() diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index 2d41e2c35a9c3..dd6cfda6915e6 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -17,7 +17,6 @@ package proxy import ( "context" - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -307,10 +306,8 @@ func TestUpsertTask(t *testing.T) { ).Return(collectionID, nil) globalMetaCache = cache - chMgr := newMockChannelsMgr() - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return channels, nil - } + chMgr := NewMockChannelsMgr(t) + chMgr.EXPECT().getChannels(mock.Anything).Return(channels, nil) ut := upsertTask{ ctx: context.Background(), req: &milvuspb.UpsertRequest{ @@ -323,12 +320,5 @@ func TestUpsertTask(t *testing.T) { resChannels := ut.getChannels() assert.ElementsMatch(t, channels, resChannels) assert.ElementsMatch(t, channels, ut.pChannels) - - chMgr.getChannelsFunc = func(collectionID UniqueID) ([]pChan, error) { - return nil, fmt.Errorf("mock err") - } - // get channels again, should return task's pChannels, so getChannelsFunc should not invoke again - resChannels = ut.getChannels() - assert.ElementsMatch(t, channels, resChannels) }) } diff --git a/internal/proxy/timestamp.go b/internal/proxy/timestamp.go index b5850ab2a13d1..7fe95cd6270ad 100644 --- a/internal/proxy/timestamp.go +++ b/internal/proxy/timestamp.go @@ -66,10 +66,13 @@ func (ta *timestampAllocator) alloc(ctx context.Context, count uint32) ([]Timest if err != nil { return nil, fmt.Errorf("syncTimestamp Failed:%w", err) } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return nil, fmt.Errorf("syncTimeStamp Failed:%s", resp.Status.Reason) + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return nil, fmt.Errorf("syncTimeStamp Failed:%s", resp.GetStatus().GetReason()) } - start, cnt := resp.Timestamp, resp.Count + if resp == nil { + return nil, fmt.Errorf("empty AllocTimestampResponse") + } + start, cnt := resp.GetTimestamp(), resp.GetCount() ret := make([]Timestamp, cnt) for i := uint32(0); i < cnt; i++ { ret[i] = start + uint64(i) diff --git a/internal/proxy/timestamp_test.go b/internal/proxy/timestamp_test.go index ba52eb5a74ae8..635e2c86ea692 100644 --- a/internal/proxy/timestamp_test.go +++ b/internal/proxy/timestamp_test.go @@ -21,9 +21,9 @@ import ( "math/rand" "testing" - "github.com/milvus-io/milvus/pkg/util/uniquegenerator" - "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/uniquegenerator" ) func TestNewTimestampAllocator(t *testing.T) { diff --git a/internal/proxy/util.go b/internal/proxy/util.go index d95db4cb48660..d2285f536d14d 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -34,7 +34,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/parser/planparserv2" "github.com/milvus-io/milvus/internal/proto/planpb" - "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" @@ -59,8 +58,10 @@ const ( defaultMaxVarCharLength = 65535 - // DefaultIndexType name of default index type for scalar field - DefaultIndexType = "STL_SORT" + defaultMaxArrayCapacity = 4096 + + // DefaultArithmeticIndexType name of default index type for scalar field + DefaultArithmeticIndexType = "STL_SORT" // DefaultStringIndexType name of default index type for varChar/string field DefaultStringIndexType = "Trie" @@ -84,6 +85,12 @@ func isNumber(c uint8) bool { return true } +func isVectorType(dataType schemapb.DataType) bool { + return dataType == schemapb.DataType_FloatVector || + dataType == schemapb.DataType_BinaryVector || + dataType == schemapb.DataType_Float16Vector +} + func validateMaxQueryResultWindow(offset int64, limit int64) error { if offset < 0 { return fmt.Errorf("%s [%d] is invalid, should be gte than 0", OffsetKey, offset) @@ -120,24 +127,24 @@ func validateCollectionNameOrAlias(entity, entityType string) error { entity = strings.TrimSpace(entity) if entity == "" { - return fmt.Errorf("collection %s should not be empty", entityType) + return merr.WrapErrParameterInvalidMsg("collection %s should not be empty", entityType) } invalidMsg := fmt.Sprintf("Invalid collection %s: %s. ", entityType, entity) if len(entity) > Params.ProxyCfg.MaxNameLength.GetAsInt() { - return fmt.Errorf("%s the length of a collection %s must be less than %s characters", invalidMsg, entityType, + return merr.WrapErrParameterInvalidMsg("%s the length of a collection %s must be less than %s characters", invalidMsg, entityType, Params.ProxyCfg.MaxNameLength.GetValue()) } firstChar := entity[0] if firstChar != '_' && !isAlpha(firstChar) { - return fmt.Errorf("%s the first character of a collection %s must be an underscore or letter", invalidMsg, entityType) + return merr.WrapErrParameterInvalidMsg("%s the first character of a collection %s must be an underscore or letter", invalidMsg, entityType) } for i := 1; i < len(entity); i++ { c := entity[i] if c != '_' && !isAlpha(c) && !isNumber(c) { - return fmt.Errorf("%s collection %s can only contain numbers, letters and underscores", invalidMsg, entityType) + return merr.WrapErrParameterInvalidMsg("%s collection %s can only contain numbers, letters and underscores", invalidMsg, entityType) } } return nil @@ -150,19 +157,19 @@ func ValidateResourceGroupName(entity string) error { invalidMsg := fmt.Sprintf("Invalid resource group name %s.", entity) if len(entity) > Params.ProxyCfg.MaxNameLength.GetAsInt() { - return fmt.Errorf("%s the length of a resource group name must be less than %s characters", + return merr.WrapErrParameterInvalidMsg("%s the length of a resource group name must be less than %s characters", invalidMsg, Params.ProxyCfg.MaxNameLength.GetValue()) } firstChar := entity[0] if firstChar != '_' && !isAlpha(firstChar) { - return fmt.Errorf("%s the first character of a resource group name must be an underscore or letter", invalidMsg) + return merr.WrapErrParameterInvalidMsg("%s the first character of a resource group name must be an underscore or letter", invalidMsg) } for i := 1; i < len(entity); i++ { c := entity[i] if c != '_' && !isAlpha(c) && !isNumber(c) { - return fmt.Errorf("%s resource group name can only contain numbers, letters and underscores", invalidMsg) + return merr.WrapErrParameterInvalidMsg("%s resource group name can only contain numbers, letters and underscores", invalidMsg) } } return nil @@ -170,24 +177,24 @@ func ValidateResourceGroupName(entity string) error { func ValidateDatabaseName(dbName string) error { if dbName == "" { - return merr.WrapErrInvalidedDatabaseName(dbName, "database name couldn't be empty") + return merr.WrapErrDatabaseNameInvalid(dbName, "database name couldn't be empty") } if len(dbName) > Params.ProxyCfg.MaxNameLength.GetAsInt() { - return merr.WrapErrInvalidedDatabaseName(dbName, + return merr.WrapErrDatabaseNameInvalid(dbName, fmt.Sprintf("the length of a database name must be less than %d characters", Params.ProxyCfg.MaxNameLength.GetAsInt())) } firstChar := dbName[0] if firstChar != '_' && !isAlpha(firstChar) { - return merr.WrapErrInvalidedDatabaseName(dbName, + return merr.WrapErrDatabaseNameInvalid(dbName, "the first character of a database name must be an underscore or letter") } for i := 1; i < len(dbName); i++ { c := dbName[i] if c != '_' && !isAlpha(c) && !isNumber(c) { - return merr.WrapErrInvalidedDatabaseName(dbName, + return merr.WrapErrDatabaseNameInvalid(dbName, "database name can only contain numbers, letters and underscores") } } @@ -237,23 +244,33 @@ func validatePartitionTag(partitionTag string, strictCheck bool) error { return nil } +func validateStringIndexType(indexType string) bool { + // compatible with the index type marisa-trie of attu versions prior to 2.3.0 + return indexType == DefaultStringIndexType || indexType == "marisa-trie" +} + +func validateArithmeticIndexType(indexType string) bool { + // compatible with the index type Asceneding of attu versions prior to 2.3.0 + return indexType == DefaultArithmeticIndexType || indexType == "Asceneding" +} + func validateFieldName(fieldName string) error { fieldName = strings.TrimSpace(fieldName) if fieldName == "" { - return errors.New("field name should not be empty") + return merr.WrapErrFieldNameInvalid(fieldName, "field name should not be empty") } invalidMsg := "Invalid field name: " + fieldName + ". " if len(fieldName) > Params.ProxyCfg.MaxNameLength.GetAsInt() { msg := invalidMsg + "The length of a field name must be less than " + Params.ProxyCfg.MaxNameLength.GetValue() + " characters." - return errors.New(msg) + return merr.WrapErrFieldNameInvalid(fieldName, msg) } firstChar := fieldName[0] if firstChar != '_' && !isAlpha(firstChar) { msg := invalidMsg + "The first character of a field name must be an underscore or letter." - return errors.New(msg) + return merr.WrapErrFieldNameInvalid(fieldName, msg) } fieldNameSize := len(fieldName) @@ -261,7 +278,7 @@ func validateFieldName(fieldName string) error { c := fieldName[i] if c != '_' && !isAlpha(c) && !isNumber(c) { msg := invalidMsg + "Field name cannot only contain numbers, letters, and underscores." - return errors.New(msg) + return merr.WrapErrFieldNameInvalid(fieldName, msg) } } return nil @@ -306,7 +323,7 @@ func validateMaxLengthPerRow(collectionName string, field *schemapb.FieldSchema) return err } if maxLengthPerRow > defaultMaxVarCharLength || maxLengthPerRow <= 0 { - return fmt.Errorf("the maximum length specified for a VarChar shoule be in (0, 65535]") + return merr.WrapErrParameterInvalidMsg("the maximum length specified for a VarChar should be in (0, 65535]") } exist = true } @@ -318,8 +335,32 @@ func validateMaxLengthPerRow(collectionName string, field *schemapb.FieldSchema) return nil } +func validateMaxCapacityPerRow(collectionName string, field *schemapb.FieldSchema) error { + exist := false + for _, param := range field.TypeParams { + if param.Key != common.MaxCapacityKey { + continue + } + + maxCapacityPerRow, err := strconv.ParseInt(param.Value, 10, 64) + if err != nil { + return fmt.Errorf("the value of %s must be an integer", common.MaxCapacityKey) + } + if maxCapacityPerRow > defaultMaxArrayCapacity || maxCapacityPerRow <= 0 { + return fmt.Errorf("the maximum capacity specified for a Array should be in (0, 4096]") + } + exist = true + } + // if not exist type params max_length, return error + if !exist { + return fmt.Errorf("type param(max_capacity) should be specified for array field of collection %s", collectionName) + } + + return nil +} + func validateVectorFieldMetricType(field *schemapb.FieldSchema) error { - if (field.DataType != schemapb.DataType_FloatVector) && (field.DataType != schemapb.DataType_BinaryVector) { + if !isVectorType(field.DataType) { return nil } for _, params := range field.IndexParams { @@ -342,6 +383,19 @@ func validateDuplicatedFieldName(fields []*schemapb.FieldSchema) error { return nil } +func validateElementType(dataType schemapb.DataType) error { + switch dataType { + case schemapb.DataType_Bool, schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, + schemapb.DataType_Int64, schemapb.DataType_Float, schemapb.DataType_Double, schemapb.DataType_VarChar: + return nil + case schemapb.DataType_String: + return errors.New("string data type not supported yet, please use VarChar type instead") + case schemapb.DataType_None: + return errors.New("element data type None is not valid") + } + return fmt.Errorf("element type %s is not supported", dataType.String()) +} + func validateFieldType(schema *schemapb.CollectionSchema) error { for _, field := range schema.GetFields() { switch field.GetDataType() { @@ -349,6 +403,10 @@ func validateFieldType(schema *schemapb.CollectionSchema) error { return errors.New("string data type not supported yet, please use VarChar type instead") case schemapb.DataType_None: return errors.New("data type None is not valid") + case schemapb.DataType_Array: + if err := validateElementType(field.GetElementType()); err != nil { + return err + } } } return nil @@ -432,7 +490,7 @@ func isVector(dataType schemapb.DataType) (bool, error) { schemapb.DataType_Float, schemapb.DataType_Double: return false, nil - case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector: + case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector: return true, nil } @@ -443,7 +501,7 @@ func validateMetricType(dataType schemapb.DataType, metricTypeStrRaw string) err metricTypeStr := strings.ToUpper(metricTypeStrRaw) switch metricTypeStr { case metric.L2, metric.IP, metric.COSINE: - if dataType == schemapb.DataType_FloatVector { + if dataType == schemapb.DataType_FloatVector || dataType == schemapb.DataType_Float16Vector { return nil } case metric.JACCARD, metric.HAMMING, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE: @@ -474,7 +532,7 @@ func validateSchema(coll *schemapb.CollectionSchema) error { return fmt.Errorf("there are more than one primary key, field name = %s, %s", coll.Fields[primaryIdx].Name, field.Name) } if field.DataType != schemapb.DataType_Int64 { - return fmt.Errorf("type of primary key shoule be int64") + return fmt.Errorf("type of primary key should be int64") } primaryIdx = idx } @@ -519,10 +577,9 @@ func validateSchema(coll *schemapb.CollectionSchema) error { if err4 != nil { return err4 } - } else { - // in C++, default type will be specified - // do nothing } + // in C++, default type will be specified + // do nothing } else { if len(field.IndexParams) != 0 { return fmt.Errorf("index params is not empty for scalar field: %s(%d)", field.Name, field.FieldID) @@ -548,7 +605,7 @@ func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error { for i := range schema.Fields { name := schema.Fields[i].Name dType := schema.Fields[i].DataType - isVec := dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector + isVec := dType == schemapb.DataType_BinaryVector || dType == schemapb.DataType_FloatVector || dType == schemapb.DataType_Float16Vector if isVec && vecExist && !enableMultipleVectorFields { return fmt.Errorf( "multiple vector fields is not supported, fields name: %s, %s", @@ -675,27 +732,23 @@ func ValidateUsername(username string) error { username = strings.TrimSpace(username) if username == "" { - return errors.New("username should not be empty") + return merr.WrapErrParameterInvalidMsg("username must be not empty") } - invalidMsg := "Invalid username: " + username + ". " if len(username) > Params.ProxyCfg.MaxUsernameLength.GetAsInt() { - msg := invalidMsg + "The length of username must be less than " + Params.ProxyCfg.MaxUsernameLength.GetValue() + " characters." - return errors.New(msg) + return merr.WrapErrParameterInvalidMsg("invalid username %s with length %d, the length of username must be less than %d", username, len(username), Params.ProxyCfg.MaxUsernameLength.GetValue()) } firstChar := username[0] if !isAlpha(firstChar) { - msg := invalidMsg + "The first character of username must be a letter." - return errors.New(msg) + return merr.WrapErrParameterInvalidMsg("invalid user name %s, the first character must be a letter, but got %s", username, firstChar) } usernameSize := len(username) for i := 1; i < usernameSize; i++ { c := username[i] if c != '_' && !isAlpha(c) && !isNumber(c) { - msg := invalidMsg + "Username should only contain numbers, letters, and underscores." - return errors.New(msg) + return merr.WrapErrParameterInvalidMsg("invalid user name %s, username must contain only numbers, letters and underscores, but got %s", username, c) } } return nil @@ -703,9 +756,9 @@ func ValidateUsername(username string) error { func ValidatePassword(password string) error { if len(password) < Params.ProxyCfg.MinPasswordLength.GetAsInt() || len(password) > Params.ProxyCfg.MaxPasswordLength.GetAsInt() { - msg := "The length of password must be great than " + Params.ProxyCfg.MinPasswordLength.GetValue() + - " and less than " + Params.ProxyCfg.MaxPasswordLength.GetValue() + " characters." - return errors.New(msg) + return merr.WrapErrParameterInvalidRange(Params.ProxyCfg.MinPasswordLength.GetAsInt(), + Params.ProxyCfg.MaxPasswordLength.GetAsInt(), + len(password), "invalid password length") } return nil } @@ -841,6 +894,19 @@ func GetCurDBNameFromContextOrDefault(ctx context.Context) string { return dbNameData[0] } +func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context { + originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username) + authKey := strings.ToLower(util.HeaderAuthorize) + authValue := crypto.Base64Encode(originValue) + dbKey := strings.ToLower(util.HeaderDBName) + contextMap := map[string]string{ + authKey: authValue, + dbKey: dbName, + } + md := metadata.New(contextMap) + return metadata.NewIncomingContext(ctx, md) +} + func GetRole(username string) ([]string, error) { if globalMetaCache == nil { return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait") @@ -852,6 +918,18 @@ func PasswordVerify(ctx context.Context, username, rawPwd string) bool { return passwordVerify(ctx, username, rawPwd, globalMetaCache) } +func VerifyAPIKey(rawToken string) (string, error) { + if hoo == nil { + return "", merr.WrapErrServiceInternal("internal: Milvus Proxy is not ready yet. please wait") + } + user, err := hoo.VerifyAPIKey(rawToken) + if err != nil { + log.Warn("fail to verify apikey", zap.String("api_key", rawToken), zap.Error(err)) + return "", merr.WrapErrParameterInvalidMsg("invalid apikey: [%s]", rawToken) + } + return user, nil +} + // PasswordVerify verify password func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCache Cache) bool { // it represents the cache miss if Sha256Password is empty within credInfo, which shall be updated first connection. @@ -881,6 +959,18 @@ func passwordVerify(ctx context.Context, username, rawPwd string, globalMetaCach return true } +func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int64) { + pkNames := []string{} + fieldIDs := []int64{} + for _, field := range schema.Fields { + if field.IsPrimaryKey { + pkNames = append(pkNames, field.GetName()) + fieldIDs = append(fieldIDs, field.GetFieldID()) + } + } + return pkNames, fieldIDs +} + // Support wildcard in output fields: // // "*" - all fields @@ -985,7 +1075,7 @@ func validateIndexName(indexName string) error { return nil } -func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collID int64) (bool, error) { +func isCollectionLoaded(ctx context.Context, qc types.QueryCoordClient, collID int64) (bool, error) { // get all loading collections resp, err := qc.ShowCollections(ctx, &querypb.ShowCollectionsRequest{ CollectionIDs: nil, @@ -993,8 +1083,8 @@ func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collID int64) if err != nil { return false, err } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return false, errors.New(resp.Status.Reason) + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return false, merr.Error(resp.GetStatus()) } for _, loadedCollID := range resp.GetCollectionIDs() { @@ -1005,7 +1095,7 @@ func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collID int64) return false, nil } -func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collID int64, partIDs []int64) (bool, error) { +func isPartitionLoaded(ctx context.Context, qc types.QueryCoordClient, collID int64, partIDs []int64) (bool, error) { // get all loading collections resp, err := qc.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{ CollectionID: collID, @@ -1014,8 +1104,8 @@ func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collID int64, p if err != nil { return false, err } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return false, errors.New(resp.Status.Reason) + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return false, merr.Error(resp.GetStatus()) } for _, loadedPartID := range resp.GetPartitionIDs() { @@ -1133,8 +1223,9 @@ func checkPrimaryFieldData(schema *schemapb.CollectionSchema, result *milvuspb.M // upsert has not supported when autoID == true log.Info("can not upsert when auto id enabled", zap.String("primaryFieldSchemaName", primaryFieldSchema.Name)) - result.Status.ErrorCode = commonpb.ErrorCode_UpsertAutoIDTrue - return nil, fmt.Errorf("upsert can not assign primary field data when auto id enabled %v", primaryFieldSchema.Name) + err := merr.WrapErrParameterInvalidMsg(fmt.Sprintf("upsert can not assign primary field data when auto id enabled %v", primaryFieldSchema.GetName())) + result.Status = merr.Status(err) + return nil, err } primaryFieldData, err = typeutil.GetPrimaryFieldData(insertMsg.GetFieldsData(), primaryFieldSchema) if err != nil { @@ -1169,7 +1260,7 @@ func getPartitionKeyFieldData(fieldSchema *schemapb.FieldSchema, insertMsg *msgs func getCollectionProgress( ctx context.Context, - queryCoord types.QueryCoord, + queryCoord types.QueryCoordClient, msgBase *commonpb.MsgBase, collectionID int64, ) (loadProgress int64, refreshProgress int64, err error) { @@ -1181,32 +1272,22 @@ func getCollectionProgress( CollectionIDs: []int64{collectionID}, }) if err != nil { - log.Warn("fail to show collections", zap.Int64("collection_id", collectionID), zap.Error(err)) - return - } - - if resp.Status.ErrorCode == commonpb.ErrorCode_InsufficientMemoryToLoad { - err = ErrInsufficientMemory - log.Warn("detected insufficientMemoryError when getCollectionProgress", zap.Int64("collection_id", collectionID), zap.String("reason", resp.GetStatus().GetReason())) - return - } - - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - err = merr.Error(resp.GetStatus()) - log.Warn("fail to show collections", zap.Int64("collection_id", collectionID), - zap.String("reason", resp.Status.Reason)) + log.Warn("fail to show collections", + zap.Int64("collectionID", collectionID), + zap.Error(err), + ) return } - if len(resp.InMemoryPercentages) == 0 { - errMsg := "fail to show collections from the querycoord, no data" - err = errors.New(errMsg) - log.Warn(errMsg, zap.Int64("collection_id", collectionID)) + err = merr.Error(resp.GetStatus()) + if err != nil { + log.Warn("fail to show collections", + zap.Int64("collectionID", collectionID), + zap.Error(err)) return } loadProgress = resp.GetInMemoryPercentages()[0] - if len(resp.GetRefreshProgress()) > 0 { // Compatibility for new Proxy with old QueryCoord refreshProgress = resp.GetRefreshProgress()[0] } @@ -1216,7 +1297,7 @@ func getCollectionProgress( func getPartitionProgress( ctx context.Context, - queryCoord types.QueryCoord, + queryCoord types.QueryCoordClient, msgBase *commonpb.MsgBase, partitionNames []string, collectionName string, @@ -1251,34 +1332,17 @@ func getPartitionProgress( zap.Error(err)) return } - if resp.GetStatus().GetErrorCode() == commonpb.ErrorCode_InsufficientMemoryToLoad { - err = ErrInsufficientMemory - log.Warn("detected insufficientMemoryError when getPartitionProgress", - zap.Int64("collection_id", collectionID), - zap.String("collection_name", collectionName), - zap.Strings("partition_names", partitionNames), - zap.String("reason", resp.GetStatus().GetReason()), - ) - return - } - if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + err = merr.Error(resp.GetStatus()) + if err != nil { err = merr.Error(resp.GetStatus()) log.Warn("fail to show partitions", - zap.String("collection_name", collectionName), - zap.Strings("partition_names", partitionNames), - zap.String("reason", resp.Status.Reason)) + zap.String("collectionName", collectionName), + zap.Strings("partitionNames", partitionNames), + zap.Error(err)) return } - if len(resp.InMemoryPercentages) != len(partitionIDs) { - errMsg := "fail to show partitions from the querycoord, invalid data num" - err = errors.New(errMsg) - log.Warn(errMsg, zap.Int64("collection_id", collectionID), - zap.String("collection_name", collectionName), - zap.Strings("partition_names", partitionNames)) - return - } for _, p := range resp.InMemoryPercentages { loadProgress += p } @@ -1440,3 +1504,76 @@ func checkDynamicFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstre insertMsg.FieldsData = append(insertMsg.FieldsData, dynamicData) return nil } + +func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream.MsgStream, request interface{ GetBase() *commonpb.MsgBase }) { + if replicateMsgStream == nil || request == nil { + log.Warn("replicate msg stream or request is nil", zap.Any("request", request)) + return + } + msgBase := request.GetBase() + ts := msgBase.GetTimestamp() + if msgBase.GetReplicateInfo().GetIsReplicate() { + ts = msgBase.GetReplicateInfo().GetMsgTimestamp() + } + getBaseMsg := func(ctx context.Context, ts uint64) msgstream.BaseMsg { + return msgstream.BaseMsg{ + Ctx: ctx, + HashValues: []uint32{0}, + BeginTimestamp: ts, + EndTimestamp: ts, + } + } + + var tsMsg msgstream.TsMsg + switch r := request.(type) { + case *milvuspb.CreateDatabaseRequest: + tsMsg = &msgstream.CreateDatabaseMsg{ + BaseMsg: getBaseMsg(ctx, ts), + CreateDatabaseRequest: *r, + } + case *milvuspb.DropDatabaseRequest: + tsMsg = &msgstream.DropDatabaseMsg{ + BaseMsg: getBaseMsg(ctx, ts), + DropDatabaseRequest: *r, + } + case *milvuspb.FlushRequest: + tsMsg = &msgstream.FlushMsg{ + BaseMsg: getBaseMsg(ctx, ts), + FlushRequest: *r, + } + case *milvuspb.LoadCollectionRequest: + tsMsg = &msgstream.LoadCollectionMsg{ + BaseMsg: getBaseMsg(ctx, ts), + LoadCollectionRequest: *r, + } + case *milvuspb.ReleaseCollectionRequest: + tsMsg = &msgstream.ReleaseCollectionMsg{ + BaseMsg: getBaseMsg(ctx, ts), + ReleaseCollectionRequest: *r, + } + case *milvuspb.CreateIndexRequest: + tsMsg = &msgstream.CreateIndexMsg{ + BaseMsg: getBaseMsg(ctx, ts), + CreateIndexRequest: *r, + } + case *milvuspb.DropIndexRequest: + tsMsg = &msgstream.DropIndexMsg{ + BaseMsg: getBaseMsg(ctx, ts), + DropIndexRequest: *r, + } + default: + log.Warn("unknown request", zap.Any("request", request)) + return + } + msgPack := &msgstream.MsgPack{ + BeginTs: ts, + EndTs: ts, + Msgs: []msgstream.TsMsg{tsMsg}, + } + msgErr := replicateMsgStream.Produce(msgPack) + // ignore the error if the msg stream failed to produce the msg, + // because it can be manually fixed in this error + if msgErr != nil { + log.Warn("send replicate msg failed", zap.Any("pack", msgPack), zap.Error(msgErr)) + } +} diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index c45b5349fa396..1bbc147f4da4e 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -25,11 +25,10 @@ import ( "testing" "time" - "github.com/milvus-io/milvus/pkg/log" - "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -41,6 +40,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/crypto" @@ -349,6 +349,7 @@ func TestValidatePrimaryKey(t *testing.T) { func TestValidateFieldType(t *testing.T) { type testCase struct { dt schemapb.DataType + et schemapb.DataType validate bool } cases := []testCase{ @@ -396,6 +397,80 @@ func TestValidateFieldType(t *testing.T) { dt: schemapb.DataType_VarChar, validate: true, }, + { + dt: schemapb.DataType_String, + validate: false, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_Bool, + validate: true, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_Int8, + validate: true, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_Int16, + validate: true, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_Int32, + validate: true, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_Int64, + validate: true, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_Float, + validate: true, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_Double, + validate: true, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_VarChar, + validate: true, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_String, + validate: false, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_None, + validate: false, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_JSON, + validate: false, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_Array, + validate: false, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_FloatVector, + validate: false, + }, + { + dt: schemapb.DataType_Array, + et: schemapb.DataType_BinaryVector, + validate: false, + }, } for _, tc := range cases { @@ -403,7 +478,8 @@ func TestValidateFieldType(t *testing.T) { sch := &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { - DataType: tc.dt, + DataType: tc.dt, + ElementType: tc.et, }, }, } @@ -833,7 +909,7 @@ func TestPasswordVerify(t *testing.T) { invokedCount := 0 mockedRootCoord := newMockRootCoord() - mockedRootCoord.GetGetCredentialFunc = func(ctx context.Context, req *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) { + mockedRootCoord.GetGetCredentialFunc = func(ctx context.Context, req *rootcoordpb.GetCredentialRequest, opts ...grpc.CallOption) (*rootcoordpb.GetCredentialResponse, error) { invokedCount++ return nil, fmt.Errorf("get cred not found credential") } @@ -872,7 +948,7 @@ func Test_isCollectionIsLoaded(t *testing.T) { ctx := context.Background() t.Run("normal", func(t *testing.T) { collID := int64(1) - qc := &mocks.MockQueryCoord{} + qc := &mocks.MockQueryCoordClient{} successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil) qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ @@ -896,7 +972,7 @@ func Test_isCollectionIsLoaded(t *testing.T) { t.Run("error", func(t *testing.T) { collID := int64(1) - qc := &mocks.MockQueryCoord{} + qc := &mocks.MockQueryCoordClient{} successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil) qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ @@ -920,7 +996,7 @@ func Test_isCollectionIsLoaded(t *testing.T) { t.Run("fail", func(t *testing.T) { collID := int64(1) - qc := &mocks.MockQueryCoord{} + qc := &mocks.MockQueryCoordClient{} successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil) qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ @@ -951,7 +1027,7 @@ func Test_isPartitionIsLoaded(t *testing.T) { t.Run("normal", func(t *testing.T) { collID := int64(1) partID := int64(2) - qc := &mocks.MockQueryCoord{} + qc := &mocks.MockQueryCoordClient{} successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil) qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ @@ -965,10 +1041,7 @@ func Test_isPartitionIsLoaded(t *testing.T) { }, }, nil) qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), PartitionIDs: []int64{partID}, }, nil) loaded, err := isPartitionLoaded(ctx, qc, collID, []int64{partID}) @@ -979,7 +1052,7 @@ func Test_isPartitionIsLoaded(t *testing.T) { t.Run("error", func(t *testing.T) { collID := int64(1) partID := int64(2) - qc := &mocks.MockQueryCoord{} + qc := &mocks.MockQueryCoordClient{} successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil) qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ @@ -993,10 +1066,7 @@ func Test_isPartitionIsLoaded(t *testing.T) { }, }, nil) qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, + Status: merr.Success(), PartitionIDs: []int64{partID}, }, errors.New("error")) loaded, err := isPartitionLoaded(ctx, qc, collID, []int64{partID}) @@ -1007,7 +1077,7 @@ func Test_isPartitionIsLoaded(t *testing.T) { t.Run("fail", func(t *testing.T) { collID := int64(1) partID := int64(2) - qc := &mocks.MockQueryCoord{} + qc := &mocks.MockQueryCoordClient{} successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(successStatus, nil) qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ @@ -1334,9 +1404,7 @@ func Test_InsertTaskCheckPrimaryFieldData(t *testing.T) { }, }, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, } @@ -1378,9 +1446,7 @@ func Test_InsertTaskCheckPrimaryFieldData(t *testing.T) { }, }, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, } _, err = checkPrimaryFieldData(case2.schema, case2.result, case2.insertMsg, true) @@ -1420,9 +1486,7 @@ func Test_InsertTaskCheckPrimaryFieldData(t *testing.T) { }, }, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, } _, err = checkPrimaryFieldData(case3.schema, case3.result, case3.insertMsg, true) @@ -1466,9 +1530,7 @@ func Test_InsertTaskCheckPrimaryFieldData(t *testing.T) { }, }, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, } case4.schema.Fields[0].IsPrimaryKey = true @@ -1507,9 +1569,7 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) { }, }, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, } _, err := checkPrimaryFieldData(case1.schema, case1.result, case1.insertMsg, false) @@ -1552,9 +1612,7 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) { }, }, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, } _, err = checkPrimaryFieldData(case2.schema, case2.result, case2.insertMsg, false) @@ -1594,9 +1652,7 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) { }, }, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, } _, err = checkPrimaryFieldData(case3.schema, case3.result, case3.insertMsg, false) @@ -1639,15 +1695,13 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) { }, }, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, } case4.schema.Fields[0].IsPrimaryKey = true case4.schema.Fields[0].AutoID = true _, err = checkPrimaryFieldData(case4.schema, case4.result, case4.insertMsg, false) - assert.Equal(t, commonpb.ErrorCode_UpsertAutoIDTrue, case4.result.Status.ErrorCode) + assert.ErrorIs(t, merr.Error(case4.result.GetStatus()), merr.ErrParameterInvalid) assert.NotEqual(t, nil, err) // primary field data is nil, GetPrimaryFieldData fail @@ -1685,9 +1739,7 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) { }, }, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, } case5.schema.Fields[0].IsPrimaryKey = true @@ -1736,9 +1788,7 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) { }, }, result: &milvuspb.MutationResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, } case6.schema.Fields[0].IsPrimaryKey = true @@ -1805,7 +1855,7 @@ func Test_MaxQueryResultWindow(t *testing.T) { } func Test_GetPartitionProgressFailed(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -1938,3 +1988,102 @@ func Test_CheckDynamicFieldData(t *testing.T) { assert.NoError(t, err) }) } + +func Test_validateMaxCapacityPerRow(t *testing.T) { + t.Run("normal case", func(t *testing.T) { + arrayField := &schemapb.FieldSchema{ + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxLengthKey, + Value: "100", + }, + { + Key: common.MaxCapacityKey, + Value: "10", + }, + }, + } + + err := validateMaxCapacityPerRow("collection", arrayField) + assert.NoError(t, err) + }) + + t.Run("no max capacity", func(t *testing.T) { + arrayField := &schemapb.FieldSchema{ + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + } + + err := validateMaxCapacityPerRow("collection", arrayField) + assert.Error(t, err) + }) + + t.Run("max capacity not int", func(t *testing.T) { + arrayField := &schemapb.FieldSchema{ + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "six", + }, + }, + } + + err := validateMaxCapacityPerRow("collection", arrayField) + assert.Error(t, err) + }) + + t.Run("max capacity exceed max", func(t *testing.T) { + arrayField := &schemapb.FieldSchema{ + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "4097", + }, + }, + } + + err := validateMaxCapacityPerRow("collection", arrayField) + assert.Error(t, err) + }) +} + +func TestSendReplicateMessagePack(t *testing.T) { + ctx := context.Background() + mockStream := msgstream.NewMockMsgStream(t) + + t.Run("empty case", func(t *testing.T) { + SendReplicateMessagePack(ctx, nil, nil) + }) + + t.Run("produce fail", func(t *testing.T) { + mockStream.EXPECT().Produce(mock.Anything).Return(errors.New("produce error")).Once() + SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{ + Base: &commonpb.MsgBase{ReplicateInfo: &commonpb.ReplicateInfo{ + IsReplicate: true, + MsgTimestamp: 100, + }}, + }) + }) + + t.Run("unknown request", func(t *testing.T) { + SendReplicateMessagePack(ctx, mockStream, &milvuspb.ListDatabasesRequest{}) + }) + + t.Run("normal case", func(t *testing.T) { + mockStream.EXPECT().Produce(mock.Anything).Return(nil) + + SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{}) + SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropDatabaseRequest{}) + SendReplicateMessagePack(ctx, mockStream, &milvuspb.FlushRequest{}) + SendReplicateMessagePack(ctx, mockStream, &milvuspb.LoadCollectionRequest{}) + SendReplicateMessagePack(ctx, mockStream, &milvuspb.ReleaseCollectionRequest{}) + SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateIndexRequest{}) + SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropIndexRequest{}) + }) +} diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go index 32ce11ad1cb64..6531f425cecaa 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/validate_util.go @@ -3,6 +3,9 @@ package proxy import ( "fmt" "math" + "reflect" + + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/log" @@ -11,13 +14,13 @@ import ( "github.com/milvus-io/milvus/pkg/util/parameterutil.go" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" ) type validateUtil struct { checkNAN bool checkMaxLen bool checkOverflow bool + checkMaxCap bool } type validateOption func(*validateUtil) @@ -40,6 +43,12 @@ func withOverflowCheck() validateOption { } } +func withMaxCapCheck() validateOption { + return func(v *validateUtil) { + v.checkMaxCap = true + } +} + func (v *validateUtil) apply(opts ...validateOption) { for _, opt := range opts { opt(v) @@ -63,6 +72,10 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.Col if err := v.checkFloatVectorFieldData(field, fieldSchema); err != nil { return err } + case schemapb.DataType_Float16Vector: + if err := v.checkFloat16VectorFieldData(field, fieldSchema); err != nil { + return err + } case schemapb.DataType_BinaryVector: if err := v.checkBinaryVectorFieldData(field, fieldSchema); err != nil { return err @@ -79,6 +92,11 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.Col if err := v.checkIntegerFieldData(field, fieldSchema); err != nil { return err } + case schemapb.DataType_Array: + if err := v.checkArrayFieldData(field, fieldSchema); err != nil { + return err + } + default: } } @@ -143,6 +161,26 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return errNumRowsMismatch(field.GetFieldName(), n, numRows) } + case schemapb.DataType_Float16Vector: + f, err := schema.GetFieldFromName(field.GetFieldName()) + if err != nil { + return err + } + + dim, err := typeutil.GetDim(f) + if err != nil { + return err + } + + n, err := funcutil.GetNumRowsOfFloat16VectorField(field.GetVectors().GetFloat16Vector(), dim) + if err != nil { + return err + } + + if n != numRows { + return errNumRowsMismatch(field.GetFieldName(), n, numRows) + } + default: // error won't happen here. n, err := funcutil.GetNumRowOfFieldData(field) @@ -249,6 +287,11 @@ func (v *validateUtil) checkFloatVectorFieldData(field *schemapb.FieldData, fiel return nil } +func (v *validateUtil) checkFloat16VectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { + // TODO + return nil +} + func (v *validateUtil) checkBinaryVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { // TODO return nil @@ -280,7 +323,7 @@ func (v *validateUtil) checkJSONFieldData(field *schemapb.FieldData, fieldSchema jsonArray := field.GetScalars().GetJsonData().GetData() if jsonArray == nil { msg := fmt.Sprintf("json field '%v' is illegal, array type mismatch", field.GetFieldName()) - return merr.WrapErrParameterInvalid("need string array", "got nil", msg) + return merr.WrapErrParameterInvalid("need json array", "got nil", msg) } if v.checkMaxLen { @@ -322,6 +365,103 @@ func (v *validateUtil) checkIntegerFieldData(field *schemapb.FieldData, fieldSch return nil } +func (v *validateUtil) checkArrayElement(array *schemapb.ArrayArray, field *schemapb.FieldSchema) error { + switch field.GetElementType() { + case schemapb.DataType_Bool: + for _, row := range array.GetData() { + actualType := reflect.TypeOf(row.GetData()) + if actualType != reflect.TypeOf((*schemapb.ScalarField_BoolData)(nil)) { + return merr.WrapErrParameterInvalid("bool array", + fmt.Sprintf("%s array", actualType.String()), "insert data does not match") + } + } + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: + for _, row := range array.GetData() { + actualType := reflect.TypeOf(row.GetData()) + if actualType != reflect.TypeOf((*schemapb.ScalarField_IntData)(nil)) { + return merr.WrapErrParameterInvalid("int array", + fmt.Sprintf("%s array", actualType.String()), "insert data does not match") + } + if v.checkOverflow { + if field.GetElementType() == schemapb.DataType_Int8 { + if err := verifyOverflowByRange(row.GetIntData().GetData(), math.MinInt8, math.MaxInt8); err != nil { + return err + } + } + if field.GetElementType() == schemapb.DataType_Int16 { + if err := verifyOverflowByRange(row.GetIntData().GetData(), math.MinInt16, math.MaxInt16); err != nil { + return err + } + } + } + } + case schemapb.DataType_Int64: + for _, row := range array.GetData() { + actualType := reflect.TypeOf(row.GetData()) + if actualType != reflect.TypeOf((*schemapb.ScalarField_LongData)(nil)) { + return merr.WrapErrParameterInvalid("int64 array", + fmt.Sprintf("%s array", actualType.String()), "insert data does not match") + } + } + case schemapb.DataType_Float: + for _, row := range array.GetData() { + actualType := reflect.TypeOf(row.GetData()) + if actualType != reflect.TypeOf((*schemapb.ScalarField_FloatData)(nil)) { + return merr.WrapErrParameterInvalid("float array", + fmt.Sprintf("%s array", actualType.String()), "insert data does not match") + } + } + case schemapb.DataType_Double: + for _, row := range array.GetData() { + actualType := reflect.TypeOf(row.GetData()) + if actualType != reflect.TypeOf((*schemapb.ScalarField_DoubleData)(nil)) { + return merr.WrapErrParameterInvalid("double array", + fmt.Sprintf("%s array", actualType.String()), "insert data does not match") + } + } + case schemapb.DataType_VarChar, schemapb.DataType_String: + for _, row := range array.GetData() { + actualType := reflect.TypeOf(row.GetData()) + if actualType != reflect.TypeOf((*schemapb.ScalarField_StringData)(nil)) { + return merr.WrapErrParameterInvalid("string array", + fmt.Sprintf("%s array", actualType.String()), "insert data does not match") + } + } + } + return nil +} + +func (v *validateUtil) checkArrayFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { + data := field.GetScalars().GetArrayData() + if data == nil { + elementTypeStr := fieldSchema.GetElementType().String() + msg := fmt.Sprintf("array field '%v' is illegal, array type mismatch", field.GetFieldName()) + expectStr := fmt.Sprintf("need %s array", elementTypeStr) + return merr.WrapErrParameterInvalid(expectStr, "got nil", msg) + } + if v.checkMaxCap { + maxCapacity, err := parameterutil.GetMaxCapacity(fieldSchema) + if err != nil { + return err + } + if err := verifyCapacityPerRow(data.GetData(), maxCapacity, fieldSchema.GetElementType()); err != nil { + return err + } + } + if typeutil.IsStringType(data.GetElementType()) && v.checkMaxLen { + maxLength, err := parameterutil.GetMaxLength(fieldSchema) + if err != nil { + return err + } + for _, row := range data.GetData() { + if err := verifyLengthPerRow(row.GetStringData().GetData(), maxLength); err != nil { + return err + } + } + } + return v.checkArrayElement(data, fieldSchema) +} + func verifyLengthPerRow[E interface{ ~string | ~[]byte }](strArr []E, maxLength int64) error { for i, s := range strArr { if int64(len(s)) > maxLength { @@ -333,6 +473,37 @@ func verifyLengthPerRow[E interface{ ~string | ~[]byte }](strArr []E, maxLength return nil } +func verifyCapacityPerRow(arrayArray []*schemapb.ScalarField, maxCapacity int64, elementType schemapb.DataType) error { + for i, array := range arrayArray { + arrayLen := 0 + switch elementType { + case schemapb.DataType_Bool: + arrayLen = len(array.GetBoolData().GetData()) + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: + arrayLen = len(array.GetIntData().GetData()) + case schemapb.DataType_Int64: + arrayLen = len(array.GetLongData().GetData()) + case schemapb.DataType_String, schemapb.DataType_VarChar: + arrayLen = len(array.GetStringData().GetData()) + case schemapb.DataType_Float: + arrayLen = len(array.GetFloatData().GetData()) + case schemapb.DataType_Double: + arrayLen = len(array.GetDoubleData().GetData()) + default: + msg := fmt.Sprintf("array element type: %s is not supported", elementType.String()) + return merr.WrapErrParameterInvalid("valid array element type", "array element type is not supported", msg) + } + + if int64(arrayLen) <= maxCapacity { + continue + } + msg := fmt.Sprintf("the length (%d) of %dth array exceeds max capacity (%d)", arrayLen, i, maxCapacity) + return merr.WrapErrParameterInvalid("valid length array", "array length exceeds max capacity", msg) + } + + return nil +} + func verifyOverflowByRange(arr []int32, lb int64, ub int64) error { for idx, e := range arr { if lb > int64(e) || ub < int64(e) { diff --git a/internal/proxy/validate_util_test.go b/internal/proxy/validate_util_test.go index 556da13b2b9c6..1bd2ed9a7dca4 100644 --- a/internal/proxy/validate_util_test.go +++ b/internal/proxy/validate_util_test.go @@ -523,6 +523,131 @@ func Test_validateUtil_checkAligned(t *testing.T) { assert.Error(t, err) }) + ////////////////////////////////////////////////////////////////////// + + t.Run("float16 vector column not found", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Float16Vector, + }, + } + + schema := &schemapb.CollectionSchema{} + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 100) + + assert.Error(t, err) + }) + + t.Run("float16 vector column dimension not found", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Float16Vector, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Float16Vector, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 100) + + assert.Error(t, err) + }) + + t.Run("invalid num rows", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Float16Vector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: []byte("not128"), + }, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Float16Vector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "128", + }, + }, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 100) + + assert.Error(t, err) + }) + + t.Run("num rows mismatch", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Float16Vector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: []byte{'1', '2'}, + }, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Float16Vector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 100) + + assert.Error(t, err) + }) + ////////////////////////////////////////////////////////////////// t.Run("mismatch", func(t *testing.T) { @@ -769,71 +894,626 @@ func Test_validateUtil_Validate(t *testing.T) { }, }, { - Name: "test2", - FieldID: 102, - DataType: schemapb.DataType_BinaryVector, + Name: "test2", + FieldID: 102, + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + }, + { + Name: "test3", + FieldID: 103, + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxLengthKey, + Value: "8", + }, + }, + }, + }, + } + + v := newValidateUtil(withNANCheck(), withMaxLenCheck()) + + err := v.Validate(data, schema, 2) + + assert.Error(t, err) + }) + + t.Run("length exceeds", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test1", + Type: schemapb.DataType_FloatVector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: generateFloatVectors(2, 1), + }, + }, + }, + }, + }, + { + FieldName: "test2", + Type: schemapb.DataType_BinaryVector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: generateBinaryVectors(2, 8), + }, + }, + }, + }, + { + FieldName: "test3", + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"very_long", "very_very_long"}, + }, + }, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test1", + FieldID: 101, + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "1", + }, + }, + }, + { + Name: "test2", + FieldID: 102, + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + }, + { + Name: "test3", + FieldID: 103, + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxLengthKey, + Value: "2", + }, + }, + }, + }, + } + + v := newValidateUtil(withNANCheck(), withMaxLenCheck()) + err := v.Validate(data, schema, 2) + assert.Error(t, err) + + // Validate JSON length + longBytes := make([]byte, paramtable.Get().CommonCfg.JSONMaxLength.GetAsInt()+1) + data = []*schemapb.FieldData{ + { + FieldName: "json", + Type: schemapb.DataType_JSON, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_JsonData{ + JsonData: &schemapb.JSONArray{ + Data: [][]byte{longBytes, longBytes}, + }, + }, + }, + }, + }, + } + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "json", + FieldID: 104, + DataType: schemapb.DataType_JSON, + }, + }, + } + err = v.Validate(data, schema, 2) + assert.Error(t, err) + }) + + t.Run("has overflow", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test1", + Type: schemapb.DataType_Int8, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{int32(math.MinInt8) - 1, int32(math.MaxInt8) + 1}, + }, + }, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test1", + FieldID: 101, + DataType: schemapb.DataType_Int8, + }, + }, + } + + v := newValidateUtil(withOverflowCheck()) + + err := v.Validate(data, schema, 2) + assert.Error(t, err) + }) + + t.Run("array data nil", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: nil, + }, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "8", + }, + }, + }, + }, + } + + v := newValidateUtil() + + err := v.Validate(data, schema, 100) + + assert.Error(t, err) + }) + + t.Run("exceed max capacity", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "2", + }, + }, + }, + }, + } + + v := newValidateUtil(withMaxCapCheck()) + + err := v.Validate(data, schema, 1) + + assert.Error(t, err) + }) + + t.Run("string element exceed max length", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"abcdefghijkl", "ajsgfuioabaxyaefilagskjfhgka"}, + }, + }, + }, + }, + ElementType: schemapb.DataType_VarChar, + }, + }, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "10", + }, + { + Key: common.MaxLengthKey, + Value: "5", + }, + }, + }, + }, + } + + v := newValidateUtil(withMaxCapCheck(), withMaxLenCheck()) + + err := v.Validate(data, schema, 1) + + assert.Error(t, err) + }) + + t.Run("no max capacity", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + TypeParams: []*commonpb.KeyValuePair{}, + }, + }, + } + + v := newValidateUtil(withMaxCapCheck()) + + err := v.Validate(data, schema, 1) + + assert.Error(t, err) + }) + + t.Run("unsupported element type", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_JSON, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "8", + }, + }, + }, + }, + } + + v := newValidateUtil(withMaxCapCheck()) + + err := v.Validate(data, schema, 1) + + assert.Error(t, err) + }) + + t.Run("element type not match", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: []bool{true, false}, + }, + }, + }, + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Bool, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + }, + }, + } + + v := newValidateUtil(withMaxCapCheck()) + err := v.Validate(data, schema, 1) + assert.Error(t, err) + + data = []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3}, + }, + }, + }, + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int8, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + }, + }, + } + + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + assert.Error(t, err) + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int16, TypeParams: []*commonpb.KeyValuePair{ { - Key: common.DimKey, - Value: "8", + Key: common.MaxCapacityKey, + Value: "100", }, }, }, + }, + } + + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + assert.Error(t, err) + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ { - Name: "test3", - FieldID: 103, - DataType: schemapb.DataType_VarChar, + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int32, TypeParams: []*commonpb.KeyValuePair{ { - Key: common.MaxLengthKey, - Value: "8", + Key: common.MaxCapacityKey, + Value: "100", }, }, }, }, } - v := newValidateUtil(withNANCheck(), withMaxLenCheck()) - - err := v.Validate(data, schema, 2) - + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) assert.Error(t, err) - }) - t.Run("length exceeds", func(t *testing.T) { - data := []*schemapb.FieldData{ + data = []*schemapb.FieldData{ { - FieldName: "test1", - Type: schemapb.DataType_FloatVector, - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Data: &schemapb.VectorField_FloatVector{ - FloatVector: &schemapb.FloatArray{ - Data: generateFloatVectors(2, 1), + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: []float32{1, 2, 3}, + }, + }, + }, + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, }, }, }, }, }, - { - FieldName: "test2", - Type: schemapb.DataType_BinaryVector, - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Data: &schemapb.VectorField_BinaryVector{ - BinaryVector: generateBinaryVectors(2, 8), + } + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Float, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", }, }, }, }, + } + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + assert.Error(t, err) + + data = []*schemapb.FieldData{ { - FieldName: "test3", - Type: schemapb.DataType_VarChar, + FieldName: "test", + Type: schemapb.DataType_Array, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_StringData{ - StringData: &schemapb.StringArray{ - Data: []string{"very_long", "very_very_long"}, + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: []float64{1, 2, 3}, + }, + }, + }, + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, }, }, }, @@ -841,88 +1521,144 @@ func Test_validateUtil_Validate(t *testing.T) { }, } - schema := &schemapb.CollectionSchema{ + schema = &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { - Name: "test1", - FieldID: 101, - DataType: schemapb.DataType_FloatVector, + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Double, TypeParams: []*commonpb.KeyValuePair{ { - Key: common.DimKey, - Value: "1", + Key: common.MaxCapacityKey, + Value: "100", }, }, }, - { - Name: "test2", - FieldID: 102, - DataType: schemapb.DataType_BinaryVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: "8", + }, + } + + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + assert.Error(t, err) + + data = []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"a", "b", "c"}, + }, + }, + }, + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, + }, }, }, }, + }, + } + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ { - Name: "test3", - FieldID: 103, - DataType: schemapb.DataType_VarChar, + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_VarChar, TypeParams: []*commonpb.KeyValuePair{ { - Key: common.MaxLengthKey, - Value: "2", + Key: common.MaxCapacityKey, + Value: "100", }, }, }, }, } - v := newValidateUtil(withNANCheck(), withMaxLenCheck()) - err := v.Validate(data, schema, 2) + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) assert.Error(t, err) - // Validate JSON length - longBytes := make([]byte, paramtable.Get().CommonCfg.JSONMaxLength.GetAsInt()+1) data = []*schemapb.FieldData{ { - FieldName: "json", - Type: schemapb.DataType_JSON, + FieldName: "test", + Type: schemapb.DataType_Array, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_JsonData{ - JsonData: &schemapb.JSONArray{ - Data: [][]byte{longBytes, longBytes}, + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + { + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, }, }, }, }, }, } + schema = &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { - Name: "json", - FieldID: 104, - DataType: schemapb.DataType_JSON, + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, }, }, } - err = v.Validate(data, schema, 2) + + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) assert.Error(t, err) }) - t.Run("has overflow", func(t *testing.T) { + t.Run("array element overflow", func(t *testing.T) { data := []*schemapb.FieldData{ { - FieldName: "test1", - Type: schemapb.DataType_Int8, + FieldName: "test", + Type: schemapb.DataType_Array, Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{ - Data: []int32{int32(math.MinInt8) - 1, int32(math.MaxInt8) + 1}, + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3, 1 << 9}, + }, + }, + }, + }, }, }, }, @@ -933,16 +1669,63 @@ func Test_validateUtil_Validate(t *testing.T) { schema := &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { - Name: "test1", - FieldID: 101, - DataType: schemapb.DataType_Int8, + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int8, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, }, }, } - v := newValidateUtil(withOverflowCheck()) + err := newValidateUtil(withMaxCapCheck(), withOverflowCheck()).Validate(data, schema, 1) + assert.Error(t, err) - err := v.Validate(data, schema, 2) + data = []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3, 1 << 9, 1 << 17}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int16, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + }, + }, + } + + err = newValidateUtil(withMaxCapCheck(), withOverflowCheck()).Validate(data, schema, 1) assert.Error(t, err) }) @@ -1011,6 +1794,174 @@ func Test_validateUtil_Validate(t *testing.T) { }, }, }, + { + FieldName: "bool_array", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: []bool{true, true}, + }, + }, + }, + { + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: []bool{false, false}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + FieldName: "int_array", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3}, + }, + }, + }, + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{4, 5, 6}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + FieldName: "long_array", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3}, + }, + }, + }, + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{4, 5, 6}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + FieldName: "string_array", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"abc", "def"}, + }, + }, + }, + { + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"hij", "jkl"}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + FieldName: "float_array", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: []float32{1.1, 2.2, 3.3}, + }, + }, + }, + { + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: []float32{4.4, 5.5, 6.6}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + FieldName: "double_array", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: []float64{1.2, 2.3, 3.4}, + }, + }, + }, + { + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: []float64{4.5, 5.6, 6.7}, + }, + }, + }, + }, + }, + }, + }, + }, + }, } schema := &schemapb.CollectionSchema{ @@ -1058,10 +2009,86 @@ func Test_validateUtil_Validate(t *testing.T) { FieldID: 105, DataType: schemapb.DataType_Int8, }, + { + Name: "bool_array", + FieldID: 106, + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Bool, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "10", + }, + }, + }, + { + Name: "int_array", + FieldID: 107, + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int16, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "10", + }, + }, + }, + { + Name: "long_array", + FieldID: 108, + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "10", + }, + }, + }, + { + Name: "string_array", + FieldID: 109, + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "10", + }, + { + Key: common.MaxLengthKey, + Value: "10", + }, + }, + }, + { + Name: "float_array", + FieldID: 110, + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Float, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "10", + }, + }, + }, + { + Name: "double_array", + FieldID: 111, + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Double, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "10", + }, + }, + }, }, } - v := newValidateUtil(withNANCheck(), withMaxLenCheck(), withOverflowCheck()) + v := newValidateUtil(withNANCheck(), withMaxLenCheck(), withOverflowCheck(), withMaxCapCheck()) err := v.Validate(data, schema, 2) @@ -1944,7 +2971,6 @@ func Test_validateUtil_fillWithDefaultValue(t *testing.T) { flag := checkFillWithDefaultValueData(data[0].GetScalars().GetStringData().Data, stringData[0], 1) assert.True(t, flag) }) - } func Test_verifyOverflowByRange(t *testing.T) { @@ -2018,7 +3044,6 @@ func Test_validateUtil_checkIntegerFieldData(t *testing.T) { }) t.Run("tiny int, overflow", func(t *testing.T) { - v := newValidateUtil(withOverflowCheck()) f := &schemapb.FieldSchema{ @@ -2041,7 +3066,6 @@ func Test_validateUtil_checkIntegerFieldData(t *testing.T) { }) t.Run("tiny int, normal case", func(t *testing.T) { - v := newValidateUtil(withOverflowCheck()) f := &schemapb.FieldSchema{ @@ -2064,7 +3088,6 @@ func Test_validateUtil_checkIntegerFieldData(t *testing.T) { }) t.Run("small int, overflow", func(t *testing.T) { - v := newValidateUtil(withOverflowCheck()) f := &schemapb.FieldSchema{ @@ -2087,7 +3110,6 @@ func Test_validateUtil_checkIntegerFieldData(t *testing.T) { }) t.Run("small int, normal case", func(t *testing.T) { - v := newValidateUtil(withOverflowCheck()) f := &schemapb.FieldSchema{ @@ -2108,7 +3130,6 @@ func Test_validateUtil_checkIntegerFieldData(t *testing.T) { err := v.checkIntegerFieldData(data, f) assert.NoError(t, err) }) - } func Test_validateUtil_checkJSONData(t *testing.T) { diff --git a/internal/querycoordv2/balance/balance.go b/internal/querycoordv2/balance/balance.go index 20a15e9c660e1..f06ff77ddc294 100644 --- a/internal/querycoordv2/balance/balance.go +++ b/internal/querycoordv2/balance/balance.go @@ -113,7 +113,7 @@ func (b *RoundRobinBalancer) AssignChannel(channels []*meta.DmChannel, nodes []i } func (b *RoundRobinBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) { - //TODO by chun.han + // TODO by chun.han return nil, nil } diff --git a/internal/querycoordv2/balance/balance_test.go b/internal/querycoordv2/balance/balance_test.go index f528834c4176b..4a9e8a8415cfb 100644 --- a/internal/querycoordv2/balance/balance_test.go +++ b/internal/querycoordv2/balance/balance_test.go @@ -19,11 +19,12 @@ package balance import ( "testing" + "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" - "github.com/stretchr/testify/suite" ) type BalanceTestSuite struct { diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index 383b42856f258..171f0fe21b3b4 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -149,7 +149,6 @@ func (b *RowCountBasedBalancer) BalanceReplica(replica *meta.Replica) ([]Segment } segmentsToMove = append(segmentsToMove, s) - } if rowCount < average { item := newNodeItem(rowCount, node) diff --git a/internal/querycoordv2/balance/rowcount_based_balancer_test.go b/internal/querycoordv2/balance/rowcount_based_balancer_test.go index 207ee8153d6cc..72eacc976f169 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer_test.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer_test.go @@ -397,7 +397,6 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() { suite.ElementsMatch(c.expectPlans, segmentPlans) }) } - } func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() { @@ -596,7 +595,6 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() { suite.ElementsMatch(c.expectPlans, segmentPlans) }) } - } func (suite *RowCountBasedBalancerTestSuite) TestBalanceOutboundNodes() { @@ -752,11 +750,11 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnLoadingCollection() { suite.ElementsMatch(c.expectPlans, segmentPlans) }) } - } func (suite *RowCountBasedBalancerTestSuite) getCollectionBalancePlans(balancer *RowCountBasedBalancer, - collectionID int64) ([]SegmentAssignPlan, []ChannelAssignPlan) { + collectionID int64, +) ([]SegmentAssignPlan, []ChannelAssignPlan) { replicas := balancer.meta.ReplicaManager.GetByCollection(collectionID) segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0) for _, replica := range replicas { diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index 428ae14604f3f..64bbd471620d2 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -39,7 +39,8 @@ func NewScoreBasedBalancer(scheduler task.Scheduler, nodeManager *session.NodeManager, dist *meta.DistributionManager, meta *meta.Meta, - targetMgr *meta.TargetManager) *ScoreBasedBalancer { + targetMgr *meta.TargetManager, +) *ScoreBasedBalancer { return &ScoreBasedBalancer{ RowCountBasedBalancer: NewRowCountBasedBalancer(scheduler, nodeManager, dist, meta, targetMgr), } @@ -162,7 +163,7 @@ func (b *ScoreBasedBalancer) BalanceReplica(replica *meta.Replica) ([]SegmentAss ) return nil, nil } - //print current distribution before generating plans + // print current distribution before generating plans segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0) if len(stoppingNodesSegments) != 0 { log.Info("Handle stopping nodes", @@ -268,7 +269,7 @@ func (b *ScoreBasedBalancer) getNormalSegmentPlan(replica *meta.Replica, nodesSe break } if targetSegmentToMove == nil { - //the node with the highest score doesn't have any segments suitable for balancing, stop balancing this round + // the node with the highest score doesn't have any segments suitable for balancing, stop balancing this round break } @@ -277,7 +278,7 @@ func (b *ScoreBasedBalancer) getNormalSegmentPlan(replica *meta.Replica, nodesSe nextToPriority := toPriority + int(targetSegmentToMove.GetNumOfRows()) + int(float64(targetSegmentToMove.GetNumOfRows())* params.Params.QueryCoordCfg.GlobalRowCountFactor.GetAsFloat()) - //still unbalanced after this balance plan is executed + // still unbalanced after this balance plan is executed if nextToPriority <= nextFromPriority { plan := SegmentAssignPlan{ ReplicaID: replica.GetID(), @@ -287,9 +288,9 @@ func (b *ScoreBasedBalancer) getNormalSegmentPlan(replica *meta.Replica, nodesSe } segmentPlans = append(segmentPlans, plan) } else { - //if unbalance reverted after balance action, we will consider the benefit - //only trigger following balance when the generated reverted balance - //is far smaller than the original unbalance + // if unbalance reverted after balance action, we will consider the benefit + // only trigger following balance when the generated reverted balance + // is far smaller than the original unbalance nextUnbalance := nextToPriority - nextFromPriority if float64(nextUnbalance)*params.Params.QueryCoordCfg.ReverseUnbalanceTolerationFactor.GetAsFloat() < unbalance { plan := SegmentAssignPlan{ @@ -300,14 +301,14 @@ func (b *ScoreBasedBalancer) getNormalSegmentPlan(replica *meta.Replica, nodesSe } segmentPlans = append(segmentPlans, plan) } else { - //if the tiniest segment movement between the highest scored node and lowest scored node will - //not provide sufficient balance benefit, we will seize balancing in this round + // if the tiniest segment movement between the highest scored node and lowest scored node will + // not provide sufficient balance benefit, we will seize balancing in this round break } } havingMovedSegments.Insert(targetSegmentToMove.GetID()) - //update node priority + // update node priority toNode.setPriority(nextToPriority) fromNode.setPriority(nextFromPriority) // if toNode and fromNode can not find segment to balance, break, else try to balance the next round diff --git a/internal/querycoordv2/balance/score_based_balancer_test.go b/internal/querycoordv2/balance/score_based_balancer_test.go index c65d5dfc2cc9c..cb1367f90b6d1 100644 --- a/internal/querycoordv2/balance/score_based_balancer_test.go +++ b/internal/querycoordv2/balance/score_based_balancer_test.go @@ -106,14 +106,20 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() { segmentCnts: []int{0, 0, 0}, expectPlans: [][]SegmentAssignPlan{ { - //as assign segments is used while loading collection, - //all assignPlan should have weight equal to 1(HIGH PRIORITY) - {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 15, - CollectionID: 1}}, From: -1, To: 1}, - {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 10, - CollectionID: 1}}, From: -1, To: 3}, - {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 5, - CollectionID: 1}}, From: -1, To: 2}, + // as assign segments is used while loading collection, + // all assignPlan should have weight equal to 1(HIGH PRIORITY) + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ + ID: 3, NumOfRows: 15, + CollectionID: 1, + }}, From: -1, To: 1}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ + ID: 2, NumOfRows: 10, + CollectionID: 1, + }}, From: -1, To: 3}, + {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ + ID: 1, NumOfRows: 5, + CollectionID: 1, + }}, From: -1, To: 2}, }, }, }, @@ -125,20 +131,20 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() { 1: { {SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 10, CollectionID: 1}, Node: 1}, {SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 300, CollectionID: 2}, Node: 1}, - //base: collection1-node1-priority is 10 + 0.1 * 310 = 41 - //assign3: collection1-node1-priority is 15 + 0.1 * 315 = 46.5 + // base: collection1-node1-priority is 10 + 0.1 * 310 = 41 + // assign3: collection1-node1-priority is 15 + 0.1 * 315 = 46.5 }, 2: { {SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 20, CollectionID: 1}, Node: 2}, {SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 180, CollectionID: 2}, Node: 2}, - //base: collection1-node2-priority is 20 + 0.1 * 200 = 40 - //assign2: collection1-node2-priority is 30 + 0.1 * 210 = 51 + // base: collection1-node2-priority is 20 + 0.1 * 200 = 40 + // assign2: collection1-node2-priority is 30 + 0.1 * 210 = 51 }, 3: { {SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 30, CollectionID: 1}, Node: 3}, {SegmentInfo: &datapb.SegmentInfo{ID: 6, NumOfRows: 20, CollectionID: 2}, Node: 3}, - //base: collection1-node2-priority is 30 + 0.1 * 50 = 35 - //assign1: collection1-node2-priority is 45 + 0.1 * 65 = 51.5 + // base: collection1-node2-priority is 30 + 0.1 * 50 = 35 + // assign1: collection1-node2-priority is 45 + 0.1 * 65 = 51.5 }, }, assignments: [][]*meta.Segment{ @@ -190,10 +196,10 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() { states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, segmentCnts: []int{0, 0, 0}, expectPlans: [][]SegmentAssignPlan{ - //note that these two segments plans are absolutely unbalanced globally, - //as if the assignment for collection1 could succeed, node1 and node2 will both have 70 rows - //much more than node3, but following assignment will still assign segment based on [10,20,40] - //rather than [70,70,40], this flaw will be mitigated by balance process and maybe fixed in the later versions + // note that these two segments plans are absolutely unbalanced globally, + // as if the assignment for collection1 could succeed, node1 and node2 will both have 70 rows + // much more than node3, but following assignment will still assign segment based on [10,20,40] + // rather than [70,70,40], this flaw will be mitigated by balance process and maybe fixed in the later versions { {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 60, CollectionID: 1}}, From: -1, To: 1}, {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 50, CollectionID: 1}}, From: -1, To: 2}, @@ -292,7 +298,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() { defer suite.TearDownTest() balancer := suite.balancer - //1. set up target for multi collections + // 1. set up target for multi collections collection := utils.CreateTestCollection(c.collectionID, int32(c.replicaID)) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, c.collectionID).Return( nil, c.collectionsSegments, nil) @@ -305,7 +311,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() { balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) - //2. set up target for distribution for multi collections + // 2. set up target for distribution for multi collections for node, s := range c.distributions { balancer.dist.SegmentDistManager.Update(node, s...) } @@ -313,7 +319,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() { balancer.dist.ChannelDistManager.Update(node, v...) } - //3. set up nodes info and resourceManager for balancer + // 3. set up nodes info and resourceManager for balancer for i := range c.nodes { nodeInfo := session.NewNodeInfo(c.nodes[i], "127.0.0.1:0") nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) @@ -322,7 +328,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() { suite.balancer.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, c.nodes[i]) } - //4. balance and verify result + // 4. balance and verify result segmentPlans, channelPlans := suite.getCollectionBalancePlans(balancer, c.collectionID) suite.ElementsMatch(c.expectChannelPlans, channelPlans) suite.ElementsMatch(c.expectPlans, segmentPlans) @@ -384,8 +390,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() { }, expectPlans: [][]SegmentAssignPlan{ { - {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 20}, - Node: 2}, From: 2, To: 3, ReplicaID: 1, + { + Segment: &meta.Segment{ + SegmentInfo: &datapb.SegmentInfo{ID: 3, CollectionID: 1, NumOfRows: 20}, + Node: 2, + }, From: 2, To: 3, ReplicaID: 1, }, }, {}, @@ -396,11 +405,9 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() { defer suite.TearDownTest() balancer := suite.balancer - //1. set up target for multi collections - collections := make([]*meta.Collection, 0, len(balanceCase.collectionIDs)) + // 1. set up target for multi collections for i := range balanceCase.collectionIDs { collection := utils.CreateTestCollection(balanceCase.collectionIDs[i], int32(balanceCase.replicaIDs[i])) - collections = append(collections, collection) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, balanceCase.collectionIDs[i]).Return( nil, balanceCase.collectionsSegments[i], nil) @@ -415,12 +422,12 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() { balancer.targetMgr.UpdateCollectionCurrentTarget(balanceCase.collectionIDs[i]) } - //2. set up target for distribution for multi collections + // 2. set up target for distribution for multi collections for node, s := range balanceCase.distributions[0] { balancer.dist.SegmentDistManager.Update(node, s...) } - //3. set up nodes info and resourceManager for balancer + // 3. set up nodes info and resourceManager for balancer for i := range balanceCase.nodes { nodeInfo := session.NewNodeInfo(balanceCase.nodes[i], "127.0.0.1:0") nodeInfo.SetState(balanceCase.states[i]) @@ -428,16 +435,16 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() { suite.balancer.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, balanceCase.nodes[i]) } - //4. first round balance + // 4. first round balance segmentPlans, _ := suite.getCollectionBalancePlans(balancer, balanceCase.collectionIDs[0]) suite.ElementsMatch(balanceCase.expectPlans[0], segmentPlans) - //5. update segment distribution to simulate balance effect + // 5. update segment distribution to simulate balance effect for node, s := range balanceCase.distributions[1] { balancer.dist.SegmentDistManager.Update(node, s...) } - //6. balance again + // 6. balance again segmentPlans, _ = suite.getCollectionBalancePlans(balancer, balanceCase.collectionIDs[1]) suite.ElementsMatch(balanceCase.expectPlans[1], segmentPlans) } @@ -477,10 +484,14 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() { }, }, expectPlans: []SegmentAssignPlan{ - {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, - Node: 1}, From: 1, To: 3, ReplicaID: 1}, - {Segment: &meta.Segment{SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, - Node: 1}, From: 1, To: 3, ReplicaID: 1}, + {Segment: &meta.Segment{ + SegmentInfo: &datapb.SegmentInfo{ID: 2, CollectionID: 1, NumOfRows: 20}, + Node: 1, + }, From: 1, To: 3, ReplicaID: 1}, + {Segment: &meta.Segment{ + SegmentInfo: &datapb.SegmentInfo{ID: 1, CollectionID: 1, NumOfRows: 10}, + Node: 1, + }, From: 1, To: 3, ReplicaID: 1}, }, expectChannelPlans: []ChannelAssignPlan{}, }, @@ -538,7 +549,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() { defer suite.TearDownTest() balancer := suite.balancer - //1. set up target for multi collections + // 1. set up target for multi collections collection := utils.CreateTestCollection(c.collectionID, int32(c.replicaID)) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, c.collectionID).Return( nil, c.collectionsSegments, nil) @@ -551,7 +562,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() { balancer.targetMgr.UpdateCollectionNextTarget(c.collectionID) balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionID) - //2. set up target for distribution for multi collections + // 2. set up target for distribution for multi collections for node, s := range c.distributions { balancer.dist.SegmentDistManager.Update(node, s...) } @@ -559,7 +570,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() { balancer.dist.ChannelDistManager.Update(node, v...) } - //3. set up nodes info and resourceManager for balancer + // 3. set up nodes info and resourceManager for balancer for i := range c.nodes { nodeInfo := session.NewNodeInfo(c.nodes[i], "127.0.0.1:0") nodeInfo.UpdateStats(session.WithChannelCnt(len(c.distributionChannels[c.nodes[i]]))) @@ -572,7 +583,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() { suite.balancer.meta.ResourceManager.UnassignNode(meta.DefaultResourceGroupName, c.outBoundNodes[i]) } - //4. balance and verify result + // 4. balance and verify result segmentPlans, channelPlans := suite.getCollectionBalancePlans(suite.balancer, c.collectionID) suite.ElementsMatch(c.expectChannelPlans, channelPlans) suite.ElementsMatch(c.expectPlans, segmentPlans) @@ -585,7 +596,8 @@ func TestScoreBasedBalancerSuite(t *testing.T) { } func (suite *ScoreBasedBalancerTestSuite) getCollectionBalancePlans(balancer *ScoreBasedBalancer, - collectionID int64) ([]SegmentAssignPlan, []ChannelAssignPlan) { + collectionID int64, +) ([]SegmentAssignPlan, []ChannelAssignPlan) { replicas := balancer.meta.ReplicaManager.GetByCollection(collectionID) segmentPlans, channelPlans := make([]SegmentAssignPlan, 0), make([]ChannelAssignPlan, 0) for _, replica := range replicas { diff --git a/internal/querycoordv2/balance/utils.go b/internal/querycoordv2/balance/utils.go index 185db59f4bc0a..16d5ede8d7cfb 100644 --- a/internal/querycoordv2/balance/utils.go +++ b/internal/querycoordv2/balance/utils.go @@ -78,7 +78,7 @@ func CreateSegmentTasksFromPlans(ctx context.Context, checkerID int64, timeout t // from balance checker t.SetPriority(task.TaskPriorityLow) } else { - //from segment checker + // from segment checker t.SetPriority(task.TaskPriorityNormal) } ret = append(ret, t) @@ -124,7 +124,8 @@ func CreateChannelTasksFromPlans(ctx context.Context, checkerID int64, timeout t } func PrintNewBalancePlans(collectionID int64, replicaID int64, segmentPlans []SegmentAssignPlan, - channelPlans []ChannelAssignPlan) { + channelPlans []ChannelAssignPlan, +) { balanceInfo := fmt.Sprintf("%s new plans:{collectionID:%d, replicaID:%d, ", PlanInfoPrefix, collectionID, replicaID) for _, segmentPlan := range segmentPlans { balanceInfo += segmentPlan.ToString() @@ -138,9 +139,10 @@ func PrintNewBalancePlans(collectionID int64, replicaID int64, segmentPlans []Se func PrintCurrentReplicaDist(replica *meta.Replica, stoppingNodesSegments map[int64][]*meta.Segment, nodeSegments map[int64][]*meta.Segment, - channelManager *meta.ChannelDistManager, segmentDistMgr *meta.SegmentDistManager) { + channelManager *meta.ChannelDistManager, segmentDistMgr *meta.SegmentDistManager, +) { distInfo := fmt.Sprintf("%s {collectionID:%d, replicaID:%d, ", DistInfoPrefix, replica.CollectionID, replica.GetID()) - //1. print stopping nodes segment distribution + // 1. print stopping nodes segment distribution distInfo += "[stoppingNodesSegmentDist:" for stoppingNodeID, stoppedSegments := range stoppingNodesSegments { distInfo += fmt.Sprintf("[nodeID:%d, ", stoppingNodeID) @@ -151,7 +153,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica, distInfo += "]]" } distInfo += "]" - //2. print normal nodes segment distribution + // 2. print normal nodes segment distribution distInfo += "[normalNodesSegmentDist:" for normalNodeID, normalNodeCollectionSegments := range nodeSegments { distInfo += fmt.Sprintf("[nodeID:%d, ", normalNodeID) @@ -171,7 +173,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica, } distInfo += "]" - //3. print stopping nodes channel distribution + // 3. print stopping nodes channel distribution distInfo += "[stoppingNodesChannelDist:" for stoppingNodeID := range stoppingNodesSegments { stoppingNodeChannels := channelManager.GetByCollectionAndNode(replica.GetCollectionID(), stoppingNodeID) @@ -184,7 +186,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica, } distInfo += "]" - //4. print normal nodes channel distribution + // 4. print normal nodes channel distribution distInfo += "[normalNodesChannelDist:" for normalNodeID := range nodeSegments { normalNodeChannels := channelManager.GetByCollectionAndNode(replica.GetCollectionID(), normalNodeID) diff --git a/internal/querycoordv2/checkers/balance_checker.go b/internal/querycoordv2/checkers/balance_checker.go index eb417ea16151d..8444392fd67a3 100644 --- a/internal/querycoordv2/checkers/balance_checker.go +++ b/internal/querycoordv2/checkers/balance_checker.go @@ -21,6 +21,9 @@ import ( "sort" "time" + "github.com/samber/lo" + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/balance" "github.com/milvus-io/milvus/internal/querycoordv2/meta" @@ -29,9 +32,6 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" - - "github.com/samber/lo" - "go.uber.org/zap" ) // BalanceChecker checks the cluster distribution and generates balance tasks. @@ -84,12 +84,12 @@ func (b *BalanceChecker) replicasToBalance() []int64 { } } } - //do stopping balance only in this round + // do stopping balance only in this round if len(stoppingReplicas) > 0 { return stoppingReplicas } - //no stopping balance and auto balance is disabled, return empty collections for balance + // no stopping balance and auto balance is disabled, return empty collections for balance if !Params.QueryCoordCfg.AutoBalance.GetAsBool() { return nil } @@ -98,7 +98,7 @@ func (b *BalanceChecker) replicasToBalance() []int64 { return nil } - //iterator one normal collection in one round + // iterator one normal collection in one round normalReplicasToBalance := make([]int64, 0) hasUnbalancedCollection := false for _, cid := range loadedCollections { diff --git a/internal/querycoordv2/checkers/balance_checker_test.go b/internal/querycoordv2/checkers/balance_checker_test.go index 9b0451b12000e..f15bb2b494799 100644 --- a/internal/querycoordv2/checkers/balance_checker_test.go +++ b/internal/querycoordv2/checkers/balance_checker_test.go @@ -20,6 +20,9 @@ import ( "context" "testing" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" @@ -32,9 +35,6 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" - - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" ) type BalanceCheckerTestSuite struct { @@ -83,7 +83,7 @@ func (suite *BalanceCheckerTestSuite) TearDownTest() { } func (suite *BalanceCheckerTestSuite) TestAutoBalanceConf() { - //set up nodes info + // set up nodes info nodeID1, nodeID2 := 1, 2 suite.nodeMgr.Add(session.NewNodeInfo(int64(nodeID1), "localhost")) suite.nodeMgr.Add(session.NewNodeInfo(int64(nodeID2), "localhost")) @@ -105,7 +105,7 @@ func (suite *BalanceCheckerTestSuite) TestAutoBalanceConf() { suite.checker.meta.CollectionManager.PutCollection(collection2) suite.checker.meta.ReplicaManager.Put(replica2) - //test disable auto balance + // test disable auto balance paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "false") suite.scheduler.EXPECT().GetSegmentTaskNum().Maybe().Return(func() int { return 0 @@ -115,22 +115,22 @@ func (suite *BalanceCheckerTestSuite) TestAutoBalanceConf() { segPlans, _ := suite.checker.balanceReplicas(replicasToBalance) suite.Empty(segPlans) - //test enable auto balance + // test enable auto balance paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") idsToBalance := []int64{int64(replicaID1)} replicasToBalance = suite.checker.replicasToBalance() suite.ElementsMatch(idsToBalance, replicasToBalance) - //next round + // next round idsToBalance = []int64{int64(replicaID2)} replicasToBalance = suite.checker.replicasToBalance() suite.ElementsMatch(idsToBalance, replicasToBalance) - //final round + // final round replicasToBalance = suite.checker.replicasToBalance() suite.Empty(replicasToBalance) } func (suite *BalanceCheckerTestSuite) TestBusyScheduler() { - //set up nodes info + // set up nodes info nodeID1, nodeID2 := 1, 2 suite.nodeMgr.Add(session.NewNodeInfo(int64(nodeID1), "localhost")) suite.nodeMgr.Add(session.NewNodeInfo(int64(nodeID2), "localhost")) @@ -152,7 +152,7 @@ func (suite *BalanceCheckerTestSuite) TestBusyScheduler() { suite.checker.meta.CollectionManager.PutCollection(collection2) suite.checker.meta.ReplicaManager.Put(replica2) - //test scheduler busy + // test scheduler busy paramtable.Get().Save(Params.QueryCoordCfg.AutoBalance.Key, "true") suite.scheduler.EXPECT().GetSegmentTaskNum().Maybe().Return(func() int { return 1 @@ -164,7 +164,7 @@ func (suite *BalanceCheckerTestSuite) TestBusyScheduler() { } func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { - //set up nodes info, stopping node1 + // set up nodes info, stopping node1 nodeID1, nodeID2 := 1, 2 suite.nodeMgr.Add(session.NewNodeInfo(int64(nodeID1), "localhost")) suite.nodeMgr.Add(session.NewNodeInfo(int64(nodeID2), "localhost")) @@ -187,12 +187,12 @@ func (suite *BalanceCheckerTestSuite) TestStoppingBalance() { suite.checker.meta.CollectionManager.PutCollection(collection2) suite.checker.meta.ReplicaManager.Put(replica2) - //test stopping balance + // test stopping balance idsToBalance := []int64{int64(replicaID1), int64(replicaID2)} replicasToBalance := suite.checker.replicasToBalance() suite.ElementsMatch(idsToBalance, replicasToBalance) - //checker check + // checker check segPlans, chanPlans := make([]balance.SegmentAssignPlan, 0), make([]balance.ChannelAssignPlan, 0) mockPlan := balance.SegmentAssignPlan{ Segment: utils.CreateTestSegment(1, 1, 1, 1, 1, "1"), diff --git a/internal/querycoordv2/checkers/channel_checker.go b/internal/querycoordv2/checkers/channel_checker.go index bcb1689693d39..046cbc45fc243 100644 --- a/internal/querycoordv2/checkers/channel_checker.go +++ b/internal/querycoordv2/checkers/channel_checker.go @@ -110,7 +110,8 @@ func (c *ChannelChecker) checkReplica(ctx context.Context, replica *meta.Replica // GetDmChannelDiff get channel diff between target and dist func (c *ChannelChecker) getDmChannelDiff(collectionID int64, - replicaID int64) (toLoad, toRelease []*meta.DmChannel) { + replicaID int64, +) (toLoad, toRelease []*meta.DmChannel) { replica := c.meta.Get(replicaID) if replica == nil { log.Info("replica does not exist, skip it") @@ -135,7 +136,7 @@ func (c *ChannelChecker) getDmChannelDiff(collectionID int64, } } - //get channels which exists on next target, but not on dist + // get channels which exists on next target, but not on dist for name, channel := range nextTargetMap { _, existOnDist := distMap[name] if !existOnDist { diff --git a/internal/querycoordv2/checkers/controller.go b/internal/querycoordv2/checkers/controller.go index 39f0c2f58543c..b238b16255dcf 100644 --- a/internal/querycoordv2/checkers/controller.go +++ b/internal/querycoordv2/checkers/controller.go @@ -21,28 +21,30 @@ import ( "sync" "time" + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/querycoordv2/balance" "github.com/milvus-io/milvus/internal/querycoordv2/meta" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/pkg/log" - "go.uber.org/zap" ) -var ( - checkRoundTaskNumLimit = 256 +const ( + segmentChecker = "segment_checker" + channelChecker = "channel_checker" + balanceChecker = "balance_checker" + indexChecker = "index_checker" ) var ( - Segment_Checker = "segment_checker" - Channel_Checker = "channel_checker" - Balance_Checker = "balance_checker" - Index_Checker = "index_checker" + checkRoundTaskNumLimit = 256 + checkerOrder = []string{channelChecker, segmentChecker, balanceChecker, indexChecker} ) type CheckerController struct { - stopCh chan struct{} + cancel context.CancelFunc manualCheckChs map[string]chan struct{} meta *meta.Meta dist *meta.DistributionManager @@ -66,29 +68,27 @@ func NewCheckerController( scheduler task.Scheduler, broker meta.Broker, ) *CheckerController { - // CheckerController runs checkers with the order, // the former checker has higher priority checkers := map[string]Checker{ - Channel_Checker: NewChannelChecker(meta, dist, targetMgr, balancer), - Segment_Checker: NewSegmentChecker(meta, dist, targetMgr, balancer, nodeMgr), - Balance_Checker: NewBalanceChecker(meta, balancer, nodeMgr, scheduler), - Index_Checker: NewIndexChecker(meta, dist, broker), + channelChecker: NewChannelChecker(meta, dist, targetMgr, balancer), + segmentChecker: NewSegmentChecker(meta, dist, targetMgr, balancer, nodeMgr), + balanceChecker: NewBalanceChecker(meta, balancer, nodeMgr, scheduler), + indexChecker: NewIndexChecker(meta, dist, broker), } id := 0 - for _, checker := range checkers { - checker.SetID(int64(id + 1)) + for _, checkerName := range checkerOrder { + checkers[checkerName].SetID(int64(id + 1)) } manualCheckChs := map[string]chan struct{}{ - Channel_Checker: make(chan struct{}, 1), - Segment_Checker: make(chan struct{}, 1), - Balance_Checker: make(chan struct{}, 1), + channelChecker: make(chan struct{}, 1), + segmentChecker: make(chan struct{}, 1), + balanceChecker: make(chan struct{}, 1), } return &CheckerController{ - stopCh: make(chan struct{}), manualCheckChs: manualCheckChs, meta: meta, dist: dist, @@ -99,29 +99,31 @@ func NewCheckerController( } } -func (controller *CheckerController) Start(ctx context.Context) { +func (controller *CheckerController) Start() { + ctx, cancel := context.WithCancel(context.Background()) + controller.cancel = cancel + for checkerType := range controller.checkers { - go controller.StartChecker(ctx, checkerType) + go controller.startChecker(ctx, checkerType) } } func getCheckerInterval(checkerType string) time.Duration { switch checkerType { - case Segment_Checker: + case segmentChecker: return Params.QueryCoordCfg.SegmentCheckInterval.GetAsDuration(time.Millisecond) - case Channel_Checker: + case channelChecker: return Params.QueryCoordCfg.ChannelCheckInterval.GetAsDuration(time.Millisecond) - case Balance_Checker: + case balanceChecker: return Params.QueryCoordCfg.BalanceCheckInterval.GetAsDuration(time.Millisecond) - case Index_Checker: + case indexChecker: return Params.QueryCoordCfg.IndexCheckInterval.GetAsDuration(time.Millisecond) default: return Params.QueryCoordCfg.CheckInterval.GetAsDuration(time.Millisecond) } - } -func (controller *CheckerController) StartChecker(ctx context.Context, checkerType string) { +func (controller *CheckerController) startChecker(ctx context.Context, checkerType string) { interval := getCheckerInterval(checkerType) ticker := time.NewTicker(interval) defer ticker.Stop() @@ -129,11 +131,6 @@ func (controller *CheckerController) StartChecker(ctx context.Context, checkerTy for { select { case <-ctx.Done(): - log.Info("Checker stopped due to context canceled", - zap.String("type", checkerType)) - return - - case <-controller.stopCh: log.Info("Checker stopped", zap.String("type", checkerType)) return @@ -144,14 +141,16 @@ func (controller *CheckerController) StartChecker(ctx context.Context, checkerTy case <-controller.manualCheckChs[checkerType]: ticker.Stop() controller.check(ctx, checkerType) - ticker.Reset(Params.QueryCoordCfg.CheckInterval.GetAsDuration(time.Millisecond)) + ticker.Reset(interval) } } } func (controller *CheckerController) Stop() { controller.stopOnce.Do(func() { - close(controller.stopCh) + if controller.cancel != nil { + controller.cancel() + } }) } diff --git a/internal/querycoordv2/checkers/controller_test.go b/internal/querycoordv2/checkers/controller_test.go index eca3e01047541..6df196c9b8d09 100644 --- a/internal/querycoordv2/checkers/controller_test.go +++ b/internal/querycoordv2/checkers/controller_test.go @@ -17,10 +17,13 @@ package checkers import ( - "context" "testing" "time" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "go.uber.org/atomic" + "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" @@ -33,9 +36,6 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" - "go.uber.org/atomic" ) type CheckerControllerSuite struct { @@ -126,8 +126,7 @@ func (suite *CheckerControllerSuite) TestBasic() { suite.balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything).Return(nil) suite.balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything).Return(nil) - ctx := context.Background() - suite.controller.Start(ctx) + suite.controller.Start() defer suite.controller.Stop() suite.Eventually(func() bool { diff --git a/internal/querycoordv2/checkers/index_checker.go b/internal/querycoordv2/checkers/index_checker.go index fd935ee4ebb71..48471f928b2e4 100644 --- a/internal/querycoordv2/checkers/index_checker.go +++ b/internal/querycoordv2/checkers/index_checker.go @@ -20,14 +20,15 @@ import ( "context" "time" + "github.com/samber/lo" + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/samber/lo" - "go.uber.org/zap" ) var _ Checker = (*IndexChecker)(nil) diff --git a/internal/querycoordv2/checkers/index_checker_test.go b/internal/querycoordv2/checkers/index_checker_test.go index f17fc453f1513..fe1b9774a8359 100644 --- a/internal/querycoordv2/checkers/index_checker_test.go +++ b/internal/querycoordv2/checkers/index_checker_test.go @@ -21,6 +21,9 @@ import ( "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" @@ -32,8 +35,6 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" ) type IndexCheckerSuite struct { diff --git a/internal/querycoordv2/checkers/segment_checker.go b/internal/querycoordv2/checkers/segment_checker.go index 4360d9703cab7..56fc61e9aa692 100644 --- a/internal/querycoordv2/checkers/segment_checker.go +++ b/internal/querycoordv2/checkers/segment_checker.go @@ -18,6 +18,7 @@ package checkers import ( "context" + "sort" "time" "github.com/samber/lo" @@ -32,7 +33,6 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/typeutil" ) type SegmentChecker struct { @@ -73,12 +73,12 @@ func (c *SegmentChecker) readyToCheck(collectionID int64) bool { func (c *SegmentChecker) Check(ctx context.Context) []task.Task { collectionIDs := c.meta.CollectionManager.GetAll() - tasks := make([]task.Task, 0) + results := make([]task.Task, 0) for _, cid := range collectionIDs { if c.readyToCheck(cid) { replicas := c.meta.ReplicaManager.GetByCollection(cid) for _, r := range replicas { - tasks = append(tasks, c.checkReplica(ctx, r)...) + results = append(results, c.checkReplica(ctx, r)...) } } } @@ -86,9 +86,11 @@ func (c *SegmentChecker) Check(ctx context.Context) []task.Task { // find already released segments which are not contained in target segments := c.dist.SegmentDistManager.GetAll() released := utils.FilterReleased(segments, collectionIDs) - tasks = append(tasks, c.createSegmentReduceTasks(ctx, released, -1, querypb.DataScope_Historical)...) - task.SetPriority(task.TaskPriorityNormal, tasks...) - return tasks + reduceTasks := c.createSegmentReduceTasks(ctx, released, -1, querypb.DataScope_Historical) + task.SetReason("collection released", reduceTasks...) + results = append(results, reduceTasks...) + task.SetPriority(task.TaskPriorityNormal, results...) + return results } func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica) []task.Task { @@ -137,7 +139,8 @@ func (c *SegmentChecker) checkReplica(ctx context.Context, replica *meta.Replica // GetStreamingSegmentDiff get streaming segment diff between leader view and target func (c *SegmentChecker) getStreamingSegmentDiff(collectionID int64, - replicaID int64) (toLoad []*datapb.SegmentInfo, toRelease []*meta.Segment) { + replicaID int64, +) (toLoad []*datapb.SegmentInfo, toRelease []*meta.Segment) { replica := c.meta.Get(replicaID) if replica == nil { log.Info("replica does not exist, skip it") @@ -196,16 +199,20 @@ func (c *SegmentChecker) getStreamingSegmentDiff(collectionID int64, // GetHistoricalSegmentDiff get historical segment diff between target and dist func (c *SegmentChecker) getHistoricalSegmentDiff( collectionID int64, - replicaID int64) (toLoad []*datapb.SegmentInfo, toRelease []*meta.Segment) { + replicaID int64, +) (toLoad []*datapb.SegmentInfo, toRelease []*meta.Segment) { replica := c.meta.Get(replicaID) if replica == nil { log.Info("replica does not exist, skip it") return } dist := c.getHistoricalSegmentsDist(replica) - distMap := typeutil.NewUniqueSet() + sort.Slice(dist, func(i, j int) bool { + return dist[i].Version < dist[j].Version + }) + distMap := make(map[int64]int64) for _, s := range dist { - distMap.Insert(s.GetID()) + distMap[s.GetID()] = s.Node } nextTargetMap := c.targetMgr.GetHistoricalSegmentsByCollection(collectionID, meta.NextTarget) @@ -213,7 +220,15 @@ func (c *SegmentChecker) getHistoricalSegmentDiff( // Segment which exist on next target, but not on dist for segmentID, segment := range nextTargetMap { - if !distMap.Contain(segmentID) { + leader := c.dist.LeaderViewManager.GetLatestLeadersByReplicaShard(replica, + segment.GetInsertChannel(), + ) + node, ok := distMap[segmentID] + if !ok || + // the L0 segments have to been in the same node as the channel watched + leader != nil && + segment.GetLevel() == datapb.SegmentLevel_L0 && + node != leader.ID { toLoad = append(toLoad, segment) } } @@ -228,6 +243,16 @@ func (c *SegmentChecker) getHistoricalSegmentDiff( } } + level0Segments := lo.Filter(toLoad, func(segment *datapb.SegmentInfo, _ int) bool { + return segment.GetLevel() == datapb.SegmentLevel_L0 + }) + // L0 segment found, + // QueryCoord loads the L0 segments first, + // to make sure all L0 delta logs will be delivered to the other segments. + if len(level0Segments) > 0 { + toLoad = level0Segments + } + return } @@ -308,25 +333,46 @@ func (c *SegmentChecker) createSegmentLoadTasks(ctx context.Context, segments [] if len(segments) == 0 { return nil } - packedSegments := make([]*meta.Segment, 0, len(segments)) + + isLevel0 := segments[0].GetLevel() == datapb.SegmentLevel_L0 + + shardSegments := make(map[string][]*meta.Segment) for _, s := range segments { - if len(c.dist.LeaderViewManager.GetLeadersByShard(s.GetInsertChannel())) == 0 { + if isLevel0 && + len(c.dist.LeaderViewManager.GetLeadersByShard(s.GetInsertChannel())) == 0 { continue } - packedSegments = append(packedSegments, &meta.Segment{SegmentInfo: s}) + channel := s.GetInsertChannel() + packedSegments := shardSegments[channel] + packedSegments = append(packedSegments, &meta.Segment{ + SegmentInfo: s, + }) + shardSegments[channel] = packedSegments } - outboundNodes := c.meta.ResourceManager.CheckOutboundNodes(replica) - availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { - stop, err := c.nodeMgr.IsStoppingNode(node) - if err != nil { - return false + + plans := make([]balance.SegmentAssignPlan, 0) + for shard, segments := range shardSegments { + outboundNodes := c.meta.ResourceManager.CheckOutboundNodes(replica) + availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { + stop, err := c.nodeMgr.IsStoppingNode(node) + if err != nil { + return false + } + + if isLevel0 { + leader := c.dist.LeaderViewManager.GetLatestLeadersByReplicaShard(replica, shard) + return !outboundNodes.Contain(node) && !stop && node == leader.ID + } + return !outboundNodes.Contain(node) && !stop + }) + + shardPlans := c.balancer.AssignSegment(replica.CollectionID, segments, availableNodes) + for i := range shardPlans { + shardPlans[i].ReplicaID = replica.GetID() } - return !outboundNodes.Contain(node) && !stop - }) - plans := c.balancer.AssignSegment(replica.CollectionID, packedSegments, availableNodes) - for i := range plans { - plans[i].ReplicaID = replica.GetID() + plans = append(plans, shardPlans...) } + return balance.CreateSegmentTasksFromPlans(ctx, c.ID(), Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), plans) } @@ -342,7 +388,6 @@ func (c *SegmentChecker) createSegmentReduceTasks(ctx context.Context, segments replicaID, action, ) - if err != nil { log.Warn("create segment reduce task failed", zap.Int64("collection", s.GetCollectionID()), diff --git a/internal/querycoordv2/dist/dist_controller.go b/internal/querycoordv2/dist/dist_controller.go index b1268b7e6732c..1c26cb62fca69 100644 --- a/internal/querycoordv2/dist/dist_controller.go +++ b/internal/querycoordv2/dist/dist_controller.go @@ -88,8 +88,9 @@ func (dc *ControllerImpl) SyncAll(ctx context.Context) { func (dc *ControllerImpl) Stop() { dc.mu.Lock() defer dc.mu.Unlock() - for _, h := range dc.handlers { + for nodeID, h := range dc.handlers { h.stop() + delete(dc.handlers, nodeID) } } diff --git a/internal/querycoordv2/dist/dist_controller_test.go b/internal/querycoordv2/dist/dist_controller_test.go index e6238ae5a10f0..21f21f024429d 100644 --- a/internal/querycoordv2/dist/dist_controller_test.go +++ b/internal/querycoordv2/dist/dist_controller_test.go @@ -87,7 +87,7 @@ func (suite *DistControllerTestSuite) TearDownSuite() { func (suite *DistControllerTestSuite) TestStart() { dispatchCalled := atomic.NewBool(false) suite.mockCluster.EXPECT().GetDataDistribution(mock.Anything, mock.Anything, mock.Anything).Return( - &querypb.GetDataDistributionResponse{Status: merr.Status(nil), NodeID: 1}, + &querypb.GetDataDistributionResponse{Status: merr.Success(), NodeID: 1}, nil, ) suite.mockScheduler.EXPECT().Dispatch(int64(1)).Run(func(node int64) { dispatchCalled.Store(true) }) @@ -115,7 +115,7 @@ func (suite *DistControllerTestSuite) TestStop() { suite.controller.StartDistInstance(context.TODO(), 1) called := atomic.NewBool(false) suite.mockCluster.EXPECT().GetDataDistribution(mock.Anything, mock.Anything, mock.Anything).Maybe().Return( - &querypb.GetDataDistributionResponse{Status: merr.Status(nil), NodeID: 1}, + &querypb.GetDataDistributionResponse{Status: merr.Success(), NodeID: 1}, nil, ).Run(func(args mock.Arguments) { called.Store(true) @@ -140,7 +140,7 @@ func (suite *DistControllerTestSuite) TestSyncAll() { suite.mockCluster.EXPECT().GetDataDistribution(mock.Anything, mock.Anything, mock.Anything).Call.Return( func(ctx context.Context, nodeID int64, req *querypb.GetDataDistributionRequest) *querypb.GetDataDistributionResponse { return &querypb.GetDataDistributionResponse{ - Status: merr.Status(nil), + Status: merr.Success(), NodeID: nodeID, } }, diff --git a/internal/querycoordv2/dist/dist_handler.go b/internal/querycoordv2/dist/dist_handler.go index b2f4bccd631a8..60cf6487f3202 100644 --- a/internal/querycoordv2/dist/dist_handler.go +++ b/internal/querycoordv2/dist/dist_handler.go @@ -233,7 +233,6 @@ func (dh *distHandler) getDistribution(ctx context.Context) (*querypb.GetDataDis ), Checkpoints: channels, }) - if err != nil { return nil, err } @@ -247,6 +246,10 @@ func (dh *distHandler) stop() { dh.stopOnce.Do(func() { close(dh.c) dh.wg.Wait() + + // clear dist + dh.dist.ChannelDistManager.Update(dh.nodeID) + dh.dist.SegmentDistManager.Update(dh.nodeID) }) } diff --git a/internal/querycoordv2/handlers.go b/internal/querycoordv2/handlers.go index e7d5c3577ea3c..92e90373d5a93 100644 --- a/internal/querycoordv2/handlers.go +++ b/internal/querycoordv2/handlers.go @@ -22,7 +22,6 @@ import ( "sync" "time" - "github.com/cockroachdb/errors" "github.com/samber/lo" "go.uber.org/zap" @@ -67,12 +66,12 @@ func (s *Server) getCollectionSegmentInfo(collection int64) []*querypb.SegmentIn infos := make(map[int64]*querypb.SegmentInfo) for _, segment := range segments { if _, existCurrentTarget := currentTargetSegmentsMap[segment.GetID()]; !existCurrentTarget { - //if one segment exists in distMap but doesn't exist in currentTargetMap - //in order to guarantee that get segment request launched by sdk could get - //consistent result, for example - //sdk insert three segments:A, B, D, then A + B----compact--> C - //In this scenario, we promise that clients see either 2 segments(C,D) or 3 segments(A, B, D) - //rather than 4 segments(A, B, C, D), in which query nodes are loading C but have completed loading process + // if one segment exists in distMap but doesn't exist in currentTargetMap + // in order to guarantee that get segment request launched by sdk could get + // consistent result, for example + // sdk insert three segments:A, B, D, then A + B----compact--> C + // In this scenario, we promise that clients see either 2 segments(C,D) or 3 segments(A, B, D) + // rather than 4 segments(A, B, C, D), in which query nodes are loading C but have completed loading process log.Info("filtered segment being in the intermediate status", zap.Int64("segmentID", segment.GetID())) continue @@ -149,7 +148,6 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe task.NewSegmentActionWithScope(plan.To, task.ActionTypeGrow, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical), task.NewSegmentActionWithScope(srcNode, task.ActionTypeReduce, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical), ) - if err != nil { log.Warn("create segment task for balance failed", zap.Int64("collection", req.GetCollectionID()), @@ -174,8 +172,8 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe // TODO(dragondriver): add more detail metrics func (s *Server) getSystemInfoMetrics( ctx context.Context, - req *milvuspb.GetMetricsRequest) (string, error) { - + req *milvuspb.GetMetricsRequest, +) (string, error) { clusterTopology := metricsinfo.QueryClusterTopology{ Self: metricsinfo.QueryCoordInfos{ BaseComponentInfos: metricsinfo.BaseComponentInfos{ @@ -333,9 +331,9 @@ func (s *Server) fillReplicaInfo(replica *meta.Replica, withShardNodes bool) (*m leaderInfo = s.nodeMgr.Get(leader) } if leaderInfo == nil { - msg := fmt.Sprintf("failed to get shard leader for shard %s, the collection not loaded or leader is offline", channel) + msg := fmt.Sprintf("failed to get shard leader for shard %s", channel) log.Warn(msg) - return nil, errors.Wrap(merr.WrapErrNodeNotFound(leader), msg) + return nil, merr.WrapErrNodeNotFound(leader, msg) } shard := &milvuspb.ShardReplica{ diff --git a/internal/querycoordv2/job/job_sync.go b/internal/querycoordv2/job/job_sync.go index 4b6c8eb435133..72a25b9a67e97 100644 --- a/internal/querycoordv2/job/job_sync.go +++ b/internal/querycoordv2/job/job_sync.go @@ -20,9 +20,9 @@ import ( "context" "time" + "github.com/cockroachdb/errors" "go.uber.org/zap" - "github.com/cockroachdb/errors" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/session" diff --git a/internal/querycoordv2/job/job_test.go b/internal/querycoordv2/job/job_test.go index 618747e2a0d56..145a8144441e2 100644 --- a/internal/querycoordv2/job/job_test.go +++ b/internal/querycoordv2/job/job_test.go @@ -123,7 +123,7 @@ func (suite *JobSuite) SetupSuite() { }) } } - suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collection).Return(vChannels, segmentBinlogs, nil) + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collection).Return(vChannels, segmentBinlogs, nil).Maybe() } suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything). @@ -134,10 +134,10 @@ func (suite *JobSuite) SetupSuite() { suite.cluster = session.NewMockCluster(suite.T()) suite.cluster.EXPECT(). LoadPartitions(mock.Anything, mock.Anything, mock.Anything). - Return(merr.Status(nil), nil) + Return(merr.Success(), nil) suite.cluster.EXPECT(). ReleasePartitions(mock.Anything, mock.Anything, mock.Anything). - Return(merr.Status(nil), nil) + Return(merr.Success(), nil).Maybe() } func (suite *JobSuite) SetupTest() { @@ -163,10 +163,10 @@ func (suite *JobSuite) SetupTest() { suite.dist, suite.broker, ) - suite.targetObserver.Start(context.Background()) + suite.targetObserver.Start() suite.scheduler = NewScheduler() - suite.scheduler.Start(context.Background()) + suite.scheduler.Start() meta.GlobalFailedLoadCache = meta.NewFailedLoadCache() suite.nodeMgr.Add(session.NewNodeInfo(1000, "localhost")) @@ -1339,7 +1339,7 @@ func (suite *JobSuite) TestCallReleasePartitionFailed() { return call.Method != "ReleasePartitions" }) suite.cluster.EXPECT().ReleasePartitions(mock.Anything, mock.Anything, mock.Anything). - Return(merr.Status(nil), nil) + Return(merr.Success(), nil) } func (suite *JobSuite) TestSyncNewCreatedPartition() { diff --git a/internal/querycoordv2/job/scheduler.go b/internal/querycoordv2/job/scheduler.go index 18e8f20a5f851..5ee70a6508c98 100644 --- a/internal/querycoordv2/job/scheduler.go +++ b/internal/querycoordv2/job/scheduler.go @@ -37,7 +37,7 @@ const ( type jobQueue chan Job type Scheduler struct { - stopCh chan struct{} + cancel context.CancelFunc wg sync.WaitGroup processors *typeutil.ConcurrentSet[int64] // Collections of having processor @@ -49,73 +49,64 @@ type Scheduler struct { func NewScheduler() *Scheduler { return &Scheduler{ - stopCh: make(chan struct{}), processors: typeutil.NewConcurrentSet[int64](), queues: make(map[int64]jobQueue), waitQueue: make(jobQueue, waitQueueCap), } } -func (scheduler *Scheduler) Start(ctx context.Context) { - scheduler.schedule(ctx) +func (scheduler *Scheduler) Start() { + ctx, cancel := context.WithCancel(context.Background()) + scheduler.cancel = cancel + + scheduler.wg.Add(1) + go func() { + defer scheduler.wg.Done() + scheduler.schedule(ctx) + }() } func (scheduler *Scheduler) Stop() { scheduler.stopOnce.Do(func() { - close(scheduler.stopCh) + if scheduler.cancel != nil { + scheduler.cancel() + } scheduler.wg.Wait() }) } func (scheduler *Scheduler) schedule(ctx context.Context) { - scheduler.wg.Add(1) - go func() { - defer scheduler.wg.Done() - ticker := time.NewTicker(500 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - log.Info("JobManager stopped due to context canceled") - return - - case <-scheduler.stopCh: - log.Info("JobManager stopped") - for _, queue := range scheduler.queues { - close(queue) - } - return + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + log.Info("JobManager stopped") + for _, queue := range scheduler.queues { + close(queue) + } + return - case job := <-scheduler.waitQueue: - queue, ok := scheduler.queues[job.CollectionID()] - if !ok { - queue = make(jobQueue, collectionQueueCap) - scheduler.queues[job.CollectionID()] = queue - } - queue <- job - scheduler.startProcessor(job.CollectionID(), queue) - - case <-ticker.C: - for collection, queue := range scheduler.queues { - if len(queue) > 0 { - scheduler.startProcessor(collection, queue) - } else { - // Release resource if no job for the collection - close(queue) - delete(scheduler.queues, collection) - } + case job := <-scheduler.waitQueue: + queue, ok := scheduler.queues[job.CollectionID()] + if !ok { + queue = make(jobQueue, collectionQueueCap) + scheduler.queues[job.CollectionID()] = queue + } + queue <- job + scheduler.startProcessor(job.CollectionID(), queue) + + case <-ticker.C: + for collection, queue := range scheduler.queues { + if len(queue) > 0 { + scheduler.startProcessor(collection, queue) + } else { + // Release resource if no job for the collection + close(queue) + delete(scheduler.queues, collection) } } } - }() -} - -func (scheduler *Scheduler) isStopped() bool { - select { - case <-scheduler.stopCh: - return true - default: - return false } } @@ -124,9 +115,6 @@ func (scheduler *Scheduler) Add(job Job) { } func (scheduler *Scheduler) startProcessor(collection int64, queue jobQueue) { - if scheduler.isStopped() { - return - } if !scheduler.processors.Insert(collection) { return } diff --git a/internal/querycoordv2/job/undo.go b/internal/querycoordv2/job/undo.go index 5ea53e62ad890..64b89bb78c2d2 100644 --- a/internal/querycoordv2/job/undo.go +++ b/internal/querycoordv2/job/undo.go @@ -19,11 +19,12 @@ package job import ( "context" + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/observers" "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/pkg/log" - "go.uber.org/zap" ) type UndoList struct { @@ -42,7 +43,8 @@ type UndoList struct { } func NewUndoList(ctx context.Context, meta *meta.Meta, - cluster session.Cluster, targetMgr *meta.TargetManager, targetObserver *observers.TargetObserver) *UndoList { + cluster session.Cluster, targetMgr *meta.TargetManager, targetObserver *observers.TargetObserver, +) *UndoList { return &UndoList{ ctx: ctx, meta: meta, diff --git a/internal/querycoordv2/job/utils.go b/internal/querycoordv2/job/utils.go index 7e06d80fc6103..c6a9b26cfcc56 100644 --- a/internal/querycoordv2/job/utils.go +++ b/internal/querycoordv2/job/utils.go @@ -68,7 +68,8 @@ func loadPartitions(ctx context.Context, broker meta.Broker, withSchema bool, collection int64, - partitions ...int64) error { + partitions ...int64, +) error { var err error var schema *schemapb.CollectionSchema if withSchema { @@ -113,7 +114,8 @@ func releasePartitions(ctx context.Context, meta *meta.Meta, cluster session.Cluster, collection int64, - partitions ...int64) { + partitions ...int64, +) { log := log.Ctx(ctx).With(zap.Int64("collection", collection), zap.Int64s("partitions", partitions)) replicas := meta.ReplicaManager.GetByCollection(collection) releaseReq := &querypb.ReleasePartitionsRequest{ diff --git a/internal/querycoordv2/meta/collection_manager.go b/internal/querycoordv2/meta/collection_manager.go index c994d0df5494c..43f4dcaf3e806 100644 --- a/internal/querycoordv2/meta/collection_manager.go +++ b/internal/querycoordv2/meta/collection_manager.go @@ -19,6 +19,7 @@ package meta import ( "context" "fmt" + "strconv" "sync" "time" @@ -33,8 +34,8 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" - . "github.com/milvus-io/milvus/pkg/util/typeutil" ) type Collection struct { @@ -98,8 +99,8 @@ func (partition *Partition) Clone() *Partition { type CollectionManager struct { rwmutex sync.RWMutex - collections map[UniqueID]*Collection - partitions map[UniqueID]*Partition + collections map[typeutil.UniqueID]*Collection + partitions map[typeutil.UniqueID]*Partition catalog metastore.QueryCoordCatalog } @@ -123,40 +124,59 @@ func (m *CollectionManager) Recover(broker Broker) error { return err } + ctx := log.WithTraceID(context.Background(), strconv.FormatInt(time.Now().UnixNano(), 10)) + ctxLog := log.Ctx(ctx) + ctxLog.Info("recover collections and partitions from kv store") + for _, collection := range collections { // Dropped collection should be deprecated - _, err = broker.GetCollectionSchema(context.Background(), collection.GetCollectionID()) + _, err = broker.GetCollectionSchema(ctx, collection.GetCollectionID()) if errors.Is(err, merr.ErrCollectionNotFound) { - log.Info("skip dropped collection during recovery", zap.Int64("collection", collection.GetCollectionID())) + ctxLog.Info("skip dropped collection during recovery", zap.Int64("collection", collection.GetCollectionID())) m.catalog.ReleaseCollection(collection.GetCollectionID()) continue } if err != nil { + ctxLog.Warn("failed to get collection schema", zap.Error(err)) return err } - // Collections not loaded done should be deprecated - if collection.GetStatus() != querypb.LoadStatus_Loaded || collection.GetReplicaNumber() <= 0 { - log.Info("skip recovery and release collection", + + if collection.GetReplicaNumber() <= 0 { + ctxLog.Info("skip recovery and release collection due to invalid replica number", zap.Int64("collectionID", collection.GetCollectionID()), - zap.String("status", collection.GetStatus().String()), - zap.Int32("replicaNumber", collection.GetReplicaNumber()), - ) + zap.Int32("replicaNumber", collection.GetReplicaNumber())) m.catalog.ReleaseCollection(collection.GetCollectionID()) continue } + + if collection.GetStatus() != querypb.LoadStatus_Loaded { + if collection.RecoverTimes >= paramtable.Get().QueryCoordCfg.CollectionRecoverTimesLimit.GetAsInt32() { + m.catalog.ReleaseCollection(collection.CollectionID) + ctxLog.Info("recover loading collection times reach limit, release collection", + zap.Int64("collectionID", collection.CollectionID), + zap.Int32("recoverTimes", collection.RecoverTimes)) + break + } + // update recoverTimes meta in etcd + collection.RecoverTimes += 1 + m.putCollection(true, &Collection{CollectionLoadInfo: collection}) + continue + } + m.collections[collection.CollectionID] = &Collection{ CollectionLoadInfo: collection, } } for collection, partitions := range partitions { - existPartitions, err := broker.GetPartitions(context.Background(), collection) + existPartitions, err := broker.GetPartitions(ctx, collection) if errors.Is(err, merr.ErrCollectionNotFound) { - log.Info("skip dropped collection during recovery", zap.Int64("collection", collection)) + ctxLog.Info("skip dropped collection during recovery", zap.Int64("collection", collection)) m.catalog.ReleaseCollection(collection) continue } if err != nil { + ctxLog.Warn("failed to get partitions", zap.Error(err)) return err } omitPartitions := make([]int64, 0) @@ -168,33 +188,32 @@ func (m *CollectionManager) Recover(broker Broker) error { return true }) if len(omitPartitions) > 0 { - log.Info("skip dropped partitions during recovery", - zap.Int64("collection", collection), zap.Int64s("partitions", omitPartitions)) + ctxLog.Info("skip dropped partitions during recovery", + zap.Int64("collection", collection), + zap.Int64s("partitions", omitPartitions)) m.catalog.ReleasePartition(collection, omitPartitions...) } - sawLoaded := false for _, partition := range partitions { // Partitions not loaded done should be deprecated if partition.GetStatus() != querypb.LoadStatus_Loaded { - log.Info("skip recovery and release partition", - zap.Int64("collectionID", collection), - zap.Int64("partitionID", partition.GetPartitionID()), - zap.String("status", partition.GetStatus().String()), - ) - m.catalog.ReleasePartition(collection, partition.GetPartitionID()) + if partition.RecoverTimes >= paramtable.Get().QueryCoordCfg.CollectionRecoverTimesLimit.GetAsInt32() { + m.catalog.ReleaseCollection(collection) + ctxLog.Info("recover loading partition times reach limit, release collection", + zap.Int64("collectionID", collection), + zap.Int32("recoverTimes", partition.RecoverTimes)) + break + } + + partition.RecoverTimes += 1 + m.putPartition([]*Partition{{PartitionLoadInfo: partition}}, true) continue } - sawLoaded = true m.partitions[partition.PartitionID] = &Partition{ PartitionLoadInfo: partition, } } - - if !sawLoaded { - m.catalog.ReleaseCollection(collection) - } } err = m.upgradeRecover(broker) @@ -255,21 +274,21 @@ func (m *CollectionManager) upgradeRecover(broker Broker) error { return nil } -func (m *CollectionManager) GetCollection(collectionID UniqueID) *Collection { +func (m *CollectionManager) GetCollection(collectionID typeutil.UniqueID) *Collection { m.rwmutex.RLock() defer m.rwmutex.RUnlock() return m.collections[collectionID] } -func (m *CollectionManager) GetPartition(partitionID UniqueID) *Partition { +func (m *CollectionManager) GetPartition(partitionID typeutil.UniqueID) *Partition { m.rwmutex.RLock() defer m.rwmutex.RUnlock() return m.partitions[partitionID] } -func (m *CollectionManager) GetLoadType(collectionID UniqueID) querypb.LoadType { +func (m *CollectionManager) GetLoadType(collectionID typeutil.UniqueID) querypb.LoadType { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -280,7 +299,7 @@ func (m *CollectionManager) GetLoadType(collectionID UniqueID) querypb.LoadType return querypb.LoadType_UnKnownType } -func (m *CollectionManager) GetReplicaNumber(collectionID UniqueID) int32 { +func (m *CollectionManager) GetReplicaNumber(collectionID typeutil.UniqueID) int32 { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -292,14 +311,14 @@ func (m *CollectionManager) GetReplicaNumber(collectionID UniqueID) int32 { } // CalculateLoadPercentage checks if collection is currently fully loaded. -func (m *CollectionManager) CalculateLoadPercentage(collectionID UniqueID) int32 { +func (m *CollectionManager) CalculateLoadPercentage(collectionID typeutil.UniqueID) int32 { m.rwmutex.RLock() defer m.rwmutex.RUnlock() return m.calculateLoadPercentage(collectionID) } -func (m *CollectionManager) calculateLoadPercentage(collectionID UniqueID) int32 { +func (m *CollectionManager) calculateLoadPercentage(collectionID typeutil.UniqueID) int32 { _, ok := m.collections[collectionID] if ok { partitions := m.getPartitionsByCollection(collectionID) @@ -312,7 +331,7 @@ func (m *CollectionManager) calculateLoadPercentage(collectionID UniqueID) int32 return -1 } -func (m *CollectionManager) GetPartitionLoadPercentage(partitionID UniqueID) int32 { +func (m *CollectionManager) GetPartitionLoadPercentage(partitionID typeutil.UniqueID) int32 { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -323,7 +342,7 @@ func (m *CollectionManager) GetPartitionLoadPercentage(partitionID UniqueID) int return -1 } -func (m *CollectionManager) CalculateLoadStatus(collectionID UniqueID) querypb.LoadStatus { +func (m *CollectionManager) CalculateLoadStatus(collectionID typeutil.UniqueID) querypb.LoadStatus { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -346,7 +365,7 @@ func (m *CollectionManager) CalculateLoadStatus(collectionID UniqueID) querypb.L return querypb.LoadStatus_Invalid } -func (m *CollectionManager) GetFieldIndex(collectionID UniqueID) map[int64]int64 { +func (m *CollectionManager) GetFieldIndex(collectionID typeutil.UniqueID) map[int64]int64 { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -357,7 +376,7 @@ func (m *CollectionManager) GetFieldIndex(collectionID UniqueID) map[int64]int64 return nil } -func (m *CollectionManager) Exist(collectionID UniqueID) bool { +func (m *CollectionManager) Exist(collectionID typeutil.UniqueID) bool { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -391,14 +410,14 @@ func (m *CollectionManager) GetAllPartitions() []*Partition { return lo.Values(m.partitions) } -func (m *CollectionManager) GetPartitionsByCollection(collectionID UniqueID) []*Partition { +func (m *CollectionManager) GetPartitionsByCollection(collectionID typeutil.UniqueID) []*Partition { m.rwmutex.RLock() defer m.rwmutex.RUnlock() return m.getPartitionsByCollection(collectionID) } -func (m *CollectionManager) getPartitionsByCollection(collectionID UniqueID) []*Partition { +func (m *CollectionManager) getPartitionsByCollection(collectionID typeutil.UniqueID) []*Partition { partitions := make([]*Partition, 0) for _, partition := range m.partitions { if partition.CollectionID == collectionID { @@ -489,6 +508,8 @@ func (m *CollectionManager) UpdateLoadPercent(partitionID int64, loadPercent int if loadPercent == 100 { savePartition = true newPartition.Status = querypb.LoadStatus_Loaded + // if partition becomes loaded, clear it's recoverTimes in load info + newPartition.RecoverTimes = 0 elapsed := time.Since(newPartition.CreatedAt) metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(elapsed.Milliseconds())) eventlog.Record(eventlog.NewRawEvt(eventlog.Level_Info, fmt.Sprintf("Partition %d loaded", partitionID))) @@ -510,12 +531,14 @@ func (m *CollectionManager) UpdateLoadPercent(partitionID int64, loadPercent int if collectionPercent == 100 { saveCollection = true newCollection.Status = querypb.LoadStatus_Loaded - elapsed := time.Since(newCollection.CreatedAt) + + // if collection becomes loaded, clear it's recoverTimes in load info + newCollection.RecoverTimes = 0 // TODO: what if part of the collection has been unloaded? Now we decrease the metric only after // `ReleaseCollection` is triggered. Maybe it's hard to make this metric really accurate. metrics.QueryCoordNumCollections.WithLabelValues().Inc() - + elapsed := time.Since(newCollection.CreatedAt) metrics.QueryCoordLoadLatency.WithLabelValues().Observe(float64(elapsed.Milliseconds())) eventlog.Record(eventlog.NewRawEvt(eventlog.Level_Info, fmt.Sprintf("Collection %d loaded", newCollection.CollectionID))) } @@ -523,7 +546,7 @@ func (m *CollectionManager) UpdateLoadPercent(partitionID int64, loadPercent int } // RemoveCollection removes collection and its partitions. -func (m *CollectionManager) RemoveCollection(collectionID UniqueID) error { +func (m *CollectionManager) RemoveCollection(collectionID typeutil.UniqueID) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() @@ -543,7 +566,7 @@ func (m *CollectionManager) RemoveCollection(collectionID UniqueID) error { return nil } -func (m *CollectionManager) RemovePartition(collectionID UniqueID, partitionIDs ...UniqueID) error { +func (m *CollectionManager) RemovePartition(collectionID typeutil.UniqueID, partitionIDs ...typeutil.UniqueID) error { if len(partitionIDs) == 0 { return nil } @@ -554,7 +577,7 @@ func (m *CollectionManager) RemovePartition(collectionID UniqueID, partitionIDs return m.removePartition(collectionID, partitionIDs...) } -func (m *CollectionManager) removePartition(collectionID UniqueID, partitionIDs ...UniqueID) error { +func (m *CollectionManager) removePartition(collectionID typeutil.UniqueID, partitionIDs ...typeutil.UniqueID) error { err := m.catalog.ReleasePartition(collectionID, partitionIDs...) if err != nil { return err diff --git a/internal/querycoordv2/meta/collection_manager_test.go b/internal/querycoordv2/meta/collection_manager_test.go index d74e1de25dab2..311bbbe6af9ca 100644 --- a/internal/querycoordv2/meta/collection_manager_test.go +++ b/internal/querycoordv2/meta/collection_manager_test.go @@ -25,6 +25,7 @@ import ( "github.com/samber/lo" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "go.uber.org/zap" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" @@ -32,6 +33,7 @@ import ( "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" "github.com/milvus-io/milvus/internal/proto/querypb" . "github.com/milvus-io/milvus/internal/querycoordv2/params" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -328,16 +330,96 @@ func (suite *CollectionManagerSuite) TestRecover_normal() { suite.clearMemory() err := mgr.Recover(suite.broker) suite.NoError(err) + for _, collection := range suite.collections { + suite.True(mgr.Exist(collection)) + for _, partitionID := range suite.partitions[collection] { + partition := mgr.GetPartition(partitionID) + suite.NotNil(partition) + } + } +} + +func (suite *CollectionManagerSuite) TestRecoverLoadingCollection() { + mgr := suite.mgr + suite.releaseAll() + suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(nil, nil) + // test put collection with partitions for i, collection := range suite.collections { - exist := suite.colLoadPercent[i] == 100 - suite.Equal(exist, mgr.Exist(collection)) - if !exist { - continue + suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe() + col := &Collection{ + CollectionLoadInfo: &querypb.CollectionLoadInfo{ + CollectionID: collection, + ReplicaNumber: suite.replicaNumber[i], + Status: querypb.LoadStatus_Loading, + LoadType: suite.loadTypes[i], + }, + LoadPercentage: 0, + CreatedAt: time.Now(), } - for j, partitionID := range suite.partitions[collection] { + partitions := lo.Map(suite.partitions[collection], func(partition int64, j int) *Partition { + return &Partition{ + PartitionLoadInfo: &querypb.PartitionLoadInfo{ + CollectionID: collection, + PartitionID: partition, + ReplicaNumber: suite.replicaNumber[i], + Status: querypb.LoadStatus_Loading, + }, + LoadPercentage: 0, + CreatedAt: time.Now(), + } + }) + err := suite.mgr.PutCollection(col, partitions...) + suite.NoError(err) + } + + // recover for first time, expected recover success + suite.clearMemory() + err := mgr.Recover(suite.broker) + suite.NoError(err) + for _, collectionID := range suite.collections { + collection := mgr.GetCollection(collectionID) + suite.NotNil(collection) + suite.Equal(int32(1), collection.GetRecoverTimes()) + for _, partitionID := range suite.partitions[collectionID] { partition := mgr.GetPartition(partitionID) - exist = suite.parLoadPercent[collection][j] == 100 - suite.Equal(exist, partition != nil) + suite.NotNil(partition) + suite.Equal(int32(1), partition.GetRecoverTimes()) + } + } + + // update load percent, then recover for second time + for _, collectionID := range suite.collections { + for _, partitionID := range suite.partitions[collectionID] { + mgr.UpdateLoadPercent(partitionID, 10) + } + } + suite.clearMemory() + err = mgr.Recover(suite.broker) + suite.NoError(err) + for _, collectionID := range suite.collections { + collection := mgr.GetCollection(collectionID) + suite.NotNil(collection) + suite.Equal(int32(2), collection.GetRecoverTimes()) + for _, partitionID := range suite.partitions[collectionID] { + partition := mgr.GetPartition(partitionID) + suite.NotNil(partition) + suite.Equal(int32(2), partition.GetRecoverTimes()) + } + } + + // test recover loading collection reach limit + for i := 0; i < int(paramtable.Get().QueryCoordCfg.CollectionRecoverTimesLimit.GetAsInt32()); i++ { + log.Info("stupid", zap.Int("count", i)) + suite.clearMemory() + err = mgr.Recover(suite.broker) + suite.NoError(err) + } + for _, collectionID := range suite.collections { + collection := mgr.GetCollection(collectionID) + suite.Nil(collection) + for _, partitionID := range suite.partitions[collectionID] { + partition := mgr.GetPartition(partitionID) + suite.Nil(partition) } } } @@ -368,15 +450,15 @@ func (suite *CollectionManagerSuite) TestRecover_with_dropped() { suite.clearMemory() err := mgr.Recover(suite.broker) suite.NoError(err) - for i, collection := range suite.collections { - exist := suite.colLoadPercent[i] == 100 && collection != droppedCollection + for _, collection := range suite.collections { + exist := collection != droppedCollection suite.Equal(exist, mgr.Exist(collection)) if !exist { continue } - for j, partitionID := range suite.partitions[collection] { + for _, partitionID := range suite.partitions[collection] { partition := mgr.GetPartition(partitionID) - exist = suite.parLoadPercent[collection][j] == 100 && partitionID != droppedPartition + exist = partitionID != droppedPartition suite.Equal(exist, partition != nil) } } diff --git a/internal/querycoordv2/meta/coordinator_broker.go b/internal/querycoordv2/meta/coordinator_broker.go index ea2324eb11b06..d85b2a022420f 100644 --- a/internal/querycoordv2/meta/coordinator_broker.go +++ b/internal/querycoordv2/meta/coordinator_broker.go @@ -21,15 +21,16 @@ import ( "fmt" "time" - "github.com/cockroachdb/errors" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/metastore/kv/datacoord" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" @@ -49,13 +50,14 @@ type Broker interface { } type CoordinatorBroker struct { - dataCoord types.DataCoord - rootCoord types.RootCoord + dataCoord types.DataCoordClient + rootCoord types.RootCoordClient } func NewCoordinatorBroker( - dataCoord types.DataCoord, - rootCoord types.RootCoord) *CoordinatorBroker { + dataCoord types.DataCoordClient, + rootCoord types.RootCoordClient, +) *CoordinatorBroker { return &CoordinatorBroker{ dataCoord, rootCoord, @@ -74,13 +76,8 @@ func (broker *CoordinatorBroker) GetCollectionSchema(ctx context.Context, collec CollectionID: collectionID, } resp, err := broker.rootCoord.DescribeCollection(ctx, req) - if err != nil { - return nil, err - } - - err = merr.Error(resp.GetStatus()) - if err != nil { - log.Warn("failed to get collection schema", zap.Error(err)) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Ctx(ctx).Warn("failed to get collection schema", zap.Error(err)) return nil, err } return resp.GetSchema(), nil @@ -89,6 +86,7 @@ func (broker *CoordinatorBroker) GetCollectionSchema(ctx context.Context, collec func (broker *CoordinatorBroker) GetPartitions(ctx context.Context, collectionID UniqueID) ([]UniqueID, error) { ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Millisecond)) defer cancel() + log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID)) req := &milvuspb.ShowPartitionsRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions), @@ -97,13 +95,7 @@ func (broker *CoordinatorBroker) GetPartitions(ctx context.Context, collectionID CollectionID: collectionID, } resp, err := broker.rootCoord.ShowPartitions(ctx, req) - if err != nil { - log.Warn("showPartition failed", zap.Int64("collectionID", collectionID), zap.Error(err)) - return nil, err - } - - err = merr.Error(resp.GetStatus()) - if err != nil { + if err := merr.CheckRPCCall(resp, err); err != nil { log.Warn("failed to get partitions", zap.Error(err)) return nil, err } @@ -114,6 +106,10 @@ func (broker *CoordinatorBroker) GetPartitions(ctx context.Context, collectionID func (broker *CoordinatorBroker) GetRecoveryInfo(ctx context.Context, collectionID UniqueID, partitionID UniqueID) ([]*datapb.VchannelInfo, []*datapb.SegmentBinlogs, error) { ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Millisecond)) defer cancel() + log := log.Ctx(ctx).With( + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", partitionID), + ) getRecoveryInfoRequest := &datapb.GetRecoveryInfoRequest{ Base: commonpbutil.NewMsgBase( @@ -123,14 +119,8 @@ func (broker *CoordinatorBroker) GetRecoveryInfo(ctx context.Context, collection PartitionID: partitionID, } recoveryInfo, err := broker.dataCoord.GetRecoveryInfo(ctx, getRecoveryInfoRequest) - if err != nil { - log.Warn("get recovery info failed", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Error(err)) - return nil, nil, err - } - - if recoveryInfo.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - err = errors.New(recoveryInfo.GetStatus().GetReason()) - log.Warn("get recovery info failed", zap.Int64("collectionID", collectionID), zap.Int64("partitionID", partitionID), zap.Error(err)) + if err := merr.CheckRPCCall(recoveryInfo, err); err != nil { + log.Warn("get recovery info failed", zap.Error(err)) return nil, nil, err } @@ -140,6 +130,10 @@ func (broker *CoordinatorBroker) GetRecoveryInfo(ctx context.Context, collection func (broker *CoordinatorBroker) GetRecoveryInfoV2(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) ([]*datapb.VchannelInfo, []*datapb.SegmentInfo, error) { ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Millisecond)) defer cancel() + log := log.Ctx(ctx).With( + zap.Int64("collectionID", collectionID), + zap.Int64s("partitionIDis", partitionIDs), + ) getRecoveryInfoRequest := &datapb.GetRecoveryInfoRequestV2{ Base: commonpbutil.NewMsgBase( @@ -149,39 +143,39 @@ func (broker *CoordinatorBroker) GetRecoveryInfoV2(ctx context.Context, collecti PartitionIDs: partitionIDs, } recoveryInfo, err := broker.dataCoord.GetRecoveryInfoV2(ctx, getRecoveryInfoRequest) - if err != nil { - log.Warn("get recovery info failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs), zap.Error(err)) - return nil, nil, err - } - if recoveryInfo.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - err = errors.New(recoveryInfo.GetStatus().GetReason()) - log.Warn("get recovery info failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs), zap.Error(err)) + if err := merr.CheckRPCCall(recoveryInfo, err); err != nil { + log.Warn("get recovery info failed", zap.Error(err)) return nil, nil, err } + path := params.Params.MinioCfg.RootPath.GetValue() + // refill log ID with log path + for _, segmentInfo := range recoveryInfo.Segments { + datacoord.DecompressBinLog(path, segmentInfo) + } return recoveryInfo.Channels, recoveryInfo.Segments, nil } func (broker *CoordinatorBroker) GetSegmentInfo(ctx context.Context, ids ...UniqueID) (*datapb.GetSegmentInfoResponse, error) { ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Millisecond)) defer cancel() + log := log.Ctx(ctx).With( + zap.Int64s("segments", ids), + ) req := &datapb.GetSegmentInfoRequest{ SegmentIDs: ids, IncludeUnHealthy: true, } resp, err := broker.dataCoord.GetSegmentInfo(ctx, req) - if err != nil { - log.Warn("failed to get segment info from DataCoord", - zap.Int64s("segments", ids), - zap.Error(err)) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to get segment info from DataCoord", zap.Error(err)) return nil, err } if len(resp.Infos) == 0 { - log.Warn("No such segment in DataCoord", - zap.Int64s("segments", ids)) + log.Warn("No such segment in DataCoord") return nil, fmt.Errorf("no such segment in DataCoord") } @@ -192,36 +186,47 @@ func (broker *CoordinatorBroker) GetIndexInfo(ctx context.Context, collectionID ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Millisecond)) defer cancel() + log := log.Ctx(ctx).With( + zap.Int64("collectionID", collectionID), + zap.Int64("segmentID", segmentID), + ) + resp, err := broker.dataCoord.GetIndexInfos(ctx, &indexpb.GetIndexInfoRequest{ CollectionID: collectionID, SegmentIDs: []int64{segmentID}, }) - if err != nil || resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to get segment index info", zap.Error(err)) + return nil, err + } + + if resp.GetSegmentInfo() == nil { + err = merr.WrapErrIndexNotFoundForSegment(segmentID) log.Warn("failed to get segment index info", - zap.Int64("collection", collectionID), - zap.Int64("segment", segmentID), zap.Error(err)) return nil, err } - segmentInfo, ok := resp.SegmentInfo[segmentID] + segmentInfo, ok := resp.GetSegmentInfo()[segmentID] if !ok || len(segmentInfo.GetIndexInfos()) == 0 { - return nil, merr.WrapErrIndexNotFound() + return nil, merr.WrapErrIndexNotFoundForSegment(segmentID) } indexes := make([]*querypb.FieldIndexInfo, 0) for _, info := range segmentInfo.GetIndexInfos() { indexes = append(indexes, &querypb.FieldIndexInfo{ - FieldID: info.GetFieldID(), - EnableIndex: true, - IndexName: info.GetIndexName(), - IndexID: info.GetIndexID(), - BuildID: info.GetBuildID(), - IndexParams: info.GetIndexParams(), - IndexFilePaths: info.GetIndexFilePaths(), - IndexSize: int64(info.GetSerializedSize()), - IndexVersion: info.GetIndexVersion(), - NumRows: info.GetNumRows(), + FieldID: info.GetFieldID(), + EnableIndex: true, + IndexName: info.GetIndexName(), + IndexID: info.GetIndexID(), + BuildID: info.GetBuildID(), + IndexParams: info.GetIndexParams(), + IndexFilePaths: info.GetIndexFilePaths(), + IndexSize: int64(info.GetSerializedSize()), + IndexVersion: info.GetIndexVersion(), + NumRows: info.GetNumRows(), + CurrentIndexVersion: info.GetCurrentIndexVersion(), }) } @@ -236,11 +241,11 @@ func (broker *CoordinatorBroker) DescribeIndex(ctx context.Context, collectionID CollectionID: collectionID, }) - if err != nil || resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + if err := merr.CheckRPCCall(resp, err); err != nil { log.Error("failed to fetch index meta", zap.Int64("collection", collectionID), zap.Error(err)) return nil, err } - return resp.IndexInfos, nil + return resp.GetIndexInfos(), nil } diff --git a/internal/querycoordv2/meta/coordinator_broker_test.go b/internal/querycoordv2/meta/coordinator_broker_test.go index 3a42b72942fb0..a5e7f8abd5e18 100644 --- a/internal/querycoordv2/meta/coordinator_broker_test.go +++ b/internal/querycoordv2/meta/coordinator_broker_test.go @@ -21,129 +21,373 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" + "github.com/samber/lo" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) -func TestCoordinatorBroker_GetCollectionSchema(t *testing.T) { - t.Run("got error on DescribeCollection", func(t *testing.T) { - rootCoord := mocks.NewRootCoord(t) - rootCoord.On("DescribeCollection", - mock.Anything, - mock.Anything, - ).Return(nil, errors.New("error mock DescribeCollection")) - ctx := context.Background() - broker := &CoordinatorBroker{rootCoord: rootCoord} - _, err := broker.GetCollectionSchema(ctx, 100) - assert.Error(t, err) - }) - - t.Run("non-success code", func(t *testing.T) { - rootCoord := mocks.NewRootCoord(t) - rootCoord.On("DescribeCollection", - mock.Anything, - mock.Anything, - ).Return(&milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_CollectionNotExists}, - }, nil) - ctx := context.Background() - broker := &CoordinatorBroker{rootCoord: rootCoord} - _, err := broker.GetCollectionSchema(ctx, 100) - assert.Error(t, err) - }) - - t.Run("normal case", func(t *testing.T) { - rootCoord := mocks.NewRootCoord(t) - rootCoord.On("DescribeCollection", - mock.Anything, - mock.Anything, - ).Return(&milvuspb.DescribeCollectionResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, - Schema: &schemapb.CollectionSchema{Name: "test_schema"}, - }, nil) - ctx := context.Background() - broker := &CoordinatorBroker{rootCoord: rootCoord} - schema, err := broker.GetCollectionSchema(ctx, 100) - assert.NoError(t, err) - assert.Equal(t, "test_schema", schema.GetName()) - }) +type CoordinatorBrokerRootCoordSuite struct { + suite.Suite + + rootcoord *mocks.MockRootCoordClient + broker *CoordinatorBroker } -func TestCoordinatorBroker_GetRecoveryInfo(t *testing.T) { - t.Run("normal case", func(t *testing.T) { - dc := mocks.NewMockDataCoord(t) - dc.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(&datapb.GetRecoveryInfoResponseV2{}, nil) +func (s *CoordinatorBrokerRootCoordSuite) SetupSuite() { + paramtable.Init() +} - ctx := context.Background() - broker := &CoordinatorBroker{dataCoord: dc} +func (s *CoordinatorBrokerRootCoordSuite) SetupTest() { + s.rootcoord = mocks.NewMockRootCoordClient(s.T()) + s.broker = NewCoordinatorBroker(nil, s.rootcoord) +} - _, _, err := broker.GetRecoveryInfoV2(ctx, 1) - assert.NoError(t, err) - }) +func (s *CoordinatorBrokerRootCoordSuite) resetMock() { + s.rootcoord.AssertExpectations(s.T()) + s.rootcoord.ExpectedCalls = nil +} - t.Run("get error", func(t *testing.T) { - dc := mocks.NewMockDataCoord(t) - fakeErr := errors.New("fake error") - dc.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(nil, fakeErr) +func (s *CoordinatorBrokerRootCoordSuite) TestGetCollectionSchema() { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + collectionID := int64(100) - ctx := context.Background() - broker := &CoordinatorBroker{dataCoord: dc} + s.Run("normal case", func() { + s.rootcoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Schema: &schemapb.CollectionSchema{Name: "test_schema"}, + }, nil) - _, _, err := broker.GetRecoveryInfoV2(ctx, 1) - assert.ErrorIs(t, err, fakeErr) + schema, err := s.broker.GetCollectionSchema(ctx, collectionID) + s.NoError(err) + s.Equal("test_schema", schema.GetName()) + s.resetMock() }) - t.Run("return non-success code", func(t *testing.T) { - dc := mocks.NewMockDataCoord(t) - dc.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(&datapb.GetRecoveryInfoResponseV2{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, - }, nil) + s.Run("rootcoord_return_error", func() { + s.rootcoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything). + Return(nil, errors.New("mock error")) - ctx := context.Background() - broker := &CoordinatorBroker{dataCoord: dc} + _, err := s.broker.GetCollectionSchema(ctx, collectionID) + s.Error(err) + s.resetMock() + }) - _, _, err := broker.GetRecoveryInfoV2(ctx, 1) - assert.Error(t, err) + s.Run("return_failure_status", func() { + s.rootcoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_CollectionNotExists}, + }, nil) + + _, err := s.broker.GetCollectionSchema(ctx, collectionID) + s.Error(err) + s.resetMock() }) } -func TestCoordinatorBroker_GetPartitions(t *testing.T) { +func (s *CoordinatorBrokerRootCoordSuite) TestGetPartitions() { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() collection := int64(100) partitions := []int64{10, 11, 12} - t.Run("normal case", func(t *testing.T) { - rc := mocks.NewRootCoord(t) - rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{ - Status: &commonpb.Status{}, + s.Run("normal_case", func() { + s.rootcoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{ + Status: merr.Status(nil), PartitionIDs: partitions, }, nil) - ctx := context.Background() - broker := &CoordinatorBroker{rootCoord: rc} - - retPartitions, err := broker.GetPartitions(ctx, collection) - assert.NoError(t, err) - assert.ElementsMatch(t, partitions, retPartitions) + retPartitions, err := s.broker.GetPartitions(ctx, collection) + s.NoError(err) + s.ElementsMatch(partitions, retPartitions) + s.resetMock() }) - t.Run("collection not exist", func(t *testing.T) { - rc := mocks.NewRootCoord(t) - rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{ + s.Run("collection_not_exist", func() { + s.rootcoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{ Status: merr.Status(merr.WrapErrCollectionNotFound("mock")), }, nil) - ctx := context.Background() - broker := &CoordinatorBroker{rootCoord: rc} - _, err := broker.GetPartitions(ctx, collection) - assert.ErrorIs(t, err, merr.ErrCollectionNotFound) + _, err := s.broker.GetPartitions(ctx, collection) + s.Error(err) + s.ErrorIs(err, merr.ErrCollectionNotFound) + s.resetMock() + }) +} + +type CoordinatorBrokerDataCoordSuite struct { + suite.Suite + + datacoord *mocks.MockDataCoordClient + broker *CoordinatorBroker +} + +func (s *CoordinatorBrokerDataCoordSuite) SetupSuite() { + paramtable.Init() +} + +func (s *CoordinatorBrokerDataCoordSuite) SetupTest() { + s.datacoord = mocks.NewMockDataCoordClient(s.T()) + s.broker = NewCoordinatorBroker(s.datacoord, nil) +} + +func (s *CoordinatorBrokerDataCoordSuite) resetMock() { + s.datacoord.AssertExpectations(s.T()) + s.datacoord.ExpectedCalls = nil +} + +func (s *CoordinatorBrokerDataCoordSuite) TestGetRecoveryInfo() { + collectionID := int64(100) + partitionID := int64(1000) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s.Run("normal_case", func() { + channels := []string{"dml_0"} + segmentIDs := []int64{1, 2, 3} + s.datacoord.EXPECT().GetRecoveryInfo(mock.Anything, mock.Anything). + Return(&datapb.GetRecoveryInfoResponse{ + Channels: lo.Map(channels, func(ch string, _ int) *datapb.VchannelInfo { + return &datapb.VchannelInfo{ + CollectionID: collectionID, + ChannelName: "dml_0", + } + }), + Binlogs: lo.Map(segmentIDs, func(id int64, _ int) *datapb.SegmentBinlogs { + return &datapb.SegmentBinlogs{SegmentID: id} + }), + }, nil) + + vchans, segInfos, err := s.broker.GetRecoveryInfo(ctx, collectionID, partitionID) + s.NoError(err) + s.ElementsMatch(channels, lo.Map(vchans, func(info *datapb.VchannelInfo, _ int) string { + return info.GetChannelName() + })) + s.ElementsMatch(segmentIDs, lo.Map(segInfos, func(info *datapb.SegmentBinlogs, _ int) int64 { + return info.GetSegmentID() + })) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.datacoord.EXPECT().GetRecoveryInfo(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + _, _, err := s.broker.GetRecoveryInfo(ctx, collectionID, partitionID) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.datacoord.EXPECT().GetRecoveryInfo(mock.Anything, mock.Anything). + Return(&datapb.GetRecoveryInfoResponse{ + Status: merr.Status(errors.New("mocked")), + }, nil) + + _, _, err := s.broker.GetRecoveryInfo(ctx, collectionID, partitionID) + s.Error(err) + s.resetMock() + }) +} + +func (s *CoordinatorBrokerDataCoordSuite) TestGetRecoveryInfoV2() { + collectionID := int64(100) + partitionID := int64(1000) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s.Run("normal_case", func() { + channels := []string{"dml_0"} + segmentIDs := []int64{1, 2, 3} + s.datacoord.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything). + Return(&datapb.GetRecoveryInfoResponseV2{ + Channels: lo.Map(channels, func(ch string, _ int) *datapb.VchannelInfo { + return &datapb.VchannelInfo{ + CollectionID: collectionID, + ChannelName: "dml_0", + } + }), + Segments: lo.Map(segmentIDs, func(id int64, _ int) *datapb.SegmentInfo { + return &datapb.SegmentInfo{ID: id} + }), + }, nil) + + vchans, segInfos, err := s.broker.GetRecoveryInfoV2(ctx, collectionID, partitionID) + s.NoError(err) + s.ElementsMatch(channels, lo.Map(vchans, func(info *datapb.VchannelInfo, _ int) string { + return info.GetChannelName() + })) + s.ElementsMatch(segmentIDs, lo.Map(segInfos, func(info *datapb.SegmentInfo, _ int) int64 { + return info.GetID() + })) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.datacoord.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + _, _, err := s.broker.GetRecoveryInfoV2(ctx, collectionID, partitionID) + s.Error(err) + s.resetMock() }) + + s.Run("datacoord_return_failure_status", func() { + s.datacoord.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything). + Return(&datapb.GetRecoveryInfoResponseV2{ + Status: merr.Status(errors.New("mocked")), + }, nil) + + _, _, err := s.broker.GetRecoveryInfoV2(ctx, collectionID, partitionID) + s.Error(err) + s.resetMock() + }) +} + +func (s *CoordinatorBrokerDataCoordSuite) TestDescribeIndex() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + collectionID := int64(100) + + s.Run("normal_case", func() { + indexIDs := []int64{1, 2} + s.datacoord.EXPECT().DescribeIndex(mock.Anything, mock.Anything). + Return(&indexpb.DescribeIndexResponse{ + Status: merr.Status(nil), + IndexInfos: lo.Map(indexIDs, func(id int64, _ int) *indexpb.IndexInfo { + return &indexpb.IndexInfo{IndexID: id} + }), + }, nil) + infos, err := s.broker.DescribeIndex(ctx, collectionID) + s.NoError(err) + s.ElementsMatch(indexIDs, lo.Map(infos, func(info *indexpb.IndexInfo, _ int) int64 { return info.GetIndexID() })) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.datacoord.EXPECT().DescribeIndex(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + _, err := s.broker.DescribeIndex(ctx, collectionID) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.datacoord.EXPECT().DescribeIndex(mock.Anything, mock.Anything). + Return(&indexpb.DescribeIndexResponse{ + Status: merr.Status(errors.New("mocked")), + }, nil) + + _, err := s.broker.DescribeIndex(ctx, collectionID) + s.Error(err) + s.resetMock() + }) +} + +func (s *CoordinatorBrokerDataCoordSuite) TestSegmentInfo() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + collectionID := int64(100) + segmentIDs := []int64{10000, 10001, 10002} + + s.Run("normal_case", func() { + s.datacoord.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything). + Return(&datapb.GetSegmentInfoResponse{ + Status: merr.Status(nil), + Infos: lo.Map(segmentIDs, func(id int64, _ int) *datapb.SegmentInfo { + return &datapb.SegmentInfo{ID: id, CollectionID: collectionID} + }), + }, nil) + + resp, err := s.broker.GetSegmentInfo(ctx, segmentIDs...) + s.NoError(err) + s.ElementsMatch(segmentIDs, lo.Map(resp.GetInfos(), func(info *datapb.SegmentInfo, _ int) int64 { + return info.GetID() + })) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.datacoord.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + _, err := s.broker.GetSegmentInfo(ctx, segmentIDs...) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.datacoord.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything). + Return(&datapb.GetSegmentInfoResponse{Status: merr.Status(errors.New("mocked"))}, nil) + + _, err := s.broker.GetSegmentInfo(ctx, segmentIDs...) + s.Error(err) + s.resetMock() + }) +} + +func (s *CoordinatorBrokerDataCoordSuite) TestGetIndexInfo() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + collectionID := int64(100) + segmentID := int64(10000) + + s.Run("normal_case", func() { + indexIDs := []int64{1, 2, 3} + s.datacoord.EXPECT().GetIndexInfos(mock.Anything, mock.Anything). + Return(&indexpb.GetIndexInfoResponse{ + Status: merr.Status(nil), + SegmentInfo: map[int64]*indexpb.SegmentInfo{ + segmentID: { + SegmentID: segmentID, + IndexInfos: lo.Map(indexIDs, func(id int64, _ int) *indexpb.IndexFilePathInfo { + return &indexpb.IndexFilePathInfo{IndexID: id} + }), + }, + }, + }, nil) + + infos, err := s.broker.GetIndexInfo(ctx, collectionID, segmentID) + s.NoError(err) + s.ElementsMatch(indexIDs, lo.Map(infos, func(info *querypb.FieldIndexInfo, _ int) int64 { + return info.GetIndexID() + })) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.datacoord.EXPECT().GetIndexInfos(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + _, err := s.broker.GetIndexInfo(ctx, collectionID, segmentID) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.datacoord.EXPECT().GetIndexInfos(mock.Anything, mock.Anything). + Return(&indexpb.GetIndexInfoResponse{Status: merr.Status(errors.New("mock"))}, nil) + + _, err := s.broker.GetIndexInfo(ctx, collectionID, segmentID) + s.Error(err) + s.resetMock() + }) +} + +func TestCoordinatorBroker(t *testing.T) { + suite.Run(t, new(CoordinatorBrokerRootCoordSuite)) + suite.Run(t, new(CoordinatorBrokerDataCoordSuite)) } diff --git a/internal/querycoordv2/meta/leader_view_manager.go b/internal/querycoordv2/meta/leader_view_manager.go index 8d5078492b25f..2d58cce3fecc0 100644 --- a/internal/querycoordv2/meta/leader_view_manager.go +++ b/internal/querycoordv2/meta/leader_view_manager.go @@ -229,3 +229,19 @@ func (mgr *LeaderViewManager) GetLeadersByShard(shard string) map[int64]*LeaderV } return ret } + +func (mgr *LeaderViewManager) GetLatestLeadersByReplicaShard(replica *Replica, shard string) *LeaderView { + mgr.rwmutex.RLock() + defer mgr.rwmutex.RUnlock() + + var ret *LeaderView + for _, views := range mgr.views { + view, ok := views[shard] + if ok && + replica.Contains(view.ID) && + (ret == nil || ret.Version < view.Version) { + ret = view + } + } + return ret +} diff --git a/internal/querycoordv2/meta/replica_manager.go b/internal/querycoordv2/meta/replica_manager.go index 0bc7de359d31d..e788075dfeba7 100644 --- a/internal/querycoordv2/meta/replica_manager.go +++ b/internal/querycoordv2/meta/replica_manager.go @@ -28,16 +28,15 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" - . "github.com/milvus-io/milvus/pkg/util/typeutil" ) type Replica struct { *querypb.Replica - nodes UniqueSet // a helper field for manipulating replica's Nodes slice field + nodes typeutil.UniqueSet // a helper field for manipulating replica's Nodes slice field rwmutex sync.RWMutex } -func NewReplica(replica *querypb.Replica, nodes UniqueSet) *Replica { +func NewReplica(replica *querypb.Replica, nodes typeutil.UniqueSet) *Replica { return &Replica{ Replica: replica, nodes: nodes, @@ -54,7 +53,7 @@ func (replica *Replica) AddNode(nodes ...int64) { func (replica *Replica) GetNodes() []int64 { replica.rwmutex.RLock() defer replica.rwmutex.RUnlock() - if replica != nil { + if replica.nodes != nil { return replica.nodes.Collect() } return nil @@ -63,7 +62,7 @@ func (replica *Replica) GetNodes() []int64 { func (replica *Replica) Len() int { replica.rwmutex.RLock() defer replica.rwmutex.RUnlock() - if replica != nil { + if replica.nodes != nil { return replica.nodes.Len() } @@ -73,7 +72,7 @@ func (replica *Replica) Len() int { func (replica *Replica) Contains(node int64) bool { replica.rwmutex.RLock() defer replica.rwmutex.RUnlock() - if replica != nil { + if replica.nodes != nil { return replica.nodes.Contain(node) } @@ -92,7 +91,7 @@ func (replica *Replica) Clone() *Replica { defer replica.rwmutex.RUnlock() return &Replica{ Replica: proto.Clone(replica.Replica).(*querypb.Replica), - nodes: NewUniqueSet(replica.Replica.Nodes...), + nodes: typeutil.NewUniqueSet(replica.Replica.Nodes...), } } @@ -100,7 +99,7 @@ type ReplicaManager struct { rwmutex sync.RWMutex idAllocator func() (int64, error) - replicas map[UniqueID]*Replica + replicas map[typeutil.UniqueID]*Replica catalog metastore.QueryCoordCatalog } @@ -128,7 +127,7 @@ func (m *ReplicaManager) Recover(collections []int64) error { if collectionSet.Contain(replica.GetCollectionID()) { m.replicas[replica.GetID()] = &Replica{ Replica: replica, - nodes: NewUniqueSet(replica.GetNodes()...), + nodes: typeutil.NewUniqueSet(replica.GetNodes()...), } log.Info("recover replica", zap.Int64("collectionID", replica.GetCollectionID()), @@ -150,7 +149,7 @@ func (m *ReplicaManager) Recover(collections []int64) error { return nil } -func (m *ReplicaManager) Get(id UniqueID) *Replica { +func (m *ReplicaManager) Get(id typeutil.UniqueID) *Replica { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -180,7 +179,7 @@ func (m *ReplicaManager) Put(replicas ...*Replica) error { return m.put(replicas...) } -func (m *ReplicaManager) spawn(collectionID UniqueID, rgName string) (*Replica, error) { +func (m *ReplicaManager) spawn(collectionID typeutil.UniqueID, rgName string) (*Replica, error) { id, err := m.idAllocator() if err != nil { return nil, err @@ -191,7 +190,7 @@ func (m *ReplicaManager) spawn(collectionID UniqueID, rgName string) (*Replica, CollectionID: collectionID, ResourceGroup: rgName, }, - nodes: make(UniqueSet), + nodes: make(typeutil.UniqueSet), }, nil } @@ -208,7 +207,7 @@ func (m *ReplicaManager) put(replicas ...*Replica) error { // RemoveCollection removes replicas of given collection, // returns error if failed to remove replica from KV -func (m *ReplicaManager) RemoveCollection(collectionID UniqueID) error { +func (m *ReplicaManager) RemoveCollection(collectionID typeutil.UniqueID) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() @@ -224,7 +223,7 @@ func (m *ReplicaManager) RemoveCollection(collectionID UniqueID) error { return nil } -func (m *ReplicaManager) GetByCollection(collectionID UniqueID) []*Replica { +func (m *ReplicaManager) GetByCollection(collectionID typeutil.UniqueID) []*Replica { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -238,7 +237,7 @@ func (m *ReplicaManager) GetByCollection(collectionID UniqueID) []*Replica { return replicas } -func (m *ReplicaManager) GetByCollectionAndNode(collectionID, nodeID UniqueID) *Replica { +func (m *ReplicaManager) GetByCollectionAndNode(collectionID, nodeID typeutil.UniqueID) *Replica { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -279,7 +278,7 @@ func (m *ReplicaManager) GetByResourceGroup(rgName string) []*Replica { return ret } -func (m *ReplicaManager) AddNode(replicaID UniqueID, nodes ...UniqueID) error { +func (m *ReplicaManager) AddNode(replicaID typeutil.UniqueID, nodes ...typeutil.UniqueID) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() @@ -293,7 +292,7 @@ func (m *ReplicaManager) AddNode(replicaID UniqueID, nodes ...UniqueID) error { return m.put(replica) } -func (m *ReplicaManager) RemoveNode(replicaID UniqueID, nodes ...UniqueID) error { +func (m *ReplicaManager) RemoveNode(replicaID typeutil.UniqueID, nodes ...typeutil.UniqueID) error { m.rwmutex.Lock() defer m.rwmutex.Unlock() @@ -307,7 +306,7 @@ func (m *ReplicaManager) RemoveNode(replicaID UniqueID, nodes ...UniqueID) error return m.put(replica) } -func (m *ReplicaManager) GetResourceGroupByCollection(collection UniqueID) typeutil.Set[string] { +func (m *ReplicaManager) GetResourceGroupByCollection(collection typeutil.UniqueID) typeutil.Set[string] { m.rwmutex.Lock() defer m.rwmutex.Unlock() diff --git a/internal/querycoordv2/meta/resource_manager.go b/internal/querycoordv2/meta/resource_manager.go index e859e62acfa3f..58ac1f2f70e97 100644 --- a/internal/querycoordv2/meta/resource_manager.go +++ b/internal/querycoordv2/meta/resource_manager.go @@ -29,7 +29,6 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" - . "github.com/milvus-io/milvus/pkg/util/typeutil" ) var ( @@ -56,7 +55,7 @@ var DefaultResourceGroupName = "__default_resource_group" var DefaultResourceGroupCapacity = 1000000 type ResourceGroup struct { - nodes UniqueSet + nodes typeutil.UniqueSet capacity int } @@ -481,7 +480,6 @@ func (rm *ResourceManager) HandleNodeDown(node int64) (string, error) { rgName, err := rm.findResourceGroupByNode(node) if err != nil { return "", ErrNodeNotAssignToRG - } newNodes := []int64{} @@ -528,7 +526,7 @@ func (rm *ResourceManager) TransferNode(from string, to string, numNode int) ([] return nil, ErrNodeNotEnough } - //todo: a better way to choose a node with least balance cost + // todo: a better way to choose a node with least balance cost movedNodes, err := rm.transferNodeInStore(from, to, numNode) if err != nil { return nil, err @@ -627,7 +625,7 @@ func (rm *ResourceManager) AutoRecoverResourceGroup(rgName string) ([]int64, err lackNodesNum := rm.groups[rgName].LackOfNodes() nodesInDefault := rm.groups[DefaultResourceGroupName].GetNodes() for i := 0; i < len(nodesInDefault) && i < lackNodesNum; i++ { - //todo: a better way to choose a node with least balance cost + // todo: a better way to choose a node with least balance cost node := nodesInDefault[i] err := rm.unassignNode(DefaultResourceGroupName, node) if err != nil { diff --git a/internal/querycoordv2/meta/target_manager.go b/internal/querycoordv2/meta/target_manager.go index 2d2c9d2b3c2b8..c45c23a2f3f52 100644 --- a/internal/querycoordv2/meta/target_manager.go +++ b/internal/querycoordv2/meta/target_manager.go @@ -20,14 +20,15 @@ import ( "context" "sync" + "github.com/cockroachdb/errors" + "github.com/samber/lo" + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/samber/lo" - "go.uber.org/zap" - "google.golang.org/grpc/codes" ) type TargetScope = int32 @@ -191,7 +192,7 @@ func (mgr *TargetManager) PullNextTargetV2(broker Broker, collectionID int64, ch vChannelInfos, segmentInfos, err := broker.GetRecoveryInfoV2(context.TODO(), collectionID) if err != nil { // if meet rpc error, for compatibility with previous versions, try pull next target v1 - if funcutil.IsGrpcErr(err, codes.Unimplemented) { + if errors.Is(err, merr.ErrServiceUnimplemented) { segments, dmChannels, err = mgr.PullNextTargetV1(broker, collectionID, chosenPartitionIDs...) return err } @@ -324,7 +325,8 @@ func (mgr *TargetManager) getTarget(scope TargetScope) *target { } func (mgr *TargetManager) GetStreamingSegmentsByCollection(collectionID int64, - scope TargetScope) typeutil.UniqueSet { + scope TargetScope, +) typeutil.UniqueSet { mgr.rwMutex.RLock() defer mgr.rwMutex.RUnlock() @@ -345,7 +347,8 @@ func (mgr *TargetManager) GetStreamingSegmentsByCollection(collectionID int64, func (mgr *TargetManager) GetStreamingSegmentsByChannel(collectionID int64, channelName string, - scope TargetScope) typeutil.UniqueSet { + scope TargetScope, +) typeutil.UniqueSet { mgr.rwMutex.RLock() defer mgr.rwMutex.RUnlock() @@ -367,7 +370,8 @@ func (mgr *TargetManager) GetStreamingSegmentsByChannel(collectionID int64, } func (mgr *TargetManager) GetHistoricalSegmentsByCollection(collectionID int64, - scope TargetScope) map[int64]*datapb.SegmentInfo { + scope TargetScope, +) map[int64]*datapb.SegmentInfo { mgr.rwMutex.RLock() defer mgr.rwMutex.RUnlock() @@ -382,7 +386,8 @@ func (mgr *TargetManager) GetHistoricalSegmentsByCollection(collectionID int64, func (mgr *TargetManager) GetHistoricalSegmentsByChannel(collectionID int64, channelName string, - scope TargetScope) map[int64]*datapb.SegmentInfo { + scope TargetScope, +) map[int64]*datapb.SegmentInfo { mgr.rwMutex.RLock() defer mgr.rwMutex.RUnlock() @@ -405,7 +410,8 @@ func (mgr *TargetManager) GetHistoricalSegmentsByChannel(collectionID int64, func (mgr *TargetManager) GetDroppedSegmentsByChannel(collectionID int64, channelName string, - scope TargetScope) []int64 { + scope TargetScope, +) []int64 { mgr.rwMutex.RLock() defer mgr.rwMutex.RUnlock() @@ -425,7 +431,8 @@ func (mgr *TargetManager) GetDroppedSegmentsByChannel(collectionID int64, } func (mgr *TargetManager) GetHistoricalSegmentsByPartition(collectionID int64, - partitionID int64, scope TargetScope) map[int64]*datapb.SegmentInfo { + partitionID int64, scope TargetScope, +) map[int64]*datapb.SegmentInfo { mgr.rwMutex.RLock() defer mgr.rwMutex.RUnlock() diff --git a/internal/querycoordv2/meta/target_manager_test.go b/internal/querycoordv2/meta/target_manager_test.go index b57449c431a6d..41443d3cb64ce 100644 --- a/internal/querycoordv2/meta/target_manager_test.go +++ b/internal/querycoordv2/meta/target_manager_test.go @@ -35,6 +35,7 @@ import ( . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -138,7 +139,8 @@ func (suite *TargetManagerSuite) SetupTest() { suite.meta.PutCollection(&Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: collection, - ReplicaNumber: 1}, + ReplicaNumber: 1, + }, }) for _, partition := range suite.partitions[collection] { suite.meta.PutPartition(&Partition{ @@ -183,7 +185,8 @@ func (suite *TargetManagerSuite) TestUpdateNextTarget() { suite.meta.PutCollection(&Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: collectionID, - ReplicaNumber: 1}, + ReplicaNumber: 1, + }, }) suite.meta.PutPartition(&Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ @@ -236,7 +239,8 @@ func (suite *TargetManagerSuite) TestUpdateNextTarget() { suite.broker.ExpectedCalls = nil // test getRecoveryInfoV2 failed , then back to getRecoveryInfo succeed - suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nil, nil, status.Errorf(codes.Unimplemented, "fake not found")) + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return( + nil, nil, merr.WrapErrServiceUnimplemented(status.Errorf(codes.Unimplemented, "fake not found"))) suite.broker.EXPECT().GetPartitions(mock.Anything, mock.Anything).Return([]int64{1}, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collectionID, int64(1)).Return(nextTargetChannels, nextTargetBinlogs, nil) err := suite.mgr.UpdateCollectionNextTarget(collectionID) @@ -251,7 +255,6 @@ func (suite *TargetManagerSuite) TestUpdateNextTarget() { err = suite.mgr.UpdateCollectionNextTarget(collectionID) suite.NoError(err) - } func (suite *TargetManagerSuite) TestRemovePartition() { @@ -365,7 +368,8 @@ func (suite *TargetManagerSuite) TestGetSegmentByChannel() { suite.meta.PutCollection(&Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: collectionID, - ReplicaNumber: 1}, + ReplicaNumber: 1, + }, }) suite.meta.PutPartition(&Partition{ PartitionLoadInfo: &querypb.PartitionLoadInfo{ diff --git a/internal/querycoordv2/mocks/mock_querynode.go b/internal/querycoordv2/mocks/mock_querynode.go index 18c4e6e7fc982..039d03ee6b5a4 100644 --- a/internal/querycoordv2/mocks/mock_querynode.go +++ b/internal/querycoordv2/mocks/mock_querynode.go @@ -689,6 +689,92 @@ func (_c *MockQueryNodeServer_QuerySegments_Call) RunAndReturn(run func(context. return _c } +// QueryStream provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNodeServer) QueryStream(_a0 *querypb.QueryRequest, _a1 querypb.QueryNode_QueryStreamServer) error { + ret := _m.Called(_a0, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(*querypb.QueryRequest, querypb.QueryNode_QueryStreamServer) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryNodeServer_QueryStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryStream' +type MockQueryNodeServer_QueryStream_Call struct { + *mock.Call +} + +// QueryStream is a helper method to define mock.On call +// - _a0 *querypb.QueryRequest +// - _a1 querypb.QueryNode_QueryStreamServer +func (_e *MockQueryNodeServer_Expecter) QueryStream(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_QueryStream_Call { + return &MockQueryNodeServer_QueryStream_Call{Call: _e.mock.On("QueryStream", _a0, _a1)} +} + +func (_c *MockQueryNodeServer_QueryStream_Call) Run(run func(_a0 *querypb.QueryRequest, _a1 querypb.QueryNode_QueryStreamServer)) *MockQueryNodeServer_QueryStream_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*querypb.QueryRequest), args[1].(querypb.QueryNode_QueryStreamServer)) + }) + return _c +} + +func (_c *MockQueryNodeServer_QueryStream_Call) Return(_a0 error) *MockQueryNodeServer_QueryStream_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryNodeServer_QueryStream_Call) RunAndReturn(run func(*querypb.QueryRequest, querypb.QueryNode_QueryStreamServer) error) *MockQueryNodeServer_QueryStream_Call { + _c.Call.Return(run) + return _c +} + +// QueryStreamSegments provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryNodeServer) QueryStreamSegments(_a0 *querypb.QueryRequest, _a1 querypb.QueryNode_QueryStreamSegmentsServer) error { + ret := _m.Called(_a0, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(*querypb.QueryRequest, querypb.QueryNode_QueryStreamSegmentsServer) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryNodeServer_QueryStreamSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryStreamSegments' +type MockQueryNodeServer_QueryStreamSegments_Call struct { + *mock.Call +} + +// QueryStreamSegments is a helper method to define mock.On call +// - _a0 *querypb.QueryRequest +// - _a1 querypb.QueryNode_QueryStreamSegmentsServer +func (_e *MockQueryNodeServer_Expecter) QueryStreamSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_QueryStreamSegments_Call { + return &MockQueryNodeServer_QueryStreamSegments_Call{Call: _e.mock.On("QueryStreamSegments", _a0, _a1)} +} + +func (_c *MockQueryNodeServer_QueryStreamSegments_Call) Run(run func(_a0 *querypb.QueryRequest, _a1 querypb.QueryNode_QueryStreamSegmentsServer)) *MockQueryNodeServer_QueryStreamSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*querypb.QueryRequest), args[1].(querypb.QueryNode_QueryStreamSegmentsServer)) + }) + return _c +} + +func (_c *MockQueryNodeServer_QueryStreamSegments_Call) Return(_a0 error) *MockQueryNodeServer_QueryStreamSegments_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryNodeServer_QueryStreamSegments_Call) RunAndReturn(run func(*querypb.QueryRequest, querypb.QueryNode_QueryStreamSegmentsServer) error) *MockQueryNodeServer_QueryStreamSegments_Call { + _c.Call.Return(run) + return _c +} + // ReleaseCollection provides a mock function with given fields: _a0, _a1 func (_m *MockQueryNodeServer) ReleaseCollection(_a0 context.Context, _a1 *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { ret := _m.Called(_a0, _a1) diff --git a/internal/querycoordv2/mocks/querynode.go b/internal/querycoordv2/mocks/querynode.go index 013b267db1cb6..0151db85978f4 100644 --- a/internal/querycoordv2/mocks/querynode.go +++ b/internal/querycoordv2/mocks/querynode.go @@ -82,7 +82,7 @@ func (node *MockQueryNode) Start() error { err = node.server.Serve(lis) }() - successStatus := merr.Status(nil) + successStatus := merr.Success() node.EXPECT().GetDataDistribution(mock.Anything, mock.Anything).Return(&querypb.GetDataDistributionResponse{ Status: successStatus, NodeID: node.ID, @@ -117,7 +117,9 @@ func (node *MockQueryNode) Start() error { case <-node.ctx.Done(): return nil default: - return &milvuspb.ComponentStates{} + return &milvuspb.ComponentStates{ + Status: successStatus, + } } }, func(context.Context, *milvuspb.GetComponentStatesRequest) error { select { diff --git a/internal/querycoordv2/observers/collection_observer.go b/internal/querycoordv2/observers/collection_observer.go index 0fa9a8b1a51af..685074749690d 100644 --- a/internal/querycoordv2/observers/collection_observer.go +++ b/internal/querycoordv2/observers/collection_observer.go @@ -34,7 +34,8 @@ import ( ) type CollectionObserver struct { - stopCh chan struct{} + cancel context.CancelFunc + wg sync.WaitGroup dist *meta.DistributionManager meta *meta.Meta @@ -56,7 +57,6 @@ func NewCollectionObserver( checherController *checkers.CheckerController, ) *CollectionObserver { return &CollectionObserver{ - stopCh: make(chan struct{}), dist: dist, meta: meta, targetMgr: targetMgr, @@ -67,23 +67,25 @@ func NewCollectionObserver( } } -func (ob *CollectionObserver) Start(ctx context.Context) { +func (ob *CollectionObserver) Start() { + ctx, cancel := context.WithCancel(context.Background()) + ob.cancel = cancel + const observePeriod = time.Second + ob.wg.Add(1) go func() { + defer ob.wg.Done() + ticker := time.NewTicker(observePeriod) defer ticker.Stop() for { select { case <-ctx.Done(): - log.Info("CollectionObserver stopped due to context canceled") - return - - case <-ob.stopCh: log.Info("CollectionObserver stopped") return case <-ticker.C: - ob.Observe() + ob.Observe(ctx) } } }() @@ -91,13 +93,16 @@ func (ob *CollectionObserver) Start(ctx context.Context) { func (ob *CollectionObserver) Stop() { ob.stopOnce.Do(func() { - close(ob.stopCh) + if ob.cancel != nil { + ob.cancel() + } + ob.wg.Wait() }) } -func (ob *CollectionObserver) Observe() { +func (ob *CollectionObserver) Observe(ctx context.Context) { ob.observeTimeout() - ob.observeLoadStatus() + ob.observeLoadStatus(ctx) } func (ob *CollectionObserver) observeTimeout() { @@ -116,11 +121,7 @@ func (ob *CollectionObserver) observeTimeout() { ob.targetMgr.RemoveCollection(collection.GetCollectionID()) } - partitions := utils.GroupPartitionsByCollection( - ob.meta.CollectionManager.GetAllPartitions()) - if len(partitions) > 0 { - log.Info("observes partitions timeout", zap.Int("partitionNum", len(partitions))) - } + partitions := utils.GroupPartitionsByCollection(ob.meta.CollectionManager.GetAllPartitions()) for collection, partitions := range partitions { for _, partition := range partitions { if partition.GetStatus() != querypb.LoadStatus_Loading || @@ -153,7 +154,7 @@ func (ob *CollectionObserver) readyToObserve(collectionID int64) bool { return metaExist && targetExist } -func (ob *CollectionObserver) observeLoadStatus() { +func (ob *CollectionObserver) observeLoadStatus(ctx context.Context) { partitions := ob.meta.CollectionManager.GetAllPartitions() if len(partitions) > 0 { log.Info("observe partitions status", zap.Int("partitionNum", len(partitions))) @@ -165,7 +166,7 @@ func (ob *CollectionObserver) observeLoadStatus() { } if ob.readyToObserve(partition.CollectionID) { replicaNum := ob.meta.GetReplicaNumber(partition.GetCollectionID()) - ob.observePartitionLoadStatus(partition, replicaNum) + ob.observePartitionLoadStatus(ctx, partition, replicaNum) loading = true } } @@ -175,7 +176,7 @@ func (ob *CollectionObserver) observeLoadStatus() { } } -func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partition, replicaNum int32) { +func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, partition *meta.Partition, replicaNum int32) { log := log.With( zap.Int64("collectionID", partition.GetCollectionID()), zap.Int64("partitionID", partition.GetPartitionID()), @@ -225,7 +226,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partiti } ob.partitionLoadedCount[partition.GetPartitionID()] = loadedCount - if loadPercentage == 100 && ob.targetObserver.Check(partition.GetCollectionID()) && ob.leaderObserver.CheckTargetVersion(partition.GetCollectionID()) { + if loadPercentage == 100 && ob.targetObserver.Check(ctx, partition.GetCollectionID()) && ob.leaderObserver.CheckTargetVersion(ctx, partition.GetCollectionID()) { delete(ob.partitionLoadedCount, partition.GetPartitionID()) } collectionPercentage, err := ob.meta.CollectionManager.UpdateLoadPercent(partition.PartitionID, loadPercentage) diff --git a/internal/querycoordv2/observers/collection_observer_test.go b/internal/querycoordv2/observers/collection_observer_test.go index d213b3254b901..9a7b14744d29a 100644 --- a/internal/querycoordv2/observers/collection_observer_test.go +++ b/internal/querycoordv2/observers/collection_observer_test.go @@ -17,7 +17,6 @@ package observers import ( - "context" "testing" "time" @@ -26,7 +25,6 @@ import ( "github.com/stretchr/testify/suite" clientv3 "go.etcd.io/etcd/client/v3" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/metastore" @@ -39,6 +37,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -198,9 +197,7 @@ func (suite *CollectionObserverSuite) SetupTest() { mockCluster := session.NewMockCluster(suite.T()) suite.leaderObserver = NewLeaderObserver(suite.dist, suite.meta, suite.targetMgr, suite.broker, mockCluster) - mockCluster.EXPECT().SyncDistribution(mock.Anything, mock.Anything, mock.Anything).Return(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil).Maybe() + mockCluster.EXPECT().SyncDistribution(mock.Anything, mock.Anything, mock.Anything).Return(merr.Success(), nil).Maybe() // Test object suite.ob = NewCollectionObserver( @@ -215,9 +212,9 @@ func (suite *CollectionObserverSuite) SetupTest() { for _, collection := range suite.collections { suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe() } - suite.targetObserver.Start(context.Background()) - suite.leaderObserver.Start(context.TODO()) - suite.ob.Start(context.Background()) + suite.targetObserver.Start() + suite.leaderObserver.Start() + suite.ob.Start() suite.loadAll() } diff --git a/internal/querycoordv2/observers/leader_observer.go b/internal/querycoordv2/observers/leader_observer.go index ceb50d25e1c17..0e01477fbdfa9 100644 --- a/internal/querycoordv2/observers/leader_observer.go +++ b/internal/querycoordv2/observers/leader_observer.go @@ -21,6 +21,7 @@ import ( "sync" "time" + "github.com/samber/lo" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -30,7 +31,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" - "github.com/samber/lo" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) const ( @@ -41,7 +42,7 @@ const ( // LeaderObserver is to sync the distribution with leader type LeaderObserver struct { wg sync.WaitGroup - closeCh chan struct{} + cancel context.CancelFunc dist *meta.DistributionManager meta *meta.Meta target *meta.TargetManager @@ -52,7 +53,10 @@ type LeaderObserver struct { stopOnce sync.Once } -func (o *LeaderObserver) Start(ctx context.Context) { +func (o *LeaderObserver) Start() { + ctx, cancel := context.WithCancel(context.Background()) + o.cancel = cancel + o.wg.Add(1) go func() { defer o.wg.Done() @@ -60,12 +64,10 @@ func (o *LeaderObserver) Start(ctx context.Context) { defer ticker.Stop() for { select { - case <-o.closeCh: - log.Info("stop leader observer") - return case <-ctx.Done(): - log.Info("stop leader observer due to ctx done") + log.Info("stop leader observer") return + case req := <-o.manualCheck: log.Info("triggering manual check") ret := o.observeCollection(ctx, req.CollectionID) @@ -81,7 +83,9 @@ func (o *LeaderObserver) Start(ctx context.Context) { func (o *LeaderObserver) Stop() { o.stopOnce.Do(func() { - close(o.closeCh) + if o.cancel != nil { + o.cancel() + } o.wg.Wait() }) } @@ -133,13 +137,20 @@ func (o *LeaderObserver) observeCollection(ctx context.Context, collection int64 return result } -func (ob *LeaderObserver) CheckTargetVersion(collectionID int64) bool { +func (ob *LeaderObserver) CheckTargetVersion(ctx context.Context, collectionID int64) bool { notifier := make(chan bool) - ob.manualCheck <- checkRequest{ - CollectionID: collectionID, - Notifier: notifier, + select { + case ob.manualCheck <- checkRequest{CollectionID: collectionID, Notifier: notifier}: + case <-ctx.Done(): + return false + } + + select { + case result := <-notifier: + return result + case <-ctx.Done(): + return false } - return <-notifier } func (o *LeaderObserver) checkNeedUpdateTargetVersion(ctx context.Context, leaderView *meta.LeaderView) *querypb.SyncAction { @@ -278,6 +289,8 @@ func (o *LeaderObserver) sync(ctx context.Context, replicaID int64, leaderView * }, Version: time.Now().UnixNano(), } + ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond)) + defer cancel() resp, err := o.cluster.SyncDistribution(ctx, leaderView.ID, req) if err != nil { log.Warn("failed to sync distribution", zap.Error(err)) @@ -300,7 +313,6 @@ func NewLeaderObserver( cluster session.Cluster, ) *LeaderObserver { return &LeaderObserver{ - closeCh: make(chan struct{}), dist: dist, meta: meta, target: targetMgr, diff --git a/internal/querycoordv2/observers/leader_observer_test.go b/internal/querycoordv2/observers/leader_observer_test.go index ab240a41d7b69..c2f1f771d2ece 100644 --- a/internal/querycoordv2/observers/leader_observer_test.go +++ b/internal/querycoordv2/observers/leader_observer_test.go @@ -158,7 +158,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegments() { } called := atomic.NewBool(false) - suite.mockCluster.EXPECT().SyncDistribution(context.TODO(), int64(2), + suite.mockCluster.EXPECT().SyncDistribution(mock.Anything, int64(2), mock.AnythingOfType("*querypb.SyncDistributionRequest")). Run(func(ctx context.Context, nodeID int64, req *querypb.SyncDistributionRequest) { assert.ElementsMatch(suite.T(), []*querypb.SyncDistributionRequest{req}, @@ -167,7 +167,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegments() { }). Return(&commonpb.Status{}, nil) - observer.Start(context.TODO()) + observer.Start() suite.Eventually( func() bool { @@ -249,7 +249,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncLoadedSegments() { } } called := atomic.NewBool(false) - suite.mockCluster.EXPECT().SyncDistribution(context.TODO(), int64(2), mock.AnythingOfType("*querypb.SyncDistributionRequest")). + suite.mockCluster.EXPECT().SyncDistribution(mock.Anything, int64(2), mock.AnythingOfType("*querypb.SyncDistributionRequest")). Run(func(ctx context.Context, nodeID int64, req *querypb.SyncDistributionRequest) { assert.ElementsMatch(suite.T(), []*querypb.SyncDistributionRequest{req}, []*querypb.SyncDistributionRequest{expectReqeustFunc(req.GetVersion())}) @@ -257,7 +257,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncLoadedSegments() { }). Return(&commonpb.Status{}, nil) - observer.Start(context.TODO()) + observer.Start() suite.Eventually( func() bool { @@ -303,7 +303,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreBalancedSegment() { } leaderView.TargetVersion = observer.target.GetCollectionTargetVersion(1, meta.CurrentTarget) observer.dist.LeaderViewManager.Update(2, leaderView) - observer.Start(context.TODO()) + observer.Start() // Nothing should happen time.Sleep(2 * time.Second) @@ -383,7 +383,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegmentsWithReplicas() { } } called := atomic.NewBool(false) - suite.mockCluster.EXPECT().SyncDistribution(context.TODO(), int64(2), + suite.mockCluster.EXPECT().SyncDistribution(mock.Anything, int64(2), mock.AnythingOfType("*querypb.SyncDistributionRequest")). Run(func(ctx context.Context, nodeID int64, req *querypb.SyncDistributionRequest) { assert.ElementsMatch(suite.T(), []*querypb.SyncDistributionRequest{req}, @@ -392,7 +392,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegmentsWithReplicas() { }). Return(&commonpb.Status{}, nil) - observer.Start(context.TODO()) + observer.Start() suite.Eventually( func() bool { @@ -453,7 +453,7 @@ func (suite *LeaderObserverTestSuite) TestSyncRemovedSegments() { } } ch := make(chan struct{}) - suite.mockCluster.EXPECT().SyncDistribution(context.TODO(), int64(2), + suite.mockCluster.EXPECT().SyncDistribution(mock.Anything, int64(2), mock.AnythingOfType("*querypb.SyncDistributionRequest")). Run(func(ctx context.Context, nodeID int64, req *querypb.SyncDistributionRequest) { assert.ElementsMatch(suite.T(), []*querypb.SyncDistributionRequest{req}, @@ -462,7 +462,7 @@ func (suite *LeaderObserverTestSuite) TestSyncRemovedSegments() { }). Return(&commonpb.Status{}, nil) - observer.Start(context.TODO()) + observer.Start() select { case <-ch: @@ -471,7 +471,6 @@ func (suite *LeaderObserverTestSuite) TestSyncRemovedSegments() { } func (suite *LeaderObserverTestSuite) TestIgnoreSyncRemovedSegments() { - observer := suite.observer observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) observer.meta.CollectionManager.PutPartition(utils.CreateTestPartition(1, 1)) @@ -523,7 +522,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncRemovedSegments() { } } called := atomic.NewBool(false) - suite.mockCluster.EXPECT().SyncDistribution(context.TODO(), int64(2), mock.AnythingOfType("*querypb.SyncDistributionRequest")). + suite.mockCluster.EXPECT().SyncDistribution(mock.Anything, int64(2), mock.AnythingOfType("*querypb.SyncDistributionRequest")). Run(func(ctx context.Context, nodeID int64, req *querypb.SyncDistributionRequest) { assert.ElementsMatch(suite.T(), []*querypb.SyncDistributionRequest{req}, []*querypb.SyncDistributionRequest{expectReqeustFunc(req.GetVersion())}) @@ -531,7 +530,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncRemovedSegments() { }). Return(&commonpb.Status{}, nil) - observer.Start(context.TODO()) + observer.Start() suite.Eventually(func() bool { return called.Load() }, @@ -592,6 +591,44 @@ func (suite *LeaderObserverTestSuite) TestSyncTargetVersion() { suite.Len(action.SealedInTarget, 1) } +func (suite *LeaderObserverTestSuite) TestCheckTargetVersion() { + collectionID := int64(1001) + observer := suite.observer + + suite.Run("check_channel_blocked", func() { + oldCh := observer.manualCheck + defer func() { + observer.manualCheck = oldCh + }() + + // zero-length channel + observer.manualCheck = make(chan checkRequest) + + ctx, cancel := context.WithCancel(context.Background()) + // cancel context, make test return fast + cancel() + + result := observer.CheckTargetVersion(ctx, collectionID) + suite.False(result) + }) + + suite.Run("check_return_ctx_timeout", func() { + oldCh := observer.manualCheck + defer func() { + observer.manualCheck = oldCh + }() + + // make channel length = 1, task received + observer.manualCheck = make(chan checkRequest, 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) + defer cancel() + + result := observer.CheckTargetVersion(ctx, collectionID) + suite.False(result) + }) +} + func TestLeaderObserverSuite(t *testing.T) { suite.Run(t, new(LeaderObserverTestSuite)) } diff --git a/internal/querycoordv2/observers/replica_observer.go b/internal/querycoordv2/observers/replica_observer.go index 4b44005fcf14b..dcd8bdd3ce604 100644 --- a/internal/querycoordv2/observers/replica_observer.go +++ b/internal/querycoordv2/observers/replica_observer.go @@ -31,7 +31,7 @@ import ( // check replica, find outbound nodes and remove it from replica if all segment/channel has been moved type ReplicaObserver struct { - c chan struct{} + cancel context.CancelFunc wg sync.WaitGroup meta *meta.Meta distMgr *meta.DistributionManager @@ -41,20 +41,24 @@ type ReplicaObserver struct { func NewReplicaObserver(meta *meta.Meta, distMgr *meta.DistributionManager) *ReplicaObserver { return &ReplicaObserver{ - c: make(chan struct{}), meta: meta, distMgr: distMgr, } } -func (ob *ReplicaObserver) Start(ctx context.Context) { +func (ob *ReplicaObserver) Start() { + ctx, cancel := context.WithCancel(context.Background()) + ob.cancel = cancel + ob.wg.Add(1) go ob.schedule(ctx) } func (ob *ReplicaObserver) Stop() { ob.stopOnce.Do(func() { - close(ob.c) + if ob.cancel != nil { + ob.cancel() + } ob.wg.Wait() }) } @@ -68,9 +72,6 @@ func (ob *ReplicaObserver) schedule(ctx context.Context) { for { select { case <-ctx.Done(): - log.Info("Close replica observer due to context canceled") - return - case <-ob.c: log.Info("Close replica observer") return diff --git a/internal/querycoordv2/observers/replica_observer_test.go b/internal/querycoordv2/observers/replica_observer_test.go index e6181a2346d45..1efcc0597b810 100644 --- a/internal/querycoordv2/observers/replica_observer_test.go +++ b/internal/querycoordv2/observers/replica_observer_test.go @@ -16,7 +16,6 @@ package observers import ( - "context" "testing" "time" @@ -39,7 +38,7 @@ type ReplicaObserverSuite struct { suite.Suite kv kv.MetaKv - //dependency + // dependency meta *meta.Meta distMgr *meta.DistributionManager @@ -77,7 +76,7 @@ func (suite *ReplicaObserverSuite) SetupTest() { suite.distMgr = meta.NewDistributionManager() suite.observer = NewReplicaObserver(suite.meta, suite.distMgr) - suite.observer.Start(context.TODO()) + suite.observer.Start() suite.collectionID = int64(1000) suite.partitionID = int64(100) } diff --git a/internal/querycoordv2/observers/resource_observer.go b/internal/querycoordv2/observers/resource_observer.go index 6e802289c2306..dfb23b2763481 100644 --- a/internal/querycoordv2/observers/resource_observer.go +++ b/internal/querycoordv2/observers/resource_observer.go @@ -31,28 +31,32 @@ import ( // check whether rg lack of node, try to transfer node from default rg type ResourceObserver struct { - c chan struct{} - wg sync.WaitGroup - meta *meta.Meta + cancel context.CancelFunc + wg sync.WaitGroup + meta *meta.Meta stopOnce sync.Once } func NewResourceObserver(meta *meta.Meta) *ResourceObserver { return &ResourceObserver{ - c: make(chan struct{}), meta: meta, } } -func (ob *ResourceObserver) Start(ctx context.Context) { +func (ob *ResourceObserver) Start() { + ctx, cancel := context.WithCancel(context.Background()) + ob.cancel = cancel + ob.wg.Add(1) go ob.schedule(ctx) } func (ob *ResourceObserver) Stop() { ob.stopOnce.Do(func() { - close(ob.c) + if ob.cancel != nil { + ob.cancel() + } ob.wg.Wait() }) } @@ -66,9 +70,6 @@ func (ob *ResourceObserver) schedule(ctx context.Context) { for { select { case <-ctx.Done(): - log.Info("Close resource group observer due to context canceled") - return - case <-ob.c: log.Info("Close resource group observer") return diff --git a/internal/querycoordv2/observers/resource_observer_test.go b/internal/querycoordv2/observers/resource_observer_test.go index 3bbe78b262d71..7565c06e2a1ec 100644 --- a/internal/querycoordv2/observers/resource_observer_test.go +++ b/internal/querycoordv2/observers/resource_observer_test.go @@ -16,7 +16,6 @@ package observers import ( - "context" "testing" "time" @@ -41,7 +40,7 @@ type ResourceObserverSuite struct { suite.Suite kv kv.MetaKv - //dependency + // dependency store *mocks.QueryCoordCatalog meta *meta.Meta observer *ResourceObserver @@ -77,7 +76,7 @@ func (suite *ResourceObserverSuite) SetupTest() { suite.meta = meta.NewMeta(idAllocator, suite.store, suite.nodeMgr) suite.observer = NewResourceObserver(suite.meta) - suite.observer.Start(context.TODO()) + suite.observer.Start() suite.store.EXPECT().SaveResourceGroup(mock.Anything).Return(nil) for i := 0; i < 10; i++ { @@ -122,7 +121,7 @@ func (suite *ResourceObserverSuite) TestCheckNodesInReplica() { suite.meta.ResourceManager.HandleNodeDown(100) suite.meta.ResourceManager.HandleNodeDown(101) - //before auto recover rg + // before auto recover rg suite.Eventually(func() bool { lackNodesNum := suite.meta.ResourceManager.CheckLackOfNode("rg") nodesInReplica := suite.meta.ReplicaManager.Get(2).GetNodes() @@ -189,7 +188,7 @@ func (suite *ResourceObserverSuite) TestRecoverReplicaFailed() { suite.meta.ResourceManager.HandleNodeDown(100) suite.meta.ResourceManager.HandleNodeDown(101) - //before auto recover rg + // before auto recover rg suite.Eventually(func() bool { lackNodesNum := suite.meta.ResourceManager.CheckLackOfNode("rg") nodesInReplica := suite.meta.ReplicaManager.Get(2).GetNodes() diff --git a/internal/querycoordv2/observers/target_observer.go b/internal/querycoordv2/observers/target_observer.go index 3568b3ae793ad..d1f210fed083e 100644 --- a/internal/querycoordv2/observers/target_observer.go +++ b/internal/querycoordv2/observers/target_observer.go @@ -41,14 +41,17 @@ type targetUpdateRequest struct { ReadyNotifier chan struct{} } +type initRequest struct{} + type TargetObserver struct { - c chan struct{} + cancel context.CancelFunc wg sync.WaitGroup meta *meta.Meta targetMgr *meta.TargetManager distMgr *meta.DistributionManager broker meta.Broker + initChan chan initRequest manualCheck chan checkRequest nextTargetLastUpdate map[int64]time.Time updateChan chan targetUpdateRequest @@ -60,7 +63,6 @@ type TargetObserver struct { func NewTargetObserver(meta *meta.Meta, targetMgr *meta.TargetManager, distMgr *meta.DistributionManager, broker meta.Broker) *TargetObserver { return &TargetObserver{ - c: make(chan struct{}), meta: meta, targetMgr: targetMgr, distMgr: distMgr, @@ -69,17 +71,26 @@ func NewTargetObserver(meta *meta.Meta, targetMgr *meta.TargetManager, distMgr * nextTargetLastUpdate: make(map[int64]time.Time), updateChan: make(chan targetUpdateRequest), readyNotifiers: make(map[int64][]chan struct{}), + initChan: make(chan initRequest), } } -func (ob *TargetObserver) Start(ctx context.Context) { +func (ob *TargetObserver) Start() { + ctx, cancel := context.WithCancel(context.Background()) + ob.cancel = cancel + ob.wg.Add(1) go ob.schedule(ctx) + + // after target observer start, update target for all collection + ob.initChan <- initRequest{} } func (ob *TargetObserver) Stop() { ob.stopOnce.Do(func() { - close(ob.c) + if ob.cancel != nil { + ob.cancel() + } ob.wg.Wait() }) } @@ -93,15 +104,19 @@ func (ob *TargetObserver) schedule(ctx context.Context) { for { select { case <-ctx.Done(): - log.Info("Close target observer due to context canceled") - return - case <-ob.c: log.Info("Close target observer") return + case <-ob.initChan: + for _, collectionID := range ob.meta.GetAll() { + ob.init(collectionID) + } + case <-ticker.C: ob.clean() - ob.tryUpdateTarget() + for _, collectionID := range ob.meta.GetAll() { + ob.check(collectionID) + } case req := <-ob.manualCheck: ob.check(req.CollectionID) @@ -118,11 +133,6 @@ func (ob *TargetObserver) schedule(ctx context.Context) { } req.Notifier <- err - - // Manually trigger the observer, - // to avoid waiting for a long time (10s) - ob.clean() - ob.tryUpdateTarget() } } } @@ -130,13 +140,20 @@ func (ob *TargetObserver) schedule(ctx context.Context) { // Check checks whether the next target is ready, // and updates the current target if it is, // returns true if current target is not nil -func (ob *TargetObserver) Check(collectionID int64) bool { +func (ob *TargetObserver) Check(ctx context.Context, collectionID int64) bool { notifier := make(chan bool) - ob.manualCheck <- checkRequest{ - CollectionID: collectionID, - Notifier: notifier, + select { + case ob.manualCheck <- checkRequest{CollectionID: collectionID, Notifier: notifier}: + case <-ctx.Done(): + return false + } + + select { + case result := <-notifier: + return result + case <-ctx.Done(): + return false } - return <-notifier } func (ob *TargetObserver) check(collectionID int64) { @@ -158,6 +175,18 @@ func (ob *TargetObserver) check(collectionID int64) { } } +func (ob *TargetObserver) init(collectionID int64) { + // pull next target first if not exist + if !ob.targetMgr.IsNextTargetExist(collectionID) { + ob.updateNextTarget(collectionID) + } + + // try to update current target if all segment/channel are ready + if ob.shouldUpdateCurrentTarget(collectionID) { + ob.updateCurrentTarget(collectionID) + } +} + // UpdateNextTarget updates the next target, // returns a channel which will be closed when the next target is ready, // or returns error if failed to pull target @@ -183,28 +212,19 @@ func (ob *TargetObserver) ReleaseCollection(collectionID int64) { delete(ob.readyNotifiers, collectionID) } -func (ob *TargetObserver) tryUpdateTarget() { - collections := ob.meta.GetAll() - for _, collectionID := range collections { - ob.check(collectionID) - } - - collectionSet := typeutil.NewUniqueSet(collections...) +func (ob *TargetObserver) clean() { + collectionSet := typeutil.NewUniqueSet(ob.meta.GetAll()...) // for collection which has been removed from target, try to clear nextTargetLastUpdate for collection := range ob.nextTargetLastUpdate { if !collectionSet.Contain(collection) { delete(ob.nextTargetLastUpdate, collection) } } -} - -func (ob *TargetObserver) clean() { - collections := typeutil.NewSet(ob.meta.GetAll()...) ob.mut.Lock() defer ob.mut.Unlock() for collectionID, notifiers := range ob.readyNotifiers { - if !collections.Contain(collectionID) { + if !collectionSet.Contain(collectionID) { for i := range notifiers { close(notifiers[i]) } diff --git a/internal/querycoordv2/observers/target_observer_test.go b/internal/querycoordv2/observers/target_observer_test.go index 9a09f1b1fbb05..2fee0098b0bb0 100644 --- a/internal/querycoordv2/observers/target_observer_test.go +++ b/internal/querycoordv2/observers/target_observer_test.go @@ -41,7 +41,7 @@ type TargetObserverSuite struct { suite.Suite kv kv.MetaKv - //dependency + // dependency meta *meta.Meta targetMgr *meta.TargetManager distMgr *meta.DistributionManager @@ -83,7 +83,6 @@ func (suite *TargetObserverSuite) SetupTest() { suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta) suite.distMgr = meta.NewDistributionManager() suite.observer = NewTargetObserver(suite.meta, suite.targetMgr, suite.distMgr, suite.broker) - suite.observer.Start(context.TODO()) suite.collectionID = int64(1000) suite.partitionID = int64(100) @@ -122,6 +121,7 @@ func (suite *TargetObserverSuite) SetupTest() { } suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(suite.nextTargetChannels, suite.nextTargetSegments, nil) + suite.observer.Start() } func (suite *TargetObserverSuite) TestTriggerUpdateTarget() { @@ -216,6 +216,101 @@ func (suite *TargetObserverSuite) TearDownSuite() { suite.observer.Stop() } +type TargetObserverCheckSuite struct { + suite.Suite + + kv kv.MetaKv + // dependency + meta *meta.Meta + targetMgr *meta.TargetManager + distMgr *meta.DistributionManager + broker *meta.MockBroker + + observer *TargetObserver + + collectionID int64 + partitionID int64 +} + +func (suite *TargetObserverCheckSuite) SetupSuite() { + paramtable.Init() +} + +func (suite *TargetObserverCheckSuite) SetupTest() { + var err error + config := GenerateEtcdConfig() + cli, err := etcd.GetEtcdClient( + config.UseEmbedEtcd.GetAsBool(), + config.EtcdUseSSL.GetAsBool(), + config.Endpoints.GetAsStrings(), + config.EtcdTLSCert.GetValue(), + config.EtcdTLSKey.GetValue(), + config.EtcdTLSCACert.GetValue(), + config.EtcdTLSMinVersion.GetValue()) + suite.Require().NoError(err) + suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) + + // meta + store := querycoord.NewCatalog(suite.kv) + idAllocator := RandomIncrementIDAllocator() + suite.meta = meta.NewMeta(idAllocator, store, session.NewNodeManager()) + + suite.broker = meta.NewMockBroker(suite.T()) + suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta) + suite.distMgr = meta.NewDistributionManager() + suite.observer = NewTargetObserver(suite.meta, suite.targetMgr, suite.distMgr, suite.broker) + suite.collectionID = int64(1000) + suite.partitionID = int64(100) + + err = suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(suite.collectionID, 1)) + suite.NoError(err) + err = suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collectionID, suite.partitionID)) + suite.NoError(err) + replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, 1, meta.DefaultResourceGroupName) + suite.NoError(err) + replicas[0].AddNode(2) + err = suite.meta.ReplicaManager.Put(replicas...) + suite.NoError(err) +} + +func (suite *TargetObserverCheckSuite) TestCheckCtxDone() { + observer := suite.observer + + suite.Run("check_channel_blocked", func() { + oldCh := observer.manualCheck + defer func() { + observer.manualCheck = oldCh + }() + + // zero-length channel + observer.manualCheck = make(chan checkRequest) + + ctx, cancel := context.WithCancel(context.Background()) + // cancel context, make test return fast + cancel() + + result := observer.Check(ctx, suite.collectionID) + suite.False(result) + }) + + suite.Run("check_return_ctx_timeout", func() { + oldCh := observer.manualCheck + defer func() { + observer.manualCheck = oldCh + }() + + // make channel length = 1, task received + observer.manualCheck = make(chan checkRequest, 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) + defer cancel() + + result := observer.Check(ctx, suite.collectionID) + suite.False(result) + }) +} + func TestTargetObserver(t *testing.T) { suite.Run(t, new(TargetObserverSuite)) + suite.Run(t, new(TargetObserverCheckSuite)) } diff --git a/internal/querycoordv2/params/params.go b/internal/querycoordv2/params/params.go index e5440518733cb..1b7fe5aa9f6ff 100644 --- a/internal/querycoordv2/params/params.go +++ b/internal/querycoordv2/params/params.go @@ -29,9 +29,7 @@ import ( var Params *paramtable.ComponentParam = paramtable.Get() -var ( - ErrFailedAllocateID = errors.New("failed to allocate ID") -) +var ErrFailedAllocateID = errors.New("failed to allocate ID") // GenerateEtcdConfig returns a etcd config with a random root path, // NOTE: for test only diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index d8609c18eb884..1d48f39718ff0 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -25,6 +25,7 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/tikv/client-go/v2/txnkv" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/atomic" "go.uber.org/zap" @@ -34,8 +35,10 @@ import ( "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/kv/tikv" "github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/querycoordv2/balance" "github.com/milvus-io/milvus/internal/querycoordv2/checkers" "github.com/milvus-io/milvus/internal/querycoordv2/dist" @@ -52,6 +55,7 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -59,10 +63,8 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -var ( - // Only for re-export - Params = params.Params -) +// Only for re-export +var Params = params.Params type Server struct { ctx context.Context @@ -70,6 +72,7 @@ type Server struct { wg sync.WaitGroup status atomic.Int32 etcdCli *clientv3.Client + tikvCli *txnkv.Client address string session *sessionutil.Session kv kv.MetaKv @@ -77,8 +80,8 @@ type Server struct { metricsCacheManager *metricsinfo.MetricsCacheManager // Coordinators - dataCoord types.DataCoord - rootCoord types.RootCoord + dataCoord types.DataCoordClient + rootCoord types.RootCoordClient // Meta store metastore.QueryCoordCatalog @@ -204,13 +207,21 @@ func (s *Server) Init() error { func (s *Server) initQueryCoord() error { s.UpdateStateCode(commonpb.StateCode_Initializing) log.Info("QueryCoord", zap.Any("State", commonpb.StateCode_Initializing)) - // Init KV - etcdKV := etcdkv.NewEtcdKV(s.etcdCli, Params.EtcdCfg.MetaRootPath.GetValue()) - s.kv = etcdKV - log.Info("query coordinator try to connect etcd success") + // Init KV and ID allocator + metaType := Params.MetaStoreCfg.MetaStoreType.GetValue() + var idAllocatorKV kv.TxnKV + log.Info(fmt.Sprintf("query coordinator connecting to %s.", metaType)) + if metaType == util.MetaStoreTypeTiKV { + s.kv = tikv.NewTiKV(s.tikvCli, Params.TiKVCfg.MetaRootPath.GetValue()) + idAllocatorKV = tsoutil.NewTSOTiKVBase(s.tikvCli, Params.TiKVCfg.KvRootPath.GetValue(), "querycoord-id-allocator") + } else if metaType == util.MetaStoreTypeEtcd { + s.kv = etcdkv.NewEtcdKV(s.etcdCli, Params.EtcdCfg.MetaRootPath.GetValue()) + idAllocatorKV = tsoutil.NewTSOKVBase(s.etcdCli, Params.EtcdCfg.KvRootPath.GetValue(), "querycoord-id-allocator") + } else { + return fmt.Errorf("not supported meta store: %s", metaType) + } + log.Info(fmt.Sprintf("query coordinator successfully connected to %s.", metaType)) - // Init ID allocator - idAllocatorKV := tsoutil.NewTSOKVBase(s.etcdCli, Params.EtcdCfg.KvRootPath.GetValue(), "querycoord-id-allocator") idAllocator := allocator.NewGlobalIDAllocator("idTimestamp", idAllocatorKV) err := idAllocator.Initialize() if err != nil { @@ -310,7 +321,7 @@ func (s *Server) initMeta() error { log.Info("recover meta...") err := s.meta.CollectionManager.Recover(s.broker) if err != nil { - log.Warn("failed to recover collections") + log.Warn("failed to recover collections", zap.Error(err)) return err } collections := s.meta.GetAll() @@ -323,13 +334,13 @@ func (s *Server) initMeta() error { err = s.meta.ReplicaManager.Recover(collections) if err != nil { - log.Warn("failed to recover replicas") + log.Warn("failed to recover replicas", zap.Error(err)) return err } err = s.meta.ResourceManager.Recover() if err != nil { - log.Warn("failed to recover resource groups") + log.Warn("failed to recover resource groups", zap.Error(err)) return err } @@ -413,44 +424,46 @@ func (s *Server) startQueryCoord() error { s.startServerLoop() s.afterStart() s.UpdateStateCode(commonpb.StateCode_Healthy) + sessionutil.SaveServerInfo(typeutil.QueryCoordRole, s.session.ServerID) return nil } func (s *Server) startServerLoop() { + // start the components from inside to outside, + // to make the dependencies ready for every component log.Info("start cluster...") - s.cluster.Start(s.ctx) + s.cluster.Start() - log.Info("start job scheduler...") - s.jobScheduler.Start(s.ctx) + log.Info("start observers...") + s.collectionObserver.Start() + s.leaderObserver.Start() + s.targetObserver.Start() + s.replicaObserver.Start() + s.resourceObserver.Start() log.Info("start task scheduler...") - s.taskScheduler.Start(s.ctx) + s.taskScheduler.Start() log.Info("start checker controller...") - s.checkerController.Start(s.ctx) + s.checkerController.Start() - log.Info("start observers...") - s.collectionObserver.Start(s.ctx) - s.leaderObserver.Start(s.ctx) - s.targetObserver.Start(s.ctx) - s.replicaObserver.Start(s.ctx) - s.resourceObserver.Start(s.ctx) + log.Info("start job scheduler...") + s.jobScheduler.Start() } func (s *Server) Stop() error { + // stop the components from outside to inside, + // to make the dependencies stopped working properly, + // cancel the server context first to stop receiving requests s.cancel() - if s.session != nil { - s.session.Stop() - } - if s.cluster != nil { - log.Info("stop cluster...") - s.cluster.Stop() - } + // FOLLOW the dependence graph: + // job scheduler -> checker controller -> task scheduler -> dist controller -> cluster -> session + // observers -> dist controller - if s.distController != nil { - log.Info("stop dist controller...") - s.distController.Stop() + if s.jobScheduler != nil { + log.Info("stop job scheduler...") + s.jobScheduler.Stop() } if s.checkerController != nil { @@ -463,11 +476,6 @@ func (s *Server) Stop() error { s.taskScheduler.Stop() } - if s.jobScheduler != nil { - log.Info("stop job scheduler...") - s.jobScheduler.Stop() - } - log.Info("stop observers...") if s.collectionObserver != nil { s.collectionObserver.Stop() @@ -485,6 +493,20 @@ func (s *Server) Stop() error { s.resourceObserver.Stop() } + if s.distController != nil { + log.Info("stop dist controller...") + s.distController.Stop() + } + + if s.cluster != nil { + log.Info("stop cluster...") + s.cluster.Stop() + } + + if s.session != nil { + s.session.Stop() + } + s.wg.Wait() log.Info("QueryCoord stop successfully") return nil @@ -499,7 +521,7 @@ func (s *Server) State() commonpb.StateCode { return commonpb.StateCode(s.status.Load()) } -func (s *Server) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (s *Server) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { nodeID := common.NotRegisteredID if s.session != nil && s.session.Registered() { nodeID = s.session.ServerID @@ -511,21 +533,21 @@ func (s *Server) GetComponentStates(ctx context.Context) (*milvuspb.ComponentSta } return &milvuspb.ComponentStates{ - Status: merr.Status(nil), + Status: merr.Success(), State: serviceComponentInfo, - //SubcomponentStates: subComponentInfos, + // SubcomponentStates: subComponentInfos, }, nil } -func (s *Server) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (s *Server) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil } -func (s *Server) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (s *Server) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Value: Params.CommonCfg.QueryCoordTimeTick.GetValue(), }, nil } @@ -539,8 +561,12 @@ func (s *Server) SetEtcdClient(etcdClient *clientv3.Client) { s.etcdCli = etcdClient } +func (s *Server) SetTiKVClient(client *txnkv.Client) { + s.tikvCli = client +} + // SetRootCoord sets root coordinator's client -func (s *Server) SetRootCoord(rootCoord types.RootCoord) error { +func (s *Server) SetRootCoordClient(rootCoord types.RootCoordClient) error { if rootCoord == nil { return errors.New("null RootCoord interface") } @@ -550,7 +576,7 @@ func (s *Server) SetRootCoord(rootCoord types.RootCoord) error { } // SetDataCoord sets data coordinator's client -func (s *Server) SetDataCoord(dataCoord types.DataCoord) error { +func (s *Server) SetDataCoordClient(dataCoord types.DataCoordClient) error { if dataCoord == nil { return errors.New("null DataCoord interface") } @@ -559,7 +585,7 @@ func (s *Server) SetDataCoord(dataCoord types.DataCoord) error { return nil } -func (s *Server) SetQueryNodeCreator(f func(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error)) { +func (s *Server) SetQueryNodeCreator(f func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error)) { s.queryNodeCreator = f } diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index 2b97d09e2e354..b3dad1843397b 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -26,6 +26,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "github.com/tikv/client-go/v2/txnkv" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -47,6 +48,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tikv" ) func TestMain(m *testing.M) { @@ -81,10 +83,13 @@ type ServerSuite struct { // Mocks broker *meta.MockBroker - server *Server - nodes []*mocks.MockQueryNode + tikvCli *txnkv.Client + server *Server + nodes []*mocks.MockQueryNode } +var testMeta string + func (suite *ServerSuite) SetupSuite() { paramtable.Init() params.GenerateEtcdConfig() @@ -121,7 +126,10 @@ func (suite *ServerSuite) SetupSuite() { func (suite *ServerSuite) SetupTest() { var err error + paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, testMeta) + suite.tikvCli = tikv.SetupLocalTxn() suite.server, err = suite.newQueryCoord() + suite.Require().NoError(err) suite.hackServer() err = suite.server.Start() @@ -152,6 +160,7 @@ func (suite *ServerSuite) TearDownTest() { node.Stop() } } + paramtable.Get().Reset(paramtable.Get().MetaStoreCfg.MetaStoreType.Key) } func (suite *ServerSuite) TestRecover() { @@ -171,7 +180,7 @@ func (suite *ServerSuite) TestRecover() { func (suite *ServerSuite) TestNodeUp() { node1 := mocks.NewMockQueryNode(suite.T(), suite.server.etcdCli, 100) - node1.EXPECT().GetDataDistribution(mock.Anything, mock.Anything).Return(&querypb.GetDataDistributionResponse{Status: merr.Status(nil)}, nil) + node1.EXPECT().GetDataDistribution(mock.Anything, mock.Anything).Return(&querypb.GetDataDistributionResponse{Status: merr.Success()}, nil) err := node1.Start() suite.NoError(err) defer node1.Stop() @@ -195,7 +204,7 @@ func (suite *ServerSuite) TestNodeUp() { suite.server.nodeMgr.Add(session.NewNodeInfo(1001, "localhost")) node2 := mocks.NewMockQueryNode(suite.T(), suite.server.etcdCli, 101) - node2.EXPECT().GetDataDistribution(mock.Anything, mock.Anything).Return(&querypb.GetDataDistributionResponse{Status: merr.Status(nil)}, nil).Maybe() + node2.EXPECT().GetDataDistribution(mock.Anything, mock.Anything).Return(&querypb.GetDataDistributionResponse{Status: merr.Success()}, nil).Maybe() err = node2.Start() suite.NoError(err) defer node2.Stop() @@ -282,7 +291,7 @@ func (suite *ServerSuite) TestDisableActiveStandby() { suite.NoError(err) suite.Equal(commonpb.StateCode_Healthy, suite.server.State()) - states, err := suite.server.GetComponentStates(context.Background()) + states, err := suite.server.GetComponentStates(context.Background(), nil) suite.NoError(err) suite.Equal(commonpb.StateCode_Healthy, states.GetState().GetStateCode()) } @@ -295,11 +304,11 @@ func (suite *ServerSuite) TestEnableActiveStandby() { suite.server, err = suite.newQueryCoord() suite.NoError(err) - mockRootCoord := coordMocks.NewRootCoord(suite.T()) - mockDataCoord := coordMocks.NewMockDataCoord(suite.T()) + mockRootCoord := coordMocks.NewMockRootCoordClient(suite.T()) + mockDataCoord := coordMocks.NewMockDataCoordClient(suite.T()) mockRootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Schema: &schemapb.CollectionSchema{}, }, nil).Maybe() for _, collection := range suite.collections { @@ -310,17 +319,17 @@ func (suite *ServerSuite) TestEnableActiveStandby() { CollectionID: collection, } mockRootCoord.EXPECT().ShowPartitions(mock.Anything, req).Return(&milvuspb.ShowPartitionsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), PartitionIDs: suite.partitions[collection], }, nil).Maybe() suite.expectGetRecoverInfoByMockDataCoord(collection, mockDataCoord) } - err = suite.server.SetRootCoord(mockRootCoord) + err = suite.server.SetRootCoordClient(mockRootCoord) suite.NoError(err) - err = suite.server.SetDataCoord(mockDataCoord) + err = suite.server.SetDataCoordClient(mockDataCoord) suite.NoError(err) - //suite.hackServer() - states1, err := suite.server.GetComponentStates(context.Background()) + // suite.hackServer() + states1, err := suite.server.GetComponentStates(context.Background(), nil) suite.NoError(err) suite.Equal(commonpb.StateCode_StandBy, states1.GetState().GetStateCode()) err = suite.server.Register() @@ -328,7 +337,7 @@ func (suite *ServerSuite) TestEnableActiveStandby() { err = suite.server.Start() suite.NoError(err) - states2, err := suite.server.GetComponentStates(context.Background()) + states2, err := suite.server.GetComponentStates(context.Background(), nil) suite.NoError(err) suite.Equal(commonpb.StateCode_Healthy, states2.GetState().GetStateCode()) @@ -399,11 +408,11 @@ func (suite *ServerSuite) expectGetRecoverInfo(collection int64) { } func (suite *ServerSuite) expectLoadAndReleasePartitions(querynode *mocks.MockQueryNode) { - querynode.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Return(merr.Status(nil), nil).Maybe() - querynode.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(merr.Status(nil), nil).Maybe() + querynode.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Return(merr.Success(), nil).Maybe() + querynode.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(merr.Success(), nil).Maybe() } -func (suite *ServerSuite) expectGetRecoverInfoByMockDataCoord(collection int64, dataCoord *coordMocks.MockDataCoord) { +func (suite *ServerSuite) expectGetRecoverInfoByMockDataCoord(collection int64, dataCoord *coordMocks.MockDataCoordClient) { var ( vChannels []*datapb.VchannelInfo segmentInfos []*datapb.SegmentInfo @@ -434,9 +443,7 @@ func (suite *ServerSuite) expectGetRecoverInfoByMockDataCoord(collection int64, } } dataCoord.EXPECT().GetRecoveryInfoV2(mock.Anything, getRecoveryInfoRequest).Return(&datapb.GetRecoveryInfoResponseV2{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), Channels: vChannels, Segments: segmentInfos, }, nil).Maybe() @@ -520,12 +527,12 @@ func (suite *ServerSuite) hackServer() { } func (suite *ServerSuite) hackBroker(server *Server) { - mockRootCoord := coordMocks.NewRootCoord(suite.T()) - mockDataCoord := coordMocks.NewMockDataCoord(suite.T()) + mockRootCoord := coordMocks.NewMockRootCoordClient(suite.T()) + mockDataCoord := coordMocks.NewMockDataCoordClient(suite.T()) for _, collection := range suite.collections { mockRootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Schema: &schemapb.CollectionSchema{}, }, nil).Maybe() req := &milvuspb.ShowPartitionsRequest{ @@ -535,13 +542,13 @@ func (suite *ServerSuite) hackBroker(server *Server) { CollectionID: collection, } mockRootCoord.EXPECT().ShowPartitions(mock.Anything, req).Return(&milvuspb.ShowPartitionsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), PartitionIDs: suite.partitions[collection], }, nil).Maybe() } - err := server.SetRootCoord(mockRootCoord) + err := server.SetRootCoordClient(mockRootCoord) suite.NoError(err) - err = server.SetDataCoord(mockDataCoord) + err = server.SetDataCoordClient(mockDataCoord) suite.NoError(err) } @@ -563,6 +570,8 @@ func (suite *ServerSuite) newQueryCoord() (*Server, error) { return nil, err } server.SetEtcdClient(etcdCli) + server.SetTiKVClient(suite.tikvCli) + server.SetQueryNodeCreator(session.DefaultQueryNodeCreator) suite.hackBroker(server) err = server.Init() @@ -570,5 +579,9 @@ func (suite *ServerSuite) newQueryCoord() (*Server, error) { } func TestServer(t *testing.T) { - suite.Run(t, new(ServerSuite)) + parameters := []string{"tikv", "etcd"} + for _, v := range parameters { + testMeta = v + suite.Run(t, new(ServerSuite)) + } } diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index c81dce13e124a..2e1c52389dab6 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -36,7 +36,6 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/errorutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -104,7 +103,7 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio }, nil } - err = fmt.Errorf("collection %d has not been loaded to memory or load failed", collectionID) + err = merr.WrapErrCollectionNotLoaded(collectionID) log.Warn("show collection failed", zap.Error(err)) return &querypb.ShowCollectionsResponse{ Status: merr.Status(err), @@ -162,10 +161,9 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions } err = merr.WrapErrPartitionNotLoaded(partitionID) - msg := fmt.Sprintf("partition %d has not been loaded to memory or load failed", partitionID) - log.Warn(msg) + log.Warn("show partitions failed", zap.Error(err)) return &querypb.ShowPartitionsResponse{ - Status: merr.Status(errors.Wrap(err, msg)), + Status: merr.Status(err), }, nil } percentages = append(percentages, int64(percentage)) @@ -181,7 +179,7 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions } return &querypb.ShowPartitionsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), PartitionIDs: partitions, InMemoryPercentages: percentages, RefreshProgress: refreshProgresses, @@ -245,7 +243,7 @@ func (s *Server) LoadCollection(ctx context.Context, req *querypb.LoadCollection } metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc() - return merr.Status(nil), nil + return merr.Success(), nil } func (s *Server) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { @@ -286,7 +284,7 @@ func (s *Server) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl metrics.QueryCoordReleaseLatency.WithLabelValues().Observe(float64(tr.ElapseSpan().Milliseconds())) meta.GlobalFailedLoadCache.Remove(req.GetCollectionID()) - return merr.Status(nil), nil + return merr.Success(), nil } func (s *Server) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { @@ -345,7 +343,7 @@ func (s *Server) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions } metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc() - return merr.Status(nil), nil + return merr.Success(), nil } func (s *Server) checkResourceGroup(collectionID int64, resourceGroups []string) error { @@ -411,7 +409,7 @@ func (s *Server) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart metrics.QueryCoordReleaseLatency.WithLabelValues().Observe(float64(tr.ElapseSpan().Milliseconds())) meta.GlobalFailedLoadCache.Remove(req.GetCollectionID()) - return merr.Status(nil), nil + return merr.Success(), nil } func (s *Server) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) { @@ -477,7 +475,7 @@ func (s *Server) GetPartitionStates(ctx context.Context, req *querypb.GetPartiti } return &querypb.GetPartitionStatesResponse{ - Status: merr.Status(nil), + Status: merr.Success(), PartitionDescriptions: states, }, nil } @@ -518,7 +516,7 @@ func (s *Server) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo } return &querypb.GetSegmentInfoResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Infos: infos, }, nil } @@ -545,7 +543,7 @@ func (s *Server) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncN return merr.Status(err), nil } - return merr.Status(nil), nil + return merr.Success(), nil } // refreshCollection must be called after loading a collection. It looks for new segments that are not loaded yet and @@ -674,7 +672,7 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques srcNode := req.GetSourceNodeIDs()[0] replica := s.meta.ReplicaManager.GetByCollectionAndNode(req.GetCollectionID(), srcNode) if replica == nil { - err := merr.WrapErrReplicaNotFound(-1, fmt.Sprintf("replica not found for collection %d and node %d", req.GetCollectionID(), srcNode)) + err := merr.WrapErrNodeNotFound(srcNode, fmt.Sprintf("source node not found in any replica of collection %d", req.GetCollectionID())) msg := "source node not found in any replica" log.Warn(msg) return merr.Status(err), nil @@ -685,9 +683,8 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques } for _, dstNode := range req.GetDstNodeIDs() { if !replica.Contains(dstNode) { - err := merr.WrapErrParameterInvalid("destination node in the same replica as source node", fmt.Sprintf("destination node %d not in replica %d", dstNode, replica.GetID())) - msg := "destination nodes have to be in the same replica of source node" - log.Warn(msg) + err := merr.WrapErrNodeNotFound(dstNode, "destination node not found in the same replica") + log.Warn("failed to balance to the destination node", zap.Error(err)) return merr.Status(err), nil } if err := s.isStoppingNode(dstNode); err != nil { @@ -702,7 +699,7 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques log.Warn(msg, zap.Error(err)) return merr.Status(errors.Wrap(err, msg)), nil } - return merr.Status(nil), nil + return merr.Success(), nil } func (s *Server) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { @@ -727,7 +724,7 @@ func (s *Server) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon } return &internalpb.ShowConfigurationsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Configuations: configList, }, nil } @@ -738,7 +735,7 @@ func (s *Server) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest log.RatedDebug(60, "get metrics request received", zap.String("metricType", req.GetRequest())) - if err := merr.CheckHealthy(s.State()); err != nil { + if err := merr.CheckHealthyStandby(s.State()); err != nil { msg := "failed to get metrics" log.Warn(msg, zap.Error(err)) return &milvuspb.GetMetricsResponse{ @@ -747,7 +744,7 @@ func (s *Server) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest } resp := &milvuspb.GetMetricsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryCoordRole, paramtable.GetNodeID()), } @@ -795,7 +792,7 @@ func (s *Server) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasReque } resp := &milvuspb.GetReplicasResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Replicas: make([]*milvuspb.ReplicaInfo, 0), } @@ -847,7 +844,7 @@ func (s *Server) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeade } resp := &querypb.GetShardLeadersResponse{ - Status: merr.Status(nil), + Status: merr.Success(), } percentage := s.meta.CollectionManager.CalculateLoadPercentage(req.GetCollectionID()) @@ -972,16 +969,15 @@ func (s *Server) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeade func (s *Server) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { if err := merr.CheckHealthy(s.State()); err != nil { - reason := errorutil.UnHealthReason("querycoord", paramtable.GetNodeID(), "querycoord is unhealthy") - return &milvuspb.CheckHealthResponse{IsHealthy: false, Reasons: []string{reason}}, nil + return &milvuspb.CheckHealthResponse{Status: merr.Status(err), IsHealthy: false, Reasons: []string{err.Error()}}, nil } errReasons, err := s.checkNodeHealth(ctx) if err != nil || len(errReasons) != 0 { - return &milvuspb.CheckHealthResponse{IsHealthy: false, Reasons: errReasons}, nil + return &milvuspb.CheckHealthResponse{Status: merr.Success(), IsHealthy: false, Reasons: errReasons}, nil } - return &milvuspb.CheckHealthResponse{IsHealthy: true, Reasons: errReasons}, nil + return &milvuspb.CheckHealthResponse{Status: merr.Success(), IsHealthy: true, Reasons: errReasons}, nil } func (s *Server) checkNodeHealth(ctx context.Context) ([]string, error) { @@ -993,13 +989,17 @@ func (s *Server) checkNodeHealth(ctx context.Context) ([]string, error) { node := node group.Go(func() error { resp, err := s.cluster.GetComponentStates(ctx, node.ID()) - isHealthy, reason := errorutil.UnHealthReasonWithComponentStatesOrErr("querynode", node.ID(), resp, err) - if !isHealthy { + if err != nil { + return err + } + + err = merr.AnalyzeState("QueryNode", node.ID(), resp) + if err != nil { mu.Lock() defer mu.Unlock() - errReasons = append(errReasons, reason) + errReasons = append(errReasons, err.Error()) } - return err + return nil }) } @@ -1024,7 +1024,7 @@ func (s *Server) CreateResourceGroup(ctx context.Context, req *milvuspb.CreateRe log.Warn("failed to create resource group", zap.Error(err)) return merr.Status(err), nil } - return merr.Status(nil), nil + return merr.Success(), nil } func (s *Server) DropResourceGroup(ctx context.Context, req *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) { @@ -1050,7 +1050,7 @@ func (s *Server) DropResourceGroup(ctx context.Context, req *milvuspb.DropResour log.Warn("failed to drop resource group", zap.Error(err)) return merr.Status(err), nil } - return merr.Status(nil), nil + return merr.Success(), nil } func (s *Server) TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) { @@ -1106,7 +1106,7 @@ func (s *Server) TransferNode(ctx context.Context, req *milvuspb.TransferNodeReq utils.AddNodesToCollectionsInRG(s.meta, req.GetTargetResourceGroup(), nodes...) - return merr.Status(nil), nil + return merr.Success(), nil } func (s *Server) TransferReplica(ctx context.Context, req *querypb.TransferReplicaRequest) (*commonpb.Status, error) { @@ -1166,7 +1166,7 @@ func (s *Server) TransferReplica(ctx context.Context, req *querypb.TransferRepli return merr.Status(err), nil } - return merr.Status(nil), nil + return merr.Success(), nil } func (s *Server) transferReplica(targetRG string, replicas []*meta.Replica) error { @@ -1190,7 +1190,7 @@ func (s *Server) ListResourceGroups(ctx context.Context, req *milvuspb.ListResou log.Info("list resource group request received") resp := &milvuspb.ListResourceGroupsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), } if err := merr.CheckHealthy(s.State()); err != nil { log.Warn("failed to list resource group", zap.Error(err)) @@ -1209,7 +1209,7 @@ func (s *Server) DescribeResourceGroup(ctx context.Context, req *querypb.Describ log.Info("describe resource group request received") resp := &querypb.DescribeResourceGroupResponse{ - Status: merr.Status(nil), + Status: merr.Success(), } if err := merr.CheckHealthy(s.State()); err != nil { log.Warn("failed to describe resource group", zap.Error(err)) diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index a42b7f45d4c86..ed7c18a786c93 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -19,6 +19,7 @@ package querycoordv2 import ( "context" "encoding/json" + "sort" "testing" "time" @@ -118,8 +119,10 @@ func (suite *ServiceSuite) SetupSuite() { 1000: 1, 1001: 3, } - suite.nodes = []int64{1, 2, 3, 4, 5, - 101, 102, 103, 104, 105} + suite.nodes = []int64{ + 1, 2, 3, 4, 5, + 101, 102, 103, 104, 105, + } } func (suite *ServiceSuite) SetupTest() { @@ -147,7 +150,7 @@ func (suite *ServiceSuite) SetupTest() { suite.dist, suite.broker, ) - suite.targetObserver.Start(context.Background()) + suite.targetObserver.Start() for _, node := range suite.nodes { suite.nodeMgr.Add(session.NewNodeInfo(node, "localhost")) err := suite.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, node) @@ -156,7 +159,7 @@ func (suite *ServiceSuite) SetupTest() { suite.cluster = session.NewMockCluster(suite.T()) suite.jobScheduler = job.NewScheduler() suite.taskScheduler = task.NewMockScheduler(suite.T()) - suite.jobScheduler.Start(context.Background()) + suite.jobScheduler.Start() suite.balancer = balance.NewRowCountBasedBalancer( suite.taskScheduler, suite.nodeMgr, @@ -389,14 +392,16 @@ func (suite *ServiceSuite) TestResourceGroup() { ID: 1, CollectionID: 1, Nodes: []int64{1011, 1013}, - ResourceGroup: "rg11"}, + ResourceGroup: "rg11", + }, typeutil.NewUniqueSet(1011, 1013)), ) server.meta.ReplicaManager.Put(meta.NewReplica(&querypb.Replica{ ID: 2, CollectionID: 2, Nodes: []int64{1012, 1014}, - ResourceGroup: "rg12"}, + ResourceGroup: "rg12", + }, typeutil.NewUniqueSet(1012, 1014)), ) @@ -871,7 +876,7 @@ func (suite *ServiceSuite) TestReleaseCollection() { server := suite.server suite.cluster.EXPECT().ReleasePartitions(mock.Anything, mock.Anything, mock.Anything). - Return(merr.Status(nil), nil) + Return(merr.Success(), nil) // Test release all collections for _, collection := range suite.collections { @@ -911,7 +916,7 @@ func (suite *ServiceSuite) TestReleasePartition() { // Test release all partitions suite.cluster.EXPECT().ReleasePartitions(mock.Anything, mock.Anything, mock.Anything). - Return(merr.Status(nil), nil) + Return(merr.Success(), nil) for _, collection := range suite.collections { req := &querypb.ReleasePartitionsRequest{ CollectionID: collection, @@ -949,7 +954,7 @@ func (suite *ServiceSuite) TestReleasePartition() { func (suite *ServiceSuite) TestRefreshCollection() { server := suite.server - server.collectionObserver.Start(context.Background()) + server.collectionObserver.Start() // Test refresh all collections. for _, collection := range suite.collections { @@ -1236,7 +1241,7 @@ func (suite *ServiceSuite) TestLoadBalanceFailed() { } resp, err := server.LoadBalance(ctx, req) suite.NoError(err) - suite.ErrorIs(merr.Error(resp), merr.ErrParameterInvalid) + suite.ErrorIs(merr.Error(resp), merr.ErrNodeNotFound) } // Test balance task failed @@ -1314,7 +1319,7 @@ func (suite *ServiceSuite) TestGetMetrics() { for _, node := range suite.nodes { suite.cluster.EXPECT().GetMetrics(ctx, node, mock.Anything).Return(&milvuspb.GetMetricsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), ComponentName: "QueryNode", }, nil) } @@ -1735,7 +1740,7 @@ func (suite *ServiceSuite) expectLoadPartitions() { suite.broker.EXPECT().DescribeIndex(mock.Anything, mock.Anything). Return(nil, nil) suite.cluster.EXPECT().LoadPartitions(mock.Anything, mock.Anything, mock.Anything). - Return(merr.Status(nil), nil) + Return(merr.Success(), nil) } func (suite *ServiceSuite) getAllSegments(collection int64) []int64 { @@ -1764,7 +1769,7 @@ func (suite *ServiceSuite) updateChannelDist(collection int64) { replicas := suite.meta.ReplicaManager.GetByCollection(collection) for _, replica := range replicas { i := 0 - for _, node := range replica.GetNodes() { + for _, node := range suite.sortInt64(replica.GetNodes()) { suite.dist.ChannelDistManager.Update(node, meta.DmChannelFromVChannel(&datapb.VchannelInfo{ CollectionID: collection, ChannelName: channels[i], @@ -1788,13 +1793,20 @@ func (suite *ServiceSuite) updateChannelDist(collection int64) { } } +func (suite *ServiceSuite) sortInt64(ints []int64) []int64 { + sort.Slice(ints, func(i int, j int) bool { + return ints[i] < ints[j] + }) + return ints +} + func (suite *ServiceSuite) updateChannelDistWithoutSegment(collection int64) { channels := suite.channels[collection] replicas := suite.meta.ReplicaManager.GetByCollection(collection) for _, replica := range replicas { i := 0 - for _, node := range replica.GetNodes() { + for _, node := range suite.sortInt64(replica.GetNodes()) { suite.dist.ChannelDistManager.Update(node, meta.DmChannelFromVChannel(&datapb.VchannelInfo{ CollectionID: collection, ChannelName: channels[i], diff --git a/internal/querycoordv2/session/cluster.go b/internal/querycoordv2/session/cluster.go index a06569105e844..1a2852d7dcb1a 100644 --- a/internal/querycoordv2/session/cluster.go +++ b/internal/querycoordv2/session/cluster.go @@ -23,15 +23,15 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/golang/protobuf/proto" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/pkg/log" - "go.uber.org/zap" ) const ( @@ -40,9 +40,7 @@ const ( bufferFlushPeriod = 500 * time.Millisecond ) -var ( - ErrNodeNotFound = errors.New("NodeNotFound") -) +var ErrNodeNotFound = errors.New("NodeNotFound") func WrapErrNodeNotFound(nodeID int64) error { return fmt.Errorf("%w(%v)", ErrNodeNotFound, nodeID) @@ -59,7 +57,7 @@ type Cluster interface { GetMetrics(ctx context.Context, nodeID int64, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) SyncDistribution(ctx context.Context, nodeID int64, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) GetComponentStates(ctx context.Context, nodeID int64) (*milvuspb.ComponentStates, error) - Start(ctx context.Context) + Start() Stop() } @@ -72,9 +70,9 @@ type QueryCluster struct { stopOnce sync.Once } -type QueryNodeCreator func(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error) +type QueryNodeCreator func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) -func DefaultQueryNodeCreator(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error) { +func DefaultQueryNodeCreator(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { return grpcquerynodeclient.NewClient(ctx, addr, nodeID) } @@ -87,7 +85,7 @@ func NewCluster(nodeManager *NodeManager, queryNodeCreator QueryNodeCreator) *Qu return c } -func (c *QueryCluster) Start(ctx context.Context) { +func (c *QueryCluster) Start() { c.wg.Add(1) go c.updateLoop() } @@ -123,7 +121,7 @@ func (c *QueryCluster) updateLoop() { func (c *QueryCluster) LoadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { var status *commonpb.Status var err error - err1 := c.send(ctx, nodeID, func(cli types.QueryNode) { + err1 := c.send(ctx, nodeID, func(cli types.QueryNodeClient) { req := proto.Clone(req).(*querypb.LoadSegmentsRequest) req.Base.TargetID = nodeID status, err = cli.LoadSegments(ctx, req) @@ -137,7 +135,7 @@ func (c *QueryCluster) LoadSegments(ctx context.Context, nodeID int64, req *quer func (c *QueryCluster) WatchDmChannels(ctx context.Context, nodeID int64, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { var status *commonpb.Status var err error - err1 := c.send(ctx, nodeID, func(cli types.QueryNode) { + err1 := c.send(ctx, nodeID, func(cli types.QueryNodeClient) { req := proto.Clone(req).(*querypb.WatchDmChannelsRequest) req.Base.TargetID = nodeID status, err = cli.WatchDmChannels(ctx, req) @@ -151,7 +149,7 @@ func (c *QueryCluster) WatchDmChannels(ctx context.Context, nodeID int64, req *q func (c *QueryCluster) UnsubDmChannel(ctx context.Context, nodeID int64, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { var status *commonpb.Status var err error - err1 := c.send(ctx, nodeID, func(cli types.QueryNode) { + err1 := c.send(ctx, nodeID, func(cli types.QueryNodeClient) { req := proto.Clone(req).(*querypb.UnsubDmChannelRequest) req.Base.TargetID = nodeID status, err = cli.UnsubDmChannel(ctx, req) @@ -165,7 +163,7 @@ func (c *QueryCluster) UnsubDmChannel(ctx context.Context, nodeID int64, req *qu func (c *QueryCluster) ReleaseSegments(ctx context.Context, nodeID int64, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { var status *commonpb.Status var err error - err1 := c.send(ctx, nodeID, func(cli types.QueryNode) { + err1 := c.send(ctx, nodeID, func(cli types.QueryNodeClient) { req := proto.Clone(req).(*querypb.ReleaseSegmentsRequest) req.Base.TargetID = nodeID status, err = cli.ReleaseSegments(ctx, req) @@ -179,7 +177,7 @@ func (c *QueryCluster) ReleaseSegments(ctx context.Context, nodeID int64, req *q func (c *QueryCluster) LoadPartitions(ctx context.Context, nodeID int64, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { var status *commonpb.Status var err error - err1 := c.send(ctx, nodeID, func(cli types.QueryNode) { + err1 := c.send(ctx, nodeID, func(cli types.QueryNodeClient) { req := proto.Clone(req).(*querypb.LoadPartitionsRequest) req.Base.TargetID = nodeID status, err = cli.LoadPartitions(ctx, req) @@ -193,7 +191,7 @@ func (c *QueryCluster) LoadPartitions(ctx context.Context, nodeID int64, req *qu func (c *QueryCluster) ReleasePartitions(ctx context.Context, nodeID int64, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { var status *commonpb.Status var err error - err1 := c.send(ctx, nodeID, func(cli types.QueryNode) { + err1 := c.send(ctx, nodeID, func(cli types.QueryNodeClient) { req := proto.Clone(req).(*querypb.ReleasePartitionsRequest) req.Base.TargetID = nodeID status, err = cli.ReleasePartitions(ctx, req) @@ -207,7 +205,7 @@ func (c *QueryCluster) ReleasePartitions(ctx context.Context, nodeID int64, req func (c *QueryCluster) GetDataDistribution(ctx context.Context, nodeID int64, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) { var resp *querypb.GetDataDistributionResponse var err error - err1 := c.send(ctx, nodeID, func(cli types.QueryNode) { + err1 := c.send(ctx, nodeID, func(cli types.QueryNodeClient) { req := proto.Clone(req).(*querypb.GetDataDistributionRequest) req.Base = &commonpb.MsgBase{ TargetID: nodeID, @@ -225,7 +223,7 @@ func (c *QueryCluster) GetMetrics(ctx context.Context, nodeID int64, req *milvus resp *milvuspb.GetMetricsResponse err error ) - err1 := c.send(ctx, nodeID, func(cli types.QueryNode) { + err1 := c.send(ctx, nodeID, func(cli types.QueryNodeClient) { resp, err = cli.GetMetrics(ctx, req) }) if err1 != nil { @@ -239,7 +237,7 @@ func (c *QueryCluster) SyncDistribution(ctx context.Context, nodeID int64, req * resp *commonpb.Status err error ) - err1 := c.send(ctx, nodeID, func(cli types.QueryNode) { + err1 := c.send(ctx, nodeID, func(cli types.QueryNodeClient) { req := proto.Clone(req).(*querypb.SyncDistributionRequest) req.Base.TargetID = nodeID resp, err = cli.SyncDistribution(ctx, req) @@ -255,8 +253,8 @@ func (c *QueryCluster) GetComponentStates(ctx context.Context, nodeID int64) (*m resp *milvuspb.ComponentStates err error ) - err1 := c.send(ctx, nodeID, func(cli types.QueryNode) { - resp, err = cli.GetComponentStates(ctx) + err1 := c.send(ctx, nodeID, func(cli types.QueryNodeClient) { + resp, err = cli.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) if err1 != nil { return nil, err1 @@ -264,7 +262,7 @@ func (c *QueryCluster) GetComponentStates(ctx context.Context, nodeID int64) (*m return resp, err } -func (c *QueryCluster) send(ctx context.Context, nodeID int64, fn func(cli types.QueryNode)) error { +func (c *QueryCluster) send(ctx context.Context, nodeID int64, fn func(cli types.QueryNodeClient)) error { node := c.nodeManager.Get(nodeID) if node == nil { return WrapErrNodeNotFound(nodeID) @@ -281,7 +279,7 @@ func (c *QueryCluster) send(ctx context.Context, nodeID int64, fn func(cli types type clients struct { sync.RWMutex - clients map[int64]types.QueryNode // nodeID -> client + clients map[int64]types.QueryNodeClient // nodeID -> client queryNodeCreator QueryNodeCreator } @@ -296,28 +294,22 @@ func (c *clients) getAllNodeIDs() []int64 { return ret } -func (c *clients) getOrCreate(ctx context.Context, node *NodeInfo) (types.QueryNode, error) { +func (c *clients) getOrCreate(ctx context.Context, node *NodeInfo) (types.QueryNodeClient, error) { if cli := c.get(node.ID()); cli != nil { return cli, nil } return c.create(node) } -func createNewClient(ctx context.Context, addr string, nodeID int64, queryNodeCreator QueryNodeCreator) (types.QueryNode, error) { +func createNewClient(ctx context.Context, addr string, nodeID int64, queryNodeCreator QueryNodeCreator) (types.QueryNodeClient, error) { newCli, err := queryNodeCreator(ctx, addr, nodeID) if err != nil { return nil, err } - if err = newCli.Init(); err != nil { - return nil, err - } - if err = newCli.Start(); err != nil { - return nil, err - } return newCli, nil } -func (c *clients) create(node *NodeInfo) (types.QueryNode, error) { +func (c *clients) create(node *NodeInfo) (types.QueryNodeClient, error) { c.Lock() defer c.Unlock() if cli, ok := c.clients[node.ID()]; ok { @@ -331,7 +323,7 @@ func (c *clients) create(node *NodeInfo) (types.QueryNode, error) { return cli, nil } -func (c *clients) get(nodeID int64) types.QueryNode { +func (c *clients) get(nodeID int64) types.QueryNodeClient { c.RLock() defer c.RUnlock() return c.clients[nodeID] @@ -341,7 +333,7 @@ func (c *clients) close(nodeID int64) { c.Lock() defer c.Unlock() if cli, ok := c.clients[nodeID]; ok { - if err := cli.Stop(); err != nil { + if err := cli.Close(); err != nil { log.Warn("error occurred during stopping client", zap.Int64("nodeID", nodeID), zap.Error(err)) } delete(c.clients, nodeID) @@ -352,7 +344,7 @@ func (c *clients) closeAll() { c.Lock() defer c.Unlock() for nodeID, cli := range c.clients { - if err := cli.Stop(); err != nil { + if err := cli.Close(); err != nil { log.Warn("error occurred during stopping client", zap.Int64("nodeID", nodeID), zap.Error(err)) } } @@ -360,7 +352,7 @@ func (c *clients) closeAll() { func newClients(queryNodeCreator QueryNodeCreator) *clients { return &clients{ - clients: make(map[int64]types.QueryNode), + clients: make(map[int64]types.QueryNodeClient), queryNodeCreator: queryNodeCreator, } } diff --git a/internal/querycoordv2/session/cluster_test.go b/internal/querycoordv2/session/cluster_test.go index d888387e4576b..4720a2db582b0 100644 --- a/internal/querycoordv2/session/cluster_test.go +++ b/internal/querycoordv2/session/cluster_test.go @@ -19,18 +19,21 @@ package session import ( "context" "net" + "strconv" "testing" "time" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/mocks" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" ) const bufSize = 1024 * 1024 @@ -45,10 +48,12 @@ type ClusterTestSuite struct { func (suite *ClusterTestSuite) SetupSuite() { paramtable.Init() + paramtable.Get().Save("grpc.client.maxMaxAttempts", "1") suite.setupServers() } func (suite *ClusterTestSuite) TearDownSuite() { + paramtable.Get().Save("grpc.client.maxMaxAttempts", strconv.FormatInt(paramtable.DefaultMaxAttempts, 10)) for _, svr := range suite.svrs { svr.GracefulStop() } @@ -61,6 +66,7 @@ func (suite *ClusterTestSuite) SetupTest() { func (suite *ClusterTestSuite) TearDownTest() { suite.cluster.Stop() } + func (suite *ClusterTestSuite) setupServers() { svrs := suite.createTestServers() for _, svr := range svrs { @@ -103,10 +109,7 @@ func (suite *ClusterTestSuite) createTestServers() []querypb.QueryNodeServer { } func (suite *ClusterTestSuite) createDefaultMockServer() querypb.QueryNodeServer { - succStatus := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - } + succStatus := merr.Success() svr := mocks.NewMockQueryNodeServer(suite.T()) // TODO: register more mock methods svr.EXPECT().LoadSegments( @@ -215,10 +218,7 @@ func (suite *ClusterTestSuite) TestLoadSegments() { Infos: []*querypb.SegmentLoadInfo{{}}, }) suite.NoError(err) - suite.Equal(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, status) + suite.Equal(merr.Success(), status) status, err = suite.cluster.LoadSegments(ctx, 1, &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{}, @@ -244,10 +244,7 @@ func (suite *ClusterTestSuite) TestWatchDmChannels() { Base: &commonpb.MsgBase{}, }) suite.NoError(err) - suite.Equal(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, status) + suite.Equal(merr.Success(), status) status, err = suite.cluster.WatchDmChannels(ctx, 1, &querypb.WatchDmChannelsRequest{ Base: &commonpb.MsgBase{}, @@ -265,10 +262,7 @@ func (suite *ClusterTestSuite) TestUnsubDmChannel() { Base: &commonpb.MsgBase{}, }) suite.NoError(err) - suite.Equal(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, status) + suite.Equal(merr.Success(), status) status, err = suite.cluster.UnsubDmChannel(ctx, 1, &querypb.UnsubDmChannelRequest{ Base: &commonpb.MsgBase{}, @@ -286,10 +280,7 @@ func (suite *ClusterTestSuite) TestReleaseSegments() { Base: &commonpb.MsgBase{}, }) suite.NoError(err) - suite.Equal(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, status) + suite.Equal(merr.Success(), status) status, err = suite.cluster.ReleaseSegments(ctx, 1, &querypb.ReleaseSegmentsRequest{ Base: &commonpb.MsgBase{}, @@ -307,10 +298,7 @@ func (suite *ClusterTestSuite) TestLoadAndReleasePartitions() { Base: &commonpb.MsgBase{}, }) suite.NoError(err) - suite.Equal(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, status) + suite.Equal(merr.Success(), status) status, err = suite.cluster.LoadPartitions(ctx, 1, &querypb.LoadPartitionsRequest{ Base: &commonpb.MsgBase{}, @@ -325,10 +313,7 @@ func (suite *ClusterTestSuite) TestLoadAndReleasePartitions() { Base: &commonpb.MsgBase{}, }) suite.NoError(err) - suite.Equal(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, status) + suite.Equal(merr.Success(), status) status, err = suite.cluster.ReleasePartitions(ctx, 1, &querypb.ReleasePartitionsRequest{ Base: &commonpb.MsgBase{}, @@ -346,10 +331,7 @@ func (suite *ClusterTestSuite) TestGetDataDistribution() { Base: &commonpb.MsgBase{}, }) suite.NoError(err) - suite.Equal(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, resp.GetStatus()) + suite.Equal(merr.Success(), resp.GetStatus()) resp, err = suite.cluster.GetDataDistribution(ctx, 1, &querypb.GetDataDistributionRequest{ Base: &commonpb.MsgBase{}, @@ -366,10 +348,7 @@ func (suite *ClusterTestSuite) TestGetMetrics() { ctx := context.TODO() resp, err := suite.cluster.GetMetrics(ctx, 0, &milvuspb.GetMetricsRequest{}) suite.NoError(err) - suite.Equal(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, resp.GetStatus()) + suite.Equal(merr.Success(), resp.GetStatus()) resp, err = suite.cluster.GetMetrics(ctx, 1, &milvuspb.GetMetricsRequest{}) suite.NoError(err) @@ -385,10 +364,7 @@ func (suite *ClusterTestSuite) TestSyncDistribution() { Base: &commonpb.MsgBase{}, }) suite.NoError(err) - suite.Equal(&commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, status) + suite.Equal(merr.Success(), status) status, err = suite.cluster.SyncDistribution(ctx, 1, &querypb.SyncDistributionRequest{ Base: &commonpb.MsgBase{}, diff --git a/internal/querycoordv2/session/mock_cluster.go b/internal/querycoordv2/session/mock_cluster.go index 6d9b96d0cfbb3..dbc14c720ce98 100644 --- a/internal/querycoordv2/session/mock_cluster.go +++ b/internal/querycoordv2/session/mock_cluster.go @@ -418,9 +418,9 @@ func (_c *MockCluster_ReleaseSegments_Call) RunAndReturn(run func(context.Contex return _c } -// Start provides a mock function with given fields: ctx -func (_m *MockCluster) Start(ctx context.Context) { - _m.Called(ctx) +// Start provides a mock function with given fields: +func (_m *MockCluster) Start() { + _m.Called() } // MockCluster_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' @@ -429,14 +429,13 @@ type MockCluster_Start_Call struct { } // Start is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockCluster_Expecter) Start(ctx interface{}) *MockCluster_Start_Call { - return &MockCluster_Start_Call{Call: _e.mock.On("Start", ctx)} +func (_e *MockCluster_Expecter) Start() *MockCluster_Start_Call { + return &MockCluster_Start_Call{Call: _e.mock.On("Start")} } -func (_c *MockCluster_Start_Call) Run(run func(ctx context.Context)) *MockCluster_Start_Call { +func (_c *MockCluster_Start_Call) Run(run func()) *MockCluster_Start_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run() }) return _c } @@ -446,7 +445,7 @@ func (_c *MockCluster_Start_Call) Return() *MockCluster_Start_Call { return _c } -func (_c *MockCluster_Start_Call) RunAndReturn(run func(context.Context)) *MockCluster_Start_Call { +func (_c *MockCluster_Start_Call) RunAndReturn(run func()) *MockCluster_Start_Call { _c.Call.Return(run) return _c } diff --git a/internal/querycoordv2/session/node_manager.go b/internal/querycoordv2/session/node_manager.go index bb0f51f8d4725..451a043f3a549 100644 --- a/internal/querycoordv2/session/node_manager.go +++ b/internal/querycoordv2/session/node_manager.go @@ -21,8 +21,9 @@ import ( "sync" "time" - "github.com/milvus-io/milvus/pkg/metrics" "go.uber.org/atomic" + + "github.com/milvus-io/milvus/pkg/metrics" ) type Manager interface { diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index 71331d55d7d75..0d1e5ccdd397c 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -22,8 +22,6 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/util/tsoutil" - "github.com/milvus-io/milvus/pkg/util/typeutil" "go.uber.org/atomic" "go.uber.org/zap" @@ -35,6 +33,8 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type Executor struct { @@ -59,7 +59,8 @@ func NewExecutor(meta *meta.Meta, broker meta.Broker, targetMgr *meta.TargetManager, cluster session.Cluster, - nodeMgr *session.NodeManager) *Executor { + nodeMgr *session.NodeManager, +) *Executor { return &Executor{ doneCh: make(chan struct{}), meta: meta, @@ -223,7 +224,8 @@ func (ex *Executor) executeSegmentAction(task *SegmentTask, step int) { func (ex *Executor) loadSegment(task *SegmentTask, step int) error { action := task.Actions()[step].(*SegmentAction) defer action.rpcReturned.Store(true) - log := log.With( + ctx := task.Context() + log := log.Ctx(ctx).With( zap.Int64("taskID", task.ID()), zap.Int64("collectionID", task.CollectionID()), zap.Int64("replicaID", task.ReplicaID()), @@ -240,7 +242,6 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { } }() - ctx := task.Context() schema, err := ex.broker.GetCollectionSchema(ctx, task.CollectionID()) if err != nil { log.Warn("failed to get schema of collection", zap.Error(err)) @@ -285,7 +286,7 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error { } log = log.With(zap.Int64("shardLeader", leader)) - //Get collection index info + // Get collection index info indexInfo, err := ex.broker.DescribeIndex(ctx, task.CollectionID()) if err != nil { log.Warn("fail to get index meta of collection") @@ -355,7 +356,7 @@ func (ex *Executor) releaseSegment(task *SegmentTask, step int) { log.Warn("failed to release segment, it may be a false failure", zap.Error(err)) return } - if status.ErrorCode != commonpb.ErrorCode_Success { + if status.GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("failed to release segment", zap.String("reason", status.GetReason())) return } diff --git a/internal/querycoordv2/task/mock_scheduler.go b/internal/querycoordv2/task/mock_scheduler.go index de5567b162d75..f9dd83835b512 100644 --- a/internal/querycoordv2/task/mock_scheduler.go +++ b/internal/querycoordv2/task/mock_scheduler.go @@ -2,11 +2,7 @@ package task -import ( - context "context" - - mock "github.com/stretchr/testify/mock" -) +import mock "github.com/stretchr/testify/mock" // MockScheduler is an autogenerated mock type for the Scheduler type type MockScheduler struct { @@ -361,9 +357,9 @@ func (_c *MockScheduler_RemoveExecutor_Call) RunAndReturn(run func(int64)) *Mock return _c } -// Start provides a mock function with given fields: ctx -func (_m *MockScheduler) Start(ctx context.Context) { - _m.Called(ctx) +// Start provides a mock function with given fields: +func (_m *MockScheduler) Start() { + _m.Called() } // MockScheduler_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' @@ -372,14 +368,13 @@ type MockScheduler_Start_Call struct { } // Start is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockScheduler_Expecter) Start(ctx interface{}) *MockScheduler_Start_Call { - return &MockScheduler_Start_Call{Call: _e.mock.On("Start", ctx)} +func (_e *MockScheduler_Expecter) Start() *MockScheduler_Start_Call { + return &MockScheduler_Start_Call{Call: _e.mock.On("Start")} } -func (_c *MockScheduler_Start_Call) Run(run func(ctx context.Context)) *MockScheduler_Start_Call { +func (_c *MockScheduler_Start_Call) Run(run func()) *MockScheduler_Start_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) + run() }) return _c } @@ -389,7 +384,7 @@ func (_c *MockScheduler_Start_Call) Return() *MockScheduler_Start_Call { return _c } -func (_c *MockScheduler_Start_Call) RunAndReturn(run func(context.Context)) *MockScheduler_Start_Call { +func (_c *MockScheduler_Start_Call) RunAndReturn(run func()) *MockScheduler_Start_Call { _c.Call.Return(run) return _c } diff --git a/internal/querycoordv2/task/scheduler.go b/internal/querycoordv2/task/scheduler.go index f9e31594f5913..412270a10f321 100644 --- a/internal/querycoordv2/task/scheduler.go +++ b/internal/querycoordv2/task/scheduler.go @@ -23,10 +23,11 @@ import ( "sync" "time" + "github.com/cockroachdb/errors" + "github.com/samber/lo" "go.uber.org/atomic" "go.uber.org/zap" - "github.com/cockroachdb/errors" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" @@ -37,7 +38,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" . "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/samber/lo" ) const ( @@ -126,7 +126,7 @@ func (queue *taskQueue) Range(fn func(task Task) bool) { } type Scheduler interface { - Start(ctx context.Context) + Start() Stop() AddExecutor(nodeID int64) RemoveExecutor(nodeID int64) @@ -165,7 +165,8 @@ func NewScheduler(ctx context.Context, targetMgr *meta.TargetManager, broker meta.Broker, cluster session.Cluster, - nodeMgr *session.NodeManager) *taskScheduler { + nodeMgr *session.NodeManager, +) *taskScheduler { id := time.Now().UnixMilli() return &taskScheduler{ ctx: ctx, @@ -190,7 +191,7 @@ func NewScheduler(ctx context.Context, } } -func (scheduler *taskScheduler) Start(ctx context.Context) {} +func (scheduler *taskScheduler) Start() {} func (scheduler *taskScheduler) Stop() { scheduler.rwmutex.Lock() @@ -200,6 +201,13 @@ func (scheduler *taskScheduler) Stop() { executor.Stop() delete(scheduler.executors, nodeID) } + + for _, task := range scheduler.segmentTasks { + scheduler.remove(task) + } + for _, task := range scheduler.channelTasks { + scheduler.remove(task) + } } func (scheduler *taskScheduler) AddExecutor(nodeID int64) { @@ -258,7 +266,7 @@ func (scheduler *taskScheduler) Add(task Task) error { } scheduler.updateTaskMetrics() - log.Info("task added", zap.String("task", task.String())) + log.Ctx(task.Context()).Info("task added", zap.String("task", task.String())) return nil } @@ -688,7 +696,7 @@ func (scheduler *taskScheduler) recordSegmentTaskError(task *SegmentTask) { } func (scheduler *taskScheduler) remove(task Task) { - log := log.With( + log := log.Ctx(task.Context()).With( zap.Int64("taskID", task.ID()), zap.Int64("collectionID", task.CollectionID()), zap.Int64("replicaID", task.ReplicaID()), diff --git a/internal/querycoordv2/task/task.go b/internal/querycoordv2/task/task.go index 8585206439b7d..9b7b507f5ae5f 100644 --- a/internal/querycoordv2/task/task.go +++ b/internal/querycoordv2/task/task.go @@ -22,15 +22,20 @@ import ( "time" "github.com/cockroachdb/errors" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/trace" + "go.uber.org/atomic" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/pkg/util/merr" . "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/atomic" ) -type Status = int32 -type Priority int32 +type ( + Status = int32 + Priority int32 +) const ( TaskStatusCreated Status = iota + 1 @@ -49,17 +54,15 @@ const ( var TaskPriorityName = map[Priority]string{ TaskPriorityLow: "Low", TaskPriorityNormal: "Normal", - TaskPriorityHigh: "Hight", + TaskPriorityHigh: "High", } func (p Priority) String() string { return TaskPriorityName[p] } -var ( - // All task priorities from low to high - TaskPriorities = []Priority{TaskPriorityLow, TaskPriorityNormal, TaskPriorityHigh} -) +// All task priorities from low to high +var TaskPriorities = []Priority{TaskPriorityLow, TaskPriorityNormal, TaskPriorityHigh} type Task interface { Context() context.Context @@ -108,10 +111,14 @@ type baseTask struct { actions []Action step int reason string + + // span for tracing + span trace.Span } func newBaseTask(ctx context.Context, sourceID, collectionID, replicaID UniqueID, shard string) *baseTask { ctx, cancel := context.WithCancel(ctx) + ctx, span := otel.Tracer("QueryCoord").Start(ctx, "QueryCoord-BaseTask") return &baseTask{ sourceID: sourceID, @@ -125,6 +132,7 @@ func newBaseTask(ctx context.Context, sourceID, collectionID, replicaID UniqueID cancel: cancel, doneCh: make(chan struct{}), canceled: atomic.NewBool(false), + span: span, } } @@ -193,6 +201,9 @@ func (task *baseTask) Cancel(err error) { } task.err = err close(task.doneCh) + if task.span != nil { + task.span.End() + } } } @@ -276,7 +287,8 @@ func NewSegmentTask(ctx context.Context, sourceID, collectionID, replicaID UniqueID, - actions ...Action) (*SegmentTask, error) { + actions ...Action, +) (*SegmentTask, error) { if len(actions) == 0 { return nil, errors.WithStack(merr.WrapErrParameterInvalid("non-empty actions", "no action")) } @@ -332,7 +344,8 @@ func NewChannelTask(ctx context.Context, sourceID, collectionID, replicaID UniqueID, - actions ...Action) (*ChannelTask, error) { + actions ...Action, +) (*ChannelTask, error) { if len(actions) == 0 { return nil, errors.WithStack(merr.WrapErrParameterInvalid("non-empty actions", "no action")) } diff --git a/internal/querycoordv2/task/task_test.go b/internal/querycoordv2/task/task_test.go index 6c1779867f8d2..a682ccffe97e0 100644 --- a/internal/querycoordv2/task/task_test.go +++ b/internal/querycoordv2/task/task_test.go @@ -146,7 +146,7 @@ func (suite *TaskSuite) SetupTest() { suite.cluster = session.NewMockCluster(suite.T()) suite.scheduler = suite.newScheduler() - suite.scheduler.Start(context.Background()) + suite.scheduler.Start() suite.scheduler.AddExecutor(1) suite.scheduler.AddExecutor(2) suite.scheduler.AddExecutor(3) @@ -205,13 +205,15 @@ func (suite *TaskSuite) TestSubscribeChannelTask() { }, nil) for channel, segment := range suite.growingSegments { suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment). - Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ - { - ID: segment, - CollectionID: suite.collection, - PartitionID: partitions[0], - InsertChannel: channel, - }}, + Return(&datapb.GetSegmentInfoResponse{ + Infos: []*datapb.SegmentInfo{ + { + ID: segment, + CollectionID: suite.collection, + PartitionID: partitions[0], + InsertChannel: channel, + }, + }, }, nil) } suite.broker.EXPECT().DescribeIndex(mock.Anything, suite.collection).Return([]*indexpb.IndexInfo{ @@ -226,7 +228,7 @@ func (suite *TaskSuite) TestSubscribeChannelTask() { }, }, }, nil) - suite.cluster.EXPECT().WatchDmChannels(mock.Anything, targetNode, mock.Anything).Return(merr.Status(nil), nil) + suite.cluster.EXPECT().WatchDmChannels(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil) // Test subscribe channel task tasks := []Task{} @@ -321,7 +323,7 @@ func (suite *TaskSuite) TestUnsubscribeChannelTask() { targetNode := int64(1) // Expect - suite.cluster.EXPECT().UnsubDmChannel(mock.Anything, targetNode, mock.Anything).Return(merr.Status(nil), nil) + suite.cluster.EXPECT().UnsubDmChannel(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil) // Test unsubscribe channel task tasks := []Task{} @@ -394,17 +396,19 @@ func (suite *TaskSuite) TestLoadSegmentTask() { }, }, nil) for _, segment := range suite.loadSegments { - suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ - { - ID: segment, - CollectionID: suite.collection, - PartitionID: partition, - InsertChannel: channel.ChannelName, - }}, + suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{ + Infos: []*datapb.SegmentInfo{ + { + ID: segment, + CollectionID: suite.collection, + PartitionID: partition, + InsertChannel: channel.ChannelName, + }, + }, }, nil) suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil) } - suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Status(nil), nil) + suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil) // Test load segment task suite.dist.ChannelDistManager.Update(targetNode, meta.DmChannelFromVChannel(&datapb.VchannelInfo{ @@ -488,17 +492,19 @@ func (suite *TaskSuite) TestLoadSegmentTaskNotIndex() { }, }, nil) for _, segment := range suite.loadSegments { - suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ - { - ID: segment, - CollectionID: suite.collection, - PartitionID: partition, - InsertChannel: channel.ChannelName, - }}, + suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{ + Infos: []*datapb.SegmentInfo{ + { + ID: segment, + CollectionID: suite.collection, + PartitionID: partition, + InsertChannel: channel.ChannelName, + }, + }, }, nil) - suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, merr.WrapErrIndexNotFound()) + suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, merr.WrapErrIndexNotFoundForSegment(segment)) } - suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Status(nil), nil) + suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil) // Test load segment task suite.dist.ChannelDistManager.Update(targetNode, meta.DmChannelFromVChannel(&datapb.VchannelInfo{ @@ -577,13 +583,15 @@ func (suite *TaskSuite) TestLoadSegmentTaskFailed() { }, }, nil) for _, segment := range suite.loadSegments { - suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ - { - ID: segment, - CollectionID: suite.collection, - PartitionID: partition, - InsertChannel: channel.ChannelName, - }}, + suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{ + Infos: []*datapb.SegmentInfo{ + { + ID: segment, + CollectionID: suite.collection, + PartitionID: partition, + InsertChannel: channel.ChannelName, + }, + }, }, nil) suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, errors.New("index not ready")) } @@ -645,7 +653,7 @@ func (suite *TaskSuite) TestReleaseSegmentTask() { } // Expect - suite.cluster.EXPECT().ReleaseSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Status(nil), nil) + suite.cluster.EXPECT().ReleaseSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil) // Test load segment task view := &meta.LeaderView{ @@ -706,7 +714,7 @@ func (suite *TaskSuite) TestReleaseGrowingSegmentTask() { targetNode := int64(3) // Expect - suite.cluster.EXPECT().ReleaseSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Status(nil), nil) + suite.cluster.EXPECT().ReleaseSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil) tasks := []Task{} for _, segment := range suite.releaseSegments { @@ -778,18 +786,20 @@ func (suite *TaskSuite) TestMoveSegmentTask() { }, }, nil) for _, segment := range suite.moveSegments { - suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ - { - ID: segment, - CollectionID: suite.collection, - PartitionID: partition, - InsertChannel: channel.ChannelName, - }}, + suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{ + Infos: []*datapb.SegmentInfo{ + { + ID: segment, + CollectionID: suite.collection, + PartitionID: partition, + InsertChannel: channel.ChannelName, + }, + }, }, nil) suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil) } - suite.cluster.EXPECT().LoadSegments(mock.Anything, leader, mock.Anything).Return(merr.Status(nil), nil) - suite.cluster.EXPECT().ReleaseSegments(mock.Anything, leader, mock.Anything).Return(merr.Status(nil), nil) + suite.cluster.EXPECT().LoadSegments(mock.Anything, leader, mock.Anything).Return(merr.Success(), nil) + suite.cluster.EXPECT().ReleaseSegments(mock.Anything, leader, mock.Anything).Return(merr.Success(), nil) vchannel := &datapb.VchannelInfo{ CollectionID: suite.collection, ChannelName: channel.ChannelName, @@ -946,17 +956,19 @@ func (suite *TaskSuite) TestTaskCanceled() { }, }, nil) for _, segment := range suite.loadSegments { - suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ - { - ID: segment, - CollectionID: suite.collection, - PartitionID: partition, - InsertChannel: channel.ChannelName, - }}, + suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{ + Infos: []*datapb.SegmentInfo{ + { + ID: segment, + CollectionID: suite.collection, + PartitionID: partition, + InsertChannel: channel.ChannelName, + }, + }, }, nil) suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil) } - suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Status(nil), nil) + suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil) // Test load segment task suite.dist.ChannelDistManager.Update(targetNode, meta.DmChannelFromVChannel(&datapb.VchannelInfo{ @@ -1031,17 +1043,19 @@ func (suite *TaskSuite) TestSegmentTaskStale() { }, }, nil) for _, segment := range suite.loadSegments { - suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ - { - ID: segment, - CollectionID: suite.collection, - PartitionID: partition, - InsertChannel: channel.ChannelName, - }}, + suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{ + Infos: []*datapb.SegmentInfo{ + { + ID: segment, + CollectionID: suite.collection, + PartitionID: partition, + InsertChannel: channel.ChannelName, + }, + }, }, nil) suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil) } - suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Status(nil), nil) + suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(merr.Success(), nil) // Test load segment task suite.meta.ReplicaManager.Put(createReplica(suite.collection, targetNode)) diff --git a/internal/querycoordv2/utils/types.go b/internal/querycoordv2/utils/types.go index 456bab4e371c8..2c420db196ac6 100644 --- a/internal/querycoordv2/utils/types.go +++ b/internal/querycoordv2/utils/types.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/tsoutil" ) @@ -47,6 +48,7 @@ func MergeMetaSegmentIntoSegmentInfo(info *querypb.SegmentInfo, segments ...*met first := segments[0] if info.GetSegmentID() == 0 { *info = querypb.SegmentInfo{ + NodeID: paramtable.GetNodeID(), SegmentID: first.GetID(), CollectionID: first.GetCollectionID(), PartitionID: first.GetPartitionID(), @@ -54,6 +56,12 @@ func MergeMetaSegmentIntoSegmentInfo(info *querypb.SegmentInfo, segments ...*met DmChannel: first.GetInsertChannel(), NodeIds: make([]int64, 0), SegmentState: commonpb.SegmentState_Sealed, + IndexInfos: make([]*querypb.FieldIndexInfo, 0), + } + for _, indexInfo := range first.IndexInfo { + info.IndexName = indexInfo.IndexName + info.IndexID = indexInfo.IndexID + info.IndexInfos = append(info.IndexInfos, indexInfo) } } diff --git a/internal/querycoordv2/utils/types_test.go b/internal/querycoordv2/utils/types_test.go index fb55c7bcc6425..f9bc9d9489e08 100644 --- a/internal/querycoordv2/utils/types_test.go +++ b/internal/querycoordv2/utils/types_test.go @@ -21,11 +21,11 @@ import ( "time" "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) func Test_packLoadSegmentRequest(t *testing.T) { diff --git a/internal/querynodev2/cluster/manager.go b/internal/querynodev2/cluster/manager.go index 3a4387fc23a4d..b239f24736a60 100644 --- a/internal/querynodev2/cluster/manager.go +++ b/internal/querynodev2/cluster/manager.go @@ -39,7 +39,7 @@ type WorkerBuilder func(ctx context.Context, nodeID int64) (Worker, error) type grpcWorkerManager struct { workers *typeutil.ConcurrentMap[int64, Worker] builder WorkerBuilder - sf conc.Singleflight[Worker] //singleflight.Group + sf conc.Singleflight[Worker] // singleflight.Group } // GetWorker returns worker with specified nodeID. diff --git a/internal/querynodev2/cluster/manager_test.go b/internal/querynodev2/cluster/manager_test.go index 3009c2677ab40..c953d5d1c5350 100644 --- a/internal/querynodev2/cluster/manager_test.go +++ b/internal/querynodev2/cluster/manager_test.go @@ -21,7 +21,6 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" ) diff --git a/internal/querynodev2/cluster/mock_worker.go b/internal/querynodev2/cluster/mock_worker.go index 59654da101ce1..ba573bf89d0b3 100644 --- a/internal/querynodev2/cluster/mock_worker.go +++ b/internal/querynodev2/cluster/mock_worker.go @@ -9,6 +9,8 @@ import ( mock "github.com/stretchr/testify/mock" querypb "github.com/milvus-io/milvus/internal/proto/querypb" + + streamrpc "github.com/milvus-io/milvus/internal/util/streamrpc" ) // MockWorker is an autogenerated mock type for the Worker type @@ -261,6 +263,50 @@ func (_c *MockWorker_QuerySegments_Call) RunAndReturn(run func(context.Context, return _c } +// QueryStreamSegments provides a mock function with given fields: ctx, req, srv +func (_m *MockWorker) QueryStreamSegments(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error { + ret := _m.Called(ctx, req, srv) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest, streamrpc.QueryStreamServer) error); ok { + r0 = rf(ctx, req, srv) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockWorker_QueryStreamSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryStreamSegments' +type MockWorker_QueryStreamSegments_Call struct { + *mock.Call +} + +// QueryStreamSegments is a helper method to define mock.On call +// - ctx context.Context +// - req *querypb.QueryRequest +// - srv streamrpc.QueryStreamServer +func (_e *MockWorker_Expecter) QueryStreamSegments(ctx interface{}, req interface{}, srv interface{}) *MockWorker_QueryStreamSegments_Call { + return &MockWorker_QueryStreamSegments_Call{Call: _e.mock.On("QueryStreamSegments", ctx, req, srv)} +} + +func (_c *MockWorker_QueryStreamSegments_Call) Run(run func(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer)) *MockWorker_QueryStreamSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.QueryRequest), args[2].(streamrpc.QueryStreamServer)) + }) + return _c +} + +func (_c *MockWorker_QueryStreamSegments_Call) Return(_a0 error) *MockWorker_QueryStreamSegments_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockWorker_QueryStreamSegments_Call) RunAndReturn(run func(context.Context, *querypb.QueryRequest, streamrpc.QueryStreamServer) error) *MockWorker_QueryStreamSegments_Call { + _c.Call.Return(run) + return _c +} + // ReleaseSegments provides a mock function with given fields: _a0, _a1 func (_m *MockWorker) ReleaseSegments(_a0 context.Context, _a1 *querypb.ReleaseSegmentsRequest) error { ret := _m.Called(_a0, _a1) diff --git a/internal/querynodev2/cluster/worker.go b/internal/querynodev2/cluster/worker.go index b0ab0b12c1488..9791de7547b0b 100644 --- a/internal/querynodev2/cluster/worker.go +++ b/internal/querynodev2/cluster/worker.go @@ -20,16 +20,17 @@ package cluster import ( "context" "fmt" + "io" + "github.com/cockroachdb/errors" "go.uber.org/zap" - "google.golang.org/grpc/codes" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -40,6 +41,7 @@ type Worker interface { Delete(ctx context.Context, req *querypb.DeleteRequest) error SearchSegments(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) + QueryStreamSegments(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) IsHealthy() bool @@ -48,11 +50,11 @@ type Worker interface { // remoteWorker wraps grpc QueryNode client as Worker. type remoteWorker struct { - client types.QueryNode + client types.QueryNodeClient } // NewRemoteWorker creates a grpcWorker. -func NewRemoteWorker(client types.QueryNode) Worker { +func NewRemoteWorker(client types.QueryNodeClient) Worker { return &remoteWorker{ client: client, } @@ -69,7 +71,7 @@ func (w *remoteWorker) LoadSegments(ctx context.Context, req *querypb.LoadSegmen zap.Error(err), ) return err - } else if status.ErrorCode != commonpb.ErrorCode_Success { + } else if status.GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("failed to call LoadSegments, worker return error", zap.String("errorCode", status.GetErrorCode().String()), zap.String("reason", status.GetReason()), @@ -89,7 +91,7 @@ func (w *remoteWorker) ReleaseSegments(ctx context.Context, req *querypb.Release zap.Error(err), ) return err - } else if status.ErrorCode != commonpb.ErrorCode_Success { + } else if status.GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("failed to call ReleaseSegments, worker return error", zap.String("errorCode", status.GetErrorCode().String()), zap.String("reason", status.GetReason()), @@ -104,24 +106,20 @@ func (w *remoteWorker) Delete(ctx context.Context, req *querypb.DeleteRequest) e zap.Int64("workerID", req.GetBase().GetTargetID()), ) status, err := w.client.Delete(ctx, req) - if err != nil { - log.Warn("failed to call Delete via grpc worker", - zap.Error(err), - ) + if err := merr.CheckRPCCall(status, err); err != nil { + if errors.Is(err, merr.ErrServiceUnimplemented) { + log.Warn("invoke legacy querynode Delete method, ignore error", zap.Error(err)) + return nil + } + log.Warn("failed to call Delete, worker return error", zap.Error(err)) return err - } else if status.GetErrorCode() != commonpb.ErrorCode_Success { - log.Warn("failed to call Delete, worker return error", - zap.String("errorCode", status.GetErrorCode().String()), - zap.String("reason", status.GetReason()), - ) - return merr.Error(status) } return nil } func (w *remoteWorker) SearchSegments(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { ret, err := w.client.SearchSegments(ctx, req) - if err != nil && funcutil.IsGrpcErr(err, codes.Unimplemented) { + if err != nil && errors.Is(err, merr.ErrServiceUnimplemented) { // for compatible with rolling upgrade from version before v2.2.9 return w.client.Search(ctx, req) } @@ -131,7 +129,7 @@ func (w *remoteWorker) SearchSegments(ctx context.Context, req *querypb.SearchRe func (w *remoteWorker) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { ret, err := w.client.QuerySegments(ctx, req) - if err != nil && funcutil.IsGrpcErr(err, codes.Unimplemented) { + if err != nil && errors.Is(err, merr.ErrServiceUnimplemented) { // for compatible with rolling upgrade from version before v2.2.9 return w.client.Query(ctx, req) } @@ -139,6 +137,37 @@ func (w *remoteWorker) QuerySegments(ctx context.Context, req *querypb.QueryRequ return ret, err } +func (w *remoteWorker) QueryStreamSegments(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error { + client, err := w.client.QueryStreamSegments(ctx, req) + if err != nil { + return err + } + + for { + result, err := client.Recv() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + + err = merr.Error(result.GetStatus()) + if err != nil { + return err + } + + err = srv.Send(result) + if err != nil { + log.Warn("send stream pks from remote woker failed", + zap.Int64("collectionID", req.Req.GetCollectionID()), + zap.Int64s("segmentIDs", req.GetSegmentIDs()), + ) + return err + } + } +} + func (w *remoteWorker) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) { return w.client.GetStatistics(ctx, req) } @@ -148,5 +177,7 @@ func (w *remoteWorker) IsHealthy() bool { } func (w *remoteWorker) Stop() { - w.client.Stop() + if err := w.client.Close(); err != nil { + log.Warn("failed to call Close via grpc worker", zap.Error(err)) + } } diff --git a/internal/querynodev2/cluster/worker_test.go b/internal/querynodev2/cluster/worker_test.go index f31efd2a4331d..084fa3ead1f9b 100644 --- a/internal/querynodev2/cluster/worker_test.go +++ b/internal/querynodev2/cluster/worker_test.go @@ -19,31 +19,36 @@ package cluster import ( - context "context" + "context" + "io" "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/mocks" - internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" - querypb "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/streamrpc" + "github.com/milvus-io/milvus/pkg/util/merr" ) type RemoteWorkerSuite struct { suite.Suite - mockClient *mocks.MockQueryNode + mockClient *mocks.MockQueryNodeClient worker *remoteWorker } func (s *RemoteWorkerSuite) SetupTest() { - s.mockClient = &mocks.MockQueryNode{} + s.mockClient = &mocks.MockQueryNodeClient{} s.worker = &remoteWorker{client: s.mockClient} } @@ -173,6 +178,19 @@ func (s *RemoteWorkerSuite) TestDelete() { s.Error(err) }) + + s.Run("legacy_querynode_unimplemented", func() { + defer func() { s.mockClient.ExpectedCalls = nil }() + + s.mockClient.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")). + Return(nil, merr.WrapErrServiceUnimplemented(status.Errorf(codes.Unimplemented, "mocked grpc unimplemented"))) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err := s.worker.Delete(ctx, &querypb.DeleteRequest{}) + + s.NoError(err) + }) } func (s *RemoteWorkerSuite) TestSearch() { @@ -240,7 +258,7 @@ func (s *RemoteWorkerSuite) TestSearch() { grpcErr := status.Error(codes.Unimplemented, "method not implemented") s.mockClient.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")). - Return(result, grpcErr) + Return(result, merr.WrapErrServiceUnimplemented(grpcErr)) s.mockClient.EXPECT().Search(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")). Return(result, err) @@ -319,7 +337,7 @@ func (s *RemoteWorkerSuite) TestQuery() { grpcErr := status.Error(codes.Unimplemented, "method not implemented") s.mockClient.EXPECT().QuerySegments(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")). - Return(result, grpcErr) + Return(result, merr.WrapErrServiceUnimplemented(grpcErr)) s.mockClient.EXPECT().Query(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")). Return(result, err) @@ -332,6 +350,174 @@ func (s *RemoteWorkerSuite) TestQuery() { }) } +func (s *RemoteWorkerSuite) TestQueryStream() { + s.Run("normal_run", func() { + defer func() { s.mockClient.ExpectedCalls = nil }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + ids := []int64{10, 11, 12} + s.mockClient.EXPECT().QueryStreamSegments( + mock.Anything, + mock.AnythingOfType("*querypb.QueryRequest"), + ).RunAndReturn(func(ctx context.Context, request *querypb.QueryRequest, option ...grpc.CallOption) (querypb.QueryNode_QueryStreamSegmentsClient, error) { + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + for _, id := range ids { + err := server.Send(&internalpb.RetrieveResults{ + Status: merr.Success(), + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{id}}, + }, + }, + }) + s.NoError(err) + } + err := server.FinishSend(nil) + s.NoError(err) + return client, nil + }) + + go func() { + err := s.worker.QueryStreamSegments(ctx, &querypb.QueryRequest{}, server) + if err != nil { + server.Send(&internalpb.RetrieveResults{ + Status: merr.Status(err), + }) + } + server.FinishSend(err) + }() + + recNum := 0 + for { + result, err := client.Recv() + if err == io.EOF { + break + } + s.NoError(err) + + err = merr.Error(result.GetStatus()) + s.NoError(err) + + s.Less(recNum, len(ids)) + s.Equal(result.Ids.GetIntId().Data[0], ids[recNum]) + recNum++ + } + }) + + s.Run("send msg failed", func() { + defer func() { s.mockClient.ExpectedCalls = nil }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clientCtx, clientClose := context.WithCancel(ctx) + client := streamrpc.NewLocalQueryClient(clientCtx) + server := client.CreateServer() + clientClose() + + ids := []int64{10, 11, 12} + + s.mockClient.EXPECT().QueryStreamSegments( + mock.Anything, + mock.AnythingOfType("*querypb.QueryRequest"), + ).RunAndReturn(func(ctx context.Context, request *querypb.QueryRequest, option ...grpc.CallOption) (querypb.QueryNode_QueryStreamSegmentsClient, error) { + for _, id := range ids { + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + err := server.Send(&internalpb.RetrieveResults{ + Status: merr.Success(), + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{id}}, + }, + }, + }) + s.NoError(err) + } + err := server.FinishSend(nil) + s.NoError(err) + return client, nil + }) + + err := s.worker.QueryStreamSegments(ctx, &querypb.QueryRequest{}, server) + s.Error(err) + }) + + s.Run("client_return_error", func() { + defer func() { s.mockClient.ExpectedCalls = nil }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := streamrpc.NewLocalQueryClient(ctx) + server := streamrpc.NewConcurrentQueryStreamServer(client.CreateServer()) + + s.mockClient.EXPECT().QueryStreamSegments( + mock.Anything, + mock.AnythingOfType("*querypb.QueryRequest"), + ).Return(nil, errors.New("mocked error")) + + go func() { + err := s.worker.QueryStreamSegments(ctx, &querypb.QueryRequest{}, server) + server.Send(&internalpb.RetrieveResults{ + Status: merr.Status(err), + }) + }() + + result, err := client.Recv() + s.NoError(err) + + err = merr.Error(result.GetStatus()) + // Check result + s.Error(err) + }) + + s.Run("client_return_fail_status", func() { + defer func() { s.mockClient.ExpectedCalls = nil }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + s.mockClient.EXPECT().QueryStreamSegments( + mock.Anything, + mock.AnythingOfType("*querypb.QueryRequest"), + ).RunAndReturn(func(ctx context.Context, request *querypb.QueryRequest, option ...grpc.CallOption) (querypb.QueryNode_QueryStreamSegmentsClient, error) { + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + err := server.Send(&internalpb.RetrieveResults{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, + }) + s.NoError(err) + + err = server.FinishSend(nil) + s.NoError(err) + return client, nil + }) + + go func() { + err := s.worker.QueryStreamSegments(ctx, &querypb.QueryRequest{}, server) + server.Send(&internalpb.RetrieveResults{ + Status: merr.Status(err), + }) + }() + + result, err := client.Recv() + s.NoError(err) + + err = merr.Error(result.GetStatus()) + // Check result + s.Error(err) + }) +} + func (s *RemoteWorkerSuite) TestGetStatistics() { s.Run("normal_run", func() { defer func() { s.mockClient.ExpectedCalls = nil }() @@ -394,9 +580,9 @@ func (s *RemoteWorkerSuite) TestGetStatistics() { func (s *RemoteWorkerSuite) TestBasic() { s.True(s.worker.IsHealthy()) - s.mockClient.EXPECT().Stop().Return(nil) + s.mockClient.EXPECT().Close().Return(nil) s.worker.Stop() - s.mockClient.AssertCalled(s.T(), "Stop") + s.mockClient.AssertCalled(s.T(), "Close") } func TestRemoteWorker(t *testing.T) { @@ -404,8 +590,7 @@ func TestRemoteWorker(t *testing.T) { } func TestNewRemoteWorker(t *testing.T) { - client := &mocks.MockQueryNode{} - + client := mocks.NewMockQueryNodeClient(t) w := NewRemoteWorker(client) rw, ok := w.(*remoteWorker) diff --git a/internal/querynodev2/collector/average_test.go b/internal/querynodev2/collector/average_test.go index e1eaa1bdcaf0d..9306ec7facfad 100644 --- a/internal/querynodev2/collector/average_test.go +++ b/internal/querynodev2/collector/average_test.go @@ -34,17 +34,17 @@ func (suite *AverageCollectorTestSuite) SetupSuite() { } func (suite *AverageCollectorTestSuite) TestBasic() { - //Get average not register + // Get average not register _, err := suite.average.Average(suite.label) suite.Error(err) - //register and get + // register and get suite.average.Register(suite.label) value, err := suite.average.Average(suite.label) suite.Equal(float64(0), value) suite.NoError(err) - //add and get + // add and get sum := 4 for i := 0; i <= sum; i++ { suite.average.Add(suite.label, float64(i)) diff --git a/internal/querynodev2/collector/collector.go b/internal/querynodev2/collector/collector.go index c39fde724e947..797a29d319863 100644 --- a/internal/querynodev2/collector/collector.go +++ b/internal/querynodev2/collector/collector.go @@ -17,10 +17,11 @@ package collector import ( + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/ratelimitutil" - "go.uber.org/zap" ) var Average *averageCollector @@ -65,11 +66,11 @@ func init() { Average = newAverageCollector() Counter = newCounter() - //init rate Metric + // init rate Metric for _, label := range RateMetrics() { Rate.Register(label) } - //init average metric + // init average metric for _, label := range AverageMetrics() { Average.Register(label) diff --git a/internal/querynodev2/collector/counter_test.go b/internal/querynodev2/collector/counter_test.go index 74ce05ff08471..731dd6477b98c 100644 --- a/internal/querynodev2/collector/counter_test.go +++ b/internal/querynodev2/collector/counter_test.go @@ -34,26 +34,26 @@ func (suite *CounterTestSuite) SetupSuite() { } func (suite *CounterTestSuite) TestBasic() { - //get default value(zero) + // get default value(zero) value := suite.counter.Get(suite.label) suite.Equal(int64(0), value) - //get after inc + // get after inc suite.counter.Inc(suite.label, 3) value = suite.counter.Get(suite.label) suite.Equal(int64(3), value) - //remove + // remove suite.counter.Remove(suite.label) value = suite.counter.Get(suite.label) suite.Equal(int64(0), value) - //get after dec + // get after dec suite.counter.Dec(suite.label, 3) value = suite.counter.Get(suite.label) suite.Equal(int64(-3), value) - //remove + // remove suite.counter.Remove(suite.label) value = suite.counter.Get(suite.label) suite.Equal(int64(0), value) diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 18f9e02eebb8d..2796fe70269c8 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -37,42 +37,19 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tsafe" + "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/tsoutil" ) -type lifetime struct { - state atomic.Int32 - closeCh chan struct{} - closeOnce sync.Once -} - -func (lt *lifetime) SetState(state int32) { - lt.state.Store(state) -} - -func (lt *lifetime) GetState() int32 { - return lt.state.Load() -} - -func (lt *lifetime) Close() { - lt.closeOnce.Do(func() { - close(lt.closeCh) - }) -} - -func newLifetime() *lifetime { - return &lifetime{ - closeCh: make(chan struct{}), - } -} - // ShardDelegator is the interface definition. type ShardDelegator interface { Collection() int64 @@ -81,9 +58,10 @@ type ShardDelegator interface { SyncDistribution(ctx context.Context, entries ...SegmentEntry) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) + QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) - //data + // data ProcessInsert(insertRecords map[int64]*InsertData) ProcessDelete(deleteData []*DeleteData, ts uint64) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error @@ -100,12 +78,6 @@ type ShardDelegator interface { var _ ShardDelegator = (*shardDelegator)(nil) -const ( - initializing int32 = iota - working - stopped -) - // shardDelegator maintains the shard distribution and streaming part of the data. type shardDelegator struct { // shard information attributes @@ -118,7 +90,7 @@ type shardDelegator struct { workerManager cluster.Manager - lifetime *lifetime + lifetime lifetime.Lifetime[lifetime.State] distribution *distribution segmentManager segments.SegmentManager @@ -127,11 +99,11 @@ type shardDelegator struct { // L0 delete buffer deleteMut sync.Mutex deleteBuffer deletebuffer.DeleteBuffer[*deletebuffer.Item] - //dispatcherClient msgdispatcher.Client + // dispatcherClient msgdispatcher.Client factory msgstream.Factory + sf conc.Singleflight[struct{}] loader segments.Loader - wg sync.WaitGroup tsCond *sync.Cond latestTsafe *atomic.Uint64 } @@ -147,16 +119,16 @@ func (sd *shardDelegator) getLogger(ctx context.Context) *log.MLogger { // Serviceable returns whether delegator is serviceable now. func (sd *shardDelegator) Serviceable() bool { - return sd.lifetime.GetState() == working + return lifetime.IsWorking(sd.lifetime.GetState()) == nil } func (sd *shardDelegator) Stopped() bool { - return sd.lifetime.GetState() == stopped + return lifetime.NotStopped(sd.lifetime.GetState()) != nil } // Start sets delegator to working state. func (sd *shardDelegator) Start() { - sd.lifetime.SetState(working) + sd.lifetime.SetState(lifetime.Working) } // Collection returns delegator collection id. @@ -206,9 +178,10 @@ func (sd *shardDelegator) modifyQueryRequest(req *querypb.QueryRequest, scope qu // Search preforms search operation on shard. func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) { log := sd.getLogger(ctx) - if !sd.Serviceable() { - return nil, errors.New("delegator is not serviceable") + if err := sd.lifetime.Add(lifetime.IsWorking); err != nil { + return nil, err } + defer sd.lifetime.Done() if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) { log.Warn("deletgator received search request not belongs to it", @@ -268,12 +241,75 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest return results, nil } +func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error { + log := sd.getLogger(ctx) + if !sd.Serviceable() { + return errors.New("delegator is not serviceable") + } + + if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) { + log.Warn("deletgator received query request not belongs to it", + zap.Strings("reqChannels", req.GetDmlChannels()), + ) + return fmt.Errorf("dml channel not match, delegator channel %s, search channels %v", sd.vchannelName, req.GetDmlChannels()) + } + + partitions := req.GetReq().GetPartitionIDs() + if !sd.collection.ExistPartition(partitions...) { + return merr.WrapErrPartitionNotLoaded(partitions) + } + + // wait tsafe + waitTr := timerecord.NewTimeRecorder("wait tSafe") + err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) + if err != nil { + log.Warn("delegator query failed to wait tsafe", zap.Error(err)) + return err + } + metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues( + fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel). + Observe(float64(waitTr.ElapseSpan().Milliseconds())) + + sealed, growing, version := sd.distribution.GetSegments(true, req.GetReq().GetPartitionIDs()...) + defer sd.distribution.FinishUsage(version) + existPartitions := sd.collection.GetPartitions() + growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool { + return funcutil.SliceContain(existPartitions, segment.PartitionID) + }) + if req.Req.IgnoreGrowing { + growing = []SegmentEntry{} + } + + log.Info("query segments...", + zap.Int("sealedNum", len(sealed)), + zap.Int("growingNum", len(growing)), + ) + tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifyQueryRequest) + if err != nil { + log.Warn("query organizeSubTask failed", zap.Error(err)) + return err + } + + _, err = executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.QueryRequest, worker cluster.Worker) (*internalpb.RetrieveResults, error) { + return nil, worker.QueryStreamSegments(ctx, req, srv) + }, "Query", log) + if err != nil { + log.Warn("Delegator query failed", zap.Error(err)) + return err + } + + log.Info("Delegator Query done") + + return nil +} + // Query performs query operation on shard. func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) { log := sd.getLogger(ctx) - if !sd.Serviceable() { - return nil, errors.New("delegator is not serviceable") + if err := sd.lifetime.Add(lifetime.IsWorking); err != nil { + return nil, err } + defer sd.lifetime.Done() if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) { log.Warn("delegator received query request not belongs to it", @@ -335,9 +371,10 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) // GetStatistics returns statistics aggregated by delegator. func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) { log := sd.getLogger(ctx) - if !sd.Serviceable() { - return nil, errors.New("delegator is not serviceable") + if err := sd.lifetime.Add(lifetime.IsWorking); err != nil { + return nil, err } + defer sd.lifetime.Done() if !funcutil.SliceContain(req.GetDmlChannels(), sd.vchannelName) { log.Warn("deletgator received query request not belongs to it", @@ -510,7 +547,9 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error { sd.tsCond.L.Lock() defer sd.tsCond.L.Unlock() - for sd.latestTsafe.Load() < ts && ctx.Err() == nil { + for sd.latestTsafe.Load() < ts && + ctx.Err() == nil && + sd.Serviceable() { sd.tsCond.Wait() } close(ch) @@ -524,6 +563,9 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error { sd.tsCond.Broadcast() return ctx.Err() case <-ch: + if !sd.Serviceable() { + return merr.WrapErrChannelNotAvailable(sd.vchannelName, "delegator closed during wait tsafe") + } return nil } } @@ -531,7 +573,7 @@ func (sd *shardDelegator) waitTSafe(ctx context.Context, ts uint64) error { // watchTSafe is the worker function to update serviceable timestamp. func (sd *shardDelegator) watchTSafe() { - defer sd.wg.Done() + defer sd.lifetime.Done() listener := sd.tsafeManager.WatchChannel(sd.vchannelName) sd.updateTSafe() log := sd.getLogger(context.Background()) @@ -544,7 +586,7 @@ func (sd *shardDelegator) watchTSafe() { return } sd.updateTSafe() - case <-sd.lifetime.closeCh: + case <-sd.lifetime.CloseCh(): log.Info("updateTSafe quit") // shard delegator closed return @@ -568,15 +610,24 @@ func (sd *shardDelegator) updateTSafe() { // Close closes the delegator. func (sd *shardDelegator) Close() { - sd.lifetime.SetState(stopped) + sd.lifetime.SetState(lifetime.Stopped) sd.lifetime.Close() - sd.wg.Wait() + // broadcast to all waitTsafe goroutine to quit + sd.tsCond.Broadcast() + sd.lifetime.Wait() } // NewShardDelegator creates a new ShardDelegator instance with all fields initialized. func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string, version int64, workerManager cluster.Manager, manager *segments.Manager, tsafeManager tsafe.Manager, loader segments.Loader, - factory msgstream.Factory, startTs uint64) (ShardDelegator, error) { + factory msgstream.Factory, startTs uint64, +) (ShardDelegator, error) { + log := log.With(zap.Int64("collectionID", collectionID), + zap.Int64("replicaID", replicaID), + zap.String("channel", channel), + zap.Int64("version", version), + zap.Uint64("startTs", startTs), + ) collection := manager.Collection.Get(collectionID) if collection == nil { @@ -584,7 +635,7 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string } maxSegmentDeleteBuffer := paramtable.Get().QueryNodeCfg.MaxSegmentDeleteBuffer.GetAsInt64() - log.Info("Init delte cache", zap.Int64("maxSegmentCacheBuffer", maxSegmentDeleteBuffer), zap.Time("startTime", tsoutil.PhysicalTime(startTs))) + log.Info("Init delta cache", zap.Int64("maxSegmentCacheBuffer", maxSegmentDeleteBuffer), zap.Time("startTime", tsoutil.PhysicalTime(startTs))) sd := &shardDelegator{ collectionID: collectionID, @@ -594,7 +645,7 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string collection: collection, segmentManager: manager.Segment, workerManager: workerManager, - lifetime: newLifetime(), + lifetime: lifetime.NewLifetime(lifetime.Initializing), distribution: NewDistribution(), deleteBuffer: deletebuffer.NewDoubleCacheDeleteBuffer[*deletebuffer.Item](startTs, maxSegmentDeleteBuffer), pkOracle: pkoracle.NewPkOracle(), @@ -605,7 +656,9 @@ func NewShardDelegator(collectionID UniqueID, replicaID UniqueID, channel string } m := sync.Mutex{} sd.tsCond = sync.NewCond(&m) - sd.wg.Add(1) - go sd.watchTSafe() + if sd.lifetime.Add(lifetime.NotStopped) == nil { + go sd.watchTSafe() + } + log.Info("finish build new shardDelegator") return sd, nil } diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index c7fdb7fe5ac58..b538f61a9e8fb 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -355,13 +355,44 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg } req.Base.TargetID = req.GetDstNodeID() - log.Info("worker loads segments...") - err = worker.LoadSegments(ctx, req) + log.Debug("worker loads segments...") + + sLoad := func(ctx context.Context, req *querypb.LoadSegmentsRequest) error { + segmentID := req.GetInfos()[0].GetSegmentID() + nodeID := req.GetDstNodeID() + _, err, _ := sd.sf.Do(fmt.Sprintf("%d-%d", nodeID, segmentID), func() (struct{}, error) { + err := worker.LoadSegments(ctx, req) + return struct{}{}, err + }) + return err + } + + // separate infos into different load task + if len(req.GetInfos()) > 1 { + var reqs []*querypb.LoadSegmentsRequest + for _, info := range req.GetInfos() { + newReq := typeutil.Clone(req) + newReq.Infos = []*querypb.SegmentLoadInfo{info} + reqs = append(reqs, newReq) + } + + group, ctx := errgroup.WithContext(ctx) + for _, req := range reqs { + req := req + group.Go(func() error { + return sLoad(ctx, req) + }) + } + err = group.Wait() + } else { + err = sLoad(ctx, req) + } + if err != nil { log.Warn("worker failed to load segments", zap.Error(err)) return err } - log.Info("work loads segments done") + log.Debug("work loads segments done") // load index need no stream delete and distribution change if req.GetLoadScope() == querypb.LoadScope_Index { @@ -376,7 +407,7 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg Version: req.GetVersion(), } }) - log.Info("load delete...") + log.Debug("load delete...") err = sd.loadStreamDelete(ctx, candidates, infos, req.GetDeltaPositions(), targetNodeID, worker, entries) if err != nil { log.Warn("load stream delete failed", zap.Error(err)) @@ -404,11 +435,16 @@ func (sd *shardDelegator) loadStreamDelete(ctx context.Context, defer sd.deleteMut.Unlock() // apply buffered delete for new segments // no goroutines here since qnv2 has no load merging logic - for i, info := range infos { + for _, info := range infos { candidate := idCandidates[info.GetSegmentID()] position := info.GetDeltaPosition() if position == nil { // for compatibility of rolling upgrade from 2.2.x to 2.3 - position = deltaPositions[i] + // During rolling upgrade, Querynode(2.3) may receive merged LoadSegmentRequest + // from QueryCoord(2.2); In version 2.2.x, only segments with the same dmlChannel + // can be merged, and deltaPositions will be merged into a single deltaPosition, + // so we should use `deltaPositions[0]` as the seek position for all the segments + // within the same LoadSegmentRequest. + position = deltaPositions[0] } deleteData := &storage.DeleteData{} @@ -474,7 +510,6 @@ func (sd *shardDelegator) loadStreamDelete(ctx context.Context, } func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position *msgpb.MsgPosition, safeTs uint64, candidate *pkoracle.BloomFilterSet) (*storage.DeleteData, error) { - log := sd.getLogger(ctx).With( zap.String("channel", position.ChannelName), zap.Int64("segmentID", candidate.ID()), @@ -493,9 +528,12 @@ func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position // Random the subname in case we trying to load same delta at the same time subName := fmt.Sprintf("querynode-delta-loader-%d-%d-%d", paramtable.GetNodeID(), sd.collectionID, rand.Int()) log.Info("from dml check point load delete", zap.Any("position", position), zap.String("vChannel", vchannelName), zap.String("subName", subName), zap.Time("positionTs", ts)) - stream.AsConsumer([]string{pChannelName}, subName, mqwrapper.SubscriptionPositionUnknown) + err = stream.AsConsumer(context.TODO(), []string{pChannelName}, subName, mqwrapper.SubscriptionPositionUnknown) + if err != nil { + return nil, err + } - err = stream.Seek([]*msgpb.MsgPosition{position}) + err = stream.Seek(context.TODO(), []*msgpb.MsgPosition{position}) if err != nil { return nil, err } @@ -541,7 +579,6 @@ func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position // reach safe ts if safeTs <= msgPack.EndPositions[0].GetTimestamp() { hasMore = false - break } } } @@ -624,7 +661,8 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele } func (sd *shardDelegator) SyncTargetVersion(newVersion int64, growingInTarget []int64, - sealedInTarget []int64, droppedInTarget []int64) { + sealedInTarget []int64, droppedInTarget []int64, +) { growings := sd.segmentManager.GetBy( segments.WithType(segments.SegmentTypeGrowing), segments.WithChannel(sd.vchannelName), diff --git a/internal/querynodev2/delegator/delegator_data_test.go b/internal/querynodev2/delegator/delegator_data_test.go index 4cb92f2e4cfb8..fb6ce0ab6ef07 100644 --- a/internal/querynodev2/delegator/delegator_data_test.go +++ b/internal/querynodev2/delegator/delegator_data_test.go @@ -468,6 +468,79 @@ func (s *DelegatorDataSuite) TestLoadSegments() { }, sealed[0].Segments) }) + s.Run("load_segments_with_streaming_delete_failed", func() { + defer func() { + s.workerManager.ExpectedCalls = nil + s.loader.ExpectedCalls = nil + }() + + s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything). + Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { + return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet { + bfs := pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed) + bf := bloom.NewWithEstimates(storage.BloomFilterSize, storage.MaxBloomFalsePositive) + pks := &storage.PkStatistics{ + PkFilter: bf, + } + pks.UpdatePKRange(&storage.Int64FieldData{ + Data: []int64{10, 20, 30}, + }) + bfs.AddHistoricalStats(pks) + return bfs + }) + }, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error { + return nil + }) + + workers := make(map[int64]*cluster.MockWorker) + worker1 := &cluster.MockWorker{} + workers[1] = worker1 + + worker1.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")). + Return(nil) + worker1.EXPECT().Delete(mock.Anything, mock.AnythingOfType("*querypb.DeleteRequest")).Return(nil) + s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker { + return workers[nodeID] + }, nil) + + s.delegator.ProcessDelete([]*DeleteData{ + { + PartitionID: 500, + PrimaryKeys: []storage.PrimaryKey{ + storage.NewInt64PrimaryKey(1), + storage.NewInt64PrimaryKey(10), + }, + Timestamps: []uint64{10, 10}, + RowCount: 2, + }, + }, 10) + + s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + s.mq.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil) + s.mq.EXPECT().Close() + ch := make(chan *msgstream.MsgPack, 10) + close(ch) + + s.mq.EXPECT().Chan().Return(ch) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + err := s.delegator.LoadSegments(ctx, &querypb.LoadSegmentsRequest{ + Base: commonpbutil.NewMsgBase(), + DstNodeID: 1, + CollectionID: s.collectionID, + Infos: []*querypb.SegmentLoadInfo{ + { + SegmentID: 300, + PartitionID: 500, + StartPosition: &msgpb.MsgPosition{Timestamp: 2}, + DeltaPosition: &msgpb.MsgPosition{Timestamp: 2}, + }, + }, + }) + s.Error(err) + }) + s.Run("get_worker_fail", func() { defer func() { s.workerManager.ExpectedCalls = nil diff --git a/internal/querynodev2/delegator/delegator_test.go b/internal/querynodev2/delegator/delegator_test.go index d3de8c54680d4..5aecf40e60033 100644 --- a/internal/querynodev2/delegator/delegator_test.go +++ b/internal/querynodev2/delegator/delegator_test.go @@ -18,6 +18,7 @@ package delegator import ( "context" + "io" "sync" "testing" "time" @@ -38,9 +39,11 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/cluster" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tsafe" + "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -68,7 +71,6 @@ func (s *DelegatorSuite) SetupSuite() { } func (s *DelegatorSuite) TearDownSuite() { - } func (s *DelegatorSuite) SetupTest() { @@ -204,12 +206,10 @@ func (s *DelegatorSuite) TestGetSegmentInfo() { s.Equal(0, len(growing)) } -func (s *DelegatorSuite) TestSearch() { - s.delegator.Start() - // 1 => sealed segment 1000, 1001 - // 1 => growing segment 1004 - // 2 => sealed segment 1002, 1003 - paramtable.SetNodeID(1) +// nodeID 1 => sealed segment 1000, 1001 +// nodeID 1 => growing segment 1004 +// nodeID 2 => sealed segment 1002, 1003 +func (s *DelegatorSuite) initSegments() { s.delegator.LoadGrowing(context.Background(), []*querypb.SegmentLoadInfo{ { SegmentID: 1004, @@ -244,6 +244,12 @@ func (s *DelegatorSuite) TestSearch() { }, ) s.delegator.SyncTargetVersion(2001, []int64{1004}, []int64{1000, 1001, 1002, 1003}, []int64{}) +} + +func (s *DelegatorSuite) TestSearch() { + s.delegator.Start() + paramtable.SetNodeID(1) + s.initSegments() s.Run("normal", func() { defer func() { s.workerManager.ExpectedCalls = nil @@ -296,33 +302,19 @@ func (s *DelegatorSuite) TestSearch() { defer func() { s.workerManager.ExpectedCalls = nil }() - workers := make(map[int64]*cluster.MockWorker) - worker1 := &cluster.MockWorker{} - worker2 := &cluster.MockWorker{} - - workers[1] = worker1 - workers[2] = worker2 - - worker1.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")). - Return(&internalpb.SearchResults{}, nil) - worker2.EXPECT().SearchSegments(mock.Anything, mock.AnythingOfType("*querypb.SearchRequest")). - Return(&internalpb.SearchResults{}, nil) - - s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker { - return workers[nodeID] - }, nil) ctx, cancel := context.WithCancel(context.Background()) defer cancel() _, err := s.delegator.Search(ctx, &querypb.SearchRequest{ Req: &internalpb.SearchRequest{ - Base: commonpbutil.NewMsgBase(), - PartitionIDs: []int64{500}, + Base: commonpbutil.NewMsgBase(), + // not load partation -1,will return error + PartitionIDs: []int64{-1}, }, DmlChannels: []string{s.vchannelName}, }) - errors.Is(err, merr.ErrPartitionNotLoaded) + s.True(errors.Is(err, merr.ErrPartitionNotLoaded)) }) s.Run("worker_return_error", func() { @@ -459,44 +451,8 @@ func (s *DelegatorSuite) TestSearch() { func (s *DelegatorSuite) TestQuery() { s.delegator.Start() - // 1 => sealed segment 1000, 1001 - // 1 => growing segment 1004 - // 2 => sealed segment 1002, 1003 paramtable.SetNodeID(1) - s.delegator.LoadGrowing(context.Background(), []*querypb.SegmentLoadInfo{ - { - SegmentID: 1004, - CollectionID: s.collectionID, - PartitionID: 500, - }, - }, 0) - s.delegator.SyncDistribution(context.Background(), - SegmentEntry{ - NodeID: 1, - SegmentID: 1000, - PartitionID: 500, - Version: 2001, - }, - SegmentEntry{ - NodeID: 1, - SegmentID: 1001, - PartitionID: 501, - Version: 2001, - }, - SegmentEntry{ - NodeID: 2, - SegmentID: 1002, - PartitionID: 500, - Version: 2001, - }, - SegmentEntry{ - NodeID: 2, - SegmentID: 1003, - PartitionID: 501, - Version: 2001, - }, - ) - s.delegator.SyncTargetVersion(2001, []int64{1004}, []int64{1000, 1001, 1002, 1003}, []int64{}) + s.initSegments() s.Run("normal", func() { defer func() { s.workerManager.ExpectedCalls = nil @@ -549,33 +505,18 @@ func (s *DelegatorSuite) TestQuery() { defer func() { s.workerManager.ExpectedCalls = nil }() - workers := make(map[int64]*cluster.MockWorker) - worker1 := &cluster.MockWorker{} - worker2 := &cluster.MockWorker{} - - workers[1] = worker1 - workers[2] = worker2 - - worker1.EXPECT().QuerySegments(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")). - Return(&internalpb.RetrieveResults{}, nil) - worker2.EXPECT().QuerySegments(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest")). - Return(&internalpb.RetrieveResults{}, nil) - - s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker { - return workers[nodeID] - }, nil) ctx, cancel := context.WithCancel(context.Background()) defer cancel() _, err := s.delegator.Query(ctx, &querypb.QueryRequest{ Req: &internalpb.RetrieveRequest{ - Base: commonpbutil.NewMsgBase(), - PartitionIDs: []int64{500}, + Base: commonpbutil.NewMsgBase(), + // not load partation -1,will return error + PartitionIDs: []int64{-1}, }, DmlChannels: []string{s.vchannelName}, }) - - errors.Is(err, merr.ErrPartitionNotLoaded) + s.True(errors.Is(err, merr.ErrPartitionNotLoaded)) }) s.Run("worker_return_error", func() { @@ -677,47 +618,279 @@ func (s *DelegatorSuite) TestQuery() { }) } +func (s *DelegatorSuite) TestQueryStream() { + s.delegator.Start() + paramtable.SetNodeID(1) + s.initSegments() + + s.Run("normal", func() { + defer func() { + s.workerManager.AssertExpectations(s.T()) + s.workerManager.ExpectedCalls = nil + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + workers := make(map[int64]*cluster.MockWorker) + worker1 := &cluster.MockWorker{} + worker2 := &cluster.MockWorker{} + + workers[1] = worker1 + workers[2] = worker2 + s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker { + return workers[nodeID] + }, nil) + + worker1.EXPECT().QueryStreamSegments(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest"), mock.Anything). + Run(func(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) { + s.EqualValues(1, req.Req.GetBase().GetTargetID()) + s.True(req.GetFromShardLeader()) + if req.GetScope() == querypb.DataScope_Streaming { + s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels()) + s.ElementsMatch([]int64{1004}, req.GetSegmentIDs()) + } + if req.GetScope() == querypb.DataScope_Historical { + s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels()) + s.ElementsMatch([]int64{1000, 1001}, req.GetSegmentIDs()) + } + + srv.Send(&internalpb.RetrieveResults{ + Status: merr.Success(), + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: req.GetSegmentIDs()}, + }, + }, + }) + }).Return(nil) + + worker2.EXPECT().QueryStreamSegments(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest"), mock.Anything). + Run(func(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) { + s.EqualValues(2, req.Req.GetBase().GetTargetID()) + s.True(req.GetFromShardLeader()) + s.Equal(querypb.DataScope_Historical, req.GetScope()) + s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels()) + s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs()) + srv.Send(&internalpb.RetrieveResults{ + Status: merr.Success(), + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: req.GetSegmentIDs()}, + }, + }, + }) + }).Return(nil) + + // run stream function + go func() { + err := s.delegator.QueryStream(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()}, + DmlChannels: []string{s.vchannelName}, + }, server) + s.NoError(err) + server.FinishSend(err) + }() + + resultIDs := []int64{1000, 1001, 1002, 1003, 1004} + recNum := 0 + for { + result, err := client.Recv() + if err == io.EOF { + s.Equal(recNum, len(resultIDs)) + break + } + s.NoError(err) + + err = merr.Error(result.GetStatus()) + s.NoError(err) + + for _, segmentID := range result.Ids.GetIntId().Data { + s.Less(recNum, len(resultIDs)) + lo.Contains[int64](resultIDs, segmentID) + recNum++ + } + } + }) + + s.Run("partition_not_loaded", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + err := s.delegator.QueryStream(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + Base: commonpbutil.NewMsgBase(), + // not load partation -1,will return error + PartitionIDs: []int64{-1}, + }, + DmlChannels: []string{s.vchannelName}, + }, server) + s.True(errors.Is(err, merr.ErrPartitionNotLoaded)) + }) + + s.Run("tsafe_behind_max_lag", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + err := s.delegator.QueryStream(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + Base: commonpbutil.NewMsgBase(), + GuaranteeTimestamp: uint64(paramtable.Get().QueryNodeCfg.MaxTimestampLag.GetAsDuration(time.Second)), + }, + DmlChannels: []string{s.vchannelName}, + }, server) + s.Error(err) + }) + + s.Run("get_worker_failed", func() { + defer func() { + s.workerManager.AssertExpectations(s.T()) + s.workerManager.ExpectedCalls = nil + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + mockErr := errors.New("mock error") + + s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(nil, mockErr) + + // run stream function + err := s.delegator.QueryStream(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()}, + DmlChannels: []string{s.vchannelName}, + }, server) + s.True(errors.Is(err, mockErr)) + }) + + s.Run("worker_return_error", func() { + defer func() { + s.workerManager.AssertExpectations(s.T()) + s.workerManager.ExpectedCalls = nil + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + workers := make(map[int64]*cluster.MockWorker) + worker1 := &cluster.MockWorker{} + worker2 := &cluster.MockWorker{} + mockErr := errors.New("mock error") + + workers[1] = worker1 + workers[2] = worker2 + s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).Call.Return(func(_ context.Context, nodeID int64) cluster.Worker { + return workers[nodeID] + }, nil) + + worker1.EXPECT().QueryStreamSegments(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest"), mock.Anything). + Return(mockErr) + + worker2.EXPECT().QueryStreamSegments(mock.Anything, mock.AnythingOfType("*querypb.QueryRequest"), mock.Anything). + Run(func(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) { + s.EqualValues(2, req.Req.GetBase().GetTargetID()) + s.True(req.GetFromShardLeader()) + s.Equal(querypb.DataScope_Historical, req.GetScope()) + s.EqualValues([]string{s.vchannelName}, req.GetDmlChannels()) + s.ElementsMatch([]int64{1002, 1003}, req.GetSegmentIDs()) + srv.Send(&internalpb.RetrieveResults{ + Status: merr.Success(), + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: req.GetSegmentIDs()}, + }, + }, + }) + }).Return(nil) + + // run stream function + go func() { + err := s.delegator.QueryStream(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()}, + DmlChannels: []string{s.vchannelName}, + }, server) + server.Send(&internalpb.RetrieveResults{ + Status: merr.Status(err), + }) + }() + + resultIDs := []int64{1002, 1003} + recNum := 0 + for { + result, err := client.Recv() + s.NoError(err) + + err = merr.Error(result.GetStatus()) + if err != nil { + s.Equal(recNum, len(resultIDs)) + s.Equal(err.Error(), mockErr.Error()) + break + } + + for _, segmentID := range result.Ids.GetIntId().GetData() { + s.Less(recNum, len(resultIDs)) + lo.Contains[int64](resultIDs, segmentID) + recNum++ + } + } + }) + + s.Run("wrong_channel", func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + err := s.delegator.QueryStream(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()}, + DmlChannels: []string{"non_exist_channel"}, + }, server) + + s.Error(err) + }) + + s.Run("cluster_not_serviceable", func() { + s.delegator.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + err := s.delegator.QueryStream(ctx, &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()}, + DmlChannels: []string{s.vchannelName}, + }, server) + + s.Error(err) + }) +} + func (s *DelegatorSuite) TestGetStats() { s.delegator.Start() // 1 => sealed segment 1000, 1001 // 1 => growing segment 1004 // 2 => sealed segment 1002, 1003 paramtable.SetNodeID(1) - s.delegator.LoadGrowing(context.Background(), []*querypb.SegmentLoadInfo{ - { - SegmentID: 1004, - CollectionID: s.collectionID, - PartitionID: 500, - }, - }, 0) - s.delegator.SyncDistribution(context.Background(), - SegmentEntry{ - NodeID: 1, - SegmentID: 1000, - PartitionID: 500, - Version: 2001, - }, - SegmentEntry{ - NodeID: 1, - SegmentID: 1001, - PartitionID: 501, - Version: 2001, - }, - SegmentEntry{ - NodeID: 2, - SegmentID: 1002, - PartitionID: 500, - Version: 2001, - }, - SegmentEntry{ - NodeID: 2, - SegmentID: 1003, - PartitionID: 501, - Version: 2001, - }, - ) + s.initSegments() - s.delegator.SyncTargetVersion(2001, []int64{1004}, []int64{1000, 1001, 1002, 1003}, []int64{}) s.Run("normal", func() { defer func() { s.workerManager.ExpectedCalls = nil @@ -877,15 +1050,16 @@ func TestDelegatorWatchTsafe(t *testing.T) { sd := &shardDelegator{ tsafeManager: tsafeManager, vchannelName: channelName, - lifetime: newLifetime(), + lifetime: lifetime.NewLifetime(lifetime.Initializing), latestTsafe: atomic.NewUint64(0), } defer sd.Close() m := sync.Mutex{} sd.tsCond = sync.NewCond(&m) - sd.wg.Add(1) - go sd.watchTSafe() + if sd.lifetime.Add(lifetime.NotStopped) == nil { + go sd.watchTSafe() + } err := tsafeManager.Set(channelName, 200) require.NoError(t, err) @@ -903,19 +1077,20 @@ func TestDelegatorTSafeListenerClosed(t *testing.T) { sd := &shardDelegator{ tsafeManager: tsafeManager, vchannelName: channelName, - lifetime: newLifetime(), + lifetime: lifetime.NewLifetime(lifetime.Initializing), latestTsafe: atomic.NewUint64(0), } defer sd.Close() m := sync.Mutex{} sd.tsCond = sync.NewCond(&m) - sd.wg.Add(1) signal := make(chan struct{}) - go func() { - sd.watchTSafe() - close(signal) - }() + if sd.lifetime.Add(lifetime.NotStopped) == nil { + go func() { + sd.watchTSafe() + close(signal) + }() + } select { case <-signal: diff --git a/internal/querynodev2/delegator/deletebuffer/delete_buffer.go b/internal/querynodev2/delegator/deletebuffer/delete_buffer.go index 970c9b2355a2d..c652ae1f27f57 100644 --- a/internal/querynodev2/delegator/deletebuffer/delete_buffer.go +++ b/internal/querynodev2/delegator/deletebuffer/delete_buffer.go @@ -23,9 +23,7 @@ import ( "github.com/cockroachdb/errors" ) -var ( - errBufferFull = errors.New("buffer full") -) +var errBufferFull = errors.New("buffer full") type timed interface { Timestamp() uint64 diff --git a/internal/querynodev2/delegator/deletebuffer/delete_buffer_test.go b/internal/querynodev2/delegator/deletebuffer/delete_buffer_test.go index 86b9217fd17b1..e580916ac7a9b 100644 --- a/internal/querynodev2/delegator/deletebuffer/delete_buffer_test.go +++ b/internal/querynodev2/delegator/deletebuffer/delete_buffer_test.go @@ -19,9 +19,10 @@ package deletebuffer import ( "testing" - "github.com/milvus-io/milvus/internal/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/storage" ) func TestSkipListDeleteBuffer(t *testing.T) { diff --git a/internal/querynodev2/delegator/deletebuffer/delete_item.go b/internal/querynodev2/delegator/deletebuffer/delete_item.go index df85f7b37c58a..abc89baa0c0f9 100644 --- a/internal/querynodev2/delegator/deletebuffer/delete_item.go +++ b/internal/querynodev2/delegator/deletebuffer/delete_item.go @@ -1,8 +1,9 @@ package deletebuffer import ( - "github.com/milvus-io/milvus/internal/storage" "github.com/samber/lo" + + "github.com/milvus-io/milvus/internal/storage" ) // Item wraps cache item as `timed`. diff --git a/internal/querynodev2/delegator/deletebuffer/delete_item_test.go b/internal/querynodev2/delegator/deletebuffer/delete_item_test.go index a35cab26885ad..59bf9d979337f 100644 --- a/internal/querynodev2/delegator/deletebuffer/delete_item_test.go +++ b/internal/querynodev2/delegator/deletebuffer/delete_item_test.go @@ -3,8 +3,9 @@ package deletebuffer import ( "testing" - "github.com/milvus-io/milvus/internal/storage" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/storage" ) func TestDeleteBufferItem(t *testing.T) { diff --git a/internal/querynodev2/delegator/distribution.go b/internal/querynodev2/delegator/distribution.go index fc949b6f22e66..795ee8b39bdf0 100644 --- a/internal/querynodev2/delegator/distribution.go +++ b/internal/querynodev2/delegator/distribution.go @@ -19,12 +19,12 @@ package delegator import ( "sync" + "github.com/samber/lo" "go.uber.org/atomic" "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/samber/lo" ) const ( @@ -277,8 +277,8 @@ func (d *distribution) SyncTargetVersion(newVersion int64, growingInTarget []int log.Info("Update readable segment version", zap.Int64("oldVersion", oldValue), zap.Int64("newVersion", newVersion), - zap.Int64s("growing", growingInTarget), - zap.Int64s("sealed", sealedInTarget), + zap.Int("growingSegmentNum", len(growingInTarget)), + zap.Int("sealedSegmentNum", len(sealedInTarget)), ) } diff --git a/internal/querynodev2/delegator/mock_delegator.go b/internal/querynodev2/delegator/mock_delegator.go index 3f39ba85f1238..c1f5e95e0cb84 100644 --- a/internal/querynodev2/delegator/mock_delegator.go +++ b/internal/querynodev2/delegator/mock_delegator.go @@ -9,6 +9,8 @@ import ( mock "github.com/stretchr/testify/mock" querypb "github.com/milvus-io/milvus/internal/proto/querypb" + + streamrpc "github.com/milvus-io/milvus/internal/util/streamrpc" ) // MockShardDelegator is an autogenerated mock type for the ShardDelegator type @@ -458,6 +460,50 @@ func (_c *MockShardDelegator_Query_Call) RunAndReturn(run func(context.Context, return _c } +// QueryStream provides a mock function with given fields: ctx, req, srv +func (_m *MockShardDelegator) QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error { + ret := _m.Called(ctx, req, srv) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest, streamrpc.QueryStreamServer) error); ok { + r0 = rf(ctx, req, srv) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockShardDelegator_QueryStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryStream' +type MockShardDelegator_QueryStream_Call struct { + *mock.Call +} + +// QueryStream is a helper method to define mock.On call +// - ctx context.Context +// - req *querypb.QueryRequest +// - srv streamrpc.QueryStreamServer +func (_e *MockShardDelegator_Expecter) QueryStream(ctx interface{}, req interface{}, srv interface{}) *MockShardDelegator_QueryStream_Call { + return &MockShardDelegator_QueryStream_Call{Call: _e.mock.On("QueryStream", ctx, req, srv)} +} + +func (_c *MockShardDelegator_QueryStream_Call) Run(run func(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer)) *MockShardDelegator_QueryStream_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.QueryRequest), args[2].(streamrpc.QueryStreamServer)) + }) + return _c +} + +func (_c *MockShardDelegator_QueryStream_Call) Return(_a0 error) *MockShardDelegator_QueryStream_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockShardDelegator_QueryStream_Call) RunAndReturn(run func(context.Context, *querypb.QueryRequest, streamrpc.QueryStreamServer) error) *MockShardDelegator_QueryStream_Call { + _c.Call.Return(run) + return _c +} + // ReleaseSegments provides a mock function with given fields: ctx, req, force func (_m *MockShardDelegator) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error { ret := _m.Called(ctx, req, force) diff --git a/internal/querynodev2/delegator/types.go b/internal/querynodev2/delegator/types.go index 597ee6af95ca9..b981cbbfbd63b 100644 --- a/internal/querynodev2/delegator/types.go +++ b/internal/querynodev2/delegator/types.go @@ -50,10 +50,8 @@ type TSafeUpdater interface { UnregisterChannel(string) error } -var ( - // ErrTsLagTooLarge serviceable and guarantee lag too large. - ErrTsLagTooLarge = errors.New("Timestamp lag too large") -) +// ErrTsLagTooLarge serviceable and guarantee lag too large. +var ErrTsLagTooLarge = errors.New("Timestamp lag too large") // WrapErrTsLagTooLarge wraps ErrTsLagTooLarge with lag and max value. func WrapErrTsLagTooLarge(duration time.Duration, maxLag time.Duration) error { diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index 0dab6bbcfacd4..02bb7e027d1c1 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -32,11 +32,11 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" - "github.com/milvus-io/milvus/internal/util" + "github.com/milvus-io/milvus/internal/querynodev2/tasks" + "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" - "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -98,10 +98,10 @@ func (node *QueryNode) loadDeltaLogs(ctx context.Context, req *querypb.LoadSegme if finalErr != nil { log.Warn("failed to load delta logs", zap.Error(finalErr)) - return util.WrapStatus(commonpb.ErrorCode_UnexpectedError, "failed to load delta logs", finalErr) + return merr.Status(finalErr) } - return util.SuccessStatus() + return merr.Success() } func (node *QueryNode) loadIndex(ctx context.Context, req *querypb.LoadSegmentsRequest) *commonpb.Status { @@ -110,7 +110,7 @@ func (node *QueryNode) loadIndex(ctx context.Context, req *querypb.LoadSegmentsR zap.Int64s("segmentIDs", lo.Map(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) int64 { return info.GetSegmentID() })), ) - status := util.SuccessStatus() + status := merr.Success() log.Info("start to load index") for _, info := range req.GetInfos() { @@ -147,21 +147,14 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque zap.String("scope", req.GetScope().String()), ) - failRet := WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "") + var err error metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.Leader).Inc() defer func() { - if failRet.Status.ErrorCode != commonpb.ErrorCode_Success { + if err != nil { metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader).Inc() } }() - if !node.lifetime.Add(commonpbutil.IsHealthy) { - err := merr.WrapErrServiceUnavailable(fmt.Sprintf("node id: %d is unhealthy", paramtable.GetNodeID())) - failRet.Status = merr.Status(err) - return failRet, nil - } - defer node.lifetime.Done() - log.Debug("start do query with channel", zap.Bool("fromShardLeader", req.GetFromShardLeader()), zap.Int64s("segmentIDs", req.GetSegmentIDs()), @@ -175,18 +168,16 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque // get delegator sd, ok := node.delegators.Get(channel) if !ok { - err := merr.WrapErrServiceUnavailable("failed to get shard delegator for query") + err := merr.WrapErrChannelNotFound(channel) log.Warn("Query failed, failed to get shard delegator for query", zap.Error(err)) - failRet.Status = merr.Status(err) - return failRet, nil + return nil, err } // do query results, err := sd.Query(queryCtx, req) if err != nil { log.Warn("failed to query on delegator", zap.Error(err)) - failRet.Status.Reason = err.Error() - return failRet, nil + return nil, err } // reduce result @@ -201,16 +192,14 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque if collection == nil { err := merr.WrapErrCollectionNotFound(req.Req.GetCollectionID()) log.Warn("Query failed, failed to get collection", zap.Error(err)) - failRet.Status = merr.Status(err) - return failRet, nil + return nil, err } reducer := segments.CreateInternalReducer(req, collection.Schema()) - ret, err := reducer.Reduce(ctx, results) + resp, err := reducer.Reduce(ctx, results) if err != nil { - failRet.Status.Reason = err.Error() - return failRet, nil + return nil, err } tr.CtxElapse(ctx, fmt.Sprintf("do query with channel done , vChannel = %s, segmentIDs = %v", @@ -218,12 +207,82 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque req.GetSegmentIDs(), )) - // - failRet.Status.ErrorCode = commonpb.ErrorCode_Success latency := tr.ElapseSpan() metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.Leader).Observe(float64(latency.Milliseconds())) metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.Leader).Inc() - return ret, nil + return resp, nil +} + +func (node *QueryNode) queryChannelStream(ctx context.Context, req *querypb.QueryRequest, channel string, srv streamrpc.QueryStreamServer) error { + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.Leader).Inc() + msgID := req.Req.Base.GetMsgID() + log := log.Ctx(ctx).With( + zap.Int64("msgID", msgID), + zap.Int64("collectionID", req.GetReq().GetCollectionID()), + zap.String("channel", channel), + zap.String("scope", req.GetScope().String()), + ) + + var err error + defer func() { + if err != nil { + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader).Inc() + } + }() + + log.Debug("start do streaming query with channel", + zap.Bool("fromShardLeader", req.GetFromShardLeader()), + zap.Int64s("segmentIDs", req.GetSegmentIDs()), + ) + + // add cancel when error occurs + queryCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // From Proxy + tr := timerecord.NewTimeRecorder("queryDelegator") + // get delegator + sd, ok := node.delegators.Get(channel) + if !ok { + err := merr.WrapErrChannelNotFound(channel) + log.Warn("Query failed, failed to get query shard delegator", zap.Error(err)) + return err + } + + // do query + err = sd.QueryStream(queryCtx, req, srv) + if err != nil { + return err + } + + tr.CtxElapse(ctx, fmt.Sprintf("do query with channel done , vChannel = %s, segmentIDs = %v", + channel, + req.GetSegmentIDs(), + )) + + return nil +} + +func (node *QueryNode) queryStreamSegments(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error { + collection := node.manager.Collection.Get(req.Req.GetCollectionID()) + if collection == nil { + return merr.WrapErrCollectionNotFound(req.Req.GetCollectionID()) + } + + // Send task to scheduler and wait until it finished. + task := tasks.NewQueryStreamTask(ctx, collection, node.manager, req, srv) + if err := node.scheduler.Add(task); err != nil { + log.Warn("failed to add query task into scheduler", zap.Error(err)) + return err + } + + err := task.Wait() + if err != nil { + log.Warn("failed to execute task by node scheduler", zap.Error(err)) + return err + } + + return nil } func (node *QueryNode) optimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, deleg delegator.ShardDelegator) (*querypb.SearchRequest, error) { @@ -301,15 +360,15 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq ) traceID := trace.SpanFromContext(ctx).SpanContext().TraceID() - if !node.lifetime.Add(commonpbutil.IsHealthy) { - return nil, merr.WrapErrServiceNotReady(fmt.Sprintf("node id: %d is unhealthy", paramtable.GetNodeID())) + if err := node.lifetime.Add(merr.IsHealthy); err != nil { + return nil, err } defer node.lifetime.Done() - failRet := WrapSearchResult(commonpb.ErrorCode_UnexpectedError, "") + var err error metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.TotalLabel, metrics.Leader).Inc() defer func() { - if failRet.Status.ErrorCode != commonpb.ErrorCode_Success { + if err != nil { metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel, metrics.Leader).Inc() } }() @@ -326,23 +385,20 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq // get delegator sd, ok := node.delegators.Get(channel) if !ok { - err := merr.WrapErrServiceUnavailable("failed to get shard delegator for search") + err := merr.WrapErrChannelNotFound(channel) log.Warn("Query failed, failed to get shard delegator for search", zap.Error(err)) - failRet.Status.Reason = err.Error() - return failRet, err + return nil, err } - req, err := node.optimizeSearchParams(ctx, req, sd) + req, err = node.optimizeSearchParams(ctx, req, sd) if err != nil { log.Warn("failed to optimize search params", zap.Error(err)) - failRet.Status.Reason = err.Error() - return failRet, err + return nil, err } // do search results, err := sd.Search(searchCtx, req) if err != nil { log.Warn("failed to search on delegator", zap.Error(err)) - failRet.Status.Reason = err.Error() - return failRet, err + return nil, err } // reduce result @@ -353,10 +409,9 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq req.GetSegmentIDs(), )) - ret, err := segments.ReduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType()) + resp, err := segments.ReduceSearchResults(ctx, results, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType()) if err != nil { - failRet.Status.Reason = err.Error() - return failRet, err + return nil, err } tr.CtxElapse(ctx, fmt.Sprintf("do search with channel done , vChannel = %s, segmentIDs = %v", @@ -365,14 +420,13 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq )) // update metric to prometheus - failRet.Status.ErrorCode = commonpb.ErrorCode_Success latency := tr.ElapseSpan() metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.Leader).Observe(float64(latency.Milliseconds())) metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.Leader).Inc() metrics.QueryNodeSearchNQ.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetNq())) metrics.QueryNodeSearchTopK.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Observe(float64(req.Req.GetTopk())) - return ret, nil + return resp, nil } func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.GetStatisticsRequest, channel string) (*internalpb.GetStatisticsResponse, error) { @@ -381,11 +435,8 @@ func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.Ge zap.String("channel", channel), zap.String("scope", req.GetScope().String()), ) - failRet := &internalpb.GetStatisticsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, - } + + resp := &internalpb.GetStatisticsResponse{} if req.GetFromShardLeader() { var ( @@ -411,23 +462,26 @@ func (node *QueryNode) getChannelStatistics(ctx context.Context, req *querypb.Ge sd, ok := node.delegators.Get(channel) if !ok { - log.Warn("GetStatistics failed, failed to get query shard delegator") - return failRet, nil + err := merr.WrapErrChannelNotFound(channel, "failed to get channel statistics") + log.Warn("GetStatistics failed, failed to get query shard delegator", zap.Error(err)) + resp.Status = merr.Status(err) + return resp, nil } results, err := sd.GetStatistics(ctx, req) if err != nil { log.Warn("failed to get statistics from delegator", zap.Error(err)) - failRet.Status.Reason = err.Error() - return failRet, nil + resp.Status = merr.Status(err) + return resp, nil } - ret, err := reduceStatisticResponse(results) + resp, err = reduceStatisticResponse(results) if err != nil { - failRet.Status.Reason = err.Error() - return failRet, nil + log.Warn("failed to reduce channel statistics", zap.Error(err)) + resp.Status = merr.Status(err) + return resp, nil } - return ret, nil + return resp, nil } func segmentStatsResponse(segStats []segments.SegmentStats) *internalpb.GetStatisticsResponse { @@ -440,7 +494,7 @@ func segmentStatsResponse(segStats []segments.SegmentStats) *internalpb.GetStati resultMap["row_count"] = strconv.FormatInt(totalRowNum, 10) ret := &internalpb.GetStatisticsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Stats: funcutil.Map2KeyValuePair(resultMap), } return ret @@ -479,7 +533,7 @@ func reduceStatisticResponse(results []*internalpb.GetStatisticsResponse) (*inte } ret := &internalpb.GetStatisticsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Stats: funcutil.Map2KeyValuePair(stringMap), } return ret, nil diff --git a/internal/querynodev2/handlers_test.go b/internal/querynodev2/handlers_test.go index 774586a5565a3..0088d6f531fb0 100644 --- a/internal/querynodev2/handlers_test.go +++ b/internal/querynodev2/handlers_test.go @@ -31,6 +31,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" + "github.com/milvus-io/milvus/internal/querynodev2/optimizers" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/common" @@ -180,7 +181,7 @@ func (suite *OptimizeSearchParamSuite) TestOptimizeSearchParam() { defer cancel() suite.Run("normal_run", func() { - mockHook := &MockQueryHook{} + mockHook := optimizers.NewMockQueryHook(suite.T()) mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { params[common.TopKKey] = int64(50) params[common.SearchParamKey] = `{"param": 2}` @@ -237,7 +238,7 @@ func (suite *OptimizeSearchParamSuite) TestOptimizeSearchParam() { }) suite.Run("other_plannode", func() { - mockHook := &MockQueryHook{} + mockHook := optimizers.NewMockQueryHook(suite.T()) mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { params[common.TopKKey] = int64(50) params[common.SearchParamKey] = `{"param": 2}` @@ -262,11 +263,7 @@ func (suite *OptimizeSearchParamSuite) TestOptimizeSearchParam() { }) suite.Run("no_serialized_plan", func() { - mockHook := &MockQueryHook{} - mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { - params[common.TopKKey] = int64(50) - params[common.SearchParamKey] = `{"param": 2}` - }).Return(nil) + mockHook := optimizers.NewMockQueryHook(suite.T()) suite.node.queryHook = mockHook defer func() { suite.node.queryHook = nil }() @@ -278,7 +275,7 @@ func (suite *OptimizeSearchParamSuite) TestOptimizeSearchParam() { }) suite.Run("hook_run_error", func() { - mockHook := &MockQueryHook{} + mockHook := optimizers.NewMockQueryHook(suite.T()) mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { params[common.TopKKey] = int64(50) params[common.SearchParamKey] = `{"param": 2}` diff --git a/internal/querynodev2/local_worker.go b/internal/querynodev2/local_worker.go index ad7586baf53e0..0ef6af50139ec 100644 --- a/internal/querynodev2/local_worker.go +++ b/internal/querynodev2/local_worker.go @@ -20,14 +20,16 @@ import ( "context" "fmt" + "github.com/samber/lo" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/cluster" "github.com/milvus-io/milvus/internal/querynodev2/segments" + "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/log" - "github.com/samber/lo" - "go.uber.org/zap" ) var _ cluster.Worker = &LocalWorker{} @@ -86,7 +88,7 @@ func (w *LocalWorker) Delete(ctx context.Context, req *querypb.DeleteRequest) er zap.Int64("collectionID", req.GetCollectionId()), zap.Int64("segmentID", req.GetSegmentId()), ) - log.Info("start to process segment delete") + log.Debug("start to process segment delete") status, err := w.node.Delete(ctx, req) if err != nil { return err @@ -101,6 +103,10 @@ func (w *LocalWorker) SearchSegments(ctx context.Context, req *querypb.SearchReq return w.node.SearchSegments(ctx, req) } +func (w *LocalWorker) QueryStreamSegments(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error { + return w.node.queryStreamSegments(ctx, req, srv) +} + func (w *LocalWorker) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { return w.node.QuerySegments(ctx, req) } diff --git a/internal/querynodev2/local_worker_test.go b/internal/querynodev2/local_worker_test.go index 3a68525cfba8a..67a55c049041a 100644 --- a/internal/querynodev2/local_worker_test.go +++ b/internal/querynodev2/local_worker_test.go @@ -18,7 +18,6 @@ package querynodev2 import ( "context" - "testing" "github.com/samber/lo" diff --git a/internal/querynodev2/metrics_info.go b/internal/querynodev2/metrics_info.go index 8f7b6f2a547b0..d3bbc0527bc9f 100644 --- a/internal/querynodev2/metrics_info.go +++ b/internal/querynodev2/metrics_info.go @@ -23,7 +23,6 @@ import ( "github.com/samber/lo" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/querynodev2/collector" "github.com/milvus-io/milvus/internal/querynodev2/segments" @@ -163,10 +162,7 @@ func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, quotaMetrics, err := getQuotaMetrics(node) if err != nil { return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), ComponentName: metricsinfo.ConstructComponentName(typeutil.DataNodeRole, paramtable.GetNodeID()), }, nil } @@ -201,17 +197,14 @@ func getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest, resp, err := metricsinfo.MarshalComponentInfos(nodeInfos) if err != nil { return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), Response: "", ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()), }, nil } return &milvuspb.GetMetricsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.QueryNodeRole, paramtable.GetNodeID()), }, nil diff --git a/internal/querynodev2/mock_data.go b/internal/querynodev2/mock_data.go index d0387ab6803e4..ef884a3234270 100644 --- a/internal/querynodev2/mock_data.go +++ b/internal/querynodev2/mock_data.go @@ -108,7 +108,7 @@ func genPlaceHolderGroup(nq int64) ([]byte, error) { Values: make([][]byte, 0), } for i := int64(0); i < nq; i++ { - var vec = make([]float32, defaultDim) + vec := make([]float32, defaultDim) for j := 0; j < defaultDim; j++ { vec[j] = rand.Float32() } diff --git a/internal/querynodev2/mock_query_hook.go b/internal/querynodev2/optimizers/mock_query_hook.go similarity index 94% rename from internal/querynodev2/mock_query_hook.go rename to internal/querynodev2/optimizers/mock_query_hook.go index 79c19a42c0ff0..7c9f5dab88f20 100644 --- a/internal/querynodev2/mock_query_hook.go +++ b/internal/querynodev2/optimizers/mock_query_hook.go @@ -1,7 +1,10 @@ -package querynodev2 +// Code generated by mockery v2.32.4. DO NOT EDIT. -import "github.com/stretchr/testify/mock" +package optimizers +import mock "github.com/stretchr/testify/mock" + +// MockQueryHook is an autogenerated mock type for the QueryHook type type MockQueryHook struct { mock.Mock } @@ -182,13 +185,12 @@ func (_c *MockQueryHook_Run_Call) RunAndReturn(run func(map[string]interface{}) return _c } -type mockConstructorTestingTNewMockQueryHook interface { +// NewMockQueryHook creates a new instance of MockQueryHook. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockQueryHook(t interface { mock.TestingT Cleanup(func()) -} - -// NewMockQueryHook creates a new instance of MockQueryHook. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewMockQueryHook(t mockConstructorTestingTNewMockQueryHook) *MockQueryHook { +}) *MockQueryHook { mock := &MockQueryHook{} mock.Mock.Test(t) diff --git a/internal/querynodev2/optimizers/query_hook.go b/internal/querynodev2/optimizers/query_hook.go new file mode 100644 index 0000000000000..c3703feba12ff --- /dev/null +++ b/internal/querynodev2/optimizers/query_hook.go @@ -0,0 +1,9 @@ +package optimizers + +// QueryHook is the interface for search/query parameter optimizer. +type QueryHook interface { + Run(map[string]any) error + Init(string) error + InitTuningConfig(map[string]string) error + DeleteTuningConfig(string) error +} diff --git a/internal/querynodev2/pipeline/delete_node.go b/internal/querynodev2/pipeline/delete_node.go index 91c3ce9bce1b8..a408f98f9140e 100644 --- a/internal/querynodev2/pipeline/delete_node.go +++ b/internal/querynodev2/pipeline/delete_node.go @@ -74,11 +74,11 @@ func (dNode *deleteNode) Operate(in Msg) Msg { } if len(deleteDatas) > 0 { - //do Delete, use ts range max as ts + // do Delete, use ts range max as ts dNode.delegator.ProcessDelete(lo.Values(deleteDatas), nodeMsg.timeRange.timestampMax) } - //update tSafe + // update tSafe err := dNode.tSafeManager.Set(dNode.channel, nodeMsg.timeRange.timestampMax) if err != nil { // should not happen, QueryNode should addTSafe before start pipeline diff --git a/internal/querynodev2/pipeline/delete_node_test.go b/internal/querynodev2/pipeline/delete_node_test.go index 52b430ac57bfd..1c4fd45d6f2bd 100644 --- a/internal/querynodev2/pipeline/delete_node_test.go +++ b/internal/querynodev2/pipeline/delete_node_test.go @@ -19,18 +19,19 @@ package pipeline import ( "testing" + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tsafe" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/samber/lo" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" ) type DeleteNodeSuite struct { suite.Suite - //datas + // datas collectionID int64 collectionName string partitionIDs []int64 @@ -38,9 +39,9 @@ type DeleteNodeSuite struct { channel string timeRange TimeRange - //dependency + // dependency tSafeManager TSafeManager - //mocks + // mocks manager *segments.Manager delegator *delegator.MockShardDelegator } @@ -51,7 +52,7 @@ func (suite *DeleteNodeSuite) SetupSuite() { suite.collectionName = "test-collection" suite.partitionIDs = []int64{11, 22} suite.channel = "test-channel" - //segment own data row which‘s pk same with segment‘s ID + // segment own data row which‘s pk same with segment‘s ID suite.deletePKs = []int64{1, 2, 3, 4} suite.timeRange = TimeRange{ timestampMin: 0, @@ -74,7 +75,7 @@ func (suite *DeleteNodeSuite) buildDeleteNodeMsg() *deleteNodeMsg { } func (suite *DeleteNodeSuite) TestBasic() { - //mock + // mock mockCollectionManager := segments.NewMockCollectionManager(suite.T()) mockSegmentManager := segments.NewMockSegmentManager(suite.T()) suite.manager = &segments.Manager{ @@ -90,16 +91,16 @@ func (suite *DeleteNodeSuite) TestBasic() { } } }) - //init dependency + // init dependency suite.tSafeManager = tsafe.NewTSafeReplica() suite.tSafeManager.Add(suite.channel, 0) - //build delete node and data + // build delete node and data node := newDeleteNode(suite.collectionID, suite.channel, suite.manager, suite.tSafeManager, suite.delegator, 8) in := suite.buildDeleteNodeMsg() - //run + // run out := node.Operate(in) suite.Nil(out) - //check tsafe + // check tsafe tt, err := suite.tSafeManager.Get(suite.channel) suite.NoError(err) suite.Equal(suite.timeRange.timestampMax, tt) diff --git a/internal/querynodev2/pipeline/filter_node.go b/internal/querynodev2/pipeline/filter_node.go index a523e9f838068..8e4205cb66ae1 100644 --- a/internal/querynodev2/pipeline/filter_node.go +++ b/internal/querynodev2/pipeline/filter_node.go @@ -20,7 +20,6 @@ import ( "fmt" "reflect" - "github.com/golang/protobuf/proto" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -69,7 +68,7 @@ func (fNode *filterNode) Operate(in Msg) Msg { WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.InsertLabel, fmt.Sprint(fNode.collectionID)). Set(float64(tsoutil.SubByNow(streamMsgPack.EndTs))) - //Get collection from collection manager + // Get collection from collection manager collection := fNode.manager.Collection.Get(fNode.collectionID) if collection == nil { log.Fatal("collection not found in meta", zap.Int64("collectionID", fNode.collectionID)) @@ -84,7 +83,7 @@ func (fNode *filterNode) Operate(in Msg) Msg { }, } - //add msg to out if msg pass check of filter + // add msg to out if msg pass check of filter for _, msg := range streamMsgPack.Msgs { err := fNode.filtrate(collection, msg) if err != nil { @@ -105,11 +104,10 @@ func (fNode *filterNode) Operate(in Msg) Msg { // filtrate message with filter policy func (fNode *filterNode) filtrate(c *Collection, msg msgstream.TsMsg) error { - switch msg.Type() { case commonpb.MsgType_Insert: insertMsg := msg.(*msgstream.InsertMsg) - metrics.QueryNodeConsumeCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.InsertLabel).Add(float64(proto.Size(insertMsg))) + metrics.QueryNodeConsumeCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.InsertLabel).Add(float64(insertMsg.Size())) for _, policy := range fNode.InsertMsgPolicys { err := policy(fNode, c, insertMsg) if err != nil { @@ -119,7 +117,7 @@ func (fNode *filterNode) filtrate(c *Collection, msg msgstream.TsMsg) error { case commonpb.MsgType_Delete: deleteMsg := msg.(*msgstream.DeleteMsg) - metrics.QueryNodeConsumeCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.InsertLabel).Add(float64(proto.Size(deleteMsg))) + metrics.QueryNodeConsumeCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.InsertLabel).Add(float64(deleteMsg.Size())) for _, policy := range fNode.DeleteMsgPolicys { err := policy(fNode, c, deleteMsg) if err != nil { diff --git a/internal/querynodev2/pipeline/filter_node_test.go b/internal/querynodev2/pipeline/filter_node_test.go index 2bd51a83187f2..8d5eda9f332d1 100644 --- a/internal/querynodev2/pipeline/filter_node_test.go +++ b/internal/querynodev2/pipeline/filter_node_test.go @@ -34,7 +34,7 @@ import ( // test of filter node type FilterNodeSuite struct { suite.Suite - //datas + // datas collectionID int64 partitionIDs []int64 channel string @@ -44,10 +44,10 @@ type FilterNodeSuite struct { excludedSegmentIDs []int64 insertSegmentIDs []int64 deleteSegmentSum int - //segmentID of msg invalid because empty of not aligned + // segmentID of msg invalid because empty of not aligned errSegmentID int64 - //mocks + // mocks manager *segments.Manager } @@ -63,7 +63,7 @@ func (suite *FilterNodeSuite) SetupSuite() { suite.deleteSegmentSum = 4 suite.errSegmentID = 7 - //init excludedSegment + // init excludedSegment suite.excludedSegments = typeutil.NewConcurrentMap[int64, *datapb.SegmentInfo]() for _, id := range suite.excludedSegmentIDs { suite.excludedSegments.Insert(id, &datapb.SegmentInfo{ @@ -76,10 +76,10 @@ func (suite *FilterNodeSuite) SetupSuite() { // test filter node with collection load collection func (suite *FilterNodeSuite) TestWithLoadCollection() { - //data + // data suite.validSegmentIDs = []int64{2, 3, 4, 5, 6} - //mock + // mock collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadCollection) for _, partitionID := range suite.partitionIDs { collection.AddPartition(partitionID) @@ -111,10 +111,10 @@ func (suite *FilterNodeSuite) TestWithLoadCollection() { // test filter node with collection load partition func (suite *FilterNodeSuite) TestWithLoadPartation() { - //data + // data suite.validSegmentIDs = []int64{2, 3, 4, 5, 6} - //mock + // mock collection := segments.NewCollectionWithoutSchema(suite.collectionID, querypb.LoadType_LoadPartition) collection.AddPartition(suite.partitionIDs[0]) @@ -149,43 +149,43 @@ func (suite *FilterNodeSuite) buildMsgPack() *msgstream.MsgPack { Msgs: []msgstream.TsMsg{}, } - //add valid insert + // add valid insert for _, id := range suite.insertSegmentIDs { insertMsg := buildInsertMsg(suite.collectionID, suite.partitionIDs[id%2], id, suite.channel, 1) msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - //add valid delete + // add valid delete for i := 0; i < suite.deleteSegmentSum; i++ { deleteMsg := buildDeleteMsg(suite.collectionID, suite.partitionIDs[i%2], suite.channel, 1) msgPack.Msgs = append(msgPack.Msgs, deleteMsg) } - //add invalid msg + // add invalid msg - //segment in excludedSegments - //some one end timestamp befroe dmlPosition timestamp will be invalid + // segment in excludedSegments + // some one end timestamp befroe dmlPosition timestamp will be invalid for _, id := range suite.excludedSegmentIDs { insertMsg := buildInsertMsg(suite.collectionID, suite.partitionIDs[id%2], id, suite.channel, 1) insertMsg.EndTimestamp = uint64(id) msgPack.Msgs = append(msgPack.Msgs, insertMsg) } - //empty msg + // empty msg insertMsg := buildInsertMsg(suite.collectionID, suite.partitionIDs[0], suite.errSegmentID, suite.channel, 0) msgPack.Msgs = append(msgPack.Msgs, insertMsg) deleteMsg := buildDeleteMsg(suite.collectionID, suite.partitionIDs[0], suite.channel, 0) msgPack.Msgs = append(msgPack.Msgs, deleteMsg) - //msg not target + // msg not target insertMsg = buildInsertMsg(suite.collectionID+1, 1, 0, "Unknown", 1) msgPack.Msgs = append(msgPack.Msgs, insertMsg) deleteMsg = buildDeleteMsg(suite.collectionID+1, 1, "Unknown", 1) msgPack.Msgs = append(msgPack.Msgs, deleteMsg) - //msg not aligned + // msg not aligned insertMsg = buildInsertMsg(suite.collectionID, suite.partitionIDs[0], suite.errSegmentID, suite.channel, 1) insertMsg.Timestamps = []uint64{} msgPack.Msgs = append(msgPack.Msgs, insertMsg) diff --git a/internal/querynodev2/pipeline/insert_node.go b/internal/querynodev2/pipeline/insert_node.go index 8e7c060a6bce3..16c588bbec4f7 100644 --- a/internal/querynodev2/pipeline/insert_node.go +++ b/internal/querynodev2/pipeline/insert_node.go @@ -101,7 +101,7 @@ func (iNode *insertNode) Operate(in Msg) Msg { panic("insertNode with collection not exist") } - //get InsertData and merge datas of same segment + // get InsertData and merge datas of same segment for _, msg := range nodeMsg.insertMsgs { iNode.addInsertData(insertDatas, msg, collection) } diff --git a/internal/querynodev2/pipeline/insert_node_test.go b/internal/querynodev2/pipeline/insert_node_test.go index a3b58aabbc737..6d6979fa9b71c 100644 --- a/internal/querynodev2/pipeline/insert_node_test.go +++ b/internal/querynodev2/pipeline/insert_node_test.go @@ -19,26 +19,27 @@ package pipeline import ( "testing" + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/samber/lo" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" ) type InsertNodeSuite struct { suite.Suite - //datas + // datas collectionName string collectionID int64 partitionID int64 channel string insertSegmentIDs []int64 deleteSegmentSum int - //mocks + // mocks manager *segments.Manager delegator *delegator.MockShardDelegator } @@ -56,14 +57,14 @@ func (suite *InsertNodeSuite) SetupSuite() { } func (suite *InsertNodeSuite) TestBasic() { - //data + // data schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) in := suite.buildInsertNodeMsg(schema) collection := segments.NewCollection(suite.collectionID, schema, segments.GenTestIndexMeta(suite.collectionID, schema), querypb.LoadType_LoadCollection) collection.AddPartition(suite.partitionID) - //init mock + // init mock mockCollectionManager := segments.NewMockCollectionManager(suite.T()) mockCollectionManager.EXPECT().Get(suite.collectionID).Return(collection) @@ -81,7 +82,7 @@ func (suite *InsertNodeSuite) TestBasic() { } }) - //TODO mock a delgator for test + // TODO mock a delgator for test node := newInsertNode(suite.collectionID, suite.channel, suite.manager, suite.delegator, 8) out := node.Operate(in) @@ -97,7 +98,7 @@ func (suite *InsertNodeSuite) TestDataTypeNotSupported() { collection := segments.NewCollection(suite.collectionID, schema, segments.GenTestIndexMeta(suite.collectionID, schema), querypb.LoadType_LoadCollection) collection.AddPartition(suite.partitionID) - //init mock + // init mock mockCollectionManager := segments.NewMockCollectionManager(suite.T()) mockCollectionManager.EXPECT().Get(suite.collectionID).Return(collection) @@ -116,7 +117,7 @@ func (suite *InsertNodeSuite) TestDataTypeNotSupported() { } } - //TODO mock a delgator for test + // TODO mock a delgator for test node := newInsertNode(suite.collectionID, suite.channel, suite.manager, suite.delegator, 8) suite.Panics(func() { node.Operate(in) diff --git a/internal/querynodev2/pipeline/manager.go b/internal/querynodev2/pipeline/manager.go index 1d486e9fc38a3..cf4a746d7d422 100644 --- a/internal/querynodev2/pipeline/manager.go +++ b/internal/querynodev2/pipeline/manager.go @@ -77,7 +77,7 @@ func (m *manager) Add(collectionID UniqueID, channel string) (Pipeline, error) { return pipeline, nil } - //get shard delegator for add growing in pipeline + // get shard delegator for add growing in pipeline delegator, ok := m.delegators.Get(channel) if !ok { return nil, merr.WrapErrChannelNotFound(channel, "delegator not found") @@ -132,7 +132,7 @@ func (m *manager) Start(channels ...string) error { m.mu.Lock() defer m.mu.Unlock() - //check pipelie all exist before start + // check pipelie all exist before start for _, channel := range channels { if _, ok := m.channel2Pipeline[channel]; !ok { reason := fmt.Sprintf("pipeline with channel %s not exist", channel) diff --git a/internal/querynodev2/pipeline/manager_test.go b/internal/querynodev2/pipeline/manager_test.go index 508ec1aec564e..849857c1fb0c7 100644 --- a/internal/querynodev2/pipeline/manager_test.go +++ b/internal/querynodev2/pipeline/manager_test.go @@ -35,14 +35,14 @@ import ( type PipelineManagerTestSuite struct { suite.Suite - //data + // data collectionID int64 channel string - //dependencies + // dependencies tSafeManager TSafeManager delegators *typeutil.ConcurrentMap[string, delegator.ShardDelegator] - //mocks + // mocks segmentManager *segments.MockSegmentManager collectionManager *segments.MockCollectionManager delegator *delegator.MockShardDelegator @@ -57,13 +57,13 @@ func (suite *PipelineManagerTestSuite) SetupSuite() { func (suite *PipelineManagerTestSuite) SetupTest() { paramtable.Init() - //init dependency + // init dependency // init tsafeManager suite.tSafeManager = tsafe.NewTSafeReplica() suite.tSafeManager.Add(suite.channel, 0) suite.delegators = typeutil.NewConcurrentMap[string, delegator.ShardDelegator]() - //init mock + // init mock // init manager suite.collectionManager = segments.NewMockCollectionManager(suite.T()) suite.segmentManager = segments.NewMockSegmentManager(suite.T()) @@ -75,14 +75,14 @@ func (suite *PipelineManagerTestSuite) SetupTest() { } func (suite *PipelineManagerTestSuite) TestBasic() { - //init mock + // init mock // mock collection manager suite.collectionManager.EXPECT().Get(suite.collectionID).Return(&segments.Collection{}) // mock mq factory - suite.msgDispatcher.EXPECT().Register(suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.msgChan, nil) + suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.msgChan, nil) suite.msgDispatcher.EXPECT().Deregister(suite.channel) - //build manager + // build manager manager := &segments.Manager{ Collection: suite.collectionManager, Segment: suite.segmentManager, @@ -90,24 +90,24 @@ func (suite *PipelineManagerTestSuite) TestBasic() { pipelineManager := NewManager(manager, suite.tSafeManager, suite.msgDispatcher, suite.delegators) defer pipelineManager.Close() - //Add pipeline + // Add pipeline _, err := pipelineManager.Add(suite.collectionID, suite.channel) suite.NoError(err) suite.Equal(1, pipelineManager.Num()) - //Get pipeline + // Get pipeline pipeline := pipelineManager.Get(suite.channel) suite.NotNil(pipeline) - //Init Consumer + // Init Consumer err = pipeline.ConsumeMsgStream(&msgpb.MsgPosition{}) suite.NoError(err) - //Start pipeline + // Start pipeline err = pipelineManager.Start(suite.channel) suite.NoError(err) - //Remove pipeline + // Remove pipeline pipelineManager.Remove(suite.channel) suite.Equal(0, pipelineManager.Num()) } diff --git a/internal/querynodev2/pipeline/message.go b/internal/querynodev2/pipeline/message.go index c5d0f6781a2fa..fd5f3acda7cef 100644 --- a/internal/querynodev2/pipeline/message.go +++ b/internal/querynodev2/pipeline/message.go @@ -17,8 +17,6 @@ package pipeline import ( - "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/querynodev2/collector" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -42,11 +40,11 @@ func (msg *insertNodeMsg) append(taskMsg msgstream.TsMsg) error { case commonpb.MsgType_Insert: insertMsg := taskMsg.(*InsertMsg) msg.insertMsgs = append(msg.insertMsgs, insertMsg) - collector.Rate.Add(metricsinfo.InsertConsumeThroughput, float64(proto.Size(&insertMsg.InsertRequest))) + collector.Rate.Add(metricsinfo.InsertConsumeThroughput, float64(insertMsg.Size())) case commonpb.MsgType_Delete: deleteMsg := taskMsg.(*DeleteMsg) msg.deleteMsgs = append(msg.deleteMsgs, deleteMsg) - collector.Rate.Add(metricsinfo.DeleteConsumeThroughput, float64(proto.Size(&deleteMsg.DeleteRequest))) + collector.Rate.Add(metricsinfo.DeleteConsumeThroughput, float64(deleteMsg.Size())) default: return merr.WrapErrParameterInvalid("msgType is Insert or Delete", "not") } diff --git a/internal/querynodev2/pipeline/pipeline_test.go b/internal/querynodev2/pipeline/pipeline_test.go index b0027fa4f8f18..a1e13e2c77db2 100644 --- a/internal/querynodev2/pipeline/pipeline_test.go +++ b/internal/querynodev2/pipeline/pipeline_test.go @@ -37,7 +37,7 @@ import ( type PipelineTestSuite struct { suite.Suite - //datas + // datas collectionName string collectionID int64 partitionIDs []int64 @@ -45,10 +45,10 @@ type PipelineTestSuite struct { insertSegmentIDs []int64 deletePKs []int64 - //dependencies + // dependencies tSafeManager TSafeManager - //mocks + // mocks segmentManager *segments.MockSegmentManager collectionManager *segments.MockCollectionManager delegator *delegator.MockShardDelegator @@ -89,7 +89,7 @@ func (suite *PipelineTestSuite) buildMsgPack(schema *schemapb.CollectionSchema) func (suite *PipelineTestSuite) SetupTest() { paramtable.Init() - //init mock + // init mock // init manager suite.collectionManager = segments.NewMockCollectionManager(suite.T()) suite.segmentManager = segments.NewMockSegmentManager(suite.T()) @@ -98,21 +98,21 @@ func (suite *PipelineTestSuite) SetupTest() { // init mq dispatcher suite.msgDispatcher = msgdispatcher.NewMockClient(suite.T()) - //init dependency + // init dependency // init tsafeManager suite.tSafeManager = tsafe.NewTSafeReplica() suite.tSafeManager.Add(suite.channel, 0) } func (suite *PipelineTestSuite) TestBasic() { - //init mock + // init mock // mock collection manager schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) collection := segments.NewCollection(suite.collectionID, schema, segments.GenTestIndexMeta(suite.collectionID, schema), querypb.LoadType_LoadCollection) suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection) // mock mq factory - suite.msgDispatcher.EXPECT().Register(suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.msgChan, nil) + suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.msgChan, nil) suite.msgDispatcher.EXPECT().Deregister(suite.channel) // mock delegator @@ -131,7 +131,7 @@ func (suite *PipelineTestSuite) TestBasic() { } } }) - //build pipleine + // build pipleine manager := &segments.Manager{ Collection: suite.collectionManager, Segment: suite.segmentManager, @@ -139,7 +139,7 @@ func (suite *PipelineTestSuite) TestBasic() { pipeline, err := NewPipeLine(suite.collectionID, suite.channel, manager, suite.tSafeManager, suite.msgDispatcher, suite.delegator) suite.NoError(err) - //Init Consumer + // Init Consumer err = pipeline.ConsumeMsgStream(&msgpb.MsgPosition{}) suite.NoError(err) @@ -157,7 +157,7 @@ func (suite *PipelineTestSuite) TestBasic() { // wait pipeline work <-listener.On() - //check tsafe + // check tsafe tsafe, err := suite.tSafeManager.Get(suite.channel) suite.NoError(err) suite.Equal(in.EndTs, tsafe) diff --git a/internal/querynodev2/segments/bloom_filter_set_test.go b/internal/querynodev2/segments/bloom_filter_set_test.go index 25f2be58b5c3d..a427737b4ddba 100644 --- a/internal/querynodev2/segments/bloom_filter_set_test.go +++ b/internal/querynodev2/segments/bloom_filter_set_test.go @@ -19,8 +19,9 @@ package segments import ( "testing" - "github.com/milvus-io/milvus/internal/storage" "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/storage" ) type BloomFilterSetSuite struct { diff --git a/internal/querynodev2/segments/cgo_util.go b/internal/querynodev2/segments/cgo_util.go index 2032b1301c119..3ee10af706103 100644 --- a/internal/querynodev2/segments/cgo_util.go +++ b/internal/querynodev2/segments/cgo_util.go @@ -86,11 +86,11 @@ func GetCProtoBlob(cProto *C.CProto) []byte { func GetLocalUsedSize(path string) (int64, error) { var availableSize int64 - cSize := C.int64_t(availableSize) + cSize := (*C.int64_t)(&availableSize) cPath := C.CString(path) defer C.free(unsafe.Pointer(cPath)) - status := C.GetLocalUsedSize(cPath, &cSize) + status := C.GetLocalUsedSize(cPath, cSize) err := HandleCStatus(&status, "get local used size failed") if err != nil { return 0, err diff --git a/internal/querynodev2/segments/count_reducer.go b/internal/querynodev2/segments/count_reducer.go index 758030b34efc0..70a5f0dfb8535 100644 --- a/internal/querynodev2/segments/count_reducer.go +++ b/internal/querynodev2/segments/count_reducer.go @@ -8,8 +8,7 @@ import ( "github.com/milvus-io/milvus/internal/util/funcutil" ) -type cntReducer struct { -} +type cntReducer struct{} func (r *cntReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) { cnt := int64(0) @@ -23,8 +22,7 @@ func (r *cntReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveR return funcutil.WrapCntToInternalResult(cnt), nil } -type cntReducerSegCore struct { -} +type cntReducerSegCore struct{} func (r *cntReducerSegCore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) { cnt := int64(0) diff --git a/internal/querynodev2/segments/count_reducer_test.go b/internal/querynodev2/segments/count_reducer_test.go index 3dd9094bc6e81..ba33c2d305984 100644 --- a/internal/querynodev2/segments/count_reducer_test.go +++ b/internal/querynodev2/segments/count_reducer_test.go @@ -6,11 +6,10 @@ import ( "github.com/stretchr/testify/suite" - "github.com/milvus-io/milvus/internal/proto/segcorepb" - "github.com/milvus-io/milvus/internal/util/funcutil" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/util/funcutil" ) type InternalCntReducerSuite struct { diff --git a/internal/querynodev2/segments/default_limit_reducer.go b/internal/querynodev2/segments/default_limit_reducer.go index 6532d67ed81c5..7f7af5b2ac237 100644 --- a/internal/querynodev2/segments/default_limit_reducer.go +++ b/internal/querynodev2/segments/default_limit_reducer.go @@ -14,8 +14,26 @@ type defaultLimitReducer struct { schema *schemapb.CollectionSchema } +type mergeParam struct { + limit int64 + outputFieldsId []int64 + schema *schemapb.CollectionSchema + mergeStopForBest bool +} + +func NewMergeParam(limit int64, outputFieldsId []int64, schema *schemapb.CollectionSchema, reduceStopForBest bool) *mergeParam { + return &mergeParam{ + limit: limit, + outputFieldsId: outputFieldsId, + schema: schema, + mergeStopForBest: reduceStopForBest, + } +} + func (r *defaultLimitReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) { - return mergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema) + reduceParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), + r.schema, r.req.GetReq().GetReduceStopForBest()) + return mergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, reduceParam) } func newDefaultLimitReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) *defaultLimitReducer { @@ -25,31 +43,14 @@ func newDefaultLimitReducer(req *querypb.QueryRequest, schema *schemapb.Collecti } } -type extensionLimitReducer struct { - req *querypb.QueryRequest - schema *schemapb.CollectionSchema - extendedLimit int64 -} - -func (r *extensionLimitReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) { - return mergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, r.extendedLimit, r.req.GetReq().GetOutputFieldsId(), r.schema) -} - -func newExtensionLimitReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema, extLimit int64) *extensionLimitReducer { - return &extensionLimitReducer{ - req: req, - schema: schema, - extendedLimit: extLimit, - } -} - type defaultLimitReducerSegcore struct { req *querypb.QueryRequest schema *schemapb.CollectionSchema } func (r *defaultLimitReducerSegcore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) { - return mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema) + mergeParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema, r.req.GetReq().GetReduceStopForBest()) + return mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, mergeParam) } func newDefaultLimitReducerSegcore(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) *defaultLimitReducerSegcore { @@ -58,21 +59,3 @@ func newDefaultLimitReducerSegcore(req *querypb.QueryRequest, schema *schemapb.C schema: schema, } } - -type extensionLimitSegcoreReducer struct { - req *querypb.QueryRequest - schema *schemapb.CollectionSchema - extendedLimit int64 -} - -func (r *extensionLimitSegcoreReducer) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) { - return mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, r.extendedLimit, r.req.GetReq().GetOutputFieldsId(), r.schema) -} - -func newExtensionLimitSegcoreReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema, extLimit int64) *extensionLimitSegcoreReducer { - return &extensionLimitSegcoreReducer{ - req: req, - schema: schema, - extendedLimit: extLimit, - } -} diff --git a/internal/querynodev2/segments/load_field_data_info.go b/internal/querynodev2/segments/load_field_data_info.go index f8bf103bc3163..ca5046b40e904 100644 --- a/internal/querynodev2/segments/load_field_data_info.go +++ b/internal/querynodev2/segments/load_field_data_info.go @@ -21,7 +21,12 @@ package segments #include "segcore/load_field_data_c.h" */ import "C" -import "unsafe" + +import ( + "unsafe" + + "github.com/milvus-io/milvus/internal/proto/datapb" +) type LoadFieldDataInfo struct { cLoadFieldDataInfo C.CLoadFieldDataInfo @@ -49,12 +54,13 @@ func (ld *LoadFieldDataInfo) appendLoadFieldInfo(fieldID int64, rowCount int64) return HandleCStatus(&status, "appendLoadFieldInfo failed") } -func (ld *LoadFieldDataInfo) appendLoadFieldDataPath(fieldID int64, file string) error { +func (ld *LoadFieldDataInfo) appendLoadFieldDataPath(fieldID int64, binlog *datapb.Binlog) error { cFieldID := C.int64_t(fieldID) - cFile := C.CString(file) + cEntriesNum := C.int64_t(binlog.GetEntriesNum()) + cFile := C.CString(binlog.GetLogPath()) defer C.free(unsafe.Pointer(cFile)) - status := C.AppendLoadFieldDataPath(ld.cLoadFieldDataInfo, cFieldID, cFile) + status := C.AppendLoadFieldDataPath(ld.cLoadFieldDataInfo, cFieldID, cEntriesNum, cFile) return HandleCStatus(&status, "appendLoadFieldDataPath failed") } diff --git a/internal/querynodev2/segments/load_index_info.go b/internal/querynodev2/segments/load_index_info.go index 0a075a405fe5c..b076fba818487 100644 --- a/internal/querynodev2/segments/load_index_info.go +++ b/internal/querynodev2/segments/load_index_info.go @@ -87,6 +87,10 @@ func (li *LoadIndexInfo) appendLoadIndexInfo(indexInfo *querypb.FieldIndexInfo, } } + if err := li.appendIndexEngineVersion(indexInfo.GetCurrentIndexVersion()); err != nil { + return err + } + err = li.appendIndexData(indexPaths) return err } @@ -161,3 +165,10 @@ func (li *LoadIndexInfo) appendIndexData(indexKeys []string) error { return HandleCStatus(&status, "AppendIndex failed") } + +func (li *LoadIndexInfo) appendIndexEngineVersion(indexEngineVersion int32) error { + cIndexEngineVersion := C.int32_t(indexEngineVersion) + + status := C.AppendIndexEngineVersionToLoadInfo(li.cLoadIndexInfo, cIndexEngineVersion) + return HandleCStatus(&status, "AppendIndexEngineVersion failed") +} diff --git a/internal/querynodev2/segments/manager.go b/internal/querynodev2/segments/manager.go index 778c96dca9bce..ce251198b6f6e 100644 --- a/internal/querynodev2/segments/manager.go +++ b/internal/querynodev2/segments/manager.go @@ -29,6 +29,8 @@ import ( "fmt" "sync" + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/eventlog" "github.com/milvus-io/milvus/pkg/log" @@ -36,11 +38,16 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" . "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" ) type SegmentFilter func(segment Segment) bool +func WithSkipEmpty() SegmentFilter { + return func(segment Segment) bool { + return segment.InsertCount() > 0 + } +} + func WithPartition(partitionID UniqueID) SegmentFilter { return func(segment Segment) bool { return segment.Partition() == partitionID @@ -152,7 +159,7 @@ func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) { var replacedSegment []Segment mgr.mu.Lock() defer mgr.mu.Unlock() - targetMap := mgr.growingSegments + var targetMap map[int64]Segment switch segmentType { case SegmentTypeGrowing: targetMap = mgr.growingSegments @@ -173,9 +180,7 @@ func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) { zap.Int64("newVersion", segment.Version()), ) // delete redundant segment - if s, ok := segment.(*LocalSegment); ok { - DeleteSegment(s) - } + segment.Release() continue } replacedSegment = append(replacedSegment, oldSegment) @@ -206,7 +211,7 @@ func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) { if len(replacedSegment) > 0 { go func() { for _, segment := range replacedSegment { - remove(segment.(*LocalSegment)) + remove(segment) } }() } @@ -285,6 +290,8 @@ func (mgr *segmentManager) GetAndPinBy(filters ...SegmentFilter) ([]Segment, err mgr.mu.RLock() defer mgr.mu.RUnlock() + filters = append(filters, WithSkipEmpty()) + ret := make([]Segment, 0) var err error defer func() { @@ -321,6 +328,8 @@ func (mgr *segmentManager) GetAndPin(segments []int64, filters ...SegmentFilter) mgr.mu.RLock() defer mgr.mu.RUnlock() + filters = append(filters, WithSkipEmpty()) + lockedSegments := make([]Segment, 0, len(segments)) var err error defer func() { @@ -411,7 +420,7 @@ func (mgr *segmentManager) Remove(segmentID UniqueID, scope querypb.DataScope) ( mgr.mu.Lock() var removeGrowing, removeSealed int - var growing, sealed *LocalSegment + var growing, sealed Segment switch scope { case querypb.DataScope_Streaming: growing = mgr.removeSegmentWithType(SegmentTypeGrowing, segmentID) @@ -450,20 +459,20 @@ func (mgr *segmentManager) Remove(segmentID UniqueID, scope querypb.DataScope) ( return removeGrowing, removeSealed } -func (mgr *segmentManager) removeSegmentWithType(typ SegmentType, segmentID UniqueID) *LocalSegment { +func (mgr *segmentManager) removeSegmentWithType(typ SegmentType, segmentID UniqueID) Segment { switch typ { case SegmentTypeGrowing: s, ok := mgr.growingSegments[segmentID] if ok { delete(mgr.growingSegments, segmentID) - return s.(*LocalSegment) + return s } case SegmentTypeSealed: s, ok := mgr.sealedSegments[segmentID] if ok { delete(mgr.sealedSegments, segmentID) - return s.(*LocalSegment) + return s } default: return nil @@ -475,7 +484,7 @@ func (mgr *segmentManager) removeSegmentWithType(typ SegmentType, segmentID Uniq func (mgr *segmentManager) RemoveBy(filters ...SegmentFilter) (int, int) { mgr.mu.Lock() - var removeGrowing, removeSealed []*LocalSegment + var removeGrowing, removeSealed []Segment for id, segment := range mgr.growingSegments { if filter(segment, filters...) { s := mgr.removeSegmentWithType(SegmentTypeGrowing, id) @@ -513,19 +522,19 @@ func (mgr *segmentManager) Clear() { for id, segment := range mgr.growingSegments { delete(mgr.growingSegments, id) - remove(segment.(*LocalSegment)) + remove(segment) } for id, segment := range mgr.sealedSegments { delete(mgr.sealedSegments, id) - remove(segment.(*LocalSegment)) + remove(segment) } mgr.updateMetric() } func (mgr *segmentManager) updateMetric() { // update collection and partiation metric - var collections, partiations = make(Set[int64]), make(Set[int64]) + collections, partiations := make(Set[int64]), make(Set[int64]) for _, seg := range mgr.growingSegments { collections.Insert(seg.Collection()) partiations.Insert(seg.Partition()) @@ -538,9 +547,9 @@ func (mgr *segmentManager) updateMetric() { metrics.QueryNodeNumPartitions.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(partiations.Len())) } -func remove(segment *LocalSegment) bool { +func remove(segment Segment) bool { rowNum := segment.RowNum() - DeleteSegment(segment) + segment.Release() metrics.QueryNodeNumSegments.WithLabelValues( fmt.Sprint(paramtable.GetNodeID()), diff --git a/internal/querynodev2/segments/mock_data.go b/internal/querynodev2/segments/mock_data.go index 5529a2cebabf7..57d4d8f3630f3 100644 --- a/internal/querynodev2/segments/mock_data.go +++ b/internal/querynodev2/segments/mock_data.go @@ -114,6 +114,14 @@ var simpleBinVecField = vecFieldParam{ fieldName: "binVectorField", } +var simpleFloat16VecField = vecFieldParam{ + id: 112, + dim: defaultDim, + metricType: defaultMetricType, + vecType: schemapb.DataType_Float16Vector, + fieldName: "float16VectorField", +} + var simpleBoolField = constFieldParam{ id: 102, dataType: schemapb.DataType_Bool, @@ -244,6 +252,7 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType) *s fieldDouble := genConstantFieldSchema(simpleDoubleField) // fieldArray := genConstantFieldSchema(simpleArrayField) fieldJSON := genConstantFieldSchema(simpleJSONField) + fieldArray := genConstantFieldSchema(simpleArrayField) floatVecFieldSchema := genVectorFieldSchema(simpleFloatVecField) binVecFieldSchema := genVectorFieldSchema(simpleBinVecField) var pkFieldSchema *schemapb.FieldSchema @@ -265,11 +274,11 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType) *s fieldInt32, fieldFloat, fieldDouble, - // fieldArray, fieldJSON, floatVecFieldSchema, binVecFieldSchema, pkFieldSchema, + fieldArray, }, } @@ -386,6 +395,7 @@ func generateStringArray(numRows int) []string { } return ret } + func generateArrayArray(numRows int) []*schemapb.ScalarField { ret := make([]*schemapb.ScalarField, 0, numRows) for i := 0; i < numRows; i++ { @@ -399,6 +409,7 @@ func generateArrayArray(numRows int) []*schemapb.ScalarField { } return ret } + func generateJSONArray(numRows int) [][]byte { ret := make([][]byte, 0, numRows) for i := 0; i < numRows; i++ { @@ -434,6 +445,16 @@ func generateBinaryVectors(numRows, dim int) []byte { return ret } +func generateFloat16Vectors(numRows, dim int) []byte { + total := numRows * dim * 2 + ret := make([]byte, total) + _, err := rand.Read(ret) + if err != nil { + panic(err) + } + return ret +} + func GenTestScalarFieldData(dType schemapb.DataType, fieldName string, fieldID int64, numRows int) *schemapb.FieldData { ret := &schemapb.FieldData{ Type: dType, @@ -550,8 +571,10 @@ func GenTestScalarFieldData(dType schemapb.DataType, fieldName string, fieldID i Data: &schemapb.ScalarField_JsonData{ JsonData: &schemapb.JSONArray{ Data: generateJSONArray(numRows), - }}, - }} + }, + }, + }, + } default: panic("data type not supported") @@ -589,6 +612,16 @@ func GenTestVectorFiledData(dType schemapb.DataType, fieldName string, fieldID i }, }, } + case schemapb.DataType_Float16Vector: + ret.FieldId = fieldID + ret.Field = &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: generateFloat16Vectors(numRows, dim), + }, + }, + } default: panic("data type not supported") } @@ -639,7 +672,7 @@ func SaveBinLog(ctx context.Context, } k := JoinIDPath(collectionID, partitionID, segmentID, fieldID) - //key := path.Join(defaultLocalStorage, "insert-log", k) + // key := path.Join(defaultLocalStorage, "insert-log", k) key := path.Join(chunkManager.RootPath(), "insert-log", k) kvs[key] = blob.Value fieldBinlog = append(fieldBinlog, &datapb.FieldBinlog{ @@ -662,7 +695,7 @@ func SaveBinLog(ctx context.Context, } k := JoinIDPath(collectionID, partitionID, segmentID, fieldID) - //key := path.Join(defaultLocalStorage, "stats-log", k) + // key := path.Join(defaultLocalStorage, "stats-log", k) key := path.Join(chunkManager.RootPath(), "stats-log", k) kvs[key] = blob.Value[:] statsBinlog = append(statsBinlog, &datapb.FieldBinlog{ @@ -680,7 +713,8 @@ func genStorageBlob(collectionID int64, partitionID int64, segmentID int64, msgLength int, - schema *schemapb.CollectionSchema) ([]*storage.Blob, []*storage.Blob, error) { + schema *schemapb.CollectionSchema, +) ([]*storage.Blob, []*storage.Blob, error) { tmpSchema := &schemapb.CollectionSchema{ Name: schema.Name, AutoID: schema.AutoID, @@ -778,6 +812,12 @@ func genInsertData(msgLength int, schema *schemapb.CollectionSchema) (*storage.I Data: generateFloatVectors(msgLength, dim), Dim: dim, } + case schemapb.DataType_Float16Vector: + dim := simpleFloat16VecField.dim + insertData.Data[f.FieldID] = &storage.Float16VectorFieldData{ + Data: generateFloat16Vectors(msgLength, dim), + Dim: dim, + } case schemapb.DataType_BinaryVector: dim := simpleBinVecField.dim insertData.Data[f.FieldID] = &storage.BinaryVectorFieldData{ @@ -808,7 +848,6 @@ func SaveDeltaLog(collectionID int64, segmentID int64, cm storage.ChunkManager, ) ([]*datapb.FieldBinlog, error) { - binlogWriter := storage.NewDeleteBinlogWriter(schemapb.DataType_String, collectionID, partitionID, segmentID) eventWriter, _ := binlogWriter.NextDeleteEventWriter() dData := &storage.DeleteData{ @@ -840,12 +879,16 @@ func SaveDeltaLog(collectionID int64, fieldBinlog := make([]*datapb.FieldBinlog, 0) log.Debug("[query node unittest] save delta log", zap.Int64("fieldID", pkFieldID)) key := JoinIDPath(collectionID, partitionID, segmentID, pkFieldID) - //keyPath := path.Join(defaultLocalStorage, "delta-log", key) + // keyPath := path.Join(defaultLocalStorage, "delta-log", key) keyPath := path.Join(cm.RootPath(), "delta-log", key) kvs[keyPath] = blob.Value[:] fieldBinlog = append(fieldBinlog, &datapb.FieldBinlog{ FieldID: pkFieldID, - Binlogs: []*datapb.Binlog{{LogPath: keyPath}}, + Binlogs: []*datapb.Binlog{{ + LogPath: keyPath, + TimestampFrom: 100, + TimestampTo: 200, + }}, }) log.Debug("[query node unittest] save delta log file to MinIO/S3") @@ -892,7 +935,7 @@ func GenAndSaveIndex(collectionID, partitionID, segmentID, fieldID int64, msgLen indexPaths := make([]string, 0) for _, index := range serializedIndexBlobs { - //indexPath := filepath.Join(defaultLocalStorage, strconv.Itoa(int(segmentID)), index.Key) + // indexPath := filepath.Join(defaultLocalStorage, strconv.Itoa(int(segmentID)), index.Key) indexPath := filepath.Join(cm.RootPath(), "index_files", strconv.Itoa(int(segmentID)), index.Key) indexPaths = append(indexPaths, indexPath) @@ -901,13 +944,15 @@ func GenAndSaveIndex(collectionID, partitionID, segmentID, fieldID int64, msgLen return nil, err } } + _, cCurrentIndexVersion := getIndexEngineVersion() return &querypb.FieldIndexInfo{ - FieldID: fieldID, - EnableIndex: true, - IndexName: "querynode-test", - IndexParams: funcutil.Map2KeyValuePair(indexParams), - IndexFilePaths: indexPaths, + FieldID: fieldID, + EnableIndex: true, + IndexName: "querynode-test", + IndexParams: funcutil.Map2KeyValuePair(indexParams), + IndexFilePaths: indexPaths, + CurrentIndexVersion: cCurrentIndexVersion, }, nil } @@ -932,13 +977,13 @@ func genIndexParams(indexType, metricType string) (map[string]string, map[string } else if indexType == IndexHNSW { indexParams["M"] = strconv.Itoa(16) indexParams["efConstruction"] = strconv.Itoa(efConstruction) - //indexParams["ef"] = strconv.Itoa(ef) + // indexParams["ef"] = strconv.Itoa(ef) } else if indexType == IndexFaissBinIVFFlat { // binary vector indexParams["nlist"] = strconv.Itoa(nlist) indexParams["m"] = strconv.Itoa(m) indexParams["nbits"] = strconv.Itoa(nbits) } else if indexType == IndexFaissBinIDMap { - //indexParams[common.DimKey] = strconv.Itoa(defaultDim) + // indexParams[common.DimKey] = strconv.Itoa(defaultDim) } else { panic("") } @@ -1001,7 +1046,7 @@ func genPlaceHolderGroup(nq int64) ([]byte, error) { Values: make([][]byte, 0), } for i := int64(0); i < nq; i++ { - var vec = make([]float32, defaultDim) + vec := make([]float32, defaultDim) for j := 0; j < defaultDim; j++ { vec[j] = rand.Float32() } @@ -1042,7 +1087,7 @@ func genBruteForceDSL(schema *schemapb.CollectionSchema, topK int64, roundDecima roundDecimalStr := strconv.FormatInt(roundDecimal, 10) var fieldID int64 for _, f := range schema.Fields { - if f.DataType == schemapb.DataType_FloatVector { + if f.DataType == schemapb.DataType_FloatVector || f.DataType == schemapb.DataType_Float16Vector { vecFieldName = f.Name fieldID = f.FieldID for _, p := range f.IndexParams { @@ -1152,7 +1197,6 @@ func checkSearchResult(nq int64, plan *SearchPlan, searchResult *SearchResult) e } func genSearchPlanAndRequests(collection *Collection, segments []int64, indexType string, nq int64) (*SearchRequest, error) { - iReq, _ := genSearchRequest(nq, indexType, collection) queryReq := &querypb.SearchRequest{ Req: iReq, @@ -1195,6 +1239,9 @@ func genInsertMsg(collection *Collection, partitionID, segment int64, numRows in case schemapb.DataType_BinaryVector: dim := simpleBinVecField.dim // if no dim specified, use simpleFloatVecField's dim fieldsData = append(fieldsData, GenTestVectorFiledData(f.DataType, f.Name, f.FieldID, numRows, dim)) + case schemapb.DataType_Float16Vector: + dim := simpleFloat16VecField.dim // if no dim specified, use simpleFloatVecField's dim + fieldsData = append(fieldsData, GenTestVectorFiledData(f.DataType, f.Name, f.FieldID, numRows, dim)) default: err := errors.New("data type not supported") return nil, err @@ -1429,6 +1476,20 @@ func genFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType, }, FieldId: fieldID, } + case schemapb.DataType_Float16Vector: + fieldData = &schemapb.FieldData{ + Type: schemapb.DataType_Float16Vector, + FieldName: fieldName, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: fieldValue.([]byte), + }, + }, + }, + FieldId: fieldID, + } case schemapb.DataType_JSON: fieldData = &schemapb.FieldData{ Type: schemapb.DataType_JSON, diff --git a/internal/querynodev2/segments/mock_segment.go b/internal/querynodev2/segments/mock_segment.go index a06d54c9f5eeb..e5377c7147963 100644 --- a/internal/querynodev2/segments/mock_segment.go +++ b/internal/querynodev2/segments/mock_segment.go @@ -3,7 +3,10 @@ package segments import ( + context "context" + commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + mock "github.com/stretchr/testify/mock" msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" @@ -680,6 +683,93 @@ func (_c *MockSegment_RUnlock_Call) RunAndReturn(run func()) *MockSegment_RUnloc return _c } +// Release provides a mock function with given fields: +func (_m *MockSegment) Release() { + _m.Called() +} + +// MockSegment_Release_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Release' +type MockSegment_Release_Call struct { + *mock.Call +} + +// Release is a helper method to define mock.On call +func (_e *MockSegment_Expecter) Release() *MockSegment_Release_Call { + return &MockSegment_Release_Call{Call: _e.mock.On("Release")} +} + +func (_c *MockSegment_Release_Call) Run(run func()) *MockSegment_Release_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSegment_Release_Call) Return() *MockSegment_Release_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSegment_Release_Call) RunAndReturn(run func()) *MockSegment_Release_Call { + _c.Call.Return(run) + return _c +} + +// Retrieve provides a mock function with given fields: ctx, plan +func (_m *MockSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) { + ret := _m.Called(ctx, plan) + + var r0 *segcorepb.RetrieveResults + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan) (*segcorepb.RetrieveResults, error)); ok { + return rf(ctx, plan) + } + if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan) *segcorepb.RetrieveResults); ok { + r0 = rf(ctx, plan) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*segcorepb.RetrieveResults) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *RetrievePlan) error); ok { + r1 = rf(ctx, plan) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSegment_Retrieve_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Retrieve' +type MockSegment_Retrieve_Call struct { + *mock.Call +} + +// Retrieve is a helper method to define mock.On call +// - ctx context.Context +// - plan *RetrievePlan +func (_e *MockSegment_Expecter) Retrieve(ctx interface{}, plan interface{}) *MockSegment_Retrieve_Call { + return &MockSegment_Retrieve_Call{Call: _e.mock.On("Retrieve", ctx, plan)} +} + +func (_c *MockSegment_Retrieve_Call) Run(run func(ctx context.Context, plan *RetrievePlan)) *MockSegment_Retrieve_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*RetrievePlan)) + }) + return _c +} + +func (_c *MockSegment_Retrieve_Call) Return(_a0 *segcorepb.RetrieveResults, _a1 error) *MockSegment_Retrieve_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSegment_Retrieve_Call) RunAndReturn(run func(context.Context, *RetrievePlan) (*segcorepb.RetrieveResults, error)) *MockSegment_Retrieve_Call { + _c.Call.Return(run) + return _c +} + // RowNum provides a mock function with given fields: func (_m *MockSegment) RowNum() int64 { ret := _m.Called() @@ -721,6 +811,61 @@ func (_c *MockSegment_RowNum_Call) RunAndReturn(run func() int64) *MockSegment_R return _c } +// Search provides a mock function with given fields: ctx, searchReq +func (_m *MockSegment) Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) { + ret := _m.Called(ctx, searchReq) + + var r0 *SearchResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *SearchRequest) (*SearchResult, error)); ok { + return rf(ctx, searchReq) + } + if rf, ok := ret.Get(0).(func(context.Context, *SearchRequest) *SearchResult); ok { + r0 = rf(ctx, searchReq) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*SearchResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *SearchRequest) error); ok { + r1 = rf(ctx, searchReq) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSegment_Search_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Search' +type MockSegment_Search_Call struct { + *mock.Call +} + +// Search is a helper method to define mock.On call +// - ctx context.Context +// - searchReq *SearchRequest +func (_e *MockSegment_Expecter) Search(ctx interface{}, searchReq interface{}) *MockSegment_Search_Call { + return &MockSegment_Search_Call{Call: _e.mock.On("Search", ctx, searchReq)} +} + +func (_c *MockSegment_Search_Call) Run(run func(ctx context.Context, searchReq *SearchRequest)) *MockSegment_Search_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*SearchRequest)) + }) + return _c +} + +func (_c *MockSegment_Search_Call) Return(_a0 *SearchResult, _a1 error) *MockSegment_Search_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSegment_Search_Call) RunAndReturn(run func(context.Context, *SearchRequest) (*SearchResult, error)) *MockSegment_Search_Call { + _c.Call.Return(run) + return _c +} + // Shard provides a mock function with given fields: func (_m *MockSegment) Shard() string { ret := _m.Called() diff --git a/internal/querynodev2/segments/plan.go b/internal/querynodev2/segments/plan.go index 81d081eed4745..a9f07bfc78a1a 100644 --- a/internal/querynodev2/segments/plan.go +++ b/internal/querynodev2/segments/plan.go @@ -53,7 +53,7 @@ func createSearchPlanByExpr(col *Collection, expr []byte, metricType string) (*S return nil, err1 } - var newPlan = &SearchPlan{cSearchPlan: cPlan} + newPlan := &SearchPlan{cSearchPlan: cPlan} if len(metricType) != 0 { newPlan.setMetricType(metricType) } else { @@ -106,7 +106,7 @@ func NewSearchRequest(collection *Collection, req *querypb.SearchRequest, placeh return nil, errors.New("empty search request") } - var blobPtr = unsafe.Pointer(&placeholderGrp[0]) + blobPtr := unsafe.Pointer(&placeholderGrp[0]) blobSize := C.int64_t(len(placeholderGrp)) var cPlaceholderGroup C.CPlaceholderGroup status := C.ParsePlaceholderGroup(plan.cSearchPlan, blobPtr, blobSize, &cPlaceholderGroup) @@ -153,7 +153,7 @@ func parseSearchRequest(plan *SearchPlan, searchRequestBlob []byte) (*SearchRequ if len(searchRequestBlob) == 0 { return nil, fmt.Errorf("empty search request") } - var blobPtr = unsafe.Pointer(&searchRequestBlob[0]) + blobPtr := unsafe.Pointer(&searchRequestBlob[0]) blobSize := C.int64_t(len(searchRequestBlob)) var cPlaceholderGroup C.CPlaceholderGroup status := C.ParsePlaceholderGroup(plan.cSearchPlan, blobPtr, blobSize, &cPlaceholderGroup) @@ -162,7 +162,7 @@ func parseSearchRequest(plan *SearchPlan, searchRequestBlob []byte) (*SearchRequ return nil, err } - var ret = &SearchRequest{cPlaceholderGroup: cPlaceholderGroup, plan: plan} + ret := &SearchRequest{cPlaceholderGroup: cPlaceholderGroup, plan: plan} return ret, nil } @@ -189,7 +189,7 @@ func NewRetrievePlan(col *Collection, expr []byte, timestamp Timestamp, msgID Un return nil, err } - var newPlan = &RetrievePlan{ + newPlan := &RetrievePlan{ cRetrievePlan: cPlan, Timestamp: timestamp, msgID: msgID, diff --git a/internal/querynodev2/segments/pool.go b/internal/querynodev2/segments/pool.go index fc34ccc0e4f76..a6196da59eac6 100644 --- a/internal/querynodev2/segments/pool.go +++ b/internal/querynodev2/segments/pool.go @@ -21,9 +21,10 @@ import ( "runtime" "sync" + "go.uber.org/atomic" + "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/paramtable" - "go.uber.org/atomic" ) var ( diff --git a/internal/querynodev2/segments/reduce.go b/internal/querynodev2/segments/reduce.go index 4a55190950670..7e8ec94441695 100644 --- a/internal/querynodev2/segments/reduce.go +++ b/internal/querynodev2/segments/reduce.go @@ -23,6 +23,7 @@ package segments #include "segcore/reduce_c.h" */ import "C" + import ( "fmt" ) @@ -70,7 +71,8 @@ func ParseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *S } func ReduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchResult, - numSegments int64, sliceNQs []int64, sliceTopKs []int64) (searchResultDataBlobs, error) { + numSegments int64, sliceNQs []int64, sliceTopKs []int64, +) (searchResultDataBlobs, error) { if plan.cSearchPlan == nil { return nil, fmt.Errorf("nil search plan") } @@ -92,9 +94,9 @@ func ReduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchRes } cSearchResultPtr := &cSearchResults[0] cNumSegments := C.int64_t(numSegments) - var cSliceNQSPtr = (*C.int64_t)(&sliceNQs[0]) - var cSliceTopKSPtr = (*C.int64_t)(&sliceTopKs[0]) - var cNumSlices = C.int64_t(len(sliceNQs)) + cSliceNQSPtr := (*C.int64_t)(&sliceNQs[0]) + cSliceTopKSPtr := (*C.int64_t)(&sliceTopKs[0]) + cNumSlices := C.int64_t(len(sliceNQs)) var cSearchResultDataBlobs searchResultDataBlobs status := C.ReduceSearchResultsAndFillData(&cSearchResultDataBlobs, plan.cSearchPlan, cSearchResultPtr, cNumSegments, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices) diff --git a/internal/querynodev2/segments/reduce_test.go b/internal/querynodev2/segments/reduce_test.go index d299561e7cbed..573a3dfe6f6a4 100644 --- a/internal/querynodev2/segments/reduce_test.go +++ b/internal/querynodev2/segments/reduce_test.go @@ -99,7 +99,7 @@ func (suite *ReduceSuite) SetupTest() { } func (suite *ReduceSuite) TearDownTest() { - DeleteSegment(suite.segment) + suite.segment.Release() DeleteCollection(suite.collection) ctx := context.Background() suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath) diff --git a/internal/querynodev2/segments/reducer.go b/internal/querynodev2/segments/reducer.go index 990035e9fd200..f6e2f2b1d4613 100644 --- a/internal/querynodev2/segments/reducer.go +++ b/internal/querynodev2/segments/reducer.go @@ -3,11 +3,10 @@ package segments import ( "context" - "github.com/milvus-io/milvus/internal/proto/segcorepb" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/segcorepb" ) type internalReducer interface { @@ -17,9 +16,6 @@ type internalReducer interface { func CreateInternalReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) internalReducer { if req.GetReq().GetIsCount() { return &cntReducer{} - } else if req.GetReq().GetIterationExtensionReduceRate() > 0 { - extendedLimit := req.GetReq().GetIterationExtensionReduceRate() * req.GetReq().Limit - return newExtensionLimitReducer(req, schema, extendedLimit) } return newDefaultLimitReducer(req, schema) } @@ -31,9 +27,6 @@ type segCoreReducer interface { func CreateSegCoreReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema) segCoreReducer { if req.GetReq().GetIsCount() { return &cntReducerSegCore{} - } else if req.GetReq().GetIterationExtensionReduceRate() > 0 { - extendedLimit := req.GetReq().GetIterationExtensionReduceRate() * req.GetReq().Limit - return newExtensionLimitSegcoreReducer(req, schema, extendedLimit) } return newDefaultLimitReducerSegcore(req, schema) } diff --git a/internal/querynodev2/segments/reducer_test.go b/internal/querynodev2/segments/reducer_test.go index fa58cd04452ba..2c1940014e63a 100644 --- a/internal/querynodev2/segments/reducer_test.go +++ b/internal/querynodev2/segments/reducer_test.go @@ -40,18 +40,9 @@ func (suite *ReducerFactorySuite) TestCreateInternalReducer() { suite.ir = CreateInternalReducer(req, nil) _, suite.ok = suite.ir.(*cntReducer) suite.True(suite.ok) - - req.GetReq().IsCount = false - req.GetReq().IterationExtensionReduceRate = 10 - req.GetReq().Limit = 10 - suite.ir = CreateInternalReducer(req, nil) - extReducer, typeOk := suite.ir.(*extensionLimitReducer) - suite.True(typeOk) - suite.Equal(int64(100), extReducer.extendedLimit) } func (suite *ReducerFactorySuite) TestCreateSegCoreReducer() { - req := &querypb.QueryRequest{ Req: &internalpb.RetrieveRequest{ IsCount: false, @@ -66,12 +57,4 @@ func (suite *ReducerFactorySuite) TestCreateSegCoreReducer() { suite.sr = CreateSegCoreReducer(req, nil) _, suite.ok = suite.sr.(*cntReducerSegCore) suite.True(suite.ok) - - req.GetReq().IsCount = false - req.GetReq().IterationExtensionReduceRate = 10 - req.GetReq().Limit = 10 - suite.sr = CreateSegCoreReducer(req, nil) - extReducer, typeOk := suite.sr.(*extensionLimitSegcoreReducer) - suite.True(typeOk) - suite.Equal(int64(100), extReducer.extendedLimit) } diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index 9703ab0e03ce3..b3421e2a9a71a 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -123,10 +123,12 @@ func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se } var skipDupCnt int64 + var retSize int64 + maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for i := int64(0); i < nq; i++ { offsets := make([]int64, len(searchResultData)) - var idSet = make(map[interface{}]struct{}) + idSet := make(map[interface{}]struct{}) var j int64 for j = 0; j < topk; { sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i) @@ -140,7 +142,7 @@ func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se // remove duplicates if _, ok := idSet[id]; !ok { - typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx) + retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx) typeutil.AppendPKs(ret.Ids, id) ret.Scores = append(ret.Scores, score) idSet[id] = struct{}{} @@ -159,8 +161,8 @@ func ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se ret.Topks = append(ret.Topks, j) // limit search result to avoid oom - if int64(proto.Size(ret)) > paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() { - return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()) + if retSize > maxOutputSize { + return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize) } } log.Debug("skip duplicated search result", zap.Int64("count", skipDupCnt)) @@ -222,7 +224,7 @@ func DecodeSearchResults(searchResults []*internalpb.SearchResults) ([]*schemapb func EncodeSearchResultData(searchResultData *schemapb.SearchResultData, nq int64, topk int64, metricType string) (searchResults *internalpb.SearchResults, err error) { searchResults = &internalpb.SearchResults{ - Status: merr.Status(nil), + Status: merr.Success(), NumQueries: nq, TopK: topk, MetricType: metricType, @@ -238,14 +240,15 @@ func EncodeSearchResultData(searchResultData *schemapb.SearchResultData, nq int6 return } -func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, limit int64) (*internalpb.RetrieveResults, error) { +func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, param *mergeParam) (*internalpb.RetrieveResults, error) { log.Ctx(ctx).Debug("mergeInternelRetrieveResults", - zap.Int64("limit", limit), + zap.Int64("limit", param.limit), zap.Int("resultNum", len(retrieveResults)), ) var ( ret = &internalpb.RetrieveResults{ - Ids: &schemapb.IDs{}, + Status: merr.Success(), + Ids: &schemapb.IDs{}, } skipDupCnt int64 loopEnd int @@ -265,15 +268,18 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna return ret, nil } - if limit != typeutil.Unlimited { - loopEnd = int(limit) + if param.limit != typeutil.Unlimited && !param.mergeStopForBest { + loopEnd = int(param.limit) } ret.FieldsData = make([]*schemapb.FieldData, len(validRetrieveResults[0].GetFieldsData())) idTsMap := make(map[interface{}]uint64) cursors := make([]int64, len(validRetrieveResults)) - for j := 0; j < loopEnd; j++ { - sel := typeutil.SelectMinPK(validRetrieveResults, cursors) + + var retSize int64 + maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() + for j := 0; j < loopEnd; { + sel := typeutil.SelectMinPK(validRetrieveResults, cursors, param.mergeStopForBest, param.limit) if sel == -1 { break } @@ -282,21 +288,22 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna ts := getTS(validRetrieveResults[sel], cursors[sel]) if _, ok := idTsMap[pk]; !ok { typeutil.AppendPKs(ret.Ids, pk) - typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) + retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) idTsMap[pk] = ts + j++ } else { // primary keys duplicate skipDupCnt++ if ts != 0 && ts > idTsMap[pk] { idTsMap[pk] = ts typeutil.DeleteFieldData(ret.FieldsData) - typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) + retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) } } // limit retrieve result to avoid oom - if int64(proto.Size(ret)) > paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() { - return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()) + if retSize > maxOutputSize { + return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", maxOutputSize) } cursors[sel]++ @@ -336,9 +343,9 @@ func getTS(i *internalpb.RetrieveResults, idx int64) uint64 { return 0 } -func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, limit int64) (*segcorepb.RetrieveResults, error) { +func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, param *mergeParam) (*segcorepb.RetrieveResults, error) { log.Ctx(ctx).Debug("mergeSegcoreRetrieveResults", - zap.Int64("limit", limit), + zap.Int64("limit", param.limit), zap.Int("resultNum", len(retrieveResults)), ) var ( @@ -365,15 +372,18 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore return ret, nil } - if limit != typeutil.Unlimited { - loopEnd = int(limit) + if param.limit != typeutil.Unlimited && !param.mergeStopForBest { + loopEnd = int(param.limit) } ret.FieldsData = make([]*schemapb.FieldData, len(validRetrieveResults[0].GetFieldsData())) idSet := make(map[interface{}]struct{}) cursors := make([]int64, len(validRetrieveResults)) + + var retSize int64 + maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd; j++ { - sel := typeutil.SelectMinPK(validRetrieveResults, cursors) + sel := typeutil.SelectMinPK(validRetrieveResults, cursors, param.mergeStopForBest, param.limit) if sel == -1 { break } @@ -381,7 +391,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel]) if _, ok := idSet[pk]; !ok { typeutil.AppendPKs(ret.Ids, pk) - typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) + retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) idSet[pk] = struct{}{} } else { // primary keys duplicate @@ -389,8 +399,8 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore } // limit retrieve result to avoid oom - if int64(proto.Size(ret)) > paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() { - return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()) + if retSize > maxOutputSize { + return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", maxOutputSize) } cursors[sel]++ @@ -406,17 +416,14 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore func mergeInternalRetrieveResultsAndFillIfEmpty( ctx context.Context, retrieveResults []*internalpb.RetrieveResults, - limit int64, - outputFieldsID []int64, - schema *schemapb.CollectionSchema, + param *mergeParam, ) (*internalpb.RetrieveResults, error) { - - mergedResult, err := MergeInternalRetrieveResult(ctx, retrieveResults, limit) + mergedResult, err := MergeInternalRetrieveResult(ctx, retrieveResults, param) if err != nil { return nil, err } - if err := typeutil2.FillRetrieveResultIfEmpty(typeutil2.NewInternalResult(mergedResult), outputFieldsID, schema); err != nil { + if err := typeutil2.FillRetrieveResultIfEmpty(typeutil2.NewInternalResult(mergedResult), param.outputFieldsId, param.schema); err != nil { return nil, fmt.Errorf("failed to fill internal retrieve results: %s", err.Error()) } @@ -426,17 +433,14 @@ func mergeInternalRetrieveResultsAndFillIfEmpty( func mergeSegcoreRetrieveResultsAndFillIfEmpty( ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, - limit int64, - outputFieldsID []int64, - schema *schemapb.CollectionSchema, + param *mergeParam, ) (*segcorepb.RetrieveResults, error) { - - mergedResult, err := MergeSegcoreRetrieveResults(ctx, retrieveResults, limit) + mergedResult, err := MergeSegcoreRetrieveResults(ctx, retrieveResults, param) if err != nil { return nil, err } - if err := typeutil2.FillRetrieveResultIfEmpty(typeutil2.NewSegcoreResults(mergedResult), outputFieldsID, schema); err != nil { + if err := typeutil2.FillRetrieveResultIfEmpty(typeutil2.NewSegcoreResults(mergedResult), param.outputFieldsId, param.schema); err != nil { return nil, fmt.Errorf("failed to fill segcore retrieve results: %s", err.Error()) } diff --git a/internal/querynodev2/segments/result_test.go b/internal/querynodev2/segments/result_test.go index efb18914764d2..867611dc1fbc9 100644 --- a/internal/querynodev2/segments/result_test.go +++ b/internal/querynodev2/segments/result_test.go @@ -80,7 +80,8 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { FieldsData: fieldDataArray2, } - result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, typeutil.Unlimited) + result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.NoError(err) suite.Equal(2, len(result.GetFieldsData())) suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData()) @@ -89,7 +90,8 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { }) suite.Run("test nil results", func() { - ret, err := MergeSegcoreRetrieveResults(context.Background(), nil, typeutil.Unlimited) + ret, err := MergeSegcoreRetrieveResults(context.Background(), nil, + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.NoError(err) suite.Empty(ret.GetIds()) suite.Empty(ret.GetFieldsData()) @@ -107,7 +109,8 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { FieldsData: fieldDataArray1, } - ret, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r}, typeutil.Unlimited) + ret, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r}, + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.NoError(err) suite.Empty(ret.GetIds()) suite.Empty(ret.GetFieldsData()) @@ -141,7 +144,8 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0, - 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} + 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0, + } suite.Run("test limited", func() { tests := []struct { @@ -157,7 +161,8 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { resultField0 := []int64{11, 11, 22, 22} for _, test := range tests { suite.Run(test.description, func() { - result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, test.limit) + result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, + NewMergeParam(test.limit, make([]int64, 0), nil, false)) suite.Equal(2, len(result.GetFieldsData())) suite.Equal(int(test.limit), len(result.GetIds().GetIntId().GetData())) suite.Equal(resultIDs[0:test.limit], result.GetIds().GetIntId().GetData()) @@ -192,13 +197,15 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { FieldsData: []*schemapb.FieldData{fieldData}, } - _, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{result}, reqLimit) + _, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{result}, + NewMergeParam(reqLimit, make([]int64, 0), nil, false)) suite.Error(err) paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600") }) suite.Run("test int ID", func() { - result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, typeutil.Unlimited) + result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.Equal(2, len(result.GetFieldsData())) suite.Equal([]int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) suite.Equal([]int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) @@ -211,15 +218,20 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { IdField: &schemapb.IDs_StrId{ StrId: &schemapb.StringArray{ Data: []string{"a", "c"}, - }}} + }, + }, + } r2.Ids = &schemapb.IDs{ IdField: &schemapb.IDs_StrId{ StrId: &schemapb.StringArray{ Data: []string{"b", "d"}, - }}} + }, + }, + } - result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, typeutil.Unlimited) + result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.NoError(err) suite.Equal(2, len(result.GetFieldsData())) suite.Equal([]string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData()) @@ -227,7 +239,6 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { suite.InDeltaSlice(resultFloat, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) suite.NoError(err) }) - }) } @@ -272,7 +283,8 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { FieldsData: fieldDataArray2, } - result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, typeutil.Unlimited) + result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.NoError(err) suite.Equal(2, len(result.GetFieldsData())) suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData()) @@ -281,7 +293,8 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { }) suite.Run("test nil results", func() { - ret, err := MergeInternalRetrieveResult(context.Background(), nil, typeutil.Unlimited) + ret, err := MergeInternalRetrieveResult(context.Background(), nil, + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.NoError(err) suite.Empty(ret.GetIds()) suite.Empty(ret.GetFieldsData()) @@ -293,7 +306,8 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { IdField: &schemapb.IDs_IntId{ IntId: &schemapb.LongArray{ Data: []int64{0, 1}, - }}, + }, + }, }, FieldsData: []*schemapb.FieldData{ genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, @@ -307,7 +321,8 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { IdField: &schemapb.IDs_IntId{ IntId: &schemapb.LongArray{ Data: []int64{0, 1}, - }}, + }, + }, }, FieldsData: []*schemapb.FieldData{ genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, @@ -316,7 +331,8 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { []int64{7, 8}, 1), }, } - result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{ret1, ret2}, typeutil.Unlimited) + result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{ret1, ret2}, + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.NoError(err) suite.Equal(2, len(result.GetFieldsData())) suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData()) @@ -349,7 +365,8 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0, - 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} + 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0, + } suite.Run("test limited", func() { tests := []struct { @@ -365,7 +382,8 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { resultField0 := []int64{11, 11, 22, 22} for _, test := range tests { suite.Run(test.description, func() { - result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, test.limit) + result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, + NewMergeParam(test.limit, make([]int64, 0), nil, false)) suite.Equal(2, len(result.GetFieldsData())) suite.Equal(int(test.limit), len(result.GetIds().GetIntId().GetData())) suite.Equal(resultIDs[0:test.limit], result.GetIds().GetIntId().GetData()) @@ -398,13 +416,15 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { FieldsData: []*schemapb.FieldData{fieldData}, } - _, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result}, typeutil.Unlimited) + _, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result}, + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.Error(err) paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600") }) suite.Run("test int ID", func() { - result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, typeutil.Unlimited) + result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.Equal(2, len(result.GetFieldsData())) suite.Equal([]int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) suite.Equal([]int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) @@ -429,7 +449,8 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { }, } - result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, typeutil.Unlimited) + result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.NoError(err) suite.Equal(2, len(result.GetFieldsData())) suite.Equal([]string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData()) @@ -437,7 +458,110 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { suite.InDeltaSlice(resultFloat, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) suite.NoError(err) }) + }) +} +func (suite *ResultSuite) TestResult_MergeStopForBestResult() { + const ( + Dim = 4 + Int64FieldName = "Int64Field" + FloatVectorFieldName = "FloatVectorField" + Int64FieldID = common.StartOfUserFieldID + 1 + FloatVectorFieldID = common.StartOfUserFieldID + 2 + ) + Int64Array := []int64{11, 22, 33} + FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0} + + var fieldDataArray1 []*schemapb.FieldData + fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, + schemapb.DataType_Int64, Int64Array[0:3], 1)) + fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, + schemapb.DataType_FloatVector, FloatVector[0:12], Dim)) + + var fieldDataArray2 []*schemapb.FieldData + fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, + schemapb.DataType_Int64, Int64Array[0:3], 1)) + fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, + schemapb.DataType_FloatVector, FloatVector[0:12], Dim)) + + suite.Run("test stop seg core merge for best", func() { + result1 := &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{0, 1, 4}, + }, + }, + }, + Offset: []int64{0, 1, 2}, + FieldsData: fieldDataArray1, + } + result2 := &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{2, 3, 6}, + }, + }, + }, + Offset: []int64{0, 1, 2}, + FieldsData: fieldDataArray2, + } + suite.Run("merge stop finite limited", func() { + result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, + NewMergeParam(3, make([]int64, 0), nil, true)) + suite.NoError(err) + suite.Equal(2, len(result.GetFieldsData())) + suite.Equal([]int64{0, 1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) + // here, we can only get best result from 0 to 4 without 6, because we can never know whether there is + // one potential 5 in following result1 + suite.Equal([]int64{11, 22, 11, 22, 33}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + suite.InDeltaSlice([]float32{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44}, + result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + }) + suite.Run("merge stop unlimited", func() { + result, err := MergeSegcoreRetrieveResults(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true)) + suite.NoError(err) + suite.Equal(2, len(result.GetFieldsData())) + suite.Equal([]int64{0, 1, 2, 3, 4, 6}, result.GetIds().GetIntId().GetData()) + // here, we can only get best result from 0 to 4 without 6, because we can never know whether there is + // one potential 5 in following result1 + suite.Equal([]int64{11, 22, 11, 22, 33, 33}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + suite.InDeltaSlice([]float32{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44, 11, 22, 33, 44}, + result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + }) + }) + + suite.Run("test stop internal merge for best", func() { + result1 := &internalpb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{0, 4, 7}, + }, + }, + }, + FieldsData: fieldDataArray1, + } + result2 := &internalpb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{2, 6, 9}, + }, + }, + }, + FieldsData: fieldDataArray2, + } + result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, + NewMergeParam(3, make([]int64, 0), nil, true)) + suite.NoError(err) + suite.Equal(2, len(result.GetFieldsData())) + suite.Equal([]int64{0, 2, 4, 6, 7}, result.GetIds().GetIntId().GetData()) + suite.Equal([]int64{11, 11, 22, 22, 33}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + suite.InDeltaSlice([]float32{1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 11, 22, 33, 44}, + result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) }) } @@ -595,7 +719,8 @@ func (suite *ResultSuite) TestSort() { IdField: &schemapb.IDs_IntId{ IntId: &schemapb.LongArray{ Data: []int64{5, 4, 3, 2, 9, 8, 7, 6}, - }}, + }, + }, }, Offset: []int64{5, 4, 3, 2, 9, 8, 7, 6}, FieldsData: []*schemapb.FieldData{ @@ -616,8 +741,10 @@ func (suite *ResultSuite) TestSort() { genFieldData("binary vector field", 107, schemapb.DataType_BinaryVector, []byte{5, 4, 3, 2, 9, 8, 7, 6}, 8), genFieldData("json field", 108, schemapb.DataType_JSON, - [][]byte{[]byte("{\"5\": 5}"), []byte("{\"4\": 4}"), []byte("{\"3\": 3}"), []byte("{\"2\": 2}"), - []byte("{\"9\": 9}"), []byte("{\"8\": 8}"), []byte("{\"7\": 7}"), []byte("{\"6\": 6}")}, 1), + [][]byte{ + []byte("{\"5\": 5}"), []byte("{\"4\": 4}"), []byte("{\"3\": 3}"), []byte("{\"2\": 2}"), + []byte("{\"9\": 9}"), []byte("{\"8\": 8}"), []byte("{\"7\": 7}"), []byte("{\"6\": 6}"), + }, 1), genFieldData("json field", 108, schemapb.DataType_Array, []*schemapb.ScalarField{ {Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{5, 6, 7}}}}, @@ -644,8 +771,10 @@ func (suite *ResultSuite) TestSort() { suite.Equal([]int32{2, 3, 4, 5, 6, 7, 8, 9}, result.FieldsData[5].GetScalars().GetIntData().Data) suite.InDeltaSlice([]float32{2, 3, 4, 5, 6, 7, 8, 9}, result.FieldsData[6].GetVectors().GetFloatVector().GetData(), 10e-10) suite.Equal([]byte{2, 3, 4, 5, 6, 7, 8, 9}, result.FieldsData[7].GetVectors().GetBinaryVector()) - suite.Equal([][]byte{[]byte("{\"2\": 2}"), []byte("{\"3\": 3}"), []byte("{\"4\": 4}"), []byte("{\"5\": 5}"), - []byte("{\"6\": 6}"), []byte("{\"7\": 7}"), []byte("{\"8\": 8}"), []byte("{\"9\": 9}")}, result.FieldsData[8].GetScalars().GetJsonData().GetData()) + suite.Equal([][]byte{ + []byte("{\"2\": 2}"), []byte("{\"3\": 3}"), []byte("{\"4\": 4}"), []byte("{\"5\": 5}"), + []byte("{\"6\": 6}"), []byte("{\"7\": 7}"), []byte("{\"8\": 8}"), []byte("{\"9\": 9}"), + }, result.FieldsData[8].GetScalars().GetJsonData().GetData()) suite.Equal([]*schemapb.ScalarField{ {Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{2, 3, 4}}}}, {Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{3, 4, 5}}}}, @@ -692,5 +821,6 @@ func TestResult_MergeRequestCost(t *testing.T) { } func TestResult(t *testing.T) { + paramtable.Init() suite.Run(t, new(ResultSuite)) } diff --git a/internal/querynodev2/segments/retrieve.go b/internal/querynodev2/segments/retrieve.go index 019e5e964580a..60b976f99b4e5 100644 --- a/internal/querynodev2/segments/retrieve.go +++ b/internal/querynodev2/segments/retrieve.go @@ -22,11 +22,14 @@ import ( "sync" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" - . "github.com/milvus-io/milvus/pkg/util/typeutil" ) // retrieveOnSegments performs retrieve on listed segments @@ -45,19 +48,14 @@ func retrieveOnSegments(ctx context.Context, segments []Segment, segType Segment for i, segment := range segments { wg.Add(1) - go func(segment Segment, i int) { + go func(seg Segment, i int) { defer wg.Done() - seg := segment.(*LocalSegment) tr := timerecord.NewTimeRecorder("retrieveOnSegments") result, err := seg.Retrieve(ctx, plan) if err != nil { errs[i] = err return } - if err = seg.ValidateIndexedFieldsData(ctx, result); err != nil { - errs[i] = err - return - } errs[i] = nil resultCh <- result metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), @@ -81,23 +79,93 @@ func retrieveOnSegments(ctx context.Context, segments []Segment, segType Segment return retrieveResults, nil } -// retrieveHistorical will retrieve all the target segments in historical -func RetrieveHistorical(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []Segment, error) { - segments, err := validateOnHistorical(ctx, manager, collID, partIDs, segIDs) +func retrieveOnSegmentsWithStream(ctx context.Context, segments []Segment, segType SegmentType, plan *RetrievePlan, svr streamrpc.QueryStreamServer) error { + var ( + errs = make([]error, len(segments)) + wg sync.WaitGroup + ) + + label := metrics.SealedSegmentLabel + if segType == commonpb.SegmentState_Growing { + label = metrics.GrowingSegmentLabel + } + + for i, segment := range segments { + wg.Add(1) + go func(segment Segment, i int) { + defer wg.Done() + seg := segment.(*LocalSegment) + tr := timerecord.NewTimeRecorder("retrieveOnSegmentsWithStream") + result, err := seg.Retrieve(ctx, plan) + if err != nil { + errs[i] = err + return + } + + if err = svr.Send(&internalpb.RetrieveResults{ + Status: merr.Success(), + Ids: result.GetIds(), + FieldsData: result.GetFieldsData(), + }); err != nil { + errs[i] = err + } + + errs[i] = nil + metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), + metrics.QueryLabel, label).Observe(float64(tr.ElapseSpan().Milliseconds())) + }(segment, i) + } + wg.Wait() + return merr.Combine(errs...) +} + +// retrieve will retrieve all the validate target segments +func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest) ([]*segcorepb.RetrieveResults, []Segment, error) { + var err error + var SegType commonpb.SegmentState + var retrieveResults []*segcorepb.RetrieveResults + var retrieveSegments []Segment + + segIDs := req.GetSegmentIDs() + collID := req.Req.GetCollectionID() + + if req.GetScope() == querypb.DataScope_Historical { + SegType = SegmentTypeSealed + retrieveSegments, err = validateOnHistorical(ctx, manager, collID, nil, segIDs) + } else { + SegType = SegmentTypeGrowing + retrieveSegments, err = validateOnStream(ctx, manager, collID, nil, segIDs) + } + if err != nil { - return nil, nil, err + return retrieveResults, retrieveSegments, err } - retrieveResults, err := retrieveOnSegments(ctx, segments, SegmentTypeSealed, plan) - return retrieveResults, segments, err + retrieveResults, err = retrieveOnSegments(ctx, retrieveSegments, SegType, plan) + return retrieveResults, retrieveSegments, err } -// retrieveStreaming will retrieve all the target segments in streaming -func RetrieveStreaming(ctx context.Context, manager *Manager, plan *RetrievePlan, collID UniqueID, partIDs []UniqueID, segIDs []UniqueID) ([]*segcorepb.RetrieveResults, []Segment, error) { - segments, err := validateOnStream(ctx, manager, collID, partIDs, segIDs) +// retrieveStreaming will retrieve all the validate target segments and return by stream +func RetrieveStream(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) ([]Segment, error) { + var err error + var SegType commonpb.SegmentState + var retrieveSegments []Segment + + segIDs := req.GetSegmentIDs() + collID := req.Req.GetCollectionID() + + if req.GetScope() == querypb.DataScope_Historical { + SegType = SegmentTypeSealed + retrieveSegments, err = validateOnHistorical(ctx, manager, collID, nil, segIDs) + } else { + SegType = SegmentTypeGrowing + retrieveSegments, err = validateOnStream(ctx, manager, collID, nil, segIDs) + } + if err != nil { - return nil, nil, err + return retrieveSegments, err } - retrieveResults, err := retrieveOnSegments(ctx, segments, SegmentTypeGrowing, plan) - return retrieveResults, segments, err + + err = retrieveOnSegmentsWithStream(ctx, retrieveSegments, SegType, plan, srv) + return retrieveSegments, err } diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go index cffc16efab770..c356948b52739 100644 --- a/internal/querynodev2/segments/retrieve_test.go +++ b/internal/querynodev2/segments/retrieve_test.go @@ -18,14 +18,17 @@ package segments import ( "context" + "io" "testing" "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/initcore" + "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -129,8 +132,8 @@ func (suite *RetrieveSuite) SetupTest() { } func (suite *RetrieveSuite) TearDownTest() { - DeleteSegment(suite.sealed) - DeleteSegment(suite.growing) + suite.sealed.Release() + suite.growing.Release() DeleteCollection(suite.collection) ctx := context.Background() suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath) @@ -140,10 +143,16 @@ func (suite *RetrieveSuite) TestRetrieveSealed() { plan, err := genSimpleRetrievePlan(suite.collection) suite.NoError(err) - res, segments, err := RetrieveHistorical(context.TODO(), suite.manager, plan, - suite.collectionID, - []int64{suite.partitionID}, - []int64{suite.sealed.ID()}) + req := &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + CollectionID: suite.collectionID, + PartitionIDs: []int64{suite.partitionID}, + }, + SegmentIDs: []int64{suite.sealed.ID()}, + Scope: querypb.DataScope_Historical, + } + + res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) suite.NoError(err) suite.Len(res[0].Offset, 3) suite.manager.Segment.Unpin(segments) @@ -153,24 +162,80 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() { plan, err := genSimpleRetrievePlan(suite.collection) suite.NoError(err) - res, segments, err := RetrieveStreaming(context.TODO(), suite.manager, plan, - suite.collectionID, - []int64{suite.partitionID}, - []int64{suite.growing.ID()}) + req := &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + CollectionID: suite.collectionID, + PartitionIDs: []int64{suite.partitionID}, + }, + SegmentIDs: []int64{suite.growing.ID()}, + Scope: querypb.DataScope_Streaming, + } + + res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) suite.NoError(err) suite.Len(res[0].Offset, 3) suite.manager.Segment.Unpin(segments) } +func (suite *RetrieveSuite) TestRetrieveStreamSealed() { + plan, err := genSimpleRetrievePlan(suite.collection) + suite.NoError(err) + + req := &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + CollectionID: suite.collectionID, + PartitionIDs: []int64{suite.partitionID}, + }, + SegmentIDs: []int64{suite.sealed.ID()}, + Scope: querypb.DataScope_Historical, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + go func() { + segments, err := RetrieveStream(ctx, suite.manager, plan, req, server) + suite.NoError(err) + suite.manager.Segment.Unpin(segments) + server.FinishSend(err) + }() + + sum := 0 + for { + result, err := client.Recv() + if err != nil { + if err == io.EOF { + suite.Equal(3, sum) + break + } + suite.Fail("Retrieve stream fetch error") + } + + err = merr.Error(result.GetStatus()) + suite.NoError(err) + + sum += len(result.Ids.GetIntId().GetData()) + } +} + func (suite *RetrieveSuite) TestRetrieveNonExistSegment() { plan, err := genSimpleRetrievePlan(suite.collection) suite.NoError(err) - res, segments, err := RetrieveHistorical(context.TODO(), suite.manager, plan, - suite.collectionID, - []int64{suite.partitionID}, - []int64{999}) - suite.ErrorIs(err, merr.ErrSegmentNotLoaded) + req := &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + CollectionID: suite.collectionID, + PartitionIDs: []int64{suite.partitionID}, + }, + SegmentIDs: []int64{999}, + Scope: querypb.DataScope_Streaming, + } + + res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) + suite.Error(err) suite.Len(res, 0) suite.manager.Segment.Unpin(segments) } @@ -179,11 +244,17 @@ func (suite *RetrieveSuite) TestRetrieveNilSegment() { plan, err := genSimpleRetrievePlan(suite.collection) suite.NoError(err) - DeleteSegment(suite.sealed) - res, segments, err := RetrieveHistorical(context.TODO(), suite.manager, plan, - suite.collectionID, - []int64{suite.partitionID}, - []int64{suite.sealed.ID()}) + suite.sealed.Release() + req := &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + CollectionID: suite.collectionID, + PartitionIDs: []int64{suite.partitionID}, + }, + SegmentIDs: []int64{suite.sealed.ID()}, + Scope: querypb.DataScope_Historical, + } + + res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) suite.ErrorIs(err, merr.ErrSegmentNotLoaded) suite.Len(res, 0) suite.manager.Segment.Unpin(segments) diff --git a/internal/querynodev2/segments/search.go b/internal/querynodev2/segments/search.go index af4b46097a0b4..8bc71ee61d051 100644 --- a/internal/querynodev2/segments/search.go +++ b/internal/querynodev2/segments/search.go @@ -52,9 +52,8 @@ func searchSegments(ctx context.Context, segments []Segment, segType SegmentType // calling segment search in goroutines for i, segment := range segments { wg.Add(1) - go func(segment Segment, i int) { + go func(seg Segment, i int) { defer wg.Done() - seg := segment.(*LocalSegment) if !seg.ExistIndex(searchReq.searchFieldID) { mu.Lock() segmentsWithoutIndex = append(segmentsWithoutIndex, seg.ID()) diff --git a/internal/querynodev2/segments/search_test.go b/internal/querynodev2/segments/search_test.go index ee80de6ccd508..9c7d257f55642 100644 --- a/internal/querynodev2/segments/search_test.go +++ b/internal/querynodev2/segments/search_test.go @@ -122,7 +122,7 @@ func (suite *SearchSuite) SetupTest() { } func (suite *SearchSuite) TearDownTest() { - DeleteSegment(suite.sealed) + suite.sealed.Release() DeleteCollection(suite.collection) ctx := context.Background() suite.chunkManager.RemoveWithPrefix(ctx, paramtable.Get().MinioCfg.RootPath.GetValue()) diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index 97922017643b6..1ac2cd6aa6e91 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -33,8 +33,6 @@ import ( "github.com/apache/arrow/go/v12/arrow/array" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/merr" "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" @@ -50,10 +48,11 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" - pkoracle "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" + "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -66,9 +65,7 @@ const ( SegmentTypeSealed = commonpb.SegmentState_Sealed ) -var ( - ErrSegmentUnhealthy = errors.New("segment unhealthy") -) +var ErrSegmentUnhealthy = errors.New("segment unhealthy") // IndexedFieldInfo contains binlog info of vector field type IndexedFieldInfo struct { @@ -190,10 +187,10 @@ func NewSegment(collection *Collection, zap.Int64("segmentID", segmentID), zap.String("segmentType", segmentType.String())) - var segment = &LocalSegment{ + segment := &LocalSegment{ baseSegment: newBaseSegment(segmentID, partitionID, collectionID, shard, segmentType, version, startPosition), ptr: segmentPtr, - lastDeltaTimestamp: atomic.NewUint64(deltaPosition.GetTimestamp()), + lastDeltaTimestamp: atomic.NewUint64(0), fieldIndexes: typeutil.NewConcurrentMap[int64, *IndexedFieldInfo](), } @@ -360,31 +357,6 @@ func (s *LocalSegment) Type() SegmentType { return s.typ } -func DeleteSegment(segment *LocalSegment) { - /* - void - deleteSegment(CSegmentInterface segment); - */ - // wait all read ops finished - var ptr C.CSegmentInterface - - segment.ptrLock.Lock() - ptr = segment.ptr - segment.ptr = nil - segment.ptrLock.Unlock() - - if ptr == nil { - return - } - - C.DeleteSegment(ptr) - log.Info("delete segment from memory", - zap.Int64("collectionID", segment.collectionID), - zap.Int64("partitionID", segment.partitionID), - zap.Int64("segmentID", segment.ID()), - zap.String("segmentType", segment.typ.String())) -} - func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) { /* CStatus @@ -449,7 +421,7 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco return nil, merr.WrapErrSegmentNotLoaded(s.segmentID, "segment released") } - log := log.With( + log := log.Ctx(ctx).With( zap.Int64("collectionID", s.Collection()), zap.Int64("partitionID", s.Partition()), zap.Int64("segmentID", s.ID()), @@ -517,35 +489,6 @@ func (s *LocalSegment) GetFieldDataPath(index *IndexedFieldInfo, offset int64) ( return dataPath, offsetInBinlog } -func (s *LocalSegment) ValidateIndexedFieldsData(ctx context.Context, result *segcorepb.RetrieveResults) error { - log := log.Ctx(ctx).With( - zap.Int64("collectionID", s.Collection()), - zap.Int64("partitionID", s.Partition()), - zap.Int64("segmentID", s.ID()), - ) - - for _, fieldData := range result.FieldsData { - if !typeutil.IsVectorType(fieldData.GetType()) { - continue - } - if !s.ExistIndex(fieldData.FieldId) { - continue - } - if !s.HasRawData(fieldData.FieldId) { - index := s.GetIndex(fieldData.FieldId) - indexType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.IndexTypeKey, index.IndexInfo.GetIndexParams()) - if err != nil { - return err - } - err = fmt.Errorf("vector output fields for %s index is not allowed", indexType) - log.Warn("validate fields failed", zap.Error(err)) - return err - } - } - - return nil -} - // -------------------------------------------------------------------------------------- interfaces for growing segment func (s *LocalSegment) preInsert(numOfRecords int) (int64, error) { /* @@ -588,11 +531,11 @@ func (s *LocalSegment) Insert(rowIDs []int64, timestamps []typeutil.Timestamp, r return fmt.Errorf("failed to marshal insert record: %s", err) } - var numOfRow = len(rowIDs) - var cOffset = C.int64_t(offset) - var cNumOfRows = C.int64_t(numOfRow) - var cEntityIdsPtr = (*C.int64_t)(&(rowIDs)[0]) - var cTimestampsPtr = (*C.uint64_t)(&(timestamps)[0]) + numOfRow := len(rowIDs) + cOffset := C.int64_t(offset) + cNumOfRows := C.int64_t(numOfRow) + cEntityIdsPtr := (*C.int64_t)(&(rowIDs)[0]) + cTimestampsPtr := (*C.uint64_t)(&(timestamps)[0]) var status C.CStatus @@ -637,9 +580,9 @@ func (s *LocalSegment) Delete(primaryKeys []storage.PrimaryKey, timestamps []typ return merr.WrapErrSegmentNotLoaded(s.segmentID, "segment released") } - var cOffset = C.int64_t(0) // depre - var cSize = C.int64_t(len(primaryKeys)) - var cTimestampsPtr = (*C.uint64_t)(&(timestamps)[0]) + cOffset := C.int64_t(0) // depre + cSize := C.int64_t(len(primaryKeys)) + cTimestampsPtr := (*C.uint64_t)(&(timestamps)[0]) ids := &schemapb.IDs{} pkType := primaryKeys[0].Type() @@ -722,7 +665,7 @@ func (s *LocalSegment) LoadMultiFieldData(rowCount int64, fields []*datapb.Field } for _, binlog := range field.Binlogs { - err = loadFieldDataInfo.appendLoadFieldDataPath(fieldID, binlog.GetLogPath()) + err = loadFieldDataInfo.appendLoadFieldDataPath(fieldID, binlog) if err != nil { return err } @@ -783,7 +726,7 @@ func (s *LocalSegment) LoadFieldData(fieldID int64, rowCount int64, field *datap } for _, binlog := range field.Binlogs { - err = loadFieldDataInfo.appendLoadFieldDataPath(fieldID, binlog.GetLogPath()) + err = loadFieldDataInfo.appendLoadFieldDataPath(fieldID, binlog) if err != nil { return err } @@ -824,7 +767,6 @@ func (s *LocalSegment) LoadDeltaData2(schema *schemapb.CollectionSchema) error { if err != nil { return err } - ids := &schemapb.IDs{} var pkint64s []int64 var pkstrings []string @@ -897,6 +839,54 @@ func (s *LocalSegment) LoadDeltaData2(schema *schemapb.CollectionSchema) error { zap.String("segmentType", s.Type().String())) return nil } +func (s *LocalSegment) AddFieldDataInfo(rowCount int64, fields []*datapb.FieldBinlog) error { + s.ptrLock.RLock() + defer s.ptrLock.RUnlock() + + if s.ptr == nil { + return merr.WrapErrSegmentNotLoaded(s.segmentID, "segment released") + } + + log := log.With( + zap.Int64("collectionID", s.Collection()), + zap.Int64("partitionID", s.Partition()), + zap.Int64("segmentID", s.ID()), + zap.Int64("row count", rowCount), + ) + + loadFieldDataInfo, err := newLoadFieldDataInfo() + defer deleteFieldDataInfo(loadFieldDataInfo) + if err != nil { + return err + } + + for _, field := range fields { + fieldID := field.FieldID + err = loadFieldDataInfo.appendLoadFieldInfo(fieldID, rowCount) + if err != nil { + return err + } + + for _, binlog := range field.Binlogs { + err = loadFieldDataInfo.appendLoadFieldDataPath(fieldID, binlog) + if err != nil { + return err + } + } + } + + var status C.CStatus + GetDynamicPool().Submit(func() (any, error) { + status = C.AddFieldDataInfoForSealed(s.ptr, loadFieldDataInfo.cLoadFieldDataInfo) + return nil, nil + }).Await() + if err := HandleCStatus(&status, "AddFieldDataInfo failed"); err != nil { + return err + } + + log.Info("add field data info done") + return nil +} func (s *LocalSegment) LoadDeltaData(deltaData *storage.DeleteData) error { pks, tss := deltaData.Pks, deltaData.Tss @@ -967,6 +957,8 @@ func (s *LocalSegment) LoadDeltaData(deltaData *storage.DeleteData) error { return err } + s.lastDeltaTimestamp.Store(tss[len(tss)-1]) + log.Info("load deleted record done", zap.Int64("rowNum", rowNum), zap.String("segmentType", s.Type().String())) @@ -1049,3 +1041,29 @@ func (s *LocalSegment) UpdateFieldRawDataSize(numRows int64, fieldBinlog *datapb return nil } + +func (s *LocalSegment) Release() { + /* + void + deleteSegment(CSegmentInterface segment); + */ + // wait all read ops finished + var ptr C.CSegmentInterface + + s.ptrLock.Lock() + ptr = s.ptr + s.ptr = nil + s.ptrLock.Unlock() + + if ptr == nil { + return + } + + C.DeleteSegment(ptr) + log.Info("delete segment from memory", + zap.Int64("collectionID", s.collectionID), + zap.Int64("partitionID", s.partitionID), + zap.Int64("segmentID", s.ID()), + zap.String("segmentType", s.typ.String()), + ) +} diff --git a/internal/querynodev2/segments/segment_interface.go b/internal/querynodev2/segments/segment_interface.go index 1be88aaddeb27..71f24cc90da66 100644 --- a/internal/querynodev2/segments/segment_interface.go +++ b/internal/querynodev2/segments/segment_interface.go @@ -17,6 +17,8 @@ package segments import ( + "context" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/proto/segcorepb" storage "github.com/milvus-io/milvus/internal/storage" @@ -57,4 +59,10 @@ type Segment interface { // Bloom filter related UpdateBloomFilter(pks []storage.PrimaryKey) MayPkExist(pk storage.PrimaryKey) bool + + // Read operations + Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) + Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) + + Release() } diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index 920c91e9e598e..bb15b70e95b84 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -16,6 +16,17 @@ package segments +/* +#cgo pkg-config: milvus_segcore milvus_common + +#include "segcore/collection_c.h" +#include "segcore/segment_c.h" +#include "segcore/segcore_init_c.h" +#include "common/init_c.h" + +*/ +import "C" + import ( "context" "fmt" @@ -28,6 +39,7 @@ import ( "github.com/cockroachdb/errors" "github.com/samber/lo" + "go.uber.org/atomic" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -54,9 +66,7 @@ const ( UsedDiskMemoryRatio = 4 ) -var ( - ErrReadDeltaMsgFailed = errors.New("ReadDeltaMsgFailed") -) +var ErrReadDeltaMsgFailed = errors.New("ReadDeltaMsgFailed") type Loader interface { // Load loads binlogs, and spawn segments, @@ -186,12 +196,37 @@ func NewLoader( loader := &segmentLoader{ manager: manager, cm: cm, - loadingSegments: typeutil.NewConcurrentMap[int64, chan struct{}](), + loadingSegments: typeutil.NewConcurrentMap[int64, *loadResult](), } return loader } +type loadStatus = int32 + +const ( + loading loadStatus = iota + 1 + success + failure +) + +type loadResult struct { + status *atomic.Int32 + cond *sync.Cond +} + +func newLoadResult() *loadResult { + return &loadResult{ + status: atomic.NewInt32(loading), + cond: sync.NewCond(&sync.Mutex{}), + } +} + +func (r *loadResult) SetResult(status loadStatus) { + r.status.CompareAndSwap(loading, status) + r.cond.Broadcast() +} + // segmentLoader is only responsible for loading the field data from binlog type segmentLoader struct { manager *Manager @@ -199,7 +234,7 @@ type segmentLoader struct { mut sync.Mutex // The channel will be closed as the segment loaded - loadingSegments *typeutil.ConcurrentMap[int64, chan struct{}] + loadingSegments *typeutil.ConcurrentMap[int64, *loadResult] committedResource LoadResource space *milvus_storage.Space } @@ -225,23 +260,31 @@ func (loader *segmentLoader) Load(ctx context.Context, infos := loader.prepare(segmentType, version, segments...) defer loader.unregister(infos...) + log.With( + zap.Int64s("requestSegments", lo.Map(segments, func(s *querypb.SegmentLoadInfo, _ int) int64 { return s.GetSegmentID() })), + zap.Int64s("preparedSegments", lo.Map(infos, func(s *querypb.SegmentLoadInfo, _ int) int64 { return s.GetSegmentID() })), + ) + // continue to wait other task done log.Info("start loading...", zap.Int("segmentNum", len(segments)), zap.Int("afterFilter", len(infos))) // Check memory & storage limit resource, concurrencyLevel, err := loader.requestResource(ctx, infos...) if err != nil { + log.Error("request resource failed", zap.Error(err)) return nil, err } defer loader.freeRequest(resource) - newSegments := make(map[int64]*LocalSegment, len(infos)) - clearAll := func() { - for _, s := range newSegments { - DeleteSegment(s) - } + newSegments := typeutil.NewConcurrentMap[int64, *LocalSegment]() + loaded := typeutil.NewConcurrentMap[int64, *LocalSegment]() + defer func() { + newSegments.Range(func(_ int64, s *LocalSegment) bool { + s.Release() + return true + }) debug.FreeOSMemory() - } + }() for _, info := range infos { segmentID := info.SegmentID @@ -253,7 +296,6 @@ func (loader *segmentLoader) Load(ctx context.Context, if collection == nil { err := merr.WrapErrCollectionNotFound(collectionID) log.Warn("failed to get collection", zap.Error(err)) - clearAll() return nil, err } @@ -269,18 +311,17 @@ func (loader *segmentLoader) Load(ctx context.Context, zap.Int64("segmentID", segmentID), zap.Error(err), ) - clearAll() return nil, err } - newSegments[segmentID] = segment + newSegments.Insert(segmentID, segment) } loadSegmentFunc := func(idx int) error { loadInfo := infos[idx] partitionID := loadInfo.PartitionID segmentID := loadInfo.SegmentID - segment := newSegments[segmentID] + segment, _ := newSegments.Get(segmentID) tr := timerecord.NewTimeRecorder("loadDurationPerSegment") err := loader.loadSegment(ctx, segment, loadInfo) @@ -292,6 +333,9 @@ func (loader *segmentLoader) Load(ctx context.Context, ) return err } + loader.manager.Segment.Put(segmentType, segment) + newSegments.GetAndRemove(segmentID) + loaded.Insert(segmentID, segment) log.Info("load segment done", zap.Int64("segmentID", segmentID)) loader.notifyLoadFinish(loadInfo) @@ -307,25 +351,23 @@ func (loader *segmentLoader) Load(ctx context.Context, err = funcutil.ProcessFuncParallel(len(infos), concurrencyLevel, loadSegmentFunc, "loadSegmentFunc") if err != nil { - clearAll() log.Warn("failed to load some segments", zap.Error(err)) return nil, err } // Wait for all segments loaded if err := loader.waitSegmentLoadDone(ctx, segmentType, lo.Map(segments, func(info *querypb.SegmentLoadInfo, _ int) int64 { return info.GetSegmentID() })...); err != nil { - clearAll() log.Warn("failed to wait the filtered out segments load done", zap.Error(err)) return nil, err } - loaded := make([]Segment, 0, len(newSegments)) - for _, segment := range newSegments { - loaded = append(loaded, segment) - } - loader.manager.Segment.Put(segmentType, loaded...) log.Info("all segment load done") - return loaded, nil + var result []Segment + loaded.Range(func(_ int64, s *LocalSegment) bool { + result = append(result, s) + return true + }) + return result, nil } func (loader *segmentLoader) prepare(segmentType SegmentType, version int64, segments ...*querypb.SegmentLoadInfo) []*querypb.SegmentLoadInfo { @@ -339,7 +381,7 @@ func (loader *segmentLoader) prepare(segmentType SegmentType, version int64, seg if len(loader.manager.Segment.GetBy(WithType(segmentType), WithID(segment.GetSegmentID()))) == 0 && !loader.loadingSegments.Contain(segment.GetSegmentID()) { infos = append(infos, segment) - loader.loadingSegments.Insert(segment.GetSegmentID(), make(chan struct{})) + loader.loadingSegments.Insert(segment.GetSegmentID(), newLoadResult()) } else { // try to update segment version before skip load operation loader.manager.Segment.UpdateSegmentBy(IncreaseVersion(version), @@ -358,23 +400,18 @@ func (loader *segmentLoader) unregister(segments ...*querypb.SegmentLoadInfo) { loader.mut.Lock() defer loader.mut.Unlock() for i := range segments { - waitCh, ok := loader.loadingSegments.GetAndRemove(segments[i].GetSegmentID()) + result, ok := loader.loadingSegments.GetAndRemove(segments[i].GetSegmentID()) if ok { - select { - case <-waitCh: - default: // close wait channel for failed task - close(waitCh) - } + result.SetResult(failure) } } } func (loader *segmentLoader) notifyLoadFinish(segments ...*querypb.SegmentLoadInfo) { - for _, loadInfo := range segments { - waitCh, ok := loader.loadingSegments.Get(loadInfo.GetSegmentID()) + result, ok := loader.loadingSegments.Get(loadInfo.GetSegmentID()) if ok { - close(waitCh) + result.SetResult(success) } } } @@ -382,6 +419,14 @@ func (loader *segmentLoader) notifyLoadFinish(segments ...*querypb.SegmentLoadIn // requestResource requests memory & storage to load segments, // returns the memory usage, disk usage and concurrency with the gained memory. func (loader *segmentLoader) requestResource(ctx context.Context, infos ...*querypb.SegmentLoadInfo) (LoadResource, int, error) { + resource := LoadResource{} + // we need to deal with empty infos case separately, + // because the following judgement for requested resources are based on current status and static config + // which may block empty-load operations by accident + if len(infos) == 0 { + return resource, 0, nil + } + segmentIDs := lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) int64 { return info.GetSegmentID() }) @@ -389,8 +434,6 @@ func (loader *segmentLoader) requestResource(ctx context.Context, infos ...*quer zap.Int64s("segmentIDs", segmentIDs), ) - resource := LoadResource{} - loader.mut.Lock() defer loader.mut.Unlock() @@ -464,23 +507,48 @@ func (loader *segmentLoader) freeRequest(resource LoadResource) { } func (loader *segmentLoader) waitSegmentLoadDone(ctx context.Context, segmentType SegmentType, segmentIDs ...int64) error { + log := log.Ctx(ctx).With( + zap.String("segmentType", segmentType.String()), + zap.Int64s("segmentIDs", segmentIDs), + ) for _, segmentID := range segmentIDs { if loader.manager.Segment.GetWithType(segmentID, segmentType) != nil { continue } - waitCh, ok := loader.loadingSegments.Get(segmentID) + result, ok := loader.loadingSegments.Get(segmentID) if !ok { log.Warn("segment was removed from the loading map early", zap.Int64("segmentID", segmentID)) return errors.New("segment was removed from the loading map early") } log.Info("wait segment loaded...", zap.Int64("segmentID", segmentID)) - select { - case <-ctx.Done(): + + signal := make(chan struct{}) + go func() { + select { + case <-signal: + case <-ctx.Done(): + result.cond.Broadcast() + } + }() + result.cond.L.Lock() + for result.status.Load() == loading && ctx.Err() == nil { + result.cond.Wait() + } + result.cond.L.Unlock() + close(signal) + + if ctx.Err() != nil { + log.Warn("failed to wait segment loaded due to context done", zap.Int64("segmentID", segmentID)) return ctx.Err() - case <-waitCh: } + + if result.status.Load() == failure { + log.Warn("failed to wait segment loaded", zap.Int64("segmentID", segmentID)) + return merr.WrapErrSegmentLack(segmentID, "failed to wait segment loaded") + } + log.Info("segment loaded...", zap.Int64("segmentID", segmentID)) } return nil @@ -490,7 +558,7 @@ func (loader *segmentLoader) LoadBloomFilterSet(ctx context.Context, collectionI log := log.Ctx(ctx).With( zap.Int64("collectionID", collectionID), zap.Int64s("segmentIDs", lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) int64 { - return info.SegmentID + return info.GetSegmentID() })), ) @@ -605,6 +673,9 @@ func (loader *segmentLoader) loadSegment(ctx context.Context, if err := loader.loadSealedSegmentFields(ctx, segment, fieldBinlogs, loadInfo.GetNumOfRows()); err != nil { return err } + if err := segment.AddFieldDataInfo(loadInfo.GetNumOfRows(), loadInfo.GetBinlogPaths()); err != nil { + return err + } // https://github.com/milvus-io/milvus/23654 // legacy entry num = 0 if err := loader.patchEntryNumber(ctx, segment, loadInfo); err != nil { @@ -677,7 +748,8 @@ func (loader *segmentLoader) loadFieldsIndex(ctx context.Context, schema *schemapb.CollectionSchema, segment *LocalSegment, numRows int64, - vecFieldInfos map[int64]*IndexedFieldInfo) error { + vecFieldInfos map[int64]*IndexedFieldInfo, +) error { schemaHelper, _ := typeutil.CreateSchemaHelper(schema) for fieldID, fieldInfo := range vecFieldInfos { @@ -692,6 +764,7 @@ func (loader *segmentLoader) loadFieldsIndex(ctx context.Context, zap.Int64("segment", segment.segmentID), zap.Int64("fieldID", fieldID), zap.Any("binlog", fieldInfo.FieldBinlog.Binlogs), + zap.Int32("current_index_version", fieldInfo.IndexInfo.GetCurrentIndexVersion()), ) segment.AddIndex(fieldID, fieldInfo) @@ -732,8 +805,8 @@ func (loader *segmentLoader) loadFieldIndex(ctx context.Context, segment *LocalS } func (loader *segmentLoader) loadBloomFilter(ctx context.Context, segmentID int64, bfs *pkoracle.BloomFilterSet, - binlogPaths []string, logType storage.StatsLogType) error { - + binlogPaths []string, logType storage.StatsLogType, +) error { log := log.Ctx(ctx).With( zap.Int64("segmentID", segmentID), ) @@ -786,6 +859,11 @@ func (loader *segmentLoader) LoadDeltaLogs(ctx context.Context, segment *LocalSe var blobs []*storage.Blob for _, deltaLog := range deltaLogs { for _, bLog := range deltaLog.GetBinlogs() { + // the segment has applied the delta logs, skip it + if bLog.GetTimestampTo() > 0 && // this field may be missed in legacy versions + bLog.GetTimestampTo() < segment.LastDeltaTimestamp() { + continue + } value, err := loader.cm.Read(ctx, bLog.GetLogPath()) if err != nil { return err @@ -1119,3 +1197,8 @@ func getBinlogDataSize(fieldBinlog *datapb.FieldBinlog) int64 { return fieldSize } + +func getIndexEngineVersion() (minimal, current int32) { + cMinimal, cCurrent := C.GetMinimalIndexVersion(), C.GetCurrentIndexVersion() + return int32(cMinimal), int32(cCurrent) +} diff --git a/internal/querynodev2/segments/segment_loader_test.go b/internal/querynodev2/segments/segment_loader_test.go index 19230c20765be..31524f7651ff7 100644 --- a/internal/querynodev2/segments/segment_loader_test.go +++ b/internal/querynodev2/segments/segment_loader_test.go @@ -20,9 +20,12 @@ import ( "context" "math/rand" "testing" + "time" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/storage" @@ -65,7 +68,7 @@ func (suite *SegmentLoaderSuite) SetupTest() { suite.manager = NewManager() ctx := context.Background() // TODO:: cpp chunk manager not support local chunk manager - //suite.chunkManager = storage.NewLocalChunkManager(storage.RootPath( + // suite.chunkManager = storage.NewLocalChunkManager(storage.RootPath( // fmt.Sprintf("/tmp/milvus-ut/%d", rand.Int63()))) chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx) @@ -209,7 +212,6 @@ func (suite *SegmentLoaderSuite) TestLoadMultipleSegments() { suite.True(exist) } } - } func (suite *SegmentLoaderSuite) TestLoadWithIndex() { @@ -353,6 +355,66 @@ func (suite *SegmentLoaderSuite) TestLoadDeltaLogs() { } } +func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() { + ctx := context.Background() + loadInfos := make([]*querypb.SegmentLoadInfo, 0, suite.segmentNum) + + msgLength := 100 + // Load sealed + for i := 0; i < suite.segmentNum; i++ { + segmentID := suite.segmentID + int64(i) + binlogs, statsLogs, err := SaveBinLog(ctx, + suite.collectionID, + suite.partitionID, + segmentID, + msgLength, + suite.schema, + suite.chunkManager, + ) + suite.NoError(err) + + // Delete PKs 1, 2 + deltaLogs, err := SaveDeltaLog(suite.collectionID, + suite.partitionID, + segmentID, + suite.chunkManager, + ) + suite.NoError(err) + + loadInfos = append(loadInfos, &querypb.SegmentLoadInfo{ + SegmentID: segmentID, + PartitionID: suite.partitionID, + CollectionID: suite.collectionID, + BinlogPaths: binlogs, + Statslogs: statsLogs, + Deltalogs: deltaLogs, + NumOfRows: int64(msgLength), + }) + } + + segments, err := suite.loader.Load(ctx, suite.collectionID, SegmentTypeGrowing, 0, loadInfos...) + suite.NoError(err) + + for i, segment := range segments { + suite.Equal(int64(100-2), segment.RowNum()) + for pk := 0; pk < 100; pk++ { + if pk == 1 || pk == 2 { + continue + } + exist := segment.MayPkExist(storage.NewInt64PrimaryKey(int64(pk))) + suite.Require().True(exist) + } + + seg := segment.(*LocalSegment) + // nothing would happen as the delta logs have been all applied, + // so the released segment won't cause error + seg.Release() + loadInfos[i].Deltalogs[0].Binlogs[0].TimestampTo-- + err := suite.loader.LoadDeltaLogs(ctx, seg, loadInfos[i].GetDeltalogs()) + suite.NoError(err) + } +} + func (suite *SegmentLoaderSuite) TestLoadIndex() { ctx := context.Background() segment := &LocalSegment{} @@ -369,7 +431,6 @@ func (suite *SegmentLoaderSuite) TestLoadIndex() { err := suite.loader.LoadIndex(ctx, segment, loadInfo, 0) suite.ErrorIs(err, merr.ErrIndexNotFound) - } func (suite *SegmentLoaderSuite) TestLoadWithMmap() { @@ -527,6 +588,143 @@ func (suite *SegmentLoaderSuite) TestRunOutMemory() { suite.Error(err) } +type SegmentLoaderDetailSuite struct { + suite.Suite + + loader *segmentLoader + manager *Manager + segmentManager *MockSegmentManager + collectionManager *MockCollectionManager + + rootPath string + chunkManager storage.ChunkManager + + // Data + collectionID int64 + partitionID int64 + segmentID int64 + schema *schemapb.CollectionSchema + segmentNum int +} + +func (suite *SegmentLoaderDetailSuite) SetupSuite() { + paramtable.Init() + suite.rootPath = suite.T().Name() + suite.collectionID = rand.Int63() + suite.partitionID = rand.Int63() + suite.segmentID = rand.Int63() + suite.segmentNum = 5 + suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64) +} + +func (suite *SegmentLoaderDetailSuite) SetupTest() { + // Dependencies + suite.collectionManager = NewMockCollectionManager(suite.T()) + suite.segmentManager = NewMockSegmentManager(suite.T()) + suite.manager = &Manager{ + Segment: suite.segmentManager, + Collection: suite.collectionManager, + } + + ctx := context.Background() + chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) + suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx) + suite.loader = NewLoader(suite.manager, suite.chunkManager) + initcore.InitRemoteChunkManager(paramtable.Get()) + + // Data + schema := GenTestCollectionSchema("test", schemapb.DataType_Int64) + + indexMeta := GenTestIndexMeta(suite.collectionID, schema) + loadMeta := &querypb.LoadMetaInfo{ + LoadType: querypb.LoadType_LoadCollection, + CollectionID: suite.collectionID, + PartitionIDs: []int64{suite.partitionID}, + } + + collection := NewCollection(suite.collectionID, schema, indexMeta, loadMeta.GetLoadType()) + suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection).Maybe() +} + +func (suite *SegmentLoaderDetailSuite) TestWaitSegmentLoadDone() { + suite.Run("wait_success", func() { + idx := 0 + + var infos []*querypb.SegmentLoadInfo + suite.segmentManager.EXPECT().GetBy(mock.Anything, mock.Anything).Return(nil) + suite.segmentManager.EXPECT().GetWithType(suite.segmentID, SegmentTypeSealed).RunAndReturn(func(segmentID int64, segmentType commonpb.SegmentState) Segment { + defer func() { idx++ }() + if idx == 0 { + go func() { + <-time.After(time.Second) + suite.loader.notifyLoadFinish(infos...) + }() + } + return nil + }) + infos = suite.loader.prepare(SegmentTypeSealed, 0, &querypb.SegmentLoadInfo{ + SegmentID: suite.segmentID, + PartitionID: suite.partitionID, + CollectionID: suite.collectionID, + NumOfRows: 100, + }) + + err := suite.loader.waitSegmentLoadDone(context.Background(), SegmentTypeSealed, suite.segmentID) + suite.NoError(err) + }) + + suite.Run("wait_failure", func() { + suite.SetupTest() + + var idx int + var infos []*querypb.SegmentLoadInfo + suite.segmentManager.EXPECT().GetBy(mock.Anything, mock.Anything).Return(nil) + suite.segmentManager.EXPECT().GetWithType(suite.segmentID, SegmentTypeSealed).RunAndReturn(func(segmentID int64, segmentType commonpb.SegmentState) Segment { + defer func() { idx++ }() + if idx == 0 { + go func() { + <-time.After(time.Second) + suite.loader.unregister(infos...) + }() + } + + return nil + }) + infos = suite.loader.prepare(SegmentTypeSealed, 0, &querypb.SegmentLoadInfo{ + SegmentID: suite.segmentID, + PartitionID: suite.partitionID, + CollectionID: suite.collectionID, + NumOfRows: 100, + }) + + err := suite.loader.waitSegmentLoadDone(context.Background(), SegmentTypeSealed, suite.segmentID) + suite.Error(err) + }) + + suite.Run("wait_timeout", func() { + suite.SetupTest() + + suite.segmentManager.EXPECT().GetBy(mock.Anything, mock.Anything).Return(nil) + suite.segmentManager.EXPECT().GetWithType(suite.segmentID, SegmentTypeSealed).RunAndReturn(func(segmentID int64, segmentType commonpb.SegmentState) Segment { + return nil + }) + suite.loader.prepare(SegmentTypeSealed, 0, &querypb.SegmentLoadInfo{ + SegmentID: suite.segmentID, + PartitionID: suite.partitionID, + CollectionID: suite.collectionID, + NumOfRows: 100, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := suite.loader.waitSegmentLoadDone(ctx, SegmentTypeSealed, suite.segmentID) + suite.Error(err) + suite.True(merr.IsCanceledOrTimeout(err)) + }) +} + func TestSegmentLoader(t *testing.T) { suite.Run(t, &SegmentLoaderSuite{}) + suite.Run(t, &SegmentLoaderDetailSuite{}) } diff --git a/internal/querynodev2/segments/segment_test.go b/internal/querynodev2/segments/segment_test.go index 9788b00016ccd..6c8b90e3e6d40 100644 --- a/internal/querynodev2/segments/segment_test.go +++ b/internal/querynodev2/segments/segment_test.go @@ -8,11 +8,8 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/proto/segcorepb" storage "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/initcore" - "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -114,8 +111,8 @@ func (suite *SegmentSuite) SetupTest() { func (suite *SegmentSuite) TearDownTest() { ctx := context.Background() - DeleteSegment(suite.sealed) - DeleteSegment(suite.growing) + suite.sealed.Release() + suite.growing.Release() DeleteCollection(suite.collection) suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath) } @@ -148,56 +145,6 @@ func (suite *SegmentSuite) TestHasRawData() { suite.True(has) } -func (suite *SegmentSuite) TestValidateIndexedFieldsData() { - result := &segcorepb.RetrieveResults{ - Ids: &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: []int64{5, 4, 3, 2, 9, 8, 7, 6}, - }}, - }, - Offset: []int64{5, 4, 3, 2, 9, 8, 7, 6}, - FieldsData: []*schemapb.FieldData{ - genFieldData("int64 field", 100, schemapb.DataType_Int64, - []int64{5, 4, 3, 2, 9, 8, 7, 6}, 1), - genFieldData("float vector field", 101, schemapb.DataType_FloatVector, - []float32{5, 4, 3, 2, 9, 8, 7, 6}, 1), - }, - } - - // no index - err := suite.growing.ValidateIndexedFieldsData(context.Background(), result) - suite.NoError(err) - err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result) - suite.NoError(err) - - // with index and has raw data - suite.sealed.AddIndex(101, &IndexedFieldInfo{ - IndexInfo: &querypb.FieldIndexInfo{ - FieldID: 101, - EnableIndex: true, - }, - }) - suite.True(suite.sealed.ExistIndex(101)) - err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result) - suite.NoError(err) - - // index doesn't have index type - DeleteSegment(suite.sealed) - suite.True(suite.sealed.ExistIndex(101)) - err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result) - suite.Error(err) - - // with index but doesn't have raw data - index := suite.sealed.GetIndex(101) - _, indexParams := genIndexParams(IndexHNSW, metric.L2) - index.IndexInfo.IndexParams = funcutil.Map2KeyValuePair(indexParams) - DeleteSegment(suite.sealed) - suite.True(suite.sealed.ExistIndex(101)) - err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result) - suite.Error(err) -} - func (suite *SegmentSuite) TestCASVersion() { segment := suite.sealed @@ -210,7 +157,7 @@ func (suite *SegmentSuite) TestCASVersion() { } func (suite *SegmentSuite) TestSegmentReleased() { - DeleteSegment(suite.sealed) + suite.sealed.Release() suite.sealed.ptrLock.RLock() suite.False(suite.sealed.isValid()) diff --git a/internal/querynodev2/segments/validate.go b/internal/querynodev2/segments/validate.go index b83ad77b76742..2003e99caba35 100644 --- a/internal/querynodev2/segments/validate.go +++ b/internal/querynodev2/segments/validate.go @@ -37,7 +37,7 @@ func validate(ctx context.Context, manager *Manager, collectionID int64, partiti return nil, merr.WrapErrCollectionNotFound(collectionID) } - //validate partition + // validate partition // no partition id specified, get all partition ids in collection if len(partitionIDs) == 0 { searchPartIDs = collection.GetPartitions() @@ -59,7 +59,7 @@ func validate(ctx context.Context, manager *Manager, collectionID int64, partiti return []Segment{}, nil } - //validate segment + // validate segment segments := make([]Segment, 0, len(segmentIDs)) var err error if len(segmentIDs) == 0 { diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index 97b8d49bb0066..acc4fdb1c07e3 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -41,6 +41,7 @@ import ( "time" "unsafe" + "github.com/samber/lo" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" @@ -48,10 +49,12 @@ import ( grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client" "github.com/milvus-io/milvus/internal/querynodev2/cluster" "github.com/milvus-io/milvus/internal/querynodev2/delegator" + "github.com/milvus-io/milvus/internal/querynodev2/optimizers" "github.com/milvus-io/milvus/internal/querynodev2/pipeline" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tasks" "github.com/milvus-io/milvus/internal/querynodev2/tsafe" + "github.com/milvus-io/milvus/internal/registry" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" @@ -66,7 +69,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/samber/lo" ) // make sure QueryNode implements types.QueryNode @@ -127,7 +129,7 @@ type QueryNode struct { knnPool *conc.Pool*/ // parameter turning hook - queryHook queryHook + queryHook optimizers.QueryHook } // NewQueryNode will return a QueryNode with abnormal state. @@ -145,11 +147,17 @@ func NewQueryNode(ctx context.Context, factory dependency.Factory) *QueryNode { } func (node *QueryNode) initSession() error { - node.session = sessionutil.NewSession(node.ctx, paramtable.Get().EtcdCfg.MetaRootPath.GetValue(), node.etcdCli) + minimalIndexVersion, currentIndexVersion := getIndexEngineVersion() + node.session = sessionutil.NewSession(node.ctx, + paramtable.Get().EtcdCfg.MetaRootPath.GetValue(), + node.etcdCli, + sessionutil.WithIndexEngineVersion(minimalIndexVersion, currentIndexVersion), + ) if node.session == nil { return fmt.Errorf("session is nil, the etcd client connection may have failed") } node.session.Init(typeutil.QueryNodeRole, node.address, false, true) + sessionutil.SaveServerInfo(typeutil.QueryNodeRole, node.session.ServerID) paramtable.SetNodeID(node.session.ServerID) log.Info("QueryNode init session", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("node address", node.session.Address)) return nil @@ -207,7 +215,7 @@ func (node *QueryNode) InitSegcore() error { cIndexSliceSize := C.int64_t(paramtable.Get().CommonCfg.IndexSliceSize.GetAsInt64()) C.InitIndexSliceSize(cIndexSliceSize) - //set up thread pool for different priorities + // set up thread pool for different priorities cHighPriorityThreadCoreCoefficient := C.int64_t(paramtable.Get().CommonCfg.HighPriorityThreadCoreCoefficient.GetAsInt64()) C.InitHighPriorityThreadCoreCoefficient(cHighPriorityThreadCoreCoefficient) cMiddlePriorityThreadCoreCoefficient := C.int64_t(paramtable.Get().CommonCfg.MiddlePriorityThreadCoreCoefficient.GetAsInt64()) @@ -221,8 +229,30 @@ func (node *QueryNode) InitSegcore() error { localDataRootPath := filepath.Join(paramtable.Get().LocalStorageCfg.Path.GetValue(), typeutil.QueryNodeRole) initcore.InitLocalChunkManager(localDataRootPath) + err := initcore.InitRemoteChunkManager(paramtable.Get()) + if err != nil { + return err + } + + mmapDirPath := paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue() + if len(mmapDirPath) == 0 { + mmapDirPath = paramtable.Get().LocalStorageCfg.Path.GetValue() + } + chunkCachePath := path.Join(mmapDirPath, "chunk_cache") + policy := paramtable.Get().QueryNodeCfg.ReadAheadPolicy.GetValue() + err = initcore.InitChunkCache(chunkCachePath, policy) + if err != nil { + return err + } + log.Info("InitChunkCache done", zap.String("dir", chunkCachePath), zap.String("policy", policy)) + initcore.InitTraceConfig(paramtable.Get()) - return initcore.InitRemoteChunkManager(paramtable.Get()) + return nil +} + +func getIndexEngineVersion() (minimal, current int32) { + cMinimal, cCurrent := C.GetMinimalIndexVersion(), C.GetCurrentIndexVersion() + return int32(cMinimal), int32(cCurrent) } func (node *QueryNode) CloseSegcore() { @@ -362,13 +392,15 @@ func (node *QueryNode) Init() error { // Start mainly start QueryNode's query service. func (node *QueryNode) Start() error { node.startOnce.Do(func() { - node.scheduler.Start(node.ctx) + node.scheduler.Start() paramtable.SetCreateTime(time.Now()) paramtable.SetUpdateTime(time.Now()) mmapDirPath := paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue() mmapEnabled := len(mmapDirPath) > 0 node.UpdateStateCode(commonpb.StateCode_Healthy) + + registry.GetInMemoryResolver().RegisterQueryNode(paramtable.GetNodeID(), node) log.Info("query node start successfully", zap.Int64("queryNodeID", paramtable.GetNodeID()), zap.String("Address", node.address), @@ -421,12 +453,14 @@ func (node *QueryNode) Stop() error { case <-time.After(time.Second): } } - } node.UpdateStateCode(commonpb.StateCode_Abnormal) node.lifetime.Wait() node.cancel() + if node.scheduler != nil { + node.scheduler.Stop() + } if node.pipelineManager != nil { node.pipelineManager.Close() } @@ -463,13 +497,6 @@ func (node *QueryNode) SetAddress(address string) { node.address = address } -type queryHook interface { - Run(map[string]any) error - Init(string) error - InitTuningConfig(map[string]string) error - DeleteTuningConfig(string) error -} - // initHook initializes parameter tuning hook. func (node *QueryNode) initHook() error { path := paramtable.Get().QueryNodeCfg.SoPath.GetValue() @@ -489,7 +516,7 @@ func (node *QueryNode) initHook() error { return fmt.Errorf("fail to find the 'QueryNodePlugin' object in the plugin, error: %s", err.Error()) } - hoo, ok := h.(queryHook) + hoo, ok := h.(optimizers.QueryHook) if !ok { return fmt.Errorf("fail to convert the `Hook` interface") } diff --git a/internal/querynodev2/server_test.go b/internal/querynodev2/server_test.go index 413bb9af63a84..1e59d3274ebea 100644 --- a/internal/querynodev2/server_test.go +++ b/internal/querynodev2/server_test.go @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/querynodev2/optimizers" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" @@ -56,7 +57,6 @@ type QueryNodeSuite struct { func (suite *QueryNodeSuite) SetupSuite() { suite.address = "test-address" - } func (suite *QueryNodeSuite) SetupTest() { @@ -97,14 +97,14 @@ func (suite *QueryNodeSuite) TestBasic() { err = suite.node.Init() suite.NoError(err) - // node shoule be unhealthy before node start + // node should be unhealthy before node start suite.False(suite.node.lifetime.GetState() == commonpb.StateCode_Healthy) // start node err = suite.node.Start() suite.NoError(err) - // node shoule be healthy after node start + // node should be healthy after node start suite.True(suite.node.lifetime.GetState() == commonpb.StateCode_Healthy) // register node to etcd @@ -158,7 +158,7 @@ func (suite *QueryNodeSuite) TestInit_QueryHook() { err = suite.node.Init() suite.NoError(err) - mockHook := &MockQueryHook{} + mockHook := optimizers.NewMockQueryHook(suite.T()) suite.node.queryHook = mockHook suite.node.handleQueryHookEvent() diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index fa5bf3e472e5d..bbc55ea1a4047 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -42,7 +42,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tasks" "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/internal/util" + "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -56,9 +56,9 @@ import ( ) // GetComponentStates returns information about whether the node is healthy -func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (node *QueryNode) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { stats := &milvuspb.ComponentStates{ - Status: merr.Status(nil), + Status: merr.Success(), } code := node.lifetime.GetState() @@ -78,18 +78,18 @@ func (node *QueryNode) GetComponentStates(ctx context.Context) (*milvuspb.Compon // GetTimeTickChannel returns the time tick channel // TimeTickChannel contains many time tick messages, which will be sent by query nodes -func (node *QueryNode) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (node *QueryNode) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Value: paramtable.Get().CommonCfg.QueryCoordTimeTick.GetValue(), }, nil } // GetStatisticsChannel returns the statistics channel // Statistics channel contains statistics infos of query nodes, such as segment infos, memory infos -func (node *QueryNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (node *QueryNode) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil } @@ -102,29 +102,22 @@ func (node *QueryNode) GetStatistics(ctx context.Context, req *querypb.GetStatis zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()), zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp())) - if !node.lifetime.Add(commonpbutil.IsHealthy) { - msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID()) - err := merr.WrapErrServiceNotReady(msg) + if err := node.lifetime.Add(merr.IsHealthy); err != nil { return &internalpb.GetStatisticsResponse{ Status: merr.Status(err), }, nil } defer node.lifetime.Done() - if !CheckTargetID(req.GetReq()) { - targetID := req.GetReq().GetBase().GetTargetID() - log.Warn("target ID not match", - zap.Int64("targetID", targetID), - zap.Int64("nodeID", paramtable.GetNodeID()), - ) + err := merr.CheckTargetID(req.GetReq().GetBase()) + if err != nil { + log.Warn("target ID check failed", zap.Error(err)) return &internalpb.GetStatisticsResponse{ - Status: util.WrapStatus(commonpb.ErrorCode_NodeIDNotMatch, - common.WrapNodeIDNotMatchMsg(targetID, paramtable.GetNodeID()), - ), + Status: merr.Status(err), }, nil } failRet := &internalpb.GetStatisticsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), } var toReduceResults []*internalpb.GetStatisticsResponse @@ -141,18 +134,16 @@ func (node *QueryNode) GetStatistics(ctx context.Context, req *querypb.GetStatis } runningGp.Go(func() error { ret, err := node.getChannelStatistics(runningCtx, req, ch) + if err == nil { + err = merr.Error(ret.GetStatus()) + } + mu.Lock() defer mu.Unlock() if err != nil { - failRet.Status.Reason = err.Error() - failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError + failRet.Status = merr.Status(err) return err } - if ret.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - failRet.Status.Reason = ret.Status.Reason - failRet.Status.ErrorCode = ret.Status.ErrorCode - return fmt.Errorf("%s", ret.Status.Reason) - } toReduceResults = append(toReduceResults, ret) return nil }) @@ -163,7 +154,7 @@ func (node *QueryNode) GetStatistics(ctx context.Context, req *querypb.GetStatis ret, err := reduceStatisticResponse(toReduceResults) if err != nil { - failRet.Status.Reason = err.Error() + failRet.Status = merr.Status(err) return failRet, nil } log.Debug("reduce statistic result done") @@ -215,20 +206,14 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm ) // check node healthy - if !node.lifetime.Add(commonpbutil.IsHealthy) { - msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID()) - err := merr.WrapErrServiceNotReady(msg) + if err := node.lifetime.Add(merr.IsHealthy); err != nil { return merr.Status(err), nil } defer node.lifetime.Done() // check target matches - if req.GetBase().GetTargetID() != paramtable.GetNodeID() { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, - Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), paramtable.GetNodeID()), - } - return status, nil + if err := merr.CheckTargetID(req.GetBase()); err != nil { + return merr.Status(err), nil } // check metric type @@ -240,7 +225,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm if !node.subscribingChannels.Insert(channel.GetChannelName()) { msg := "channel subscribing..." log.Warn(msg) - return util.SuccessStatus(msg), nil + return merr.Success(), nil } defer node.subscribingChannels.Remove(channel.GetChannelName()) @@ -254,7 +239,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm _, exist := node.delegators.Get(channel.GetChannelName()) if exist { log.Info("channel already subscribed") - return util.SuccessStatus(), nil + return merr.Success(), nil } node.manager.Collection.PutOrRef(req.GetCollectionID(), req.GetSchema(), @@ -265,7 +250,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm node.clusterManager, node.manager, node.tSafeManager, node.loader, node.factory, channel.GetSeekPosition().GetTimestamp()) if err != nil { log.Warn("failed to create shard delegator", zap.Error(err)) - return util.WrapStatus(commonpb.ErrorCode_UnexpectedError, "failed to create shard delegator", err), nil + return merr.Status(err), nil } node.delegators.Insert(channel.GetChannelName(), delegator) defer func() { @@ -286,7 +271,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm if err != nil { msg := "failed to create pipeline" log.Warn(msg, zap.Error(err)) - return util.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg, err), nil + return merr.Status(err), nil } defer func() { if err != nil { @@ -322,7 +307,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm if err != nil { msg := "failed to load growing segments" log.Warn(msg, zap.Error(err)) - return util.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg, err), nil + return merr.Status(err), nil } position := &msgpb.MsgPosition{ ChannelName: channel.SeekPosition.ChannelName, @@ -344,7 +329,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm // delegator after all steps done delegator.Start() log.Info("watch dml channel success") - return util.SuccessStatus(), nil + return merr.Success(), nil } func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { @@ -357,21 +342,14 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC log.Info("received unsubscribe channel request") // check node healthy - if !node.lifetime.Add(commonpbutil.IsHealthy) { - - msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID()) - err := merr.WrapErrServiceNotReady(msg) + if err := node.lifetime.Add(merr.IsHealthy); err != nil { return merr.Status(err), nil } defer node.lifetime.Done() // check target matches - if req.GetBase().GetTargetID() != paramtable.GetNodeID() { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, - Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), paramtable.GetNodeID()), - } - return status, nil + if err := merr.CheckTargetID(req.GetBase()); err != nil { + return merr.Status(err), nil } node.unsubscribingChannels.Insert(req.GetChannelName()) @@ -389,7 +367,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC } log.Info("unsubscribed channel") - return util.SuccessStatus(), nil + return merr.Success(), nil } func (node *QueryNode) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { @@ -400,9 +378,7 @@ func (node *QueryNode) LoadPartitions(ctx context.Context, req *querypb.LoadPart log.Info("received load partitions request") // check node healthy - if !node.lifetime.Add(commonpbutil.IsHealthyOrStopping) { - msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID()) - err := merr.WrapErrServiceNotReady(msg) + if err := node.lifetime.Add(merr.IsHealthyOrStopping); err != nil { return merr.Status(err), nil } defer node.lifetime.Done() @@ -413,7 +389,7 @@ func (node *QueryNode) LoadPartitions(ctx context.Context, req *querypb.LoadPart } log.Info("load partitions done") - return merr.Status(nil), nil + return merr.Success(), nil } // LoadSegments load historical data into query node, historical data can be vector data or index @@ -433,17 +409,14 @@ func (node *QueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmen zap.Bool("needTransfer", req.GetNeedTransfer()), ) // check node healthy - if !node.lifetime.Add(commonpbutil.IsHealthy) { - msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID()) - err := merr.WrapErrServiceNotReady(msg) + if err := node.lifetime.Add(merr.IsHealthy); err != nil { return merr.Status(err), nil } node.lifetime.Done() // check target matches - if req.GetBase().GetTargetID() != paramtable.GetNodeID() { - return util.WrapStatus(commonpb.ErrorCode_NodeIDNotMatch, - common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), paramtable.GetNodeID())), nil + if err := merr.CheckTargetID(req.GetBase()); err != nil { + return merr.Status(err), nil } // Delegates request to workers @@ -452,17 +425,18 @@ func (node *QueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmen if !ok { msg := "failed to load segments, delegator not found" log.Warn(msg) - return util.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg), nil + err := merr.WrapErrChannelNotFound(segment.GetInsertChannel()) + return merr.Status(err), nil } req.NeedTransfer = false err := delegator.LoadSegments(ctx, req) if err != nil { log.Warn("delegator failed to load segments", zap.Error(err)) - return util.WrapStatus(commonpb.ErrorCode_UnexpectedError, err.Error()), nil + return merr.Status(err), nil } - return util.SuccessStatus(), nil + return merr.Success(), nil } if req.GetLoadScope() == querypb.LoadScope_Delta { @@ -485,10 +459,7 @@ func (node *QueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmen req.GetInfos()..., ) if err != nil { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + return merr.Status(err), nil } node.manager.Collection.Ref(req.GetCollectionID(), uint32(len(loaded))) @@ -496,19 +467,17 @@ func (node *QueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmen log.Info("load segments done...", zap.Int64s("segments", lo.Map(loaded, func(s segments.Segment, _ int) int64 { return s.ID() }))) - return util.SuccessStatus(), nil + return merr.Success(), nil } // ReleaseCollection clears all data related to this collection on the querynode func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { - if !node.lifetime.Add(commonpbutil.IsHealthyOrStopping) { - msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID()) - err := merr.WrapErrServiceNotReady(msg) + if err := node.lifetime.Add(merr.IsHealthyOrStopping); err != nil { return merr.Status(err), nil } defer node.lifetime.Done() - return util.SuccessStatus(), nil + return merr.Success(), nil } // ReleasePartitions clears all data related to this partition on the querynode @@ -521,9 +490,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, req *querypb.Relea log.Info("received release partitions request") // check node healthy - if !node.lifetime.Add(commonpbutil.IsHealthy) { - msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID()) - err := merr.WrapErrServiceNotReady(msg) + if err := node.lifetime.Add(merr.IsHealthy); err != nil { return merr.Status(err), nil } defer node.lifetime.Done() @@ -536,7 +503,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, req *querypb.Relea } log.Info("release partitions done") - return util.SuccessStatus(), nil + return merr.Success(), nil } // ReleaseSegments remove the specified segments from query node according segmentIDs, partitionIDs, and collectionID @@ -554,20 +521,14 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, req *querypb.Release ) // check node healthy - if !node.lifetime.Add(commonpbutil.IsHealthyOrStopping) { - msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID()) - err := merr.WrapErrServiceNotReady(msg) + if err := node.lifetime.Add(merr.IsHealthyOrStopping); err != nil { return merr.Status(err), nil } defer node.lifetime.Done() // check target matches - if req.GetBase().GetTargetID() != paramtable.GetNodeID() { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, - Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), paramtable.GetNodeID()), - } - return status, nil + if err := merr.CheckTargetID(req.GetBase()); err != nil { + return merr.Status(err), nil } if req.GetNeedTransfer() { @@ -575,7 +536,8 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, req *querypb.Release if !ok { msg := "failed to release segment, delegator not found" log.Warn(msg) - return util.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg), nil + err := merr.WrapErrChannelNotFound(req.GetShard()) + return merr.Status(err), nil } // when we try to release a segment, add it to pipeline's exclude list first @@ -597,10 +559,10 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, req *querypb.Release err := delegator.ReleaseSegments(ctx, req, false) if err != nil { log.Warn("delegator failed to release segment", zap.Error(err)) - return util.WrapStatus(commonpb.ErrorCode_UnexpectedError, err.Error()), nil + return merr.Status(err), nil } - return util.SuccessStatus(), nil + return merr.Success(), nil } log.Info("start to release segments") @@ -611,14 +573,12 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, req *querypb.Release } node.manager.Collection.Unref(req.GetCollectionID(), uint32(sealedCount)) - return util.SuccessStatus(), nil + return merr.Success(), nil } // GetSegmentInfo returns segment information of the collection on the queryNode, and the information includes memSize, numRow, indexName, indexID ... func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { - if !node.lifetime.Add(commonpbutil.IsHealthy) { - msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID()) - err := merr.WrapErrServiceNotReady(msg) + if err := node.lifetime.Add(merr.IsHealthy); err != nil { return &querypb.GetSegmentInfoResponse{ Status: merr.Status(err), }, nil @@ -670,7 +630,7 @@ func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmen } return &querypb.GetSegmentInfoResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Infos: segmentInfos, }, nil } @@ -685,16 +645,16 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe zap.String("scope", req.GetScope().String()), ) - failRet := WrapSearchResult(commonpb.ErrorCode_UnexpectedError, "") - if !node.lifetime.Add(commonpbutil.IsHealthy) { - failRet.Status = merr.Status(merr.WrapErrServiceNotReady(fmt.Sprintf("node id: %d is unhealthy", paramtable.GetNodeID()))) - return failRet, nil + resp := &internalpb.SearchResults{} + if err := node.lifetime.Add(merr.IsHealthy); err != nil { + resp.Status = merr.Status(err) + return resp, nil } defer node.lifetime.Done() metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.TotalLabel, metrics.FromLeader).Inc() defer func() { - if failRet.Status.ErrorCode != commonpb.ErrorCode_Success { + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FailLabel, metrics.FromLeader).Inc() } }() @@ -712,22 +672,22 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe if collection == nil { err := merr.WrapErrCollectionNotLoaded(req.GetReq().GetCollectionID()) log.Warn("failed to search segments", zap.Error(err)) - failRet.Status = merr.Status(err) - return failRet, nil + resp.Status = merr.Status(err) + return resp, nil } task := tasks.NewSearchTask(searchCtx, collection, node.manager, req) if err := node.scheduler.Add(task); err != nil { log.Warn("failed to search channel", zap.Error(err)) - failRet.Status.Reason = err.Error() - return failRet, nil + resp.Status = merr.Status(err) + return resp, nil } err := task.Wait() if err != nil { log.Warn("failed to search segments", zap.Error(err)) - failRet.Status.Reason = err.Error() - return failRet, nil + resp.Status = merr.Status(err) + return resp, nil } tr.CtxElapse(ctx, fmt.Sprintf("search segments done, channel = %s, segmentIDs = %v", @@ -735,16 +695,14 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe req.GetSegmentIDs(), )) - // TODO QueryNodeSQLatencyInQueue QueryNodeReduceLatency - failRet.Status.ErrorCode = commonpb.ErrorCode_Success latency := tr.ElapseSpan() metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds())) metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.FromLeader).Inc() - result := task.Result() - result.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds() - result.GetCostAggregation().TotalNQ = node.scheduler.GetWaitingTaskTotalNQ() - return result, nil + resp = task.Result() + resp.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds() + resp.GetCostAggregation().TotalNQ = node.scheduler.GetWaitingTaskTotalNQ() + return resp, nil } // Search performs replica search tasks. @@ -766,27 +724,23 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( tr := timerecord.NewTimeRecorderWithTrace(ctx, "SearchRequest") - if !node.lifetime.Add(commonpbutil.IsHealthy) { - msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID()) - err := merr.WrapErrServiceNotReady(msg) + if err := node.lifetime.Add(merr.IsHealthy); err != nil { return &internalpb.SearchResults{ Status: merr.Status(err), }, nil } defer node.lifetime.Done() - if !CheckTargetID(req.GetReq()) { - targetID := req.GetReq().GetBase().GetTargetID() - log.Warn("target ID not match", - zap.Int64("targetID", targetID), - zap.Int64("nodeID", paramtable.GetNodeID()), - ) - return WrapSearchResult(commonpb.ErrorCode_NodeIDNotMatch, - common.WrapNodeIDNotMatchMsg(targetID, paramtable.GetNodeID())), nil + err := merr.CheckTargetID(req.GetReq().GetBase()) + if err != nil { + log.Warn("target ID check failed", zap.Error(err)) + return &internalpb.SearchResults{ + Status: merr.Status(err), + }, nil } failRet := &internalpb.SearchResults{ - Status: merr.Status(nil), + Status: merr.Success(), } collection := node.manager.Collection.Get(req.GetReq().GetCollectionID()) if collection == nil { @@ -827,8 +781,7 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( mu.Lock() defer mu.Unlock() if err != nil { - failRet.Status.Reason = err.Error() - failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError + failRet.Status = merr.Status(err) return err } if ret.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { @@ -846,8 +799,7 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( result, err := segments.ReduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType()) if err != nil { log.Warn("failed to reduce search results", zap.Error(err)) - failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError - failRet.Status.Reason = err.Error() + failRet.Status = merr.Status(err) return failRet, nil } reduceLatency := tr.RecordSpan() @@ -867,7 +819,9 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( // only used for delegator query segments from worker func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { - failRet := WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "") + resp := &internalpb.RetrieveResults{ + Status: merr.Success(), + } msgID := req.Req.Base.GetMsgID() traceID := trace.SpanFromContext(ctx).SpanContext().TraceID() channel := req.GetDmlChannels()[0] @@ -878,16 +832,15 @@ func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequ zap.String("scope", req.GetScope().String()), ) - if !node.lifetime.Add(commonpbutil.IsHealthy) { - err := merr.WrapErrServiceUnavailable(fmt.Sprintf("node id: %d is unhealthy", paramtable.GetNodeID())) - failRet.Status = merr.Status(err) - return failRet, nil + if err := node.lifetime.Add(merr.IsHealthy); err != nil { + resp.Status = merr.Status(err) + return resp, nil } defer node.lifetime.Done() metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.FromLeader).Inc() defer func() { - if failRet.Status.ErrorCode != commonpb.ErrorCode_Success { + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.FromLeader).Inc() } }() @@ -903,22 +856,22 @@ func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequ tr := timerecord.NewTimeRecorder("querySegments") collection := node.manager.Collection.Get(req.Req.GetCollectionID()) if collection == nil { - failRet.Status = merr.Status(merr.WrapErrCollectionNotLoaded(req.Req.GetCollectionID())) - return failRet, nil + resp.Status = merr.Status(merr.WrapErrCollectionNotLoaded(req.Req.GetCollectionID())) + return resp, nil } // Send task to scheduler and wait until it finished. task := tasks.NewQueryTask(queryCtx, collection, node.manager, req) if err := node.scheduler.Add(task); err != nil { log.Warn("failed to add query task into scheduler", zap.Error(err)) - failRet.Status = merr.Status(err) - return failRet, nil + resp.Status = merr.Status(err) + return resp, nil } err := task.Wait() if err != nil { log.Warn("failed to query channel", zap.Error(err)) - failRet.Status = merr.Status(err) - return failRet, nil + resp.Status = merr.Status(err) + return resp, nil } tr.CtxElapse(ctx, fmt.Sprintf("do query done, traceID = %s, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", @@ -928,7 +881,6 @@ func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequ req.GetSegmentIDs(), )) - failRet.Status.ErrorCode = commonpb.ErrorCode_Success // TODO QueryNodeSQLatencyInQueue QueryNodeReduceLatency latency := tr.ElapseSpan() metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds())) @@ -960,23 +912,19 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i ) tr := timerecord.NewTimeRecorderWithTrace(ctx, "QueryRequest") - if !node.lifetime.Add(commonpbutil.IsHealthy) { - msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID()) - err := merr.WrapErrServiceNotReady(msg) + if err := node.lifetime.Add(merr.IsHealthy); err != nil { return &internalpb.RetrieveResults{ Status: merr.Status(err), }, nil } defer node.lifetime.Done() - if !CheckTargetID(req.GetReq()) { - targetID := req.GetReq().GetBase().GetTargetID() - log.Warn("target ID not match", - zap.Int64("targetID", targetID), - zap.Int64("nodeID", paramtable.GetNodeID()), - ) - return WrapRetrieveResult(commonpb.ErrorCode_NodeIDNotMatch, - common.WrapNodeIDNotMatchMsg(targetID, paramtable.GetNodeID())), nil + err := merr.CheckTargetID(req.GetReq().GetBase()) + if err != nil { + log.Warn("target ID check failed", zap.Error(err)) + return &internalpb.RetrieveResults{ + Status: merr.Status(err), + }, nil } toMergeResults := make([]*internalpb.RetrieveResults, len(req.GetDmlChannels())) @@ -995,25 +943,29 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i idx := i runningGp.Go(func() error { ret, err := node.queryChannel(runningCtx, req, ch) + if err == nil { + err = merr.Error(ret.GetStatus()) + } if err != nil { return err } - if ret.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - return fmt.Errorf("%s", ret.Status.Reason) - } toMergeResults[idx] = ret return nil }) } if err := runningGp.Wait(); err != nil { - return WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "failed to query channel", err), nil + return &internalpb.RetrieveResults{ + Status: merr.Status(err), + }, nil } tr.RecordSpan() reducer := segments.CreateInternalReducer(req, node.manager.Collection.Get(req.GetReq().GetCollectionID()).Schema()) ret, err := reducer.Reduce(ctx, toMergeResults) if err != nil { - return WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "failed to query channel", err), nil + return &internalpb.RetrieveResults{ + Status: merr.Status(err), + }, nil } reduceLatency := tr.RecordSpan() metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.ReduceShards). @@ -1030,15 +982,132 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i return ret, nil } +func (node *QueryNode) QueryStream(req *querypb.QueryRequest, srv querypb.QueryNode_QueryStreamServer) error { + ctx := srv.Context() + log := log.Ctx(ctx).With( + zap.Int64("collectionID", req.GetReq().GetCollectionID()), + zap.Strings("shards", req.GetDmlChannels()), + ) + concurrentSrv := streamrpc.NewConcurrentQueryStreamServer(srv) + + log.Debug("received query stream request", + zap.Int64s("outputFields", req.GetReq().GetOutputFieldsId()), + zap.Int64s("segmentIDs", req.GetSegmentIDs()), + zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()), + zap.Uint64("mvccTimestamp", req.GetReq().GetMvccTimestamp()), + zap.Bool("isCount", req.GetReq().GetIsCount()), + ) + + if err := node.lifetime.Add(merr.IsHealthy); err != nil { + concurrentSrv.Send(&internalpb.RetrieveResults{Status: merr.Status(err)}) + return nil + } + defer node.lifetime.Done() + + err := merr.CheckTargetID(req.GetReq().GetBase()) + if err != nil { + log.Warn("target ID check failed", zap.Error(err)) + return err + } + + runningGp, runningCtx := errgroup.WithContext(ctx) + + for _, ch := range req.GetDmlChannels() { + ch := ch + req := &querypb.QueryRequest{ + Req: req.Req, + DmlChannels: []string{ch}, + SegmentIDs: req.SegmentIDs, + FromShardLeader: req.FromShardLeader, + Scope: req.Scope, + } + + runningGp.Go(func() error { + err := node.queryChannelStream(runningCtx, req, ch, concurrentSrv) + if err != nil { + return err + } + return nil + }) + } + + if err := runningGp.Wait(); err != nil { + concurrentSrv.Send(&internalpb.RetrieveResults{ + Status: merr.Status(err), + }) + return nil + } + + collector.Rate.Add(metricsinfo.NQPerSecond, 1) + metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req))) + return nil +} + +func (node *QueryNode) QueryStreamSegments(req *querypb.QueryRequest, srv querypb.QueryNode_QueryStreamSegmentsServer) error { + ctx := srv.Context() + msgID := req.Req.Base.GetMsgID() + traceID := trace.SpanFromContext(ctx).SpanContext().TraceID() + channel := req.GetDmlChannels()[0] + concurrentSrv := streamrpc.NewConcurrentQueryStreamServer(srv) + + log := log.Ctx(ctx).With( + zap.Int64("msgID", msgID), + zap.Int64("collectionID", req.GetReq().GetCollectionID()), + zap.String("channel", channel), + zap.String("scope", req.GetScope().String()), + ) + + resp := &internalpb.RetrieveResults{} + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.TotalLabel, metrics.FromLeader).Inc() + defer func() { + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.FromLeader).Inc() + } + }() + + if err := node.lifetime.Add(merr.IsHealthy); err != nil { + resp.Status = merr.Status(err) + concurrentSrv.Send(resp) + return nil + } + defer node.lifetime.Done() + + log.Debug("start do query with channel", + zap.Bool("fromShardLeader", req.GetFromShardLeader()), + zap.Int64s("segmentIDs", req.GetSegmentIDs()), + ) + + tr := timerecord.NewTimeRecorder("queryChannel") + + err := node.queryStreamSegments(ctx, req, concurrentSrv) + if err != nil { + resp.Status = merr.Status(err) + concurrentSrv.Send(resp) + return nil + } + + tr.CtxElapse(ctx, fmt.Sprintf("do query done, traceID = %s, fromSharedLeader = %t, vChannel = %s, segmentIDs = %v", + traceID, + req.GetFromShardLeader(), + channel, + req.GetSegmentIDs(), + )) + + // TODO QueryNodeSQLatencyInQueue QueryNodeReduceLatency + latency := tr.ElapseSpan() + metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds())) + metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.FromLeader).Inc() + return nil +} + // SyncReplicaSegments syncs replica node & segments states func (node *QueryNode) SyncReplicaSegments(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) { - return util.SuccessStatus(), nil + return merr.Success(), nil } // ShowConfigurations returns the configurations of queryNode matching req.Pattern func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - if !node.lifetime.Add(commonpbutil.IsHealthy) { - err := merr.WrapErrServiceNotReady(fmt.Sprintf("node id: %d is unhealthy", paramtable.GetNodeID())) + if err := node.lifetime.Add(merr.IsHealthy); err != nil { log.Warn("QueryNode.ShowConfigurations failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.String("req", req.Pattern), @@ -1061,15 +1130,14 @@ func (node *QueryNode) ShowConfigurations(ctx context.Context, req *internalpb.S } return &internalpb.ShowConfigurationsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Configuations: configList, }, nil } // GetMetrics return system infos of the query node, such as total memory, memory usage, cpu usage ... func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - if !node.lifetime.Add(commonpbutil.IsHealthy) { - err := merr.WrapErrServiceNotReady(fmt.Sprintf("node id: %d is unhealthy", paramtable.GetNodeID())) + if err := node.lifetime.Add(merr.IsHealthy); err != nil { log.Warn("QueryNode.GetMetrics failed", zap.Int64("nodeId", paramtable.GetNodeID()), zap.String("req", req.Request), @@ -1090,10 +1158,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR zap.Error(err)) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } @@ -1106,10 +1171,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR zap.String("metricType", metricType), zap.Error(err)) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), }, nil } log.RatedDebug(50, "QueryNode.GetMetrics", @@ -1127,11 +1189,7 @@ func (node *QueryNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsR zap.String("metricType", metricType)) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: metricsinfo.MsgUnimplementedMetric, - }, - Response: "", + Status: merr.Status(merr.WrapErrMetricNotFound(metricType)), }, nil } @@ -1140,9 +1198,8 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get zap.Int64("msgID", req.GetBase().GetMsgID()), zap.Int64("nodeID", paramtable.GetNodeID()), ) - if !node.lifetime.Add(commonpbutil.IsHealthy) { - err := merr.WrapErrServiceNotReady(fmt.Sprintf("node id: %d is unhealthy", paramtable.GetNodeID())) - log.Warn("QueryNode.GetMetrics failed", + if err := node.lifetime.Add(merr.IsHealthy); err != nil { + log.Warn("QueryNode.GetDataDistribution failed", zap.Error(err)) return &querypb.GetDataDistributionResponse{ @@ -1152,12 +1209,10 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get defer node.lifetime.Done() // check target matches - if req.GetBase().GetTargetID() != paramtable.GetNodeID() { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, - Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), paramtable.GetNodeID()), - } - return &querypb.GetDataDistributionResponse{Status: status}, nil + if err := merr.CheckTargetID(req.GetBase()); err != nil { + return &querypb.GetDataDistributionResponse{ + Status: merr.Status(err), + }, nil } sealedSegments := node.manager.Segment.GetBy(segments.WithType(commonpb.SegmentState_Sealed)) @@ -1179,17 +1234,17 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get channelVersionInfos := make([]*querypb.ChannelVersionInfo, 0) leaderViews := make([]*querypb.LeaderView, 0) - node.delegators.Range(func(key string, value delegator.ShardDelegator) bool { - if !value.Serviceable() { + node.delegators.Range(func(key string, delegator delegator.ShardDelegator) bool { + if !delegator.Serviceable() { return true } channelVersionInfos = append(channelVersionInfos, &querypb.ChannelVersionInfo{ Channel: key, - Collection: value.Collection(), - Version: value.Version(), + Collection: delegator.Collection(), + Version: delegator.Version(), }) - sealed, growing := value.GetSegmentInfo(false) + sealed, growing := delegator.GetSegmentInfo(false) sealedSegments := make(map[int64]*querypb.SegmentDist) for _, item := range sealed { for _, segment := range item.Segments { @@ -1212,17 +1267,17 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get } leaderViews = append(leaderViews, &querypb.LeaderView{ - Collection: value.Collection(), + Collection: delegator.Collection(), Channel: key, SegmentDist: sealedSegments, GrowingSegments: growingSegments, - TargetVersion: value.GetTargetVersion(), + TargetVersion: delegator.GetTargetVersion(), }) return true }) return &querypb.GetDataDistributionResponse{ - Status: merr.Status(nil), + Status: merr.Success(), NodeID: paramtable.GetNodeID(), Segments: segmentVersionInfos, Channels: channelVersionInfos, @@ -1234,31 +1289,22 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi log := log.Ctx(ctx).With(zap.Int64("collectionID", req.GetCollectionID()), zap.String("channel", req.GetChannel()), zap.Int64("currentNodeID", paramtable.GetNodeID())) // check node healthy - if !node.lifetime.Add(commonpbutil.IsHealthy) { - msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID()) - err := merr.WrapErrServiceNotReady(msg) + if err := node.lifetime.Add(merr.IsHealthy); err != nil { return merr.Status(err), nil } defer node.lifetime.Done() // check target matches - if req.GetBase().GetTargetID() != paramtable.GetNodeID() { - log.Warn("failed to do match target id when sync ", zap.Int64("expect", req.GetBase().GetTargetID()), zap.Int64("actual", node.session.ServerID)) - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, - Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), paramtable.GetNodeID()), - } - return status, nil + if err := merr.CheckTargetID(req.GetBase()); err != nil { + return merr.Status(err), nil } // get shard delegator shardDelegator, ok := node.delegators.Get(req.GetChannel()) if !ok { + err := merr.WrapErrChannelNotFound(req.GetChannel()) log.Warn("failed to find shard cluster when sync") - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "shard not exist", - }, nil + return merr.Status(err), nil } // translate segment action @@ -1266,17 +1312,20 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi addSegments := make(map[int64][]*querypb.SegmentLoadInfo) for _, action := range req.GetActions() { log := log.With(zap.String("Action", - action.GetType().String()), - zap.Int64("segmentID", action.SegmentID), - zap.Int64("TargetVersion", action.GetTargetVersion()), - ) - log.Info("sync action") + action.GetType().String())) switch action.GetType() { case querypb.SyncType_Remove: + log.Info("sync action", zap.Int64("segmentID", action.SegmentID)) removeActions = append(removeActions, action) case querypb.SyncType_Set: + log.Info("sync action", zap.Int64("segmentID", action.SegmentID)) + if action.GetInfo() == nil { + log.Warn("sync request from legacy querycoord without load info, skip") + continue + } addSegments[action.GetNodeID()] = append(addSegments[action.GetNodeID()], action.GetInfo()) case querypb.SyncType_UpdateVersion: + log.Info("sync action", zap.Int64("TargetVersion", action.GetTargetVersion())) pipeline := node.pipelineManager.Get(req.GetChannel()) if pipeline != nil { droppedInfos := lo.Map(action.GetDroppedInTarget(), func(id int64, _ int) *datapb.SegmentInfo { @@ -1292,10 +1341,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi shardDelegator.SyncTargetVersion(action.GetTargetVersion(), action.GetGrowingInTarget(), action.GetSealedInTarget(), action.GetDroppedInTarget()) default: - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "unexpected action type", - }, nil + return merr.Status(merr.WrapErrServiceInternal("unknown action type", action.GetType().String())), nil } } @@ -1316,7 +1362,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi LoadScope: querypb.LoadScope_Delta, }) if err != nil { - return util.WrapStatus(commonpb.ErrorCode_UnexpectedError, "failed to sync(load) segment", err), nil + return merr.Status(err), nil } } @@ -1329,7 +1375,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi }, true) } - return merr.Status(nil), nil + return merr.Success(), nil } // Delete is used to forward delete message between delegator and workers. @@ -1341,20 +1387,14 @@ func (node *QueryNode) Delete(ctx context.Context, req *querypb.DeleteRequest) ( ) // check node healthy - if !node.lifetime.Add(commonpbutil.IsHealthy) { - msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID()) - err := merr.WrapErrServiceNotReady(msg) + if err := node.lifetime.Add(merr.IsHealthy); err != nil { return merr.Status(err), nil } defer node.lifetime.Done() // check target matches - if req.GetBase().GetTargetID() != paramtable.GetNodeID() { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_NodeIDNotMatch, - Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), paramtable.GetNodeID()), - } - return status, nil + if err := merr.CheckTargetID(req.GetBase()); err != nil { + return merr.Status(err), nil } log.Info("QueryNode received worker delete request") @@ -1367,11 +1407,7 @@ func (node *QueryNode) Delete(ctx context.Context, req *querypb.DeleteRequest) ( if len(segments) == 0 { err := merr.WrapErrSegmentNotFound(req.GetSegmentId()) log.Warn("segment not found for delete") - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_SegmentNotFound, - Reason: fmt.Sprintf("segment %d not found", req.GetSegmentId()), - Code: merr.Code(err), - }, nil + return merr.Status(err), nil } pks := storage.ParseIDs2PrimaryKeys(req.GetPrimaryKeys()) @@ -1379,12 +1415,9 @@ func (node *QueryNode) Delete(ctx context.Context, req *querypb.DeleteRequest) ( err := segment.Delete(pks, req.GetTimestamps()) if err != nil { log.Warn("segment delete failed", zap.Error(err)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: fmt.Sprintf("delete on segment %d failed, %s", req.GetSegmentId(), err.Error()), - }, nil + return merr.Status(err), nil } } - return merr.Status(nil), nil + return merr.Success(), nil } diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index 372d01076fd8c..c5eb1895cd017 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -18,7 +18,9 @@ package querynodev2 import ( "context" "encoding/json" + "io" "math/rand" + "sync" "testing" "time" @@ -41,6 +43,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/etcd" @@ -110,7 +113,7 @@ func (suite *ServiceSuite) SetupTest() { suite.factory = dependency.NewMockFactory(suite.T()) suite.msgStream = msgstream.NewMockMsgStream(suite.T()) // TODO:: cpp chunk manager not support local chunk manager - //suite.chunkManagerFactory = storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus-test")) + // suite.chunkManagerFactory = storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus-test")) suite.chunkManagerFactory = segments.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) suite.factory.EXPECT().Init(mock.Anything).Return() suite.factory.EXPECT().NewPersistentStorageChunkManager(mock.Anything).Return(suite.chunkManagerFactory.NewPersistentStorageChunkManager(ctx)) @@ -163,31 +166,31 @@ func (suite *ServiceSuite) TearDownTest() { func (suite *ServiceSuite) TestGetComponentStatesNormal() { ctx := context.Background() suite.node.session.UpdateRegistered(true) - rsp, err := suite.node.GetComponentStates(ctx) + rsp, err := suite.node.GetComponentStates(ctx, nil) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_Success, rsp.Status.ErrorCode) + suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode()) suite.Equal(commonpb.StateCode_Healthy, rsp.State.StateCode) // after update suite.node.UpdateStateCode(commonpb.StateCode_Abnormal) - rsp, err = suite.node.GetComponentStates(ctx) + rsp, err = suite.node.GetComponentStates(ctx, nil) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_Success, rsp.Status.ErrorCode) + suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode()) suite.Equal(commonpb.StateCode_Abnormal, rsp.State.StateCode) } func (suite *ServiceSuite) TestGetTimeTiclChannel_Normal() { ctx := context.Background() - rsp, err := suite.node.GetTimeTickChannel(ctx) + rsp, err := suite.node.GetTimeTickChannel(ctx, nil) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_Success, rsp.Status.ErrorCode) + suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode()) } func (suite *ServiceSuite) TestGetStatisChannel_Normal() { ctx := context.Background() - rsp, err := suite.node.GetStatisticsChannel(ctx) + rsp, err := suite.node.GetStatisticsChannel(ctx, nil) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_Success, rsp.Status.ErrorCode) + suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode()) } func (suite *ServiceSuite) TestGetStatistics_Normal() { @@ -211,7 +214,7 @@ func (suite *ServiceSuite) TestGetStatistics_Normal() { rsp, err := suite.node.GetStatistics(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_Success, rsp.Status.ErrorCode) + suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode()) } func (suite *ServiceSuite) TestGetStatistics_Failed() { @@ -280,8 +283,8 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() { // mocks suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil) - suite.msgStream.EXPECT().AsConsumer([]string{suite.pchannel}, mock.Anything, mock.Anything).Return() - suite.msgStream.EXPECT().Seek(mock.Anything).Return(nil) + suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil) + suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil) suite.msgStream.EXPECT().Chan().Return(suite.msgChan) suite.msgStream.EXPECT().Close() @@ -329,8 +332,8 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() { // mocks suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil) - suite.msgStream.EXPECT().AsConsumer([]string{suite.pchannel}, mock.Anything, mock.Anything).Return() - suite.msgStream.EXPECT().Seek(mock.Anything).Return(nil) + suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil) + suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil) suite.msgStream.EXPECT().Chan().Return(suite.msgChan) suite.msgStream.EXPECT().Close() @@ -382,9 +385,9 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() { // init msgstream failed suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil) - suite.msgStream.EXPECT().AsConsumer([]string{suite.pchannel}, mock.Anything, mock.Anything).Return() + suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil) suite.msgStream.EXPECT().Close().Return() - suite.msgStream.EXPECT().Seek(mock.Anything).Return(errors.New("mock error")) + suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything).Return(errors.New("mock error")) status, err = suite.node.WatchDmChannels(ctx, req) suite.NoError(err) @@ -516,23 +519,26 @@ func (suite *ServiceSuite) TestLoadSegments_Int64() { suite.TestWatchDmChannelsInt64() // data schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) - req := &querypb.LoadSegmentsRequest{ - Base: &commonpb.MsgBase{ - MsgID: rand.Int63(), - TargetID: suite.node.session.ServerID, - }, - CollectionID: suite.collectionID, - DstNodeID: suite.node.session.ServerID, - Infos: suite.genSegmentLoadInfos(schema), - Schema: schema, - DeltaPositions: []*msgpb.MsgPosition{{Timestamp: 20000}}, - NeedTransfer: true, - } + infos := suite.genSegmentLoadInfos(schema) + for _, info := range infos { + req := &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgID: rand.Int63(), + TargetID: suite.node.session.ServerID, + }, + CollectionID: suite.collectionID, + DstNodeID: suite.node.session.ServerID, + Infos: []*querypb.SegmentLoadInfo{info}, + Schema: schema, + DeltaPositions: []*msgpb.MsgPosition{{Timestamp: 20000}}, + NeedTransfer: true, + } - // LoadSegment - status, err := suite.node.LoadSegments(ctx, req) - suite.NoError(err) - suite.Equal(commonpb.ErrorCode_Success, status.GetErrorCode()) + // LoadSegment + status, err := suite.node.LoadSegments(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_Success, status.GetErrorCode()) + } } func (suite *ServiceSuite) TestLoadSegments_VarChar() { @@ -547,24 +553,28 @@ func (suite *ServiceSuite) TestLoadSegments_VarChar() { } suite.node.manager.Collection = segments.NewCollectionManager() suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, nil, loadMeta) - req := &querypb.LoadSegmentsRequest{ - Base: &commonpb.MsgBase{ - MsgID: rand.Int63(), - TargetID: suite.node.session.ServerID, - }, - CollectionID: suite.collectionID, - DstNodeID: suite.node.session.ServerID, - Infos: suite.genSegmentLoadInfos(schema), - Schema: schema, - DeltaPositions: []*msgpb.MsgPosition{{Timestamp: 20000}}, - NeedTransfer: true, - LoadMeta: loadMeta, - } - // LoadSegment - status, err := suite.node.LoadSegments(ctx, req) - suite.NoError(err) - suite.Equal(commonpb.ErrorCode_Success, status.GetErrorCode()) + infos := suite.genSegmentLoadInfos(schema) + for _, info := range infos { + req := &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgID: rand.Int63(), + TargetID: suite.node.session.ServerID, + }, + CollectionID: suite.collectionID, + DstNodeID: suite.node.session.ServerID, + Infos: []*querypb.SegmentLoadInfo{info}, + Schema: schema, + DeltaPositions: []*msgpb.MsgPosition{{Timestamp: 20000}}, + NeedTransfer: true, + LoadMeta: loadMeta, + } + + // LoadSegment + status, err := suite.node.LoadSegments(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_Success, status.GetErrorCode()) + } } func (suite *ServiceSuite) TestLoadDeltaInt64() { @@ -770,20 +780,19 @@ func (suite *ServiceSuite) TestLoadSegments_Failed() { // Delegator not found status, err := suite.node.LoadSegments(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_UnexpectedError, status.GetErrorCode()) - suite.Contains(status.GetReason(), "failed to load segments, delegator not found") + suite.ErrorIs(merr.Error(status), merr.ErrChannelNotFound) // target not match req.Base.TargetID = -1 status, err = suite.node.LoadSegments(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_NodeIDNotMatch, status.GetErrorCode()) + suite.ErrorIs(merr.Error(status), merr.ErrNodeNotMatch) // node not healthy suite.node.UpdateStateCode(commonpb.StateCode_Abnormal) status, err = suite.node.LoadSegments(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_NotReadyServe, status.GetErrorCode()) + suite.ErrorIs(merr.Error(status), merr.ErrServiceNotReady) } func (suite *ServiceSuite) TestLoadSegments_Transfer() { @@ -1179,8 +1188,7 @@ func (suite *ServiceSuite) TestSearch_Failed() { // Delegator not found resp, err = suite.node.Search(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - suite.Contains(resp.GetStatus().GetReason(), merr.ErrServiceUnavailable.Error()) + suite.ErrorIs(merr.Error(resp.GetStatus()), merr.ErrChannelNotFound) suite.TestWatchDmChannelsInt64() suite.TestLoadSegments_Int64() @@ -1309,7 +1317,8 @@ func (suite *ServiceSuite) TestQuery_Normal() { } func (suite *ServiceSuite) TestQuery_Failed() { - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // data schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) @@ -1324,8 +1333,7 @@ func (suite *ServiceSuite) TestQuery_Failed() { // Delegator not found resp, err := suite.node.Query(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) - suite.Contains(resp.GetStatus().GetReason(), merr.ErrServiceUnavailable.Error()) + suite.ErrorIs(merr.Error(resp.GetStatus()), merr.ErrChannelNotFound) suite.TestWatchDmChannelsInt64() suite.TestLoadSegments_Int64() @@ -1372,8 +1380,145 @@ func (suite *ServiceSuite) TestQuerySegments_Failed() { suite.Equal(commonpb.ErrorCode_UnexpectedError, rsp.GetStatus().GetErrorCode()) } +func (suite *ServiceSuite) TestQueryStream_Normal() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // prepare + suite.TestWatchDmChannelsInt64() + suite.TestLoadSegments_Int64() + + // data + schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) + creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) + suite.NoError(err) + req := &querypb.QueryRequest{ + Req: creq, + FromShardLeader: false, + DmlChannels: []string{suite.vchannel}, + } + + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + go func() { + err := suite.node.QueryStream(req, server) + suite.NoError(err) + server.FinishSend(err) + }() + + for { + result, err := client.Recv() + if err == io.EOF { + break + } + suite.NoError(err) + + err = merr.Error(result.GetStatus()) + suite.NoError(err) + } +} + +func (suite *ServiceSuite) TestQueryStream_Failed() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // data + schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) + creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) + suite.NoError(err) + req := &querypb.QueryRequest{ + Req: creq, + FromShardLeader: false, + DmlChannels: []string{suite.vchannel}, + } + + queryFunc := func(wg *sync.WaitGroup, req *querypb.QueryRequest, client *streamrpc.LocalQueryClient) { + server := client.CreateServer() + + defer wg.Done() + err := suite.node.QueryStream(req, server) + suite.NoError(err) + server.FinishSend(err) + } + + // Delegator not found + suite.Run("delegator not found", func() { + client := streamrpc.NewLocalQueryClient(ctx) + wg := &sync.WaitGroup{} + wg.Add(1) + go queryFunc(wg, req, client) + + for { + result, err := client.Recv() + if err == io.EOF { + break + } + suite.NoError(err) + + err = merr.Error(result.GetStatus()) + // Check result + if err != nil { + suite.ErrorIs(err, merr.ErrChannelNotFound) + } + } + wg.Wait() + }) + + // prepare + suite.TestWatchDmChannelsInt64() + suite.TestLoadSegments_Int64() + + // target not match + suite.Run("target not match", func() { + client := streamrpc.NewLocalQueryClient(ctx) + wg := &sync.WaitGroup{} + wg.Add(1) + go queryFunc(wg, req, client) + + for { + result, err := client.Recv() + if err == io.EOF { + break + } + suite.NoError(err) + + err = merr.Error(result.GetStatus()) + if err != nil { + suite.Equal(commonpb.ErrorCode_NodeIDNotMatch, result.GetStatus().GetErrorCode()) + } + } + wg.Wait() + }) + + // node not healthy + suite.Run("node not healthy", func() { + suite.node.UpdateStateCode(commonpb.StateCode_Abnormal) + client := streamrpc.NewLocalQueryClient(ctx) + wg := &sync.WaitGroup{} + wg.Add(1) + go queryFunc(wg, req, client) + + for { + result, err := client.Recv() + if err == io.EOF { + break + } + suite.NoError(err) + + err = merr.Error(result.GetStatus()) + if err != nil { + suite.True(errors.Is(err, merr.ErrServiceNotReady)) + } + } + wg.Wait() + }) +} + func (suite *ServiceSuite) TestQuerySegments_Normal() { - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // pre suite.TestWatchDmChannelsInt64() suite.TestLoadSegments_Int64() @@ -1393,13 +1538,55 @@ func (suite *ServiceSuite) TestQuerySegments_Normal() { suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode()) } +func (suite *ServiceSuite) TestQueryStreamSegments_Normal() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // pre + suite.TestWatchDmChannelsInt64() + suite.TestLoadSegments_Int64() + + // data + schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) + creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) + suite.NoError(err) + req := &querypb.QueryRequest{ + Req: creq, + FromShardLeader: true, + DmlChannels: []string{suite.vchannel}, + } + + client := streamrpc.NewLocalQueryClient(ctx) + server := client.CreateServer() + + go func() { + err := suite.node.QueryStreamSegments(req, server) + suite.NoError(err) + server.FinishSend(err) + }() + + for { + result, err := client.Recv() + if err == io.EOF { + break + } + suite.NoError(err) + + err = merr.Error(result.GetStatus()) + suite.NoError(err) + // Check result + if !errors.Is(err, nil) { + suite.NoError(err) + break + } + } +} + func (suite *ServiceSuite) TestSyncReplicaSegments_Normal() { ctx := context.Background() req := &querypb.SyncReplicaSegmentsRequest{} status, err := suite.node.SyncReplicaSegments(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_Success, status.ErrorCode) - } func (suite *ServiceSuite) TestShowConfigurations_Normal() { @@ -1414,7 +1601,7 @@ func (suite *ServiceSuite) TestShowConfigurations_Normal() { resp, err := suite.node.ShowConfigurations(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode) + suite.Equal(commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) suite.Equal(1, len(resp.Configuations)) } @@ -1452,7 +1639,7 @@ func (suite *ServiceSuite) TestGetMetric_Normal() { resp, err := suite.node.GetMetrics(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode) + suite.Equal(commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } func (suite *ServiceSuite) TestGetMetric_Failed() { @@ -1473,20 +1660,20 @@ func (suite *ServiceSuite) TestGetMetric_Failed() { resp, err := suite.node.GetMetrics(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode) - suite.Equal(metricsinfo.MsgUnimplementedMetric, resp.Status.Reason) + err = merr.Error(resp.GetStatus()) + suite.ErrorIs(err, merr.ErrMetricNotFound) // metric parse failed req.Request = "---" resp, err = suite.node.GetMetrics(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode) + suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) // node unhealthy suite.node.UpdateStateCode(commonpb.StateCode_Abnormal) resp, err = suite.node.GetMetrics(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_NotReadyServe, resp.Status.ErrorCode) + suite.Equal(commonpb.ErrorCode_NotReadyServe, resp.GetStatus().GetErrorCode()) } func (suite *ServiceSuite) TestGetDataDistribution_Normal() { @@ -1503,7 +1690,7 @@ func (suite *ServiceSuite) TestGetDataDistribution_Normal() { resp, err := suite.node.GetDataDistribution(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_Success, resp.Status.ErrorCode) + suite.Equal(commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } func (suite *ServiceSuite) TestGetDataDistribution_Failed() { @@ -1560,7 +1747,7 @@ func (suite *ServiceSuite) TestSyncDistribution_Normal() { req.Actions = []*querypb.SyncAction{releaseAction, setAction} status, err := suite.node.SyncDistribution(ctx, req) suite.NoError(err) - suite.Equal(commonpb.ErrorCode_UnexpectedError, status.ErrorCode) + suite.Equal(commonpb.ErrorCode_Success, status.ErrorCode) syncVersionAction := &querypb.SyncAction{ Type: querypb.SyncType_UpdateVersion, @@ -1573,6 +1760,7 @@ func (suite *ServiceSuite) TestSyncDistribution_Normal() { req.Actions = []*querypb.SyncAction{syncVersionAction} status, err = suite.node.SyncDistribution(ctx, req) suite.NoError(err) + suite.Equal(commonpb.ErrorCode_Success, status.GetErrorCode()) } func (suite *ServiceSuite) TestSyncDistribution_ReleaseResultCheck() { diff --git a/internal/querynodev2/tasks/concurrent_safe_scheduler.go b/internal/querynodev2/tasks/concurrent_safe_scheduler.go index fe054c02e386a..7968cd172beef 100644 --- a/internal/querynodev2/tasks/concurrent_safe_scheduler.go +++ b/internal/querynodev2/tasks/concurrent_safe_scheduler.go @@ -1,17 +1,19 @@ package tasks import ( - "context" "fmt" + "sync" + + "go.uber.org/atomic" + "go.uber.org/zap" "github.com/milvus-io/milvus/internal/querynodev2/collector" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/conc" + "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" - "go.uber.org/atomic" - "go.uber.org/zap" ) const ( @@ -29,6 +31,7 @@ func newScheduler(policy schedulePolicy) Scheduler { execChan: make(chan Task), pool: conc.NewPool[any](maxReadConcurrency, conc.WithPreAlloc(true)), schedulerCounter: schedulerCounter{}, + lifetime: lifetime.NewLifetime(lifetime.Initializing), } } @@ -43,12 +46,23 @@ type scheduler struct { receiveChan chan addTaskReq execChan chan Task pool *conc.Pool[any] + + // wg is the waitgroup for internal worker goroutine + wg sync.WaitGroup + // lifetime controls scheduler State & make sure all requests accepted will be processed + lifetime lifetime.Lifetime[lifetime.State] + schedulerCounter } // Add a new task into scheduler, // error will be returned if scheduler reaches some limit. func (s *scheduler) Add(task Task) (err error) { + if err := s.lifetime.Add(lifetime.IsWorking); err != nil { + return err + } + defer s.lifetime.Done() + errCh := make(chan error, 1) // TODO: add operation should be fast, is UnsolveLen metric unnesscery? @@ -67,16 +81,31 @@ func (s *scheduler) Add(task Task) (err error) { // Start schedule the owned task asynchronously and continuously. // Start should be only call once. -func (s *scheduler) Start(ctx context.Context) { +func (s *scheduler) Start() { + s.wg.Add(2) + // Start a background task executing loop. - go s.exec(ctx) + go s.exec() // Begin to schedule tasks. - go s.schedule(ctx) + go s.schedule() + + s.lifetime.SetState(lifetime.Working) +} + +func (s *scheduler) Stop() { + s.lifetime.SetState(lifetime.Stopped) + // wait all accepted Add done + s.lifetime.Wait() + // close receiveChan start stopping process for `schedule` + close(s.receiveChan) + // wait workers quit + s.wg.Wait() } // schedule the owned task asynchronously and continuously. -func (s *scheduler) schedule(ctx context.Context) { +func (s *scheduler) schedule() { + defer s.wg.Done() var task Task for { s.setupReadyLenMetric() @@ -86,10 +115,19 @@ func (s *scheduler) schedule(ctx context.Context) { task, nq, execChan = s.setupExecListener(task) select { - case <-ctx.Done(): - log.Warn("unexpected quit of schedule loop") - return - case req := <-s.receiveChan: + case req, ok := <-s.receiveChan: + if !ok { + log.Info("receiveChan closed, processing remaining request") + // drain policy maintained task + for task != nil { + execChan <- task + s.updateWaitingTaskCounter(-1, -nq) + task = s.produceExecChan() + } + log.Info("all task put into exeChan, schedule worker exit") + close(s.execChan) + return + } // Receive add operation request and return the process result. // And consume recv chan as much as possible. s.consumeRecvChan(req, maxReceiveChanBatchConsumeNum) @@ -114,7 +152,10 @@ func (s *scheduler) consumeRecvChan(req addTaskReq, limit int) { // consume the add chan until reaching the batch operation limit for i := 1; i < limit; i++ { select { - case req := <-s.receiveChan: + case req, ok := <-s.receiveChan: + if !ok { + return + } if !s.handleAddTaskRequest(req, maxWaitTaskNum) { return } @@ -165,42 +206,42 @@ func (s *scheduler) produceExecChan() Task { } // exec exec the ready task in background continuously. -func (s *scheduler) exec(ctx context.Context) { +func (s *scheduler) exec() { + defer s.wg.Done() log.Info("start execute loop") for { - select { - case <-ctx.Done(): - log.Warn("unexpected quit of exec loop") + t, ok := <-s.execChan + if !ok { + log.Info("scheduler execChan closed, worker exit") return - case t := <-s.execChan: - // Skip this task if task is canceled. - if err := t.Canceled(); err != nil { - log.Warn("task canceled before executing", zap.Error(err)) - t.Done(err) - continue - } - if err := t.PreExecute(); err != nil { - log.Warn("failed to pre-execute task", zap.Error(err)) - t.Done(err) - continue - } + } + // Skip this task if task is canceled. + if err := t.Canceled(); err != nil { + log.Warn("task canceled before executing", zap.Error(err)) + t.Done(err) + continue + } + if err := t.PreExecute(); err != nil { + log.Warn("failed to pre-execute task", zap.Error(err)) + t.Done(err) + continue + } - s.pool.Submit(func() (any, error) { - // Update concurrency metric and notify task done. - metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc() - collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1) + s.pool.Submit(func() (any, error) { + // Update concurrency metric and notify task done. + metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Inc() + collector.Counter.Inc(metricsinfo.ExecuteQueueType, 1) - err := t.Execute() + err := t.Execute() - // Update all metric after task finished. - metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec() - collector.Counter.Dec(metricsinfo.ExecuteQueueType, -1) + // Update all metric after task finished. + metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Dec() + collector.Counter.Dec(metricsinfo.ExecuteQueueType, -1) - // Notify task done. - t.Done(err) - return nil, err - }) - } + // Notify task done. + t.Done(err) + return nil, err + }) } } diff --git a/internal/querynodev2/tasks/concurrent_safe_scheduler_test.go b/internal/querynodev2/tasks/concurrent_safe_scheduler_test.go index fd0ed081dffe4..69064f18fa8b6 100644 --- a/internal/querynodev2/tasks/concurrent_safe_scheduler_test.go +++ b/internal/querynodev2/tasks/concurrent_safe_scheduler_test.go @@ -7,10 +7,13 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "go.uber.org/atomic" + "github.com/milvus-io/milvus/pkg/util/conc" + "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" ) func TestScheduler(t *testing.T) { @@ -21,12 +24,33 @@ func TestScheduler(t *testing.T) { t.Run("fifo", func(t *testing.T) { testScheduler(t, newFIFOPolicy()) }) + t.Run("scheduler_not_working", func(t *testing.T) { + scheduler := newScheduler(newFIFOPolicy()) + + task := newMockTask(mockTaskConfig{ + nq: 1, + executeCost: 10 * time.Millisecond, + execution: func(ctx context.Context) error { + return nil + }, + }) + + err := scheduler.Add(task) + assert.Error(t, err) + + scheduler.Stop() + + err = scheduler.Add(task) + assert.Error(t, err) + }) + + suite.Run(t, new(SchedulerSuite)) } func testScheduler(t *testing.T, policy schedulePolicy) { // start a new scheduler scheduler := newScheduler(policy) - go scheduler.Start(context.Background()) + scheduler.Start() var cnt atomic.Int32 n := 100 @@ -79,3 +103,37 @@ func testScheduler(t *testing.T, policy schedulePolicy) { assert.Equal(t, 0, int(scheduler.GetWaitingTaskTotal())) assert.Equal(t, 0, int(scheduler.GetWaitingTaskTotalNQ())) } + +type SchedulerSuite struct { + suite.Suite +} + +func (s *SchedulerSuite) TestConsumeRecvChan() { + s.Run("consume_chan_closed", func() { + ch := make(chan addTaskReq, 10) + close(ch) + scheduler := &scheduler{ + policy: newFIFOPolicy(), + receiveChan: ch, + execChan: make(chan Task), + pool: conc.NewPool[any](10, conc.WithPreAlloc(true)), + schedulerCounter: schedulerCounter{}, + lifetime: lifetime.NewLifetime(lifetime.Initializing), + } + + task := newMockTask(mockTaskConfig{ + nq: 1, + executeCost: 10 * time.Millisecond, + execution: func(ctx context.Context) error { + return nil + }, + }) + + s.NotPanics(func() { + scheduler.consumeRecvChan(addTaskReq{ + task: task, + err: make(chan error, 1), + }, maxReceiveChanBatchConsumeNum) + }) + }) +} diff --git a/internal/querynodev2/tasks/policy_test.go b/internal/querynodev2/tasks/policy_test.go index fd4fcef719806..03ce1a811f042 100644 --- a/internal/querynodev2/tasks/policy_test.go +++ b/internal/querynodev2/tasks/policy_test.go @@ -4,8 +4,9 @@ import ( "fmt" "testing" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func TestUserTaskPollingPolicy(t *testing.T) { diff --git a/internal/querynodev2/tasks/query_stream_task.go b/internal/querynodev2/tasks/query_stream_task.go new file mode 100644 index 0000000000000..450e9e91a669c --- /dev/null +++ b/internal/querynodev2/tasks/query_stream_task.go @@ -0,0 +1,83 @@ +package tasks + +import ( + "context" + + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/querynodev2/segments" + "github.com/milvus-io/milvus/internal/util/streamrpc" +) + +var _ Task = &QueryStreamTask{} + +func NewQueryStreamTask(ctx context.Context, + collection *segments.Collection, + manager *segments.Manager, + req *querypb.QueryRequest, + srv streamrpc.QueryStreamServer, +) *QueryStreamTask { + return &QueryStreamTask{ + ctx: ctx, + collection: collection, + segmentManager: manager, + req: req, + srv: srv, + notifier: make(chan error, 1), + } +} + +type QueryStreamTask struct { + ctx context.Context + collection *segments.Collection + segmentManager *segments.Manager + req *querypb.QueryRequest + srv streamrpc.QueryStreamServer + notifier chan error +} + +// Return the username which task is belong to. +// Return "" if the task do not contain any user info. +func (t *QueryStreamTask) Username() string { + return t.req.Req.GetUsername() +} + +// PreExecute the task, only call once. +func (t *QueryStreamTask) PreExecute() error { + return nil +} + +func (t *QueryStreamTask) Execute() error { + retrievePlan, err := segments.NewRetrievePlan( + t.collection, + t.req.Req.GetSerializedExprPlan(), + t.req.Req.GetMvccTimestamp(), + t.req.Req.Base.GetMsgID(), + ) + if err != nil { + return err + } + defer retrievePlan.Delete() + + segments, err := segments.RetrieveStream(t.ctx, t.segmentManager, retrievePlan, t.req, t.srv) + defer t.segmentManager.Segment.Unpin(segments) + if err != nil { + return err + } + return nil +} + +func (t *QueryStreamTask) Done(err error) { + t.notifier <- err +} + +func (t *QueryStreamTask) Canceled() error { + return t.ctx.Err() +} + +func (t *QueryStreamTask) Wait() error { + return <-t.notifier +} + +func (t *QueryStreamTask) NQ() int64 { + return 1 +} diff --git a/internal/querynodev2/tasks/query_task.go b/internal/querynodev2/tasks/query_task.go index 5b1f658195340..73fff23414227 100644 --- a/internal/querynodev2/tasks/query_task.go +++ b/internal/querynodev2/tasks/query_task.go @@ -2,12 +2,13 @@ package tasks import ( "context" + "fmt" "strconv" + "time" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/querynodev2/collector" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/pkg/metrics" @@ -88,31 +89,8 @@ func (t *QueryTask) Execute() error { return err } defer retrievePlan.Delete() - - var ( - results []*segcorepb.RetrieveResults - searchedSegments []segments.Segment - ) - if t.req.GetScope() == querypb.DataScope_Historical { - results, searchedSegments, err = segments.RetrieveHistorical( - t.ctx, - t.segmentManager, - retrievePlan, - t.req.Req.CollectionID, - nil, - t.req.GetSegmentIDs(), - ) - } else { - results, searchedSegments, err = segments.RetrieveStreaming( - t.ctx, - t.segmentManager, - retrievePlan, - t.req.Req.CollectionID, - nil, - t.req.GetSegmentIDs(), - ) - } - defer t.segmentManager.Segment.Unpin(searchedSegments) + results, querySegments, err := segments.Retrieve(t.ctx, t.segmentManager, retrievePlan, t.req) + defer t.segmentManager.Segment.Unpin(querySegments) if err != nil { return err } @@ -121,8 +99,13 @@ func (t *QueryTask) Execute() error { t.req, t.collection.Schema(), ) - + beforeReduce := time.Now() reducedResult, err := reducer.Reduce(t.ctx, results) + + metrics.QueryNodeReduceLatency.WithLabelValues( + fmt.Sprint(paramtable.GetNodeID()), + metrics.QueryLabel, + metrics.ReduceSegments).Observe(float64(time.Since(beforeReduce).Milliseconds())) if err != nil { return err } @@ -131,7 +114,7 @@ func (t *QueryTask) Execute() error { Base: &commonpb.MsgBase{ SourceID: paramtable.GetNodeID(), }, - Status: merr.Status(nil), + Status: merr.Success(), Ids: reducedResult.Ids, FieldsData: reducedResult.FieldsData, CostAggregation: &internalpb.CostAggregation{ diff --git a/internal/querynodev2/tasks/task.go b/internal/querynodev2/tasks/task.go index 96a37fe2dec12..ee4abcc23ac1a 100644 --- a/internal/querynodev2/tasks/task.go +++ b/internal/querynodev2/tasks/task.go @@ -16,7 +16,6 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/collector" "github.com/milvus-io/milvus/internal/querynodev2/segments" - "github.com/milvus-io/milvus/internal/util" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -167,7 +166,7 @@ func (t *SearchTask) Execute() error { Base: &commonpb.MsgBase{ SourceID: paramtable.GetNodeID(), }, - Status: merr.Status(nil), + Status: merr.Success(), MetricType: req.GetReq().GetMetricType(), NumQueries: t.originNqs[i], TopK: t.originTopks[i], @@ -194,8 +193,11 @@ func (t *SearchTask) Execute() error { return err } defer segments.DeleteSearchResultDataBlobs(blobs) - reduceLatency := tr.RecordSpan() - + metrics.QueryNodeReduceLatency.WithLabelValues( + fmt.Sprint(paramtable.GetNodeID()), + metrics.SearchLabel, + metrics.ReduceSegments). + Observe(float64(tr.RecordSpan().Milliseconds())) for i := range t.originNqs { blob, err := segments.GetSearchResultDataBlob(blobs, i) if err != nil { @@ -213,17 +215,11 @@ func (t *SearchTask) Execute() error { bs := make([]byte, len(blob)) copy(bs, blob) - metrics.QueryNodeReduceLatency.WithLabelValues( - fmt.Sprint(paramtable.GetNodeID()), - metrics.SearchLabel, - metrics.ReduceSegments). - Observe(float64(reduceLatency.Milliseconds())) - task.result = &internalpb.SearchResults{ Base: &commonpb.MsgBase{ SourceID: paramtable.GetNodeID(), }, - Status: util.WrapStatus(commonpb.ErrorCode_Success, ""), + Status: merr.Success(), MetricType: req.GetReq().GetMetricType(), NumQueries: t.originNqs[i], TopK: t.originTopks[i], @@ -235,6 +231,7 @@ func (t *SearchTask) Execute() error { }, } } + return nil } diff --git a/internal/querynodev2/tasks/tasks.go b/internal/querynodev2/tasks/tasks.go index 59328be571b51..6a0d55b2edeec 100644 --- a/internal/querynodev2/tasks/tasks.go +++ b/internal/querynodev2/tasks/tasks.go @@ -1,9 +1,5 @@ package tasks -import ( - "context" -) - const ( schedulePolicyNameFIFO = "fifo" schedulePolicyNameUserTaskPolling = "user-task-polling" @@ -44,9 +40,12 @@ type Scheduler interface { Add(task Task) error // Start schedule the owned task asynchronously and continuously. - // 1. Stop processing until ctx.Cancel() is called. - // 2. Only call once. - Start(ctx context.Context) + // Shall be called only once + Start() + + // Stop make scheduler deny all incoming tasks + // and cleans up all related resources + Stop() // GetWaitingTaskTotalNQ GetWaitingTaskTotalNQ() int64 diff --git a/internal/querynodev2/tsafe/manager.go b/internal/querynodev2/tsafe/manager.go index 6e2448f0d9d1a..c3da4e009b7d6 100644 --- a/internal/querynodev2/tsafe/manager.go +++ b/internal/querynodev2/tsafe/manager.go @@ -20,10 +20,11 @@ import ( "fmt" "sync" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/tsoutil" . "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" ) // Manager is the interface for tsafe manager. diff --git a/internal/querynodev2/tsafe/tsafe_test.go b/internal/querynodev2/tsafe/tsafe_test.go index fa9c1c8064a3d..8edbd2de8a1f6 100644 --- a/internal/querynodev2/tsafe/tsafe_test.go +++ b/internal/querynodev2/tsafe/tsafe_test.go @@ -48,19 +48,19 @@ func (suite *TSafeTestSuite) TestBasic() { suite.NoError(err) suite.Equal(ZeroTimestamp, t) - //Add listener + // Add listener globalWatcher := suite.tSafeReplica.WatchChannel(suite.channel) channelWatcher := suite.tSafeReplica.Watch() defer globalWatcher.Close() defer channelWatcher.Close() - //Test Set tSafe + // Test Set tSafe suite.tSafeReplica.Set(suite.channel, suite.time) t, err = suite.tSafeReplica.Get(suite.channel) suite.NoError(err) suite.Equal(suite.time, t) - //Test listener + // Test listener select { case <-globalWatcher.On(): default: diff --git a/internal/querynodev2/utils.go b/internal/querynodev2/utils.go deleted file mode 100644 index d63a94a76f8b5..0000000000000 --- a/internal/querynodev2/utils.go +++ /dev/null @@ -1,27 +0,0 @@ -package querynodev2 - -import ( - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/util" - "github.com/milvus-io/milvus/pkg/util/paramtable" -) - -func WrapRetrieveResult(code commonpb.ErrorCode, msg string, errs ...error) *internalpb.RetrieveResults { - return &internalpb.RetrieveResults{ - Status: util.WrapStatus(code, msg, errs...), - } -} - -func WrapSearchResult(code commonpb.ErrorCode, msg string, errs ...error) *internalpb.SearchResults { - return &internalpb.SearchResults{ - Status: util.WrapStatus(code, msg, errs...), - } -} - -// CheckTargetID checks whether the target ID of request is the server itself, -// returns true if matched, -// returns false otherwise -func CheckTargetID[R interface{ GetBase() *commonpb.MsgBase }](req R) bool { - return req.GetBase().GetTargetID() == paramtable.GetNodeID() -} diff --git a/internal/registry/in_mem_resolver.go b/internal/registry/in_mem_resolver.go new file mode 100644 index 0000000000000..5f690d6da066c --- /dev/null +++ b/internal/registry/in_mem_resolver.go @@ -0,0 +1,49 @@ +package registry + +import ( + "context" + "sync" + + "go.uber.org/atomic" + + qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/wrappers" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +var ( + once sync.Once + + resolver atomic.Pointer[InMemResolver] +) + +func GetInMemoryResolver() *InMemResolver { + r := resolver.Load() + if r == nil { + once.Do(func() { + newResolver := &InMemResolver{ + queryNodes: typeutil.NewConcurrentMap[int64, types.QueryNode](), + } + resolver.Store(newResolver) + }) + r = resolver.Load() + } + return r +} + +type InMemResolver struct { + queryNodes *typeutil.ConcurrentMap[int64, types.QueryNode] +} + +func (r *InMemResolver) RegisterQueryNode(id int64, qn types.QueryNode) { + r.queryNodes.Insert(id, qn) +} + +func (r *InMemResolver) ResolveQueryNode(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { + qn, ok := r.queryNodes.Get(nodeID) + if !ok { + return qnClient.NewClient(ctx, addr, nodeID) + } + return wrappers.WrapQueryNodeServerAsClient(qn), nil +} diff --git a/internal/rootcoord/alter_alias_task_test.go b/internal/rootcoord/alter_alias_task_test.go index c4c714c0720e2..8eefe721c7fc7 100644 --- a/internal/rootcoord/alter_alias_task_test.go +++ b/internal/rootcoord/alter_alias_task_test.go @@ -23,7 +23,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" ) diff --git a/internal/rootcoord/alter_collection_task_test.go b/internal/rootcoord/alter_collection_task_test.go index 4dd9d97a241fe..20c31cc4f663e 100644 --- a/internal/rootcoord/alter_collection_task_test.go +++ b/internal/rootcoord/alter_collection_task_test.go @@ -220,6 +220,5 @@ func Test_alterCollectionTask_Execute(t *testing.T) { Key: common.CollectionAutoCompactionKey, Value: "true", }) - }) } diff --git a/internal/rootcoord/broker.go b/internal/rootcoord/broker.go index d9f4eab464dc5..fd97429e1010f 100644 --- a/internal/rootcoord/broker.go +++ b/internal/rootcoord/broker.go @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -202,7 +203,7 @@ func (b *ServerBroker) Flush(ctx context.Context, cID int64, segIDs []int64) err return errors.New("failed to call flush to data coordinator: " + err.Error()) } if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - return errors.New(resp.Status.Reason) + return merr.Error(resp.GetStatus()) } log.Info("flush on collection succeed", zap.Int64("collectionID", cID)) return nil @@ -250,8 +251,8 @@ func (b *ServerBroker) GetSegmentIndexState(ctx context.Context, collID UniqueID if err != nil { return nil, err } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return nil, errors.New(resp.Status.Reason) + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return nil, merr.Error(resp.GetStatus()) } return resp.GetStates(), nil @@ -307,7 +308,6 @@ func (b *ServerBroker) GcConfirm(ctx context.Context, collectionID, partitionID req := &datapb.GcConfirmRequest{CollectionId: collectionID, PartitionId: partitionID} resp, err := b.s.dataCoord.GcConfirm(ctx, req) - if err != nil { log.Warn("gc is not finished", zap.Error(err)) return false diff --git a/internal/rootcoord/broker_test.go b/internal/rootcoord/broker_test.go index f57a087338d5e..7d2788b3f69cc 100644 --- a/internal/rootcoord/broker_test.go +++ b/internal/rootcoord/broker_test.go @@ -21,18 +21,17 @@ import ( "testing" "github.com/cockroachdb/errors" - - "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus/pkg/util/merr" ) func TestServerBroker_ReleaseCollection(t *testing.T) { @@ -230,7 +229,7 @@ func TestServerBroker_GetSegmentIndexState(t *testing.T) { c := newTestCore(withValidDataCoord()) c.dataCoord.(*mockDataCoord).GetSegmentIndexStateFunc = func(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest) (*indexpb.GetSegmentIndexStateResponse, error) { return &indexpb.GetSegmentIndexStateResponse{ - Status: succStatus(), + Status: merr.Success(), States: []*indexpb.SegmentIndexState{ { SegmentID: 1, @@ -342,10 +341,11 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) { func TestServerBroker_GcConfirm(t *testing.T) { t.Run("invalid datacoord", func(t *testing.T) { - dc := mocks.NewMockDataCoord(t) + dc := mocks.NewMockDataCoordClient(t) dc.On("GcConfirm", mock.Anything, // context.Context mock.Anything, // *datapb.GcConfirmRequest + mock.Anything, // *datapb.GcConfirmRequest ).Return(nil, errors.New("error mock GcConfirm")) c := newTestCore(withDataCoord(dc)) broker := newServerBroker(c) @@ -353,12 +353,14 @@ func TestServerBroker_GcConfirm(t *testing.T) { }) t.Run("non success", func(t *testing.T) { - dc := mocks.NewMockDataCoord(t) + dc := mocks.NewMockDataCoordClient(t) + err := errors.New("mock error") dc.On("GcConfirm", mock.Anything, // context.Context mock.Anything, // *datapb.GcConfirmRequest + mock.Anything, ).Return( - &datapb.GcConfirmResponse{Status: failStatus(commonpb.ErrorCode_UnexpectedError, "error mock GcConfirm")}, + &datapb.GcConfirmResponse{Status: merr.Status(err)}, nil) c := newTestCore(withDataCoord(dc)) broker := newServerBroker(c) @@ -366,12 +368,13 @@ func TestServerBroker_GcConfirm(t *testing.T) { }) t.Run("normal case", func(t *testing.T) { - dc := mocks.NewMockDataCoord(t) + dc := mocks.NewMockDataCoordClient(t) dc.On("GcConfirm", mock.Anything, // context.Context mock.Anything, // *datapb.GcConfirmRequest + mock.Anything, ).Return( - &datapb.GcConfirmResponse{Status: succStatus(), GcFinished: true}, + &datapb.GcConfirmResponse{Status: merr.Success(), GcFinished: true}, nil) c := newTestCore(withDataCoord(dc)) broker := newServerBroker(c) diff --git a/internal/rootcoord/create_alias_task.go b/internal/rootcoord/create_alias_task.go index 88ec00d27081e..7cd8334bd76df 100644 --- a/internal/rootcoord/create_alias_task.go +++ b/internal/rootcoord/create_alias_task.go @@ -20,7 +20,6 @@ import ( "context" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" ) diff --git a/internal/rootcoord/create_alias_task_test.go b/internal/rootcoord/create_alias_task_test.go index 7bea4d775d701..77d8a16f748b7 100644 --- a/internal/rootcoord/create_alias_task_test.go +++ b/internal/rootcoord/create_alias_task_test.go @@ -20,9 +20,9 @@ import ( "context" "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" ) diff --git a/internal/rootcoord/create_collection_task.go b/internal/rootcoord/create_collection_task.go index e61b2b93a8b19..5a8afcf093cd6 100644 --- a/internal/rootcoord/create_collection_task.go +++ b/internal/rootcoord/create_collection_task.go @@ -349,8 +349,7 @@ func (t *createCollectionTask) Prepare(ctx context.Context) error { return t.assignChannels() } -func (t *createCollectionTask) genCreateCollectionMsg(ctx context.Context) *ms.MsgPack { - ts := t.GetTs() +func (t *createCollectionTask) genCreateCollectionMsg(ctx context.Context, ts uint64) *ms.MsgPack { collectionID := t.collID partitionIDs := t.partIDs // error won't happen here. @@ -382,21 +381,36 @@ func (t *createCollectionTask) genCreateCollectionMsg(ctx context.Context) *ms.M return &msgPack } -func (t *createCollectionTask) addChannelsAndGetStartPositions(ctx context.Context) (map[string][]byte, error) { +func (t *createCollectionTask) addChannelsAndGetStartPositions(ctx context.Context, ts uint64) (map[string][]byte, error) { t.core.chanTimeTick.addDmlChannels(t.channels.physicalChannels...) - msg := t.genCreateCollectionMsg(ctx) + msg := t.genCreateCollectionMsg(ctx, ts) return t.core.chanTimeTick.broadcastMarkDmlChannels(t.channels.physicalChannels, msg) } +func (t *createCollectionTask) getCreateTs() (uint64, error) { + replicateInfo := t.Req.GetBase().GetReplicateInfo() + if !replicateInfo.GetIsReplicate() { + return t.GetTs(), nil + } + if replicateInfo.GetMsgTimestamp() == 0 { + log.Warn("the cdc timestamp is not set in the request for the backup instance") + return 0, merr.WrapErrParameterInvalidMsg("the cdc timestamp is not set in the request for the backup instance") + } + return replicateInfo.GetMsgTimestamp(), nil +} + func (t *createCollectionTask) Execute(ctx context.Context) error { collID := t.collID partIDs := t.partIDs - ts := t.GetTs() + ts, err := t.getCreateTs() + if err != nil { + return err + } vchanNames := t.channels.virtualChannels chanNames := t.channels.physicalChannels - startPositions, err := t.addChannelsAndGetStartPositions(ctx) + startPositions, err := t.addChannelsAndGetStartPositions(ctx, ts) if err != nil { // ugly here, since we must get start positions first. t.core.chanTimeTick.removeDmlChannels(t.channels.physicalChannels...) @@ -445,7 +459,7 @@ func (t *createCollectionTask) Execute(ctx context.Context) error { return fmt.Errorf("create duplicate collection with different parameters, collection: %s", t.Req.GetCollectionName()) } // make creating collection idempotent. - log.Warn("add duplicate collection", zap.String("collection", t.Req.GetCollectionName()), zap.Uint64("ts", t.GetTs())) + log.Warn("add duplicate collection", zap.String("collection", t.Req.GetCollectionName()), zap.Uint64("ts", ts)) return nil } @@ -475,6 +489,7 @@ func (t *createCollectionTask) Execute(ctx context.Context) error { baseStep: baseStep{core: t.core}, collectionID: collID, channels: t.channels, + isSkip: !Params.CommonCfg.TTMsgEnabled.GetAsBool(), }) undoTask.AddStep(&watchChannelsStep{ baseStep: baseStep{core: t.core}, diff --git a/internal/rootcoord/create_collection_task_test.go b/internal/rootcoord/create_collection_task_test.go index 808ffbececa7e..00e926c0ba798 100644 --- a/internal/rootcoord/create_collection_task_test.go +++ b/internal/rootcoord/create_collection_task_test.go @@ -51,6 +51,40 @@ func Test_createCollectionTask_validate(t *testing.T) { assert.Error(t, err) }) + t.Run("create ts", func(t *testing.T) { + task := createCollectionTask{ + Req: nil, + } + { + task.SetTs(1000) + ts, err := task.getCreateTs() + assert.NoError(t, err) + assert.EqualValues(t, 1000, ts) + } + + task.Req = &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreateCollection, + ReplicateInfo: &commonpb.ReplicateInfo{ + IsReplicate: true, + }, + }, + } + { + task.SetTs(1000) + _, err := task.getCreateTs() + assert.Error(t, err) + err = task.Execute(context.Background()) + assert.Error(t, err) + } + { + task.Req.Base.ReplicateInfo.MsgTimestamp = 2000 + ts, err := task.getCreateTs() + assert.NoError(t, err) + assert.EqualValues(t, 2000, ts) + } + }) + t.Run("invalid msg type", func(t *testing.T) { task := createCollectionTask{ Req: &milvuspb.CreateCollectionRequest{ @@ -708,11 +742,11 @@ func Test_createCollectionTask_Execute(t *testing.T) { StateCode: commonpb.StateCode_Healthy, }, SubcomponentStates: nil, - Status: succStatus(), + Status: merr.Success(), }, nil } dc.WatchChannelsFunc = func(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) { - return &datapb.WatchChannelsResponse{Status: succStatus()}, nil + return &datapb.WatchChannelsResponse{Status: merr.Success()}, nil } core := newTestCore(withValidIDAllocator(), diff --git a/internal/rootcoord/ddl_ts_lock_manager_test.go b/internal/rootcoord/ddl_ts_lock_manager_test.go index b7e7f9a45bab0..7fe9ccd9b9ed5 100644 --- a/internal/rootcoord/ddl_ts_lock_manager_test.go +++ b/internal/rootcoord/ddl_ts_lock_manager_test.go @@ -20,7 +20,6 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" ) diff --git a/internal/rootcoord/dml_channels.go b/internal/rootcoord/dml_channels.go index af9df06300975..b3055c3c7b1e6 100644 --- a/internal/rootcoord/dml_channels.go +++ b/internal/rootcoord/dml_channels.go @@ -25,12 +25,10 @@ import ( "sync" "github.com/cockroachdb/errors" - - "github.com/milvus-io/milvus/pkg/metrics" - "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -188,10 +186,11 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref if params.PreCreatedTopicEnabled.GetAsBool() { subName := fmt.Sprintf("pre-created-topic-check-%s", name) - ms.AsConsumer([]string{name}, subName, mqwrapper.SubscriptionPositionUnknown) - // check topic exist and check the existed topic whether empty or not + ms.AsConsumer(ctx, []string{name}, subName, mqwrapper.SubscriptionPositionUnknown) + // check if topic is existed // kafka and rmq will err if the topic does not yet exist, pulsar will not - // if one of the topics is not empty, panic + // allow topics is not empty, for the reason that when restart or upgrade, the topic is not empty + // if there are any message that not belong to milvus, will skip it err := ms.CheckTopicValid(name) if err != nil { log.Error("created topic is invaild", zap.String("name", name), zap.Error(err)) diff --git a/internal/rootcoord/dml_channels_test.go b/internal/rootcoord/dml_channels_test.go index f517ed5049b6c..db61ff1327db9 100644 --- a/internal/rootcoord/dml_channels_test.go +++ b/internal/rootcoord/dml_channels_test.go @@ -24,20 +24,18 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestDmlMsgStream(t *testing.T) { t.Run("RefCnt", func(t *testing.T) { - dms := &dmlMsgStream{refcnt: 0} assert.Equal(t, int64(0), dms.RefCnt()) assert.Equal(t, int64(0), dms.Used()) @@ -283,7 +281,8 @@ func (ms *FailMsgStream) Close() {} func (ms *FailMsgStream) Chan() <-chan *msgstream.MsgPack { return nil } func (ms *FailMsgStream) AsProducer(channels []string) {} func (ms *FailMsgStream) AsReader(channels []string, subName string) {} -func (ms *FailMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) { +func (ms *FailMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error { + return nil } func (ms *FailMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) {} func (ms *FailMsgStream) GetProduceChannels() []string { return nil } @@ -294,8 +293,8 @@ func (ms *FailMsgStream) Broadcast(*msgstream.MsgPack) (map[string][]msgstream.M } return nil, nil } -func (ms *FailMsgStream) Consume() *msgstream.MsgPack { return nil } -func (ms *FailMsgStream) Seek(offset []*msgstream.MsgPosition) error { return nil } +func (ms *FailMsgStream) Consume() *msgstream.MsgPack { return nil } +func (ms *FailMsgStream) Seek(ctx context.Context, offset []*msgstream.MsgPosition) error { return nil } func (ms *FailMsgStream) GetLatestMsgID(channel string) (msgstream.MessageID, error) { return nil, nil diff --git a/internal/rootcoord/drop_alias_task_test.go b/internal/rootcoord/drop_alias_task_test.go index f8796e5bc033f..199a583107c79 100644 --- a/internal/rootcoord/drop_alias_task_test.go +++ b/internal/rootcoord/drop_alias_task_test.go @@ -54,7 +54,6 @@ func Test_dropAliasTask_Execute(t *testing.T) { task := &dropAliasTask{ baseTask: newBaseTask(context.Background(), core), Req: &milvuspb.DropAliasRequest{ - Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropAlias}, Alias: alias, }, @@ -69,7 +68,6 @@ func Test_dropAliasTask_Execute(t *testing.T) { task := &dropAliasTask{ baseTask: newBaseTask(context.Background(), core), Req: &milvuspb.DropAliasRequest{ - Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropAlias}, Alias: alias, }, @@ -92,7 +90,6 @@ func Test_dropAliasTask_Execute(t *testing.T) { task := &dropAliasTask{ baseTask: newBaseTask(context.Background(), core), Req: &milvuspb.DropAliasRequest{ - Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropAlias}, Alias: alias, }, diff --git a/internal/rootcoord/drop_collection_task.go b/internal/rootcoord/drop_collection_task.go index 5440186d3859f..f35fca1770355 100644 --- a/internal/rootcoord/drop_collection_task.go +++ b/internal/rootcoord/drop_collection_task.go @@ -20,9 +20,9 @@ import ( "context" "fmt" + "github.com/cockroachdb/errors" "go.uber.org/zap" - "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" @@ -100,6 +100,7 @@ func (t *dropCollectionTask) Execute(ctx context.Context) error { redoTask.AddAsyncStep(&deleteCollectionDataStep{ baseStep: baseStep{core: t.core}, coll: collMeta, + isSkip: t.Req.GetBase().GetReplicateInfo().GetIsReplicate(), }) redoTask.AddAsyncStep(&removeDmlChannelsStep{ baseStep: baseStep{core: t.core}, diff --git a/internal/rootcoord/drop_partition_task.go b/internal/rootcoord/drop_partition_task.go index 4e352c0729597..1306b1ef2aadc 100644 --- a/internal/rootcoord/drop_partition_task.go +++ b/internal/rootcoord/drop_partition_task.go @@ -91,6 +91,7 @@ func (t *dropPartitionTask) Execute(ctx context.Context) error { PartitionName: t.Req.GetPartitionName(), CollectionID: t.collMeta.CollectionID, }, + isSkip: t.Req.GetBase().GetReplicateInfo().GetIsReplicate(), }) redoTask.AddAsyncStep(newConfirmGCStep(t.core, t.collMeta.CollectionID, partID)) redoTask.AddAsyncStep(&removePartitionMetaStep{ diff --git a/internal/rootcoord/expire_cache.go b/internal/rootcoord/expire_cache.go index fe6f50b4d21e6..eccfa6e7c2367 100644 --- a/internal/rootcoord/expire_cache.go +++ b/internal/rootcoord/expire_cache.go @@ -70,8 +70,8 @@ func (c *Core) ExpireMetaCache(ctx context.Context, dbName string, collNames []s for _, collName := range collNames { req := proxypb.InvalidateCollMetaCacheRequest{ Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(0), //TODO, msg type - commonpbutil.WithMsgID(0), //TODO, msg id + commonpbutil.WithMsgType(0), // TODO, msg type + commonpbutil.WithMsgID(0), // TODO, msg id commonpbutil.WithTimeStamp(ts), commonpbutil.WithSourceID(c.session.ServerID), ), diff --git a/internal/rootcoord/expire_cache_test.go b/internal/rootcoord/expire_cache_test.go index 61974b17dd120..82782c6753930 100644 --- a/internal/rootcoord/expire_cache_test.go +++ b/internal/rootcoord/expire_cache_test.go @@ -19,10 +19,10 @@ package rootcoord import ( "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/proxypb" - "github.com/stretchr/testify/assert" ) func Test_expireCacheConfig_apply(t *testing.T) { diff --git a/internal/rootcoord/garbage_collector.go b/internal/rootcoord/garbage_collector.go index 5d4436175d322..523191da2274f 100644 --- a/internal/rootcoord/garbage_collector.go +++ b/internal/rootcoord/garbage_collector.go @@ -61,6 +61,7 @@ func (c *bgGarbageCollector) ReDropCollection(collMeta *model.Collection, ts Tim redo.AddAsyncStep(&deleteCollectionDataStep{ baseStep: baseStep{core: c.s}, coll: collMeta, + isSkip: !Params.CommonCfg.TTMsgEnabled.GetAsBool(), }) redo.AddAsyncStep(&removeDmlChannelsStep{ baseStep: baseStep{core: c.s}, @@ -93,6 +94,7 @@ func (c *bgGarbageCollector) RemoveCreatingCollection(collMeta *model.Collection virtualChannels: collMeta.VirtualChannelNames, physicalChannels: collMeta.PhysicalChannelNames, }, + isSkip: !Params.CommonCfg.TTMsgEnabled.GetAsBool(), }) redo.AddAsyncStep(&removeDmlChannelsStep{ baseStep: baseStep{core: c.s}, @@ -117,6 +119,7 @@ func (c *bgGarbageCollector) ReDropPartition(dbID int64, pChannels []string, par baseStep: baseStep{core: c.s}, pchans: pChannels, partition: partition, + isSkip: !Params.CommonCfg.TTMsgEnabled.GetAsBool(), }) redo.AddAsyncStep(&removeDmlChannelsStep{ baseStep: baseStep{core: c.s}, diff --git a/internal/rootcoord/garbage_collector_test.go b/internal/rootcoord/garbage_collector_test.go index f3583299c3977..b08e4088d492c 100644 --- a/internal/rootcoord/garbage_collector_test.go +++ b/internal/rootcoord/garbage_collector_test.go @@ -24,6 +24,7 @@ import ( "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "google.golang.org/grpc" "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/mocks" @@ -460,8 +461,8 @@ func TestGarbageCollector_RemoveCreatingPartition(t *testing.T) { signal <- struct{}{} }) - qc := mocks.NewMockQueryCoord(t) - qc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(merr.Status(nil), nil) + qc := mocks.NewMockQueryCoordClient(t) + qc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(merr.Success(), nil) core := newTestCore(withTtSynchronizer(ticker), withMeta(meta), @@ -485,10 +486,10 @@ func TestGarbageCollector_RemoveCreatingPartition(t *testing.T) { signal := make(chan struct{}, 1) meta := mockrootcoord.NewIMetaTable(t) - qc := mocks.NewMockQueryCoord(t) - qc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything). - Return(merr.Status(nil), fmt.Errorf("mock err")). - Run(func(ctx context.Context, req *querypb.ReleasePartitionsRequest) { + qc := mocks.NewMockQueryCoordClient(t) + qc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything, mock.Anything). + Return(merr.Success(), fmt.Errorf("mock err")). + Run(func(ctx context.Context, req *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) { signal <- struct{}{} }) @@ -519,8 +520,8 @@ func TestGarbageCollector_RemoveCreatingPartition(t *testing.T) { signal <- struct{}{} }) - qc := mocks.NewMockQueryCoord(t) - qc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(merr.Status(nil), nil) + qc := mocks.NewMockQueryCoordClient(t) + qc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything, mock.Anything).Return(merr.Success(), nil) core := newTestCore(withTtSynchronizer(ticker), withMeta(meta), diff --git a/internal/rootcoord/has_collection_task.go b/internal/rootcoord/has_collection_task.go index 128583a5b8ea0..d9258a8f19607 100644 --- a/internal/rootcoord/has_collection_task.go +++ b/internal/rootcoord/has_collection_task.go @@ -21,6 +21,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/util/merr" ) // hasCollectionTask has collection request task @@ -39,7 +40,7 @@ func (t *hasCollectionTask) Prepare(ctx context.Context) error { // Execute task execution func (t *hasCollectionTask) Execute(ctx context.Context) error { - t.Rsp.Status = succStatus() + t.Rsp.Status = merr.Success() ts := getTravelTs(t.Req) // TODO: what if err != nil && common.IsCollectionNotExistError == false, should we consider this RPC as failure? _, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), ts) diff --git a/internal/rootcoord/has_partition_task.go b/internal/rootcoord/has_partition_task.go index ac728f682ab18..77ef717b47c84 100644 --- a/internal/rootcoord/has_partition_task.go +++ b/internal/rootcoord/has_partition_task.go @@ -21,6 +21,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -40,12 +41,12 @@ func (t *hasPartitionTask) Prepare(ctx context.Context) error { // Execute task execution func (t *hasPartitionTask) Execute(ctx context.Context) error { - t.Rsp.Status = succStatus() + t.Rsp.Status = merr.Success() t.Rsp.Value = false // TODO: why HasPartitionRequest doesn't contain Timestamp but other requests do. coll, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.CollectionName, typeutil.MaxTimestamp) if err != nil { - t.Rsp.Status = failStatus(commonpb.ErrorCode_CollectionNotExists, err.Error()) + t.Rsp.Status = merr.Status(err) return err } for _, part := range coll.Partitions { diff --git a/internal/rootcoord/has_partition_task_test.go b/internal/rootcoord/has_partition_task_test.go index e6049224bac21..3ccc0935a3ba8 100644 --- a/internal/rootcoord/has_partition_task_test.go +++ b/internal/rootcoord/has_partition_task_test.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/metastore/model" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" + "github.com/milvus-io/milvus/pkg/util/merr" ) func Test_hasPartitionTask_Prepare(t *testing.T) { @@ -57,7 +58,9 @@ func Test_hasPartitionTask_Prepare(t *testing.T) { func Test_hasPartitionTask_Execute(t *testing.T) { t.Run("fail to get collection", func(t *testing.T) { - core := newTestCore(withInvalidMeta()) + metaTable := mockrootcoord.NewIMetaTable(t) + metaTable.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, "test coll", mock.Anything).Return(nil, merr.WrapErrCollectionNotFound("test coll")) + core := newTestCore(withMeta(metaTable)) task := &hasPartitionTask{ baseTask: newBaseTask(context.Background(), core), Req: &milvuspb.HasPartitionRequest{ @@ -70,7 +73,8 @@ func Test_hasPartitionTask_Execute(t *testing.T) { } err := task.Execute(context.Background()) assert.Error(t, err) - assert.Equal(t, task.Rsp.GetStatus().GetErrorCode(), commonpb.ErrorCode_CollectionNotExists) + assert.ErrorIs(t, err, merr.ErrCollectionNotFound) + assert.ErrorIs(t, merr.Error(task.Rsp.GetStatus()), merr.ErrCollectionNotFound) assert.False(t, task.Rsp.GetValue()) }) diff --git a/internal/rootcoord/import_manager.go b/internal/rootcoord/import_manager.go index c48d4cd25af38..d358291a1ba1d 100644 --- a/internal/rootcoord/import_manager.go +++ b/internal/rootcoord/import_manager.go @@ -91,7 +91,8 @@ func newImportManager(ctx context.Context, client kv.TxnKV, importService func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error), getSegmentStates func(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error), getCollectionName func(dbName string, collID, partitionID typeutil.UniqueID) (string, string, error), - unsetIsImportingState func(context.Context, *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error)) *importManager { + unsetIsImportingState func(context.Context, *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error), +) *importManager { mgr := &importManager{ ctx: ctx, taskStore: client, @@ -334,7 +335,6 @@ func (m *importManager) flipTaskFlushedState(ctx context.Context, importTask *mi log.Info("a DataNode is no longer busy after processing task", zap.Int64("dataNode ID", dataNodeID), zap.Int64("task ID", importTask.GetId())) - }() // Unset isImporting flag. if m.callUnsetIsImportingState == nil { @@ -419,25 +419,20 @@ func (m *importManager) isRowbased(files []string) (bool, error) { // importJob processes the import request, generates import tasks, sends these tasks to DataCoord, and returns // immediately. func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportRequest, cID int64, pID int64) *milvuspb.ImportResponse { - returnErrorFunc := func(reason string) *milvuspb.ImportResponse { + if len(req.GetFiles()) == 0 { return &milvuspb.ImportResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: reason, - }, + Status: merr.Status(merr.WrapErrParameterInvalidMsg("import request is empty")), } } - if req == nil || len(req.Files) == 0 { - return returnErrorFunc("import request is empty") - } - if m.callImportService == nil { - return returnErrorFunc("import service is not available") + return &milvuspb.ImportResponse{ + Status: merr.Status(merr.WrapErrServiceUnavailable("import service unavailable")), + } } resp := &milvuspb.ImportResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Tasks: make([]int64, 0), } @@ -553,7 +548,9 @@ func (m *importManager) importJob(ctx context.Context, req *milvuspb.ImportReque return nil }() if err != nil { - return returnErrorFunc(err.Error()) + return &milvuspb.ImportResponse{ + Status: merr.Status(err), + } } if sendOutTasksErr := m.sendOutTasks(ctx); sendOutTasksErr != nil { log.Error("fail to send out tasks", zap.Error(sendOutTasksErr)) @@ -625,7 +622,6 @@ func (m *importManager) updateTaskInfo(ir *rootcoordpb.ImportResult) (*datapb.Im return toPersistImportTaskInfo, nil }() - if err != nil { return nil, err } @@ -734,7 +730,7 @@ func (m *importManager) setCollectionPartitionName(dbName string, colID, partID } func (m *importManager) copyTaskInfo(input *datapb.ImportTaskInfo, output *milvuspb.GetImportStateResponse) { - output.Status = merr.Status(nil) + output.Status = merr.Success() output.Id = input.GetId() output.CollectionId = input.GetCollectionId() @@ -756,11 +752,8 @@ func (m *importManager) copyTaskInfo(input *datapb.ImportTaskInfo, output *milvu // getTaskState looks for task with the given ID and returns its import state. func (m *importManager) getTaskState(tID int64) *milvuspb.GetImportStateResponse { resp := &milvuspb.GetImportStateResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "import task id doesn't exist", - }, - Infos: make([]*commonpb.KeyValuePair, 0), + Status: merr.Success(), + Infos: make([]*commonpb.KeyValuePair, 0), } // (1) Search in pending tasks list. found := false @@ -787,24 +780,24 @@ func (m *importManager) getTaskState(tID int64) *milvuspb.GetImportStateResponse return resp } // (3) Search in Etcd. - if v, err := m.taskStore.Load(BuildImportTaskKey(tID)); err == nil && v != "" { - ti := &datapb.ImportTaskInfo{} - if err := proto.Unmarshal([]byte(v), ti); err != nil { - log.Error("failed to unmarshal proto", zap.String("taskInfo", v), zap.Error(err)) - } else { - m.copyTaskInfo(ti, resp) - found = true - } - } else { + v, err := m.taskStore.Load(BuildImportTaskKey(tID)) + if err != nil { log.Warn("failed to load task info from Etcd", zap.String("value", v), - zap.Error(err)) + zap.Error(err), + ) + resp.Status = merr.Status(err) + return resp } - if found { - log.Info("getting import task state", zap.Int64("task ID", tID), zap.Any("state", resp.State), zap.Int64s("segment", resp.SegmentIds)) + + ti := &datapb.ImportTaskInfo{} + if err := proto.Unmarshal([]byte(v), ti); err != nil { + log.Error("failed to unmarshal proto", zap.String("taskInfo", v), zap.Error(err)) + resp.Status = merr.Status(err) return resp } - log.Debug("get import task state failed", zap.Int64("taskID", tID)) + + m.copyTaskInfo(ti, resp) return resp } @@ -1070,10 +1063,9 @@ func tryUpdateErrMsg(errReason string, toPersistImportTaskInfo *datapb.ImportTas if toPersistImportTaskInfo.GetState().GetErrorMessage() == "" { toPersistImportTaskInfo.State.ErrorMessage = errReason } else { - toPersistImportTaskInfo.State.ErrorMessage = - fmt.Sprintf("%s; %s", - toPersistImportTaskInfo.GetState().GetErrorMessage(), - errReason) + toPersistImportTaskInfo.State.ErrorMessage = fmt.Sprintf("%s; %s", + toPersistImportTaskInfo.GetState().GetErrorMessage(), + errReason) } } } diff --git a/internal/rootcoord/import_manager_test.go b/internal/rootcoord/import_manager_test.go index 17ae70cc7135f..3b841c221d925 100644 --- a/internal/rootcoord/import_manager_test.go +++ b/internal/rootcoord/import_manager_test.go @@ -35,15 +35,16 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" importutil2 "github.com/milvus-io/milvus/internal/util/importutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) func TestImportManager_NewImportManager(t *testing.T) { var countLock sync.RWMutex - var globalCount = typeutil.UniqueID(0) + globalCount := typeutil.UniqueID(0) - var idAlloc = func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { + idAlloc := func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { countLock.Lock() defer countLock.Unlock() globalCount++ @@ -91,22 +92,16 @@ func TestImportManager_NewImportManager(t *testing.T) { callImportServiceFn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { if mockCallImportServiceErr { return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, errors.New("mock err") } return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } callGetSegmentStates := func(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { return &datapb.GetSegmentStatesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } var wg sync.WaitGroup @@ -233,9 +228,9 @@ func TestImportManager_NewImportManager(t *testing.T) { func TestImportManager_TestSetImportTaskState(t *testing.T) { var countLock sync.RWMutex - var globalCount = typeutil.UniqueID(0) + globalCount := typeutil.UniqueID(0) - var idAlloc = func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { + idAlloc := func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { countLock.Lock() defer countLock.Unlock() globalCount++ @@ -302,9 +297,9 @@ func TestImportManager_TestSetImportTaskState(t *testing.T) { func TestImportManager_TestEtcdCleanUp(t *testing.T) { var countLock sync.RWMutex - var globalCount = typeutil.UniqueID(0) + globalCount := typeutil.UniqueID(0) - var idAlloc = func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { + idAlloc := func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { countLock.Lock() defer countLock.Unlock() globalCount++ @@ -351,23 +346,17 @@ func TestImportManager_TestEtcdCleanUp(t *testing.T) { callImportServiceFn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { if mockCallImportServiceErr { return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, errors.New("mock err") } return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } callGetSegmentStates := func(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { return &datapb.GetSegmentStatesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) @@ -392,9 +381,9 @@ func TestImportManager_TestEtcdCleanUp(t *testing.T) { func TestImportManager_TestFlipTaskStateLoop(t *testing.T) { var countLock sync.RWMutex - var globalCount = typeutil.UniqueID(0) + globalCount := typeutil.UniqueID(0) - var idAlloc = func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { + idAlloc := func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { countLock.Lock() defer countLock.Unlock() globalCount++ @@ -443,30 +432,22 @@ func TestImportManager_TestFlipTaskStateLoop(t *testing.T) { callImportServiceFn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { if mockCallImportServiceErr { return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, errors.New("mock err") } return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } callGetSegmentStates := func(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { return &datapb.GetSegmentStatesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } callUnsetIsImportingState := func(context.Context, *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Success(), nil } flipPersistedTaskInterval = 20 @@ -520,9 +501,9 @@ func TestImportManager_TestFlipTaskStateLoop(t *testing.T) { func TestImportManager_ImportJob(t *testing.T) { var countLock sync.RWMutex - var globalCount = typeutil.UniqueID(0) + globalCount := typeutil.UniqueID(0) - var idAlloc = func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { + idAlloc := func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { countLock.Lock() defer countLock.Unlock() globalCount++ @@ -536,15 +517,13 @@ func TestImportManager_ImportJob(t *testing.T) { mockKv := memkv.NewMemoryKV() callGetSegmentStates := func(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { return &datapb.GetSegmentStatesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } // nil request mgr := newImportManager(context.TODO(), mockKv, idAlloc, nil, callGetSegmentStates, nil, nil) resp := mgr.importJob(context.TODO(), nil, colID, 0) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) rowReq := &milvuspb.ImportRequest{ CollectionName: "c1", @@ -554,11 +533,11 @@ func TestImportManager_ImportJob(t *testing.T) { // nil callImportService resp = mgr.importJob(context.TODO(), rowReq, colID, 0) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) // row-based import not allow multiple files resp = mgr.importJob(context.TODO(), rowReq, colID, 0) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) importServiceFunc := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ @@ -573,6 +552,7 @@ func TestImportManager_ImportJob(t *testing.T) { rowReq.Files = []string{"f1.json"} mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callGetSegmentStates, nil, nil) resp = mgr.importJob(context.TODO(), rowReq, colID, 0) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, len(rowReq.Files), len(mgr.pendingTasks)) assert.Equal(t, 0, len(mgr.workingTasks)) @@ -586,26 +566,27 @@ func TestImportManager_ImportJob(t *testing.T) { // since the importServiceFunc return error, tasks will be kept in pending list mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callGetSegmentStates, nil, nil) resp = mgr.importJob(context.TODO(), colReq, colID, 0) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 1, len(mgr.pendingTasks)) assert.Equal(t, 0, len(mgr.workingTasks)) importServiceFunc = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } // row-based case, since the importServiceFunc return success, tasks will be sent to working list mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callGetSegmentStates, nil, nil) resp = mgr.importJob(context.TODO(), rowReq, colID, 0) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(mgr.pendingTasks)) assert.Equal(t, len(rowReq.Files), len(mgr.workingTasks)) // column-based case, since the importServiceFunc return success, tasks will be sent to working list mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callGetSegmentStates, nil, nil) resp = mgr.importJob(context.TODO(), colReq, colID, 0) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(mgr.pendingTasks)) assert.Equal(t, 1, len(mgr.workingTasks)) @@ -620,9 +601,7 @@ func TestImportManager_ImportJob(t *testing.T) { } count++ return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } @@ -630,9 +609,11 @@ func TestImportManager_ImportJob(t *testing.T) { // the first task is sent to working list, and 1 task left in pending list mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callGetSegmentStates, nil, nil) resp = mgr.importJob(context.TODO(), rowReq, colID, 0) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(mgr.pendingTasks)) assert.Equal(t, 1, len(mgr.workingTasks)) resp = mgr.importJob(context.TODO(), rowReq, colID, 0) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 1, len(mgr.pendingTasks)) assert.Equal(t, 1, len(mgr.workingTasks)) @@ -641,18 +622,18 @@ func TestImportManager_ImportJob(t *testing.T) { for i := 0; i <= Params.RootCoordCfg.ImportMaxPendingTaskCount.GetAsInt(); i++ { resp = mgr.importJob(context.TODO(), rowReq, colID, 0) if i < Params.RootCoordCfg.ImportMaxPendingTaskCount.GetAsInt()-1 { - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } else { - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } } } func TestImportManager_AllDataNodesBusy(t *testing.T) { var countLock sync.RWMutex - var globalCount = typeutil.UniqueID(0) + globalCount := typeutil.UniqueID(0) - var idAlloc = func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { + idAlloc := func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { countLock.Lock() defer countLock.Unlock() globalCount++ @@ -684,9 +665,7 @@ func TestImportManager_AllDataNodesBusy(t *testing.T) { if count < len(dnList) { count++ return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), DatanodeId: dnList[count-1], }, nil } @@ -699,9 +678,7 @@ func TestImportManager_AllDataNodesBusy(t *testing.T) { callGetSegmentStates := func(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { return &datapb.GetSegmentStatesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } @@ -709,7 +686,7 @@ func TestImportManager_AllDataNodesBusy(t *testing.T) { mgr := newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callGetSegmentStates, nil, nil) for i := 0; i < len(dnList); i++ { resp := mgr.importJob(context.TODO(), rowReq, colID, 0) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(mgr.pendingTasks)) assert.Equal(t, i+1, len(mgr.workingTasks)) } @@ -717,7 +694,7 @@ func TestImportManager_AllDataNodesBusy(t *testing.T) { // all data nodes are busy, new task waiting in pending list mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callGetSegmentStates, nil, nil) resp := mgr.importJob(context.TODO(), rowReq, colID, 0) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, len(rowReq.Files), len(mgr.pendingTasks)) assert.Equal(t, 0, len(mgr.workingTasks)) @@ -725,32 +702,32 @@ func TestImportManager_AllDataNodesBusy(t *testing.T) { count = 0 mgr = newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callGetSegmentStates, nil, nil) resp = mgr.importJob(context.TODO(), colReq, colID, 0) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(mgr.pendingTasks)) assert.Equal(t, 1, len(mgr.workingTasks)) resp = mgr.importJob(context.TODO(), colReq, colID, 0) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(mgr.pendingTasks)) assert.Equal(t, 2, len(mgr.workingTasks)) resp = mgr.importJob(context.TODO(), colReq, colID, 0) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(mgr.pendingTasks)) assert.Equal(t, 3, len(mgr.workingTasks)) // all data nodes are busy now, new task is pending resp = mgr.importJob(context.TODO(), colReq, colID, 0) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 1, len(mgr.pendingTasks)) assert.Equal(t, 3, len(mgr.workingTasks)) } func TestImportManager_TaskState(t *testing.T) { var countLock sync.RWMutex - var globalCount = typeutil.UniqueID(0) + globalCount := typeutil.UniqueID(0) - var idAlloc = func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { + idAlloc := func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { countLock.Lock() defer countLock.Unlock() globalCount++ @@ -761,9 +738,7 @@ func TestImportManager_TaskState(t *testing.T) { mockKv := memkv.NewMemoryKV() importServiceFunc := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } @@ -774,9 +749,7 @@ func TestImportManager_TaskState(t *testing.T) { } callGetSegmentStates := func(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { return &datapb.GetSegmentStatesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } @@ -812,9 +785,7 @@ func TestImportManager_TaskState(t *testing.T) { } mgr.callUnsetIsImportingState = func(context.Context, *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Success(), nil } // index doesn't exist, the persist task will be set to completed ti, err := mgr.updateTaskInfo(info) @@ -828,14 +799,14 @@ func TestImportManager_TaskState(t *testing.T) { assert.Equal(t, int64(1000), ti.GetState().GetRowCount()) resp := mgr.getTaskState(10000) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) resp = mgr.getTaskState(2) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.ImportState_ImportPersisted, resp.State) resp = mgr.getTaskState(1) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.ImportState_ImportStarted, resp.State) info = &rootcoordpb.ImportResult{ @@ -863,7 +834,7 @@ func TestImportManager_TaskState(t *testing.T) { } func TestImportManager_AllocFail(t *testing.T) { - var idAlloc = func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { + idAlloc := func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { return 0, 0, errors.New("injected failure") } paramtable.Get().Save(Params.RootCoordCfg.ImportTaskSubPath.Key, "test_import_task") @@ -871,9 +842,7 @@ func TestImportManager_AllocFail(t *testing.T) { mockKv := memkv.NewMemoryKV() importServiceFunc := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } @@ -885,22 +854,20 @@ func TestImportManager_AllocFail(t *testing.T) { callGetSegmentStates := func(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { return &datapb.GetSegmentStatesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } mgr := newImportManager(context.TODO(), mockKv, idAlloc, importServiceFunc, callGetSegmentStates, nil, nil) resp := mgr.importJob(context.TODO(), rowReq, colID, 0) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, 0, len(mgr.pendingTasks)) } func TestImportManager_ListAllTasks(t *testing.T) { var countLock sync.RWMutex - var globalCount = typeutil.UniqueID(0) + globalCount := typeutil.UniqueID(0) - var idAlloc = func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { + idAlloc := func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { countLock.Lock() defer countLock.Unlock() globalCount++ @@ -920,9 +887,7 @@ func TestImportManager_ListAllTasks(t *testing.T) { callGetSegmentStates := func(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { return &datapb.GetSegmentStatesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } @@ -1031,9 +996,7 @@ func TestImportManager_ListAllTasks(t *testing.T) { // accept tasks to working list mgr.callImportService = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } diff --git a/internal/rootcoord/list_db_task.go b/internal/rootcoord/list_db_task.go index 3e297329eb3f3..4c34a7424c75a 100644 --- a/internal/rootcoord/list_db_task.go +++ b/internal/rootcoord/list_db_task.go @@ -19,8 +19,8 @@ package rootcoord import ( "context" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/util/merr" ) type listDatabaseTask struct { @@ -34,10 +34,10 @@ func (t *listDatabaseTask) Prepare(ctx context.Context) error { } func (t *listDatabaseTask) Execute(ctx context.Context) error { - t.Resp.Status = succStatus() + t.Resp.Status = merr.Success() ret, err := t.core.meta.ListDatabases(ctx, t.GetTs()) if err != nil { - t.Resp.Status = failStatus(commonpb.ErrorCode_UnexpectedError, err.Error()) + t.Resp.Status = merr.Status(err) return err } diff --git a/internal/rootcoord/list_db_task_test.go b/internal/rootcoord/list_db_task_test.go index e5213fc9eaeda..79eea20c5ee66 100644 --- a/internal/rootcoord/list_db_task_test.go +++ b/internal/rootcoord/list_db_task_test.go @@ -20,12 +20,13 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/metastore/model" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" ) func Test_ListDBTask(t *testing.T) { @@ -46,7 +47,7 @@ func Test_ListDBTask(t *testing.T) { err = task.Execute(context.Background()) assert.Error(t, err) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, task.Resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, task.Resp.GetStatus().GetErrorCode()) }) t.Run("ok", func(t *testing.T) { @@ -75,6 +76,6 @@ func Test_ListDBTask(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(task.Resp.GetDbNames())) assert.Equal(t, ret[0].Name, task.Resp.GetDbNames()[0]) - assert.Equal(t, commonpb.ErrorCode_Success, task.Resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, task.Resp.GetStatus().GetErrorCode()) }) } diff --git a/internal/rootcoord/meta_table.go b/internal/rootcoord/meta_table.go index aa90d14a0cfa9..0fa5de0c45e8f 100644 --- a/internal/rootcoord/meta_table.go +++ b/internal/rootcoord/meta_table.go @@ -657,7 +657,7 @@ func (mt *MetaTable) listCollectionFromCache(dbName string, onlyAvail bool) ([]* db, ok := mt.dbName2Meta[dbName] if !ok { - return nil, fmt.Errorf("database:%s not found", dbName) + return nil, merr.WrapErrDatabaseNotFound(dbName) } collectionFromCache := make([]*model.Collection, 0, len(mt.collID2Meta)) @@ -730,7 +730,7 @@ func (mt *MetaTable) RenameCollection(ctx context.Context, dbName string, oldNam return fmt.Errorf("target database:%s not found", newDBName) } - //old collection should not be an alias + // old collection should not be an alias _, ok = mt.aliases.get(dbName, oldName) if ok { log.Warn("unsupported use a alias to rename collection") @@ -862,7 +862,7 @@ func (mt *MetaTable) RemovePartition(ctx context.Context, dbID int64, collection if !ok { return nil } - var loc = -1 + loc := -1 for idx, part := range coll.Partitions { if part.PartitionID == partitionID { loc = idx @@ -889,7 +889,7 @@ func (mt *MetaTable) CreateAlias(ctx context.Context, dbName string, alias strin // Since cache always keep the latest version, and the ts should always be the latest. if !mt.names.exist(dbName) { - return fmt.Errorf("database %s not found", dbName) + return merr.WrapErrDatabaseNotFound(dbName) } if collID, ok := mt.names.get(dbName, alias); ok { @@ -899,14 +899,14 @@ func (mt *MetaTable) CreateAlias(ctx context.Context, dbName string, alias strin } // allow alias with dropping&dropped if coll.State != pb.CollectionState_CollectionDropping && coll.State != pb.CollectionState_CollectionDropped { - return fmt.Errorf("cannot alter alias, collection already exists with same name: %s", alias) + return merr.WrapErrAliasCollectionNameConflict(dbName, alias) } } collectionID, ok := mt.names.get(dbName, collectionName) if !ok { // you cannot alias to a non-existent collection. - return fmt.Errorf("collection not exists: %s", collectionName) + return merr.WrapErrCollectionNotFoundWithDB(dbName, collectionName) } // check if alias exists. @@ -917,14 +917,15 @@ func (mt *MetaTable) CreateAlias(ctx context.Context, dbName string, alias strin } else if ok { // TODO: better to check if aliasedCollectionID exist or is available, though not very possible. aliasedColl := mt.collID2Meta[aliasedCollectionID] - return fmt.Errorf("alias exists and already aliased to another collection, alias: %s, collection: %s, other collection: %s", alias, collectionName, aliasedColl.Name) + msg := fmt.Sprintf("%s is alias to another collection: %s", alias, aliasedColl.Name) + return merr.WrapErrAliasAlreadyExist(dbName, alias, msg) } // alias didn't exist. coll, ok := mt.collID2Meta[collectionID] if !ok || !coll.Available() { // you cannot alias to a non-existent collection. - return fmt.Errorf("collection not exists: %s", collectionName) + return merr.WrapErrCollectionNotFoundWithDB(dbName, collectionName) } ctx1 := contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName.GetValue()) @@ -993,34 +994,34 @@ func (mt *MetaTable) AlterAlias(ctx context.Context, dbName string, alias string // Since cache always keep the latest version, and the ts should always be the latest. if !mt.names.exist(dbName) { - return fmt.Errorf("database not found: %s", dbName) + return merr.WrapErrDatabaseNotFound(dbName) } if collID, ok := mt.names.get(dbName, alias); ok { coll := mt.collID2Meta[collID] // allow alias with dropping&dropped if coll.State != pb.CollectionState_CollectionDropping && coll.State != pb.CollectionState_CollectionDropped { - return fmt.Errorf("cannot alter alias, collection already exists with same name: %s", alias) + return merr.WrapErrAliasCollectionNameConflict(dbName, alias) } } collectionID, ok := mt.names.get(dbName, collectionName) if !ok { // you cannot alias to a non-existent collection. - return fmt.Errorf("collection not exists: %s", collectionName) + return merr.WrapErrCollectionNotFound(collectionName) } coll, ok := mt.collID2Meta[collectionID] if !ok || !coll.Available() { // you cannot alias to a non-existent collection. - return fmt.Errorf("collection not exists: %s", collectionName) + return merr.WrapErrCollectionNotFound(collectionName) } // check if alias exists. _, ok = mt.aliases.get(dbName, alias) if !ok { // - return fmt.Errorf("failed to alter alias, alias does not exist: %s", alias) + return merr.WrapErrAliasNotFound(dbName, alias) } ctx1 := contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName.GetValue()) @@ -1102,7 +1103,7 @@ func (mt *MetaTable) GetPartitionNameByID(collID UniqueID, partitionID UniqueID, return partition.PartitionName, nil } } - return "", fmt.Errorf("partition not exist: %d", partitionID) + return "", merr.WrapErrPartitionNotFound(partitionID) } // GetPartitionByName serve for bulk insert. @@ -1126,7 +1127,7 @@ func (mt *MetaTable) GetPartitionByName(collID UniqueID, partitionName string, t return common.InvalidPartitionID, err } if !coll.Available() { - return common.InvalidPartitionID, fmt.Errorf("collection not exist: %d", collID) + return common.InvalidPartitionID, merr.WrapErrCollectionNotFoundWithDB(coll.DBID, collID) } for _, partition := range coll.Partitions { // no need to check time travel logic again, since catalog already did. @@ -1225,12 +1226,18 @@ func (mt *MetaTable) CreateRole(tenant string, entity *milvuspb.RoleEntity) erro results, err := mt.catalog.ListRole(mt.ctx, tenant, nil, false) if err != nil { - log.Error("fail to list roles", zap.Error(err)) + log.Warn("fail to list roles", zap.Error(err)) return err } + for _, result := range results { + if result.GetRole().GetName() == entity.Name { + log.Info("role already exists", zap.String("role", entity.Name)) + return common.NewIgnorableError(errors.Newf("role [%s] already exists", entity)) + } + } if len(results) >= Params.ProxyCfg.MaxRoleNum.GetAsInt() { errMsg := "unable to create role because the number of roles has reached the limit" - log.Error(errMsg, zap.Int("max_role_num", Params.ProxyCfg.MaxRoleNum.GetAsInt())) + log.Warn(errMsg, zap.Int("max_role_num", Params.ProxyCfg.MaxRoleNum.GetAsInt())) return errors.New(errMsg) } diff --git a/internal/rootcoord/meta_table_test.go b/internal/rootcoord/meta_table_test.go index 0a1635797ed23..058e95432487d 100644 --- a/internal/rootcoord/meta_table_test.go +++ b/internal/rootcoord/meta_table_test.go @@ -34,6 +34,7 @@ import ( pb "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/internalpb" mocktso "github.com/milvus-io/milvus/internal/tso/mocks" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -102,6 +103,11 @@ func TestRbacCreateRole(t *testing.T) { assert.Error(t, err) }) } + t.Run("role has existed", func(t *testing.T) { + err := mt.CreateRole(util.DefaultTenant, &milvuspb.RoleEntity{Name: "role1"}) + assert.Error(t, err) + assert.True(t, common.IsIgnorableError(err)) + }) { mockCata := mocks.NewRootCoordCatalog(t) @@ -165,7 +171,6 @@ func TestRbacOperateRole(t *testing.T) { assert.Error(t, err) }) } - } func TestRbacSelect(t *testing.T) { @@ -263,7 +268,6 @@ func TestRbacSelect(t *testing.T) { if test.isValid { assert.NoError(t, err) assert.Equal(t, test.expectedOutLength, len(res)) - } else { assert.Error(t, err) } @@ -283,57 +287,71 @@ func TestRbacOperatePrivilege(t *testing.T) { {"empty objectName", &milvuspb.GrantEntity{ObjectName: ""}, milvuspb.OperatePrivilegeType_Grant}, {"nil Object", &milvuspb.GrantEntity{ Object: nil, - ObjectName: "obj_name"}, milvuspb.OperatePrivilegeType_Grant}, + ObjectName: "obj_name", + }, milvuspb.OperatePrivilegeType_Grant}, {"empty Object name", &milvuspb.GrantEntity{ Object: &milvuspb.ObjectEntity{Name: ""}, - ObjectName: "obj_name"}, milvuspb.OperatePrivilegeType_Grant}, + ObjectName: "obj_name", + }, milvuspb.OperatePrivilegeType_Grant}, {"nil Role", &milvuspb.GrantEntity{ Role: nil, Object: &milvuspb.ObjectEntity{Name: "obj_name"}, - ObjectName: "obj_name"}, milvuspb.OperatePrivilegeType_Grant}, + ObjectName: "obj_name", + }, milvuspb.OperatePrivilegeType_Grant}, {"empty Role name", &milvuspb.GrantEntity{ Role: &milvuspb.RoleEntity{Name: ""}, Object: &milvuspb.ObjectEntity{Name: "obj_name"}, - ObjectName: "obj_name"}, milvuspb.OperatePrivilegeType_Grant}, + ObjectName: "obj_name", + }, milvuspb.OperatePrivilegeType_Grant}, {"nil grantor", &milvuspb.GrantEntity{ Grantor: nil, Role: &milvuspb.RoleEntity{Name: "role_name"}, Object: &milvuspb.ObjectEntity{Name: "obj_name"}, - ObjectName: "obj_name"}, milvuspb.OperatePrivilegeType_Grant}, + ObjectName: "obj_name", + }, milvuspb.OperatePrivilegeType_Grant}, {"nil grantor privilege", &milvuspb.GrantEntity{ Grantor: &milvuspb.GrantorEntity{ Privilege: nil, }, Role: &milvuspb.RoleEntity{Name: "role_name"}, Object: &milvuspb.ObjectEntity{Name: "obj_name"}, - ObjectName: "obj_name"}, milvuspb.OperatePrivilegeType_Grant}, + ObjectName: "obj_name", + }, milvuspb.OperatePrivilegeType_Grant}, {"empty grantor privilege name", &milvuspb.GrantEntity{ Grantor: &milvuspb.GrantorEntity{ - Privilege: &milvuspb.PrivilegeEntity{Name: ""}}, + Privilege: &milvuspb.PrivilegeEntity{Name: ""}, + }, Role: &milvuspb.RoleEntity{Name: "role_name"}, Object: &milvuspb.ObjectEntity{Name: "obj_name"}, - ObjectName: "obj_name"}, milvuspb.OperatePrivilegeType_Grant}, + ObjectName: "obj_name", + }, milvuspb.OperatePrivilegeType_Grant}, {"nil grantor user", &milvuspb.GrantEntity{ Grantor: &milvuspb.GrantorEntity{ User: nil, - Privilege: &milvuspb.PrivilegeEntity{Name: "privilege_name"}}, + Privilege: &milvuspb.PrivilegeEntity{Name: "privilege_name"}, + }, Role: &milvuspb.RoleEntity{Name: "role_name"}, Object: &milvuspb.ObjectEntity{Name: "obj_name"}, - ObjectName: "obj_name"}, milvuspb.OperatePrivilegeType_Grant}, + ObjectName: "obj_name", + }, milvuspb.OperatePrivilegeType_Grant}, {"empty grantor user name", &milvuspb.GrantEntity{ Grantor: &milvuspb.GrantorEntity{ User: &milvuspb.UserEntity{Name: ""}, - Privilege: &milvuspb.PrivilegeEntity{Name: "privilege_name"}}, + Privilege: &milvuspb.PrivilegeEntity{Name: "privilege_name"}, + }, Role: &milvuspb.RoleEntity{Name: "role_name"}, Object: &milvuspb.ObjectEntity{Name: "obj_name"}, - ObjectName: "obj_name"}, milvuspb.OperatePrivilegeType_Grant}, + ObjectName: "obj_name", + }, milvuspb.OperatePrivilegeType_Grant}, {"invalid operateType", &milvuspb.GrantEntity{ Grantor: &milvuspb.GrantorEntity{ User: &milvuspb.UserEntity{Name: "user_name"}, - Privilege: &milvuspb.PrivilegeEntity{Name: "privilege_name"}}, + Privilege: &milvuspb.PrivilegeEntity{Name: "privilege_name"}, + }, Role: &milvuspb.RoleEntity{Name: "role_name"}, Object: &milvuspb.ObjectEntity{Name: "obj_name"}, - ObjectName: "obj_name"}, milvuspb.OperatePrivilegeType(-1)}, + ObjectName: "obj_name", + }, milvuspb.OperatePrivilegeType(-1)}, } for _, test := range tests { @@ -346,10 +364,12 @@ func TestRbacOperatePrivilege(t *testing.T) { validEntity := milvuspb.GrantEntity{ Grantor: &milvuspb.GrantorEntity{ User: &milvuspb.UserEntity{Name: "user_name"}, - Privilege: &milvuspb.PrivilegeEntity{Name: "privilege_name"}}, + Privilege: &milvuspb.PrivilegeEntity{Name: "privilege_name"}, + }, Role: &milvuspb.RoleEntity{Name: "role_name"}, Object: &milvuspb.ObjectEntity{Name: "obj_name"}, - ObjectName: "obj_name"} + ObjectName: "obj_name", + } err := mt.OperatePrivilege(util.DefaultTenant, &validEntity, milvuspb.OperatePrivilegeType_Grant) assert.NoError(t, err) @@ -366,11 +386,14 @@ func TestRbacSelectGrant(t *testing.T) { }{ {"nil Entity", false, nil}, {"nil entity Role", false, &milvuspb.GrantEntity{ - Role: nil}}, + Role: nil, + }}, {"empty entity Role name", false, &milvuspb.GrantEntity{ - Role: &milvuspb.RoleEntity{Name: ""}}}, + Role: &milvuspb.RoleEntity{Name: ""}, + }}, {"valid", true, &milvuspb.GrantEntity{ - Role: &milvuspb.RoleEntity{Name: "role"}}}, + Role: &milvuspb.RoleEntity{Name: "role"}, + }}, } for _, test := range tests { diff --git a/internal/rootcoord/metrics_info.go b/internal/rootcoord/metrics_info.go index d32e887c57850..c54e3d9890120 100644 --- a/internal/rootcoord/metrics_info.go +++ b/internal/rootcoord/metrics_info.go @@ -21,7 +21,6 @@ import ( "go.uber.org/zap" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/hardware" @@ -69,17 +68,14 @@ func (c *Core) getSystemInfoMetrics(ctx context.Context, req *milvuspb.GetMetric zap.Error(err)) return &milvuspb.GetMetricsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, + Status: merr.Status(err), Response: "", ComponentName: metricsinfo.ConstructComponentName(typeutil.RootCoordRole, c.session.ServerID), }, nil } return &milvuspb.GetMetricsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Response: resp, ComponentName: metricsinfo.ConstructComponentName(typeutil.RootCoordRole, c.session.ServerID), }, nil diff --git a/internal/rootcoord/mock_test.go b/internal/rootcoord/mock_test.go index f1ebe477f0ff6..0922debfe44c6 100644 --- a/internal/rootcoord/mock_test.go +++ b/internal/rootcoord/mock_test.go @@ -22,6 +22,10 @@ import ( "os" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/mock" + "go.uber.org/zap" + "google.golang.org/grpc" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/allocator" @@ -39,12 +43,11 @@ import ( "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/stretchr/testify/mock" - "go.uber.org/zap" ) const ( @@ -259,7 +262,7 @@ func newMockMetaTable() *mockMetaTable { //} type mockDataCoord struct { - types.DataCoord + types.DataCoordClient GetComponentStatesFunc func(ctx context.Context) (*milvuspb.ComponentStates, error) WatchChannelsFunc func(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) FlushFunc func(ctx context.Context, req *datapb.FlushRequest) (*datapb.FlushResponse, error) @@ -274,60 +277,60 @@ func newMockDataCoord() *mockDataCoord { return &mockDataCoord{} } -func (m *mockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (m *mockDataCoord) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return m.GetComponentStatesFunc(ctx) } -func (m *mockDataCoord) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) { +func (m *mockDataCoord) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest, opts ...grpc.CallOption) (*datapb.WatchChannelsResponse, error) { return m.WatchChannelsFunc(ctx, req) } -func (m *mockDataCoord) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.FlushResponse, error) { +func (m *mockDataCoord) Flush(ctx context.Context, req *datapb.FlushRequest, opts ...grpc.CallOption) (*datapb.FlushResponse, error) { return m.FlushFunc(ctx, req) } -func (m *mockDataCoord) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { +func (m *mockDataCoord) Import(ctx context.Context, req *datapb.ImportTaskRequest, opts ...grpc.CallOption) (*datapb.ImportTaskResponse, error) { return m.ImportFunc(ctx, req) } -func (m *mockDataCoord) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { +func (m *mockDataCoord) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsImportingStateRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return m.UnsetIsImportingStateFunc(ctx, req) } -func (m *mockDataCoord) BroadcastAlteredCollection(ctx context.Context, req *datapb.AlterCollectionRequest) (*commonpb.Status, error) { +func (m *mockDataCoord) BroadcastAlteredCollection(ctx context.Context, req *datapb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return m.broadCastAlteredCollectionFunc(ctx, req) } -func (m *mockDataCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { +func (m *mockDataCoord) CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { return &milvuspb.CheckHealthResponse{ IsHealthy: true, }, nil } -func (m *mockDataCoord) GetSegmentIndexState(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest) (*indexpb.GetSegmentIndexStateResponse, error) { +func (m *mockDataCoord) GetSegmentIndexState(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest, opts ...grpc.CallOption) (*indexpb.GetSegmentIndexStateResponse, error) { return m.GetSegmentIndexStateFunc(ctx, req) } -func (m *mockDataCoord) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error) { +func (m *mockDataCoord) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return m.DropIndexFunc(ctx, req) } type mockQueryCoord struct { - types.QueryCoord + types.QueryCoordClient GetSegmentInfoFunc func(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) GetComponentStatesFunc func(ctx context.Context) (*milvuspb.ComponentStates, error) ReleaseCollectionFunc func(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) } -func (m mockQueryCoord) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { +func (m mockQueryCoord) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) { return m.GetSegmentInfoFunc(ctx, req) } -func (m mockQueryCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (m mockQueryCoord) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return m.GetComponentStatesFunc(ctx) } -func (m mockQueryCoord) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { +func (m mockQueryCoord) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return m.ReleaseCollectionFunc(ctx, req) } @@ -355,40 +358,40 @@ func newMockTsoAllocator() *tso.MockAllocator { } type mockProxy struct { - types.Proxy + types.ProxyClient InvalidateCollectionMetaCacheFunc func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) InvalidateCredentialCacheFunc func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) RefreshPolicyInfoCacheFunc func(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) GetComponentStatesFunc func(ctx context.Context) (*milvuspb.ComponentStates, error) } -func (m mockProxy) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { +func (m mockProxy) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return m.InvalidateCollectionMetaCacheFunc(ctx, request) } -func (m mockProxy) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { +func (m mockProxy) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return m.InvalidateCredentialCacheFunc(ctx, request) } -func (m mockProxy) RefreshPolicyInfoCache(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { +func (m mockProxy) RefreshPolicyInfoCache(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return m.RefreshPolicyInfoCacheFunc(ctx, request) } -func (m mockProxy) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (m mockProxy) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return m.GetComponentStatesFunc(ctx) } func newMockProxy() *mockProxy { r := &mockProxy{} r.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - return succStatus(), nil + return merr.Success(), nil } return r } func newTestCore(opts ...Opt) *Core { c := &Core{ - session: &sessionutil.Session{ServerID: TestRootCoordID}, + session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: TestRootCoordID}}, } executor := newMockStepExecutor() executor.AddStepsFunc = func(s *stepStack) { @@ -406,11 +409,11 @@ func newTestCore(opts ...Opt) *Core { func withValidProxyManager() Opt { return func(c *Core) { c.proxyClientManager = &proxyClientManager{ - proxyClient: make(map[UniqueID]types.Proxy), + proxyClient: make(map[UniqueID]types.ProxyClient), } p := newMockProxy() p.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - return succStatus(), nil + return merr.Success(), nil } p.GetComponentStatesFunc = func(ctx context.Context) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ @@ -425,11 +428,11 @@ func withValidProxyManager() Opt { func withInvalidProxyManager() Opt { return func(c *Core) { c.proxyClientManager = &proxyClientManager{ - proxyClient: make(map[UniqueID]types.Proxy), + proxyClient: make(map[UniqueID]types.ProxyClient), } p := newMockProxy() p.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - return succStatus(), errors.New("error mock InvalidateCollectionMetaCache") + return merr.Success(), errors.New("error mock InvalidateCollectionMetaCache") } p.GetComponentStatesFunc = func(ctx context.Context) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ @@ -559,29 +562,30 @@ func withInvalidIDAllocator() Opt { return withIDAllocator(idAllocator) } -func withQueryCoord(qc types.QueryCoord) Opt { +func withQueryCoord(qc types.QueryCoordClient) Opt { return func(c *Core) { c.queryCoord = qc } } func withUnhealthyQueryCoord() Opt { - qc := &mocks.MockQueryCoord{} - qc.EXPECT().GetComponentStates(mock.Anything).Return( + qc := &mocks.MockQueryCoordClient{} + err := errors.New("mock error") + qc.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return( &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Abnormal}, - Status: failStatus(commonpb.ErrorCode_UnexpectedError, "error mock GetComponentStates"), + Status: merr.Status(err), }, retry.Unrecoverable(errors.New("error mock GetComponentStates")), ) return withQueryCoord(qc) } func withInvalidQueryCoord() Opt { - qc := &mocks.MockQueryCoord{} - qc.EXPECT().GetComponentStates(mock.Anything).Return( + qc := &mocks.MockQueryCoordClient{} + qc.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return( &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy}, - Status: succStatus(), + Status: merr.Success(), }, nil, ) qc.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).Return( @@ -596,20 +600,21 @@ func withInvalidQueryCoord() Opt { } func withFailedQueryCoord() Opt { - qc := &mocks.MockQueryCoord{} - qc.EXPECT().GetComponentStates(mock.Anything).Return( + qc := &mocks.MockQueryCoordClient{} + qc.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return( &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy}, - Status: succStatus(), + Status: merr.Success(), }, nil, ) + err := errors.New("mock error") qc.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).Return( - failStatus(commonpb.ErrorCode_UnexpectedError, "mock release collection error"), nil, + merr.Status(err), nil, ) qc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return( &querypb.GetSegmentInfoResponse{ - Status: failStatus(commonpb.ErrorCode_UnexpectedError, "mock get segment info error"), + Status: merr.Status(err), }, nil, ) @@ -617,29 +622,29 @@ func withFailedQueryCoord() Opt { } func withValidQueryCoord() Opt { - qc := &mocks.MockQueryCoord{} - qc.EXPECT().GetComponentStates(mock.Anything).Return( + qc := &mocks.MockQueryCoordClient{} + qc.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return( &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy}, - Status: succStatus(), + Status: merr.Success(), }, nil, ) qc.EXPECT().ReleaseCollection(mock.Anything, mock.Anything).Return( - succStatus(), nil, + merr.Success(), nil, ) qc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return( - succStatus(), nil, + merr.Success(), nil, ) qc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return( &querypb.GetSegmentInfoResponse{ - Status: succStatus(), + Status: merr.Success(), }, nil, ) qc.EXPECT().SyncNewCreatedPartition(mock.Anything, mock.Anything).Return( - succStatus(), nil, + merr.Success(), nil, ) return withQueryCoord(qc) @@ -676,7 +681,7 @@ func withRocksMqTtSynchronizer() Opt { return withTtSynchronizer(ticker) } -func withDataCoord(dc types.DataCoord) Opt { +func withDataCoord(dc types.DataCoordClient) Opt { return func(c *Core) { c.dataCoord = dc } @@ -684,10 +689,11 @@ func withDataCoord(dc types.DataCoord) Opt { func withUnhealthyDataCoord() Opt { dc := newMockDataCoord() + err := errors.New("mock error") dc.GetComponentStatesFunc = func(ctx context.Context) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Abnormal}, - Status: failStatus(commonpb.ErrorCode_UnexpectedError, "error mock GetComponentStates"), + Status: merr.Status(err), }, retry.Unrecoverable(errors.New("error mock GetComponentStates")) } return withDataCoord(dc) @@ -698,7 +704,7 @@ func withInvalidDataCoord() Opt { dc.GetComponentStatesFunc = func(ctx context.Context) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy}, - Status: succStatus(), + Status: merr.Success(), }, nil } dc.WatchChannelsFunc = func(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) { @@ -733,22 +739,23 @@ func withFailedDataCoord() Opt { dc.GetComponentStatesFunc = func(ctx context.Context) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy}, - Status: succStatus(), + Status: merr.Success(), }, nil } + err := errors.New("mock error") dc.WatchChannelsFunc = func(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) { return &datapb.WatchChannelsResponse{ - Status: failStatus(commonpb.ErrorCode_UnexpectedError, "mock watch channels error"), + Status: merr.Status(err), }, nil } dc.FlushFunc = func(ctx context.Context, req *datapb.FlushRequest) (*datapb.FlushResponse, error) { return &datapb.FlushResponse{ - Status: failStatus(commonpb.ErrorCode_UnexpectedError, "mock flush error"), + Status: merr.Status(err), }, nil } dc.ImportFunc = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ - Status: failStatus(commonpb.ErrorCode_UnexpectedError, "mock import error"), + Status: merr.Status(err), }, nil } dc.UnsetIsImportingStateFunc = func(ctx context.Context, req *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { @@ -758,15 +765,15 @@ func withFailedDataCoord() Opt { }, nil } dc.broadCastAlteredCollectionFunc = func(ctx context.Context, req *datapb.AlterCollectionRequest) (*commonpb.Status, error) { - return failStatus(commonpb.ErrorCode_UnexpectedError, "mock broadcast altered collection error"), nil + return merr.Status(err), nil } dc.GetSegmentIndexStateFunc = func(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest) (*indexpb.GetSegmentIndexStateResponse, error) { return &indexpb.GetSegmentIndexStateResponse{ - Status: failStatus(commonpb.ErrorCode_UnexpectedError, "mock GetSegmentIndexStateFunc fail"), + Status: merr.Status(err), }, nil } dc.DropIndexFunc = func(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error) { - return failStatus(commonpb.ErrorCode_UnexpectedError, "mock DropIndexFunc fail"), nil + return merr.Status(err), nil } return withDataCoord(dc) } @@ -776,37 +783,37 @@ func withValidDataCoord() Opt { dc.GetComponentStatesFunc = func(ctx context.Context) (*milvuspb.ComponentStates, error) { return &milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Healthy}, - Status: succStatus(), + Status: merr.Success(), }, nil } dc.WatchChannelsFunc = func(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) { return &datapb.WatchChannelsResponse{ - Status: succStatus(), + Status: merr.Success(), }, nil } dc.FlushFunc = func(ctx context.Context, req *datapb.FlushRequest) (*datapb.FlushResponse, error) { return &datapb.FlushResponse{ - Status: succStatus(), + Status: merr.Success(), }, nil } dc.ImportFunc = func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { return &datapb.ImportTaskResponse{ - Status: succStatus(), + Status: merr.Success(), }, nil } dc.UnsetIsImportingStateFunc = func(ctx context.Context, req *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { - return succStatus(), nil + return merr.Success(), nil } dc.broadCastAlteredCollectionFunc = func(ctx context.Context, req *datapb.AlterCollectionRequest) (*commonpb.Status, error) { - return succStatus(), nil + return merr.Success(), nil } dc.GetSegmentIndexStateFunc = func(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest) (*indexpb.GetSegmentIndexStateResponse, error) { return &indexpb.GetSegmentIndexStateResponse{ - Status: succStatus(), + Status: merr.Success(), }, nil } dc.DropIndexFunc = func(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error) { - return succStatus(), nil + return merr.Success(), nil } return withDataCoord(dc) } @@ -1107,21 +1114,18 @@ func newMockStepExecutor() *mockStepExecutor { func (m mockStepExecutor) Start() { if m.StartFunc != nil { m.StartFunc() - } else { } } func (m mockStepExecutor) Stop() { if m.StopFunc != nil { m.StopFunc() - } else { } } func (m mockStepExecutor) AddSteps(s *stepStack) { if m.AddStepsFunc != nil { m.AddStepsFunc(s) - } else { } } diff --git a/internal/rootcoord/proxy_client_manager.go b/internal/rootcoord/proxy_client_manager.go index cd5073d76275d..3b5af86fc5bbd 100644 --- a/internal/rootcoord/proxy_client_manager.go +++ b/internal/rootcoord/proxy_client_manager.go @@ -21,7 +21,6 @@ import ( "fmt" "sync" - "github.com/cockroachdb/errors" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -33,29 +32,24 @@ import ( "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" ) -type proxyCreator func(ctx context.Context, addr string, nodeID int64) (types.Proxy, error) +type proxyCreator func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) -func DefaultProxyCreator(ctx context.Context, addr string, nodeID int64) (types.Proxy, error) { +func DefaultProxyCreator(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) { cli, err := grpcproxyclient.NewClient(ctx, addr, nodeID) if err != nil { return nil, err } - if err := cli.Init(); err != nil { - return nil, err - } - if err := cli.Start(); err != nil { - return nil, err - } return cli, nil } type proxyClientManager struct { creator proxyCreator lock sync.RWMutex - proxyClient map[int64]types.Proxy + proxyClient map[int64]types.ProxyClient helper proxyClientManagerHelper } @@ -70,7 +64,7 @@ var defaultClientManagerHelper = proxyClientManagerHelper{ func newProxyClientManager(creator proxyCreator) *proxyClientManager { return &proxyClientManager{ creator: creator, - proxyClient: make(map[int64]types.Proxy), + proxyClient: make(map[int64]types.ProxyClient), helper: defaultClientManagerHelper, } } @@ -118,7 +112,7 @@ func (p *proxyClientManager) connect(session *sessionutil.Session) { _, ok := p.proxyClient[session.ServerID] if ok { - pc.Stop() + pc.Close() return } p.proxyClient[session.ServerID] = pc @@ -132,7 +126,7 @@ func (p *proxyClientManager) DelProxyClient(s *sessionutil.Session) { cli, ok := p.proxyClient[s.ServerID] if ok { - cli.Stop() + cli.Close() } delete(p.proxyClient, s.ServerID) @@ -245,7 +239,7 @@ func (p *proxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *pr return fmt.Errorf("RefreshPolicyInfoCache failed, proxyID = %d, err = %s", k, err) } if status.GetErrorCode() != commonpb.ErrorCode_Success { - return errors.New(status.GetReason()) + return merr.Error(status) } return nil }) diff --git a/internal/rootcoord/proxy_client_manager_test.go b/internal/rootcoord/proxy_client_manager_test.go index b151f507a9576..9892776c62d69 100644 --- a/internal/rootcoord/proxy_client_manager_test.go +++ b/internal/rootcoord/proxy_client_manager_test.go @@ -30,11 +30,12 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) type proxyMock struct { - types.Proxy + types.ProxyClient collArray []string collIDs []UniqueID mutex sync.Mutex @@ -60,9 +61,7 @@ func (p *proxyMock) InvalidateCollectionMetaCache(ctx context.Context, request * } p.collArray = append(p.collArray, request.CollectionName) p.collIDs = append(p.collIDs, request.CollectionID) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Success(), nil } func (p *proxyMock) GetCollArray() []string { @@ -89,16 +88,11 @@ func (p *proxyMock) InvalidateCredentialCache(ctx context.Context, request *prox if p.returnGrpcError { return nil, fmt.Errorf("grpc error") } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, nil + return merr.Success(), nil } func (p *proxyMock) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Success(), nil } func TestProxyClientManager_GetProxyClients(t *testing.T) { @@ -117,15 +111,17 @@ func TestProxyClientManager_GetProxyClients(t *testing.T) { defer cli.Close() assert.NoError(t, err) core.etcdCli = cli - core.proxyCreator = func(ctx context.Context, addr string, nodeID int64) (types.Proxy, error) { + core.proxyCreator = func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) { return nil, errors.New("failed") } pcm := newProxyClientManager(core.proxyCreator) session := &sessionutil.Session{ - ServerID: 100, - Address: "localhost", + SessionRaw: sessionutil.SessionRaw{ + ServerID: 100, + Address: "localhost", + }, } sessions := []*sessionutil.Session{session} @@ -149,15 +145,17 @@ func TestProxyClientManager_AddProxyClient(t *testing.T) { defer cli.Close() core.etcdCli = cli - core.proxyCreator = func(ctx context.Context, addr string, nodeID int64) (types.Proxy, error) { + core.proxyCreator = func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) { return nil, errors.New("failed") } pcm := newProxyClientManager(core.proxyCreator) session := &sessionutil.Session{ - ServerID: 100, - Address: "localhost", + SessionRaw: sessionutil.SessionRaw{ + ServerID: 100, + Address: "localhost", + }, } pcm.AddProxyClient(session) @@ -166,7 +164,7 @@ func TestProxyClientManager_AddProxyClient(t *testing.T) { func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) { t.Run("empty proxy list", func(t *testing.T) { ctx := context.Background() - pcm := &proxyClientManager{proxyClient: map[int64]types.Proxy{}} + pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{}} err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) assert.NoError(t, err) }) @@ -175,9 +173,9 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) { ctx := context.Background() p1 := newMockProxy() p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - return succStatus(), errors.New("error mock InvalidateCollectionMetaCache") + return merr.Success(), errors.New("error mock InvalidateCollectionMetaCache") } - pcm := &proxyClientManager{proxyClient: map[int64]types.Proxy{ + pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ TestProxyID: p1, }} err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) @@ -187,10 +185,11 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) { t.Run("mock error code", func(t *testing.T) { ctx := context.Background() p1 := newMockProxy() + mockErr := errors.New("mock error") p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - return failStatus(commonpb.ErrorCode_UnexpectedError, "error mock error code"), nil + return merr.Status(mockErr), nil } - pcm := &proxyClientManager{proxyClient: map[int64]types.Proxy{ + pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ TestProxyID: p1, }} err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) @@ -201,9 +200,9 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) { ctx := context.Background() p1 := newMockProxy() p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - return succStatus(), nil + return merr.Success(), nil } - pcm := &proxyClientManager{proxyClient: map[int64]types.Proxy{ + pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ TestProxyID: p1, }} err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) @@ -214,7 +213,7 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) { func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) { t.Run("empty proxy list", func(t *testing.T) { ctx := context.Background() - pcm := &proxyClientManager{proxyClient: map[int64]types.Proxy{}} + pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{}} err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) assert.NoError(t, err) }) @@ -223,9 +222,9 @@ func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) { ctx := context.Background() p1 := newMockProxy() p1.InvalidateCredentialCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { - return succStatus(), errors.New("error mock InvalidateCredentialCache") + return merr.Success(), errors.New("error mock InvalidateCredentialCache") } - pcm := &proxyClientManager{proxyClient: map[int64]types.Proxy{ + pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ TestProxyID: p1, }} err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) @@ -235,10 +234,11 @@ func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) { t.Run("mock error code", func(t *testing.T) { ctx := context.Background() p1 := newMockProxy() + mockErr := errors.New("mock error") p1.InvalidateCredentialCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { - return failStatus(commonpb.ErrorCode_UnexpectedError, "error mock error code"), nil + return merr.Status(mockErr), nil } - pcm := &proxyClientManager{proxyClient: map[int64]types.Proxy{ + pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ TestProxyID: p1, }} err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) @@ -249,9 +249,9 @@ func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) { ctx := context.Background() p1 := newMockProxy() p1.InvalidateCredentialCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { - return succStatus(), nil + return merr.Success(), nil } - pcm := &proxyClientManager{proxyClient: map[int64]types.Proxy{ + pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ TestProxyID: p1, }} err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) @@ -262,7 +262,7 @@ func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) { func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) { t.Run("empty proxy list", func(t *testing.T) { ctx := context.Background() - pcm := &proxyClientManager{proxyClient: map[int64]types.Proxy{}} + pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{}} err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) assert.NoError(t, err) }) @@ -271,9 +271,9 @@ func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) { ctx := context.Background() p1 := newMockProxy() p1.RefreshPolicyInfoCacheFunc = func(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { - return succStatus(), errors.New("error mock RefreshPolicyInfoCache") + return merr.Success(), errors.New("error mock RefreshPolicyInfoCache") } - pcm := &proxyClientManager{proxyClient: map[int64]types.Proxy{ + pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ TestProxyID: p1, }} err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) @@ -283,10 +283,11 @@ func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) { t.Run("mock error code", func(t *testing.T) { ctx := context.Background() p1 := newMockProxy() + mockErr := errors.New("mock error") p1.RefreshPolicyInfoCacheFunc = func(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { - return failStatus(commonpb.ErrorCode_UnexpectedError, "error mock error code"), nil + return merr.Status(mockErr), nil } - pcm := &proxyClientManager{proxyClient: map[int64]types.Proxy{ + pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ TestProxyID: p1, }} err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) @@ -297,9 +298,9 @@ func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) { ctx := context.Background() p1 := newMockProxy() p1.RefreshPolicyInfoCacheFunc = func(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { - return succStatus(), nil + return merr.Success(), nil } - pcm := &proxyClientManager{proxyClient: map[int64]types.Proxy{ + pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ TestProxyID: p1, }} err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) diff --git a/internal/rootcoord/proxy_manager.go b/internal/rootcoord/proxy_manager.go index 8ded5da9234cd..dd86d153debcf 100644 --- a/internal/rootcoord/proxy_manager.go +++ b/internal/rootcoord/proxy_manager.go @@ -125,7 +125,7 @@ func (p *proxyManager) startWatchEtcd(ctx context.Context, eventCh clientv3.Watc err2 := p.WatchProxy() if err2 != nil { log.Error("re watch proxy fails when etcd has a compaction error", - zap.String("etcd error", err.Error()), zap.Error(err2)) + zap.Error(err), zap.Error(err2)) panic("failed to handle etcd request, exit..") } return diff --git a/internal/rootcoord/proxy_manager_test.go b/internal/rootcoord/proxy_manager_test.go index 6660824964692..c60310d414cac 100644 --- a/internal/rootcoord/proxy_manager_test.go +++ b/internal/rootcoord/proxy_manager_test.go @@ -53,7 +53,7 @@ func TestProxyManager(t *testing.T) { etcdCli.Delete(ctx, sessKey, clientv3.WithPrefix()) defer etcdCli.Delete(ctx, sessKey, clientv3.WithPrefix()) s1 := sessionutil.Session{ - ServerID: 100, + SessionRaw: sessionutil.SessionRaw{ServerID: 100}, } b1, err := json.Marshal(&s1) assert.NoError(t, err) @@ -62,7 +62,7 @@ func TestProxyManager(t *testing.T) { assert.NoError(t, err) s0 := sessionutil.Session{ - ServerID: 99, + SessionRaw: sessionutil.SessionRaw{ServerID: 99}, } b0, err := json.Marshal(&s0) assert.NoError(t, err) @@ -94,7 +94,7 @@ func TestProxyManager(t *testing.T) { t.Log("======== start watch proxy ==========") s2 := sessionutil.Session{ - ServerID: 101, + SessionRaw: sessionutil.SessionRaw{ServerID: 101}, } b2, err := json.Marshal(&s2) assert.NoError(t, err) diff --git a/internal/rootcoord/quota_center.go b/internal/rootcoord/quota_center.go index f56ddee250efb..17c49ee616154 100644 --- a/internal/rootcoord/quota_center.go +++ b/internal/rootcoord/quota_center.go @@ -87,8 +87,8 @@ type collectionStates = map[milvuspb.QuotaState]commonpb.ErrorCode type QuotaCenter struct { // clients proxies *proxyClientManager - queryCoord types.QueryCoord - dataCoord types.DataCoord + queryCoord types.QueryCoordClient + dataCoord types.DataCoordClient meta IMetaTable // metrics @@ -113,7 +113,7 @@ type QuotaCenter struct { } // NewQuotaCenter returns a new QuotaCenter. -func NewQuotaCenter(proxies *proxyClientManager, queryCoord types.QueryCoord, dataCoord types.DataCoord, tsoAllocator tso.Allocator, meta IMetaTable) *QuotaCenter { +func NewQuotaCenter(proxies *proxyClientManager, queryCoord types.QueryCoordClient, dataCoord types.DataCoordClient, tsoAllocator tso.Allocator, meta IMetaTable) *QuotaCenter { return &QuotaCenter{ proxies: proxies, queryCoord: queryCoord, @@ -264,7 +264,7 @@ func (q *QuotaCenter) syncMetrics() error { if err != nil { return err } - //log.Debug("QuotaCenter sync metrics done", + // log.Debug("QuotaCenter sync metrics done", // zap.Any("dataNodeMetrics", q.dataNodeMetrics), // zap.Any("queryNodeMetrics", q.queryNodeMetrics), // zap.Any("proxyMetrics", q.proxyMetrics), @@ -857,7 +857,7 @@ func (q *QuotaCenter) recordMetrics() { func (q *QuotaCenter) diskAllowance(collection UniqueID) float64 { q.diskMu.Lock() - q.diskMu.Unlock() + defer q.diskMu.Unlock() if !Params.QuotaConfig.DiskProtectionEnabled.GetAsBool() { return math.MaxInt64 } diff --git a/internal/rootcoord/quota_center_test.go b/internal/rootcoord/quota_center_test.go index b9d65ba089dcd..01fb5d36d4bfb 100644 --- a/internal/rootcoord/quota_center_test.go +++ b/internal/rootcoord/quota_center_test.go @@ -23,8 +23,10 @@ import ( "testing" "time" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -48,17 +50,18 @@ type dataCoordMockForQuota struct { retFailStatus bool } -func (d *dataCoordMockForQuota) GetMetrics(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { +func (d *dataCoordMockForQuota) GetMetrics(ctx context.Context, request *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + mockErr := errors.New("mock error") if d.retErr { - return nil, fmt.Errorf("mock err") + return nil, mockErr } if d.retFailStatus { return &milvuspb.GetMetricsResponse{ - Status: failStatus(commonpb.ErrorCode_UnexpectedError, "mock failure status"), + Status: merr.Status(mockErr), }, nil } return &milvuspb.GetMetricsResponse{ - Status: succStatus(), + Status: merr.Success(), }, nil } @@ -73,7 +76,7 @@ func TestQuotaCenter(t *testing.T) { pcm := newProxyClientManager(core.proxyCreator) t.Run("test QuotaCenter", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) @@ -83,10 +86,10 @@ func TestQuotaCenter(t *testing.T) { }) t.Run("test syncMetrics", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() - qc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{Status: succStatus()}, nil) + qc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{Status: merr.Success()}, nil) quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) err = quotaCenter.syncMetrics() assert.Error(t, err) // for empty response @@ -105,7 +108,7 @@ func TestQuotaCenter(t *testing.T) { assert.Error(t, err) qc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{ - Status: failStatus(commonpb.ErrorCode_UnexpectedError, "mock failure status"), + Status: merr.Status(err), }, nil) quotaCenter = NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) err = quotaCenter.syncMetrics() @@ -113,7 +116,7 @@ func TestQuotaCenter(t *testing.T) { }) t.Run("test forceDeny", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) @@ -143,7 +146,7 @@ func TestQuotaCenter(t *testing.T) { }) t.Run("test calculateRates", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) @@ -159,7 +162,7 @@ func TestQuotaCenter(t *testing.T) { }) t.Run("test getTimeTickDelayFactor factors", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) @@ -207,7 +210,7 @@ func TestQuotaCenter(t *testing.T) { }) t.Run("test TimeTickDelayFactor factors", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) @@ -280,7 +283,7 @@ func TestQuotaCenter(t *testing.T) { }) t.Run("test calculateReadRates", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) @@ -289,7 +292,8 @@ func TestQuotaCenter(t *testing.T) { 1: {Rms: []metricsinfo.RateMetric{ {Label: internalpb.RateType_DQLSearch.String(), Rate: 100}, {Label: internalpb.RateType_DQLQuery.String(), Rate: 100}, - }}} + }}, + } paramtable.Get().Save(Params.QuotaConfig.ForceDenyReading.Key, "false") paramtable.Get().Save(Params.QuotaConfig.QueueProtectionEnabled.Key, "true") @@ -304,7 +308,8 @@ func TestQuotaCenter(t *testing.T) { }, Effect: metricsinfo.NodeEffect{ NodeID: 1, CollectionIDs: []int64{1, 2, 3}, - }}} + }}, + } quotaCenter.calculateReadRates() for _, collection := range quotaCenter.readableCollections { assert.Equal(t, Limit(100.0*0.9), quotaCenter.currentRates[collection][internalpb.RateType_DQLSearch]) @@ -315,7 +320,8 @@ func TestQuotaCenter(t *testing.T) { quotaCenter.queryNodeMetrics = map[UniqueID]*metricsinfo.QueryNodeQuotaMetrics{ 1: {SearchQueue: metricsinfo.ReadInfoInQueue{ UnsolvedQueue: Params.QuotaConfig.NQInQueueThreshold.GetAsInt64(), - }}} + }}, + } quotaCenter.calculateReadRates() for _, collection := range quotaCenter.readableCollections { assert.Equal(t, Limit(100.0*0.9), quotaCenter.currentRates[collection][internalpb.RateType_DQLSearch]) @@ -329,7 +335,8 @@ func TestQuotaCenter(t *testing.T) { {Label: internalpb.RateType_DQLSearch.String(), Rate: 100}, {Label: internalpb.RateType_DQLQuery.String(), Rate: 100}, {Label: metricsinfo.ReadResultThroughput, Rate: 1.2}, - }}} + }}, + } quotaCenter.queryNodeMetrics = map[UniqueID]*metricsinfo.QueryNodeQuotaMetrics{1: {SearchQueue: metricsinfo.ReadInfoInQueue{}}} quotaCenter.calculateReadRates() for _, collection := range quotaCenter.readableCollections { @@ -339,7 +346,7 @@ func TestQuotaCenter(t *testing.T) { }) t.Run("test calculateWriteRates", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) @@ -379,7 +386,7 @@ func TestQuotaCenter(t *testing.T) { }) t.Run("test MemoryFactor factors", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) @@ -434,7 +441,7 @@ func TestQuotaCenter(t *testing.T) { }) t.Run("test GrowingSegmentsSize factors", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) @@ -489,7 +496,7 @@ func TestQuotaCenter(t *testing.T) { }) t.Run("test checkDiskQuota", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) @@ -500,7 +507,8 @@ func TestQuotaCenter(t *testing.T) { paramtable.Get().Save(Params.QuotaConfig.DiskQuota.Key, "99") quotaCenter.dataCoordMetrics = &metricsinfo.DataCoordQuotaMetrics{ TotalBinlogSize: 200 * 1024 * 1024, - CollectionBinlogSize: map[int64]int64{1: 100 * 1024 * 1024}} + CollectionBinlogSize: map[int64]int64{1: 100 * 1024 * 1024}, + } quotaCenter.writableCollections = []int64{1, 2, 3} quotaCenter.resetAllCurrentRates() quotaCenter.checkDiskQuota() @@ -515,7 +523,8 @@ func TestQuotaCenter(t *testing.T) { colQuotaBackup := Params.QuotaConfig.DiskQuotaPerCollection.GetValue() paramtable.Get().Save(Params.QuotaConfig.DiskQuotaPerCollection.Key, "30") quotaCenter.dataCoordMetrics = &metricsinfo.DataCoordQuotaMetrics{CollectionBinlogSize: map[int64]int64{ - 1: 20 * 1024 * 1024, 2: 30 * 1024 * 1024, 3: 60 * 1024 * 1024}} + 1: 20 * 1024 * 1024, 2: 30 * 1024 * 1024, 3: 60 * 1024 * 1024, + }} quotaCenter.writableCollections = []int64{1, 2, 3} quotaCenter.resetAllCurrentRates() quotaCenter.checkDiskQuota() @@ -532,10 +541,10 @@ func TestQuotaCenter(t *testing.T) { }) t.Run("test setRates", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) - p1 := mocks.NewMockProxy(t) + qc := mocks.NewMockQueryCoordClient(t) + p1 := mocks.NewMockProxyClient(t) p1.EXPECT().SetRates(mock.Anything, mock.Anything).Return(nil, nil) - pcm := &proxyClientManager{proxyClient: map[int64]types.Proxy{ + pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{ TestProxyID: p1, }} meta := mockrootcoord.NewIMetaTable(t) @@ -553,7 +562,7 @@ func TestQuotaCenter(t *testing.T) { }) t.Run("test recordMetrics", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) @@ -565,7 +574,7 @@ func TestQuotaCenter(t *testing.T) { }) t.Run("test guaranteeMinRate", func(t *testing.T) { - qc := mocks.NewMockQueryCoord(t) + qc := mocks.NewMockQueryCoordClient(t) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, &dataCoordMockForQuota{}, core.tsoAllocator, meta) diff --git a/internal/rootcoord/redo.go b/internal/rootcoord/redo.go index 8c9b412485c6d..72406ee29604d 100644 --- a/internal/rootcoord/redo.go +++ b/internal/rootcoord/redo.go @@ -19,8 +19,9 @@ package rootcoord import ( "context" - "github.com/milvus-io/milvus/pkg/log" "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" ) type baseRedoTask struct { diff --git a/internal/rootcoord/redo_test.go b/internal/rootcoord/redo_test.go index 7c80558a1809d..a01e897fcd598 100644 --- a/internal/rootcoord/redo_test.go +++ b/internal/rootcoord/redo_test.go @@ -21,7 +21,6 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" ) diff --git a/internal/rootcoord/rename_collection_task_test.go b/internal/rootcoord/rename_collection_task_test.go index caac2a03a2a13..dd4be07ab95ab 100644 --- a/internal/rootcoord/rename_collection_task_test.go +++ b/internal/rootcoord/rename_collection_task_test.go @@ -21,11 +21,9 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" ) diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index afc7e4af7da83..bfbba5b2a1335 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -22,13 +22,14 @@ import ( "math/rand" "os" "sync" - "sync/atomic" "syscall" "time" "github.com/cockroachdb/errors" "github.com/samber/lo" + "github.com/tikv/client-go/v2/txnkv" clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/atomic" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -38,6 +39,7 @@ import ( "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/kv/tikv" "github.com/milvus-io/milvus/internal/metastore" kvmetestore "github.com/milvus-io/milvus/internal/metastore/kv/rootcoord" "github.com/milvus-io/milvus/internal/metastore/model" @@ -57,7 +59,6 @@ import ( "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/crypto" - "github.com/milvus-io/milvus/pkg/util/errorutil" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -81,13 +82,7 @@ var Params *paramtable.ComponentParam = paramtable.Get() type Opt func(*Core) -type metaKVCreator func(root string) (kv.MetaKv, error) - -func defaultMetaKVCreator(etcdCli *clientv3.Client) metaKVCreator { - return func(root string) (kv.MetaKv, error) { - return etcdkv.NewEtcdKV(etcdCli, root), nil - } -} +type metaKVCreator func() (kv.MetaKv, error) // Core root coordinator core type Core struct { @@ -95,6 +90,7 @@ type Core struct { cancel context.CancelFunc wg sync.WaitGroup etcdCli *clientv3.Client + tikvCli *txnkv.Client address string meta IMetaTable scheduler IScheduler @@ -116,12 +112,12 @@ type Core struct { idAllocator allocator.Interface tsoAllocator tso2.Allocator - dataCoord types.DataCoord - queryCoord types.QueryCoord + dataCoord types.DataCoordClient + queryCoord types.QueryCoordClient quotaCenter *QuotaCenter - stateCode atomic.Value + stateCode atomic.Int32 initOnce sync.Once startOnce sync.Once session *sessionutil.Session @@ -155,14 +151,12 @@ func NewCore(c context.Context, factory dependency.Factory) (*Core, error) { // UpdateStateCode update state code func (c *Core) UpdateStateCode(code commonpb.StateCode) { - c.stateCode.Store(code) + c.stateCode.Store(int32(code)) log.Info("update rootcoord state", zap.String("state", code.String())) } -func (c *Core) checkHealthy() (commonpb.StateCode, bool) { - code := c.stateCode.Load().(commonpb.StateCode) - ok := code == commonpb.StateCode_Healthy - return code, ok +func (c *Core) GetStateCode() commonpb.StateCode { + return commonpb.StateCode(c.stateCode.Load()) } func (c *Core) sendTimeTick(t Timestamp, reason string) error { @@ -185,7 +179,10 @@ func (c *Core) sendTimeTick(t Timestamp, reason string) error { } func (c *Core) sendMinDdlTsAsTt() { - code := c.stateCode.Load().(commonpb.StateCode) + if !paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool() { + return + } + code := c.GetStateCode() if code != commonpb.StateCode_Healthy { log.Warn("rootCoord is not healthy, skip send timetick") return @@ -249,11 +246,11 @@ func (c *Core) tsLoop() { } } -func (c *Core) SetProxyCreator(f func(ctx context.Context, addr string, nodeID int64) (types.Proxy, error)) { +func (c *Core) SetProxyCreator(f func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error)) { c.proxyCreator = f } -func (c *Core) SetDataCoord(s types.DataCoord) error { +func (c *Core) SetDataCoordClient(s types.DataCoordClient) error { if s == nil { return errors.New("null DataCoord interface") } @@ -261,7 +258,7 @@ func (c *Core) SetDataCoord(s types.DataCoord) error { return nil } -func (c *Core) SetQueryCoord(s types.QueryCoord) error { +func (c *Core) SetQueryCoordClient(s types.QueryCoordClient) error { if s == nil { return errors.New("null QueryCoord interface") } @@ -305,6 +302,11 @@ func (c *Core) SetEtcdClient(etcdClient *clientv3.Client) { c.etcdCli = etcdClient } +// SetTiKVClient sets the tikvCli of Core +func (c *Core) SetTiKVClient(client *txnkv.Client) { + c.tikvCli = client +} + func (c *Core) initSession() error { c.session = sessionutil.NewSession(c.ctx, Params.EtcdCfg.MetaRootPath.GetValue(), c.etcdCli) if c.session == nil { @@ -317,7 +319,15 @@ func (c *Core) initSession() error { func (c *Core) initKVCreator() { if c.metaKVCreator == nil { - c.metaKVCreator = defaultMetaKVCreator(c.etcdCli) + if Params.MetaStoreCfg.MetaStoreType.GetValue() == util.MetaStoreTypeTiKV { + c.metaKVCreator = func() (kv.MetaKv, error) { + return tikv.NewTiKV(c.tikvCli, Params.TiKVCfg.MetaRootPath.GetValue()), nil + } + } else { + c.metaKVCreator = func() (kv.MetaKv, error) { + return etcdkv.NewEtcdKV(c.etcdCli, Params.EtcdCfg.MetaRootPath.GetValue()), nil + } + } } } @@ -328,18 +338,32 @@ func (c *Core) initMetaTable() error { switch Params.MetaStoreCfg.MetaStoreType.GetValue() { case util.MetaStoreTypeEtcd: + log.Info("Using etcd as meta storage.") var metaKV kv.MetaKv var ss *kvmetestore.SuffixSnapshot var err error - if metaKV, err = c.metaKVCreator(Params.EtcdCfg.MetaRootPath.GetValue()); err != nil { + if metaKV, err = c.metaKVCreator(); err != nil { return err } if ss, err = kvmetestore.NewSuffixSnapshot(metaKV, kvmetestore.SnapshotsSep, Params.EtcdCfg.MetaRootPath.GetValue(), kvmetestore.SnapshotPrefix); err != nil { return err } + catalog = &kvmetestore.Catalog{Txn: metaKV, Snapshot: ss} + case util.MetaStoreTypeTiKV: + log.Info("Using tikv as meta storage.") + var metaKV kv.MetaKv + var ss *kvmetestore.SuffixSnapshot + var err error + + if metaKV, err = c.metaKVCreator(); err != nil { + return err + } + if ss, err = kvmetestore.NewSuffixSnapshot(metaKV, kvmetestore.SnapshotsSep, Params.TiKVCfg.MetaRootPath.GetValue(), kvmetestore.SnapshotPrefix); err != nil { + return err + } catalog = &kvmetestore.Catalog{Txn: metaKV, Snapshot: ss} default: return retry.Unrecoverable(fmt.Errorf("not supported meta store: %s", Params.MetaStoreCfg.MetaStoreType.GetValue())) @@ -356,7 +380,15 @@ func (c *Core) initMetaTable() error { } func (c *Core) initIDAllocator() error { - tsoKV := tsoutil2.NewTSOKVBase(c.etcdCli, Params.EtcdCfg.KvRootPath.GetValue(), globalIDAllocatorSubPath) + var tsoKV kv.TxnKV + var kvPath string + if Params.MetaStoreCfg.MetaStoreType.GetValue() == util.MetaStoreTypeTiKV { + kvPath = Params.TiKVCfg.KvRootPath.GetValue() + tsoKV = tsoutil2.NewTSOTiKVBase(c.tikvCli, kvPath, globalIDAllocatorSubPath) + } else { + kvPath = Params.EtcdCfg.KvRootPath.GetValue() + tsoKV = tsoutil2.NewTSOKVBase(c.etcdCli, kvPath, globalIDAllocatorSubPath) + } idAllocator := allocator.NewGlobalIDAllocator(globalIDAllocatorKey, tsoKV) if err := idAllocator.Initialize(); err != nil { return err @@ -364,7 +396,7 @@ func (c *Core) initIDAllocator() error { c.idAllocator = idAllocator log.Info("id allocator initialized", - zap.String("root_path", Params.EtcdCfg.KvRootPath.GetValue()), + zap.String("root_path", kvPath), zap.String("sub_path", globalIDAllocatorSubPath), zap.String("key", globalIDAllocatorKey)) @@ -372,7 +404,15 @@ func (c *Core) initIDAllocator() error { } func (c *Core) initTSOAllocator() error { - tsoKV := tsoutil2.NewTSOKVBase(c.etcdCli, Params.EtcdCfg.KvRootPath.GetValue(), globalTSOAllocatorSubPath) + var tsoKV kv.TxnKV + var kvPath string + if Params.MetaStoreCfg.MetaStoreType.GetValue() == util.MetaStoreTypeTiKV { + kvPath = Params.TiKVCfg.KvRootPath.GetValue() + tsoKV = tsoutil2.NewTSOTiKVBase(c.tikvCli, Params.TiKVCfg.KvRootPath.GetValue(), globalIDAllocatorSubPath) + } else { + kvPath = Params.EtcdCfg.KvRootPath.GetValue() + tsoKV = tsoutil2.NewTSOKVBase(c.etcdCli, Params.EtcdCfg.KvRootPath.GetValue(), globalIDAllocatorSubPath) + } tsoAllocator := tso2.NewGlobalTSOAllocator(globalTSOAllocatorKey, tsoKV) if err := tsoAllocator.Initialize(); err != nil { return err @@ -380,7 +420,7 @@ func (c *Core) initTSOAllocator() error { c.tsoAllocator = tsoAllocator log.Info("tso allocator initialized", - zap.String("root_path", Params.EtcdCfg.KvRootPath.GetValue()), + zap.String("root_path", kvPath), zap.String("sub_path", globalIDAllocatorSubPath), zap.String("key", globalIDAllocatorKey)) @@ -388,7 +428,7 @@ func (c *Core) initTSOAllocator() error { } func (c *Core) initImportManager() error { - impTaskKv, err := c.metaKVCreator(Params.EtcdCfg.KvRootPath.GetValue()) + impTaskKv, err := c.metaKVCreator() if err != nil { return err } @@ -491,7 +531,7 @@ func (c *Core) Init() error { log.Error("RootCoord start failed", zap.Error(err)) } }) - log.Info("RootCoord startup success") + log.Info("RootCoord startup success", zap.String("address", c.session.Address)) return err } c.UpdateStateCode(commonpb.StateCode_StandBy) @@ -642,6 +682,7 @@ func (c *Core) startInternal() error { c.startServerLoop() c.UpdateStateCode(commonpb.StateCode_Healthy) + sessionutil.SaveServerInfo(typeutil.RootCoordRole, c.session.ServerID) logutil.Logger(c.ctx).Info("rootcoord startup successfully") return nil @@ -716,8 +757,8 @@ func (c *Core) Stop() error { } // GetComponentStates get states of components -func (c *Core) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - code := c.stateCode.Load().(commonpb.StateCode) +func (c *Core) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest) (*milvuspb.ComponentStates, error) { + code := c.GetStateCode() nodeID := common.NotRegisteredID if c.session != nil && c.session.Registered() { @@ -732,7 +773,7 @@ func (c *Core) GetComponentStates(ctx context.Context) (*milvuspb.ComponentState StateCode: code, ExtraInfo: nil, }, - Status: merr.Status(nil), + Status: merr.Success(), SubcomponentStates: []*milvuspb.ComponentInfo{ { NodeID: nodeID, @@ -745,24 +786,24 @@ func (c *Core) GetComponentStates(ctx context.Context) (*milvuspb.ComponentState } // GetTimeTickChannel get timetick channel name -func (c *Core) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Core) GetTimeTickChannel(ctx context.Context, req *internalpb.GetTimeTickChannelRequest) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Value: Params.CommonCfg.RootCoordTimeTick.GetValue(), }, nil } // GetStatisticsChannel get statistics channel name -func (c *Core) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (c *Core) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest) (*milvuspb.StringResponse, error) { return &milvuspb.StringResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Value: Params.CommonCfg.RootCoordStatistics.GetValue(), }, nil } func (c *Core) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) { - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } method := "CreateDatabase" @@ -784,7 +825,7 @@ func (c *Core) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRe zap.String("dbName", in.GetDbName()), zap.Int64("msgID", in.GetBase().GetMsgID())) metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.FailLabel).Inc() - return failStatus(commonpb.ErrorCode_UnexpectedError, err.Error()), nil + return merr.Status(err), nil } if err := t.WaitToFinish(); err != nil { @@ -795,7 +836,7 @@ func (c *Core) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRe zap.Int64("msgID", in.GetBase().GetMsgID()), zap.Uint64("ts", t.GetTs())) metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.FailLabel).Inc() - return failStatus(commonpb.ErrorCode_UnexpectedError, err.Error()), nil + return merr.Status(err), nil } metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() @@ -804,12 +845,12 @@ func (c *Core) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRe log.Ctx(ctx).Info("done to create database", zap.String("role", typeutil.RootCoordRole), zap.String("dbName", in.GetDbName()), zap.Int64("msgID", in.GetBase().GetMsgID()), zap.Uint64("ts", t.GetTs())) - return succStatus(), nil + return merr.Success(), nil } func (c *Core) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) { - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } method := "DropDatabase" @@ -830,7 +871,7 @@ func (c *Core) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseReques zap.String("dbName", in.GetDbName()), zap.Int64("msgID", in.GetBase().GetMsgID())) metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.FailLabel).Inc() - return failStatus(commonpb.ErrorCode_UnexpectedError, err.Error()), nil + return merr.Status(err), nil } if err := t.WaitToFinish(); err != nil { @@ -840,7 +881,7 @@ func (c *Core) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseReques zap.Int64("msgID", in.GetBase().GetMsgID()), zap.Uint64("ts", t.GetTs())) metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.FailLabel).Inc() - return failStatus(commonpb.ErrorCode_UnexpectedError, err.Error()), nil + return merr.Status(err), nil } metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() @@ -849,12 +890,12 @@ func (c *Core) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseReques log.Ctx(ctx).Info("done to drop database", zap.String("role", typeutil.RootCoordRole), zap.String("dbName", in.GetDbName()), zap.Int64("msgID", in.GetBase().GetMsgID()), zap.Uint64("ts", t.GetTs())) - return succStatus(), nil + return merr.Success(), nil } func (c *Core) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) { - if code, ok := c.checkHealthy(); !ok { - ret := &milvuspb.ListDatabasesResponse{Status: merr.Status(merr.WrapErrServiceNotReady(code.String()))} + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + ret := &milvuspb.ListDatabasesResponse{Status: merr.Status(err)} return ret, nil } method := "ListDatabases" @@ -874,7 +915,7 @@ func (c *Core) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequ log.Info("failed to enqueue request to list databases", zap.Error(err)) metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.FailLabel).Inc() return &milvuspb.ListDatabasesResponse{ - Status: failStatus(commonpb.ErrorCode_UnexpectedError, "ListDatabases failed: "+err.Error()), + Status: merr.Status(err), }, nil } @@ -882,7 +923,7 @@ func (c *Core) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequ log.Info("failed to list databases", zap.Error(err)) metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.FailLabel).Inc() return &milvuspb.ListDatabasesResponse{ - Status: failStatus(commonpb.ErrorCode_UnexpectedError, "ListDatabases failed: "+err.Error()), + Status: merr.Status(err), }, nil } @@ -894,8 +935,8 @@ func (c *Core) ListDatabases(ctx context.Context, in *milvuspb.ListDatabasesRequ // CreateCollection create collection func (c *Core) CreateCollection(ctx context.Context, in *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } metrics.RootCoordDDLReqCounter.WithLabelValues("CreateCollection", metrics.TotalLabel).Inc() @@ -941,13 +982,13 @@ func (c *Core) CreateCollection(ctx context.Context, in *milvuspb.CreateCollecti zap.String("role", typeutil.RootCoordRole), zap.String("name", in.GetCollectionName()), zap.Uint64("ts", t.GetTs())) - return merr.Status(nil), nil + return merr.Success(), nil } // DropCollection drop collection func (c *Core) DropCollection(ctx context.Context, in *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } metrics.RootCoordDDLReqCounter.WithLabelValues("DropCollection", metrics.TotalLabel).Inc() @@ -990,14 +1031,14 @@ func (c *Core) DropCollection(ctx context.Context, in *milvuspb.DropCollectionRe log.Ctx(ctx).Info("done to drop collection", zap.String("role", typeutil.RootCoordRole), zap.String("name", in.GetCollectionName()), zap.Uint64("ts", t.GetTs())) - return merr.Status(nil), nil + return merr.Success(), nil } // HasCollection check collection existence func (c *Core) HasCollection(ctx context.Context, in *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &milvuspb.BoolResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), }, nil } @@ -1050,7 +1091,7 @@ func (c *Core) describeCollection(ctx context.Context, in *milvuspb.DescribeColl } func convertModelToDesc(collInfo *model.Collection, aliases []string) *milvuspb.DescribeCollectionResponse { - resp := &milvuspb.DescribeCollectionResponse{Status: merr.Status(nil)} + resp := &milvuspb.DescribeCollectionResponse{Status: merr.Success()} resp.Schema = &schemapb.CollectionSchema{ Name: collInfo.Name, @@ -1080,9 +1121,9 @@ func convertModelToDesc(collInfo *model.Collection, aliases []string) *milvuspb. } func (c *Core) describeCollectionImpl(ctx context.Context, in *milvuspb.DescribeCollectionRequest, allowUnavailable bool) (*milvuspb.DescribeCollectionResponse, error) { - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &milvuspb.DescribeCollectionResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), }, nil } @@ -1103,7 +1144,7 @@ func (c *Core) describeCollectionImpl(ctx context.Context, in *milvuspb.Describe t := &describeCollectionTask{ baseTask: newBaseTask(ctx, c), Req: in, - Rsp: &milvuspb.DescribeCollectionResponse{Status: merr.Status(nil)}, + Rsp: &milvuspb.DescribeCollectionResponse{Status: merr.Success()}, allowUnavailable: allowUnavailable, } @@ -1149,9 +1190,9 @@ func (c *Core) DescribeCollectionInternal(ctx context.Context, in *milvuspb.Desc // ShowCollections list all collection names func (c *Core) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &milvuspb.ShowCollectionsResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), }, nil } @@ -1196,8 +1237,8 @@ func (c *Core) ShowCollections(ctx context.Context, in *milvuspb.ShowCollections } func (c *Core) AlterCollection(ctx context.Context, in *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } metrics.RootCoordDDLReqCounter.WithLabelValues("AlterCollection", metrics.TotalLabel).Inc() @@ -1242,13 +1283,13 @@ func (c *Core) AlterCollection(ctx context.Context, in *milvuspb.AlterCollection zap.String("role", typeutil.RootCoordRole), zap.String("name", in.GetCollectionName()), zap.Uint64("ts", t.GetTs())) - return merr.Status(nil), nil + return merr.Success(), nil } // CreatePartition create partition func (c *Core) CreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } metrics.RootCoordDDLReqCounter.WithLabelValues("CreatePartition", metrics.TotalLabel).Inc() @@ -1296,13 +1337,13 @@ func (c *Core) CreatePartition(ctx context.Context, in *milvuspb.CreatePartition zap.String("collection", in.GetCollectionName()), zap.String("partition", in.GetPartitionName()), zap.Uint64("ts", t.GetTs())) - return merr.Status(nil), nil + return merr.Success(), nil } // DropPartition drop partition func (c *Core) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } metrics.RootCoordDDLReqCounter.WithLabelValues("DropPartition", metrics.TotalLabel).Inc() @@ -1349,14 +1390,14 @@ func (c *Core) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequ zap.String("collection", in.GetCollectionName()), zap.String("partition", in.GetPartitionName()), zap.Uint64("ts", t.GetTs())) - return merr.Status(nil), nil + return merr.Success(), nil } // HasPartition check partition existence func (c *Core) HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &milvuspb.BoolResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), }, nil } @@ -1403,9 +1444,9 @@ func (c *Core) HasPartition(ctx context.Context, in *milvuspb.HasPartitionReques } func (c *Core) showPartitionsImpl(ctx context.Context, in *milvuspb.ShowPartitionsRequest, allowUnavailable bool) (*milvuspb.ShowPartitionsResponse, error) { - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &milvuspb.ShowPartitionsResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), }, nil } @@ -1467,14 +1508,14 @@ func (c *Core) ShowPartitionsInternal(ctx context.Context, in *milvuspb.ShowPart func (c *Core) ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { // ShowSegments Only used in GetPersistentSegmentInfo, it's already deprecated for a long time. // Though we continue to keep current logic, it's not right enough since RootCoord only contains indexed segments. - return &milvuspb.ShowSegmentsResponse{Status: merr.Status(nil)}, nil + return &milvuspb.ShowSegmentsResponse{Status: merr.Success()}, nil } // AllocTimestamp alloc timestamp func (c *Core) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &rootcoordpb.AllocTimestampResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), }, nil } @@ -1492,7 +1533,7 @@ func (c *Core) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestam ts = ts - uint64(in.GetCount()) + 1 metrics.RootCoordTimestamp.Set(float64(ts)) return &rootcoordpb.AllocTimestampResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Timestamp: ts, Count: in.GetCount(), }, nil @@ -1500,9 +1541,9 @@ func (c *Core) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestam // AllocID alloc ids func (c *Core) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &rootcoordpb.AllocIDResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), }, nil } start, _, err := c.idAllocator.Alloc(in.Count) @@ -1519,7 +1560,7 @@ func (c *Core) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest) (*ro metrics.RootCoordIDAllocCounter.Add(float64(in.Count)) return &rootcoordpb.AllocIDResponse{ - Status: merr.Status(nil), + Status: merr.Success(), ID: start, Count: in.Count, }, nil @@ -1528,9 +1569,9 @@ func (c *Core) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest) (*ro // UpdateChannelTimeTick used to handle ChannelTimeTickMsg func (c *Core) UpdateChannelTimeTick(ctx context.Context, in *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) { log := log.Ctx(ctx) - if code, ok := c.checkHealthy(); !ok { - log.Warn("failed to updateTimeTick because rootcoord is not healthy", zap.Any("state", code)) - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + log.Warn("failed to updateTimeTick because rootcoord is not healthy", zap.Error(err)) + return merr.Status(err), nil } if in.Base.MsgType != commonpb.MsgType_TimeTick { log.Warn("failed to updateTimeTick because base messasge is not timetick, state", zap.Any("base message type", in.Base.MsgType)) @@ -1543,26 +1584,26 @@ func (c *Core) UpdateChannelTimeTick(ctx context.Context, in *internalpb.Channel zap.Error(err)) return merr.Status(err), nil } - return merr.Status(nil), nil + return merr.Success(), nil } // InvalidateCollectionMetaCache notifies RootCoord to release the collection cache in Proxies. func (c *Core) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } err := c.proxyClientManager.InvalidateCollectionMetaCache(ctx, in) if err != nil { return merr.Status(err), nil } - return merr.Status(nil), nil + return merr.Success(), nil } // ShowConfigurations returns the configurations of RootCoord matching req.Pattern func (c *Core) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &internalpb.ShowConfigurationsResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), Configuations: nil, }, nil } @@ -1577,16 +1618,16 @@ func (c *Core) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfi } return &internalpb.ShowConfigurationsResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Configuations: configList, }, nil } // GetMetrics get metrics func (c *Core) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthyStandby(c.GetStateCode()); err != nil { return &milvuspb.GetMetricsResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), Response: "", }, nil } @@ -1633,8 +1674,8 @@ func (c *Core) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) ( // CreateAlias create collection alias func (c *Core) CreateAlias(ctx context.Context, in *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } metrics.RootCoordDDLReqCounter.WithLabelValues("CreateAlias", metrics.TotalLabel).Inc() @@ -1682,13 +1723,13 @@ func (c *Core) CreateAlias(ctx context.Context, in *milvuspb.CreateAliasRequest) zap.String("alias", in.GetAlias()), zap.String("collection", in.GetCollectionName()), zap.Uint64("ts", t.GetTs())) - return merr.Status(nil), nil + return merr.Success(), nil } // DropAlias drop collection alias func (c *Core) DropAlias(ctx context.Context, in *milvuspb.DropAliasRequest) (*commonpb.Status, error) { - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } metrics.RootCoordDDLReqCounter.WithLabelValues("DropAlias", metrics.TotalLabel).Inc() @@ -1732,13 +1773,13 @@ func (c *Core) DropAlias(ctx context.Context, in *milvuspb.DropAliasRequest) (*c zap.String("role", typeutil.RootCoordRole), zap.String("alias", in.GetAlias()), zap.Uint64("ts", t.GetTs())) - return merr.Status(nil), nil + return merr.Success(), nil } // AlterAlias alter collection alias func (c *Core) AlterAlias(ctx context.Context, in *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } metrics.RootCoordDDLReqCounter.WithLabelValues("DropAlias", metrics.TotalLabel).Inc() @@ -1786,14 +1827,14 @@ func (c *Core) AlterAlias(ctx context.Context, in *milvuspb.AlterAliasRequest) ( zap.String("alias", in.GetAlias()), zap.String("collection", in.GetCollectionName()), zap.Uint64("ts", t.GetTs())) - return merr.Status(nil), nil + return merr.Success(), nil } // Import imports large files (json, numpy, etc.) on MinIO/S3 storage into Milvus storage. func (c *Core) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &milvuspb.ImportResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), }, nil } @@ -1807,20 +1848,7 @@ func (c *Core) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvus return nil, err } - // Backup tool call import must with a partition name, each time restore a partition isBackUp := importutil.IsBackup(req.GetOptions()) - if isBackUp { - if len(req.GetPartitionName()) == 0 { - log.Info("partition name not specified when backup recovery", - zap.String("collectionName", req.GetCollectionName())) - ret := &milvuspb.ImportResponse{ - Status: failStatus(commonpb.ErrorCode_UnexpectedError, - "partition name not specified when backup"), - } - return ret, nil - } - } - cID := colInfo.CollectionID req.ChannelNames = c.meta.GetCollectionVirtualChannels(cID) @@ -1832,32 +1860,53 @@ func (c *Core) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvus } } - // If has partition key and not backup/restore mode, don't allow user to specify partition name - if hasPartitionKey && !isBackUp && req.GetPartitionName() != "" { - msg := "not allow to set partition name for collection with partition key" - log.Warn(msg, zap.String("collection name", req.GetCollectionName())) - return nil, errors.New(msg) - } - // Get partition ID by partition name var pID UniqueID - if !hasPartitionKey { - if req.GetPartitionName() == "" { - req.PartitionName = Params.CommonCfg.DefaultPartitionName.GetValue() + if isBackUp { + // Currently, Backup tool call import must with a partition name, each time restore a partition + if req.GetPartitionName() != "" { + if pID, err = c.meta.GetPartitionByName(cID, req.GetPartitionName(), typeutil.MaxTimestamp); err != nil { + log.Warn("failed to get partition ID from its name", zap.String("partitionName", req.GetPartitionName()), zap.Error(err)) + return &milvuspb.ImportResponse{ + Status: merr.Status(merr.WrapErrPartitionNotFound(req.GetPartitionName())), + }, nil + } + } else { + log.Info("partition name not specified when backup recovery", + zap.String("collectionName", req.GetCollectionName())) + return &milvuspb.ImportResponse{ + Status: merr.Status(merr.WrapErrParameterInvalidMsg("partition not specified")), + }, nil } - if pID, err = c.meta.GetPartitionByName(cID, req.GetPartitionName(), typeutil.MaxTimestamp); err != nil { - log.Warn("failed to get partition ID from its name", - zap.String("partition name", req.GetPartitionName()), - zap.Error(err)) - return nil, err + } else { + if hasPartitionKey { + if req.GetPartitionName() != "" { + msg := "not allow to set partition name for collection with partition key" + log.Warn(msg, zap.String("collectionName", req.GetCollectionName())) + return &milvuspb.ImportResponse{ + Status: merr.Status(merr.WrapErrParameterInvalidMsg(msg)), + }, nil + } + } else { + if req.GetPartitionName() == "" { + req.PartitionName = Params.CommonCfg.DefaultPartitionName.GetValue() + } + if pID, err = c.meta.GetPartitionByName(cID, req.GetPartitionName(), typeutil.MaxTimestamp); err != nil { + log.Warn("failed to get partition ID from its name", + zap.String("partition name", req.GetPartitionName()), + zap.Error(err)) + return &milvuspb.ImportResponse{ + Status: merr.Status(merr.WrapErrPartitionNotFound(req.GetPartitionName())), + }, nil + } } } log.Info("RootCoord receive import request", zap.String("collectionName", req.GetCollectionName()), zap.Int64("collectionID", cID), - zap.String("partition name", req.GetPartitionName()), - zap.Strings("virtual channel names", req.GetChannelNames()), + zap.String("partitionName", req.GetPartitionName()), + zap.Strings("virtualChannelNames", req.GetChannelNames()), zap.Int64("partitionID", pID), zap.Int("# of files = ", len(req.GetFiles())), ) @@ -1867,9 +1916,9 @@ func (c *Core) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvus // GetImportState returns the current state of an import task. func (c *Core) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) { - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &milvuspb.GetImportStateResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), }, nil } return c.importManager.getTaskState(req.GetTask()), nil @@ -1877,9 +1926,9 @@ func (c *Core) GetImportState(ctx context.Context, req *milvuspb.GetImportStateR // ListImportTasks returns id array of all import tasks. func (c *Core) ListImportTasks(ctx context.Context, req *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) { - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &milvuspb.ListImportTasksResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), }, nil } @@ -1893,7 +1942,6 @@ func (c *Core) ListImportTasks(ctx context.Context, req *milvuspb.ListImportTask err = fmt.Errorf("failed to find collection ID from its name: '%s', error: %w", req.GetCollectionName(), err) log.Error("ListImportTasks failed", zap.Error(err)) status := merr.Status(err) - status.ErrorCode = commonpb.ErrorCode_IllegalCollectionName return &milvuspb.ListImportTasksResponse{ Status: status, }, nil @@ -1912,7 +1960,7 @@ func (c *Core) ListImportTasks(ctx context.Context, req *milvuspb.ListImportTask } resp := &milvuspb.ListImportTasksResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Tasks: tasks, } return resp, nil @@ -1923,8 +1971,8 @@ func (c *Core) ReportImport(ctx context.Context, ir *rootcoordpb.ImportResult) ( log.Info("RootCoord receive import state report", zap.Int64("task ID", ir.GetTaskId()), zap.Any("import state", ir.GetState())) - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } // This method update a busy node to idle node, and send import task to idle node @@ -1936,7 +1984,6 @@ func (c *Core) ReportImport(ctx context.Context, ir *rootcoordpb.ImportResult) ( log.Info("a DataNode is no longer busy after processing task", zap.Int64("dataNode ID", ir.GetDatanodeId()), zap.Int64("task ID", ir.GetTaskId())) - }() err := c.importManager.sendOutTasks(c.importManager.ctx) if err != nil { @@ -1951,11 +1998,7 @@ func (c *Core) ReportImport(ctx context.Context, ir *rootcoordpb.ImportResult) ( // Upon receiving ReportImport request, update the related task's state in task store. ti, err := c.importManager.updateTaskInfo(ir) if err != nil { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UpdateImportTaskFailure, - Reason: err.Error(), - Code: merr.Code(err), - }, nil + return merr.Status(err), nil } // If task failed, send task to idle datanode @@ -1981,15 +2024,15 @@ func (c *Core) ReportImport(ctx context.Context, ir *rootcoordpb.ImportResult) ( } } - return merr.Status(nil), nil + return merr.Success(), nil } // ExpireCredCache will call invalidate credential cache func (c *Core) ExpireCredCache(ctx context.Context, username string) error { req := proxypb.InvalidateCredCacheRequest{ Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(0), //TODO, msg type - commonpbutil.WithMsgID(0), //TODO, msg id + commonpbutil.WithMsgType(0), // TODO, msg type + commonpbutil.WithMsgID(0), // TODO, msg id commonpbutil.WithSourceID(c.session.ServerID), ), Username: username, @@ -2001,8 +2044,8 @@ func (c *Core) ExpireCredCache(ctx context.Context, username string) error { func (c *Core) UpdateCredCache(ctx context.Context, credInfo *internalpb.CredentialInfo) error { req := proxypb.UpdateCredCacheRequest{ Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(0), //TODO, msg type - commonpbutil.WithMsgID(0), //TODO, msg id + commonpbutil.WithMsgType(0), // TODO, msg type + commonpbutil.WithMsgID(0), // TODO, msg id commonpbutil.WithSourceID(c.session.ServerID), ), Username: credInfo.Username, @@ -2021,8 +2064,8 @@ func (c *Core) CreateCredential(ctx context.Context, credInfo *internalpb.Creden tr := timerecord.NewTimeRecorder(method) ctxLog := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole), zap.String("username", credInfo.Username)) ctxLog.Debug(method) - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } // insert to db @@ -2044,7 +2087,7 @@ func (c *Core) CreateCredential(ctx context.Context, credInfo *internalpb.Creden metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.RootCoordNumOfCredentials.Inc() - return merr.Status(nil), nil + return merr.Success(), nil } // GetCredential get credential by username @@ -2054,8 +2097,8 @@ func (c *Core) GetCredential(ctx context.Context, in *rootcoordpb.GetCredentialR tr := timerecord.NewTimeRecorder(method) ctxLog := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole), zap.String("username", in.Username)) ctxLog.Debug(method) - if code, ok := c.checkHealthy(); !ok { - return &rootcoordpb.GetCredentialResponse{Status: merr.Status(merr.WrapErrServiceNotReady(code.String()))}, nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return &rootcoordpb.GetCredentialResponse{Status: merr.Status(err)}, nil } credInfo, err := c.meta.GetCredential(in.Username) @@ -2071,7 +2114,7 @@ func (c *Core) GetCredential(ctx context.Context, in *rootcoordpb.GetCredentialR metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) return &rootcoordpb.GetCredentialResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Username: credInfo.Username, Password: credInfo.EncryptedPassword, }, nil @@ -2084,8 +2127,8 @@ func (c *Core) UpdateCredential(ctx context.Context, credInfo *internalpb.Creden tr := timerecord.NewTimeRecorder(method) ctxLog := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole), zap.String("username", credInfo.Username)) ctxLog.Debug(method) - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } // update data on storage err := c.meta.AlterCredential(credInfo) @@ -2105,7 +2148,7 @@ func (c *Core) UpdateCredential(ctx context.Context, credInfo *internalpb.Creden metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) - return merr.Status(nil), nil + return merr.Success(), nil } // DeleteCredential delete a user @@ -2115,8 +2158,8 @@ func (c *Core) DeleteCredential(ctx context.Context, in *milvuspb.DeleteCredenti tr := timerecord.NewTimeRecorder(method) ctxLog := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole), zap.String("username", in.Username)) ctxLog.Debug(method) - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } var status *commonpb.Status defer func() { @@ -2163,7 +2206,7 @@ func (c *Core) DeleteCredential(ctx context.Context, in *milvuspb.DeleteCredenti metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.RootCoordNumOfCredentials.Dec() - status = merr.Status(nil) + status = merr.Success() return status, nil } @@ -2174,8 +2217,8 @@ func (c *Core) ListCredUsers(ctx context.Context, in *milvuspb.ListCredUsersRequ tr := timerecord.NewTimeRecorder(method) ctxLog := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole)) ctxLog.Debug(method) - if code, ok := c.checkHealthy(); !ok { - return &milvuspb.ListCredUsersResponse{Status: merr.Status(merr.WrapErrServiceNotReady(code.String()))}, nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return &milvuspb.ListCredUsersResponse{Status: merr.Status(err)}, nil } credInfo, err := c.meta.ListCredentialUsernames() @@ -2184,7 +2227,6 @@ func (c *Core) ListCredUsers(ctx context.Context, in *milvuspb.ListCredUsersRequ metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.FailLabel).Inc() status := merr.Status(err) - status.ErrorCode = commonpb.ErrorCode_ListCredUsersFailure return &milvuspb.ListCredUsersResponse{Status: status}, nil } ctxLog.Debug("ListCredUsers success") @@ -2192,7 +2234,7 @@ func (c *Core) ListCredUsers(ctx context.Context, in *milvuspb.ListCredUsersRequ metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) return &milvuspb.ListCredUsersResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Usernames: credInfo.Usernames, }, nil } @@ -2209,8 +2251,8 @@ func (c *Core) CreateRole(ctx context.Context, in *milvuspb.CreateRoleRequest) ( ctxLog := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole), zap.Any("in", in)) ctxLog.Debug(method + " begin") - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } entity := in.Entity @@ -2226,7 +2268,7 @@ func (c *Core) CreateRole(ctx context.Context, in *milvuspb.CreateRoleRequest) ( metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.RootCoordNumOfRoles.Inc() - return merr.Status(nil), nil + return merr.Success(), nil } // DropRole drop role @@ -2243,8 +2285,8 @@ func (c *Core) DropRole(ctx context.Context, in *milvuspb.DropRoleRequest) (*com ctxLog := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole), zap.String("role_name", in.RoleName)) ctxLog.Debug(method) - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } if _, err := c.meta.SelectRole(util.DefaultTenant, &milvuspb.RoleEntity{Name: in.RoleName}, false); err != nil { errMsg := "not found the role, maybe the role isn't existed or internal system error" @@ -2289,7 +2331,7 @@ func (c *Core) DropRole(ctx context.Context, in *milvuspb.DropRoleRequest) (*com metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.RootCoordNumOfRoles.Dec() - return merr.Status(nil), nil + return merr.Success(), nil } // OperateUserRole operate the relationship between a user and a role @@ -2305,8 +2347,8 @@ func (c *Core) OperateUserRole(ctx context.Context, in *milvuspb.OperateUserRole ctxLog := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole), zap.Any("in", in)) ctxLog.Debug(method) - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } if _, err := c.meta.SelectRole(util.DefaultTenant, &milvuspb.RoleEntity{Name: in.RoleName}, false); err != nil { @@ -2362,7 +2404,7 @@ func (c *Core) OperateUserRole(ctx context.Context, in *milvuspb.OperateUserRole ctxLog.Debug(method + " success") metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) - return merr.Status(nil), nil + return merr.Success(), nil } // SelectRole select role @@ -2376,15 +2418,15 @@ func (c *Core) SelectRole(ctx context.Context, in *milvuspb.SelectRoleRequest) ( ctxLog := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole), zap.Any("in", in)) ctxLog.Debug(method) - if code, ok := c.checkHealthy(); !ok { - return &milvuspb.SelectRoleResponse{Status: merr.Status(merr.WrapErrServiceNotReady(code.String()))}, nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return &milvuspb.SelectRoleResponse{Status: merr.Status(err)}, nil } if in.Role != nil { if _, err := c.meta.SelectRole(util.DefaultTenant, &milvuspb.RoleEntity{Name: in.Role.Name}, false); err != nil { - if common.IsKeyNotExistError(err) { + if errors.Is(err, merr.ErrIoKeyNotFound) { return &milvuspb.SelectRoleResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil } errMsg := "fail to select the role to check the role name" @@ -2407,7 +2449,7 @@ func (c *Core) SelectRole(ctx context.Context, in *milvuspb.SelectRoleRequest) ( metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) return &milvuspb.SelectRoleResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Results: roleResults, }, nil } @@ -2423,15 +2465,15 @@ func (c *Core) SelectUser(ctx context.Context, in *milvuspb.SelectUserRequest) ( ctxLog := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole), zap.Any("in", in)) ctxLog.Debug(method) - if code, ok := c.checkHealthy(); !ok { - return &milvuspb.SelectUserResponse{Status: merr.Status(merr.WrapErrServiceNotReady(code.String()))}, nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return &milvuspb.SelectUserResponse{Status: merr.Status(err)}, nil } if in.User != nil { if _, err := c.meta.SelectUser(util.DefaultTenant, &milvuspb.UserEntity{Name: in.User.Name}, false); err != nil { - if common.IsKeyNotExistError(err) { + if errors.Is(err, merr.ErrIoKeyNotFound) { return &milvuspb.SelectUserResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil } errMsg := "fail to select the user to check the username" @@ -2454,7 +2496,7 @@ func (c *Core) SelectUser(ctx context.Context, in *milvuspb.SelectUserRequest) ( metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) return &milvuspb.SelectUserResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Results: userResults, }, nil } @@ -2532,8 +2574,8 @@ func (c *Core) OperatePrivilege(ctx context.Context, in *milvuspb.OperatePrivile ctxLog := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole), zap.Any("in", in)) ctxLog.Debug(method) - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } if in.Type != milvuspb.OperatePrivilegeType_Grant && in.Type != milvuspb.OperatePrivilegeType_Revoke { errMsg := fmt.Sprintf("invalid operate privilege type, current type: %s, valid value: [%s, %s]", in.Type, milvuspb.OperatePrivilegeType_Grant, milvuspb.OperatePrivilegeType_Revoke) @@ -2607,7 +2649,7 @@ func (c *Core) OperatePrivilege(ctx context.Context, in *milvuspb.OperatePrivile ctxLog.Debug(method + " success") metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) - return merr.Status(nil), nil + return merr.Success(), nil } // SelectGrant select grant @@ -2622,9 +2664,9 @@ func (c *Core) SelectGrant(ctx context.Context, in *milvuspb.SelectGrantRequest) ctxLog := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole), zap.Any("in", in)) ctxLog.Debug(method) - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &milvuspb.SelectGrantResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), }, nil } if in.Entity == nil { @@ -2650,9 +2692,9 @@ func (c *Core) SelectGrant(ctx context.Context, in *milvuspb.SelectGrantRequest) } grantEntities, err := c.meta.SelectGrant(util.DefaultTenant, in.Entity) - if common.IsKeyNotExistError(err) { + if errors.Is(err, merr.ErrIoKeyNotFound) { return &milvuspb.SelectGrantResponse{ - Status: merr.Status(nil), + Status: merr.Success(), }, nil } if err != nil { @@ -2667,7 +2709,7 @@ func (c *Core) SelectGrant(ctx context.Context, in *milvuspb.SelectGrantRequest) metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) return &milvuspb.SelectGrantResponse{ - Status: merr.Status(nil), + Status: merr.Success(), Entities: grantEntities, }, nil } @@ -2679,9 +2721,9 @@ func (c *Core) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest) ctxLog := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole), zap.Any("in", in)) ctxLog.Debug(method) - if code, ok := c.checkHealthy(); !ok { + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { return &internalpb.ListPolicyResponse{ - Status: merr.Status(merr.WrapErrServiceNotReady(code.String())), + Status: merr.Status(err), }, nil } @@ -2706,15 +2748,15 @@ func (c *Core) ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest) metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) return &internalpb.ListPolicyResponse{ - Status: merr.Status(nil), + Status: merr.Success(), PolicyInfos: policies, UserRoles: userRoles, }, nil } func (c *Core) RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) { - if code, ok := c.checkHealthy(); !ok { - return merr.Status(merr.WrapErrServiceNotReady(code.String())), nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return merr.Status(err), nil } log := log.Ctx(ctx).With(zap.String("oldCollectionName", req.GetOldName()), zap.String("newCollectionName", req.GetNewName())) @@ -2743,13 +2785,16 @@ func (c *Core) RenameCollection(ctx context.Context, req *milvuspb.RenameCollect metrics.RootCoordDDLReqLatency.WithLabelValues("RenameCollection").Observe(float64(tr.ElapseSpan().Milliseconds())) log.Info("done to rename collection", zap.Uint64("ts", t.GetTs())) - return merr.Status(nil), nil + return merr.Success(), nil } func (c *Core) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) { - if _, ok := c.checkHealthy(); !ok { - reason := errorutil.UnHealthReason("rootcoord", c.session.ServerID, "rootcoord is unhealthy") - return &milvuspb.CheckHealthResponse{IsHealthy: false, Reasons: []string{reason}}, nil + if err := merr.CheckHealthy(c.GetStateCode()); err != nil { + return &milvuspb.CheckHealthResponse{ + Status: merr.Status(err), + IsHealthy: false, + Reasons: []string{fmt.Sprintf("serverID=%d: %v", c.session.ServerID, err)}, + }, nil } mu := &sync.Mutex{} @@ -2760,21 +2805,29 @@ func (c *Core) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest) nodeID := nodeID proxyClient := proxyClient group.Go(func() error { - sta, err := proxyClient.GetComponentStates(ctx) - isHealthy, reason := errorutil.UnHealthReasonWithComponentStatesOrErr("proxy", nodeID, sta, err) - if !isHealthy { + sta, err := proxyClient.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + if err != nil { + return err + } + + err = merr.AnalyzeState("Proxy", nodeID, sta) + if err != nil { mu.Lock() defer mu.Unlock() - errReasons = append(errReasons, reason) + errReasons = append(errReasons, err.Error()) } - return err + return nil }) } err := group.Wait() if err != nil || len(errReasons) != 0 { - return &milvuspb.CheckHealthResponse{IsHealthy: false, Reasons: errReasons}, nil + return &milvuspb.CheckHealthResponse{ + Status: merr.Success(), + IsHealthy: false, + Reasons: errReasons, + }, nil } - return &milvuspb.CheckHealthResponse{IsHealthy: true, Reasons: errReasons}, nil + return &milvuspb.CheckHealthResponse{Status: merr.Success(), IsHealthy: true, Reasons: errReasons}, nil } diff --git a/internal/rootcoord/root_coord_test.go b/internal/rootcoord/root_coord_test.go index 754304fd5dac5..4acadce05a8de 100644 --- a/internal/rootcoord/root_coord_test.go +++ b/internal/rootcoord/root_coord_test.go @@ -46,17 +46,26 @@ import ( "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/importutil" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tikv" "github.com/milvus-io/milvus/pkg/util/typeutil" ) func TestMain(m *testing.M) { paramtable.Init() rand.Seed(time.Now().UnixNano()) - os.Exit(m.Run()) + parameters := []string{"tikv", "etcd"} + var code int + for _, v := range parameters { + paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, v) + code = m.Run() + } + os.Exit(code) } func TestRootCoord_CreateDatabase(t *testing.T) { @@ -792,7 +801,7 @@ func TestRootCoord_UpdateChannelTimeTick(t *testing.T) { defaultTs := Timestamp(101) ticker := newRocksMqTtSynchronizer() - ticker.addSession(&sessionutil.Session{ServerID: source}) + ticker.addSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: source}}) ctx := context.Background() c := newTestCore(withHealthyCode(), @@ -951,7 +960,7 @@ func TestRootCoord_GetMetrics(t *testing.T) { withMetricsCacheManager()) resp, err := c.GetMetrics(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("get system info metrics from cache", func(t *testing.T) { @@ -962,13 +971,13 @@ func TestRootCoord_GetMetrics(t *testing.T) { c := newTestCore(withHealthyCode(), withMetricsCacheManager()) c.metricsCacheManager.UpdateSystemInfoMetrics(&milvuspb.GetMetricsResponse{ - Status: succStatus(), + Status: merr.Success(), Response: "cached response", ComponentName: "cached component", }) resp, err := c.GetMetrics(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("get system info metrics, cache miss", func(t *testing.T) { @@ -981,7 +990,7 @@ func TestRootCoord_GetMetrics(t *testing.T) { c.metricsCacheManager.InvalidateSystemInfoMetrics() resp, err := c.GetMetrics(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("get system info metrics", func(t *testing.T) { @@ -993,7 +1002,7 @@ func TestRootCoord_GetMetrics(t *testing.T) { withMetricsCacheManager()) resp, err := c.getSystemInfoMetrics(ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) } @@ -1044,10 +1053,11 @@ func TestCore_Import(t *testing.T) { meta.GetPartitionByNameFunc = func(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) { return 0, errors.New("mock GetPartitionByNameFunc error") } - _, err := c.Import(ctx, &milvuspb.ImportRequest{ + resp, err := c.Import(ctx, &milvuspb.ImportRequest{ CollectionName: "a-good-name", }) - assert.Error(t, err) + assert.NoError(t, err) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrPartitionNotFound) }) t.Run("normal case", func(t *testing.T) { @@ -1091,7 +1101,7 @@ func TestCore_Import(t *testing.T) { }, }) assert.NotNil(t, resp) - assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrParameterInvalid) }) // Remove the following case after bulkinsert can support partition key @@ -1144,11 +1154,70 @@ func TestCore_Import(t *testing.T) { meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) { return coll.Clone(), nil } - _, err := c.Import(ctx, &milvuspb.ImportRequest{ + resp, err := c.Import(ctx, &milvuspb.ImportRequest{ CollectionName: "a-good-name", PartitionName: "p1", }) - assert.Error(t, err) + assert.NoError(t, err) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrParameterInvalid) + }) + + t.Run("backup should set partition name", func(t *testing.T) { + ctx := context.Background() + c := newTestCore(withHealthyCode(), + withMeta(meta)) + meta.GetCollectionIDByNameFunc = func(name string) (UniqueID, error) { + return 100, nil + } + meta.GetCollectionVirtualChannelsFunc = func(colID int64) []string { + return []string{"ch-1", "ch-2"} + } + meta.GetPartitionByNameFunc = func(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) { + return 101, nil + } + coll := &model.Collection{ + CollectionID: 100, + Name: "a-good-name", + Fields: []*model.Field{ + { + FieldID: 101, + Name: "test_field_name_1", + IsPrimaryKey: false, + IsPartitionKey: true, + DataType: schemapb.DataType_Int64, + }, + }, + } + meta.GetCollectionByNameFunc = func(ctx context.Context, collectionName string, ts Timestamp) (*model.Collection, error) { + return coll.Clone(), nil + } + resp1, err := c.Import(ctx, &milvuspb.ImportRequest{ + CollectionName: "a-good-name", + Options: []*commonpb.KeyValuePair{ + { + Key: importutil.BackupFlag, + Value: "true", + }, + }, + }) + assert.NoError(t, err) + assert.ErrorIs(t, merr.Error(resp1.GetStatus()), merr.ErrParameterInvalid) + + meta.GetPartitionByNameFunc = func(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error) { + return common.InvalidPartitionID, fmt.Errorf("partition ID not found for partition name '%s'", partitionName) + } + resp2, _ := c.Import(ctx, &milvuspb.ImportRequest{ + CollectionName: "a-good-name", + PartitionName: "a-bad-name", + Options: []*commonpb.KeyValuePair{ + { + Key: importutil.BackupFlag, + Value: "true", + }, + }, + }) + assert.NoError(t, err) + assert.ErrorIs(t, merr.Error(resp2.GetStatus()), merr.ErrPartitionNotFound) }) } @@ -1268,7 +1337,7 @@ func TestCore_ListImportTasks(t *testing.T) { CollectionID: ti3.CollectionId, }, nil } - return nil, errors.New("GetCollectionByName error") + return nil, merr.WrapErrCollectionNotFound(collectionName) } ctx := context.Background() @@ -1306,7 +1375,7 @@ func TestCore_ListImportTasks(t *testing.T) { }) assert.NoError(t, err) assert.Equal(t, 0, len(resp.GetTasks())) - assert.Equal(t, commonpb.ErrorCode_IllegalCollectionName, resp.GetStatus().GetErrorCode()) + assert.ErrorIs(t, merr.Error(resp.GetStatus()), merr.ErrCollectionNotFound) // list the latest 2 tasks resp, err = c.ListImportTasks(ctx, &milvuspb.ListImportTasksRequest{ @@ -1332,8 +1401,8 @@ func TestCore_ListImportTasks(t *testing.T) { func TestCore_ReportImport(t *testing.T) { paramtable.Get().Save(Params.RootCoordCfg.ImportTaskSubPath.Key, "importtask") var countLock sync.RWMutex - var globalCount = typeutil.UniqueID(0) - var idAlloc = func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { + globalCount := typeutil.UniqueID(0) + idAlloc := func(count uint32) (typeutil.UniqueID, typeutil.UniqueID, error) { countLock.Lock() defer countLock.Unlock() globalCount++ @@ -1382,44 +1451,36 @@ func TestCore_ReportImport(t *testing.T) { StateCode: commonpb.StateCode_Healthy, }, SubcomponentStates: nil, - Status: succStatus(), + Status: merr.Success(), }, nil } dc.WatchChannelsFunc = func(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) { - return &datapb.WatchChannelsResponse{Status: succStatus()}, nil + return &datapb.WatchChannelsResponse{Status: merr.Success()}, nil } dc.FlushFunc = func(ctx context.Context, req *datapb.FlushRequest) (*datapb.FlushResponse, error) { - return &datapb.FlushResponse{Status: succStatus()}, nil + return &datapb.FlushResponse{Status: merr.Success()}, nil } mockCallImportServiceErr := false callImportServiceFn := func(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { if mockCallImportServiceErr { return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, errors.New("mock err") } return &datapb.ImportTaskResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } callGetSegmentStates := func(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { return &datapb.GetSegmentStatesResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), }, nil } callUnsetIsImportingState := func(context.Context, *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil + return merr.Success(), nil } t.Run("not healthy", func(t *testing.T) { @@ -1490,7 +1551,7 @@ func TestCore_Rbac(t *testing.T) { } // not healthy. - c.stateCode.Store(commonpb.StateCode_Abnormal) + c.UpdateStateCode(commonpb.StateCode_Abnormal) { resp, err := c.CreateCredential(ctx, &internalpb.CredentialInfo{}) @@ -1513,13 +1574,13 @@ func TestCore_Rbac(t *testing.T) { { resp, err := c.GetCredential(ctx, &rootcoordpb.GetCredentialRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_NotReadyServe, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_NotReadyServe, resp.GetStatus().GetErrorCode()) } { resp, err := c.ListCredUsers(ctx, &milvuspb.ListCredUsersRequest{}) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_NotReadyServe, resp.Status.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_NotReadyServe, resp.GetStatus().GetErrorCode()) } { @@ -1543,13 +1604,13 @@ func TestCore_Rbac(t *testing.T) { { resp, err := c.SelectRole(ctx, &milvuspb.SelectRoleRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } { resp, err := c.SelectUser(ctx, &milvuspb.SelectUserRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } { @@ -1561,13 +1622,13 @@ func TestCore_Rbac(t *testing.T) { { resp, err := c.SelectGrant(ctx, &milvuspb.SelectGrantRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } { resp, err := c.ListPolicy(ctx, &internalpb.ListPolicyRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } } @@ -1586,10 +1647,15 @@ func TestCore_sendMinDdlTsAsTt(t *testing.T) { withDdlTsLockManager(ddlManager), withScheduler(sched)) - c.stateCode.Store(commonpb.StateCode_Healthy) + c.UpdateStateCode(commonpb.StateCode_Healthy) c.session.ServerID = TestRootCoordID + + _ = paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false") + c.sendMinDdlTsAsTt() // disable ts msg + _ = paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true") + c.sendMinDdlTsAsTt() // no session. - ticker.addSession(&sessionutil.Session{ServerID: TestRootCoordID}) + ticker.addSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: TestRootCoordID}}) c.sendMinDdlTsAsTt() sched.GetMinDdlTsFunc = func() Timestamp { return typeutil.ZeroTimestamp @@ -1606,7 +1672,7 @@ func TestCore_sendMinDdlTsAsTt(t *testing.T) { func TestCore_startTimeTickLoop(t *testing.T) { ticker := newRocksMqTtSynchronizer() - ticker.addSession(&sessionutil.Session{ServerID: TestRootCoordID}) + ticker.addSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: TestRootCoordID}}) ddlManager := newMockDdlTsLockManager() ddlManager.GetMinDdlTsFunc = func() Timestamp { return 100 @@ -1656,15 +1722,17 @@ func TestRootcoord_EnableActiveStandby(t *testing.T) { core, err := NewCore(ctx, coreFactory) core.etcdCli = etcdCli assert.NoError(t, err) + core.SetTiKVClient(tikv.SetupLocalTxn()) + err = core.Init() assert.NoError(t, err) - assert.Equal(t, commonpb.StateCode_StandBy, core.stateCode.Load().(commonpb.StateCode)) + assert.Equal(t, commonpb.StateCode_StandBy, core.GetStateCode()) err = core.Start() assert.NoError(t, err) core.session.TriggerKill = false err = core.Register() assert.NoError(t, err) - assert.Equal(t, commonpb.StateCode_Healthy, core.stateCode.Load().(commonpb.StateCode)) + assert.Equal(t, commonpb.StateCode_Healthy, core.GetStateCode()) resp, err := core.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_DescribeCollection, @@ -1672,7 +1740,8 @@ func TestRootcoord_EnableActiveStandby(t *testing.T) { Timestamp: 0, SourceID: paramtable.GetNodeID(), }, - CollectionName: "unexist"}) + CollectionName: "unexist", + }) assert.NoError(t, err) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) err = core.Stop() @@ -1704,15 +1773,17 @@ func TestRootcoord_DisableActiveStandby(t *testing.T) { core, err := NewCore(ctx, coreFactory) core.etcdCli = etcdCli assert.NoError(t, err) + core.SetTiKVClient(tikv.SetupLocalTxn()) + err = core.Init() assert.NoError(t, err) - assert.Equal(t, commonpb.StateCode_Initializing, core.stateCode.Load().(commonpb.StateCode)) + assert.Equal(t, commonpb.StateCode_Initializing, core.GetStateCode()) err = core.Start() assert.NoError(t, err) core.session.TriggerKill = false err = core.Register() assert.NoError(t, err) - assert.Equal(t, commonpb.StateCode_Healthy, core.stateCode.Load().(commonpb.StateCode)) + assert.Equal(t, commonpb.StateCode_Healthy, core.GetStateCode()) resp, err := core.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_DescribeCollection, @@ -1720,7 +1791,8 @@ func TestRootcoord_DisableActiveStandby(t *testing.T) { Timestamp: 0, SourceID: paramtable.GetNodeID(), }, - CollectionName: "unexist"}) + CollectionName: "unexist", + }) assert.NoError(t, err) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) err = core.Stop() @@ -1811,7 +1883,7 @@ func TestRootCoord_RBACError(t *testing.T) { t.Run("get credential failed", func(t *testing.T) { resp, err := c.GetCredential(ctx, &rootcoordpb.GetCredentialRequest{Username: "foo"}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("update credential failed", func(t *testing.T) { resp, err := c.UpdateCredential(ctx, &internalpb.CredentialInfo{}) @@ -1826,7 +1898,7 @@ func TestRootCoord_RBACError(t *testing.T) { t.Run("list credential failed", func(t *testing.T) { resp, err := c.ListCredUsers(ctx, &milvuspb.ListCredUsersRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) t.Run("create role failed", func(t *testing.T) { resp, err := c.CreateRole(ctx, &milvuspb.CreateRoleRequest{Entity: &milvuspb.RoleEntity{Name: "foo"}}) @@ -1860,24 +1932,24 @@ func TestRootCoord_RBACError(t *testing.T) { { resp, err := c.SelectRole(ctx, &milvuspb.SelectRoleRequest{Role: &milvuspb.RoleEntity{Name: "foo"}}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } { resp, err := c.SelectRole(ctx, &milvuspb.SelectRoleRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } }) t.Run("select user failed", func(t *testing.T) { { resp, err := c.SelectUser(ctx, &milvuspb.SelectUserRequest{User: &milvuspb.UserEntity{Name: "foo"}}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } { resp, err := c.SelectUser(ctx, &milvuspb.SelectUserRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } }) t.Run("operate privilege failed", func(t *testing.T) { @@ -1946,12 +2018,12 @@ func TestRootCoord_RBACError(t *testing.T) { { resp, err := c.SelectGrant(ctx, &milvuspb.SelectGrantRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } { resp, err := c.SelectGrant(ctx, &milvuspb.SelectGrantRequest{Entity: &milvuspb.GrantEntity{Role: &milvuspb.RoleEntity{Name: "foo"}}}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } mockMeta := c.meta.(*mockMetaTable) mockMeta.SelectRoleFunc = func(tenant string, entity *milvuspb.RoleEntity, includeUserInfo bool) ([]*milvuspb.RoleResult, error) { @@ -1960,12 +2032,12 @@ func TestRootCoord_RBACError(t *testing.T) { { resp, err := c.SelectGrant(ctx, &milvuspb.SelectGrantRequest{Entity: &milvuspb.GrantEntity{Role: &milvuspb.RoleEntity{Name: "foo"}, Object: &milvuspb.ObjectEntity{Name: "CollectionFoo"}}}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } { resp, err := c.SelectGrant(ctx, &milvuspb.SelectGrantRequest{Entity: &milvuspb.GrantEntity{Role: &milvuspb.RoleEntity{Name: "foo"}, Object: &milvuspb.ObjectEntity{Name: "Collection"}}}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) } mockMeta.SelectRoleFunc = func(tenant string, entity *milvuspb.RoleEntity, includeUserInfo bool) ([]*milvuspb.RoleResult, error) { return nil, errors.New("mock error") @@ -1975,7 +2047,7 @@ func TestRootCoord_RBACError(t *testing.T) { t.Run("list policy failed", func(t *testing.T) { resp, err := c.ListPolicy(ctx, &internalpb.ListPolicyRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) mockMeta := c.meta.(*mockMetaTable) mockMeta.ListPolicyFunc = func(tenant string) ([]string, error) { @@ -1983,7 +2055,7 @@ func TestRootCoord_RBACError(t *testing.T) { } resp, err = c.ListPolicy(ctx, &internalpb.ListPolicyRequest{}) assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) mockMeta.ListPolicyFunc = func(tenant string) ([]string, error) { return []string{}, errors.New("mock error") } @@ -1995,8 +2067,7 @@ func TestCore_Stop(t *testing.T) { c := &Core{} err := c.Stop() assert.NoError(t, err) - code, ok := c.stateCode.Load().(commonpb.StateCode) - assert.True(t, ok) + code := c.GetStateCode() assert.Equal(t, commonpb.StateCode_Abnormal, code) }) @@ -2006,8 +2077,7 @@ func TestCore_Stop(t *testing.T) { c.ctx, c.cancel = context.WithCancel(context.Background()) err := c.Stop() assert.NoError(t, err) - code, ok := c.stateCode.Load().(commonpb.StateCode) - assert.True(t, ok) + code := c.GetStateCode() assert.Equal(t, commonpb.StateCode_Abnormal, code) }) } @@ -2041,7 +2111,8 @@ func (s *RootCoordSuite) TestRestore() { meta.EXPECT().ListDatabases(mock.Anything, mock.Anything). Return([]*model.Database{ {Name: "available_colls_db"}, - {Name: "not_available_colls_db"}}, nil) + {Name: "not_available_colls_db"}, + }, nil) meta.EXPECT().ListCollections(mock.Anything, "available_colls_db", mock.Anything, false). Return([]*model.Collection{ diff --git a/internal/rootcoord/show_collection_task.go b/internal/rootcoord/show_collection_task.go index de29ca0091ae4..31b88e2b58791 100644 --- a/internal/rootcoord/show_collection_task.go +++ b/internal/rootcoord/show_collection_task.go @@ -21,6 +21,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -41,14 +42,14 @@ func (t *showCollectionTask) Prepare(ctx context.Context) error { // Execute task execution func (t *showCollectionTask) Execute(ctx context.Context) error { - t.Rsp.Status = succStatus() + t.Rsp.Status = merr.Success() ts := t.Req.GetTimeStamp() if ts == 0 { ts = typeutil.MaxTimestamp } colls, err := t.core.meta.ListCollections(ctx, t.Req.GetDbName(), ts, true) if err != nil { - t.Rsp.Status = failStatus(commonpb.ErrorCode_UnexpectedError, err.Error()) + t.Rsp.Status = merr.Status(err) return err } for _, meta := range colls { diff --git a/internal/rootcoord/show_collection_task_test.go b/internal/rootcoord/show_collection_task_test.go index 9cf72af2f9219..3929b86d2bcd5 100644 --- a/internal/rootcoord/show_collection_task_test.go +++ b/internal/rootcoord/show_collection_task_test.go @@ -20,10 +20,11 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/metastore/model" - "github.com/stretchr/testify/assert" ) func Test_showCollectionTask_Prepare(t *testing.T) { diff --git a/internal/rootcoord/show_partition_task.go b/internal/rootcoord/show_partition_task.go index 6d60cd429b921..5e1f8214b4824 100644 --- a/internal/rootcoord/show_partition_task.go +++ b/internal/rootcoord/show_partition_task.go @@ -22,6 +22,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/metastore/model" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -45,14 +46,14 @@ func (t *showPartitionTask) Prepare(ctx context.Context) error { func (t *showPartitionTask) Execute(ctx context.Context) error { var coll *model.Collection var err error - t.Rsp.Status = succStatus() + t.Rsp.Status = merr.Success() if t.Req.GetCollectionName() == "" { coll, err = t.core.meta.GetCollectionByID(ctx, t.Req.GetDbName(), t.Req.GetCollectionID(), typeutil.MaxTimestamp, t.allowUnavailable) } else { coll, err = t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp) } if err != nil { - t.Rsp.Status = failStatus(commonpb.ErrorCode_CollectionNotExists, err.Error()) + t.Rsp.Status = merr.Status(err) return err } diff --git a/internal/rootcoord/show_partition_task_test.go b/internal/rootcoord/show_partition_task_test.go index de606dd8a8628..94b03bf3e923d 100644 --- a/internal/rootcoord/show_partition_task_test.go +++ b/internal/rootcoord/show_partition_task_test.go @@ -20,11 +20,15 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/metastore/model" + mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/stretchr/testify/assert" ) func Test_showPartitionTask_Prepare(t *testing.T) { @@ -55,7 +59,9 @@ func Test_showPartitionTask_Prepare(t *testing.T) { func Test_showPartitionTask_Execute(t *testing.T) { t.Run("failed to list collections by name", func(t *testing.T) { - core := newTestCore(withInvalidMeta()) + metaTable := mockrootcoord.NewIMetaTable(t) + metaTable.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, "test coll", mock.Anything).Return(nil, merr.WrapErrCollectionNotFound("test coll")) + core := newTestCore(withMeta(metaTable)) task := &showPartitionTask{ baseTask: newBaseTask(context.Background(), core), Req: &milvuspb.ShowPartitionsRequest{ @@ -67,12 +73,14 @@ func Test_showPartitionTask_Execute(t *testing.T) { Rsp: &milvuspb.ShowPartitionsResponse{}, } err := task.Execute(context.Background()) - assert.Error(t, err) - assert.Equal(t, task.Rsp.GetStatus().GetErrorCode(), commonpb.ErrorCode_CollectionNotExists) + assert.ErrorIs(t, err, merr.ErrCollectionNotFound) + assert.ErrorIs(t, merr.Error(task.Rsp.GetStatus()), merr.ErrCollectionNotFound) }) t.Run("failed to list collections by id", func(t *testing.T) { - core := newTestCore(withInvalidMeta()) + metaTable := mockrootcoord.NewIMetaTable(t) + metaTable.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.WrapErrCollectionNotFound(1)) + core := newTestCore(withMeta(metaTable)) task := &showPartitionTask{ baseTask: newBaseTask(context.Background(), core), Req: &milvuspb.ShowPartitionsRequest{ @@ -84,8 +92,8 @@ func Test_showPartitionTask_Execute(t *testing.T) { Rsp: &milvuspb.ShowPartitionsResponse{}, } err := task.Execute(context.Background()) - assert.Error(t, err) - assert.Equal(t, task.Rsp.GetStatus().GetErrorCode(), commonpb.ErrorCode_CollectionNotExists) + assert.ErrorIs(t, err, merr.ErrCollectionNotFound) + assert.ErrorIs(t, merr.Error(task.Rsp.GetStatus()), merr.ErrCollectionNotFound) }) t.Run("success", func(t *testing.T) { diff --git a/internal/rootcoord/step.go b/internal/rootcoord/step.go index 60ba42f737c7f..5a51996a6c05c 100644 --- a/internal/rootcoord/step.go +++ b/internal/rootcoord/step.go @@ -125,12 +125,15 @@ type unwatchChannelsStep struct { baseStep collectionID UniqueID channels collectionChannels + + isSkip bool } func (s *unwatchChannelsStep) Execute(ctx context.Context) ([]nestedStep, error) { unwatchByDropMsg := &deleteCollectionDataStep{ baseStep: baseStep{core: s.core}, coll: &model.Collection{CollectionID: s.collectionID, PhysicalChannelNames: s.channels.physicalChannels}, + isSkip: s.isSkip, } return unwatchByDropMsg.Execute(ctx) } @@ -183,9 +186,14 @@ func (s *expireCacheStep) Desc() string { type deleteCollectionDataStep struct { baseStep coll *model.Collection + + isSkip bool } func (s *deleteCollectionDataStep) Execute(ctx context.Context) ([]nestedStep, error) { + if s.isSkip { + return nil, nil + } ddlTs, err := s.core.garbageCollector.GcCollectionData(ctx, s.coll) if err != nil { return nil, err @@ -239,9 +247,14 @@ type deletePartitionDataStep struct { baseStep pchans []string partition *model.Partition + + isSkip bool } func (s *deletePartitionDataStep) Execute(ctx context.Context) ([]nestedStep, error) { + if s.isSkip { + return nil, nil + } _, err := s.core.garbageCollector.GcPartitionData(ctx, s.pchans, s.partition) return nil, err } @@ -382,8 +395,7 @@ func (s *removePartitionMetaStep) Weight() stepPriority { return stepPriorityNormal } -type nullStep struct { -} +type nullStep struct{} func (s *nullStep) Execute(ctx context.Context) ([]nestedStep, error) { return nil, nil diff --git a/internal/rootcoord/step_executor.go b/internal/rootcoord/step_executor.go index 03da036cec465..f28b51d2ad5e8 100644 --- a/internal/rootcoord/step_executor.go +++ b/internal/rootcoord/step_executor.go @@ -22,9 +22,10 @@ import ( "sync" "time" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/retry" - "go.uber.org/zap" ) const ( diff --git a/internal/rootcoord/step_executor_test.go b/internal/rootcoord/step_executor_test.go index 9d8d85c265d79..29d3d4784853a 100644 --- a/internal/rootcoord/step_executor_test.go +++ b/internal/rootcoord/step_executor_test.go @@ -23,13 +23,12 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/pkg/util/retry" - "github.com/stretchr/testify/assert" ) -type mockChildStep struct { -} +type mockChildStep struct{} func (m *mockChildStep) Execute(ctx context.Context) ([]nestedStep, error) { return nil, nil @@ -47,8 +46,7 @@ func newMockChildStep() *mockChildStep { return &mockChildStep{} } -type mockStepWithChild struct { -} +type mockStepWithChild struct{} func (m *mockStepWithChild) Execute(ctx context.Context) ([]nestedStep, error) { return []nestedStep{newMockChildStep()}, nil diff --git a/internal/rootcoord/step_test.go b/internal/rootcoord/step_test.go index 384f0196d7275..ef1315e54b0df 100644 --- a/internal/rootcoord/step_test.go +++ b/internal/rootcoord/step_test.go @@ -25,8 +25,8 @@ import ( ) func Test_waitForTsSyncedStep_Execute(t *testing.T) { - //Params.InitOnce() - //Params.ProxyCfg.TimeTickInterval = time.Millisecond + // Params.InitOnce() + // Params.ProxyCfg.TimeTickInterval = time.Millisecond ticker := newRocksMqTtSynchronizer() core := newTestCore(withTtSynchronizer(ticker)) @@ -95,3 +95,23 @@ func Test_confirmGCStep_Execute(t *testing.T) { assert.NoError(t, err) }) } + +func TestSkip(t *testing.T) { + { + s := &unwatchChannelsStep{isSkip: true} + _, err := s.Execute(context.Background()) + assert.NoError(t, err) + } + + { + s := &deleteCollectionDataStep{isSkip: true} + _, err := s.Execute(context.Background()) + assert.NoError(t, err) + } + + { + s := &deletePartitionDataStep{isSkip: true} + _, err := s.Execute(context.Background()) + assert.NoError(t, err) + } +} diff --git a/internal/rootcoord/timestamp_bench_test.go b/internal/rootcoord/timestamp_bench_test.go index 92e5232ddb87a..e8526af5de338 100644 --- a/internal/rootcoord/timestamp_bench_test.go +++ b/internal/rootcoord/timestamp_bench_test.go @@ -87,7 +87,6 @@ func Benchmark_RootCoord_AllocTimestamp(b *testing.B) { } _, err := c.AllocTimestamp(ctx, &req) assert.Nil(b, err) - } b.StopTimer() } diff --git a/internal/rootcoord/timeticksync_test.go b/internal/rootcoord/timeticksync_test.go index 26a4ee0ebf0f7..40b6a986db821 100644 --- a/internal/rootcoord/timeticksync_test.go +++ b/internal/rootcoord/timeticksync_test.go @@ -129,10 +129,10 @@ func TestMultiTimetickSync(t *testing.T) { defer wg.Done() // suppose this is rooit - ttSync.addSession(&sessionutil.Session{ServerID: 1}) + ttSync.addSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 1}}) // suppose this is proxy1 - ttSync.addSession(&sessionutil.Session{ServerID: 2}) + ttSync.addSession(&sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 2}}) msg := &internalpb.ChannelTimeTickMsg{ Base: &commonpb.MsgBase{ @@ -221,7 +221,6 @@ func TestTimetickSyncWithExistChannels(t *testing.T) { }) // test get new channels - } func TestTimetickSyncInvalidName(t *testing.T) { diff --git a/internal/rootcoord/undo.go b/internal/rootcoord/undo.go index 8827f9882372e..29e9b7dfac05b 100644 --- a/internal/rootcoord/undo.go +++ b/internal/rootcoord/undo.go @@ -20,8 +20,9 @@ import ( "context" "fmt" - "github.com/milvus-io/milvus/pkg/log" "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" ) type baseUndoTask struct { diff --git a/internal/rootcoord/util.go b/internal/rootcoord/util.go index b1c3987669687..59f5b49a96c88 100644 --- a/internal/rootcoord/util.go +++ b/internal/rootcoord/util.go @@ -24,11 +24,9 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" - "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -53,16 +51,6 @@ func EqualKeyPairArray(p1 []*commonpb.KeyValuePair, p2 []*commonpb.KeyValuePair) return true } -// GetFieldSchemaByID return field schema by id -func GetFieldSchemaByID(coll *model.Collection, fieldID typeutil.UniqueID) (*model.Field, error) { - for _, f := range coll.Fields { - if f.FieldID == fieldID { - return f, nil - } - } - return nil, fmt.Errorf("field id = %d not found", fieldID) -} - // EncodeMsgPositions serialize []*MsgPosition into string func EncodeMsgPositions(msgPositions []*msgstream.MsgPosition) (string, error) { if len(msgPositions) == 0 { @@ -106,18 +94,6 @@ func CheckMsgType(got, expect commonpb.MsgType) error { return nil } -// Deprecated: use merr.StatusWithErrorCode or merr.Status instead -func failStatus(code commonpb.ErrorCode, reason string) *commonpb.Status { - return &commonpb.Status{ - ErrorCode: code, - Reason: reason, - } -} - -func succStatus() *commonpb.Status { - return merr.Status(nil) -} - type TimeTravelRequest interface { GetBase() *commonpb.MsgBase GetTimeStamp() Timestamp diff --git a/internal/rootcoord/util_test.go b/internal/rootcoord/util_test.go index 7c9ea2c472c1c..de03271400d4a 100644 --- a/internal/rootcoord/util_test.go +++ b/internal/rootcoord/util_test.go @@ -19,13 +19,13 @@ package rootcoord import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/stretchr/testify/assert" ) func Test_EqualKeyPairArray(t *testing.T) { @@ -61,20 +61,6 @@ func Test_EqualKeyPairArray(t *testing.T) { assert.True(t, EqualKeyPairArray(p1, p2)) } -func Test_GetFieldSchemaByID(t *testing.T) { - coll := &model.Collection{ - Fields: []*model.Field{ - { - FieldID: 1, - }, - }, - } - _, err := GetFieldSchemaByID(coll, 1) - assert.NoError(t, err) - _, err = GetFieldSchemaByID(coll, 2) - assert.Error(t, err) -} - func Test_EncodeMsgPositions(t *testing.T) { mp := &msgstream.MsgPosition{ ChannelName: "test", diff --git a/internal/storage/aliyun/aliyun_test.go b/internal/storage/aliyun/aliyun_test.go index 8e9660eb6e181..b05f3b645b6c7 100644 --- a/internal/storage/aliyun/aliyun_test.go +++ b/internal/storage/aliyun/aliyun_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/internal/storage/aliyun/mocks" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/storage/aliyun/mocks" ) func TestNewMinioClient(t *testing.T) { diff --git a/internal/storage/azure_object_storage.go b/internal/storage/azure_object_storage.go new file mode 100644 index 0000000000000..3703f7d17fcc0 --- /dev/null +++ b/internal/storage/azure_object_storage.go @@ -0,0 +1,140 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "fmt" + "io" + "os" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service" + + "github.com/milvus-io/milvus/pkg/util/retry" +) + +type AzureObjectStorage struct { + *service.Client +} + +func newAzureObjectStorageWithConfig(ctx context.Context, c *config) (*AzureObjectStorage, error) { + var client *service.Client + var err error + if c.useIAM { + cred, credErr := azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{ + ClientID: os.Getenv("AZURE_CLIENT_ID"), + TenantID: os.Getenv("AZURE_TENANT_ID"), + TokenFilePath: os.Getenv("AZURE_FEDERATED_TOKEN_FILE"), + }) + if credErr != nil { + return nil, credErr + } + client, err = service.NewClient("https://"+c.accessKeyID+".blob."+c.address+"/", cred, &service.ClientOptions{}) + } else { + connectionString := os.Getenv("AZURE_STORAGE_CONNECTION_STRING") + if connectionString == "" { + connectionString = "DefaultEndpointsProtocol=https;AccountName=" + c.accessKeyID + + ";AccountKey=" + c.secretAccessKeyID + ";EndpointSuffix=" + c.address + } + client, err = service.NewClientFromConnectionString(connectionString, &service.ClientOptions{}) + } + if err != nil { + return nil, err + } + if c.bucketName == "" { + return nil, fmt.Errorf("invalid bucket name") + } + // check valid in first query + checkBucketFn := func() error { + _, err := client.NewContainerClient(c.bucketName).GetProperties(ctx, &container.GetPropertiesOptions{}) + if err != nil { + switch err := err.(type) { + case *azcore.ResponseError: + if c.createBucket && err.ErrorCode == string(bloberror.ContainerNotFound) { + _, createErr := client.NewContainerClient(c.bucketName).Create(ctx, &azblob.CreateContainerOptions{}) + if createErr != nil { + return createErr + } + return nil + } + } + } + return err + } + err = retry.Do(ctx, checkBucketFn, retry.Attempts(CheckBucketRetryAttempts)) + if err != nil { + return nil, err + } + return &AzureObjectStorage{Client: client}, nil +} + +func (AzureObjectStorage *AzureObjectStorage) GetObject(ctx context.Context, bucketName, objectName string, offset int64, size int64) (FileReader, error) { + opts := azblob.DownloadStreamOptions{} + if offset > 0 { + opts.Range = azblob.HTTPRange{ + Offset: offset, + Count: size, + } + } + object, err := AzureObjectStorage.Client.NewContainerClient(bucketName).NewBlockBlobClient(objectName).DownloadStream(ctx, &opts) + if err != nil { + return nil, err + } + return object.Body, nil +} + +func (AzureObjectStorage *AzureObjectStorage) PutObject(ctx context.Context, bucketName, objectName string, reader io.Reader, objectSize int64) error { + _, err := AzureObjectStorage.Client.NewContainerClient(bucketName).NewBlockBlobClient(objectName).UploadStream(ctx, reader, &azblob.UploadStreamOptions{}) + return err +} + +func (AzureObjectStorage *AzureObjectStorage) StatObject(ctx context.Context, bucketName, objectName string) (int64, error) { + info, err := AzureObjectStorage.Client.NewContainerClient(bucketName).NewBlockBlobClient(objectName).GetProperties(ctx, &blob.GetPropertiesOptions{}) + if err != nil { + return 0, err + } + return *info.ContentLength, nil +} + +func (AzureObjectStorage *AzureObjectStorage) ListObjects(ctx context.Context, bucketName string, prefix string, recursive bool) (map[string]time.Time, error) { + pager := AzureObjectStorage.Client.NewContainerClient(bucketName).NewListBlobsFlatPager(&azblob.ListBlobsFlatOptions{ + Prefix: &prefix, + }) + objects := map[string]time.Time{} + if pager.More() { + pageResp, err := pager.NextPage(context.Background()) + if err != nil { + return nil, err + } + for _, blob := range pageResp.Segment.BlobItems { + objects[*blob.Name] = *blob.Properties.LastModified + } + } + return objects, nil +} + +func (AzureObjectStorage *AzureObjectStorage) RemoveObject(ctx context.Context, bucketName, objectName string) error { + _, err := AzureObjectStorage.Client.NewContainerClient(bucketName).NewBlockBlobClient(objectName).Delete(ctx, &blob.DeleteOptions{}) + return err +} diff --git a/internal/storage/azure_object_storage_test.go b/internal/storage/azure_object_storage_test.go new file mode 100644 index 0000000000000..4af7eff35cdb6 --- /dev/null +++ b/internal/storage/azure_object_storage_test.go @@ -0,0 +1,165 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "bytes" + "context" + "io" + "os" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAzureObjectStorage(t *testing.T) { + ctx := context.Background() + bucketName := Params.MinioCfg.BucketName.GetValue() + config := config{ + bucketName: bucketName, + createBucket: true, + useIAM: false, + cloudProvider: "azure", + } + + t.Run("test initialize", func(t *testing.T) { + var err error + config.bucketName = "" + _, err = newAzureObjectStorageWithConfig(ctx, &config) + assert.Error(t, err) + config.bucketName = bucketName + _, err = newAzureObjectStorageWithConfig(ctx, &config) + assert.Equal(t, err, nil) + }) + + t.Run("test load", func(t *testing.T) { + testCM, err := newAzureObjectStorageWithConfig(ctx, &config) + assert.Equal(t, err, nil) + defer testCM.DeleteContainer(ctx, config.bucketName, &azblob.DeleteContainerOptions{}) + + prepareTests := []struct { + key string + value []byte + }{ + {"abc", []byte("123")}, + {"abcd", []byte("1234")}, + {"key_1", []byte("111")}, + {"key_2", []byte("222")}, + {"key_3", []byte("333")}, + } + + for _, test := range prepareTests { + err := testCM.PutObject(ctx, config.bucketName, test.key, bytes.NewReader(test.value), int64(len(test.value))) + require.NoError(t, err) + } + + loadTests := []struct { + isvalid bool + loadKey string + expectedValue []byte + + description string + }{ + {true, "abc", []byte("123"), "load valid key abc"}, + {true, "abcd", []byte("1234"), "load valid key abcd"}, + {true, "key_1", []byte("111"), "load valid key key_1"}, + {true, "key_2", []byte("222"), "load valid key key_2"}, + {true, "key_3", []byte("333"), "load valid key key_3"}, + {false, "key_not_exist", []byte(""), "load invalid key key_not_exist"}, + {false, "/", []byte(""), "load leading slash"}, + } + + for _, test := range loadTests { + t.Run(test.description, func(t *testing.T) { + if test.isvalid { + got, err := testCM.GetObject(ctx, config.bucketName, test.loadKey, 0, 1024) + assert.NoError(t, err) + contentData, err := io.ReadAll(got) + assert.NoError(t, err) + assert.Equal(t, len(contentData), len(test.expectedValue)) + assert.Equal(t, test.expectedValue, contentData) + statSize, err := testCM.StatObject(ctx, config.bucketName, test.loadKey) + assert.NoError(t, err) + assert.Equal(t, statSize, int64(len(contentData))) + _, err = testCM.GetObject(ctx, config.bucketName, test.loadKey, 1, 1023) + assert.NoError(t, err) + } else { + if test.loadKey == "/" { + got, err := testCM.GetObject(ctx, config.bucketName, test.loadKey, 0, 1024) + assert.Error(t, err) + assert.Empty(t, got) + return + } + got, err := testCM.GetObject(ctx, config.bucketName, test.loadKey, 0, 1024) + assert.Error(t, err) + assert.Empty(t, got) + } + }) + } + + loadWithPrefixTests := []struct { + isvalid bool + prefix string + expectedValue [][]byte + + description string + }{ + {true, "abc", [][]byte{[]byte("123"), []byte("1234")}, "load with valid prefix abc"}, + {true, "key_", [][]byte{[]byte("111"), []byte("222"), []byte("333")}, "load with valid prefix key_"}, + {true, "prefix", [][]byte{}, "load with valid but not exist prefix prefix"}, + } + + for _, test := range loadWithPrefixTests { + t.Run(test.description, func(t *testing.T) { + gotk, err := testCM.ListObjects(ctx, config.bucketName, test.prefix, false) + assert.NoError(t, err) + assert.Equal(t, len(test.expectedValue), len(gotk)) + for key := range gotk { + err := testCM.RemoveObject(ctx, config.bucketName, key) + assert.NoError(t, err) + } + }) + } + }) + + t.Run("test useIAM", func(t *testing.T) { + var err error + config.useIAM = true + _, err = newAzureObjectStorageWithConfig(ctx, &config) + assert.Error(t, err) + os.Setenv("AZURE_CLIENT_ID", "00000000-0000-0000-0000-00000000000") + os.Setenv("AZURE_TENANT_ID", "00000000-0000-0000-0000-00000000000") + os.Setenv("AZURE_FEDERATED_TOKEN_FILE", "/var/run/secrets/tokens/azure-identity-token") + _, err = newAzureObjectStorageWithConfig(ctx, &config) + assert.Error(t, err) + config.useIAM = false + }) + + t.Run("test key secret", func(t *testing.T) { + var err error + connectionString := os.Getenv("AZURE_STORAGE_CONNECTION_STRING") + os.Setenv("AZURE_STORAGE_CONNECTION_STRING", "") + config.accessKeyID = "devstoreaccount1" + config.secretAccessKeyID = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==" + config.address = "core.windows.net" + _, err = newAzureObjectStorageWithConfig(ctx, &config) + assert.Error(t, err) + os.Setenv("AZURE_STORAGE_CONNECTION_STRING", connectionString) + }) +} diff --git a/internal/storage/binlog_iterator.go b/internal/storage/binlog_iterator.go index c3445d3a86893..fad450b8ad7be 100644 --- a/internal/storage/binlog_iterator.go +++ b/internal/storage/binlog_iterator.go @@ -66,7 +66,6 @@ func NewInsertBinlogIterator(blobs []*Blob, PKfieldID UniqueID, pkType schemapb. reader := NewInsertCodecWithSchema(nil) _, _, serData, err := reader.Deserialize(blobs) - if err != nil { return nil, err } diff --git a/internal/storage/binlog_iterator_test.go b/internal/storage/binlog_iterator_test.go index 7a91316af4801..d62218aec851f 100644 --- a/internal/storage/binlog_iterator_test.go +++ b/internal/storage/binlog_iterator_test.go @@ -19,11 +19,11 @@ package storage import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/pkg/common" - - "github.com/stretchr/testify/assert" ) func generateTestData(t *testing.T, num int) []*Blob { diff --git a/internal/storage/binlog_test.go b/internal/storage/binlog_test.go index a55a83c0b732f..15454bfb71e74 100644 --- a/internal/storage/binlog_test.go +++ b/internal/storage/binlog_test.go @@ -25,14 +25,14 @@ import ( "time" "unsafe" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/uniquegenerator" - - "github.com/stretchr/testify/assert" ) /* #nosec G103 */ @@ -89,63 +89,63 @@ func TestInsertBinlog(t *testing.T) { assert.LessOrEqual(t, diffts, maxdiff) pos += int(unsafe.Sizeof(ts)) - //descriptor header, type code + // descriptor header, type code tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(tc), DescriptorEventType) pos += int(unsafe.Sizeof(tc)) - //descriptor header, event length + // descriptor header, event length descEventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(descEventLen)) - //descriptor header, next position + // descriptor header, next position descNxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, descEventLen+int32(unsafe.Sizeof(MagicNumber)), descNxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //descriptor data fix, collection id + // descriptor data fix, collection id collID := UnsafeReadInt64(buf, pos) assert.Equal(t, collID, int64(10)) pos += int(unsafe.Sizeof(collID)) - //descriptor data fix, partition id + // descriptor data fix, partition id partID := UnsafeReadInt64(buf, pos) assert.Equal(t, partID, int64(20)) pos += int(unsafe.Sizeof(partID)) - //descriptor data fix, segment id + // descriptor data fix, segment id segID := UnsafeReadInt64(buf, pos) assert.Equal(t, segID, int64(30)) pos += int(unsafe.Sizeof(segID)) - //descriptor data fix, field id + // descriptor data fix, field id fieldID := UnsafeReadInt64(buf, pos) assert.Equal(t, fieldID, int64(40)) pos += int(unsafe.Sizeof(fieldID)) - //descriptor data fix, start time stamp + // descriptor data fix, start time stamp startts := UnsafeReadInt64(buf, pos) assert.Equal(t, startts, int64(1000)) pos += int(unsafe.Sizeof(startts)) - //descriptor data fix, end time stamp + // descriptor data fix, end time stamp endts := UnsafeReadInt64(buf, pos) assert.Equal(t, endts, int64(2000)) pos += int(unsafe.Sizeof(endts)) - //descriptor data fix, payload type + // descriptor data fix, payload type colType := UnsafeReadInt32(buf, pos) assert.Equal(t, schemapb.DataType(colType), schemapb.DataType_Int64) pos += int(unsafe.Sizeof(colType)) - //descriptor data, post header lengths + // descriptor data, post header lengths for i := DescriptorEventType; i < EventTypeEnd; i++ { size := getEventFixPartSize(i) assert.Equal(t, uint8(size), buf[pos]) pos++ } - //descriptor data, extra length + // descriptor data, extra length extraLength := UnsafeReadInt32(buf, pos) assert.Equal(t, extraLength, w.baseBinlogWriter.descriptorEventData.ExtraLength) pos += int(unsafe.Sizeof(extraLength)) @@ -166,40 +166,40 @@ func TestInsertBinlog(t *testing.T) { assert.True(t, ok) assert.Equal(t, fmt.Sprintf("%v", sizeTotal), fmt.Sprintf("%v", size)) - //start of e1 + // start of e1 assert.Equal(t, pos, int(descNxtPos)) - //insert e1 header, Timestamp + // insert e1 header, Timestamp e1ts := UnsafeReadInt64(buf, pos) diffts = curts - e1ts assert.LessOrEqual(t, diffts, maxdiff) pos += int(unsafe.Sizeof(e1ts)) - //insert e1 header, type code + // insert e1 header, type code e1tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(e1tc), InsertEventType) pos += int(unsafe.Sizeof(e1tc)) - //insert e1 header, event length + // insert e1 header, event length e1EventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(e1EventLen)) - //insert e1 header, next position + // insert e1 header, next position e1NxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, descNxtPos+e1EventLen, e1NxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //insert e1 data, start time stamp + // insert e1 data, start time stamp e1st := UnsafeReadInt64(buf, pos) assert.Equal(t, e1st, int64(100)) pos += int(unsafe.Sizeof(e1st)) - //insert e1 data, end time stamp + // insert e1 data, end time stamp e1et := UnsafeReadInt64(buf, pos) assert.Equal(t, e1et, int64(200)) pos += int(unsafe.Sizeof(e1et)) - //insert e1, payload + // insert e1, payload e1Payload := buf[pos:e1NxtPos] e1r, err := NewPayloadReader(schemapb.DataType_Int64, e1Payload) assert.NoError(t, err) @@ -208,40 +208,40 @@ func TestInsertBinlog(t *testing.T) { assert.Equal(t, e1a, []int64{1, 2, 3, 4, 5, 6}) e1r.Close() - //start of e2 + // start of e2 pos = int(e1NxtPos) - //insert e2 header, Timestamp + // insert e2 header, Timestamp e2ts := UnsafeReadInt64(buf, pos) diffts = curts - e2ts assert.LessOrEqual(t, diffts, maxdiff) pos += int(unsafe.Sizeof(e2ts)) - //insert e2 header, type code + // insert e2 header, type code e2tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(e2tc), InsertEventType) pos += int(unsafe.Sizeof(e2tc)) - //insert e2 header, event length + // insert e2 header, event length e2EventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(e2EventLen)) - //insert e2 header, next position + // insert e2 header, next position e2NxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, e1NxtPos+e2EventLen, e2NxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //insert e2 data, start time stamp + // insert e2 data, start time stamp e2st := UnsafeReadInt64(buf, pos) assert.Equal(t, e2st, int64(300)) pos += int(unsafe.Sizeof(e2st)) - //insert e2 data, end time stamp + // insert e2 data, end time stamp e2et := UnsafeReadInt64(buf, pos) assert.Equal(t, e2et, int64(400)) pos += int(unsafe.Sizeof(e2et)) - //insert e2, payload + // insert e2, payload e2Payload := buf[pos:] e2r, err := NewPayloadReader(schemapb.DataType_Int64, e2Payload) assert.NoError(t, err) @@ -252,7 +252,7 @@ func TestInsertBinlog(t *testing.T) { assert.Equal(t, int(e2NxtPos), len(buf)) - //read binlog + // read binlog r, err := NewBinlogReader(buf) assert.NoError(t, err) event1, err := r.NextEventReader() @@ -321,12 +321,12 @@ func TestDeleteBinlog(t *testing.T) { w.Close() - //magic number + // magic number magicNum := UnsafeReadInt32(buf, 0) assert.Equal(t, magicNum, MagicNumber) pos := int(unsafe.Sizeof(MagicNumber)) - //descriptor header, timestamp + // descriptor header, timestamp ts := UnsafeReadInt64(buf, pos) assert.Greater(t, ts, int64(0)) curts := time.Now().UnixNano() / int64(time.Millisecond) @@ -336,63 +336,63 @@ func TestDeleteBinlog(t *testing.T) { assert.LessOrEqual(t, diffts, maxdiff) pos += int(unsafe.Sizeof(ts)) - //descriptor header, type code + // descriptor header, type code tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(tc), DescriptorEventType) pos += int(unsafe.Sizeof(tc)) - //descriptor header, event length + // descriptor header, event length descEventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(descEventLen)) - //descriptor header, next position + // descriptor header, next position descNxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, descEventLen+int32(unsafe.Sizeof(MagicNumber)), descNxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //descriptor data fix, collection id + // descriptor data fix, collection id collID := UnsafeReadInt64(buf, pos) assert.Equal(t, collID, int64(50)) pos += int(unsafe.Sizeof(collID)) - //descriptor data fix, partition id + // descriptor data fix, partition id partID := UnsafeReadInt64(buf, pos) assert.Equal(t, partID, int64(1)) pos += int(unsafe.Sizeof(partID)) - //descriptor data fix, segment id + // descriptor data fix, segment id segID := UnsafeReadInt64(buf, pos) assert.Equal(t, segID, int64(1)) pos += int(unsafe.Sizeof(segID)) - //descriptor data fix, field id + // descriptor data fix, field id fieldID := UnsafeReadInt64(buf, pos) assert.Equal(t, fieldID, int64(-1)) pos += int(unsafe.Sizeof(fieldID)) - //descriptor data fix, start time stamp + // descriptor data fix, start time stamp startts := UnsafeReadInt64(buf, pos) assert.Equal(t, startts, int64(1000)) pos += int(unsafe.Sizeof(startts)) - //descriptor data fix, end time stamp + // descriptor data fix, end time stamp endts := UnsafeReadInt64(buf, pos) assert.Equal(t, endts, int64(2000)) pos += int(unsafe.Sizeof(endts)) - //descriptor data fix, payload type + // descriptor data fix, payload type colType := UnsafeReadInt32(buf, pos) assert.Equal(t, schemapb.DataType(colType), schemapb.DataType_Int64) pos += int(unsafe.Sizeof(colType)) - //descriptor data, post header lengths + // descriptor data, post header lengths for i := DescriptorEventType; i < EventTypeEnd; i++ { size := getEventFixPartSize(i) assert.Equal(t, uint8(size), buf[pos]) pos++ } - //descriptor data, extra length + // descriptor data, extra length extraLength := UnsafeReadInt32(buf, pos) assert.Equal(t, extraLength, w.baseBinlogWriter.descriptorEventData.ExtraLength) pos += int(unsafe.Sizeof(extraLength)) @@ -413,40 +413,40 @@ func TestDeleteBinlog(t *testing.T) { assert.True(t, ok) assert.Equal(t, fmt.Sprintf("%v", sizeTotal), fmt.Sprintf("%v", size)) - //start of e1 + // start of e1 assert.Equal(t, pos, int(descNxtPos)) - //insert e1 header, Timestamp + // insert e1 header, Timestamp e1ts := UnsafeReadInt64(buf, pos) diffts = curts - e1ts assert.LessOrEqual(t, diffts, maxdiff) pos += int(unsafe.Sizeof(e1ts)) - //insert e1 header, type code + // insert e1 header, type code e1tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(e1tc), DeleteEventType) pos += int(unsafe.Sizeof(e1tc)) - //insert e1 header, event length + // insert e1 header, event length e1EventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(e1EventLen)) - //insert e1 header, next position + // insert e1 header, next position e1NxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, descNxtPos+e1EventLen, e1NxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //insert e1 data, start time stamp + // insert e1 data, start time stamp e1st := UnsafeReadInt64(buf, pos) assert.Equal(t, e1st, int64(100)) pos += int(unsafe.Sizeof(e1st)) - //insert e1 data, end time stamp + // insert e1 data, end time stamp e1et := UnsafeReadInt64(buf, pos) assert.Equal(t, e1et, int64(200)) pos += int(unsafe.Sizeof(e1et)) - //insert e1, payload + // insert e1, payload e1Payload := buf[pos:e1NxtPos] e1r, err := NewPayloadReader(schemapb.DataType_Int64, e1Payload) assert.NoError(t, err) @@ -455,40 +455,40 @@ func TestDeleteBinlog(t *testing.T) { assert.Equal(t, e1a, []int64{1, 2, 3, 4, 5, 6}) e1r.Close() - //start of e2 + // start of e2 pos = int(e1NxtPos) - //insert e2 header, Timestamp + // insert e2 header, Timestamp e2ts := UnsafeReadInt64(buf, pos) diffts = curts - e2ts assert.LessOrEqual(t, diffts, maxdiff) pos += int(unsafe.Sizeof(e2ts)) - //insert e2 header, type code + // insert e2 header, type code e2tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(e2tc), DeleteEventType) pos += int(unsafe.Sizeof(e2tc)) - //insert e2 header, event length + // insert e2 header, event length e2EventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(e2EventLen)) - //insert e2 header, next position + // insert e2 header, next position e2NxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, e1NxtPos+e2EventLen, e2NxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //insert e2 data, start time stamp + // insert e2 data, start time stamp e2st := UnsafeReadInt64(buf, pos) assert.Equal(t, e2st, int64(300)) pos += int(unsafe.Sizeof(e2st)) - //insert e2 data, end time stamp + // insert e2 data, end time stamp e2et := UnsafeReadInt64(buf, pos) assert.Equal(t, e2et, int64(400)) pos += int(unsafe.Sizeof(e2et)) - //insert e2, payload + // insert e2, payload e2Payload := buf[pos:] e2r, err := NewPayloadReader(schemapb.DataType_Int64, e2Payload) assert.NoError(t, err) @@ -499,7 +499,7 @@ func TestDeleteBinlog(t *testing.T) { assert.Equal(t, int(e2NxtPos), len(buf)) - //read binlog + // read binlog r, err := NewBinlogReader(buf) assert.NoError(t, err) event1, err := r.NextEventReader() @@ -568,12 +568,12 @@ func TestDDLBinlog1(t *testing.T) { w.Close() - //magic number + // magic number magicNum := UnsafeReadInt32(buf, 0) assert.Equal(t, magicNum, MagicNumber) pos := int(unsafe.Sizeof(MagicNumber)) - //descriptor header, timestamp + // descriptor header, timestamp ts := UnsafeReadInt64(buf, pos) assert.Greater(t, ts, int64(0)) curts := time.Now().UnixNano() / int64(time.Millisecond) @@ -583,63 +583,63 @@ func TestDDLBinlog1(t *testing.T) { assert.LessOrEqual(t, diffts, maxdiff) pos += int(unsafe.Sizeof(ts)) - //descriptor header, type code + // descriptor header, type code tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(tc), DescriptorEventType) pos += int(unsafe.Sizeof(tc)) - //descriptor header, event length + // descriptor header, event length descEventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(descEventLen)) - //descriptor header, next position + // descriptor header, next position descNxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, descEventLen+int32(unsafe.Sizeof(MagicNumber)), descNxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //descriptor data fix, collection id + // descriptor data fix, collection id collID := UnsafeReadInt64(buf, pos) assert.Equal(t, collID, int64(50)) pos += int(unsafe.Sizeof(collID)) - //descriptor data fix, partition id + // descriptor data fix, partition id partID := UnsafeReadInt64(buf, pos) assert.Equal(t, partID, int64(-1)) pos += int(unsafe.Sizeof(partID)) - //descriptor data fix, segment id + // descriptor data fix, segment id segID := UnsafeReadInt64(buf, pos) assert.Equal(t, segID, int64(-1)) pos += int(unsafe.Sizeof(segID)) - //descriptor data fix, field id + // descriptor data fix, field id fieldID := UnsafeReadInt64(buf, pos) assert.Equal(t, fieldID, int64(-1)) pos += int(unsafe.Sizeof(fieldID)) - //descriptor data fix, start time stamp + // descriptor data fix, start time stamp startts := UnsafeReadInt64(buf, pos) assert.Equal(t, startts, int64(1000)) pos += int(unsafe.Sizeof(startts)) - //descriptor data fix, end time stamp + // descriptor data fix, end time stamp endts := UnsafeReadInt64(buf, pos) assert.Equal(t, endts, int64(2000)) pos += int(unsafe.Sizeof(endts)) - //descriptor data fix, payload type + // descriptor data fix, payload type colType := UnsafeReadInt32(buf, pos) assert.Equal(t, schemapb.DataType(colType), schemapb.DataType_Int64) pos += int(unsafe.Sizeof(colType)) - //descriptor data, post header lengths + // descriptor data, post header lengths for i := DescriptorEventType; i < EventTypeEnd; i++ { size := getEventFixPartSize(i) assert.Equal(t, uint8(size), buf[pos]) pos++ } - //descriptor data, extra length + // descriptor data, extra length extraLength := UnsafeReadInt32(buf, pos) assert.Equal(t, extraLength, w.baseBinlogWriter.descriptorEventData.ExtraLength) pos += int(unsafe.Sizeof(extraLength)) @@ -660,40 +660,40 @@ func TestDDLBinlog1(t *testing.T) { assert.True(t, ok) assert.Equal(t, fmt.Sprintf("%v", sizeTotal), fmt.Sprintf("%v", size)) - //start of e1 + // start of e1 assert.Equal(t, pos, int(descNxtPos)) - //insert e1 header, Timestamp + // insert e1 header, Timestamp e1ts := UnsafeReadInt64(buf, pos) diffts = curts - e1ts assert.LessOrEqual(t, diffts, maxdiff) pos += int(unsafe.Sizeof(e1ts)) - //insert e1 header, type code + // insert e1 header, type code e1tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(e1tc), CreateCollectionEventType) pos += int(unsafe.Sizeof(e1tc)) - //insert e1 header, event length + // insert e1 header, event length e1EventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(e1EventLen)) - //insert e1 header, next position + // insert e1 header, next position e1NxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, descNxtPos+e1EventLen, e1NxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //insert e1 data, start time stamp + // insert e1 data, start time stamp e1st := UnsafeReadInt64(buf, pos) assert.Equal(t, e1st, int64(100)) pos += int(unsafe.Sizeof(e1st)) - //insert e1 data, end time stamp + // insert e1 data, end time stamp e1et := UnsafeReadInt64(buf, pos) assert.Equal(t, e1et, int64(200)) pos += int(unsafe.Sizeof(e1et)) - //insert e1, payload + // insert e1, payload e1Payload := buf[pos:e1NxtPos] e1r, err := NewPayloadReader(schemapb.DataType_Int64, e1Payload) assert.NoError(t, err) @@ -702,40 +702,40 @@ func TestDDLBinlog1(t *testing.T) { assert.Equal(t, e1a, []int64{1, 2, 3, 4, 5, 6}) e1r.Close() - //start of e2 + // start of e2 pos = int(e1NxtPos) - //insert e2 header, Timestamp + // insert e2 header, Timestamp e2ts := UnsafeReadInt64(buf, pos) diffts = curts - e2ts assert.LessOrEqual(t, diffts, maxdiff) pos += int(unsafe.Sizeof(e2ts)) - //insert e2 header, type code + // insert e2 header, type code e2tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(e2tc), DropCollectionEventType) pos += int(unsafe.Sizeof(e2tc)) - //insert e2 header, event length + // insert e2 header, event length e2EventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(e2EventLen)) - //insert e2 header, next position + // insert e2 header, next position e2NxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, e1NxtPos+e2EventLen, e2NxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //insert e2 data, start time stamp + // insert e2 data, start time stamp e2st := UnsafeReadInt64(buf, pos) assert.Equal(t, e2st, int64(300)) pos += int(unsafe.Sizeof(e2st)) - //insert e2 data, end time stamp + // insert e2 data, end time stamp e2et := UnsafeReadInt64(buf, pos) assert.Equal(t, e2et, int64(400)) pos += int(unsafe.Sizeof(e2et)) - //insert e2, payload + // insert e2, payload e2Payload := buf[pos:] e2r, err := NewPayloadReader(schemapb.DataType_Int64, e2Payload) assert.NoError(t, err) @@ -746,7 +746,7 @@ func TestDDLBinlog1(t *testing.T) { assert.Equal(t, int(e2NxtPos), len(buf)) - //read binlog + // read binlog r, err := NewBinlogReader(buf) assert.NoError(t, err) event1, err := r.NextEventReader() @@ -814,12 +814,12 @@ func TestDDLBinlog2(t *testing.T) { assert.NoError(t, err) w.Close() - //magic number + // magic number magicNum := UnsafeReadInt32(buf, 0) assert.Equal(t, magicNum, MagicNumber) pos := int(unsafe.Sizeof(MagicNumber)) - //descriptor header, timestamp + // descriptor header, timestamp ts := UnsafeReadInt64(buf, pos) assert.Greater(t, ts, int64(0)) curts := time.Now().UnixNano() / int64(time.Millisecond) @@ -829,63 +829,63 @@ func TestDDLBinlog2(t *testing.T) { assert.LessOrEqual(t, diffts, maxdiff) pos += int(unsafe.Sizeof(ts)) - //descriptor header, type code + // descriptor header, type code tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(tc), DescriptorEventType) pos += int(unsafe.Sizeof(tc)) - //descriptor header, event length + // descriptor header, event length descEventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(descEventLen)) - //descriptor header, next position + // descriptor header, next position descNxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, descEventLen+int32(unsafe.Sizeof(MagicNumber)), descNxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //descriptor data fix, collection id + // descriptor data fix, collection id collID := UnsafeReadInt64(buf, pos) assert.Equal(t, collID, int64(50)) pos += int(unsafe.Sizeof(collID)) - //descriptor data fix, partition id + // descriptor data fix, partition id partID := UnsafeReadInt64(buf, pos) assert.Equal(t, partID, int64(-1)) pos += int(unsafe.Sizeof(partID)) - //descriptor data fix, segment id + // descriptor data fix, segment id segID := UnsafeReadInt64(buf, pos) assert.Equal(t, segID, int64(-1)) pos += int(unsafe.Sizeof(segID)) - //descriptor data fix, field id + // descriptor data fix, field id fieldID := UnsafeReadInt64(buf, pos) assert.Equal(t, fieldID, int64(-1)) pos += int(unsafe.Sizeof(fieldID)) - //descriptor data fix, start time stamp + // descriptor data fix, start time stamp startts := UnsafeReadInt64(buf, pos) assert.Equal(t, startts, int64(1000)) pos += int(unsafe.Sizeof(startts)) - //descriptor data fix, end time stamp + // descriptor data fix, end time stamp endts := UnsafeReadInt64(buf, pos) assert.Equal(t, endts, int64(2000)) pos += int(unsafe.Sizeof(endts)) - //descriptor data fix, payload type + // descriptor data fix, payload type colType := UnsafeReadInt32(buf, pos) assert.Equal(t, schemapb.DataType(colType), schemapb.DataType_Int64) pos += int(unsafe.Sizeof(colType)) - //descriptor data, post header lengths + // descriptor data, post header lengths for i := DescriptorEventType; i < EventTypeEnd; i++ { size := getEventFixPartSize(i) assert.Equal(t, uint8(size), buf[pos]) pos++ } - //descriptor data, extra length + // descriptor data, extra length extraLength := UnsafeReadInt32(buf, pos) assert.Equal(t, extraLength, w.baseBinlogWriter.descriptorEventData.ExtraLength) pos += int(unsafe.Sizeof(extraLength)) @@ -906,40 +906,40 @@ func TestDDLBinlog2(t *testing.T) { assert.True(t, ok) assert.Equal(t, fmt.Sprintf("%v", sizeTotal), fmt.Sprintf("%v", size)) - //start of e1 + // start of e1 assert.Equal(t, pos, int(descNxtPos)) - //insert e1 header, Timestamp + // insert e1 header, Timestamp e1ts := UnsafeReadInt64(buf, pos) diffts = curts - e1ts assert.LessOrEqual(t, diffts, maxdiff) pos += int(unsafe.Sizeof(e1ts)) - //insert e1 header, type code + // insert e1 header, type code e1tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(e1tc), CreatePartitionEventType) pos += int(unsafe.Sizeof(e1tc)) - //insert e1 header, event length + // insert e1 header, event length e1EventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(e1EventLen)) - //insert e1 header, next position + // insert e1 header, next position e1NxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, descNxtPos+e1EventLen, e1NxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //insert e1 data, start time stamp + // insert e1 data, start time stamp e1st := UnsafeReadInt64(buf, pos) assert.Equal(t, e1st, int64(100)) pos += int(unsafe.Sizeof(e1st)) - //insert e1 data, end time stamp + // insert e1 data, end time stamp e1et := UnsafeReadInt64(buf, pos) assert.Equal(t, e1et, int64(200)) pos += int(unsafe.Sizeof(e1et)) - //insert e1, payload + // insert e1, payload e1Payload := buf[pos:e1NxtPos] e1r, err := NewPayloadReader(schemapb.DataType_Int64, e1Payload) assert.NoError(t, err) @@ -948,40 +948,40 @@ func TestDDLBinlog2(t *testing.T) { assert.Equal(t, e1a, []int64{1, 2, 3, 4, 5, 6}) e1r.Close() - //start of e2 + // start of e2 pos = int(e1NxtPos) - //insert e2 header, Timestamp + // insert e2 header, Timestamp e2ts := UnsafeReadInt64(buf, pos) diffts = curts - e2ts assert.LessOrEqual(t, diffts, maxdiff) pos += int(unsafe.Sizeof(e2ts)) - //insert e2 header, type code + // insert e2 header, type code e2tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(e2tc), DropPartitionEventType) pos += int(unsafe.Sizeof(e2tc)) - //insert e2 header, event length + // insert e2 header, event length e2EventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(e2EventLen)) - //insert e2 header, next position + // insert e2 header, next position e2NxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, e1NxtPos+e2EventLen, e2NxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //insert e2 data, start time stamp + // insert e2 data, start time stamp e2st := UnsafeReadInt64(buf, pos) assert.Equal(t, e2st, int64(300)) pos += int(unsafe.Sizeof(e2st)) - //insert e2 data, end time stamp + // insert e2 data, end time stamp e2et := UnsafeReadInt64(buf, pos) assert.Equal(t, e2et, int64(400)) pos += int(unsafe.Sizeof(e2et)) - //insert e2, payload + // insert e2, payload e2Payload := buf[pos:] e2r, err := NewPayloadReader(schemapb.DataType_Int64, e2Payload) assert.NoError(t, err) @@ -992,7 +992,7 @@ func TestDDLBinlog2(t *testing.T) { assert.Equal(t, int(e2NxtPos), len(buf)) - //read binlog + // read binlog r, err := NewBinlogReader(buf) assert.NoError(t, err) event1, err := r.NextEventReader() @@ -1060,73 +1060,73 @@ func TestIndexFileBinlog(t *testing.T) { w.Close() - //magic number + // magic number magicNum := UnsafeReadInt32(buf, 0) assert.Equal(t, magicNum, MagicNumber) pos := int(unsafe.Sizeof(MagicNumber)) - //descriptor header, timestamp + // descriptor header, timestamp ts := UnsafeReadInt64(buf, pos) assert.Greater(t, ts, int64(0)) pos += int(unsafe.Sizeof(ts)) - //descriptor header, type code + // descriptor header, type code tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(tc), DescriptorEventType) pos += int(unsafe.Sizeof(tc)) - //descriptor header, event length + // descriptor header, event length descEventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(descEventLen)) - //descriptor header, next position + // descriptor header, next position descNxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, descEventLen+int32(unsafe.Sizeof(MagicNumber)), descNxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //descriptor data fix, collection id + // descriptor data fix, collection id collID := UnsafeReadInt64(buf, pos) assert.Equal(t, collID, collectionID) pos += int(unsafe.Sizeof(collID)) - //descriptor data fix, partition id + // descriptor data fix, partition id partID := UnsafeReadInt64(buf, pos) assert.Equal(t, partID, partitionID) pos += int(unsafe.Sizeof(partID)) - //descriptor data fix, segment id + // descriptor data fix, segment id segID := UnsafeReadInt64(buf, pos) assert.Equal(t, segID, segmentID) pos += int(unsafe.Sizeof(segID)) - //descriptor data fix, field id + // descriptor data fix, field id fID := UnsafeReadInt64(buf, pos) - assert.Equal(t, fieldID, fieldID) + assert.Equal(t, fieldID, fID) pos += int(unsafe.Sizeof(fID)) - //descriptor data fix, start time stamp + // descriptor data fix, start time stamp startts := UnsafeReadInt64(buf, pos) assert.Equal(t, startts, int64(timestamp)) pos += int(unsafe.Sizeof(startts)) - //descriptor data fix, end time stamp + // descriptor data fix, end time stamp endts := UnsafeReadInt64(buf, pos) assert.Equal(t, endts, int64(timestamp)) pos += int(unsafe.Sizeof(endts)) - //descriptor data fix, payload type + // descriptor data fix, payload type colType := UnsafeReadInt32(buf, pos) assert.Equal(t, schemapb.DataType(colType), schemapb.DataType_Int8) pos += int(unsafe.Sizeof(colType)) - //descriptor data, post header lengths + // descriptor data, post header lengths for i := DescriptorEventType; i < EventTypeEnd; i++ { size := getEventFixPartSize(i) assert.Equal(t, uint8(size), buf[pos]) pos++ } - //descriptor data, extra length + // descriptor data, extra length extraLength := UnsafeReadInt32(buf, pos) assert.Equal(t, extraLength, w.baseBinlogWriter.descriptorEventData.ExtraLength) pos += int(unsafe.Sizeof(extraLength)) @@ -1189,73 +1189,73 @@ func TestIndexFileBinlogV2(t *testing.T) { w.Close() - //magic number + // magic number magicNum := UnsafeReadInt32(buf, 0) assert.Equal(t, magicNum, MagicNumber) pos := int(unsafe.Sizeof(MagicNumber)) - //descriptor header, timestamp + // descriptor header, timestamp ts := UnsafeReadInt64(buf, pos) assert.Greater(t, ts, int64(0)) pos += int(unsafe.Sizeof(ts)) - //descriptor header, type code + // descriptor header, type code tc := UnsafeReadInt8(buf, pos) assert.Equal(t, EventTypeCode(tc), DescriptorEventType) pos += int(unsafe.Sizeof(tc)) - //descriptor header, event length + // descriptor header, event length descEventLen := UnsafeReadInt32(buf, pos) pos += int(unsafe.Sizeof(descEventLen)) - //descriptor header, next position + // descriptor header, next position descNxtPos := UnsafeReadInt32(buf, pos) assert.Equal(t, descEventLen+int32(unsafe.Sizeof(MagicNumber)), descNxtPos) pos += int(unsafe.Sizeof(descNxtPos)) - //descriptor data fix, collection id + // descriptor data fix, collection id collID := UnsafeReadInt64(buf, pos) assert.Equal(t, collID, collectionID) pos += int(unsafe.Sizeof(collID)) - //descriptor data fix, partition id + // descriptor data fix, partition id partID := UnsafeReadInt64(buf, pos) assert.Equal(t, partID, partitionID) pos += int(unsafe.Sizeof(partID)) - //descriptor data fix, segment id + // descriptor data fix, segment id segID := UnsafeReadInt64(buf, pos) assert.Equal(t, segID, segmentID) pos += int(unsafe.Sizeof(segID)) - //descriptor data fix, field id + // descriptor data fix, field id fID := UnsafeReadInt64(buf, pos) - assert.Equal(t, fieldID, fieldID) + assert.Equal(t, fieldID, fID) pos += int(unsafe.Sizeof(fID)) - //descriptor data fix, start time stamp + // descriptor data fix, start time stamp startts := UnsafeReadInt64(buf, pos) assert.Equal(t, startts, int64(timestamp)) pos += int(unsafe.Sizeof(startts)) - //descriptor data fix, end time stamp + // descriptor data fix, end time stamp endts := UnsafeReadInt64(buf, pos) assert.Equal(t, endts, int64(timestamp)) pos += int(unsafe.Sizeof(endts)) - //descriptor data fix, payload type + // descriptor data fix, payload type colType := UnsafeReadInt32(buf, pos) assert.Equal(t, schemapb.DataType(colType), schemapb.DataType_String) pos += int(unsafe.Sizeof(colType)) - //descriptor data, post header lengths + // descriptor data, post header lengths for i := DescriptorEventType; i < EventTypeEnd; i++ { size := getEventFixPartSize(i) assert.Equal(t, uint8(size), buf[pos]) pos++ } - //descriptor data, extra length + // descriptor data, extra length extraLength := UnsafeReadInt32(buf, pos) assert.Equal(t, extraLength, w.baseBinlogWriter.descriptorEventData.ExtraLength) pos += int(unsafe.Sizeof(extraLength)) @@ -1482,6 +1482,7 @@ func (e *testEvent) GetMemoryUsageInBytes() (int32, error) { } return 0, nil } + func (e *testEvent) GetPayloadLengthFromWriter() (int, error) { if e.getPayloadLengthError { return -1, fmt.Errorf("getPayloadLength error") @@ -1493,7 +1494,6 @@ func (e *testEvent) ReleasePayloadWriter() { } func (e *testEvent) SetOffset(offset int32) { - } var _ EventWriter = (*testEvent)(nil) diff --git a/internal/storage/binlog_util_test.go b/internal/storage/binlog_util_test.go index f1c4c9e2ed629..de3c1de5d3357 100644 --- a/internal/storage/binlog_util_test.go +++ b/internal/storage/binlog_util_test.go @@ -7,7 +7,6 @@ import ( ) func TestParseSegmentIDByBinlog(t *testing.T) { - type testCase struct { name string input string diff --git a/internal/storage/binlog_writer.go b/internal/storage/binlog_writer.go index dffed7cc7f0db..0a38c68854564 100644 --- a/internal/storage/binlog_writer.go +++ b/internal/storage/binlog_writer.go @@ -41,6 +41,7 @@ const ( // StatsBinlog BinlogType for stats data StatsBinlog ) + const ( // MagicNumber used in binlog MagicNumber int32 = 0xfffabc diff --git a/internal/storage/binlog_writer_test.go b/internal/storage/binlog_writer_test.go index 69c473357b305..8bc80f66586bc 100644 --- a/internal/storage/binlog_writer_test.go +++ b/internal/storage/binlog_writer_test.go @@ -20,9 +20,9 @@ import ( "fmt" "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) func TestBinlogWriterReader(t *testing.T) { diff --git a/internal/storage/data_codec.go b/internal/storage/data_codec.go index 9599b253a0b68..f6c558745b5db 100644 --- a/internal/storage/data_codec.go +++ b/internal/storage/data_codec.go @@ -102,200 +102,6 @@ func (b Blob) GetValue() []byte { return b.Value } -// FieldData defines field data interface -type FieldData interface { - GetMemorySize() int - RowNum() int - GetRow(i int) interface{} -} - -type BoolFieldData struct { - Data []bool -} -type Int8FieldData struct { - Data []int8 -} -type Int16FieldData struct { - Data []int16 -} -type Int32FieldData struct { - Data []int32 -} -type Int64FieldData struct { - Data []int64 -} -type FloatFieldData struct { - Data []float32 -} -type DoubleFieldData struct { - Data []float64 -} -type StringFieldData struct { - Data []string -} -type ArrayFieldData struct { - ElementType schemapb.DataType - Data []*schemapb.ScalarField -} -type JSONFieldData struct { - Data [][]byte -} -type BinaryVectorFieldData struct { - Data []byte - Dim int -} -type FloatVectorFieldData struct { - Data []float32 - Dim int -} - -// RowNum implements FieldData.RowNum -func (data *BoolFieldData) RowNum() int { return len(data.Data) } -func (data *Int8FieldData) RowNum() int { return len(data.Data) } -func (data *Int16FieldData) RowNum() int { return len(data.Data) } -func (data *Int32FieldData) RowNum() int { return len(data.Data) } -func (data *Int64FieldData) RowNum() int { return len(data.Data) } -func (data *FloatFieldData) RowNum() int { return len(data.Data) } -func (data *DoubleFieldData) RowNum() int { return len(data.Data) } -func (data *StringFieldData) RowNum() int { return len(data.Data) } -func (data *BinaryVectorFieldData) RowNum() int { return len(data.Data) * 8 / data.Dim } -func (data *FloatVectorFieldData) RowNum() int { return len(data.Data) / data.Dim } -func (data *ArrayFieldData) RowNum() int { return len(data.Data) } -func (data *JSONFieldData) RowNum() int { return len(data.Data) } - -// GetRow implements FieldData.GetRow -func (data *BoolFieldData) GetRow(i int) any { return data.Data[i] } -func (data *Int8FieldData) GetRow(i int) any { return data.Data[i] } -func (data *Int16FieldData) GetRow(i int) any { return data.Data[i] } -func (data *Int32FieldData) GetRow(i int) any { return data.Data[i] } -func (data *Int64FieldData) GetRow(i int) any { return data.Data[i] } -func (data *FloatFieldData) GetRow(i int) any { return data.Data[i] } -func (data *DoubleFieldData) GetRow(i int) any { return data.Data[i] } -func (data *StringFieldData) GetRow(i int) any { return data.Data[i] } -func (data *ArrayFieldData) GetRow(i int) any { return data.Data[i] } -func (data *JSONFieldData) GetRow(i int) any { return data.Data[i] } -func (data *BinaryVectorFieldData) GetRow(i int) any { - return data.Data[i*data.Dim/8 : (i+1)*data.Dim/8] -} -func (data *FloatVectorFieldData) GetRow(i int) any { - return data.Data[i*data.Dim : (i+1)*data.Dim] -} - -// why not binary.Size(data) directly? binary.Size(data) return -1 -// binary.Size returns how many bytes Write would generate to encode the value v, which -// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data. -// If v is neither of these, binary.Size returns -1. - -// GetMemorySize implements FieldData.GetMemorySize -func (data *BoolFieldData) GetMemorySize() int { - return binary.Size(data.Data) -} - -// GetMemorySize implements FieldData.GetMemorySize -func (data *Int8FieldData) GetMemorySize() int { - return binary.Size(data.Data) -} - -// GetMemorySize implements FieldData.GetMemorySize -func (data *Int16FieldData) GetMemorySize() int { - return binary.Size(data.Data) -} - -// GetMemorySize implements FieldData.GetMemorySize -func (data *Int32FieldData) GetMemorySize() int { - return binary.Size(data.Data) -} - -// GetMemorySize implements FieldData.GetMemorySize -func (data *Int64FieldData) GetMemorySize() int { - return binary.Size(data.Data) -} - -func (data *FloatFieldData) GetMemorySize() int { - return binary.Size(data.Data) -} - -func (data *DoubleFieldData) GetMemorySize() int { - return binary.Size(data.Data) -} - -func (data *StringFieldData) GetMemorySize() int { - var size int - for _, val := range data.Data { - size += len(val) + 16 - } - return size -} - -func (data *ArrayFieldData) GetMemorySize() int { - var size int - for _, val := range data.Data { - switch data.ElementType { - case schemapb.DataType_Bool: - size += binary.Size(val.GetBoolData().GetData()) - case schemapb.DataType_Int8: - size += binary.Size(val.GetIntData().GetData()) / 4 - case schemapb.DataType_Int16: - size += binary.Size(val.GetIntData().GetData()) / 2 - case schemapb.DataType_Int32: - size += binary.Size(val.GetIntData().GetData()) - case schemapb.DataType_Float: - size += binary.Size(val.GetFloatData().GetData()) - case schemapb.DataType_Double: - size += binary.Size(val.GetDoubleData().GetData()) - case schemapb.DataType_String, schemapb.DataType_VarChar: - size += (&StringFieldData{Data: val.GetStringData().GetData()}).GetMemorySize() - } - } - return size -} - -func (data *JSONFieldData) GetMemorySize() int { - var size int - for _, val := range data.Data { - size += len(val) + 16 - } - return size -} - -func (data *BinaryVectorFieldData) GetMemorySize() int { - return binary.Size(data.Data) + 4 -} - -func (data *FloatVectorFieldData) GetMemorySize() int { - return binary.Size(data.Data) + 4 -} - -// system field id: -// 0: unique row id -// 1: timestamp -// 100: first user field id -// 101: second user field id -// 102: ... - -// TODO: fill it -// info for each blob -type BlobInfo struct { - Length int -} - -// InsertData example row_schema: {float_field, int_field, float_vector_field, string_field} -// Data {<0, row_id>, <1, timestamp>, <100, float_field>, <101, int_field>, <102, float_vector_field>, <103, string_field>} -type InsertData struct { - // Todo, data should be zero copy by passing data directly to event reader or change Data to map[FieldID]FieldDataArray - Data map[FieldID]FieldData // field id to field data - Infos []BlobInfo -} - -func (iData *InsertData) IsEmpty() bool { - if iData == nil { - return true - } - - timeFieldData, ok := iData.Data[common.TimeStampField] - return (!ok) || (timeFieldData.RowNum() <= 0) -} - // InsertCodec serializes and deserializes the insert data // Blob key example: // ${tenant}/insert_log/${collection_id}/${partition_id}/${segment_id}/${field_id}/${log_idx} @@ -319,7 +125,7 @@ func (insertCodec *InsertCodec) SerializePkStats(stats *PrimaryKeyStats, rowNum return nil, fmt.Errorf("sericalize empty pk stats") } - //Serialize by pk stats + // Serialize by pk stats blobKey := fmt.Sprintf("%d", stats.FieldID) statsWriter := &StatsWriter{} err := statsWriter.Generate(stats) @@ -438,6 +244,8 @@ func (insertCodec *InsertCodec) Serialize(partitionID UniqueID, segmentID Unique eventWriter, err = writer.NextInsertEventWriter(singleData.(*FloatVectorFieldData).Dim) case schemapb.DataType_BinaryVector: eventWriter, err = writer.NextInsertEventWriter(singleData.(*BinaryVectorFieldData).Dim) + case schemapb.DataType_Float16Vector: + eventWriter, err = writer.NextInsertEventWriter(singleData.(*Float16VectorFieldData).Dim) default: return nil, fmt.Errorf("undefined data type %d", field.DataType) } @@ -553,6 +361,14 @@ func (insertCodec *InsertCodec) Serialize(partitionID UniqueID, segmentID Unique return nil, err } writer.AddExtra(originalSizeKey, fmt.Sprintf("%v", singleData.(*FloatVectorFieldData).GetMemorySize())) + case schemapb.DataType_Float16Vector: + err = eventWriter.AddFloat16VectorToPayload(singleData.(*Float16VectorFieldData).Data, singleData.(*Float16VectorFieldData).Dim) + if err != nil { + eventWriter.Close() + writer.Close() + return nil, err + } + writer.AddExtra(originalSizeKey, fmt.Sprintf("%v", singleData.(*Float16VectorFieldData).GetMemorySize())) default: return nil, fmt.Errorf("undefined data type %d", field.DataType) } @@ -857,6 +673,33 @@ func (insertCodec *InsertCodec) DeserializeInto(fieldBinlogs []*Blob, rowNum int binaryVectorFieldData.Dim = dim insertData.Data[fieldID] = binaryVectorFieldData + case schemapb.DataType_Float16Vector: + var singleData []byte + singleData, dim, err = eventReader.GetFloat16VectorFromPayload() + if err != nil { + eventReader.Close() + binlogReader.Close() + return InvalidUniqueID, InvalidUniqueID, InvalidUniqueID, err + } + + if insertData.Data[fieldID] == nil { + insertData.Data[fieldID] = &Float16VectorFieldData{ + Data: make([]byte, 0, rowNum*dim), + } + } + float16VectorFieldData := insertData.Data[fieldID].(*Float16VectorFieldData) + + float16VectorFieldData.Data = append(float16VectorFieldData.Data, singleData...) + length, err := eventReader.GetPayloadLengthFromReader() + if err != nil { + eventReader.Close() + binlogReader.Close() + return InvalidUniqueID, InvalidUniqueID, InvalidUniqueID, err + } + totalLength += length + float16VectorFieldData.Dim = dim + insertData.Data[fieldID] = float16VectorFieldData + case schemapb.DataType_FloatVector: var singleData []float32 singleData, dim, err = eventReader.GetFloatVectorFromPayload() @@ -1006,8 +849,7 @@ func (data *DeleteData) Append(pk PrimaryKey, ts Timestamp) { } // DeleteCodec serializes and deserializes the delete data -type DeleteCodec struct { -} +type DeleteCodec struct{} // NewDeleteCodec returns a DeleteCodec func NewDeleteCodec() *DeleteCodec { @@ -1138,7 +980,6 @@ func (deleteCodec *DeleteCodec) Deserialize(blobs []*Blob) (partitionID UniqueID } eventReader.Close() binlogReader.Close() - } result.RowCount = int64(len(result.Pks)) @@ -1328,7 +1169,6 @@ func (dataDefinitionCodec *DataDefinitionCodec) Deserialize(blobs []*Blob) (ts [ eventReader.Close() } binlogReader.Close() - } return resultTs, requestsStrings, nil diff --git a/internal/storage/data_codec_test.go b/internal/storage/data_codec_test.go index 395a001d15c03..d05068255be61 100644 --- a/internal/storage/data_codec_test.go +++ b/internal/storage/data_codec_test.go @@ -21,36 +21,39 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/assert" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/etcdpb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" - "github.com/stretchr/testify/assert" ) const ( - CollectionID = 1 - PartitionID = 1 - SegmentID = 1 - RowIDField = 0 - TimestampField = 1 - BoolField = 100 - Int8Field = 101 - Int16Field = 102 - Int32Field = 103 - Int64Field = 104 - FloatField = 105 - DoubleField = 106 - StringField = 107 - BinaryVectorField = 108 - FloatVectorField = 109 - ArrayField = 110 - JSONField = 111 + CollectionID = 1 + PartitionID = 1 + SegmentID = 1 + RowIDField = 0 + TimestampField = 1 + BoolField = 100 + Int8Field = 101 + Int16Field = 102 + Int32Field = 103 + Int64Field = 104 + FloatField = 105 + DoubleField = 106 + StringField = 107 + BinaryVectorField = 108 + FloatVectorField = 109 + ArrayField = 110 + JSONField = 111 + Float16VectorField = 112 ) -func TestInsertCodec(t *testing.T) { - schema := &etcdpb.CollectionMeta{ +func genTestCollectionMeta() *etcdpb.CollectionMeta { + return &etcdpb.CollectionMeta{ ID: CollectionID, CreateTime: 1, SegmentIDs: []int64{SegmentID}, @@ -61,46 +64,40 @@ func TestInsertCodec(t *testing.T) { AutoID: true, Fields: []*schemapb.FieldSchema{ { - FieldID: RowIDField, - Name: "row_id", - IsPrimaryKey: false, - Description: "row_id", - DataType: schemapb.DataType_Int64, + FieldID: RowIDField, + Name: "row_id", + Description: "row_id", + DataType: schemapb.DataType_Int64, }, { - FieldID: TimestampField, - Name: "Timestamp", - IsPrimaryKey: false, - Description: "Timestamp", - DataType: schemapb.DataType_Int64, + FieldID: TimestampField, + Name: "Timestamp", + Description: "Timestamp", + DataType: schemapb.DataType_Int64, }, { - FieldID: BoolField, - Name: "field_bool", - IsPrimaryKey: false, - Description: "bool", - DataType: schemapb.DataType_Bool, + FieldID: BoolField, + Name: "field_bool", + Description: "bool", + DataType: schemapb.DataType_Bool, }, { - FieldID: Int8Field, - Name: "field_int8", - IsPrimaryKey: false, - Description: "int8", - DataType: schemapb.DataType_Int8, + FieldID: Int8Field, + Name: "field_int8", + Description: "int8", + DataType: schemapb.DataType_Int8, }, { - FieldID: Int16Field, - Name: "field_int16", - IsPrimaryKey: false, - Description: "int16", - DataType: schemapb.DataType_Int16, + FieldID: Int16Field, + Name: "field_int16", + Description: "int16", + DataType: schemapb.DataType_Int16, }, { - FieldID: Int32Field, - Name: "field_int32", - IsPrimaryKey: false, - Description: "int32", - DataType: schemapb.DataType_Int32, + FieldID: Int32Field, + Name: "field_int32", + Description: "int32", + DataType: schemapb.DataType_Int32, }, { FieldID: Int64Field, @@ -110,25 +107,22 @@ func TestInsertCodec(t *testing.T) { DataType: schemapb.DataType_Int64, }, { - FieldID: FloatField, - Name: "field_float", - IsPrimaryKey: false, - Description: "float", - DataType: schemapb.DataType_Float, + FieldID: FloatField, + Name: "field_float", + Description: "float", + DataType: schemapb.DataType_Float, }, { - FieldID: DoubleField, - Name: "field_double", - IsPrimaryKey: false, - Description: "double", - DataType: schemapb.DataType_Double, + FieldID: DoubleField, + Name: "field_double", + Description: "double", + DataType: schemapb.DataType_Double, }, { - FieldID: StringField, - Name: "field_string", - IsPrimaryKey: false, - Description: "string", - DataType: schemapb.DataType_String, + FieldID: StringField, + Name: "field_string", + Description: "string", + DataType: schemapb.DataType_String, }, { FieldID: ArrayField, @@ -144,22 +138,48 @@ func TestInsertCodec(t *testing.T) { DataType: schemapb.DataType_JSON, }, { - FieldID: BinaryVectorField, - Name: "field_binary_vector", - IsPrimaryKey: false, - Description: "binary_vector", - DataType: schemapb.DataType_BinaryVector, + FieldID: BinaryVectorField, + Name: "field_binary_vector", + Description: "binary_vector", + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, }, { - FieldID: FloatVectorField, - Name: "field_float_vector", - IsPrimaryKey: false, - Description: "float_vector", - DataType: schemapb.DataType_FloatVector, + FieldID: FloatVectorField, + Name: "field_float_vector", + Description: "float_vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "4", + }, + }, + }, + { + FieldID: Float16VectorField, + Name: "field_float16_vector", + Description: "float16_vector", + DataType: schemapb.DataType_Float16Vector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "4", + }, + }, }, }, }, } +} + +func TestInsertCodec(t *testing.T) { + schema := genTestCollectionMeta() insertCodec := NewInsertCodecWithSchema(schema) insertData1 := &InsertData{ Data: map[int64]FieldData{ @@ -222,6 +242,11 @@ func TestInsertCodec(t *testing.T) { []byte(`{"key":"world"}`), }, }, + Float16VectorField: &Float16VectorFieldData{ + // length = 2 * Dim * numRows(2) = 16 + Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}, + Dim: 4, + }, }, } @@ -265,6 +290,11 @@ func TestInsertCodec(t *testing.T) { Data: []float32{0, 1, 2, 3, 0, 1, 2, 3}, Dim: 4, }, + Float16VectorField: &Float16VectorFieldData{ + // length = 2 * Dim * numRows(2) = 16 + Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}, + Dim: 4, + }, ArrayField: &ArrayFieldData{ ElementType: schemapb.DataType_Int32, Data: []*schemapb.ScalarField{ @@ -291,20 +321,21 @@ func TestInsertCodec(t *testing.T) { insertDataEmpty := &InsertData{ Data: map[int64]FieldData{ - RowIDField: &Int64FieldData{[]int64{}}, - TimestampField: &Int64FieldData{[]int64{}}, - BoolField: &BoolFieldData{[]bool{}}, - Int8Field: &Int8FieldData{[]int8{}}, - Int16Field: &Int16FieldData{[]int16{}}, - Int32Field: &Int32FieldData{[]int32{}}, - Int64Field: &Int64FieldData{[]int64{}}, - FloatField: &FloatFieldData{[]float32{}}, - DoubleField: &DoubleFieldData{[]float64{}}, - StringField: &StringFieldData{[]string{}}, - BinaryVectorField: &BinaryVectorFieldData{[]byte{}, 8}, - FloatVectorField: &FloatVectorFieldData{[]float32{}, 4}, - ArrayField: &ArrayFieldData{schemapb.DataType_Int32, []*schemapb.ScalarField{}}, - JSONField: &JSONFieldData{[][]byte{}}, + RowIDField: &Int64FieldData{[]int64{}}, + TimestampField: &Int64FieldData{[]int64{}}, + BoolField: &BoolFieldData{[]bool{}}, + Int8Field: &Int8FieldData{[]int8{}}, + Int16Field: &Int16FieldData{[]int16{}}, + Int32Field: &Int32FieldData{[]int32{}}, + Int64Field: &Int64FieldData{[]int64{}}, + FloatField: &FloatFieldData{[]float32{}}, + DoubleField: &DoubleFieldData{[]float64{}}, + StringField: &StringFieldData{[]string{}}, + BinaryVectorField: &BinaryVectorFieldData{[]byte{}, 8}, + FloatVectorField: &FloatVectorFieldData{[]float32{}, 4}, + Float16VectorField: &Float16VectorFieldData{[]byte{}, 4}, + ArrayField: &ArrayFieldData{schemapb.DataType_Int32, []*schemapb.ScalarField{}}, + JSONField: &JSONFieldData{[][]byte{}}, }, } b, err := insertCodec.Serialize(PartitionID, SegmentID, insertDataEmpty) @@ -345,6 +376,12 @@ func TestInsertCodec(t *testing.T) { assert.Equal(t, []string{"1", "2", "3", "4"}, resultData.Data[StringField].(*StringFieldData).Data) assert.Equal(t, []byte{0, 255, 0, 255}, resultData.Data[BinaryVectorField].(*BinaryVectorFieldData).Data) assert.Equal(t, []float32{0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 4, 5, 6, 7}, resultData.Data[FloatVectorField].(*FloatVectorFieldData).Data) + assert.Equal(t, []byte{ + 0, 255, 0, 255, 0, 255, 0, 255, + 0, 255, 0, 255, 0, 255, 0, 255, + 0, 255, 0, 255, 0, 255, 0, 255, + 0, 255, 0, 255, 0, 255, 0, 255, + }, resultData.Data[Float16VectorField].(*Float16VectorFieldData).Data) int32ArrayList := [][]int32{{1, 2, 3}, {4, 5, 6}, {3, 2, 1}, {6, 5, 4}} resultArrayList := [][]int32{} @@ -665,5 +702,4 @@ func TestMemorySize(t *testing.T) { assert.Equal(t, insertDataEmpty.Data[StringField].GetMemorySize(), 0) assert.Equal(t, insertDataEmpty.Data[BinaryVectorField].GetMemorySize(), 4) assert.Equal(t, insertDataEmpty.Data[FloatVectorField].GetMemorySize(), 4) - } diff --git a/internal/storage/data_sorter.go b/internal/storage/data_sorter.go index 9b2c5115f6743..21e3e5e7ffda8 100644 --- a/internal/storage/data_sorter.go +++ b/internal/storage/data_sorter.go @@ -94,6 +94,13 @@ func (ds *DataSorter) Swap(i, j int) { for idx := 0; idx < dim; idx++ { data[i*dim+idx], data[j*dim+idx] = data[j*dim+idx], data[i*dim+idx] } + case schemapb.DataType_Float16Vector: + data := singleData.(*Float16VectorFieldData).Data + dim := singleData.(*Float16VectorFieldData).Dim + steps := dim * 2 + for idx := 0; idx < steps; idx++ { + data[i*steps+idx], data[j*steps+idx] = data[j*steps+idx], data[i*steps+idx] + } case schemapb.DataType_Array: data := singleData.(*ArrayFieldData).Data data[i], data[j] = data[j], data[i] diff --git a/internal/storage/data_sorter_test.go b/internal/storage/data_sorter_test.go index f1480444f11b7..8a9ed44b85081 100644 --- a/internal/storage/data_sorter_test.go +++ b/internal/storage/data_sorter_test.go @@ -20,9 +20,10 @@ import ( "sort" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/etcdpb" - "github.com/stretchr/testify/assert" ) func TestDataSorter(t *testing.T) { @@ -120,6 +121,13 @@ func TestDataSorter(t *testing.T) { Description: "description_11", DataType: schemapb.DataType_FloatVector, }, + { + FieldID: 110, + Name: "field_float16_vector", + IsPrimaryKey: false, + Description: "description_12", + DataType: schemapb.DataType_Float16Vector, + }, }, }, } @@ -165,6 +173,10 @@ func TestDataSorter(t *testing.T) { Data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, Dim: 8, }, + 110: &Float16VectorFieldData{ + Data: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + Dim: 4, + }, }, } @@ -226,6 +238,7 @@ func TestDataSorter(t *testing.T) { assert.Equal(t, []string{"5", "3", "4"}, dataSorter.InsertData.Data[107].(*StringFieldData).Data) assert.Equal(t, []byte{128, 0, 255}, dataSorter.InsertData.Data[108].(*BinaryVectorFieldData).Data) assert.Equal(t, []float32{16, 17, 18, 19, 20, 21, 22, 23, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, dataSorter.InsertData.Data[109].(*FloatVectorFieldData).Data) + assert.Equal(t, []byte{16, 17, 18, 19, 20, 21, 22, 23, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, dataSorter.InsertData.Data[110].(*Float16VectorFieldData).Data) } func TestDataSorter_Len(t *testing.T) { diff --git a/internal/storage/event_data.go b/internal/storage/event_data.go index a4a28d0b8a491..2b0c9baa6f621 100644 --- a/internal/storage/event_data.go +++ b/internal/storage/event_data.go @@ -65,7 +65,6 @@ func (data *descriptorEventData) GetEventDataFixPartSize() int32 { // GetMemoryUsageInBytes returns the memory size of DescriptorEventDataFixPart. func (data *descriptorEventData) GetMemoryUsageInBytes() int32 { return data.GetEventDataFixPartSize() + int32(binary.Size(data.PostHeaderLengths)) + int32(binary.Size(data.ExtraLength)) + data.ExtraLength - } // AddExtra add extra params to description event. @@ -368,36 +367,42 @@ func newInsertEventData() *insertEventData { EndTimestamp: 0, } } + func newDeleteEventData() *deleteEventData { return &deleteEventData{ StartTimestamp: 0, EndTimestamp: 0, } } + func newCreateCollectionEventData() *createCollectionEventData { return &createCollectionEventData{ StartTimestamp: 0, EndTimestamp: 0, } } + func newDropCollectionEventData() *dropCollectionEventData { return &dropCollectionEventData{ StartTimestamp: 0, EndTimestamp: 0, } } + func newCreatePartitionEventData() *createPartitionEventData { return &createPartitionEventData{ StartTimestamp: 0, EndTimestamp: 0, } } + func newDropPartitionEventData() *dropPartitionEventData { return &dropPartitionEventData{ StartTimestamp: 0, EndTimestamp: 0, } } + func newIndexFileEventData() *indexFileEventData { return &indexFileEventData{ StartTimestamp: 0, diff --git a/internal/storage/event_test.go b/internal/storage/event_test.go index 848188a6e6211..e432e3a829816 100644 --- a/internal/storage/event_test.go +++ b/internal/storage/event_test.go @@ -24,12 +24,13 @@ import ( "time" "unsafe" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/stretchr/testify/assert" ) /* #nosec G103 */ @@ -832,7 +833,6 @@ func TestDropPartitionEvent(t *testing.T) { r.Close() }) - } /* #nosec G103 */ @@ -1081,7 +1081,6 @@ func TestEventReaderError(t *testing.T) { r, err = newEventReader(schemapb.DataType_Int64, buf) assert.Nil(t, r) assert.Error(t, err) - } func TestEventClose(t *testing.T) { diff --git a/internal/storage/event_writer_test.go b/internal/storage/event_writer_test.go index 2c1d2a3d7242f..a6b6456159432 100644 --- a/internal/storage/event_writer_test.go +++ b/internal/storage/event_writer_test.go @@ -21,9 +21,10 @@ import ( "encoding/binary" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" - "github.com/stretchr/testify/assert" ) func TestEventTypeCode_String(t *testing.T) { diff --git a/internal/storage/factory.go b/internal/storage/factory.go index 293e32fe89646..24a71cbf2985d 100644 --- a/internal/storage/factory.go +++ b/internal/storage/factory.go @@ -17,7 +17,7 @@ func NewChunkManagerFactoryWithParam(params *paramtable.ComponentParam) *ChunkMa if params.CommonCfg.StorageType.GetValue() == "local" { return NewChunkManagerFactory("local", RootPath(params.LocalStorageCfg.Path.GetValue())) } - return NewChunkManagerFactory("minio", + return NewChunkManagerFactory(params.CommonCfg.StorageType.GetValue(), RootPath(params.MinioCfg.RootPath.GetValue()), Address(params.MinioCfg.Address.GetValue()), AccessKeyID(params.MinioCfg.AccessKeyID.GetValue()), @@ -29,6 +29,7 @@ func NewChunkManagerFactoryWithParam(params *paramtable.ComponentParam) *ChunkMa IAMEndpoint(params.MinioCfg.IAMEndpoint.GetValue()), UseVirtualHost(params.MinioCfg.UseVirtualHost.GetAsBool()), Region(params.MinioCfg.Region.GetValue()), + RequestTimeout(params.MinioCfg.RequestTimeoutMs.GetAsInt64()), CreateBucket(true)) } @@ -49,6 +50,8 @@ func (f *ChunkManagerFactory) newChunkManager(ctx context.Context, engine string return NewLocalChunkManager(RootPath(f.config.rootPath)), nil case "minio": return newMinioChunkManagerWithConfig(ctx, f.config) + case "remote": + return NewRemoteChunkManager(ctx, f.config) default: return nil, errors.New("no chunk manager implemented with engine: " + engine) } diff --git a/internal/storage/gcp/gcp_test.go b/internal/storage/gcp/gcp_test.go index 9695316a3d82d..990166f0be037 100644 --- a/internal/storage/gcp/gcp_test.go +++ b/internal/storage/gcp/gcp_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "github.com/stretchr/testify/assert" @@ -100,5 +99,4 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) { _, err = ts.RoundTrip(req) assert.Error(t, err) }) - } diff --git a/internal/storage/index_data_codec.go b/internal/storage/index_data_codec.go index 55343eb2b9a7c..0e928c8223736 100644 --- a/internal/storage/index_data_codec.go +++ b/internal/storage/index_data_codec.go @@ -23,7 +23,6 @@ import ( "time" "github.com/cockroachdb/errors" - "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -31,8 +30,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -type IndexFileBinlogCodec struct { -} +type IndexFileBinlogCodec struct{} // NewIndexFileBinlogCodec is constructor for IndexFileBinlogCodec func NewIndexFileBinlogCodec() *IndexFileBinlogCodec { @@ -85,7 +83,7 @@ func (codec *IndexFileBinlogCodec) serializeImpl( return &Blob{ Key: key, - //Key: strconv.Itoa(len(datas)), + // Key: strconv.Itoa(len(datas)), Value: buffer, }, nil } @@ -100,7 +98,8 @@ func (codec *IndexFileBinlogCodec) SerializeIndexParams( fieldID UniqueID, indexParams map[string]string, indexName string, - indexID UniqueID) (*Blob, error) { + indexID UniqueID, +) (*Blob, error) { ts := Timestamp(time.Now().UnixNano()) // save index params. @@ -126,7 +125,6 @@ func (codec *IndexFileBinlogCodec) Serialize( indexID UniqueID, datas []*Blob, ) ([]*Blob, error) { - var err error var blobs []*Blob @@ -268,7 +266,6 @@ func (codec *IndexFileBinlogCodec) DeserializeImpl(blobs []*Blob) ( eventReader.Close() } binlogReader.Close() - } return indexBuildID, version, collectionID, partitionID, segmentID, fieldID, indexParams, indexName, indexID, datas, nil @@ -286,8 +283,7 @@ func (codec *IndexFileBinlogCodec) Deserialize(blobs []*Blob) ( } // IndexCodec can serialize and deserialize index -type IndexCodec struct { -} +type IndexCodec struct{} // NewIndexCodec creates IndexCodec func NewIndexCodec() *IndexCodec { diff --git a/internal/storage/insert_data.go b/internal/storage/insert_data.go new file mode 100644 index 0000000000000..7b568b9254891 --- /dev/null +++ b/internal/storage/insert_data.go @@ -0,0 +1,463 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "encoding/binary" + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +// TODO: fill it +// info for each blob +type BlobInfo struct { + Length int +} + +// InsertData example row_schema: {float_field, int_field, float_vector_field, string_field} +// Data {<0, row_id>, <1, timestamp>, <100, float_field>, <101, int_field>, <102, float_vector_field>, <103, string_field>} +// +// system filed id: +// 0: unique row id +// 1: timestamp +// 100: first user field id +// 101: second user field id +// 102: ... +type InsertData struct { + // TODO, data should be zero copy by passing data directly to event reader or change Data to map[FieldID]FieldDataArray + Data map[FieldID]FieldData // field id to field data + Infos []BlobInfo +} + +func NewInsertData(schema *schemapb.CollectionSchema) (*InsertData, error) { + if schema == nil { + return nil, fmt.Errorf("Nil input schema") + } + + idata := &InsertData{ + Data: make(map[FieldID]FieldData), + } + + for _, fSchema := range schema.Fields { + fieldData, err := NewFieldData(fSchema.DataType, fSchema) + if err != nil { + return nil, err + } + idata.Data[fSchema.FieldID] = fieldData + } + return idata, nil +} + +func (iData *InsertData) IsEmpty() bool { + if iData == nil { + return true + } + + timeFieldData, ok := iData.Data[common.TimeStampField] + return (!ok) || (timeFieldData.RowNum() <= 0) +} + +func (i *InsertData) GetRowNum() int { + if i.Data == nil || len(i.Data) == 0 { + return 0 + } + + data, ok := i.Data[common.RowIDField] + if !ok { + return 0 + } + + return data.RowNum() +} + +func (i *InsertData) GetMemorySize() int { + var size int + if i.Data == nil || len(i.Data) == 0 { + return size + } + + for _, data := range i.Data { + size += data.GetMemorySize() + } + + return size +} + +func (i *InsertData) Append(row map[FieldID]interface{}) error { + for fID, v := range row { + field, ok := i.Data[fID] + if !ok { + return fmt.Errorf("Missing field when appending row, got %d", fID) + } + + if err := field.AppendRow(v); err != nil { + return err + } + } + + return nil +} + +// FieldData defines field data interface +type FieldData interface { + GetMemorySize() int + RowNum() int + GetRow(i int) any + AppendRow(row interface{}) error +} + +func NewFieldData(dataType schemapb.DataType, fieldSchema *schemapb.FieldSchema) (FieldData, error) { + typeParams := fieldSchema.GetTypeParams() + switch dataType { + case schemapb.DataType_Float16Vector: + dim, err := GetDimFromParams(typeParams) + if err != nil { + return nil, err + } + return &Float16VectorFieldData{ + Data: make([]byte, 0), + Dim: dim, + }, nil + case schemapb.DataType_FloatVector: + dim, err := GetDimFromParams(typeParams) + if err != nil { + return nil, err + } + return &FloatVectorFieldData{ + Data: make([]float32, 0), + Dim: dim, + }, nil + case schemapb.DataType_BinaryVector: + dim, err := GetDimFromParams(typeParams) + if err != nil { + return nil, err + } + return &BinaryVectorFieldData{ + Data: make([]byte, 0), + Dim: dim, + }, nil + + case schemapb.DataType_Bool: + return &BoolFieldData{ + Data: make([]bool, 0), + }, nil + + case schemapb.DataType_Int8: + return &Int8FieldData{ + Data: make([]int8, 0), + }, nil + + case schemapb.DataType_Int16: + return &Int16FieldData{ + Data: make([]int16, 0), + }, nil + + case schemapb.DataType_Int32: + return &Int32FieldData{ + Data: make([]int32, 0), + }, nil + + case schemapb.DataType_Int64: + return &Int64FieldData{ + Data: make([]int64, 0), + }, nil + case schemapb.DataType_Float: + return &FloatFieldData{ + Data: make([]float32, 0), + }, nil + + case schemapb.DataType_Double: + return &DoubleFieldData{ + Data: make([]float64, 0), + }, nil + case schemapb.DataType_JSON: + return &JSONFieldData{ + Data: make([][]byte, 0), + }, nil + case schemapb.DataType_Array: + return &ArrayFieldData{ + Data: make([]*schemapb.ScalarField, 0), + ElementType: fieldSchema.GetElementType(), + }, nil + case schemapb.DataType_String, schemapb.DataType_VarChar: + return &StringFieldData{ + Data: make([]string, 0), + }, nil + default: + return nil, fmt.Errorf("Unexpected schema data type: %d", dataType) + } +} + +type BoolFieldData struct { + Data []bool +} +type Int8FieldData struct { + Data []int8 +} +type Int16FieldData struct { + Data []int16 +} +type Int32FieldData struct { + Data []int32 +} +type Int64FieldData struct { + Data []int64 +} +type FloatFieldData struct { + Data []float32 +} +type DoubleFieldData struct { + Data []float64 +} +type StringFieldData struct { + Data []string +} +type ArrayFieldData struct { + ElementType schemapb.DataType + Data []*schemapb.ScalarField +} +type JSONFieldData struct { + Data [][]byte +} +type BinaryVectorFieldData struct { + Data []byte + Dim int +} +type FloatVectorFieldData struct { + Data []float32 + Dim int +} +type Float16VectorFieldData struct { + Data []byte + Dim int +} + +// RowNum implements FieldData.RowNum +func (data *BoolFieldData) RowNum() int { return len(data.Data) } +func (data *Int8FieldData) RowNum() int { return len(data.Data) } +func (data *Int16FieldData) RowNum() int { return len(data.Data) } +func (data *Int32FieldData) RowNum() int { return len(data.Data) } +func (data *Int64FieldData) RowNum() int { return len(data.Data) } +func (data *FloatFieldData) RowNum() int { return len(data.Data) } +func (data *DoubleFieldData) RowNum() int { return len(data.Data) } +func (data *StringFieldData) RowNum() int { return len(data.Data) } +func (data *ArrayFieldData) RowNum() int { return len(data.Data) } +func (data *JSONFieldData) RowNum() int { return len(data.Data) } +func (data *BinaryVectorFieldData) RowNum() int { return len(data.Data) * 8 / data.Dim } +func (data *FloatVectorFieldData) RowNum() int { return len(data.Data) / data.Dim } +func (data *Float16VectorFieldData) RowNum() int { return len(data.Data) / 2 / data.Dim } + +// GetRow implements FieldData.GetRow +func (data *BoolFieldData) GetRow(i int) any { return data.Data[i] } +func (data *Int8FieldData) GetRow(i int) any { return data.Data[i] } +func (data *Int16FieldData) GetRow(i int) any { return data.Data[i] } +func (data *Int32FieldData) GetRow(i int) any { return data.Data[i] } +func (data *Int64FieldData) GetRow(i int) any { return data.Data[i] } +func (data *FloatFieldData) GetRow(i int) any { return data.Data[i] } +func (data *DoubleFieldData) GetRow(i int) any { return data.Data[i] } +func (data *StringFieldData) GetRow(i int) any { return data.Data[i] } +func (data *ArrayFieldData) GetRow(i int) any { return data.Data[i] } +func (data *JSONFieldData) GetRow(i int) any { return data.Data[i] } +func (data *BinaryVectorFieldData) GetRow(i int) interface{} { + return data.Data[i*data.Dim/8 : (i+1)*data.Dim/8] +} + +func (data *FloatVectorFieldData) GetRow(i int) interface{} { + return data.Data[i*data.Dim : (i+1)*data.Dim] +} + +func (data *Float16VectorFieldData) GetRow(i int) interface{} { + return data.Data[i*data.Dim*2 : (i+1)*data.Dim*2] +} + +// AppendRow implements FieldData.AppendRow +func (data *BoolFieldData) AppendRow(row interface{}) error { + v, ok := row.(bool) + if !ok { + return merr.WrapErrParameterInvalid("bool", row, "Wrong row type") + } + data.Data = append(data.Data, v) + return nil +} + +func (data *Int8FieldData) AppendRow(row interface{}) error { + v, ok := row.(int8) + if !ok { + return merr.WrapErrParameterInvalid("int8", row, "Wrong row type") + } + data.Data = append(data.Data, v) + return nil +} + +func (data *Int16FieldData) AppendRow(row interface{}) error { + v, ok := row.(int16) + if !ok { + return merr.WrapErrParameterInvalid("int16", row, "Wrong row type") + } + data.Data = append(data.Data, v) + return nil +} + +func (data *Int32FieldData) AppendRow(row interface{}) error { + v, ok := row.(int32) + if !ok { + return merr.WrapErrParameterInvalid("int32", row, "Wrong row type") + } + data.Data = append(data.Data, v) + return nil +} + +func (data *Int64FieldData) AppendRow(row interface{}) error { + v, ok := row.(int64) + if !ok { + return merr.WrapErrParameterInvalid("int64", row, "Wrong row type") + } + data.Data = append(data.Data, v) + return nil +} + +func (data *FloatFieldData) AppendRow(row interface{}) error { + v, ok := row.(float32) + if !ok { + return merr.WrapErrParameterInvalid("float32", row, "Wrong row type") + } + data.Data = append(data.Data, v) + return nil +} + +func (data *DoubleFieldData) AppendRow(row interface{}) error { + v, ok := row.(float64) + if !ok { + return merr.WrapErrParameterInvalid("float64", row, "Wrong row type") + } + data.Data = append(data.Data, v) + return nil +} + +func (data *StringFieldData) AppendRow(row interface{}) error { + v, ok := row.(string) + if !ok { + return merr.WrapErrParameterInvalid("string", row, "Wrong row type") + } + data.Data = append(data.Data, v) + return nil +} + +func (data *ArrayFieldData) AppendRow(row interface{}) error { + v, ok := row.(*schemapb.ScalarField) + if !ok { + return merr.WrapErrParameterInvalid("*schemapb.ScalarField", row, "Wrong row type") + } + data.Data = append(data.Data, v) + return nil +} + +func (data *JSONFieldData) AppendRow(row interface{}) error { + v, ok := row.([]byte) + if !ok { + return merr.WrapErrParameterInvalid("[]byte", row, "Wrong row type") + } + data.Data = append(data.Data, v) + return nil +} + +func (data *BinaryVectorFieldData) AppendRow(row interface{}) error { + v, ok := row.([]byte) + if !ok || len(v) != data.Dim/8 { + return merr.WrapErrParameterInvalid("[]byte", row, "Wrong row type") + } + data.Data = append(data.Data, v...) + return nil +} + +func (data *FloatVectorFieldData) AppendRow(row interface{}) error { + v, ok := row.([]float32) + if !ok || len(v) != data.Dim { + return merr.WrapErrParameterInvalid("[]float32", row, "Wrong row type") + } + data.Data = append(data.Data, v...) + return nil +} + +func (data *Float16VectorFieldData) AppendRow(row interface{}) error { + v, ok := row.([]byte) + if !ok || len(v) != data.Dim*2 { + return merr.WrapErrParameterInvalid("[]byte", row, "Wrong row type") + } + data.Data = append(data.Data, v...) + return nil +} + +// GetMemorySize implements FieldData.GetMemorySize +func (data *BoolFieldData) GetMemorySize() int { return binary.Size(data.Data) } +func (data *Int8FieldData) GetMemorySize() int { return binary.Size(data.Data) } +func (data *Int16FieldData) GetMemorySize() int { return binary.Size(data.Data) } +func (data *Int32FieldData) GetMemorySize() int { return binary.Size(data.Data) } +func (data *Int64FieldData) GetMemorySize() int { return binary.Size(data.Data) } +func (data *FloatFieldData) GetMemorySize() int { return binary.Size(data.Data) } +func (data *DoubleFieldData) GetMemorySize() int { return binary.Size(data.Data) } +func (data *BinaryVectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 } +func (data *FloatVectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 } +func (data *Float16VectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 } + +// why not binary.Size(data) directly? binary.Size(data) return -1 +// binary.Size returns how many bytes Write would generate to encode the value v, which +// must be a fixed-size value or a slice of fixed-size values, or a pointer to such data. +// If v is neither of these, binary.Size returns -1. +func (data *StringFieldData) GetMemorySize() int { + var size int + for _, val := range data.Data { + size += len(val) + 16 + } + return size +} + +func (data *ArrayFieldData) GetMemorySize() int { + var size int + for _, val := range data.Data { + switch data.ElementType { + case schemapb.DataType_Bool: + size += binary.Size(val.GetBoolData().GetData()) + case schemapb.DataType_Int8: + size += binary.Size(val.GetIntData().GetData()) / 4 + case schemapb.DataType_Int16: + size += binary.Size(val.GetIntData().GetData()) / 2 + case schemapb.DataType_Int32: + size += binary.Size(val.GetIntData().GetData()) + case schemapb.DataType_Int64: + size += binary.Size(val.GetLongData().GetData()) + case schemapb.DataType_Float: + size += binary.Size(val.GetFloatData().GetData()) + case schemapb.DataType_Double: + size += binary.Size(val.GetDoubleData().GetData()) + case schemapb.DataType_String, schemapb.DataType_VarChar: + size += (&StringFieldData{Data: val.GetStringData().GetData()}).GetMemorySize() + } + } + return size +} + +func (data *JSONFieldData) GetMemorySize() int { + var size int + for _, val := range data.Data { + size += len(val) + 16 + } + return size +} diff --git a/internal/storage/insert_data_test.go b/internal/storage/insert_data_test.go new file mode 100644 index 0000000000000..49c8781a453d8 --- /dev/null +++ b/internal/storage/insert_data_test.go @@ -0,0 +1,319 @@ +package storage + +import ( + "testing" + + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +func TestInsertDataSuite(t *testing.T) { + suite.Run(t, new(InsertDataSuite)) +} + +func TestArrayFieldDataSuite(t *testing.T) { + suite.Run(t, new(ArrayFieldDataSuite)) +} + +type InsertDataSuite struct { + suite.Suite + + schema *schemapb.CollectionSchema + + iDataOneRow *InsertData + iDataTwoRows *InsertData + iDataEmpty *InsertData +} + +func (s *InsertDataSuite) SetupSuite() { + s.schema = genTestCollectionMeta().Schema +} + +func (s *InsertDataSuite) TestInsertData() { + s.Run("nil schema", func() { + idata, err := NewInsertData(nil) + s.Error(err) + s.Nil(idata) + }) + + s.Run("invalid schema", func() { + tests := []struct { + description string + invalidType schemapb.DataType + }{ + {"binary vector without dim", schemapb.DataType_BinaryVector}, + {"float vector without dim", schemapb.DataType_FloatVector}, + {"float16 vector without dim", schemapb.DataType_Float16Vector}, + } + + for _, test := range tests { + s.Run(test.description, func() { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + DataType: test.invalidType, + }, + }, + } + idata, err := NewInsertData(schema) + s.Error(err) + s.Nil(idata) + }) + } + }) + + s.Run("empty iData", func() { + idata := &InsertData{} + s.True(idata.IsEmpty()) + s.Equal(0, idata.GetRowNum()) + s.Equal(0, idata.GetMemorySize()) + + err := idata.Append(map[FieldID]interface{}{1: struct{}{}}) + s.Error(err) + }) + + s.Run("init by New", func() { + s.True(s.iDataEmpty.IsEmpty()) + s.Equal(0, s.iDataEmpty.GetRowNum()) + s.Equal(12, s.iDataEmpty.GetMemorySize()) + + s.False(s.iDataOneRow.IsEmpty()) + s.Equal(1, s.iDataOneRow.GetRowNum()) + s.Equal(139, s.iDataOneRow.GetMemorySize()) + + s.False(s.iDataTwoRows.IsEmpty()) + s.Equal(2, s.iDataTwoRows.GetRowNum()) + s.Equal(266, s.iDataTwoRows.GetMemorySize()) + + for _, field := range s.iDataTwoRows.Data { + s.Equal(2, field.RowNum()) + + err := field.AppendRow(struct{}{}) + log.Warn("error", zap.Error(err)) + s.ErrorIs(err, merr.ErrParameterInvalid) + } + }) +} + +func (s *InsertDataSuite) TestMemorySize() { + s.Equal(s.iDataEmpty.Data[RowIDField].GetMemorySize(), 0) + s.Equal(s.iDataEmpty.Data[TimestampField].GetMemorySize(), 0) + s.Equal(s.iDataEmpty.Data[BoolField].GetMemorySize(), 0) + s.Equal(s.iDataEmpty.Data[Int8Field].GetMemorySize(), 0) + s.Equal(s.iDataEmpty.Data[Int16Field].GetMemorySize(), 0) + s.Equal(s.iDataEmpty.Data[Int32Field].GetMemorySize(), 0) + s.Equal(s.iDataEmpty.Data[Int64Field].GetMemorySize(), 0) + s.Equal(s.iDataEmpty.Data[FloatField].GetMemorySize(), 0) + s.Equal(s.iDataEmpty.Data[DoubleField].GetMemorySize(), 0) + s.Equal(s.iDataEmpty.Data[StringField].GetMemorySize(), 0) + s.Equal(s.iDataEmpty.Data[ArrayField].GetMemorySize(), 0) + s.Equal(s.iDataEmpty.Data[BinaryVectorField].GetMemorySize(), 4) + s.Equal(s.iDataEmpty.Data[FloatVectorField].GetMemorySize(), 4) + s.Equal(s.iDataEmpty.Data[Float16VectorField].GetMemorySize(), 4) + + s.Equal(s.iDataOneRow.Data[RowIDField].GetMemorySize(), 8) + s.Equal(s.iDataOneRow.Data[TimestampField].GetMemorySize(), 8) + s.Equal(s.iDataOneRow.Data[BoolField].GetMemorySize(), 1) + s.Equal(s.iDataOneRow.Data[Int8Field].GetMemorySize(), 1) + s.Equal(s.iDataOneRow.Data[Int16Field].GetMemorySize(), 2) + s.Equal(s.iDataOneRow.Data[Int32Field].GetMemorySize(), 4) + s.Equal(s.iDataOneRow.Data[Int64Field].GetMemorySize(), 8) + s.Equal(s.iDataOneRow.Data[FloatField].GetMemorySize(), 4) + s.Equal(s.iDataOneRow.Data[DoubleField].GetMemorySize(), 8) + s.Equal(s.iDataOneRow.Data[StringField].GetMemorySize(), 19) + s.Equal(s.iDataOneRow.Data[JSONField].GetMemorySize(), len([]byte(`{"batch":1}`))+16) + s.Equal(s.iDataOneRow.Data[ArrayField].GetMemorySize(), 3*4) + s.Equal(s.iDataOneRow.Data[BinaryVectorField].GetMemorySize(), 5) + s.Equal(s.iDataOneRow.Data[FloatVectorField].GetMemorySize(), 20) + s.Equal(s.iDataOneRow.Data[Float16VectorField].GetMemorySize(), 12) + + s.Equal(s.iDataTwoRows.Data[RowIDField].GetMemorySize(), 16) + s.Equal(s.iDataTwoRows.Data[TimestampField].GetMemorySize(), 16) + s.Equal(s.iDataTwoRows.Data[BoolField].GetMemorySize(), 2) + s.Equal(s.iDataTwoRows.Data[Int8Field].GetMemorySize(), 2) + s.Equal(s.iDataTwoRows.Data[Int16Field].GetMemorySize(), 4) + s.Equal(s.iDataTwoRows.Data[Int32Field].GetMemorySize(), 8) + s.Equal(s.iDataTwoRows.Data[Int64Field].GetMemorySize(), 16) + s.Equal(s.iDataTwoRows.Data[FloatField].GetMemorySize(), 8) + s.Equal(s.iDataTwoRows.Data[DoubleField].GetMemorySize(), 16) + s.Equal(s.iDataTwoRows.Data[StringField].GetMemorySize(), 38) + s.Equal(s.iDataTwoRows.Data[ArrayField].GetMemorySize(), 24) + s.Equal(s.iDataTwoRows.Data[BinaryVectorField].GetMemorySize(), 6) + s.Equal(s.iDataTwoRows.Data[FloatVectorField].GetMemorySize(), 36) + s.Equal(s.iDataTwoRows.Data[Float16VectorField].GetMemorySize(), 20) +} + +func (s *InsertDataSuite) SetupTest() { + var err error + s.iDataEmpty, err = NewInsertData(s.schema) + s.Require().NoError(err) + s.True(s.iDataEmpty.IsEmpty()) + s.Equal(0, s.iDataEmpty.GetRowNum()) + s.Equal(12, s.iDataEmpty.GetMemorySize()) + + row1 := map[FieldID]interface{}{ + RowIDField: int64(3), + TimestampField: int64(3), + BoolField: true, + Int8Field: int8(3), + Int16Field: int16(3), + Int32Field: int32(3), + Int64Field: int64(3), + FloatField: float32(3), + DoubleField: float64(3), + StringField: "str", + BinaryVectorField: []byte{0}, + FloatVectorField: []float32{4, 5, 6, 7}, + Float16VectorField: []byte{0, 0, 0, 0, 255, 255, 255, 255}, + ArrayField: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}}, + }, + }, + JSONField: []byte(`{"batch":3}`), + } + + s.iDataOneRow, err = NewInsertData(s.schema) + s.Require().NoError(err) + err = s.iDataOneRow.Append(row1) + s.Require().NoError(err) + + for fID, field := range s.iDataOneRow.Data { + s.Equal(row1[fID], field.GetRow(0)) + } + + row2 := map[FieldID]interface{}{ + RowIDField: int64(1), + TimestampField: int64(1), + BoolField: false, + Int8Field: int8(1), + Int16Field: int16(1), + Int32Field: int32(1), + Int64Field: int64(1), + FloatField: float32(1), + DoubleField: float64(1), + StringField: string("str"), + BinaryVectorField: []byte{0}, + FloatVectorField: []float32{4, 5, 6, 7}, + Float16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + ArrayField: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}}, + }, + }, + JSONField: []byte(`{"batch":1}`), + } + + s.iDataTwoRows, err = NewInsertData(s.schema) + s.Require().NoError(err) + err = s.iDataTwoRows.Append(row1) + s.Require().NoError(err) + err = s.iDataTwoRows.Append(row2) + s.Require().NoError(err) +} + +type ArrayFieldDataSuite struct { + suite.Suite +} + +func (s *ArrayFieldDataSuite) TestArrayFieldData() { + fieldID2Type := map[int64]schemapb.DataType{ + ArrayField + 1: schemapb.DataType_Bool, + ArrayField + 2: schemapb.DataType_Int8, + ArrayField + 3: schemapb.DataType_Int16, + ArrayField + 4: schemapb.DataType_Int32, + ArrayField + 5: schemapb.DataType_Int64, + ArrayField + 6: schemapb.DataType_Float, + ArrayField + 7: schemapb.DataType_Double, + ArrayField + 8: schemapb.DataType_VarChar, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: RowIDField, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: TimestampField, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: Int64Field, + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, + }, + } + + for fieldID, elementType := range fieldID2Type { + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: fieldID, + DataType: schemapb.DataType_Array, + ElementType: elementType, + }) + } + + insertData, err := NewInsertData(schema) + s.NoError(err) + + s.Equal(0, insertData.GetRowNum()) + s.Equal(0, insertData.GetMemorySize()) + s.True(insertData.IsEmpty()) + + fieldIDToData := map[int64]interface{}{ + RowIDField: int64(1), + TimestampField: int64(2), + Int64Field: int64(3), + ArrayField + 1: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{Data: []bool{true, false}}, + }, + }, + ArrayField + 2: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{Data: []int32{0, 0}}, + }, + }, + ArrayField + 3: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{Data: []int32{1, 1}}, + }, + }, + ArrayField + 4: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{Data: []int32{2, 2}}, + }, + }, + ArrayField + 5: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{Data: []int64{3, 3}}, + }, + }, + ArrayField + 6: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{Data: []float32{4, 4}}, + }, + }, + ArrayField + 7: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{Data: []float64{5, 5}}, + }, + }, + ArrayField + 8: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{Data: []string{"6", "6"}}, + }, + }, + } + + err = insertData.Append(fieldIDToData) + s.NoError(err) + s.Equal(1, insertData.GetRowNum()) + s.Equal(114, insertData.GetMemorySize()) + s.False(insertData.IsEmpty()) +} diff --git a/internal/storage/local_chunk_manager.go b/internal/storage/local_chunk_manager.go index b8861c1d95735..f22c56ce150b0 100644 --- a/internal/storage/local_chunk_manager.go +++ b/internal/storage/local_chunk_manager.go @@ -20,7 +20,6 @@ import ( "context" "fmt" "io" - "io/ioutil" "os" "path" "path/filepath" @@ -97,7 +96,7 @@ func (lcm *LocalChunkManager) Write(ctx context.Context, filePath string, conten return err } } - return ioutil.WriteFile(filePath, content, os.ModePerm) + return os.WriteFile(filePath, content, os.ModePerm) } // MultiWrite writes the data to local storage. @@ -134,7 +133,7 @@ func (lcm *LocalChunkManager) Read(ctx context.Context, filePath string) ([]byte return nil, fmt.Errorf("file not exist: %s", filePath) } - return ioutil.ReadFile(filePath) + return os.ReadFile(filePath) } // MultiRead reads the local storage data if exists. diff --git a/internal/storage/minio_chunk_manager.go b/internal/storage/minio_chunk_manager.go index 635d6fee9770a..d5c660d968f44 100644 --- a/internal/storage/minio_chunk_manager.go +++ b/internal/storage/minio_chunk_manager.go @@ -26,6 +26,12 @@ import ( "time" "github.com/cockroachdb/errors" + minio "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" + "go.uber.org/zap" + "golang.org/x/exp/mmap" + "golang.org/x/sync/errgroup" + "github.com/milvus-io/milvus/internal/storage/aliyun" "github.com/milvus-io/milvus/internal/storage/gcp" "github.com/milvus-io/milvus/pkg/log" @@ -33,27 +39,20 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/timerecord" - minio "github.com/minio/minio-go/v7" - "github.com/minio/minio-go/v7/pkg/credentials" - "go.uber.org/zap" - "golang.org/x/exp/mmap" - "golang.org/x/sync/errgroup" ) -var ( - ErrNoSuchKey = errors.New("NoSuchKey") -) +const NoSuchKey = "NoSuchKey" -const ( - CloudProviderGCP = "gcp" - CloudProviderAWS = "aws" - CloudProviderAliyun = "aliyun" -) +var ErrNoSuchKey = errors.New(NoSuchKey) func WrapErrNoSuchKey(key string) error { return fmt.Errorf("%w(key=%s)", ErrNoSuchKey, key) } +func IsErrNoSuchKey(err error) bool { + return strings.HasPrefix(err.Error(), NoSuchKey) +} + var CheckBucketRetryAttempts uint = 20 // MinioChunkManager is responsible for read and write data stored in minio. @@ -80,8 +79,8 @@ func NewMinioChunkManager(ctx context.Context, opts ...Option) (*MinioChunkManag func newMinioChunkManagerWithConfig(ctx context.Context, c *config) (*MinioChunkManager, error) { var creds *credentials.Credentials - var newMinioFn = minio.New - var bucketLookupType = minio.BucketLookupAuto + newMinioFn := minio.New + bucketLookupType := minio.BucketLookupAuto if c.useVirtualHost { bucketLookupType = minio.BucketLookupDNS @@ -208,7 +207,6 @@ func (mcm *MinioChunkManager) Size(ctx context.Context, filePath string) (int64, // Write writes the data to minio storage. func (mcm *MinioChunkManager) Write(ctx context.Context, filePath string, content []byte) error { _, err := mcm.putMinioObject(ctx, mcm.bucketName, filePath, bytes.NewReader(content), int64(len(content)), minio.PutObjectOptions{}) - if err != nil { log.Warn("failed to put object", zap.String("bucket", mcm.bucketName), zap.String("path", filePath), zap.Error(err)) return err @@ -414,7 +412,6 @@ func (mcm *MinioChunkManager) RemoveWithPrefix(ctx context.Context, prefix strin // calling `ListWithPrefix` with `prefix` = a && `recursive` = false will only returns [a, ab] // If caller needs all objects without level limitation, `recursive` shall be true. func (mcm *MinioChunkManager) ListWithPrefix(ctx context.Context, prefix string, recursive bool) ([]string, []time.Time, error) { - // cannot use ListObjects(ctx, bucketName, Opt{Prefix:prefix, Recursive:true}) // if minio has lots of objects under the provided path // recursive = true may timeout during the recursive browsing the objects. @@ -475,7 +472,8 @@ func Read(r io.Reader, size int64) ([]byte, error) { } func (mcm *MinioChunkManager) getMinioObject(ctx context.Context, bucketName, objectName string, - opts minio.GetObjectOptions) (*minio.Object, error) { + opts minio.GetObjectOptions, +) (*minio.Object, error) { start := timerecord.NewTimeRecorder("getMinioObject") reader, err := mcm.Client.GetObject(ctx, bucketName, objectName, opts) @@ -491,7 +489,8 @@ func (mcm *MinioChunkManager) getMinioObject(ctx context.Context, bucketName, ob } func (mcm *MinioChunkManager) putMinioObject(ctx context.Context, bucketName, objectName string, reader io.Reader, objectSize int64, - opts minio.PutObjectOptions) (minio.UploadInfo, error) { + opts minio.PutObjectOptions, +) (minio.UploadInfo, error) { start := timerecord.NewTimeRecorder("putMinioObject") info, err := mcm.Client.PutObject(ctx, bucketName, objectName, reader, objectSize, opts) @@ -507,7 +506,8 @@ func (mcm *MinioChunkManager) putMinioObject(ctx context.Context, bucketName, ob } func (mcm *MinioChunkManager) statMinioObject(ctx context.Context, bucketName, objectName string, - opts minio.StatObjectOptions) (minio.ObjectInfo, error) { + opts minio.StatObjectOptions, +) (minio.ObjectInfo, error) { start := timerecord.NewTimeRecorder("statMinioObject") info, err := mcm.Client.StatObject(ctx, bucketName, objectName, opts) @@ -523,7 +523,8 @@ func (mcm *MinioChunkManager) statMinioObject(ctx context.Context, bucketName, o } func (mcm *MinioChunkManager) listMinioObjects(ctx context.Context, bucketName string, - opts minio.ListObjectsOptions) <-chan minio.ObjectInfo { + opts minio.ListObjectsOptions, +) <-chan minio.ObjectInfo { start := timerecord.NewTimeRecorder("listMinioObjects") res := mcm.Client.ListObjects(ctx, bucketName, opts) @@ -535,7 +536,8 @@ func (mcm *MinioChunkManager) listMinioObjects(ctx context.Context, bucketName s } func (mcm *MinioChunkManager) removeMinioObject(ctx context.Context, bucketName, objectName string, - opts minio.RemoveObjectOptions) error { + opts minio.RemoveObjectOptions, +) error { start := timerecord.NewTimeRecorder("removeMinioObject") err := mcm.Client.RemoveObject(ctx, bucketName, objectName, opts) diff --git a/internal/storage/minio_chunk_manager_test.go b/internal/storage/minio_chunk_manager_test.go index 54feec3e1f75b..5aa4eaa533841 100644 --- a/internal/storage/minio_chunk_manager_test.go +++ b/internal/storage/minio_chunk_manager_test.go @@ -76,7 +76,6 @@ func TestMinIOCMFail(t *testing.T) { ) assert.Error(t, err) assert.Nil(t, client) - } func TestMinIOCM(t *testing.T) { @@ -440,7 +439,6 @@ func TestMinIOCM(t *testing.T) { r, err := testCM.Mmap(ctx, key) assert.Error(t, err) assert.Nil(t, r) - }) t.Run("test Prefix", func(t *testing.T) { diff --git a/internal/storage/minio_object_storage.go b/internal/storage/minio_object_storage.go new file mode 100644 index 0000000000000..76a14a492f325 --- /dev/null +++ b/internal/storage/minio_object_storage.go @@ -0,0 +1,149 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "fmt" + "io" + "time" + + minio "github.com/minio/minio-go/v7" + "github.com/minio/minio-go/v7/pkg/credentials" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/storage/aliyun" + "github.com/milvus-io/milvus/internal/storage/gcp" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/retry" +) + +type MinioObjectStorage struct { + *minio.Client +} + +func newMinioObjectStorageWithConfig(ctx context.Context, c *config) (*MinioObjectStorage, error) { + var creds *credentials.Credentials + newMinioFn := minio.New + bucketLookupType := minio.BucketLookupAuto + + switch c.cloudProvider { + case CloudProviderAliyun: + // auto doesn't work for aliyun, so we set to dns deliberately + bucketLookupType = minio.BucketLookupDNS + if c.useIAM { + newMinioFn = aliyun.NewMinioClient + } else { + creds = credentials.NewStaticV4(c.accessKeyID, c.secretAccessKeyID, "") + } + case CloudProviderGCP: + newMinioFn = gcp.NewMinioClient + if !c.useIAM { + creds = credentials.NewStaticV2(c.accessKeyID, c.secretAccessKeyID, "") + } + default: // aws, minio + if c.useIAM { + creds = credentials.NewIAM("") + } else { + creds = credentials.NewStaticV4(c.accessKeyID, c.secretAccessKeyID, "") + } + } + minioOpts := &minio.Options{ + BucketLookup: bucketLookupType, + Creds: creds, + Secure: c.useSSL, + } + minIOClient, err := newMinioFn(c.address, minioOpts) + // options nil or invalid formatted endpoint, don't need to retry + if err != nil { + return nil, err + } + var bucketExists bool + // check valid in first query + checkBucketFn := func() error { + bucketExists, err = minIOClient.BucketExists(ctx, c.bucketName) + if err != nil { + log.Warn("failed to check blob bucket exist", zap.String("bucket", c.bucketName), zap.Error(err)) + return err + } + if !bucketExists { + if c.createBucket { + log.Info("blob bucket not exist, create bucket.", zap.Any("bucket name", c.bucketName)) + err := minIOClient.MakeBucket(ctx, c.bucketName, minio.MakeBucketOptions{}) + if err != nil { + log.Warn("failed to create blob bucket", zap.String("bucket", c.bucketName), zap.Error(err)) + return err + } + } else { + return fmt.Errorf("bucket %s not Existed", c.bucketName) + } + } + return nil + } + err = retry.Do(ctx, checkBucketFn, retry.Attempts(CheckBucketRetryAttempts)) + if err != nil { + return nil, err + } + + return &MinioObjectStorage{minIOClient}, nil +} + +func (minioObjectStorage *MinioObjectStorage) GetObject(ctx context.Context, bucketName, objectName string, offset int64, size int64) (FileReader, error) { + opts := minio.GetObjectOptions{} + if offset > 0 { + err := opts.SetRange(offset, offset+size-1) + if err != nil { + log.Warn("failed to set range", zap.String("bucket", bucketName), zap.String("path", objectName), zap.Error(err)) + return nil, err + } + } + object, err := minioObjectStorage.Client.GetObject(ctx, bucketName, objectName, opts) + if err != nil { + return nil, err + } + return object, nil +} + +func (minioObjectStorage *MinioObjectStorage) PutObject(ctx context.Context, bucketName, objectName string, reader io.Reader, objectSize int64) error { + _, err := minioObjectStorage.Client.PutObject(ctx, bucketName, objectName, reader, objectSize, minio.PutObjectOptions{}) + return err +} + +func (minioObjectStorage *MinioObjectStorage) StatObject(ctx context.Context, bucketName, objectName string) (int64, error) { + info, err := minioObjectStorage.Client.StatObject(ctx, bucketName, objectName, minio.StatObjectOptions{}) + return info.Size, err +} + +func (minioObjectStorage *MinioObjectStorage) ListObjects(ctx context.Context, bucketName string, prefix string, recursive bool) (map[string]time.Time, error) { + res := minioObjectStorage.Client.ListObjects(ctx, bucketName, minio.ListObjectsOptions{ + Prefix: prefix, + Recursive: recursive, + }) + + objects := map[string]time.Time{} + for object := range res { + if !recursive && object.Err != nil { + return map[string]time.Time{}, object.Err + } + objects[object.Key] = object.LastModified + } + return objects, nil +} + +func (minioObjectStorage *MinioObjectStorage) RemoveObject(ctx context.Context, bucketName, objectName string) error { + return minioObjectStorage.Client.RemoveObject(ctx, bucketName, objectName, minio.RemoveObjectOptions{}) +} diff --git a/internal/storage/minio_object_storage_test.go b/internal/storage/minio_object_storage_test.go new file mode 100644 index 0000000000000..5e7968ddc3381 --- /dev/null +++ b/internal/storage/minio_object_storage_test.go @@ -0,0 +1,168 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "bytes" + "context" + "io" + "testing" + + "github.com/minio/minio-go/v7" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMinioObjectStorage(t *testing.T) { + ctx := context.Background() + config := config{ + address: Params.MinioCfg.Address.GetValue(), + accessKeyID: Params.MinioCfg.AccessKeyID.GetValue(), + secretAccessKeyID: Params.MinioCfg.SecretAccessKey.GetValue(), + rootPath: Params.MinioCfg.RootPath.GetValue(), + + bucketName: Params.MinioCfg.BucketName.GetValue(), + createBucket: true, + useIAM: false, + cloudProvider: "minio", + } + + t.Run("test initialize", func(t *testing.T) { + var err error + bucketName := config.bucketName + config.bucketName = "" + _, err = newMinioObjectStorageWithConfig(ctx, &config) + assert.Error(t, err) + config.bucketName = bucketName + _, err = newMinioObjectStorageWithConfig(ctx, &config) + assert.Equal(t, err, nil) + }) + + t.Run("test load", func(t *testing.T) { + testCM, err := newMinioObjectStorageWithConfig(ctx, &config) + assert.Equal(t, err, nil) + defer testCM.RemoveBucket(ctx, config.bucketName) + + prepareTests := []struct { + key string + value []byte + }{ + {"abc", []byte("123")}, + {"abcd", []byte("1234")}, + {"key_1", []byte("111")}, + {"key_2", []byte("222")}, + {"key_3", []byte("333")}, + } + + for _, test := range prepareTests { + err := testCM.PutObject(ctx, config.bucketName, test.key, bytes.NewReader(test.value), int64(len(test.value))) + require.NoError(t, err) + } + + loadTests := []struct { + isvalid bool + loadKey string + expectedValue []byte + + description string + }{ + {true, "abc", []byte("123"), "load valid key abc"}, + {true, "abcd", []byte("1234"), "load valid key abcd"}, + {true, "key_1", []byte("111"), "load valid key key_1"}, + {true, "key_2", []byte("222"), "load valid key key_2"}, + {true, "key_3", []byte("333"), "load valid key key_3"}, + {false, "key_not_exist", []byte(""), "load invalid key key_not_exist"}, + {false, "/", []byte(""), "load leading slash"}, + } + + for _, test := range loadTests { + t.Run(test.description, func(t *testing.T) { + if test.isvalid { + got, err := testCM.GetObject(ctx, config.bucketName, test.loadKey, 0, 1024) + assert.NoError(t, err) + contentData, err := io.ReadAll(got) + assert.NoError(t, err) + assert.Equal(t, len(contentData), len(test.expectedValue)) + assert.Equal(t, test.expectedValue, contentData) + statSize, err := testCM.StatObject(ctx, config.bucketName, test.loadKey) + assert.NoError(t, err) + assert.Equal(t, statSize, int64(len(contentData))) + _, err = testCM.GetObject(ctx, config.bucketName, test.loadKey, 1, 1023) + assert.NoError(t, err) + } else { + got, err := testCM.GetObject(ctx, config.bucketName, test.loadKey, 0, 1024) + assert.NoError(t, err) + _, err = io.ReadAll(got) + errResponse := minio.ToErrorResponse(err) + if test.loadKey == "/" { + assert.Equal(t, errResponse.Code, "XMinioInvalidObjectName") + } else { + assert.Equal(t, errResponse.Code, "NoSuchKey") + } + } + }) + } + + loadWithPrefixTests := []struct { + isvalid bool + prefix string + expectedValue [][]byte + + description string + }{ + {true, "abc", [][]byte{[]byte("123"), []byte("1234")}, "load with valid prefix abc"}, + {true, "key_", [][]byte{[]byte("111"), []byte("222"), []byte("333")}, "load with valid prefix key_"}, + {true, "prefix", [][]byte{}, "load with valid but not exist prefix prefix"}, + } + + for _, test := range loadWithPrefixTests { + t.Run(test.description, func(t *testing.T) { + gotk, err := testCM.ListObjects(ctx, config.bucketName, test.prefix, false) + assert.NoError(t, err) + assert.Equal(t, len(test.expectedValue), len(gotk)) + for key := range gotk { + err := testCM.RemoveObject(ctx, config.bucketName, key) + assert.NoError(t, err) + } + }) + } + }) + + t.Run("test useIAM", func(t *testing.T) { + var err error + config.useIAM = true + _, err = newMinioObjectStorageWithConfig(ctx, &config) + assert.Error(t, err) + config.useIAM = false + }) + + t.Run("test cloud provider", func(t *testing.T) { + var err error + cloudProvider := config.cloudProvider + config.cloudProvider = "aliyun" + config.useIAM = true + _, err = newMinioObjectStorageWithConfig(ctx, &config) + assert.Error(t, err) + config.useIAM = false + _, err = newMinioObjectStorageWithConfig(ctx, &config) + assert.Error(t, err) + config.cloudProvider = "gcp" + _, err = newMinioObjectStorageWithConfig(ctx, &config) + assert.NoError(t, err) + config.cloudProvider = cloudProvider + }) +} diff --git a/internal/storage/options.go b/internal/storage/options.go index 83fe9f8848fd7..b0efedaca4019 100644 --- a/internal/storage/options.go +++ b/internal/storage/options.go @@ -14,6 +14,7 @@ type config struct { iamEndpoint string useVirtualHost bool region string + requestTimeoutMs int64 } func newDefaultConfig() *config { @@ -40,6 +41,7 @@ func AccessKeyID(accessKeyID string) Option { c.accessKeyID = accessKeyID } } + func SecretAccessKeyID(secretAccessKeyID string) Option { return func(c *config) { c.secretAccessKeyID = secretAccessKeyID @@ -93,3 +95,9 @@ func Region(region string) Option { c.region = region } } + +func RequestTimeout(requestTimeoutMs int64) Option { + return func(c *config) { + c.requestTimeoutMs = requestTimeoutMs + } +} diff --git a/internal/storage/payload.go b/internal/storage/payload.go index ab1c3e8b1a5fb..b316c9d93ca89 100644 --- a/internal/storage/payload.go +++ b/internal/storage/payload.go @@ -36,6 +36,7 @@ type PayloadWriterInterface interface { AddOneJSONToPayload(msg []byte) error AddBinaryVectorToPayload(binVec []byte, dim int) error AddFloatVectorToPayload(binVec []float32, dim int) error + AddFloat16VectorToPayload(binVec []byte, dim int) error FinishPayloadWriter() error GetPayloadBufferFromWriter() ([]byte, error) GetPayloadLengthFromWriter() (int, error) @@ -58,6 +59,7 @@ type PayloadReaderInterface interface { GetArrayFromPayload() ([]*schemapb.ScalarField, error) GetJSONFromPayload() ([][]byte, error) GetBinaryVectorFromPayload() ([]byte, int, error) + GetFloat16VectorFromPayload() ([]byte, int, error) GetFloatVectorFromPayload() ([]float32, int, error) GetPayloadLengthFromReader() (int, error) ReleasePayloadReader() error diff --git a/internal/storage/payload_reader.go b/internal/storage/payload_reader.go index c1586bcdd0572..4ee24a629f1cb 100644 --- a/internal/storage/payload_reader.go +++ b/internal/storage/payload_reader.go @@ -67,6 +67,8 @@ func (r *PayloadReader) GetDataFromPayload() (interface{}, int, error) { return r.GetBinaryVectorFromPayload() case schemapb.DataType_FloatVector: return r.GetFloatVectorFromPayload() + case schemapb.DataType_Float16Vector: + return r.GetFloat16VectorFromPayload() case schemapb.DataType_String, schemapb.DataType_VarChar: val, err := r.GetStringFromPayload() return val, 0, err @@ -319,6 +321,29 @@ func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) { return ret, dim * 8, nil } +// GetFloat16VectorFromPayload returns vector, dimension, error +func (r *PayloadReader) GetFloat16VectorFromPayload() ([]byte, int, error) { + if r.colType != schemapb.DataType_Float16Vector { + return nil, -1, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String()) + } + dim := r.reader.RowGroup(0).Column(0).Descriptor().TypeLength() / 2 + values := make([]parquet.FixedLenByteArray, r.numRows) + valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows) + if err != nil { + return nil, -1, err + } + + if valuesRead != r.numRows { + return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead) + } + + ret := make([]byte, int64(dim*2)*r.numRows) + for i := 0; i < int(r.numRows); i++ { + copy(ret[i*dim*2:(i+1)*dim*2], values[i]) + } + return ret, dim, nil +} + // GetFloatVectorFromPayload returns vector, dimension, error func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) { if r.colType != schemapb.DataType_FloatVector { diff --git a/internal/storage/payload_reader_test.go b/internal/storage/payload_reader_test.go index 13e17c8a07da9..3933a55fea169 100644 --- a/internal/storage/payload_reader_test.go +++ b/internal/storage/payload_reader_test.go @@ -7,6 +7,8 @@ import ( "github.com/apache/arrow/go/v12/parquet/file" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) type ReadDataFromAllRowGroupsSuite struct { diff --git a/internal/storage/payload_test.go b/internal/storage/payload_test.go index a976cfeb199b6..a9fe6177c65b5 100644 --- a/internal/storage/payload_test.go +++ b/internal/storage/payload_test.go @@ -26,7 +26,6 @@ import ( ) func TestPayload_ReaderAndWriter(t *testing.T) { - t.Run("TestBool", func(t *testing.T) { w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) @@ -60,7 +59,6 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.NoError(t, err) assert.ElementsMatch(t, []bool{false, false, false, false, false, false, false, false}, bools) defer r.ReleasePayloadReader() - }) t.Run("TestInt8", func(t *testing.T) { @@ -536,6 +534,39 @@ func TestPayload_ReaderAndWriter(t *testing.T) { defer r.ReleasePayloadReader() }) + t.Run("TestFloat16Vector", func(t *testing.T) { + w, err := NewPayloadWriter(schemapb.DataType_Float16Vector, 1) + require.Nil(t, err) + require.NotNil(t, w) + + err = w.AddFloat16VectorToPayload([]byte{1, 2}, 1) + assert.NoError(t, err) + err = w.AddDataToPayload([]byte{3, 4}, 1) + assert.NoError(t, err) + err = w.FinishPayloadWriter() + assert.NoError(t, err) + + length, err := w.GetPayloadLengthFromWriter() + assert.NoError(t, err) + assert.Equal(t, 2, length) + defer w.ReleasePayloadWriter() + + buffer, err := w.GetPayloadBufferFromWriter() + assert.NoError(t, err) + + r, err := NewPayloadReader(schemapb.DataType_Float16Vector, buffer) + require.Nil(t, err) + length, err = r.GetPayloadLengthFromReader() + assert.NoError(t, err) + assert.Equal(t, length, 2) + + float16Vecs, dim, err := r.GetFloat16VectorFromPayload() + assert.NoError(t, err) + assert.Equal(t, 1, dim) + assert.Equal(t, 4, len(float16Vecs)) + assert.ElementsMatch(t, []byte{1, 2, 3, 4}, float16Vecs) + }) + // t.Run("TestAddDataToPayload", func(t *testing.T) { // w, err := NewPayloadWriter(schemapb.DataType_Bool) // w.colType = 999 diff --git a/internal/storage/payload_writer.go b/internal/storage/payload_writer.go index fd688b3153e84..d63a2fac77505 100644 --- a/internal/storage/payload_writer.go +++ b/internal/storage/payload_writer.go @@ -30,6 +30,7 @@ import ( "github.com/apache/arrow/go/v12/parquet/pqarrow" "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -151,6 +152,12 @@ func (w *NativePayloadWriter) AddDataToPayload(data interface{}, dim ...int) err return errors.New("incorrect data type") } return w.AddFloatVectorToPayload(val, dim[0]) + case schemapb.DataType_Float16Vector: + val, ok := data.([]byte) + if !ok { + return errors.New("incorrect data type") + } + return w.AddFloat16VectorToPayload(val, dim[0]) default: return errors.New("incorrect datatype") } @@ -412,6 +419,31 @@ func (w *NativePayloadWriter) AddFloatVectorToPayload(data []float32, dim int) e return nil } +func (w *NativePayloadWriter) AddFloat16VectorToPayload(data []byte, dim int) error { + if w.finished { + return errors.New("can't append data to finished writer") + } + + if len(data) == 0 { + return errors.New("can't add empty msgs into payload") + } + + builder, ok := w.builder.(*array.FixedSizeBinaryBuilder) + if !ok { + return errors.New("failed to cast ArrayBuilder") + } + + byteLength := dim * 2 + length := len(data) / byteLength + + builder.Reserve(length) + for i := 0; i < length; i++ { + builder.Append(data[i*byteLength : (i+1)*byteLength]) + } + + return nil +} + func (w *NativePayloadWriter) FinishPayloadWriter() error { if w.finished { return errors.New("can't reuse a finished writer") @@ -503,6 +535,10 @@ func milvusDataTypeToArrowType(dataType schemapb.DataType, dim int) arrow.DataTy return &arrow.FixedSizeBinaryType{ ByteWidth: dim / 8, } + case schemapb.DataType_Float16Vector: + return &arrow.FixedSizeBinaryType{ + ByteWidth: dim * 2, + } default: panic("unsupported data type") } diff --git a/internal/storage/pk_statistics.go b/internal/storage/pk_statistics.go index 0f62ed90a84e2..2278ee22b749c 100644 --- a/internal/storage/pk_statistics.go +++ b/internal/storage/pk_statistics.go @@ -19,9 +19,9 @@ package storage import ( "fmt" + "github.com/bits-and-blooms/bloom/v3" "github.com/cockroachdb/errors" - "github.com/bits-and-blooms/bloom/v3" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" ) @@ -102,7 +102,7 @@ func (st *PkStatistics) PkExist(pk PrimaryKey) bool { varCharPk := pk.(*VarCharPrimaryKey) return st.PkFilter.TestString(varCharPk.Value) default: - //TODO:: + // TODO:: } // no idea, just make it as false positive return true diff --git a/internal/storage/primary_key.go b/internal/storage/primary_key.go index afeb4d1f8945c..80f33bad89817 100644 --- a/internal/storage/primary_key.go +++ b/internal/storage/primary_key.go @@ -372,7 +372,7 @@ func ParseIDs2PrimaryKeys(ids *schemapb.IDs) []PrimaryKey { ret = append(ret, pk) } default: - //TODO:: + // TODO:: } return ret @@ -405,7 +405,7 @@ func ParsePrimaryKeys2IDs(pks []PrimaryKey) *schemapb.IDs { }, } default: - //TODO:: + // TODO:: } return ret diff --git a/internal/storage/print_binlog.go b/internal/storage/print_binlog.go index 36f3f85b63420..da3eafc968b18 100644 --- a/internal/storage/print_binlog.go +++ b/internal/storage/print_binlog.go @@ -23,11 +23,11 @@ import ( "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "golang.org/x/exp/mmap" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) // PrintBinlogFiles call printBinlogFile in turn for the file list specified by parameter fileList. @@ -43,7 +43,7 @@ func PrintBinlogFiles(fileList []string) error { // nolint func printBinlogFile(filename string) error { - fd, err := os.OpenFile(filename, os.O_RDONLY, 0400) + fd, err := os.OpenFile(filename, os.O_RDONLY, 0o400) if err != nil { return err } diff --git a/internal/storage/print_binlog_test.go b/internal/storage/print_binlog_test.go index d3d84648b2373..090a90a6fc38c 100644 --- a/internal/storage/print_binlog_test.go +++ b/internal/storage/print_binlog_test.go @@ -18,21 +18,20 @@ package storage import ( "fmt" - "io/ioutil" "os" "testing" "time" "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/etcdpb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/uniquegenerator" ) @@ -73,7 +72,7 @@ func TestPrintBinlogFilesInt64(t *testing.T) { assert.NoError(t, err) w.Close() - fd, err := ioutil.TempFile("", "binlog_int64.db") + fd, err := os.CreateTemp("", "binlog_int64.db") defer os.RemoveAll(fd.Name()) assert.NoError(t, err) num, err := fd.Write(buf) @@ -81,7 +80,6 @@ func TestPrintBinlogFilesInt64(t *testing.T) { assert.Equal(t, num, len(buf)) err = fd.Close() assert.NoError(t, err) - } func TestPrintBinlogFiles(t *testing.T) { diff --git a/internal/storage/remote_chunk_manager.go b/internal/storage/remote_chunk_manager.go new file mode 100644 index 0000000000000..6ba546a57e8c3 --- /dev/null +++ b/internal/storage/remote_chunk_manager.go @@ -0,0 +1,455 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "bytes" + "container/list" + "context" + "io" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" + "github.com/cockroachdb/errors" + minio "github.com/minio/minio-go/v7" + "go.uber.org/zap" + "golang.org/x/exp/mmap" + "golang.org/x/sync/errgroup" + + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/timerecord" +) + +const ( + CloudProviderGCP = "gcp" + CloudProviderAWS = "aws" + CloudProviderAliyun = "aliyun" + + CloudProviderAzure = "azure" +) + +type ObjectStorage interface { + GetObject(ctx context.Context, bucketName, objectName string, offset int64, size int64) (FileReader, error) + PutObject(ctx context.Context, bucketName, objectName string, reader io.Reader, objectSize int64) error + StatObject(ctx context.Context, bucketName, objectName string) (int64, error) + ListObjects(ctx context.Context, bucketName string, prefix string, recursive bool) (map[string]time.Time, error) + RemoveObject(ctx context.Context, bucketName, objectName string) error +} + +// RemoteChunkManager is responsible for read and write data stored in minio. +type RemoteChunkManager struct { + client ObjectStorage + + // ctx context.Context + bucketName string + rootPath string +} + +var _ ChunkManager = (*RemoteChunkManager)(nil) + +func NewRemoteChunkManager(ctx context.Context, c *config) (*RemoteChunkManager, error) { + var client ObjectStorage + var err error + if c.cloudProvider == CloudProviderAzure { + client, err = newAzureObjectStorageWithConfig(ctx, c) + } else { + client, err = newMinioObjectStorageWithConfig(ctx, c) + } + if err != nil { + return nil, err + } + mcm := &RemoteChunkManager{ + client: client, + bucketName: c.bucketName, + rootPath: strings.TrimLeft(c.rootPath, "/"), + } + log.Info("remote chunk manager init success.", zap.String("remote", c.cloudProvider), zap.String("bucketname", c.bucketName), zap.String("root", mcm.RootPath())) + return mcm, nil +} + +// RootPath returns minio root path. +func (mcm *RemoteChunkManager) RootPath() string { + return mcm.rootPath +} + +// Path returns the path of minio data if exists. +func (mcm *RemoteChunkManager) Path(ctx context.Context, filePath string) (string, error) { + exist, err := mcm.Exist(ctx, filePath) + if err != nil { + return "", err + } + if !exist { + return "", errors.New("minio file manage cannot be found with filePath:" + filePath) + } + return filePath, nil +} + +// Reader returns the path of minio data if exists. +func (mcm *RemoteChunkManager) Reader(ctx context.Context, filePath string) (FileReader, error) { + reader, err := mcm.getObject(ctx, mcm.bucketName, filePath, int64(0), int64(0)) + if err != nil { + log.Warn("failed to get object", zap.String("bucket", mcm.bucketName), zap.String("path", filePath), zap.Error(err)) + return nil, err + } + return reader, nil +} + +func (mcm *RemoteChunkManager) Size(ctx context.Context, filePath string) (int64, error) { + objectInfo, err := mcm.getObjectSize(ctx, mcm.bucketName, filePath) + if err != nil { + log.Warn("failed to stat object", zap.String("bucket", mcm.bucketName), zap.String("path", filePath), zap.Error(err)) + return 0, err + } + + return objectInfo, nil +} + +// Write writes the data to minio storage. +func (mcm *RemoteChunkManager) Write(ctx context.Context, filePath string, content []byte) error { + err := mcm.putObject(ctx, mcm.bucketName, filePath, bytes.NewReader(content), int64(len(content))) + if err != nil { + log.Warn("failed to put object", zap.String("bucket", mcm.bucketName), zap.String("path", filePath), zap.Error(err)) + return err + } + + metrics.PersistentDataKvSize.WithLabelValues(metrics.DataPutLabel).Observe(float64(len(content))) + return nil +} + +// MultiWrite saves multiple objects, the path is the key of @kvs. +// The object value is the value of @kvs. +func (mcm *RemoteChunkManager) MultiWrite(ctx context.Context, kvs map[string][]byte) error { + var el error + for key, value := range kvs { + err := mcm.Write(ctx, key, value) + if err != nil { + el = merr.Combine(el, errors.Wrapf(err, "failed to write %s", key)) + } + } + return el +} + +// Exist checks whether chunk is saved to minio storage. +func (mcm *RemoteChunkManager) Exist(ctx context.Context, filePath string) (bool, error) { + _, err := mcm.getObjectSize(ctx, mcm.bucketName, filePath) + if err != nil { + if IsErrNoSuchKey(err) { + return false, nil + } + log.Warn("failed to stat object", zap.String("bucket", mcm.bucketName), zap.String("path", filePath), zap.Error(err)) + return false, err + } + return true, nil +} + +// Read reads the minio storage data if exists. +func (mcm *RemoteChunkManager) Read(ctx context.Context, filePath string) ([]byte, error) { + object, err := mcm.getObject(ctx, mcm.bucketName, filePath, int64(0), int64(0)) + if err != nil { + log.Warn("failed to get object", zap.String("bucket", mcm.bucketName), zap.String("path", filePath), zap.Error(err)) + return nil, err + } + defer object.Close() + + // Prefetch object data + var empty []byte + _, err = object.Read(empty) + if err != nil { + errResponse := minio.ToErrorResponse(err) + if errResponse.Code == "NoSuchKey" { + return nil, WrapErrNoSuchKey(filePath) + } + log.Warn("failed to read object", zap.String("path", filePath), zap.Error(err)) + return nil, err + } + size, err := mcm.getObjectSize(ctx, mcm.bucketName, filePath) + if err != nil { + log.Warn("failed to stat object", zap.String("bucket", mcm.bucketName), zap.String("path", filePath), zap.Error(err)) + return nil, err + } + data, err := Read(object, size) + if err != nil { + errResponse := minio.ToErrorResponse(err) + if errResponse.Code == "NoSuchKey" { + return nil, WrapErrNoSuchKey(filePath) + } + log.Warn("failed to read object", zap.String("bucket", mcm.bucketName), zap.String("path", filePath), zap.Error(err)) + return nil, err + } + metrics.PersistentDataKvSize.WithLabelValues(metrics.DataGetLabel).Observe(float64(size)) + return data, nil +} + +func (mcm *RemoteChunkManager) MultiRead(ctx context.Context, keys []string) ([][]byte, error) { + var el error + var objectsValues [][]byte + for _, key := range keys { + objectValue, err := mcm.Read(ctx, key) + if err != nil { + el = merr.Combine(el, errors.Wrapf(err, "failed to read %s", key)) + } + objectsValues = append(objectsValues, objectValue) + } + + return objectsValues, el +} + +func (mcm *RemoteChunkManager) ReadWithPrefix(ctx context.Context, prefix string) ([]string, [][]byte, error) { + objectsKeys, _, err := mcm.ListWithPrefix(ctx, prefix, true) + if err != nil { + return nil, nil, err + } + objectsValues, err := mcm.MultiRead(ctx, objectsKeys) + if err != nil { + return nil, nil, err + } + + return objectsKeys, objectsValues, nil +} + +func (mcm *RemoteChunkManager) Mmap(ctx context.Context, filePath string) (*mmap.ReaderAt, error) { + return nil, errors.New("this method has not been implemented") +} + +// ReadAt reads specific position data of minio storage if exists. +func (mcm *RemoteChunkManager) ReadAt(ctx context.Context, filePath string, off int64, length int64) ([]byte, error) { + if off < 0 || length < 0 { + return nil, io.EOF + } + + object, err := mcm.getObject(ctx, mcm.bucketName, filePath, off, length) + if err != nil { + log.Warn("failed to get object", zap.String("bucket", mcm.bucketName), zap.String("path", filePath), zap.Error(err)) + return nil, err + } + defer object.Close() + + data, err := Read(object, length) + if err != nil { + errResponse := minio.ToErrorResponse(err) + if errResponse.Code == "NoSuchKey" { + return nil, WrapErrNoSuchKey(filePath) + } + log.Warn("failed to read object", zap.String("bucket", mcm.bucketName), zap.String("path", filePath), zap.Error(err)) + return nil, err + } + metrics.PersistentDataKvSize.WithLabelValues(metrics.DataGetLabel).Observe(float64(length)) + return data, nil +} + +// Remove deletes an object with @key. +func (mcm *RemoteChunkManager) Remove(ctx context.Context, filePath string) error { + err := mcm.removeObject(ctx, mcm.bucketName, filePath) + if err != nil { + log.Warn("failed to remove object", zap.String("bucket", mcm.bucketName), zap.String("path", filePath), zap.Error(err)) + return err + } + return nil +} + +// MultiRemove deletes a objects with @keys. +func (mcm *RemoteChunkManager) MultiRemove(ctx context.Context, keys []string) error { + var el error + for _, key := range keys { + err := mcm.Remove(ctx, key) + if err != nil { + el = merr.Combine(el, errors.Wrapf(err, "failed to remove %s", key)) + } + } + return el +} + +// RemoveWithPrefix removes all objects with the same prefix @prefix from minio. +func (mcm *RemoteChunkManager) RemoveWithPrefix(ctx context.Context, prefix string) error { + objects, err := mcm.listObjects(ctx, mcm.bucketName, prefix, true) + if err != nil { + return err + } + removeKeys := make([]string, 0) + for key := range objects { + removeKeys = append(removeKeys, key) + } + i := 0 + maxGoroutine := 10 + for i < len(removeKeys) { + runningGroup, groupCtx := errgroup.WithContext(ctx) + for j := 0; j < maxGoroutine && i < len(removeKeys); j++ { + key := removeKeys[i] + runningGroup.Go(func() error { + err := mcm.removeObject(groupCtx, mcm.bucketName, key) + if err != nil { + log.Warn("failed to remove object", zap.String("path", key), zap.Error(err)) + return err + } + return nil + }) + i++ + } + if err := runningGroup.Wait(); err != nil { + return err + } + } + return nil +} + +// ListWithPrefix returns objects with provided prefix. +// by default, if `recursive`=false, list object with return object with path under save level +// say minio has followinng objects: [a, ab, a/b, ab/c] +// calling `ListWithPrefix` with `prefix` = a && `recursive` = false will only returns [a, ab] +// If caller needs all objects without level limitation, `recursive` shall be true. +func (mcm *RemoteChunkManager) ListWithPrefix(ctx context.Context, prefix string, recursive bool) ([]string, []time.Time, error) { + // cannot use ListObjects(ctx, bucketName, Opt{Prefix:prefix, Recursive:true}) + // if minio has lots of objects under the provided path + // recursive = true may timeout during the recursive browsing the objects. + // See also: https://github.com/milvus-io/milvus/issues/19095 + + var objectsKeys []string + var modTimes []time.Time + + tasks := list.New() + tasks.PushBack(prefix) + for tasks.Len() > 0 { + e := tasks.Front() + pre := e.Value.(string) + tasks.Remove(e) + + // TODO add concurrent call if performance matters + // only return current level per call + objects, err := mcm.listObjects(ctx, mcm.bucketName, pre, false) + if err != nil { + return nil, nil, err + } + + for object, lastModified := range objects { + // with tailing "/", object is a "directory" + if strings.HasSuffix(object, "/") && recursive { + // enqueue when recursive is true + if object != pre { + tasks.PushBack(object) + } + continue + } + objectsKeys = append(objectsKeys, object) + modTimes = append(modTimes, lastModified) + } + } + + return objectsKeys, modTimes, nil +} + +func (mcm *RemoteChunkManager) getObject(ctx context.Context, bucketName, objectName string, + offset int64, size int64, +) (FileReader, error) { + start := timerecord.NewTimeRecorder("getObject") + + reader, err := mcm.client.GetObject(ctx, bucketName, objectName, offset, size) + metrics.PersistentDataOpCounter.WithLabelValues(metrics.DataGetLabel, metrics.TotalLabel).Inc() + if err == nil && reader != nil { + metrics.PersistentDataRequestLatency.WithLabelValues(metrics.DataGetLabel).Observe(float64(start.ElapseSpan().Milliseconds())) + metrics.PersistentDataOpCounter.WithLabelValues(metrics.DataGetLabel, metrics.SuccessLabel).Inc() + } else { + metrics.PersistentDataOpCounter.WithLabelValues(metrics.DataGetLabel, metrics.FailLabel).Inc() + } + + switch err := err.(type) { + case *azcore.ResponseError: + if err.ErrorCode == string(bloberror.BlobNotFound) { + return nil, WrapErrNoSuchKey(objectName) + } + case minio.ErrorResponse: + if err.Code == "NoSuchKey" { + return nil, WrapErrNoSuchKey(objectName) + } + } + + return reader, err +} + +func (mcm *RemoteChunkManager) putObject(ctx context.Context, bucketName, objectName string, reader io.Reader, objectSize int64) error { + start := timerecord.NewTimeRecorder("putObject") + + err := mcm.client.PutObject(ctx, bucketName, objectName, reader, objectSize) + metrics.PersistentDataOpCounter.WithLabelValues(metrics.DataPutLabel, metrics.TotalLabel).Inc() + if err == nil { + metrics.PersistentDataRequestLatency.WithLabelValues(metrics.DataPutLabel).Observe(float64(start.ElapseSpan().Milliseconds())) + metrics.PersistentDataOpCounter.WithLabelValues(metrics.MetaPutLabel, metrics.SuccessLabel).Inc() + } else { + metrics.PersistentDataOpCounter.WithLabelValues(metrics.MetaPutLabel, metrics.FailLabel).Inc() + } + + return err +} + +func (mcm *RemoteChunkManager) getObjectSize(ctx context.Context, bucketName, objectName string) (int64, error) { + start := timerecord.NewTimeRecorder("getObjectSize") + + info, err := mcm.client.StatObject(ctx, bucketName, objectName) + metrics.PersistentDataOpCounter.WithLabelValues(metrics.DataStatLabel, metrics.TotalLabel).Inc() + if err == nil { + metrics.PersistentDataRequestLatency.WithLabelValues(metrics.DataStatLabel).Observe(float64(start.ElapseSpan().Milliseconds())) + metrics.PersistentDataOpCounter.WithLabelValues(metrics.DataStatLabel, metrics.SuccessLabel).Inc() + } else { + metrics.PersistentDataOpCounter.WithLabelValues(metrics.DataStatLabel, metrics.FailLabel).Inc() + } + + switch err := err.(type) { + case *azcore.ResponseError: + if err.ErrorCode == string(bloberror.BlobNotFound) { + return info, WrapErrNoSuchKey(objectName) + } + case minio.ErrorResponse: + if err.Code == "NoSuchKey" { + return info, WrapErrNoSuchKey(objectName) + } + } + + return info, err +} + +func (mcm *RemoteChunkManager) listObjects(ctx context.Context, bucketName string, prefix string, recursive bool) (map[string]time.Time, error) { + start := timerecord.NewTimeRecorder("listObjects") + + res, err := mcm.client.ListObjects(ctx, bucketName, prefix, recursive) + metrics.PersistentDataOpCounter.WithLabelValues(metrics.DataListLabel, metrics.TotalLabel).Inc() + if err == nil { + metrics.PersistentDataRequestLatency.WithLabelValues(metrics.DataListLabel).Observe(float64(start.ElapseSpan().Milliseconds())) + metrics.PersistentDataOpCounter.WithLabelValues(metrics.DataListLabel, metrics.SuccessLabel).Inc() + } else { + log.Warn("failed to list with prefix", zap.String("bucket", mcm.bucketName), zap.String("prefix", prefix), zap.Error(err)) + metrics.PersistentDataOpCounter.WithLabelValues(metrics.DataListLabel, metrics.FailLabel).Inc() + } + return res, err +} + +func (mcm *RemoteChunkManager) removeObject(ctx context.Context, bucketName, objectName string) error { + start := timerecord.NewTimeRecorder("removeObject") + + err := mcm.client.RemoveObject(ctx, bucketName, objectName) + metrics.PersistentDataOpCounter.WithLabelValues(metrics.DataRemoveLabel, metrics.TotalLabel).Inc() + if err == nil { + metrics.PersistentDataRequestLatency.WithLabelValues(metrics.DataRemoveLabel).Observe(float64(start.ElapseSpan().Milliseconds())) + metrics.PersistentDataOpCounter.WithLabelValues(metrics.DataRemoveLabel, metrics.SuccessLabel).Inc() + } else { + metrics.PersistentDataOpCounter.WithLabelValues(metrics.DataRemoveLabel, metrics.FailLabel).Inc() + } + + return err +} diff --git a/internal/storage/remote_chunk_manager_test.go b/internal/storage/remote_chunk_manager_test.go new file mode 100644 index 0000000000000..527fe4225bbbc --- /dev/null +++ b/internal/storage/remote_chunk_manager_test.go @@ -0,0 +1,971 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "path" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TODO: NewRemoteChunkManager is deprecated. Rewrite this unittest. +func newMinioChunkManager(ctx context.Context, bucketName string, rootPath string) (ChunkManager, error) { + return newRemoteChunkManager(ctx, "minio", bucketName, rootPath) +} + +func newAzureChunkManager(ctx context.Context, bucketName string, rootPath string) (ChunkManager, error) { + return newRemoteChunkManager(ctx, "azure", bucketName, rootPath) +} + +func newRemoteChunkManager(ctx context.Context, cloudProvider string, bucketName string, rootPath string) (ChunkManager, error) { + factory := NewChunkManagerFactory("remote", + RootPath(rootPath), + Address(Params.MinioCfg.Address.GetValue()), + AccessKeyID(Params.MinioCfg.AccessKeyID.GetValue()), + SecretAccessKeyID(Params.MinioCfg.SecretAccessKey.GetValue()), + UseSSL(Params.MinioCfg.UseSSL.GetAsBool()), + BucketName(bucketName), + UseIAM(Params.MinioCfg.UseIAM.GetAsBool()), + CloudProvider(cloudProvider), + IAMEndpoint(Params.MinioCfg.IAMEndpoint.GetValue()), + CreateBucket(true)) + return factory.NewPersistentStorageChunkManager(ctx) +} + +func TestInitRemoteChunkManager(t *testing.T) { + ctx := context.Background() + client, err := NewRemoteChunkManager(ctx, &config{ + bucketName: Params.MinioCfg.BucketName.GetValue(), + createBucket: true, + useIAM: false, + cloudProvider: "azure", + }) + assert.NoError(t, err) + assert.NotNil(t, client) +} + +func TestMinioChunkManager(t *testing.T) { + testBucket := Params.MinioCfg.BucketName.GetValue() + + configRoot := Params.MinioCfg.RootPath.GetValue() + + testMinIOKVRoot := path.Join(configRoot, "milvus-minio-ut-root") + + t.Run("test load", func(t *testing.T) { + testLoadRoot := path.Join(testMinIOKVRoot, "test_load") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newMinioChunkManager(ctx, testBucket, testLoadRoot) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testLoadRoot) + + assert.Equal(t, testLoadRoot, testCM.RootPath()) + + prepareTests := []struct { + key string + value []byte + }{ + {"abc", []byte("123")}, + {"abcd", []byte("1234")}, + {"key_1", []byte("111")}, + {"key_2", []byte("222")}, + {"key_3", []byte("333")}, + } + + for _, test := range prepareTests { + err = testCM.Write(ctx, path.Join(testLoadRoot, test.key), test.value) + require.NoError(t, err) + } + + loadTests := []struct { + isvalid bool + loadKey string + expectedValue []byte + + description string + }{ + {true, "abc", []byte("123"), "load valid key abc"}, + {true, "abcd", []byte("1234"), "load valid key abcd"}, + {true, "key_1", []byte("111"), "load valid key key_1"}, + {true, "key_2", []byte("222"), "load valid key key_2"}, + {true, "key_3", []byte("333"), "load valid key key_3"}, + {false, "key_not_exist", []byte(""), "load invalid key key_not_exist"}, + {false, "/", []byte(""), "load leading slash"}, + } + + for _, test := range loadTests { + t.Run(test.description, func(t *testing.T) { + if test.isvalid { + got, err := testCM.Read(ctx, path.Join(testLoadRoot, test.loadKey)) + assert.NoError(t, err) + assert.Equal(t, test.expectedValue, got) + } else { + if test.loadKey == "/" { + got, err := testCM.Read(ctx, test.loadKey) + assert.Error(t, err) + assert.Empty(t, got) + return + } + got, err := testCM.Read(ctx, path.Join(testLoadRoot, test.loadKey)) + assert.Error(t, err) + assert.Empty(t, got) + } + }) + } + + loadWithPrefixTests := []struct { + isvalid bool + prefix string + expectedValue [][]byte + + description string + }{ + {true, "abc", [][]byte{[]byte("123"), []byte("1234")}, "load with valid prefix abc"}, + {true, "key_", [][]byte{[]byte("111"), []byte("222"), []byte("333")}, "load with valid prefix key_"}, + {true, "prefix", [][]byte{}, "load with valid but not exist prefix prefix"}, + } + + for _, test := range loadWithPrefixTests { + t.Run(test.description, func(t *testing.T) { + gotk, gotv, err := testCM.ReadWithPrefix(ctx, path.Join(testLoadRoot, test.prefix)) + assert.NoError(t, err) + assert.Equal(t, len(test.expectedValue), len(gotk)) + assert.Equal(t, len(test.expectedValue), len(gotv)) + assert.ElementsMatch(t, test.expectedValue, gotv) + }) + } + + multiLoadTests := []struct { + isvalid bool + multiKeys []string + + expectedValue [][]byte + description string + }{ + {false, []string{"key_1", "key_not_exist"}, [][]byte{[]byte("111"), nil}, "multiload 1 exist 1 not"}, + {true, []string{"abc", "key_3"}, [][]byte{[]byte("123"), []byte("333")}, "multiload 2 exist"}, + } + + for _, test := range multiLoadTests { + t.Run(test.description, func(t *testing.T) { + for i := range test.multiKeys { + test.multiKeys[i] = path.Join(testLoadRoot, test.multiKeys[i]) + } + if test.isvalid { + got, err := testCM.MultiRead(ctx, test.multiKeys) + assert.NoError(t, err) + assert.Equal(t, test.expectedValue, got) + } else { + got, err := testCM.MultiRead(ctx, test.multiKeys) + assert.Error(t, err) + assert.Equal(t, test.expectedValue, got) + } + }) + } + }) + + t.Run("test MultiSave", func(t *testing.T) { + testMultiSaveRoot := path.Join(testMinIOKVRoot, "test_multisave") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newMinioChunkManager(ctx, testBucket, testMultiSaveRoot) + assert.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testMultiSaveRoot) + + err = testCM.Write(ctx, path.Join(testMultiSaveRoot, "key_1"), []byte("111")) + assert.NoError(t, err) + + kvs := map[string][]byte{ + path.Join(testMultiSaveRoot, "key_1"): []byte("123"), + path.Join(testMultiSaveRoot, "key_2"): []byte("456"), + } + + err = testCM.MultiWrite(ctx, kvs) + assert.NoError(t, err) + + val, err := testCM.Read(ctx, path.Join(testMultiSaveRoot, "key_1")) + assert.NoError(t, err) + assert.Equal(t, []byte("123"), val) + + reader, err := testCM.Reader(ctx, path.Join(testMultiSaveRoot, "key_1")) + assert.NoError(t, err) + reader.Close() + }) + + t.Run("test Remove", func(t *testing.T) { + testRemoveRoot := path.Join(testMinIOKVRoot, "test_remove") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newMinioChunkManager(ctx, testBucket, testRemoveRoot) + assert.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testRemoveRoot) + + prepareTests := []struct { + k string + v []byte + }{ + {"key_1", []byte("123")}, + {"key_2", []byte("456")}, + {"mkey_1", []byte("111")}, + {"mkey_2", []byte("222")}, + {"mkey_3", []byte("333")}, + {"key_prefix_1", []byte("111")}, + {"key_prefix_2", []byte("222")}, + {"key_prefix_3", []byte("333")}, + } + + for _, test := range prepareTests { + k := path.Join(testRemoveRoot, test.k) + err = testCM.Write(ctx, k, test.v) + require.NoError(t, err) + } + + removeTests := []struct { + removeKey string + valueBeforeRemove []byte + + description string + }{ + {"key_1", []byte("123"), "remove key_1"}, + {"key_2", []byte("456"), "remove key_2"}, + } + + for _, test := range removeTests { + t.Run(test.description, func(t *testing.T) { + k := path.Join(testRemoveRoot, test.removeKey) + v, err := testCM.Read(ctx, k) + require.NoError(t, err) + require.Equal(t, test.valueBeforeRemove, v) + + err = testCM.Remove(ctx, k) + assert.NoError(t, err) + + v, err = testCM.Read(ctx, k) + require.Error(t, err) + require.Empty(t, v) + }) + } + + multiRemoveTest := []string{ + path.Join(testRemoveRoot, "mkey_1"), + path.Join(testRemoveRoot, "mkey_2"), + path.Join(testRemoveRoot, "mkey_3"), + } + + lv, err := testCM.MultiRead(ctx, multiRemoveTest) + require.NoError(t, err) + require.ElementsMatch(t, [][]byte{[]byte("111"), []byte("222"), []byte("333")}, lv) + + err = testCM.MultiRemove(ctx, multiRemoveTest) + assert.NoError(t, err) + + for _, k := range multiRemoveTest { + v, err := testCM.Read(ctx, k) + assert.Error(t, err) + assert.Empty(t, v) + } + + removeWithPrefixTest := []string{ + path.Join(testRemoveRoot, "key_prefix_1"), + path.Join(testRemoveRoot, "key_prefix_2"), + path.Join(testRemoveRoot, "key_prefix_3"), + } + removePrefix := path.Join(testRemoveRoot, "key_prefix") + + lv, err = testCM.MultiRead(ctx, removeWithPrefixTest) + require.NoError(t, err) + require.ElementsMatch(t, [][]byte{[]byte("111"), []byte("222"), []byte("333")}, lv) + + err = testCM.RemoveWithPrefix(ctx, removePrefix) + assert.NoError(t, err) + + for _, k := range removeWithPrefixTest { + v, err := testCM.Read(ctx, k) + assert.Error(t, err) + assert.Empty(t, v) + } + }) + + t.Run("test ReadAt", func(t *testing.T) { + testLoadPartialRoot := path.Join(testMinIOKVRoot, "load_partial") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newMinioChunkManager(ctx, testBucket, testLoadPartialRoot) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testLoadPartialRoot) + + key := path.Join(testLoadPartialRoot, "TestMinIOKV_LoadPartial_key") + value := []byte("TestMinIOKV_LoadPartial_value") + + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + var off, length int64 + var partial []byte + + off, length = 1, 1 + partial, err = testCM.ReadAt(ctx, key, off, length) + assert.NoError(t, err) + assert.ElementsMatch(t, partial, value[off:off+length]) + + off, length = 0, int64(len(value)) + partial, err = testCM.ReadAt(ctx, key, off, length) + assert.NoError(t, err) + assert.ElementsMatch(t, partial, value[off:off+length]) + + // error case + off, length = 5, -2 + _, err = testCM.ReadAt(ctx, key, off, length) + assert.Error(t, err) + + off, length = -1, 2 + _, err = testCM.ReadAt(ctx, key, off, length) + assert.Error(t, err) + + off, length = 1, -2 + _, err = testCM.ReadAt(ctx, key, off, length) + assert.Error(t, err) + + err = testCM.Remove(ctx, key) + assert.NoError(t, err) + off, length = 1, 1 + _, err = testCM.ReadAt(ctx, key, off, length) + assert.Error(t, err) + }) + + t.Run("test Size", func(t *testing.T) { + testGetSizeRoot := path.Join(testMinIOKVRoot, "get_size") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newMinioChunkManager(ctx, testBucket, testGetSizeRoot) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testGetSizeRoot) + + key := path.Join(testGetSizeRoot, "TestMinIOKV_GetSize_key") + value := []byte("TestMinIOKV_GetSize_value") + + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + size, err := testCM.Size(ctx, key) + assert.NoError(t, err) + assert.Equal(t, size, int64(len(value))) + + key2 := path.Join(testGetSizeRoot, "TestMemoryKV_GetSize_key2") + + size, err = testCM.Size(ctx, key2) + assert.Error(t, err) + assert.Equal(t, int64(0), size) + }) + + t.Run("test Path", func(t *testing.T) { + testGetPathRoot := path.Join(testMinIOKVRoot, "get_path") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newMinioChunkManager(ctx, testBucket, testGetPathRoot) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testGetPathRoot) + + key := path.Join(testGetPathRoot, "TestMinIOKV_GetSize_key") + value := []byte("TestMinIOKV_GetSize_value") + + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + p, err := testCM.Path(ctx, key) + assert.NoError(t, err) + assert.Equal(t, p, key) + + key2 := path.Join(testGetPathRoot, "TestMemoryKV_GetSize_key2") + + p, err = testCM.Path(ctx, key2) + assert.Error(t, err) + assert.Equal(t, p, "") + }) + + t.Run("test Mmap", func(t *testing.T) { + testMmapRoot := path.Join(testMinIOKVRoot, "mmap") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newMinioChunkManager(ctx, testBucket, testMmapRoot) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testMmapRoot) + + key := path.Join(testMmapRoot, "TestMinIOKV_GetSize_key") + value := []byte("TestMinIOKV_GetSize_value") + + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + r, err := testCM.Mmap(ctx, key) + assert.Error(t, err) + assert.Nil(t, r) + }) + + t.Run("test Prefix", func(t *testing.T) { + testPrefix := path.Join(testMinIOKVRoot, "prefix") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newMinioChunkManager(ctx, testBucket, testPrefix) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testPrefix) + + pathB := path.Join("a", "b") + + key := path.Join(testPrefix, pathB) + value := []byte("a") + + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + pathC := path.Join("a", "c") + key = path.Join(testPrefix, pathC) + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + pathPrefix := path.Join(testPrefix, "a") + r, m, err := testCM.ListWithPrefix(ctx, pathPrefix, true) + assert.NoError(t, err) + assert.Equal(t, len(r), 2) + assert.Equal(t, len(m), 2) + + key = path.Join(testPrefix, "b", "b", "b") + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + key = path.Join(testPrefix, "b", "a", "b") + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + key = path.Join(testPrefix, "bc", "a", "b") + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + dirs, mods, err := testCM.ListWithPrefix(ctx, testPrefix+"/", true) + assert.NoError(t, err) + assert.Equal(t, 5, len(dirs)) + assert.Equal(t, 5, len(mods)) + + dirs, mods, err = testCM.ListWithPrefix(ctx, path.Join(testPrefix, "b"), true) + assert.NoError(t, err) + assert.Equal(t, 3, len(dirs)) + assert.Equal(t, 3, len(mods)) + + testCM.RemoveWithPrefix(ctx, testPrefix) + r, m, err = testCM.ListWithPrefix(ctx, pathPrefix, true) + assert.NoError(t, err) + assert.Equal(t, 0, len(r)) + assert.Equal(t, 0, len(m)) + + // test wrong prefix + b := make([]byte, 2048) + pathWrong := path.Join(testPrefix, string(b)) + _, _, err = testCM.ListWithPrefix(ctx, pathWrong, true) + assert.Error(t, err) + }) + + t.Run("test NoSuchKey", func(t *testing.T) { + testPrefix := path.Join(testMinIOKVRoot, "nokey") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newMinioChunkManager(ctx, testBucket, testPrefix) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testPrefix) + + key := "a" + + _, err = testCM.Read(ctx, key) + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrNoSuchKey)) + + file, err := testCM.Reader(ctx, key) + assert.NoError(t, err) // todo + file.Close() + + _, err = testCM.ReadAt(ctx, key, 100, 1) + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrNoSuchKey)) + }) +} + +func TestAzureChunkManager(t *testing.T) { + testBucket := Params.MinioCfg.BucketName.GetValue() + + configRoot := Params.MinioCfg.RootPath.GetValue() + + testMinIOKVRoot := path.Join(configRoot, "milvus-minio-ut-root") + + t.Run("test load", func(t *testing.T) { + testLoadRoot := path.Join(testMinIOKVRoot, "test_load") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newAzureChunkManager(ctx, testBucket, testLoadRoot) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testLoadRoot) + + assert.Equal(t, testLoadRoot, testCM.RootPath()) + + prepareTests := []struct { + key string + value []byte + }{ + {"abc", []byte("123")}, + {"abcd", []byte("1234")}, + {"key_1", []byte("111")}, + {"key_2", []byte("222")}, + {"key_3", []byte("333")}, + } + + for _, test := range prepareTests { + err = testCM.Write(ctx, path.Join(testLoadRoot, test.key), test.value) + require.NoError(t, err) + } + + loadTests := []struct { + isvalid bool + loadKey string + expectedValue []byte + + description string + }{ + {true, "abc", []byte("123"), "load valid key abc"}, + {true, "abcd", []byte("1234"), "load valid key abcd"}, + {true, "key_1", []byte("111"), "load valid key key_1"}, + {true, "key_2", []byte("222"), "load valid key key_2"}, + {true, "key_3", []byte("333"), "load valid key key_3"}, + {false, "key_not_exist", []byte(""), "load invalid key key_not_exist"}, + {false, "/", []byte(""), "load leading slash"}, + } + + for _, test := range loadTests { + t.Run(test.description, func(t *testing.T) { + if test.isvalid { + got, err := testCM.Read(ctx, path.Join(testLoadRoot, test.loadKey)) + assert.NoError(t, err) + assert.Equal(t, test.expectedValue, got) + } else { + if test.loadKey == "/" { + got, err := testCM.Read(ctx, test.loadKey) + assert.Error(t, err) + assert.Empty(t, got) + return + } + got, err := testCM.Read(ctx, path.Join(testLoadRoot, test.loadKey)) + assert.Error(t, err) + assert.Empty(t, got) + } + }) + } + + loadWithPrefixTests := []struct { + isvalid bool + prefix string + expectedValue [][]byte + + description string + }{ + {true, "abc", [][]byte{[]byte("123"), []byte("1234")}, "load with valid prefix abc"}, + {true, "key_", [][]byte{[]byte("111"), []byte("222"), []byte("333")}, "load with valid prefix key_"}, + {true, "prefix", [][]byte{}, "load with valid but not exist prefix prefix"}, + } + + for _, test := range loadWithPrefixTests { + t.Run(test.description, func(t *testing.T) { + gotk, gotv, err := testCM.ReadWithPrefix(ctx, path.Join(testLoadRoot, test.prefix)) + assert.NoError(t, err) + assert.Equal(t, len(test.expectedValue), len(gotk)) + assert.Equal(t, len(test.expectedValue), len(gotv)) + assert.ElementsMatch(t, test.expectedValue, gotv) + }) + } + + multiLoadTests := []struct { + isvalid bool + multiKeys []string + + expectedValue [][]byte + description string + }{ + {false, []string{"key_1", "key_not_exist"}, [][]byte{[]byte("111"), nil}, "multiload 1 exist 1 not"}, + {true, []string{"abc", "key_3"}, [][]byte{[]byte("123"), []byte("333")}, "multiload 2 exist"}, + } + + for _, test := range multiLoadTests { + t.Run(test.description, func(t *testing.T) { + for i := range test.multiKeys { + test.multiKeys[i] = path.Join(testLoadRoot, test.multiKeys[i]) + } + if test.isvalid { + got, err := testCM.MultiRead(ctx, test.multiKeys) + assert.NoError(t, err) + assert.Equal(t, test.expectedValue, got) + } else { + got, err := testCM.MultiRead(ctx, test.multiKeys) + assert.Error(t, err) + assert.Equal(t, test.expectedValue, got) + } + }) + } + }) + + t.Run("test MultiSave", func(t *testing.T) { + testMultiSaveRoot := path.Join(testMinIOKVRoot, "test_multisave") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newAzureChunkManager(ctx, testBucket, testMultiSaveRoot) + assert.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testMultiSaveRoot) + + err = testCM.Write(ctx, path.Join(testMultiSaveRoot, "key_1"), []byte("111")) + assert.NoError(t, err) + + kvs := map[string][]byte{ + path.Join(testMultiSaveRoot, "key_1"): []byte("123"), + path.Join(testMultiSaveRoot, "key_2"): []byte("456"), + } + + err = testCM.MultiWrite(ctx, kvs) + assert.NoError(t, err) + + val, err := testCM.Read(ctx, path.Join(testMultiSaveRoot, "key_1")) + assert.NoError(t, err) + assert.Equal(t, []byte("123"), val) + + reader, err := testCM.Reader(ctx, path.Join(testMultiSaveRoot, "key_1")) + assert.NoError(t, err) + reader.Close() + }) + + t.Run("test Remove", func(t *testing.T) { + testRemoveRoot := path.Join(testMinIOKVRoot, "test_remove") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newAzureChunkManager(ctx, testBucket, testRemoveRoot) + assert.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testRemoveRoot) + + prepareTests := []struct { + k string + v []byte + }{ + {"key_1", []byte("123")}, + {"key_2", []byte("456")}, + {"mkey_1", []byte("111")}, + {"mkey_2", []byte("222")}, + {"mkey_3", []byte("333")}, + {"key_prefix_1", []byte("111")}, + {"key_prefix_2", []byte("222")}, + {"key_prefix_3", []byte("333")}, + } + + for _, test := range prepareTests { + k := path.Join(testRemoveRoot, test.k) + err = testCM.Write(ctx, k, test.v) + require.NoError(t, err) + } + + removeTests := []struct { + removeKey string + valueBeforeRemove []byte + + description string + }{ + {"key_1", []byte("123"), "remove key_1"}, + {"key_2", []byte("456"), "remove key_2"}, + } + + for _, test := range removeTests { + t.Run(test.description, func(t *testing.T) { + k := path.Join(testRemoveRoot, test.removeKey) + v, err := testCM.Read(ctx, k) + require.NoError(t, err) + require.Equal(t, test.valueBeforeRemove, v) + + err = testCM.Remove(ctx, k) + assert.NoError(t, err) + + v, err = testCM.Read(ctx, k) + require.Error(t, err) + require.Empty(t, v) + }) + } + + multiRemoveTest := []string{ + path.Join(testRemoveRoot, "mkey_1"), + path.Join(testRemoveRoot, "mkey_2"), + path.Join(testRemoveRoot, "mkey_3"), + } + + lv, err := testCM.MultiRead(ctx, multiRemoveTest) + require.NoError(t, err) + require.ElementsMatch(t, [][]byte{[]byte("111"), []byte("222"), []byte("333")}, lv) + + err = testCM.MultiRemove(ctx, multiRemoveTest) + assert.NoError(t, err) + + for _, k := range multiRemoveTest { + v, err := testCM.Read(ctx, k) + assert.Error(t, err) + assert.Empty(t, v) + } + + removeWithPrefixTest := []string{ + path.Join(testRemoveRoot, "key_prefix_1"), + path.Join(testRemoveRoot, "key_prefix_2"), + path.Join(testRemoveRoot, "key_prefix_3"), + } + removePrefix := path.Join(testRemoveRoot, "key_prefix") + + lv, err = testCM.MultiRead(ctx, removeWithPrefixTest) + require.NoError(t, err) + require.ElementsMatch(t, [][]byte{[]byte("111"), []byte("222"), []byte("333")}, lv) + + err = testCM.RemoveWithPrefix(ctx, removePrefix) + assert.NoError(t, err) + + for _, k := range removeWithPrefixTest { + v, err := testCM.Read(ctx, k) + assert.Error(t, err) + assert.Empty(t, v) + } + }) + + t.Run("test ReadAt", func(t *testing.T) { + testLoadPartialRoot := path.Join(testMinIOKVRoot, "load_partial") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newAzureChunkManager(ctx, testBucket, testLoadPartialRoot) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testLoadPartialRoot) + + key := path.Join(testLoadPartialRoot, "TestMinIOKV_LoadPartial_key") + value := []byte("TestMinIOKV_LoadPartial_value") + + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + var off, length int64 + var partial []byte + + off, length = 1, 1 + partial, err = testCM.ReadAt(ctx, key, off, length) + assert.NoError(t, err) + assert.ElementsMatch(t, partial, value[off:off+length]) + + off, length = 0, int64(len(value)) + partial, err = testCM.ReadAt(ctx, key, off, length) + assert.NoError(t, err) + assert.ElementsMatch(t, partial, value[off:off+length]) + + // error case + off, length = 5, -2 + _, err = testCM.ReadAt(ctx, key, off, length) + assert.Error(t, err) + + off, length = -1, 2 + _, err = testCM.ReadAt(ctx, key, off, length) + assert.Error(t, err) + + off, length = 1, -2 + _, err = testCM.ReadAt(ctx, key, off, length) + assert.Error(t, err) + + err = testCM.Remove(ctx, key) + assert.NoError(t, err) + off, length = 1, 1 + _, err = testCM.ReadAt(ctx, key, off, length) + assert.Error(t, err) + }) + + t.Run("test Size", func(t *testing.T) { + testGetSizeRoot := path.Join(testMinIOKVRoot, "get_size") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newAzureChunkManager(ctx, testBucket, testGetSizeRoot) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testGetSizeRoot) + + key := path.Join(testGetSizeRoot, "TestMinIOKV_GetSize_key") + value := []byte("TestMinIOKV_GetSize_value") + + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + size, err := testCM.Size(ctx, key) + assert.NoError(t, err) + assert.Equal(t, size, int64(len(value))) + + key2 := path.Join(testGetSizeRoot, "TestMemoryKV_GetSize_key2") + + size, err = testCM.Size(ctx, key2) + assert.Error(t, err) + assert.Equal(t, int64(0), size) + }) + + t.Run("test Path", func(t *testing.T) { + testGetPathRoot := path.Join(testMinIOKVRoot, "get_path") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newAzureChunkManager(ctx, testBucket, testGetPathRoot) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testGetPathRoot) + + key := path.Join(testGetPathRoot, "TestMinIOKV_GetSize_key") + value := []byte("TestMinIOKV_GetSize_value") + + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + p, err := testCM.Path(ctx, key) + assert.NoError(t, err) + assert.Equal(t, p, key) + + key2 := path.Join(testGetPathRoot, "TestMemoryKV_GetSize_key2") + + p, err = testCM.Path(ctx, key2) + assert.Error(t, err) + assert.Equal(t, p, "") + }) + + t.Run("test Mmap", func(t *testing.T) { + testMmapRoot := path.Join(testMinIOKVRoot, "mmap") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newAzureChunkManager(ctx, testBucket, testMmapRoot) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testMmapRoot) + + key := path.Join(testMmapRoot, "TestMinIOKV_GetSize_key") + value := []byte("TestMinIOKV_GetSize_value") + + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + r, err := testCM.Mmap(ctx, key) + assert.Error(t, err) + assert.Nil(t, r) + }) + + t.Run("test Prefix", func(t *testing.T) { + testPrefix := path.Join(testMinIOKVRoot, "prefix") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newAzureChunkManager(ctx, testBucket, testPrefix) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testPrefix) + + pathB := path.Join("a", "b") + + key := path.Join(testPrefix, pathB) + value := []byte("a") + + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + pathC := path.Join("a", "c") + key = path.Join(testPrefix, pathC) + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + pathPrefix := path.Join(testPrefix, "a") + r, m, err := testCM.ListWithPrefix(ctx, pathPrefix, true) + assert.NoError(t, err) + assert.Equal(t, len(r), 2) + assert.Equal(t, len(m), 2) + + key = path.Join(testPrefix, "b", "b", "b") + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + key = path.Join(testPrefix, "b", "a", "b") + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + + key = path.Join(testPrefix, "bc", "a", "b") + err = testCM.Write(ctx, key, value) + assert.NoError(t, err) + dirs, mods, err := testCM.ListWithPrefix(ctx, testPrefix+"/", true) + assert.NoError(t, err) + assert.Equal(t, 5, len(dirs)) + assert.Equal(t, 5, len(mods)) + + dirs, mods, err = testCM.ListWithPrefix(ctx, path.Join(testPrefix, "b"), true) + assert.NoError(t, err) + assert.Equal(t, 3, len(dirs)) + assert.Equal(t, 3, len(mods)) + + testCM.RemoveWithPrefix(ctx, testPrefix) + r, m, err = testCM.ListWithPrefix(ctx, pathPrefix, true) + assert.NoError(t, err) + assert.Equal(t, 0, len(r)) + assert.Equal(t, 0, len(m)) + + // test wrong prefix + b := make([]byte, 2048) + pathWrong := path.Join(testPrefix, string(b)) + _, _, err = testCM.ListWithPrefix(ctx, pathWrong, true) + assert.Error(t, err) + }) + + t.Run("test NoSuchKey", func(t *testing.T) { + testPrefix := path.Join(testMinIOKVRoot, "nokey") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + testCM, err := newAzureChunkManager(ctx, testBucket, testPrefix) + require.NoError(t, err) + defer testCM.RemoveWithPrefix(ctx, testPrefix) + + key := "a" + + _, err = testCM.Read(ctx, key) + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrNoSuchKey)) + + _, err = testCM.Reader(ctx, key) + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrNoSuchKey)) + + _, err = testCM.ReadAt(ctx, key, 100, 1) + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrNoSuchKey)) + }) +} diff --git a/internal/storage/stats.go b/internal/storage/stats.go index c4d24eb9984c0..19522a042305e 100644 --- a/internal/storage/stats.go +++ b/internal/storage/stats.go @@ -21,6 +21,7 @@ import ( "fmt" "github.com/bits-and-blooms/bloom/v3" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -37,7 +38,7 @@ const ( type PrimaryKeyStats struct { FieldID int64 `json:"fieldID"` Max int64 `json:"max"` // useless, will delete - Min int64 `json:"min"` //useless, will delete + Min int64 `json:"min"` // useless, will delete BF *bloom.BloomFilter `json:"bf"` PkType int64 `json:"pkType"` MaxPk PrimaryKey `json:"maxPk"` @@ -154,7 +155,7 @@ func (stats *PrimaryKeyStats) UpdateByMsgs(msgs FieldData) { stats.BF.AddString(str) } default: - //TODO:: + // TODO:: } } @@ -172,7 +173,6 @@ func (stats *PrimaryKeyStats) Update(pk PrimaryKey) { default: log.Warn("Update pk stats with invalid data type") } - } // updatePk update minPk and maxPk value diff --git a/internal/storage/utils.go b/internal/storage/utils.go index f2e93690cf856..cc247aa88ddf1 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -25,7 +25,6 @@ import ( "strconv" "github.com/cockroachdb/errors" - "github.com/golang/protobuf/proto" "go.uber.org/zap" @@ -207,7 +206,7 @@ func ReadBinary(reader io.Reader, receiver interface{}, dataType schemapb.DataTy func readFloatVectors(blobReaders []io.Reader, dim int) []float32 { ret := make([]float32, 0) for _, r := range blobReaders { - var v = make([]float32, dim) + v := make([]float32, dim) ReadBinary(r, &v, schemapb.DataType_FloatVector) ret = append(ret, v...) } @@ -217,13 +216,23 @@ func readFloatVectors(blobReaders []io.Reader, dim int) []float32 { func readBinaryVectors(blobReaders []io.Reader, dim int) []byte { ret := make([]byte, 0) for _, r := range blobReaders { - var v = make([]byte, dim/8) + v := make([]byte, dim/8) ReadBinary(r, &v, schemapb.DataType_BinaryVector) ret = append(ret, v...) } return ret } +func readFloat16Vectors(blobReaders []io.Reader, dim int) []byte { + ret := make([]byte, 0) + for _, r := range blobReaders { + v := make([]byte, dim*2) + ReadBinary(r, &v, schemapb.DataType_Float16Vector) + ret = append(ret, v...) + } + return ret +} + func readBoolArray(blobReaders []io.Reader) []bool { ret := make([]bool, 0) for _, r := range blobReaders { @@ -321,6 +330,19 @@ func RowBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *schemap Dim: dim, } + case schemapb.DataType_Float16Vector: + dim, err := GetDimFromParams(field.TypeParams) + if err != nil { + log.Error("failed to get dim", zap.Error(err)) + return nil, err + } + + vecs := readFloat16Vectors(blobReaders, dim) + idata.Data[field.FieldID] = &Float16VectorFieldData{ + Data: vecs, + Dim: dim, + } + case schemapb.DataType_BinaryVector: var dim int dim, err := GetDimFromParams(field.TypeParams) @@ -435,6 +457,23 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche idata.Data[field.FieldID] = fieldData + case schemapb.DataType_Float16Vector: + dim, err := GetDimFromParams(field.TypeParams) + if err != nil { + log.Error("failed to get dim", zap.Error(err)) + return nil, err + } + + srcData := srcFields[field.FieldID].GetVectors().GetFloat16Vector() + + fieldData := &Float16VectorFieldData{ + Data: make([]byte, 0, len(srcData)), + Dim: dim, + } + fieldData.Data = append(fieldData.Data, srcData...) + + idata.Data[field.FieldID] = fieldData + case schemapb.DataType_Bool: srcData := srcFields[field.FieldID].GetScalars().GetBoolData().GetData() @@ -698,6 +737,18 @@ func mergeFloatVectorField(data *InsertData, fid FieldID, field *FloatVectorFiel fieldData.Data = append(fieldData.Data, field.Data...) } +func mergeFloat16VectorField(data *InsertData, fid FieldID, field *Float16VectorFieldData) { + if _, ok := data.Data[fid]; !ok { + fieldData := &Float16VectorFieldData{ + Data: nil, + Dim: field.Dim, + } + data.Data[fid] = fieldData + } + fieldData := data.Data[fid].(*Float16VectorFieldData) + fieldData.Data = append(fieldData.Data, field.Data...) +} + // MergeFieldData merge field into data. func MergeFieldData(data *InsertData, fid FieldID, field FieldData) { if field == nil { @@ -728,6 +779,8 @@ func MergeFieldData(data *InsertData, fid FieldID, field FieldData) { mergeBinaryVectorField(data, fid, field) case *FloatVectorFieldData: mergeFloatVectorField(data, fid, field) + case *Float16VectorFieldData: + mergeFloat16VectorField(data, fid, field) } } @@ -777,7 +830,7 @@ func GetPkFromInsertData(collSchema *schemapb.CollectionSchema, data *InsertData case schemapb.DataType_VarChar: realPfData, ok = pfData.(*StringFieldData) default: - //TODO + // TODO } if !ok { log.Warn("primary field not in Int64 or VarChar format", zap.Int64("fieldID", pf.FieldID)) diff --git a/internal/storage/utils_test.go b/internal/storage/utils_test.go index 77881f73bde16..f9458ca861e05 100644 --- a/internal/storage/utils_test.go +++ b/internal/storage/utils_test.go @@ -169,6 +169,10 @@ func TestTransferColumnBasedInsertDataToRowBased(t *testing.T) { Dim: 1, Data: []float32{0, 0, 0}, } + f11 := &Float16VectorFieldData{ + Dim: 1, + Data: []byte{1, 1, 2, 2, 3, 3}, + } data.Data[101] = f1 data.Data[102] = f2 @@ -180,6 +184,7 @@ func TestTransferColumnBasedInsertDataToRowBased(t *testing.T) { // data.Data[108] = f8 data.Data[109] = f9 data.Data[110] = f10 + data.Data[111] = f11 utss, rowIds, rows, err := TransferColumnBasedInsertDataToRowBased(data) assert.NoError(t, err) @@ -202,6 +207,7 @@ func TestTransferColumnBasedInsertDataToRowBased(t *testing.T) { // b + 1, // "1" 1, // 1 0, 0, 0, 0, // 0 + 1, 1, }, rows[0].Value) assert.ElementsMatch(t, @@ -216,6 +222,7 @@ func TestTransferColumnBasedInsertDataToRowBased(t *testing.T) { // b + 2, // "2" 2, // 2 0, 0, 0, 0, // 0 + 2, 2, }, rows[1].Value) assert.ElementsMatch(t, @@ -230,6 +237,7 @@ func TestTransferColumnBasedInsertDataToRowBased(t *testing.T) { // b + 3, // "3" 3, // 3 0, 0, 0, 0, // 0 + 3, 3, }, rows[2].Value) } @@ -313,7 +321,7 @@ func TestReadBinary(t *testing.T) { } } -func genAllFieldsSchema(fVecDim, bVecDim int) (schema *schemapb.CollectionSchema, pkFieldID UniqueID, fieldIDs []UniqueID) { +func genAllFieldsSchema(fVecDim, bVecDim, f16VecDim int) (schema *schemapb.CollectionSchema, pkFieldID UniqueID, fieldIDs []UniqueID) { schema = &schemapb.CollectionSchema{ Name: "all_fields_schema", Description: "all_fields_schema", @@ -359,6 +367,15 @@ func genAllFieldsSchema(fVecDim, bVecDim int) (schema *schemapb.CollectionSchema }, }, }, + { + DataType: schemapb.DataType_Float16Vector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(f16VecDim), + }, + }, + }, { DataType: schemapb.DataType_Array, }, @@ -412,6 +429,16 @@ func generateBinaryVectors(numRows, dim int) []byte { return ret } +func generateFloat16Vectors(numRows, dim int) []byte { + total := (numRows * dim) * 2 + ret := make([]byte, total) + _, err := rand.Read(ret) + if err != nil { + panic(err) + } + return ret +} + func generateBoolArray(numRows int) []bool { ret := make([]bool, 0, numRows) for i := 0; i < numRows; i++ { @@ -474,8 +501,8 @@ func generateInt32ArrayList(numRows int) []*schemapb.ScalarField { return ret } -func genRowWithAllFields(fVecDim, bVecDim int) (blob *commonpb.Blob, pk int64, row []interface{}) { - schema, _, _ := genAllFieldsSchema(fVecDim, bVecDim) +func genRowWithAllFields(fVecDim, bVecDim, f16VecDim int) (blob *commonpb.Blob, pk int64, row []interface{}) { + schema, _, _ := genAllFieldsSchema(fVecDim, bVecDim, f16VecDim) ret := &commonpb.Blob{ Value: nil, } @@ -493,6 +520,11 @@ func genRowWithAllFields(fVecDim, bVecDim int) (blob *commonpb.Blob, pk int64, r _ = binary.Write(&buffer, common.Endian, bVec) ret.Value = append(ret.Value, buffer.Bytes()...) row = append(row, bVec) + case schemapb.DataType_Float16Vector: + f16Vec := generateFloat16Vectors(1, f16VecDim) + _ = binary.Write(&buffer, common.Endian, f16Vec) + ret.Value = append(ret.Value, buffer.Bytes()...) + row = append(row, f16Vec) case schemapb.DataType_Bool: data := rand.Int()%2 == 0 _ = binary.Write(&buffer, common.Endian, data) @@ -550,7 +582,7 @@ func genRowWithAllFields(fVecDim, bVecDim int) (blob *commonpb.Blob, pk int64, r return ret, pk, row } -func genRowBasedInsertMsg(numRows, fVecDim, bVecDim int) (msg *msgstream.InsertMsg, pks []int64, columns [][]interface{}) { +func genRowBasedInsertMsg(numRows, fVecDim, bVecDim, f16VecDim int) (msg *msgstream.InsertMsg, pks []int64, columns [][]interface{}) { msg = &msgstream.InsertMsg{ BaseMsg: msgstream.BaseMsg{ Ctx: nil, @@ -573,7 +605,7 @@ func genRowBasedInsertMsg(numRows, fVecDim, bVecDim int) (msg *msgstream.InsertM pks = make([]int64, 0) raws := make([][]interface{}, 0) for i := 0; i < numRows; i++ { - row, pk, raw := genRowWithAllFields(fVecDim, bVecDim) + row, pk, raw := genRowWithAllFields(fVecDim, bVecDim, f16VecDim) msg.InsertRequest.RowData = append(msg.InsertRequest.RowData, row) pks = append(pks, pk) raws = append(raws, raw) @@ -588,7 +620,7 @@ func genRowBasedInsertMsg(numRows, fVecDim, bVecDim int) (msg *msgstream.InsertM return msg, pks, columns } -func genColumnBasedInsertMsg(schema *schemapb.CollectionSchema, numRows, fVecDim, bVecDim int) (msg *msgstream.InsertMsg, pks []int64, columns [][]interface{}) { +func genColumnBasedInsertMsg(schema *schemapb.CollectionSchema, numRows, fVecDim, bVecDim, f16VecDim int) (msg *msgstream.InsertMsg, pks []int64, columns [][]interface{}) { msg = &msgstream.InsertMsg{ BaseMsg: msgstream.BaseMsg{ Ctx: nil, @@ -795,6 +827,25 @@ func genColumnBasedInsertMsg(schema *schemapb.CollectionSchema, numRows, fVecDim for nrows := 0; nrows < numRows; nrows++ { columns[idx] = append(columns[idx], data[nrows*bVecDim/8:(nrows+1)*bVecDim/8]) } + case schemapb.DataType_Float16Vector: + data := generateFloat16Vectors(numRows, f16VecDim) + f := &schemapb.FieldData{ + Type: schemapb.DataType_Float16Vector, + FieldName: field.Name, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(f16VecDim), + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: data, + }, + }, + }, + FieldId: field.FieldID, + } + msg.FieldsData = append(msg.FieldsData, f) + for nrows := 0; nrows < numRows; nrows++ { + columns[idx] = append(columns[idx], data[nrows*f16VecDim*2:(nrows+1)*f16VecDim*2]) + } case schemapb.DataType_Array: data := generateInt32ArrayList(numRows) @@ -845,10 +896,10 @@ func genColumnBasedInsertMsg(schema *schemapb.CollectionSchema, numRows, fVecDim } func TestRowBasedInsertMsgToInsertData(t *testing.T) { - numRows, fVecDim, bVecDim := 10, 8, 8 - schema, _, fieldIDs := genAllFieldsSchema(fVecDim, bVecDim) + numRows, fVecDim, bVecDim, f16VecDim := 10, 8, 8, 8 + schema, _, fieldIDs := genAllFieldsSchema(fVecDim, bVecDim, f16VecDim) fieldIDs = fieldIDs[:len(fieldIDs)-2] - msg, _, columns := genRowBasedInsertMsg(numRows, fVecDim, bVecDim) + msg, _, columns := genRowBasedInsertMsg(numRows, fVecDim, bVecDim, f16VecDim) idata, err := RowBasedInsertMsgToInsertData(msg, schema) assert.NoError(t, err) @@ -864,9 +915,9 @@ func TestRowBasedInsertMsgToInsertData(t *testing.T) { } func TestColumnBasedInsertMsgToInsertData(t *testing.T) { - numRows, fVecDim, bVecDim := 2, 2, 8 - schema, _, fieldIDs := genAllFieldsSchema(fVecDim, bVecDim) - msg, _, columns := genColumnBasedInsertMsg(schema, numRows, fVecDim, bVecDim) + numRows, fVecDim, bVecDim, f16VecDim := 2, 2, 8, 2 + schema, _, fieldIDs := genAllFieldsSchema(fVecDim, bVecDim, f16VecDim) + msg, _, columns := genColumnBasedInsertMsg(schema, numRows, fVecDim, bVecDim, f16VecDim) idata, err := ColumnBasedInsertMsgToInsertData(msg, schema) assert.NoError(t, err) @@ -882,10 +933,10 @@ func TestColumnBasedInsertMsgToInsertData(t *testing.T) { } func TestInsertMsgToInsertData(t *testing.T) { - numRows, fVecDim, bVecDim := 10, 8, 8 - schema, _, fieldIDs := genAllFieldsSchema(fVecDim, bVecDim) + numRows, fVecDim, bVecDim, f16VecDim := 10, 8, 8, 8 + schema, _, fieldIDs := genAllFieldsSchema(fVecDim, bVecDim, f16VecDim) fieldIDs = fieldIDs[:len(fieldIDs)-2] - msg, _, columns := genRowBasedInsertMsg(numRows, fVecDim, bVecDim) + msg, _, columns := genRowBasedInsertMsg(numRows, fVecDim, bVecDim, f16VecDim) idata, err := InsertMsgToInsertData(msg, schema) assert.NoError(t, err) @@ -901,9 +952,9 @@ func TestInsertMsgToInsertData(t *testing.T) { } func TestInsertMsgToInsertData2(t *testing.T) { - numRows, fVecDim, bVecDim := 2, 2, 8 - schema, _, fieldIDs := genAllFieldsSchema(fVecDim, bVecDim) - msg, _, columns := genColumnBasedInsertMsg(schema, numRows, fVecDim, bVecDim) + numRows, fVecDim, bVecDim, f16VecDim := 2, 2, 8, 2 + schema, _, fieldIDs := genAllFieldsSchema(fVecDim, bVecDim, f16VecDim) + msg, _, columns := genColumnBasedInsertMsg(schema, numRows, fVecDim, bVecDim, f16VecDim) idata, err := InsertMsgToInsertData(msg, schema) assert.NoError(t, err) diff --git a/internal/storage/vector_chunk_manager.go b/internal/storage/vector_chunk_manager.go index b1dd545c2c41e..738bf8cde17d6 100644 --- a/internal/storage/vector_chunk_manager.go +++ b/internal/storage/vector_chunk_manager.go @@ -24,7 +24,6 @@ import ( "time" "github.com/cockroachdb/errors" - "go.uber.org/zap" "golang.org/x/exp/mmap" @@ -33,9 +32,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/cache" ) -var ( - defaultLocalCacheSize = 64 -) +var defaultLocalCacheSize = 64 // VectorChunkManager is responsible for read and write vector data. type VectorChunkManager struct { @@ -291,6 +288,7 @@ func (vcm *VectorChunkManager) ReadAt(ctx context.Context, filePath string, off } return p, nil } + func (vcm *VectorChunkManager) Remove(ctx context.Context, filePath string) error { err := vcm.vectorStorage.Remove(ctx, filePath) if err != nil { diff --git a/internal/storage/vector_chunk_manager_test.go b/internal/storage/vector_chunk_manager_test.go index eb875a4319aba..89d5eb55e22d8 100644 --- a/internal/storage/vector_chunk_manager_test.go +++ b/internal/storage/vector_chunk_manager_test.go @@ -23,7 +23,6 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -132,8 +131,10 @@ func buildVectorChunkManager(ctx context.Context, localPath string, localCacheEn return vcm, nil } -var Params = paramtable.Get() -var localPath = "/tmp/milvus_test/chunkmanager/" +var ( + Params = paramtable.Get() + localPath = "/tmp/milvus_test/chunkmanager/" +) func TestMain(m *testing.M) { paramtable.Init() @@ -322,6 +323,7 @@ func (m *mockFailedChunkManager) RemoveWithPrefix(ctx context.Context, prefix st } return nil } + func (m *mockFailedChunkManager) MultiRemove(ctx context.Context, key []string) error { if m.fail { return errors.New("multi remove error") @@ -465,7 +467,7 @@ func TestVectorChunkManager_Read(t *testing.T) { r, err = vcm.Mmap(ctx, "not exist") assert.Error(t, err) - assert.Nil(t, nil) + assert.Nil(t, r) } content, err = vcm.ReadAt(ctx, "109", 9999, 8*4) diff --git a/internal/tso/global_allocator.go b/internal/tso/global_allocator.go index 3b6424866a6c0..7d737387a5909 100644 --- a/internal/tso/global_allocator.go +++ b/internal/tso/global_allocator.go @@ -34,11 +34,11 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/kv" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -139,7 +139,7 @@ func (gta *GlobalTSOAllocator) GenerateTSO(count uint32) (uint64, error) { // Alloc allocates a batch of timestamps. What is returned is the starting timestamp. func (gta *GlobalTSOAllocator) Alloc(count uint32) (typeutil.Timestamp, error) { - //return gta.tso.SyncTimestamp() + // return gta.tso.SyncTimestamp() start, err := gta.GenerateTSO(count) if err != nil { return typeutil.ZeroTimestamp, err diff --git a/internal/tso/tso.go b/internal/tso/tso.go index 9d7f3f08742ed..495ec510d157c 100644 --- a/internal/tso/tso.go +++ b/internal/tso/tso.go @@ -36,11 +36,11 @@ import ( "unsafe" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -81,7 +81,7 @@ func (t *timestampOracle) loadTimestamp() (time.Time, error) { return typeutil.ZeroTime, nil } - var binData = []byte(strData) + binData := []byte(strData) if len(binData) == 0 { return typeutil.ZeroTime, nil } @@ -91,7 +91,7 @@ func (t *timestampOracle) loadTimestamp() (time.Time, error) { // save timestamp, if lastTs is 0, we think the timestamp doesn't exist, so create it, // otherwise, update it. func (t *timestampOracle) saveTimestamp(ts time.Time) error { - //we use big endian here for compatibility issues + // we use big endian here for compatibility issues data := typeutil.Uint64ToBytesBigEndian(uint64(ts.UnixNano())) err := t.txnKV.Save(t.key, string(data)) if err != nil { diff --git a/internal/types/types.go b/internal/types/types.go index cad03c5d4b8aa..ad4021bc33eb2 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -18,7 +18,9 @@ package types import ( "context" + "io" + "github.com/tikv/client-go/v2/txnkv" clientv3 "go.etcd.io/etcd/client/v3" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -31,16 +33,11 @@ import ( "github.com/milvus-io/milvus/internal/proto/rootcoordpb" ) -// TimeTickProvider is the interface all services implement -type TimeTickProvider interface { - GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) -} - // Limiter defines the interface to perform request rate limiting. // If Limit function return true, the request will be rejected. // Otherwise, the request will pass. Limit also returns limit of limiter. type Limiter interface { - Check(collectionID int64, rt internalpb.RateType, n int) commonpb.ErrorCode + Check(collectionID int64, rt internalpb.RateType, n int) error } // Component is the interface all services implement @@ -48,57 +45,19 @@ type Component interface { Init() error Start() error Stop() error - GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) - GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) Register() error - //SetAddress(address string) - //GetAddress() string +} + +// DataNodeClient is the client interface for datanode server +type DataNodeClient interface { + io.Closer + datapb.DataNodeClient } // DataNode is the interface `datanode` package implements type DataNode interface { Component - - // Deprecated - WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannelsRequest) (*commonpb.Status, error) - - // FlushSegments notifies DataNode to flush the segments req provids. The flush tasks are async to this - // rpc, DataNode will flush the segments in the background. - // - // Return UnexpectedError code in status: - // If DataNode isn't in HEALTHY: states not HEALTHY or dynamic checks not HEALTHY - // If DataNode doesn't find the correspounding segmentID in its memeory replica - // Return Success code in status and trigers background flush: - // Log an info log if a segment is under flushing - FlushSegments(ctx context.Context, req *datapb.FlushSegmentsRequest) (*commonpb.Status, error) - - // ShowConfigurations gets specified configurations param of DataNode - ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) - // GetMetrics gets the metrics about DataNode. - GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) - - // Compaction will add a compaction task according to the request plan - Compaction(ctx context.Context, req *datapb.CompactionPlan) (*commonpb.Status, error) - // GetCompactionState get states of all compation tasks - GetCompactionState(ctx context.Context, req *datapb.CompactionStateRequest) (*datapb.CompactionStateResponse, error) - // SyncSegments is called by DataCoord, to sync the segments meta when complete compaction - SyncSegments(ctx context.Context, req *datapb.SyncSegmentsRequest) (*commonpb.Status, error) - - // Import data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including file path and options - // - // Return status indicates if this operation is processed successfully or fail cause; - // error is always nil - Import(ctx context.Context, req *datapb.ImportTaskRequest) (*commonpb.Status, error) - - // ResendSegmentStats resend un-flushed segment stats back upstream to DataCoord by resending DataNode time tick message. - // It returns a list of segments to be sent. - ResendSegmentStats(ctx context.Context, req *datapb.ResendSegmentStatsRequest) (*datapb.ResendSegmentStatsResponse, error) - - // AddImportSegment puts the given import segment to current DataNode's flow graph. - AddImportSegment(ctx context.Context, req *datapb.AddImportSegmentRequest) (*datapb.AddImportSegmentResponse, error) + datapb.DataNodeServer } // DataNodeComponent is used by grpc server of DataNode @@ -120,266 +79,35 @@ type DataNodeComponent interface { // SetEtcdClient set etcd client for DataNode SetEtcdClient(etcdClient *clientv3.Client) - // SetRootCoord set RootCoord for DataNode + // SetRootCoordClient set SetRootCoordClient for DataNode // `rootCoord` is a client of root coordinator. // // Return a generic error in status: // If the rootCoord is nil or the rootCoord has been set before. // Return nil in status: // The rootCoord is not nil. - SetRootCoord(rootCoord RootCoord) error + SetRootCoordClient(rootCoord RootCoordClient) error - // SetDataCoord set DataCoord for DataNode + // SetDataCoordClient set DataCoord for DataNode // `dataCoord` is a client of data coordinator. // // Return a generic error in status: // If the dataCoord is nil or the dataCoord has been set before. // Return nil in status: // The dataCoord is not nil. - SetDataCoord(dataCoord DataCoord) error + SetDataCoordClient(dataCoord DataCoordClient) error +} + +// DataCoordClient is the client interface for datacoord server +type DataCoordClient interface { + io.Closer + datapb.DataCoordClient } // DataCoord is the interface `datacoord` package implements type DataCoord interface { Component - TimeTickProvider - - // Flush notifies DataCoord to flush all current growing segments of specified Collection - // ctx is the context to control request deadline and cancellation - // req contains the request params, which are database name(not used for now) and collection id - // - // response struct `FlushResponse` contains - // 1, related db id - // 2, related collection id - // 3, affected segment ids - // 4, already flush/flushing segment ids of related collection before this request - // 5, timeOfSeal, all data before timeOfSeal is guaranteed to be sealed or flushed - // error is returned only when some communication issue occurs - // if some error occurs in the process of `Flush`, it will be recorded and returned in `Status` field of response - // - // `Flush` returns when all growing segments of specified collection is "sealed" - // the flush procedure will wait corresponding data node(s) proceeds to the safe timestamp - // and the `Flush` operation will be truly invoked - // If the Datacoord or Datanode crashes in the flush procedure, recovery process will replay the ts check until all requirement is met - // - // Flushed segments can be check via `GetFlushedSegments` API - Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.FlushResponse, error) - - // AssignSegmentID applies allocations for specified Coolection/Partition and related Channel Name(Virtial Channel) - // - // ctx is the context to control request deadline and cancellation - // req contains the requester's info(id and role) and the list of Assignment Request, - // which contains the specified collection, partition id, the related VChannel Name and row count it needs - // - // response struct `AssignSegmentIDResponse` contains the assignment result for each request - // error is returned only when some communication issue occurs - // if some error occurs in the process of `AssignSegmentID`, it will be recorded and returned in `Status` field of response - // - // `AssignSegmentID` will applies current configured allocation policies for each request - // if the VChannel is newly used, `WatchDmlChannels` will be invoked to notify a `DataNode`(selected by policy) to watch it - // if there is anything make the allocation impossible, the response will not contain the corresponding result - AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) - - // GetSegmentStates requests segment state information - // - // ctx is the context to control request deadline and cancellation - // req contains the list of segment id to query - // - // response struct `GetSegmentStatesResponse` contains the list of each state query result - // when the segment is not found, the state entry will has the field `Status` to identify failure - // otherwise the Segment State and Start position information will be returned - // error is returned only when some communication issue occurs - GetSegmentStates(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) - - // GetInsertBinlogPaths requests binlog paths for specified segment - // - // ctx is the context to control request deadline and cancellation - // req contains the segment id to query - // - // response struct `GetInsertBinlogPathsResponse` contains the fields list - // and corresponding binlog path list - // error is returned only when some communication issue occurs - GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsertBinlogPathsRequest) (*datapb.GetInsertBinlogPathsResponse, error) - - // GetSegmentInfoChannel DEPRECATED - // legacy api to get SegmentInfo Channel name - GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringResponse, error) - - // GetCollectionStatistics requests collection statistics - // - // ctx is the context to control request deadline and cancellation - // req contains the collection id to query - // - // response struct `GetCollectionStatisticsResponse` contains the key-value list fields returning related data - // only row count for now - // error is returned only when some communication issue occurs - GetCollectionStatistics(ctx context.Context, req *datapb.GetCollectionStatisticsRequest) (*datapb.GetCollectionStatisticsResponse, error) - - // GetPartitionStatistics requests partition statistics - // - // ctx is the context to control request deadline and cancellation - // req contains the collection and partition id to query - // - // response struct `GetPartitionStatisticsResponse` contains the key-value list fields returning related data - // only row count for now - // error is returned only when some communication issue occurs - GetPartitionStatistics(ctx context.Context, req *datapb.GetPartitionStatisticsRequest) (*datapb.GetPartitionStatisticsResponse, error) - - // GetSegmentInfo requests segment info - // - // ctx is the context to control request deadline and cancellation - // req contains the list of segment ids to query - // - // response struct `GetSegmentInfoResponse` contains the list of segment info - // error is returned only when some communication issue occurs - GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoRequest) (*datapb.GetSegmentInfoResponse, error) - - // GetRecoveryInfo request segment recovery info of collection/partition - // - // ctx is the context to control request deadline and cancellation - // req contains the collection/partition id to query - // - // response struct `GetRecoveryInfoResponse` contains the list of segments info and corresponding vchannel info - // error is returned only when some communication issue occurs - GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInfoRequest) (*datapb.GetRecoveryInfoResponse, error) - - // GetRecoveryInfoV2 request segment recovery info of collection or batch partitions - // - // ctx is the context to control request deadline and cancellation - // req contains the collection/partitions id to query - // - // response struct `GetRecoveryInfoResponseV2` contains the list of segments info and corresponding vchannel info - // error is returned only when some communication issue occurs - GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryInfoRequestV2) (*datapb.GetRecoveryInfoResponseV2, error) - - // SaveBinlogPaths updates segments binlogs(including insert binlogs, stats logs and delta logs) - // and related message stream positions - // - // ctx is the context to control request deadline and cancellation - // req contains the segment binlogs and checkpoint informations. - // - // response status contains the status/error code and failing reason if any - // error is returned only when some communication issue occurs - // - // there is a constraint that the `SaveBinlogPaths` requests of same segment shall be passed in sequence - // the root reason is each `SaveBinlogPaths` will overwrite the checkpoint position - // if the constraint is broken, the checkpoint position will not be monotonically increasing and the integrity will be compromised - SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) (*commonpb.Status, error) - - // GetSegmentsByStates returns segment list of requested collection/partition in given states - // - // ctx is the context to control request deadline and cancellation - // req contains the collection/partition id and states to query - // when partition is lesser or equal to 0, all flushed segments of collection will be returned - // - // response struct `GetSegmentsByStatesResponse` contains segment id list - // error is returned only when some communication issue occurs - GetSegmentsByStates(ctx context.Context, req *datapb.GetSegmentsByStatesRequest) (*datapb.GetSegmentsByStatesResponse, error) - - // GetFlushedSegments returns flushed segment list of requested collection/partition - // - // ctx is the context to control request deadline and cancellation - // req contains the collection/partition id to query - // when partition is lesser or equal to 0, all flushed segments of collection will be returned - // - // response struct `GetFlushedSegmentsResponse` contains flushed segment id list - // error is returned only when some communication issue occurs - GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) - - // ShowConfigurations gets specified configurations para of DataCoord - ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) - // GetMetrics gets the metrics about DataCoord. - GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) - // ManualCompaction triggers a compaction for a collection - ManualCompaction(ctx context.Context, req *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) - // GetCompactionState gets the state of a compaction - GetCompactionState(ctx context.Context, req *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) - // GetCompactionStateWithPlans get the state of requested plan id - GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) - - // WatchChannels notifies DataCoord to watch vchannels of a collection - WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) - // GetFlushState gets the flush state of multiple segments - GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) - // GetFlushAllState checks if all DML messages before `FlushAllTs` have been flushed. - GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAllStateRequest) (*milvuspb.GetFlushAllStateResponse, error) - // SetSegmentState updates a segment's state explicitly. - SetSegmentState(ctx context.Context, req *datapb.SetSegmentStateRequest) (*datapb.SetSegmentStateResponse, error) - - // DropVirtualChannel notifies DataCoord a virtual channel is dropped and - // updates related segments binlogs(including insert binlogs, stats logs and delta logs) - // and related message stream positions - // - // ctx is the context to control request deadline and cancellation - // req contains the dropped virtual channel name and related segment information - // - // response status contains the status/error code and failing reason if any - // error is returned only when some communication issue occurs - DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) (*datapb.DropVirtualChannelResponse, error) - - // Import data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including file path and options - // - // The `Status` in response struct `ImportResponse` indicates if this operation is processed successfully or fail cause; - // the `tasks` in `ImportResponse` return an id list of tasks. - // error is always nil - Import(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) - - // UpdateSegmentStatistics updates a segment's stats. - UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) - // UpdateChannelCheckpoint updates channel checkpoint in dataCoord. - UpdateChannelCheckpoint(ctx context.Context, req *datapb.UpdateChannelCheckpointRequest) (*commonpb.Status, error) - // ReportDataNodeTtMsgs report DataNodeTtMsgs to dataCoord, called by datanode. - ReportDataNodeTtMsgs(ctx context.Context, req *datapb.ReportDataNodeTtMsgsRequest) (*commonpb.Status, error) - - // SaveImportSegment saves the import segment binlog paths data and then looks for the right DataNode to add the - // segment to that DataNode. - SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) - - // UnsetIsImportingState unsets the `isImport` state of the given segments so that they can get compacted normally. - UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) - - // MarkSegmentsDropped marks the given segments as `dropped` state. - MarkSegmentsDropped(ctx context.Context, req *datapb.MarkSegmentsDroppedRequest) (*commonpb.Status, error) - - BroadcastAlteredCollection(ctx context.Context, req *datapb.AlterCollectionRequest) (*commonpb.Status, error) - - CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) - - GcConfirm(ctx context.Context, request *datapb.GcConfirmRequest) (*datapb.GcConfirmResponse, error) - - // CreateIndex create an index on collection. - // Index building is asynchronous, so when an index building request comes, an IndexID is assigned to the task and - // will get all flushed segments from DataCoord and record tasks with these segments. The background process - // indexBuilder will find this task and assign it to IndexNode for execution. - CreateIndex(ctx context.Context, req *indexpb.CreateIndexRequest) (*commonpb.Status, error) - - // GetIndexState gets the index state of the index name in the request from Proxy. - // Deprecated: use DescribeIndex instead - GetIndexState(ctx context.Context, req *indexpb.GetIndexStateRequest) (*indexpb.GetIndexStateResponse, error) - - // GetSegmentIndexState gets the index state of the segments in the request from RootCoord. - GetSegmentIndexState(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest) (*indexpb.GetSegmentIndexStateResponse, error) - - // GetIndexInfos gets the index files of the IndexBuildIDs in the request from RootCoordinator. - GetIndexInfos(ctx context.Context, req *indexpb.GetIndexInfoRequest) (*indexpb.GetIndexInfoResponse, error) - - // DescribeIndex describe the index info of the collection. - DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error) - - // GetIndexStatistics get the statistics of the index. - GetIndexStatistics(ctx context.Context, req *indexpb.GetIndexStatisticsRequest) (*indexpb.GetIndexStatisticsResponse, error) - - // GetIndexBuildProgress get the index building progress by num rows. - // Deprecated: use DescribeIndex instead - GetIndexBuildProgress(ctx context.Context, req *indexpb.GetIndexBuildProgressRequest) (*indexpb.GetIndexBuildProgressResponse, error) - - // DropIndex deletes indexes based on IndexID. One IndexID corresponds to the index of an entire column. A column is - // divided into many segments, and each segment corresponds to an IndexBuildID. IndexCoord uses IndexBuildID to record - // index tasks. Therefore, when DropIndex is called, delete all tasks corresponding to IndexBuildID corresponding to IndexID. - DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error) + datapb.DataCoordServer } // DataCoordComponent defines the interface of DataCoord component. @@ -393,42 +121,28 @@ type DataCoordComponent interface { // `etcdClient` is a client of etcd SetEtcdClient(etcdClient *clientv3.Client) - SetRootCoord(rootCoord RootCoord) + // SetTiKVClient set TiKV client for QueryNode + SetTiKVClient(client *txnkv.Client) + + SetRootCoordClient(rootCoord RootCoordClient) // SetDataNodeCreator set DataNode client creator func for DataCoord - SetDataNodeCreator(func(context.Context, string, int64) (DataNode, error)) + SetDataNodeCreator(func(context.Context, string, int64) (DataNodeClient, error)) + + // SetIndexNodeCreator set Index client creator func for DataCoord + SetIndexNodeCreator(func(context.Context, string, int64) (IndexNodeClient, error)) +} - //SetIndexNodeCreator set Index client creator func for DataCoord - SetIndexNodeCreator(func(context.Context, string, int64) (IndexNode, error)) +// IndexNodeClient is the client interface for indexnode server +type IndexNodeClient interface { + io.Closer + indexpb.IndexNodeClient } // IndexNode is the interface `indexnode` package implements type IndexNode interface { Component - //TimeTickProvider - - // BuildIndex receives request from IndexCoordinator to build an index. - // Index building is asynchronous, so when an index building request comes, IndexNode records the task and returns. - //BuildIndex(ctx context.Context, req *datapb.BuildIndexRequest) (*commonpb.Status, error) - //GetTaskSlots(ctx context.Context, req *datapb.GetTaskSlotsRequest) (*datapb.GetTaskSlotsResponse, error) - // - //// GetMetrics gets the metrics about IndexNode. - //GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) - - // CreateJob receive index building job from indexcoord. Notes that index building is asynchronous, task is recorded - // in indexnode and then request is finished. - CreateJob(context.Context, *indexpb.CreateJobRequest) (*commonpb.Status, error) - // QueryJobs returns states of index building jobs specified by BuildIDs. There are four states of index building task - // Unissued, InProgress, Finished, Failed - QueryJobs(context.Context, *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error) - // DropJobs cancel index building jobs specified by BuildIDs. Notes that dropping task may have finished. - DropJobs(context.Context, *indexpb.DropJobsRequest) (*commonpb.Status, error) - // GetJobStats returns metrics of indexnode, including available job queue info, available task slots and finished job infos. - GetJobStats(context.Context, *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) - - ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) - // GetMetrics gets the metrics about IndexNode. - GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) + indexpb.IndexNodeServer } // IndexNodeComponent is used by grpc server of IndexNode @@ -445,370 +159,18 @@ type IndexNodeComponent interface { UpdateStateCode(stateCode commonpb.StateCode) } +// RootCoordClient is the client interface for rootcoord server +type RootCoordClient interface { + io.Closer + rootcoordpb.RootCoordClient +} + // RootCoord is the interface `rootcoord` package implements // //go:generate mockery --name=RootCoord --output=../mocks --filename=mock_rootcoord.go --with-expecter type RootCoord interface { Component - TimeTickProvider - - // CreateDatabase notifies RootCoord to create a database - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including a database name - // - // The `ErrorCode` of `Status` is `Success` if create database successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - CreateDatabase(ctx context.Context, req *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) - - // DropDatabase notifies RootCoord to drop a database - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including a database name - // - // The `ErrorCode` of `Status` is `Success` if drop database successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - DropDatabase(ctx context.Context, req *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) - - // ListDatabases notifies RootCoord to list all database names at specified timestamp - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used), collection name and timestamp - // - // The `Status` in response struct `ListDatabasesResponse` indicates if this operation is processed successfully or fail cause; - // other fields in `ListDatabasesResponse` are filled with all database names, error is always nil - ListDatabases(ctx context.Context, req *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) - - // CreateCollection notifies RootCoord to create a collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used), collection name, collection schema and - // physical channel num for inserting data - // - // The `ErrorCode` of `Status` is `Success` if create collection successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) - - // DropCollection notifies RootCoord to drop a collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used) and collection name - // - // The `ErrorCode` of `Status` is `Success` if drop collection successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - DropCollection(ctx context.Context, req *milvuspb.DropCollectionRequest) (*commonpb.Status, error) - - // HasCollection notifies RootCoord to check a collection's existence at specified timestamp - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used), collection name and timestamp - // - // The `Status` in response struct `BoolResponse` indicates if this operation is processed successfully or fail cause; - // the `Value` in `BoolResponse` is `true` if the collection exists at the specified timestamp, `false` otherwise. - // Timestamp is ignored if set to 0. - // error is always nil - HasCollection(ctx context.Context, req *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) - - // DescribeCollection notifies RootCoord to get all information about this collection at specified timestamp - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used), collection name or collection id, and timestamp - // - // The `Status` in response struct `DescribeCollectionResponse` indicates if this operation is processed successfully or fail cause; - // other fields in `DescribeCollectionResponse` are filled with this collection's schema, collection id, - // physical channel names, virtual channel names, created time, alias names, and so on. - // - // If timestamp is set a non-zero value and collection does not exist at this timestamp, - // the `ErrorCode` of `Status` in `DescribeCollectionResponse` will be set to `Error`. - // Timestamp is ignored if set to 0. - // error is always nil - DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) - - // DescribeCollectionInternal same to DescribeCollection, only used in internal RPC. - // Besides, it'll also return unavailable collection, for example, creating, dropping. - DescribeCollectionInternal(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) - - // ShowCollections notifies RootCoord to list all collection names and other info in database at specified timestamp - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used), collection name and timestamp - // - // The `Status` in response struct `ShowCollectionsResponse` indicates if this operation is processed successfully or fail cause; - // other fields in `ShowCollectionsResponse` are filled with all collection names, collection ids, - // created times, created UTC times, and so on. - // error is always nil - ShowCollections(ctx context.Context, req *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) - - // AlterCollection notifies Proxy to create a collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name and collection properties - // - // The `ErrorCode` of `Status` is `Success` if create collection successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) - - // CreatePartition notifies RootCoord to create a partition - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used), collection name and partition name - // - // The `ErrorCode` of `Status` is `Success` if create partition successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) - - // DropPartition notifies RootCoord to drop a partition - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used), collection name and partition name - // - // The `ErrorCode` of `Status` is `Success` if drop partition successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - // - // Default partition cannot be dropped - DropPartition(ctx context.Context, req *milvuspb.DropPartitionRequest) (*commonpb.Status, error) - - // HasPartition notifies RootCoord to check if a partition with specified name exists in the collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used), collection name and partition name - // - // The `Status` in response struct `BoolResponse` indicates if this operation is processed successfully or fail cause; - // the `Value` in `BoolResponse` is `true` if the partition exists in the collection, `false` otherwise. - // error is always nil - HasPartition(ctx context.Context, req *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) - - // ShowPartitions notifies RootCoord to list all partition names and other info in the collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used), collection name or collection id, and partition names - // - // The `Status` in response struct `ShowPartitionsResponse` indicates if this operation is processed successfully or fail cause; - // other fields in `ShowPartitionsResponse` are filled with all partition names, partition ids, - // created times, created UTC times, and so on. - // error is always nil - ShowPartitions(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) - - // ShowPartitionsInternal same to ShowPartitions, but will return unavailable resources and only used in internal. - ShowPartitionsInternal(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) - - // CreateIndex notifies RootCoord to create an index for the specified field in the collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used), collection name, field name and index parameters - // - // The `ErrorCode` of `Status` is `Success` if create index successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - // - // RootCoord forwards this request to IndexCoord to create index - //CreateIndex(ctx context.Context, req *milvuspb.CreateIndexRequest) (*commonpb.Status, error) - - // DescribeIndex notifies RootCoord to get specified index information for specified field - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used), collection name, field name and index name - // - // The `Status` in response struct `DescribeIndexResponse` indicates if this operation is processed successfully or fail cause; - // index information is filled in `IndexDescriptions` - // error is always nil - //DescribeIndex(ctx context.Context, req *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) - - //GetIndexState(ctx context.Context, req *milvuspb.GetIndexStateRequest) (*milvuspb.GetIndexStateResponse, error) - - // DropIndex notifies RootCoord to drop the specified index for the specified field - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used), collection name, field name and index name - // - // The `ErrorCode` of `Status` is `Success` if drop index successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - // - // RootCoord forwards this request to IndexCoord to drop index - //DropIndex(ctx context.Context, req *milvuspb.DropIndexRequest) (*commonpb.Status, error) - - // CreateAlias notifies RootCoord to create an alias for the collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including collection name and alias - // - // The `ErrorCode` of `Status` is `Success` if create alias successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest) (*commonpb.Status, error) - - // DropAlias notifies RootCoord to drop an alias for the collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including alias - // - // The `ErrorCode` of `Status` is `Success` if drop alias successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest) (*commonpb.Status, error) - - // AlterAlias notifies RootCoord to alter alias for the collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including collection name and new alias - // - // The `ErrorCode` of `Status` is `Success` if alter alias successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest) (*commonpb.Status, error) - - // AllocTimestamp notifies RootCoord to alloc timestamps - // - // ctx is the context to control request deadline and cancellation - // req contains the count of timestamps need to be allocated - // - // The `Status` in response struct `AllocTimestampResponse` indicates if this operation is processed successfully or fail cause; - // `Timestamp` is the first available timestamp allocated - // error is always nil - AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) - - // AllocID notifies RootCoord to alloc IDs - // - // ctx is the context to control request deadline and cancellation - // req contains the count of IDs need to be allocated - // - // The `Status` in response struct `AllocTimestampResponse` indicates if this operation is processed successfully or fail cause; - // `ID` is the first available id allocated - // error is always nil - AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) - - // UpdateChannelTimeTick notifies RootCoord to update each Proxy's safe timestamp - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including physical channel names, channels' safe timestamps and default timestamp - // - // The `ErrorCode` of `Status` is `Success` if update channel timetick successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - UpdateChannelTimeTick(ctx context.Context, req *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) - - // DescribeSegment notifies RootCoord to get specified segment information in the collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including collection id and segment id - // - // The `Status` in response struct `DescribeSegmentResponse` indicates if this operation is processed successfully or fail cause; - // segment index information is filled in `IndexID`, `BuildID` and `EnableIndex`. - // error is always nil - //DescribeSegment(ctx context.Context, req *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) - - // ShowSegments notifies RootCoord to list all segment ids in the collection or partition - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including collection id and partition id - // - // The `Status` in response struct `ShowSegmentsResponse` indicates if this operation is processed successfully or fail cause; - // `SegmentIDs` in `ShowSegmentsResponse` records all segment ids. - // error is always nil - ShowSegments(ctx context.Context, req *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) - - // InvalidateCollectionMetaCache notifies RootCoord to clear the meta cache of specific collection in Proxies. - // If `CollectionID` is specified in request, all the collection meta cache with the specified collectionID will be - // invalidated, if only the `CollectionName` is specified in request, only the collection meta cache with the - // specified collectionName will be invalidated. - // - // ctx is the request to control request deadline and cancellation. - // request contains the request params, which are database id(not used) and collection id. - // - // The `ErrorCode` of `Status` is `Success` if drop index successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - // - // RootCoord just forwards this request to Proxy client - InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) - - // SegmentFlushCompleted notifies RootCoord that specified segment has been flushed - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including SegmentInfo - // - // The `ErrorCode` of `Status` is `Success` if process successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - // - // This interface is only used by DataCoord, when RootCoord receives this request, RootCoord will notify IndexCoord - // to build index for this segment. - //SegmentFlushCompleted(ctx context.Context, in *datapb.SegmentFlushCompletedMsg) (*commonpb.Status, error) - - // ShowConfigurations gets specified configurations para of RootCoord - ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) - // GetMetrics notifies RootCoord to collect metrics for specified component - GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) - - // Import data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including file path and options - // - // Return status indicates if this operation is processed successfully or fail cause; - // error is always nil - Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) - - // GetImportState checks import task state from datanode - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including a task id - // - // The `Status` in response struct `GetImportStateResponse` indicates if this operation is processed successfully or fail cause; - // the `state` in `GetImportStateResponse` return the state of the import task. - // error is always nil - GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) - - // List id array of all import tasks - // - // ctx is the context to control request deadline and cancellation - // req contains the request params - // - // The `Status` in response struct `ListImportTasksResponse` indicates if this operation is processed successfully or fail cause; - // the `Tasks` in `ListImportTasksResponse` return the id array of all import tasks. - // error is always nil - ListImportTasks(ctx context.Context, req *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) - - // ReportImport reports import task state to rootCoord - // - // ctx is the context to control request deadline and cancellation - // req contains the import results, including imported row count and an id list of generated segments - // - // response status contains the status/error code and failing reason if any error is returned - // error is always nil - ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) (*commonpb.Status, error) - - // CreateCredential create new user and password - CreateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) - // UpdateCredential update password for a user - UpdateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) - // DeleteCredential delete a user - DeleteCredential(ctx context.Context, req *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) - // ListCredUsers list all usernames - ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) - // GetCredential get credential by username - GetCredential(ctx context.Context, req *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) - - CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest) (*commonpb.Status, error) - DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) (*commonpb.Status, error) - OperateUserRole(ctx context.Context, req *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) - SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) - SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) - OperatePrivilege(ctx context.Context, req *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) - SelectGrant(ctx context.Context, req *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) - ListPolicy(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) - - CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) - - RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) + rootcoordpb.RootCoordServer } // RootCoordComponent is used by grpc server of RootCoord @@ -820,72 +182,43 @@ type RootCoordComponent interface { // `etcdClient` is a client of etcd SetEtcdClient(etcdClient *clientv3.Client) + // SetTiKVClient set TiKV client for RootCoord + SetTiKVClient(client *txnkv.Client) + // UpdateStateCode updates state code for RootCoord // State includes: Initializing, Healthy and Abnormal UpdateStateCode(commonpb.StateCode) - // SetDataCoord set DataCoord for RootCoord + // SetDataCoordClient set SetDataCoordClient for RootCoord // `dataCoord` is a client of data coordinator. // // Always return nil. - SetDataCoord(dataCoord DataCoord) error + SetDataCoordClient(dataCoord DataCoordClient) error // SetQueryCoord set QueryCoord for RootCoord // `queryCoord` is a client of query coordinator. // // Always return nil. - SetQueryCoord(queryCoord QueryCoord) error + SetQueryCoordClient(queryCoord QueryCoordClient) error // SetProxyCreator set Proxy client creator func for RootCoord - SetProxyCreator(func(ctx context.Context, addr string, nodeID int64) (Proxy, error)) + SetProxyCreator(func(ctx context.Context, addr string, nodeID int64) (ProxyClient, error)) // GetMetrics notifies RootCoordComponent to collect metrics for specified component GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) } +// ProxyClient is the client interface for proxy server +type ProxyClient interface { + io.Closer + proxypb.ProxyClient +} + // Proxy is the interface `proxy` package implements type Proxy interface { Component - - // InvalidateCollectionMetaCache notifies Proxy to clear the meta cache of specific collection. - // If `CollectionID` is specified in request, all the collection meta cache with the specified collectionID will be - // invalidated, if only the `CollectionName` is specified in request, only the collection meta cache with the - // specified collectionName will be invalidated. - // - // InvalidateCollectionMetaCache should be called when there are any meta changes in specific collection. - // Such as `DropCollection`, `CreatePartition`, `DropPartition`, etc. - // - // ctx is the request to control request deadline and cancellation. - // request contains the request params, which are database name(not used now) and collection name. - // - // InvalidateCollectionMetaCache should always succeed even though the specific collection doesn't exist in Proxy. - // So the code of response `Status` should be always `Success`. - // - // error is returned only when some communication issue occurs. - InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) - - // InvalidateCredentialCache notifies Proxy to clear all the credential cache of specified username. - // - // InvalidateCredentialCache should be called when there are credential changes for specified username. - // Such as `CreateCredential`, `UpdateCredential`, `DeleteCredential`, etc. - // - // InvalidateCredentialCache should always succeed even though the specified username doesn't exist in Proxy. - // So the code of response `Status` should be always `Success`. - // - // error is returned only when some communication issue occurs. - InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) - - UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) (*commonpb.Status, error) - - // SetRates notifies Proxy to limit rates of requests. - SetRates(ctx context.Context, req *proxypb.SetRatesRequest) (*commonpb.Status, error) - - // GetProxyMetrics gets the metrics of proxy, it's an internal interface which is different from GetMetrics interface, - // because it only obtains the metrics of Proxy, not including the topological metrics of Query cluster and Data cluster. - GetProxyMetrics(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) - RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) - - ListClientInfos(ctx context.Context, req *proxypb.ListClientInfosRequest) (*proxypb.ListClientInfosResponse, error) + proxypb.ProxyServer + milvuspb.MilvusServiceServer } // ProxyComponent defines the interface of proxy component. @@ -900,24 +233,24 @@ type ProxyComponent interface { // `etcdClient` is a client of etcd SetEtcdClient(etcdClient *clientv3.Client) - //SetRootCoordClient set RootCoord for Proxy + // SetRootCoordClient set RootCoord for Proxy // `rootCoord` is a client of root coordinator. - SetRootCoordClient(rootCoord RootCoord) + SetRootCoordClient(rootCoord RootCoordClient) // SetDataCoordClient set DataCoord for Proxy // `dataCoord` is a client of data coordinator. - SetDataCoordClient(dataCoord DataCoord) + SetDataCoordClient(dataCoord DataCoordClient) // SetIndexCoordClient set IndexCoord for Proxy // `indexCoord` is a client of index coordinator. - //SetIndexCoordClient(indexCoord IndexCoord) + // SetIndexCoordClient(indexCoord IndexCoord) // SetQueryCoordClient set QueryCoord for Proxy // `queryCoord` is a client of query coordinator. - SetQueryCoordClient(queryCoord QueryCoord) + SetQueryCoordClient(queryCoord QueryCoordClient) // SetQueryNodeCreator set QueryNode client creator func for Proxy - SetQueryNodeCreator(func(ctx context.Context, addr string, nodeID int64) (QueryNode, error)) + SetQueryNodeCreator(func(ctx context.Context, addr string, nodeID int64) (QueryNodeClient, error)) // GetRateLimiter returns the rateLimiter in Proxy GetRateLimiter() (Limiter, error) @@ -925,538 +258,17 @@ type ProxyComponent interface { // UpdateStateCode updates state code for Proxy // `stateCode` is current statement of this proxy node, indicating whether it's healthy. UpdateStateCode(stateCode commonpb.StateCode) +} - // CreateDatabase notifies Proxy to create a database - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including a database name - // - // The `ErrorCode` of `Status` is `Success` if create database successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - CreateDatabase(ctx context.Context, req *milvuspb.CreateDatabaseRequest) (*commonpb.Status, error) - - // DropDatabase notifies Proxy to drop a database - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including a database name - // - // The `ErrorCode` of `Status` is `Success` if drop database successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - DropDatabase(ctx context.Context, req *milvuspb.DropDatabaseRequest) (*commonpb.Status, error) - - // ListDatabases notifies Proxy to list all database names at specified timestamp - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(not used), collection name and timestamp - // - // The `Status` in response struct `ListDatabasesResponse` indicates if this operation is processed successfully or fail cause; - // other fields in `ListDatabasesResponse` are filled with all database names, error is always nil - ListDatabases(ctx context.Context, req *milvuspb.ListDatabasesRequest) (*milvuspb.ListDatabasesResponse, error) - - // CreateCollection notifies Proxy to create a collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, collection schema - // - // The `ErrorCode` of `Status` is `Success` if create collection successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) - // DropCollection notifies Proxy to drop a collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved) and collection name - // - // The `ErrorCode` of `Status` is `Success` if drop collection successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - DropCollection(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) - - // HasCollection notifies Proxy to check a collection's existence at specified timestamp - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name and timestamp - // - // The `Status` in response struct `BoolResponse` indicates if this operation is processed successfully or fail cause; - // the `Value` in `BoolResponse` is `true` if the collection exists at the specified timestamp, `false` otherwise. - // Timestamp is ignored if set to 0. - // error is always nil - HasCollection(ctx context.Context, request *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) - - // LoadCollection notifies Proxy to load a collection's data - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name - // - // The `ErrorCode` of `Status` is `Success` if load collection successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error - // error is always nil - LoadCollection(ctx context.Context, request *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) - - // ReleaseCollection notifies Proxy to release a collection's data - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name - // - // The `ErrorCode` of `Status` is `Success` if release collection successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error - // error is always nil - ReleaseCollection(ctx context.Context, request *milvuspb.ReleaseCollectionRequest) (*commonpb.Status, error) - - // DescribeCollection notifies Proxy to return a collection's description - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name or collection id - // - // The `Status` in response struct `DescribeCollectionResponse` indicates if this operation is processed successfully or fail cause; - // the `Schema` in `DescribeCollectionResponse` return collection's schema. - // error is always nil - DescribeCollection(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) - - // GetCollectionStatistics notifies Proxy to return a collection's statistics - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name - // - // The `Status` in response struct `GetCollectionStatisticsResponse` indicates if this operation is processed successfully or fail cause; - // the `Stats` in `GetCollectionStatisticsResponse` return collection's statistics in key-value format. - // error is always nil - GetCollectionStatistics(ctx context.Context, request *milvuspb.GetCollectionStatisticsRequest) (*milvuspb.GetCollectionStatisticsResponse, error) - - // ShowCollections notifies Proxy to return collections list in current db at specified timestamp - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), timestamp - // - // The `Status` in response struct `ShowCollectionsResponse` indicates if this operation is processed successfully or fail cause; - // the `CollectionNames` in `ShowCollectionsResponse` return collection names list. - // the `CollectionIds` in `ShowCollectionsResponse` return collection ids list. - // error is always nil - ShowCollections(ctx context.Context, request *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) - - // AlterCollection notifies Proxy to create a collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name and collection properties - // - // The `ErrorCode` of `Status` is `Success` if create collection successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) - - // CreatePartition notifies Proxy to create a partition - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, partition name - // - // The `ErrorCode` of `Status` is `Success` if create partition successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - CreatePartition(ctx context.Context, request *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) - - // DropPartition notifies Proxy to drop a partition - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, partition name - // - // The `ErrorCode` of `Status` is `Success` if drop partition successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - DropPartition(ctx context.Context, request *milvuspb.DropPartitionRequest) (*commonpb.Status, error) - - // HasPartition notifies Proxy to check a partition's existence - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, partition name - // - // The `ErrorCode` of `Status` is `Success` if check partition's existence successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - HasPartition(ctx context.Context, request *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) - - // LoadPartitions notifies Proxy to load partition's data - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, partition names - // - // The `ErrorCode` of `Status` is `Success` if load partitions successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error - // error is always nil - LoadPartitions(ctx context.Context, request *milvuspb.LoadPartitionsRequest) (*commonpb.Status, error) - - // ReleasePartitions notifies Proxy to release collection's data - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, partition names - // - // The `ErrorCode` of `Status` is `Success` if release collection successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error - // error is always nil - ReleasePartitions(ctx context.Context, request *milvuspb.ReleasePartitionsRequest) (*commonpb.Status, error) - - // GetPartitionStatistics notifies Proxy to return a partiiton's statistics - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, partition name - // - // The `Status` in response struct `GetPartitionStatisticsResponse` indicates if this operation is processed successfully or fail cause; - // the `Stats` in `GetPartitionStatisticsResponse` return collection's statistics in key-value format. - // error is always nil - GetPartitionStatistics(ctx context.Context, request *milvuspb.GetPartitionStatisticsRequest) (*milvuspb.GetPartitionStatisticsResponse, error) - - // ShowPartitions notifies Proxy to return collections list in current db at specified timestamp - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, partition names(optional) - // When partition names is specified, will return these patitions's inMemory_percentages. - // - // The `Status` in response struct `ShowPartitionsResponse` indicates if this operation is processed successfully or fail cause; - // the `PartitionNames` in `ShowPartitionsResponse` return partition names list. - // the `PartitionIds` in `ShowPartitionsResponse` return partition ids list. - // the `InMemoryPercentages` in `ShowPartitionsResponse` return partitions's inMemory_percentages if the partition names of req is specified. - // error is always nil - ShowPartitions(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) - - // GetLoadingProgress get the collection or partitions loading progress - GetLoadingProgress(ctx context.Context, request *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) - // GetLoadState get the collection or partitions load state - GetLoadState(ctx context.Context, request *milvuspb.GetLoadStateRequest) (*milvuspb.GetLoadStateResponse, error) - - // CreateIndex notifies Proxy to create index of a field - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, field name, index parameters - // - // The `ErrorCode` of `Status` is `Success` if create index successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - CreateIndex(ctx context.Context, request *milvuspb.CreateIndexRequest) (*commonpb.Status, error) - - // DropIndex notifies Proxy to drop an index - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, field name, index name - // - // The `ErrorCode` of `Status` is `Success` if drop index successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - DropIndex(ctx context.Context, request *milvuspb.DropIndexRequest) (*commonpb.Status, error) - - // DescribeIndex notifies Proxy to return index's description - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, field name, index name - // - // The `Status` in response struct `DescribeIndexResponse` indicates if this operation is processed successfully or fail cause; - // the `IndexDescriptions` in `DescribeIndexResponse` return index's description. - // error is always nil - DescribeIndex(ctx context.Context, request *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) - - // GetIndexStatistics notifies Proxy to return index's statistics - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, field name, index name - // - // The `Status` in response struct `GetIndexStatisticsResponse` indicates if this operation is processed successfully or fail cause; - // the `IndexDescriptions` in `GetIndexStatisticsResponse` return index's statistics. - // error is always nil - GetIndexStatistics(ctx context.Context, request *milvuspb.GetIndexStatisticsRequest) (*milvuspb.GetIndexStatisticsResponse, error) - - // GetIndexBuildProgress notifies Proxy to return index build progress - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, field name, index name - // - // The `Status` in response struct `GetIndexBuildProgressResponse` indicates if this operation is processed successfully or fail cause; - // the `IndexdRows` in `GetIndexBuildProgressResponse` return the num of indexed rows. - // the `TotalRows` in `GetIndexBuildProgressResponse` return the total number of segment rows. - // error is always nil - // Deprecated: use DescribeIndex instead - GetIndexBuildProgress(ctx context.Context, request *milvuspb.GetIndexBuildProgressRequest) (*milvuspb.GetIndexBuildProgressResponse, error) - - // GetIndexState notifies Proxy to return index state - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, field name, index name - // - // The `Status` in response struct `GetIndexStateResponse` indicates if this operation is processed successfully or fail cause; - // the `State` in `GetIndexStateResponse` return the state of index: Unissued/InProgress/Finished/Failed. - // error is always nil - // Deprecated: use DescribeIndex instead - GetIndexState(ctx context.Context, request *milvuspb.GetIndexStateRequest) (*milvuspb.GetIndexStateResponse, error) - - // Insert notifies Proxy to insert rows - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, partition name(optional), fields data - // - // The `Status` in response struct `MutationResult` indicates if this operation is processed successfully or fail cause; - // the `IDs` in `MutationResult` return the id list of inserted rows. - // the `SuccIndex` in `MutationResult` return the succeed number of inserted rows. - // the `ErrIndex` in `MutationResult` return the failed number of insert rows. - // error is always nil - Insert(ctx context.Context, request *milvuspb.InsertRequest) (*milvuspb.MutationResult, error) - - // Delete notifies Proxy to delete rows - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, partition name(optional), filter expression - // - // The `Status` in response struct `MutationResult` indicates if this operation is processed successfully or fail cause; - // the `IDs` in `MutationResult` return the id list of deleted rows. - // the `SuccIndex` in `MutationResult` return the succeed number of deleted rows. - // the `ErrIndex` in `MutationResult` return the failed number of delete rows. - // error is always nil - Delete(ctx context.Context, request *milvuspb.DeleteRequest) (*milvuspb.MutationResult, error) - - // Upsert notifies Proxy to upsert rows - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, partition name(optional), fields data - // - // The `Status` in response struct `MutationResult` indicates if this operation is processed successfully or fail cause; - // the `IDs` in `MutationResult` return the id list of upserted rows. - // the `SuccIndex` in `MutationResult` return the succeed number of upserted rows. - // the `ErrIndex` in `MutationResult` return the failed number of upsert rows. - // error is always nil - Upsert(ctx context.Context, request *milvuspb.UpsertRequest) (*milvuspb.MutationResult, error) - - // Search notifies Proxy to do search - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, partition name(optional), filter expression - // - // The `Status` in response struct `SearchResults` indicates if this operation is processed successfully or fail cause; - // the `Results` in `SearchResults` return search results. - // error is always nil - Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) - - // Flush notifies Proxy to flush buffer into storage - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name - // - // The `Status` in response struct `FlushResponse` indicates if this operation is processed successfully or fail cause; - // error is always nil - Flush(ctx context.Context, request *milvuspb.FlushRequest) (*milvuspb.FlushResponse, error) - - // Query notifies Proxy to query rows - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, partition names(optional), filter expression, output fields - // - // The `Status` in response struct `QueryResults` indicates if this operation is processed successfully or fail cause; - // the `FieldsData` in `QueryResults` return query results. - // error is always nil - Query(ctx context.Context, request *milvuspb.QueryRequest) (*milvuspb.QueryResults, error) - - // CalcDistance notifies Proxy to calculate distance between specified vectors - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), vectors to calculate - // - // The `Status` in response struct `CalcDistanceResults` indicates if this operation is processed successfully or fail cause; - // The `Array` in response struct `CalcDistanceResults` return distance result - // Return generic error when specified vectors not found or float/binary vectors mismatch, otherwise return nil - CalcDistance(ctx context.Context, request *milvuspb.CalcDistanceRequest) (*milvuspb.CalcDistanceResults, error) - - // FlushAll notifies Proxy to flush all collection's DML messages, including those in message stream. - // - // ctx is the context to control request deadline and cancellation - // - // The `Status` in response struct `FlushAllResponse` indicates if this operation is processed successfully or fail cause; - // The `FlushAllTs` field in the `FlushAllResponse` response struct is used to check the flushAll state at the - // `GetFlushAllState` interface. `GetFlushAllState` would check if all DML messages before `FlushAllTs` have been flushed. - // error is always nil - FlushAll(ctx context.Context, request *milvuspb.FlushAllRequest) (*milvuspb.FlushAllResponse, error) - - // Not yet implemented - GetDdChannel(ctx context.Context, request *internalpb.GetDdChannelRequest) (*milvuspb.StringResponse, error) - - // GetPersistentSegmentInfo notifies Proxy to return sealed segments's information of a collection from data coord - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name - // - // The `Status` in response struct `GetPersistentSegmentInfoResponse` indicates if this operation is processed successfully or fail cause; - // the `Infos` in `GetPersistentSegmentInfoResponse` return sealed segments's information of a collection. - // error is always nil - GetPersistentSegmentInfo(ctx context.Context, request *milvuspb.GetPersistentSegmentInfoRequest) (*milvuspb.GetPersistentSegmentInfoResponse, error) - - // GetQuerySegmentInfo notifies Proxy to return growing segments's information of a collection from query coord - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name - // - // The `Status` in response struct `GetQuerySegmentInfoResponse` indicates if this operation is processed successfully or fail cause; - // the `Infos` in `GetQuerySegmentInfoResponse` return growing segments's information of a collection. - // error is always nil - GetQuerySegmentInfo(ctx context.Context, request *milvuspb.GetQuerySegmentInfoRequest) (*milvuspb.GetQuerySegmentInfoResponse, error) - - // For internal usage - Dummy(ctx context.Context, request *milvuspb.DummyRequest) (*milvuspb.DummyResponse, error) - - // RegisterLink notifies Proxy to its state code - // - // ctx is the context to control request deadline and cancellation - // - // The `Status` in response struct `RegisterLinkResponse` indicates if this proxy is healthy or not - // error is always nil - RegisterLink(ctx context.Context, request *milvuspb.RegisterLinkRequest) (*milvuspb.RegisterLinkResponse, error) - - // GetMetrics gets the metrics of the proxy. - GetMetrics(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) - - // LoadBalance would do a load balancing operation between query nodes. - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including source query node ids and sealed segment ids to balance - // - // The `ErrorCode` of `Status` is `Success` if load balance successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - LoadBalance(ctx context.Context, request *milvuspb.LoadBalanceRequest) (*commonpb.Status, error) - - // CreateAlias notifies Proxy to create alias for a collection - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, alias - // - // The `ErrorCode` of `Status` is `Success` if create alias successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - CreateAlias(ctx context.Context, request *milvuspb.CreateAliasRequest) (*commonpb.Status, error) - - // DropAlias notifies Proxy to drop an alias - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, alias - // - // The `ErrorCode` of `Status` is `Success` if drop alias successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - DropAlias(ctx context.Context, request *milvuspb.DropAliasRequest) (*commonpb.Status, error) - - // AlterAlias notifies Proxy to alter an alias from a colection to another - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including database name(reserved), collection name, alias - // - // The `ErrorCode` of `Status` is `Success` if alter alias successfully; - // otherwise, the `ErrorCode` of `Status` will be `Error`, and the `Reason` of `Status` will record the fail cause. - // error is always nil - AlterAlias(ctx context.Context, request *milvuspb.AlterAliasRequest) (*commonpb.Status, error) - GetCompactionState(ctx context.Context, req *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) - ManualCompaction(ctx context.Context, req *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) - GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) - // GetFlushState gets the flush state of multiple segments - GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) - // GetFlushAllState checks if all DML messages before `FlushAllTs` have been flushed. - GetFlushAllState(ctx context.Context, req *milvuspb.GetFlushAllStateRequest) (*milvuspb.GetFlushAllStateResponse, error) - - // Import data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including file path and options - // - // The `Status` in response struct `ImportResponse` indicates if this operation is processed successfully or fail cause; - // the `tasks` in `ImportResponse` return an id list of tasks. - // error is always nil - Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) - - // Check import task state from datanode - // - // ctx is the context to control request deadline and cancellation - // req contains the request params, including a task id - // - // The `Status` in response struct `GetImportStateResponse` indicates if this operation is processed successfully or fail cause; - // the `state` in `GetImportStateResponse` return the state of the import task. - // error is always nil - GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) - - // List id array of all import tasks - // - // ctx is the context to control request deadline and cancellation - // req contains the request params - // - // The `Status` in response struct `ListImportTasksResponse` indicates if this operation is processed successfully or fail cause; - // the `Tasks` in `ListImportTasksResponse` return the id array of all import tasks. - // error is always nil - ListImportTasks(ctx context.Context, req *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) - - GetReplicas(ctx context.Context, req *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) - - // CreateCredential create new user and password - CreateCredential(ctx context.Context, req *milvuspb.CreateCredentialRequest) (*commonpb.Status, error) - // UpdateCredential update password for a user - UpdateCredential(ctx context.Context, req *milvuspb.UpdateCredentialRequest) (*commonpb.Status, error) - // DeleteCredential delete a user - DeleteCredential(ctx context.Context, req *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) - // ListCredUsers list all usernames - ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) - - CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest) (*commonpb.Status, error) - DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) (*commonpb.Status, error) - OperateUserRole(ctx context.Context, req *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) - SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) - SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) - OperatePrivilege(ctx context.Context, req *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) - SelectGrant(ctx context.Context, req *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) - - CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) - - // RenameCollection rename collection from old name to new name - RenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest) (*commonpb.Status, error) - - CreateResourceGroup(ctx context.Context, req *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) - DropResourceGroup(ctx context.Context, req *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) - TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) - TransferReplica(ctx context.Context, req *milvuspb.TransferReplicaRequest) (*commonpb.Status, error) - ListResourceGroups(ctx context.Context, req *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) - DescribeResourceGroup(ctx context.Context, req *milvuspb.DescribeResourceGroupRequest) (*milvuspb.DescribeResourceGroupResponse, error) - - Connect(ctx context.Context, req *milvuspb.ConnectRequest) (*milvuspb.ConnectResponse, error) - - AllocTimestamp(ctx context.Context, req *milvuspb.AllocTimestampRequest) (*milvuspb.AllocTimestampResponse, error) +type QueryNodeClient interface { + io.Closer + querypb.QueryNodeClient } // QueryNode is the interface `querynode` package implements type QueryNode interface { Component - TimeTickProvider - - WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) - UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) - // LoadSegments notifies QueryNode to load the sealed segments from storage. The load tasks are sync to this - // rpc, QueryNode will return after all the sealed segments are loaded. - // - // Return UnexpectedError code in status: - // If QueryNode isn't in HEALTHY: states not HEALTHY or dynamic checks not HEALTHY. - // If any segment is loaded failed in QueryNode. - // Return Success code in status: - // All the sealed segments are loaded. - LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) - ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) - LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) - ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) - ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) - GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) - - GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) - Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) - SearchSegments(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) - Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) - QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) - SyncReplicaSegments(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) - - ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) - // GetMetrics gets the metrics about QueryNode. - GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) - GetDataDistribution(context.Context, *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) - SyncDistribution(context.Context, *querypb.SyncDistributionRequest) (*commonpb.Status, error) - Delete(context.Context, *querypb.DeleteRequest) (*commonpb.Status, error) + querypb.QueryNodeServer } // QueryNodeComponent is used by grpc server of QueryNode @@ -1476,36 +288,16 @@ type QueryNodeComponent interface { SetEtcdClient(etcdClient *clientv3.Client) } +// QueryCoordClient is the client interface for querycoord server +type QueryCoordClient interface { + io.Closer + querypb.QueryCoordClient +} + // QueryCoord is the interface `querycoord` package implements type QueryCoord interface { Component - TimeTickProvider - - ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) - LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) - ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) - ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) - LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) - ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) - GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) - GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) - SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) - LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest) (*commonpb.Status, error) - - ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) - GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) - - GetReplicas(ctx context.Context, req *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) - GetShardLeaders(ctx context.Context, req *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) - - CheckHealth(ctx context.Context, req *milvuspb.CheckHealthRequest) (*milvuspb.CheckHealthResponse, error) - - CreateResourceGroup(ctx context.Context, req *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) - DropResourceGroup(ctx context.Context, req *milvuspb.DropResourceGroupRequest) (*commonpb.Status, error) - TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) - TransferReplica(ctx context.Context, req *querypb.TransferReplicaRequest) (*commonpb.Status, error) - ListResourceGroups(ctx context.Context, req *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) - DescribeResourceGroup(ctx context.Context, req *querypb.DescribeResourceGroupRequest) (*querypb.DescribeResourceGroupResponse, error) + querypb.QueryCoordServer } // QueryCoordComponent is used by grpc server of QueryCoord @@ -1519,28 +311,31 @@ type QueryCoordComponent interface { // SetEtcdClient set etcd client for QueryCoord SetEtcdClient(etcdClient *clientv3.Client) + // SetTiKVClient set TiKV client for QueryCoord + SetTiKVClient(client *txnkv.Client) + // UpdateStateCode updates state code for QueryCoord // `stateCode` is current statement of this QueryCoord, indicating whether it's healthy. UpdateStateCode(stateCode commonpb.StateCode) - // SetDataCoord set DataCoord for QueryCoord + // SetDataCoordClient set SetDataCoordClient for QueryCoord // `dataCoord` is a client of data coordinator. // // Return a generic error in status: // If the dataCoord is nil. // Return nil in status: // The dataCoord is not nil. - SetDataCoord(dataCoord DataCoord) error + SetDataCoordClient(dataCoord DataCoordClient) error - // SetRootCoord set RootCoord for QueryCoord + // SetRootCoordClient set SetRootCoordClient for QueryCoord // `rootCoord` is a client of root coordinator. // // Return a generic error in status: // If the rootCoord is nil. // Return nil in status: // The rootCoord is not nil. - SetRootCoord(rootCoord RootCoord) error + SetRootCoordClient(rootCoord RootCoordClient) error // SetQueryNodeCreator set QueryNode client creator func for QueryCoord - SetQueryNodeCreator(func(ctx context.Context, addr string, nodeID int64) (QueryNode, error)) + SetQueryNodeCreator(func(ctx context.Context, addr string, nodeID int64) (QueryNodeClient, error)) } diff --git a/internal/util/componentutil/componentutil.go b/internal/util/componentutil/componentutil.go index 0219c69d2a62c..96b74732d199a 100644 --- a/internal/util/componentutil/componentutil.go +++ b/internal/util/componentutil/componentutil.go @@ -21,23 +21,26 @@ import ( "fmt" "time" - "github.com/cockroachdb/errors" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/retry" ) // WaitForComponentStates wait for component's state to be one of the specific states -func WaitForComponentStates(ctx context.Context, service types.Component, serviceName string, states []commonpb.StateCode, attempts uint, sleep time.Duration) error { +func WaitForComponentStates[T interface { + GetComponentStates(ctx context.Context, _ *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) +}](ctx context.Context, client T, serviceName string, states []commonpb.StateCode, attempts uint, sleep time.Duration) error { checkFunc := func() error { - resp, err := service.GetComponentStates(ctx) + resp, err := client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) if err != nil { return err } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(resp.Status.Reason) + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return merr.Error(resp.GetStatus()) } meet := false @@ -59,16 +62,22 @@ func WaitForComponentStates(ctx context.Context, service types.Component, servic } // WaitForComponentInitOrHealthy wait for component's state to be initializing or healthy -func WaitForComponentInitOrHealthy(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error { - return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy}, attempts, sleep) +func WaitForComponentInitOrHealthy[T interface { + GetComponentStates(ctx context.Context, _ *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) +}](ctx context.Context, client T, serviceName string, attempts uint, sleep time.Duration) error { + return WaitForComponentStates(ctx, client, serviceName, []commonpb.StateCode{commonpb.StateCode_Initializing, commonpb.StateCode_Healthy}, attempts, sleep) } // WaitForComponentInit wait for component's state to be initializing -func WaitForComponentInit(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error { - return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Initializing}, attempts, sleep) +func WaitForComponentInit[T interface { + GetComponentStates(ctx context.Context, _ *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) +}](ctx context.Context, client T, serviceName string, attempts uint, sleep time.Duration) error { + return WaitForComponentStates(ctx, client, serviceName, []commonpb.StateCode{commonpb.StateCode_Initializing}, attempts, sleep) } // WaitForComponentHealthy wait for component's state to be healthy -func WaitForComponentHealthy(ctx context.Context, service types.Component, serviceName string, attempts uint, sleep time.Duration) error { - return WaitForComponentStates(ctx, service, serviceName, []commonpb.StateCode{commonpb.StateCode_Healthy}, attempts, sleep) +func WaitForComponentHealthy[T interface { + GetComponentStates(ctx context.Context, _ *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) +}](ctx context.Context, client T, serviceName string, attempts uint, sleep time.Duration) error { + return WaitForComponentStates(ctx, client, serviceName, []commonpb.StateCode{commonpb.StateCode_Healthy}, attempts, sleep) } diff --git a/internal/util/componentutil/componentutil_test.go b/internal/util/componentutil/componentutil_test.go index bfa7f56b8c14f..107dc7258ab04 100644 --- a/internal/util/componentutil/componentutil_test.go +++ b/internal/util/componentutil/componentutil_test.go @@ -22,11 +22,13 @@ import ( "time" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/stretchr/testify/assert" ) type MockComponent struct { @@ -55,11 +57,11 @@ func (mc *MockComponent) Stop() error { return nil } -func (mc *MockComponent) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { +func (mc *MockComponent) GetComponentStates(ctx context.Context, req *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { return mc.compState, mc.compErr } -func (mc *MockComponent) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { +func (mc *MockComponent) GetStatisticsChannel(ctx context.Context, req *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { return mc.strResp, nil } diff --git a/internal/util/dependency/factory.go b/internal/util/dependency/factory.go index 6dc63263803d1..761143459781a 100644 --- a/internal/util/dependency/factory.go +++ b/internal/util/dependency/factory.go @@ -4,12 +4,13 @@ import ( "context" "github.com/cockroachdb/errors" + "go.uber.org/zap" + smsgstream "github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/paramtable" - "go.uber.org/zap" ) const ( diff --git a/internal/util/flowgraph/flow_graph.go b/internal/util/flowgraph/flow_graph.go index 8f5e9777b49af..46b6cd79b61e4 100644 --- a/internal/util/flowgraph/flow_graph.go +++ b/internal/util/flowgraph/flow_graph.go @@ -18,6 +18,7 @@ package flowgraph import ( "context" + "fmt" "sync" "github.com/cockroachdb/errors" @@ -125,3 +126,21 @@ func NewTimeTickedFlowGraph(ctx context.Context) *TimeTickedFlowGraph { return &flowGraph } + +func (fg *TimeTickedFlowGraph) AssembleNodes(orderedNodes ...Node) error { + for _, node := range orderedNodes { + fg.AddNode(node) + } + + for i, node := range orderedNodes { + // Set edge to the next node + if i < len(orderedNodes)-1 { + err := fg.SetEdges(node.Name(), []string{orderedNodes[i+1].Name()}) + if err != nil { + errMsg := fmt.Sprintf("set edges failed for flow graph, node=%s", node.Name()) + return errors.New(errMsg) + } + } + } + return nil +} diff --git a/internal/util/flowgraph/flow_graph_test.go b/internal/util/flowgraph/flow_graph_test.go index f79145e4b12ef..dbe87b371740a 100644 --- a/internal/util/flowgraph/flow_graph_test.go +++ b/internal/util/flowgraph/flow_graph_test.go @@ -24,8 +24,9 @@ import ( "testing" "time" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/paramtable" ) // Flow graph basic example: count `c = pow(a) + 2` @@ -137,7 +138,7 @@ func createExampleFlowGraph() (*TimeTickedFlowGraph, chan float64, chan float64, fg.AddNode(b) fg.AddNode(c) - var err = fg.SetEdges(a.Name(), + err := fg.SetEdges(a.Name(), []string{b.Name()}, ) if err != nil { diff --git a/internal/util/flowgraph/input_node.go b/internal/util/flowgraph/input_node.go index a847c983664ba..6100f7cdf21ca 100644 --- a/internal/util/flowgraph/input_node.go +++ b/internal/util/flowgraph/input_node.go @@ -21,17 +21,16 @@ import ( "fmt" "sync" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" - - "github.com/milvus-io/milvus/pkg/util/typeutil" + "go.uber.org/atomic" + "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" - "go.uber.org/atomic" - "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) const ( diff --git a/internal/util/flowgraph/input_node_test.go b/internal/util/flowgraph/input_node_test.go index 6d2ce6edc4741..84c1c396ed68c 100644 --- a/internal/util/flowgraph/input_node_test.go +++ b/internal/util/flowgraph/input_node_test.go @@ -20,9 +20,9 @@ import ( "context" "testing" - "github.com/milvus-io/milvus/internal/util/dependency" "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" ) @@ -32,7 +32,7 @@ func TestInputNode(t *testing.T) { msgStream, _ := factory.NewMsgStream(context.TODO()) channels := []string{"cc"} - msgStream.AsConsumer(channels, "sub", mqwrapper.SubscriptionPositionEarliest) + msgStream.AsConsumer(context.Background(), channels, "sub", mqwrapper.SubscriptionPositionEarliest) msgPack := generateMsgPack() produceStream, _ := factory.NewMsgStream(context.TODO()) diff --git a/internal/util/flowgraph/message_test.go b/internal/util/flowgraph/message_test.go index adb35ac1a5759..720aa631f27c6 100644 --- a/internal/util/flowgraph/message_test.go +++ b/internal/util/flowgraph/message_test.go @@ -20,8 +20,9 @@ import ( "context" "testing" - "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/mq/msgstream" ) type MockMsg struct { @@ -77,7 +78,6 @@ func (bm *MockMsg) Position() *MsgPosition { } func (bm *MockMsg) SetPosition(position *MsgPosition) { - } func (bm *MockMsg) Size() int { diff --git a/internal/util/flowgraph/node.go b/internal/util/flowgraph/node.go index 75cbe4980d4ae..9962b5c765e66 100644 --- a/internal/util/flowgraph/node.go +++ b/internal/util/flowgraph/node.go @@ -21,16 +21,18 @@ import ( "sync" "time" - "github.com/milvus-io/milvus/pkg/util/timerecord" + "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" - "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/util/timerecord" ) const ( // TODO: better to be configured nodeCtxTtInterval = 2 * time.Minute enableTtChecker = true + // blockAll should wait no more than 10 seconds + blockAllWait = 10 * time.Second ) // Node is the interface defines the behavior of flowgraph @@ -74,7 +76,13 @@ func (nodeCtx *nodeCtx) Start() { func (nodeCtx *nodeCtx) Block() { // input node operate function will be blocking if !nodeCtx.node.IsInputNode() { + startTs := time.Now() nodeCtx.blockMutex.Lock() + if time.Since(startTs) >= blockAllWait { + log.Warn("flow graph wait for long time", + zap.String("name", nodeCtx.node.Name()), + zap.Duration("wait time", time.Since(startTs))) + } } } @@ -200,6 +208,11 @@ func (node *BaseNode) IsValidInMsg(in []Msg) bool { return false } + if len(in) == 0 { + // avoid printing too many logs. + return false + } + if len(in) != 1 { log.Warn("Invalid operate message input", zap.Int("input length", len(in))) return false diff --git a/internal/util/flowgraph/node_test.go b/internal/util/flowgraph/node_test.go index 9900f98848e7e..a88dbcc151f14 100644 --- a/internal/util/flowgraph/node_test.go +++ b/internal/util/flowgraph/node_test.go @@ -62,7 +62,7 @@ func TestNodeCtx_Start(t *testing.T) { msgStream, _ := factory.NewMsgStream(context.TODO()) channels := []string{"cc"} - msgStream.AsConsumer(channels, "sub", mqwrapper.SubscriptionPositionEarliest) + msgStream.AsConsumer(context.TODO(), channels, "sub", mqwrapper.SubscriptionPositionEarliest) produceStream, _ := factory.NewMsgStream(context.TODO()) produceStream.AsProducer(channels) diff --git a/internal/util/funcutil/count_util.go b/internal/util/funcutil/count_util.go index 9521f7fdef22c..f00b3c430d896 100644 --- a/internal/util/funcutil/count_util.go +++ b/internal/util/funcutil/count_util.go @@ -3,9 +3,8 @@ package funcutil import ( "fmt" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/segcorepb" diff --git a/internal/util/funcutil/count_util_test.go b/internal/util/funcutil/count_util_test.go index 620582c571707..db4c430d72b5d 100644 --- a/internal/util/funcutil/count_util_test.go +++ b/internal/util/funcutil/count_util_test.go @@ -3,11 +3,11 @@ package funcutil import ( "testing" - "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/proto/segcorepb" ) func TestCntOfInternalResult(t *testing.T) { @@ -20,7 +20,6 @@ func TestCntOfInternalResult(t *testing.T) { }) t.Run("normal case", func(t *testing.T) { - res := WrapCntToInternalResult(5) cnt, err := CntOfInternalResult(res) assert.NoError(t, err) @@ -38,7 +37,6 @@ func TestCntOfSegCoreResult(t *testing.T) { }) t.Run("normal case", func(t *testing.T) { - res := WrapCntToSegCoreResult(5) cnt, err := CntOfSegCoreResult(res) assert.NoError(t, err) @@ -56,7 +54,6 @@ func TestCntOfFieldData(t *testing.T) { }) t.Run("not long data", func(t *testing.T) { - f := &schemapb.FieldData{ Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ @@ -69,7 +66,6 @@ func TestCntOfFieldData(t *testing.T) { }) t.Run("more than one row", func(t *testing.T) { - f := &schemapb.FieldData{ Field: &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{ @@ -86,7 +82,6 @@ func TestCntOfFieldData(t *testing.T) { }) t.Run("more than one row", func(t *testing.T) { - f := WrapCntToFieldData(1000) cnt, err := CntOfFieldData(f) assert.NoError(t, err) diff --git a/internal/util/grpcclient/client.go b/internal/util/grpcclient/client.go index 7dc8a12685ed7..058d6cf0d303e 100644 --- a/internal/util/grpcclient/client.go +++ b/internal/util/grpcclient/client.go @@ -19,22 +19,23 @@ package grpcclient import ( "context" "crypto/tls" - "fmt" "strings" "sync" "time" + "github.com/cockroachdb/errors" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/atomic" "go.uber.org/zap" - "golang.org/x/sync/singleflight" "google.golang.org/grpc" "google.golang.org/grpc/backoff" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" @@ -46,6 +47,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/retry" ) // GrpcClient abstracts client of grpc @@ -88,31 +90,42 @@ type ClientBase[T interface { KeepAliveTime time.Duration KeepAliveTimeout time.Duration - MaxAttempts int - InitialBackoff float32 - MaxBackoff float32 - BackoffMultiplier float32 - NodeID atomic.Int64 - sess *sessionutil.Session - - sf singleflight.Group + MaxAttempts int + InitialBackoff float64 + MaxBackoff float64 + // resetInterval is the minimal duration to reset connection + minResetInterval time.Duration + lastReset atomic.Time + // sessionCheckInterval is the minmal duration to check session, preventing too much etcd pulll + minSessionCheckInterval time.Duration + lastSessionCheck atomic.Time + + // counter for canceled or deadline exceeded + ctxCounter atomic.Int32 + maxCancelError int32 + + NodeID atomic.Int64 + sess *sessionutil.Session } func NewClientBase[T interface { GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) -}](config *paramtable.GrpcClientConfig, serviceName string) *ClientBase[T] { +}](config *paramtable.GrpcClientConfig, serviceName string, +) *ClientBase[T] { return &ClientBase[T]{ - ClientMaxRecvSize: config.ClientMaxRecvSize.GetAsInt(), - ClientMaxSendSize: config.ClientMaxSendSize.GetAsInt(), - DialTimeout: config.DialTimeout.GetAsDuration(time.Millisecond), - KeepAliveTime: config.KeepAliveTime.GetAsDuration(time.Millisecond), - KeepAliveTimeout: config.KeepAliveTimeout.GetAsDuration(time.Millisecond), - RetryServiceNameConfig: serviceName, - MaxAttempts: config.MaxAttempts.GetAsInt(), - InitialBackoff: float32(config.InitialBackoff.GetAsFloat()), - MaxBackoff: float32(config.MaxBackoff.GetAsFloat()), - BackoffMultiplier: float32(config.BackoffMultiplier.GetAsFloat()), - CompressionEnabled: config.CompressionEnabled.GetAsBool(), + ClientMaxRecvSize: config.ClientMaxRecvSize.GetAsInt(), + ClientMaxSendSize: config.ClientMaxSendSize.GetAsInt(), + DialTimeout: config.DialTimeout.GetAsDuration(time.Millisecond), + KeepAliveTime: config.KeepAliveTime.GetAsDuration(time.Millisecond), + KeepAliveTimeout: config.KeepAliveTimeout.GetAsDuration(time.Millisecond), + RetryServiceNameConfig: serviceName, + MaxAttempts: config.MaxAttempts.GetAsInt(), + InitialBackoff: config.InitialBackoff.GetAsFloat(), + MaxBackoff: config.MaxBackoff.GetAsFloat(), + CompressionEnabled: config.CompressionEnabled.GetAsBool(), + minResetInterval: config.MinResetInterval.GetAsDuration(time.Millisecond), + minSessionCheckInterval: config.MinSessionCheckInterval.GetAsDuration(time.Millisecond), + maxCancelError: config.MaxCancelError.GetAsInt32(), } } @@ -171,8 +184,14 @@ func (c *ClientBase[T]) GetGrpcClient(ctx context.Context) (T, error) { } func (c *ClientBase[T]) resetConnection(client T) { + if time.Since(c.lastReset.Load()) < c.minResetInterval { + return + } c.grpcClientMtx.Lock() defer c.grpcClientMtx.Unlock() + if time.Since(c.lastReset.Load()) < c.minResetInterval { + return + } if generic.IsZero(c.grpcClient) { return } @@ -185,6 +204,7 @@ func (c *ClientBase[T]) resetConnection(client T) { c.conn = nil c.addr.Store("") c.grpcClient = generic.Zero[T]() + c.lastReset.Store(time.Now()) } func (c *ClientBase[T]) connect(ctx context.Context) error { @@ -196,18 +216,6 @@ func (c *ClientBase[T]) connect(ctx context.Context) error { opts := tracer.GetInterceptorOpts() dialContext, cancel := context.WithTimeout(ctx, c.DialTimeout) - // refer to https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto - retryPolicy := fmt.Sprintf(`{ - "methodConfig": [{ - "name": [{"service": "%s"}], - "retryPolicy": { - "MaxAttempts": %d, - "InitialBackoff": "%fs", - "MaxBackoff": "%fs", - "BackoffMultiplier": %f, - "RetryableStatusCodes": [ "UNAVAILABLE" ] - } - }]}`, c.RetryServiceNameConfig, c.MaxAttempts, c.InitialBackoff, c.MaxBackoff, c.BackoffMultiplier) var conn *grpc.ClientConn compress := None @@ -236,7 +244,6 @@ func (c *ClientBase[T]) connect(ctx context.Context) error { interceptor.ClusterInjectionStreamClientInterceptor(), interceptor.ServerIDInjectionStreamClientInterceptor(c.GetNodeID()), )), - grpc.WithDefaultServiceConfig(retryPolicy), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: c.KeepAliveTime, Timeout: c.KeepAliveTimeout, @@ -254,6 +261,7 @@ func (c *ClientBase[T]) connect(ctx context.Context) error { grpc.WithPerRPCCredentials(&Token{Value: crypto.Base64Encode(util.MemberCredID)}), grpc.FailOnNonTempDialError(true), grpc.WithReturnConnectionError(), + grpc.WithDisableRetry(), ) } else { conn, err = grpc.DialContext( @@ -276,7 +284,6 @@ func (c *ClientBase[T]) connect(ctx context.Context) error { interceptor.ClusterInjectionStreamClientInterceptor(), interceptor.ServerIDInjectionStreamClientInterceptor(c.GetNodeID()), )), - grpc.WithDefaultServiceConfig(retryPolicy), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: c.KeepAliveTime, Timeout: c.KeepAliveTimeout, @@ -294,6 +301,7 @@ func (c *ClientBase[T]) connect(ctx context.Context) error { grpc.WithPerRPCCredentials(&Token{Value: crypto.Base64Encode(util.MemberCredID)}), grpc.FailOnNonTempDialError(true), grpc.WithReturnConnectionError(), + grpc.WithDisableRetry(), ) } @@ -307,59 +315,169 @@ func (c *ClientBase[T]) connect(ctx context.Context) error { c.conn = conn c.addr.Store(addr) + c.ctxCounter.Store(0) c.grpcClient = c.newGrpcClient(c.conn) return nil } -func (c *ClientBase[T]) callOnce(ctx context.Context, caller func(client T) (any, error)) (any, error) { - log := log.Ctx(ctx).With(zap.String("role", c.GetRole())) - client, err := c.GetGrpcClient(ctx) - if err != nil { - return generic.Zero[T](), err +func (c *ClientBase[T]) verifySession(ctx context.Context) error { + if funcutil.CheckCtxValid(ctx) { + return nil + } + log := log.Ctx(ctx).With(zap.String("clientRole", c.GetRole())) + if time.Since(c.lastSessionCheck.Load()) < c.minSessionCheckInterval { + log.Debug("skip session check, verify too frequent") + return nil } + c.lastSessionCheck.Store(time.Now()) + if c.sess != nil { + sessions, _, getSessionErr := c.sess.GetSessions(c.GetRole()) + if getSessionErr != nil { + // Only log but not handle this error as it is an auxiliary logic + log.Warn("fail to get session", zap.Error(getSessionErr)) + } + if coordSess, exist := sessions[c.GetRole()]; exist { + if c.GetNodeID() != coordSess.ServerID { + log.Warn("server id mismatch, may connected to a old server, start to reset connection", + zap.Int64("client_node", c.GetNodeID()), + zap.Int64("current_node", coordSess.ServerID)) + return merr.WrapErrNodeNotMatch(c.GetNodeID(), coordSess.ServerID) + } + } else { + return merr.WrapErrNodeNotFound(c.GetNodeID(), "session not found", c.GetRole()) + } + } + return nil +} - ret, err := caller(client) - if err == nil { - return ret, nil +func (c *ClientBase[T]) needResetCancel() (needReset bool) { + val := c.ctxCounter.Add(1) + if val > c.maxCancelError { + c.ctxCounter.Store(0) + return true } + return false +} - if IsCrossClusterRoutingErr(err) { - log.Warn("CrossClusterRoutingErr, start to reset connection", zap.Error(err)) - c.resetConnection(client) - return ret, merr.ErrServiceUnavailable // For concealing ErrCrossClusterRouting from the client +func (c *ClientBase[T]) checkErr(ctx context.Context, err error) (needRetry, needReset bool, retErr error) { + log := log.Ctx(ctx).With(zap.String("clientRole", c.GetRole())) + switch { + case funcutil.IsGrpcErr(err): + // grpc err + log.Warn("call received grpc error", zap.Error(err)) + if funcutil.IsGrpcErr(err, codes.Canceled, codes.DeadlineExceeded) { + // canceled or deadline exceeded + return true, c.needResetCancel(), err + } + + if funcutil.IsGrpcErr(err, codes.Unimplemented) { + return false, false, merr.WrapErrServiceUnimplemented(err) + } + return true, true, err + case IsServerIDMismatchErr(err): + fallthrough + case IsCrossClusterRoutingErr(err): + return true, true, err + default: + log.Warn("fail to grpc call because of unknown error", zap.Error(err)) + // Unknown err + return false, false, err + } +} + +func (c *ClientBase[T]) call(ctx context.Context, caller func(client T) (any, error)) (any, error) { + log := log.Ctx(ctx).With(zap.String("client_role", c.GetRole())) + var ( + ret any + clientErr error + client T + ) + + client, clientErr = c.GetGrpcClient(ctx) + if clientErr != nil { + log.Warn("fail to get grpc client", zap.Error(clientErr)) } - if IsServerIDMismatchErr(err) { - log.Warn("Server ID mismatch, start to reset connection", zap.Error(err)) + + resetClientFunc := func() { c.resetConnection(client) - return ret, err + client, clientErr = c.GetGrpcClient(ctx) + if clientErr != nil { + log.Warn("fail to get grpc client in the retry state", zap.Error(clientErr)) + } } - if !funcutil.CheckCtxValid(ctx) { - // check if server ID matches coord session, if not, reset connection - if c.sess != nil { - sessions, _, getSessionErr := c.sess.GetSessions(c.GetRole()) - if getSessionErr != nil { - // Only log but not handle this error as it is an auxiliary logic - log.Warn("Fail to GetSessions", zap.Error(getSessionErr)) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + err := retry.Do(ctx, func() error { + if generic.IsZero(client) { + err := errors.Wrap(clientErr, "empty grpc client") + log.Warn("grpc client is nil, maybe fail to get client in the retry state", zap.Error(err)) + resetClientFunc() + return err + } + var err error + ret, err = caller(client) + if err != nil { + var needRetry, needReset bool + needRetry, needReset, err = c.checkErr(ctx, err) + if !needRetry { + // stop retry + err = retry.Unrecoverable(err) } - if coordSess, exist := sessions[c.GetRole()]; exist { - if c.GetNodeID() != coordSess.ServerID { - log.Warn("Server ID mismatch, may connected to a old server, start to reset connection", zap.Error(err)) - c.resetConnection(client) - return ret, err + if needReset { + log.Warn("start to reset connection because of specific reasons", zap.Error(err)) + resetClientFunc() + } else { + err := c.verifySession(ctx) + if err != nil { + log.Warn("failed to verify session, reset connection", zap.Error(err)) + resetClientFunc() } } + return err } - // start bg check in case of https://github.com/milvus-io/milvus/issues/22435 - go c.bgHealthCheck(client) - return generic.Zero[T](), err - } - if !funcutil.IsGrpcErr(err) { - log.Warn("ClientBase:isNotGrpcErr", zap.Error(err)) + // reset counter + c.ctxCounter.Store(0) + + var status *commonpb.Status + switch res := ret.(type) { + case *commonpb.Status: + status = res + case interface{ GetStatus() *commonpb.Status }: + status = res.GetStatus() + default: + // it will directly return the result + log.Warn("unknown return type", zap.Any("return", ret)) + return nil + } + + if status == nil { + log.Warn("status is nil, please fix it", zap.Stack("stack")) + return nil + } + + err = merr.Error(status) + if err != nil && merr.IsRetryableErr(err) { + return err + } + return nil + }, retry.Attempts(uint(c.MaxAttempts)), + // Because the previous InitialBackoff and MaxBackoff were float, and the unit was s. + // For compatibility, this is multiplied by 1000. + retry.Sleep(time.Duration(c.InitialBackoff*1000)*time.Millisecond), + retry.MaxSleepTime(time.Duration(c.MaxBackoff*1000)*time.Millisecond)) + // default value list: MaxAttempts 10, InitialBackoff 0.2s, MaxBackoff 10s + // and consume 52.8s if all retry failed + if err != nil { + // make the error more friendly to user + if IsCrossClusterRoutingErr(err) { + err = merr.ErrServiceUnavailable + } + return generic.Zero[T](), err } - log.Info("ClientBase grpc error, start to reset connection", zap.Error(err)) - c.resetConnection(client) - return ret, err + + return ret, nil } // Call does a grpc call @@ -368,10 +486,10 @@ func (c *ClientBase[T]) Call(ctx context.Context, caller func(client T) (any, er return generic.Zero[T](), ctx.Err() } - ret, err := c.callOnce(ctx, caller) + ret, err := c.call(ctx, caller) if err != nil { - traceErr := fmt.Errorf("err: %w\n, %s", err, tracer.StackTrace()) - log.Ctx(ctx).Warn("ClientBase Call grpc first call get error", + traceErr := errors.Wrapf(err, "stack trace: %s", tracer.StackTrace()) + log.Ctx(ctx).Warn("ClientBase Call grpc call get error", zap.String("role", c.GetRole()), zap.String("address", c.GetAddr()), zap.Error(traceErr), @@ -383,44 +501,8 @@ func (c *ClientBase[T]) Call(ctx context.Context, caller func(client T) (any, er // ReCall does the grpc call twice func (c *ClientBase[T]) ReCall(ctx context.Context, caller func(client T) (any, error)) (any, error) { - if !funcutil.CheckCtxValid(ctx) { - return generic.Zero[T](), ctx.Err() - } - - ret, err := c.callOnce(ctx, caller) - if err == nil { - return ret, nil - } - - log := log.Ctx(ctx).With(zap.String("role", c.GetRole()), zap.String("address", c.GetAddr())) - traceErr := fmt.Errorf("err: %w\n, %s", err, tracer.StackTrace()) - log.Warn("ClientBase ReCall grpc first call get error ", zap.Error(traceErr)) - - if !funcutil.CheckCtxValid(ctx) { - return generic.Zero[T](), ctx.Err() - } - - ret, err = c.callOnce(ctx, caller) - if err != nil { - traceErr = fmt.Errorf("err: %w\n, %s", err, tracer.StackTrace()) - log.Warn("ClientBase ReCall grpc second call get error", zap.Error(traceErr)) - return generic.Zero[T](), traceErr - } - return ret, err -} - -func (c *ClientBase[T]) bgHealthCheck(client T) { - c.sf.Do("healthcheck", func() (any, error) { - ctx, cancel := context.WithTimeout(context.Background(), paramtable.Get().CommonCfg.SessionTTL.GetAsDuration(time.Second)) - defer cancel() - - _, err := client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) - if err != nil { - c.resetConnection(client) - } - - return struct{}{}, nil - }) + // All retry operations are done in `call` function. + return c.Call(ctx, caller) } // Close close the client connection @@ -451,7 +533,7 @@ func (c *ClientBase[T]) SetSession(sess *sessionutil.Session) { func IsCrossClusterRoutingErr(err error) bool { // GRPC utilizes `status.Status` to encapsulate errors, // hence it is not viable to employ the `errors.Is` for assessment. - return strings.Contains(err.Error(), merr.ErrCrossClusterRouting.Error()) + return strings.Contains(err.Error(), merr.ErrServiceCrossClusterRouting.Error()) } func IsServerIDMismatchErr(err error) bool { diff --git a/internal/util/grpcclient/client_test.go b/internal/util/grpcclient/client_test.go index 17272e70bcbd0..dd66b57bd702c 100644 --- a/internal/util/grpcclient/client_test.go +++ b/internal/util/grpcclient/client_test.go @@ -23,25 +23,23 @@ import ( "net" "os" "strings" - "sync" "testing" "time" "github.com/cockroachdb/errors" - + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/examples/helloworld/helloworld" "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/reflection" + "google.golang.org/grpc/status" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" - "google.golang.org/grpc" - "google.golang.org/grpc/reflection" - - "github.com/stretchr/testify/assert" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) func TestMain(m *testing.M) { @@ -104,13 +102,23 @@ func TestClientBase_CompressCall(t *testing.T) { func testCall(t *testing.T, compressed bool) { // mock client with nothing - base := ClientBase[*mockClient]{} + base := ClientBase[*mockClient]{ + maxCancelError: 10, + MaxAttempts: 3, + } base.CompressionEnabled = compressed - base.grpcClientMtx.Lock() - base.grpcClient = &mockClient{} - base.grpcClientMtx.Unlock() + initClient := func() { + base.grpcClientMtx.Lock() + base.grpcClient = &mockClient{} + base.grpcClientMtx.Unlock() + } + base.MaxAttempts = 1 + base.SetGetAddrFunc(func() (string, error) { + return "", errors.New("mocked address error") + }) t.Run("Call normal return", func(t *testing.T) { + initClient() _, err := base.Call(context.Background(), func(client *mockClient) (any, error) { return struct{}{}, nil }) @@ -118,6 +126,7 @@ func testCall(t *testing.T, compressed bool) { }) t.Run("Call with canceled context", func(t *testing.T) { + initClient() ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := base.Call(ctx, func(client *mockClient) (any, error) { @@ -128,6 +137,7 @@ func testCall(t *testing.T, compressed bool) { }) t.Run("Call canceled in caller func", func(t *testing.T) { + initClient() ctx, cancel := context.WithCancel(context.Background()) errMock := errors.New("mocked") _, err := base.Call(ctx, func(client *mockClient) (any, error) { @@ -143,11 +153,13 @@ func testCall(t *testing.T, compressed bool) { base.grpcClientMtx.RUnlock() }) - t.Run("Call canceled in caller func", func(t *testing.T) { + t.Run("Call returns non-grpc error", func(t *testing.T) { + initClient() ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errMock := errors.New("mocked") _, err := base.Call(ctx, func(client *mockClient) (any, error) { - cancel() return nil, errMock }) @@ -159,26 +171,31 @@ func testCall(t *testing.T, compressed bool) { base.grpcClientMtx.RUnlock() }) - t.Run("Call returns non-grpc error", func(t *testing.T) { + t.Run("Call returns Unavailable grpc error", func(t *testing.T) { + initClient() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - errMock := errors.New("mocked") + errGrpc := status.Error(codes.Unavailable, "mocked") _, err := base.Call(ctx, func(client *mockClient) (any, error) { - return nil, errMock + return nil, errGrpc }) assert.Error(t, err) - assert.True(t, errors.Is(err, errMock)) + assert.True(t, errors.Is(err, errGrpc)) base.grpcClientMtx.RLock() // client shall not be reset - assert.NotNil(t, base.grpcClient) + assert.Nil(t, base.grpcClient) base.grpcClientMtx.RUnlock() }) - t.Run("Call returns grpc error", func(t *testing.T) { + t.Run("Call returns canceled grpc error within limit", func(t *testing.T) { + initClient() ctx, cancel := context.WithCancel(context.Background()) defer cancel() - errGrpc := status.Error(codes.Unknown, "mocked") + defer func() { + base.ctxCounter.Store(0) + }() + errGrpc := status.Error(codes.Canceled, "mocked") _, err := base.Call(ctx, func(client *mockClient) (any, error) { return nil, errGrpc }) @@ -187,9 +204,28 @@ func testCall(t *testing.T, compressed bool) { assert.True(t, errors.Is(err, errGrpc)) base.grpcClientMtx.RLock() // client shall not be reset - assert.Nil(t, base.grpcClient) + assert.NotNil(t, base.grpcClient) base.grpcClientMtx.RUnlock() + }) + t.Run("Call returns canceled grpc error exceed limit", func(t *testing.T) { + initClient() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + base.ctxCounter.Store(10) + defer func() { + base.ctxCounter.Store(0) + }() + errGrpc := status.Error(codes.Canceled, "mocked") + _, err := base.Call(ctx, func(client *mockClient) (any, error) { + return nil, errGrpc + }) + assert.Error(t, err) + assert.True(t, errors.Is(err, errGrpc)) + base.grpcClientMtx.RLock() + // client shall not be reset + assert.Nil(t, base.grpcClient) + base.grpcClientMtx.RUnlock() }) base.grpcClientMtx.Lock() @@ -211,11 +247,18 @@ func testCall(t *testing.T, compressed bool) { func TestClientBase_Recall(t *testing.T) { // mock client with nothing base := ClientBase[*mockClient]{} - base.grpcClientMtx.Lock() - base.grpcClient = &mockClient{} - base.grpcClientMtx.Unlock() + initClient := func() { + base.grpcClientMtx.Lock() + base.grpcClient = &mockClient{} + base.grpcClientMtx.Unlock() + } + base.MaxAttempts = 1 + base.SetGetAddrFunc(func() (string, error) { + return "", errors.New("mocked address error") + }) t.Run("Recall normal return", func(t *testing.T) { + initClient() _, err := base.ReCall(context.Background(), func(client *mockClient) (any, error) { return struct{}{}, nil }) @@ -223,6 +266,7 @@ func TestClientBase_Recall(t *testing.T) { }) t.Run("ReCall with canceled context", func(t *testing.T) { + initClient() ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := base.ReCall(ctx, func(client *mockClient) (any, error) { @@ -232,24 +276,8 @@ func TestClientBase_Recall(t *testing.T) { assert.True(t, errors.Is(err, context.Canceled)) }) - t.Run("ReCall fails first and success second", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - flag := false - var mut sync.Mutex - _, err := base.ReCall(ctx, func(client *mockClient) (any, error) { - mut.Lock() - defer mut.Unlock() - if flag { - return struct{}{}, nil - } - flag = true - return nil, errors.New("mock first") - }) - assert.NoError(t, err) - }) - t.Run("ReCall canceled in caller func", func(t *testing.T) { + initClient() ctx, cancel := context.WithCancel(context.Background()) errMock := errors.New("mocked") _, err := base.ReCall(ctx, func(client *mockClient) (any, error) { @@ -258,7 +286,7 @@ func TestClientBase_Recall(t *testing.T) { }) assert.Error(t, err) - assert.True(t, errors.Is(err, context.Canceled)) + assert.True(t, errors.Is(err, errMock)) base.grpcClientMtx.RLock() // client shall not be reset assert.NotNil(t, base.grpcClient) @@ -279,7 +307,21 @@ func TestClientBase_Recall(t *testing.T) { assert.Error(t, err) assert.True(t, errors.Is(err, ErrConnect)) }) +} + +func TestClientBase_CheckError(t *testing.T) { + base := ClientBase[*mockClient]{} + base.grpcClient = &mockClient{} + base.MaxAttempts = 1 + ctx := context.Background() + retry, reset, _ := base.checkErr(ctx, status.Errorf(codes.Canceled, "fake context canceled")) + assert.True(t, retry) + assert.True(t, reset) + + retry, reset, _ = base.checkErr(ctx, status.Errorf(codes.Unimplemented, "fake context canceled")) + assert.False(t, retry) + assert.False(t, reset) } type server struct { @@ -305,16 +347,16 @@ func TestClientBase_RetryPolicy(t *testing.T) { if err != nil { log.Fatalf("failed to listen: %v", err) } - var kaep = keepalive.EnforcementPolicy{ + kaep := keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, PermitWithoutStream: true, } - var kasp = keepalive.ServerParameters{ + kasp := keepalive.ServerParameters{ Time: 60 * time.Second, Timeout: 60 * time.Second, } - maxAttempts := 5 + maxAttempts := 1 s := grpc.NewServer( grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp), @@ -338,7 +380,6 @@ func TestClientBase_RetryPolicy(t *testing.T) { MaxAttempts: maxAttempts, InitialBackoff: 10.0, MaxBackoff: 60.0, - BackoffMultiplier: 2.0, } clientBase.SetRole(typeutil.DataCoordRole) clientBase.SetGetAddrFunc(func() (string, error) { @@ -352,9 +393,11 @@ func TestClientBase_RetryPolicy(t *testing.T) { ctx := context.Background() randID := rand.Int63() res, err := clientBase.Call(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { - return &milvuspb.ComponentStates{State: &milvuspb.ComponentInfo{ - NodeID: randID, - }}, nil + return &milvuspb.ComponentStates{ + State: &milvuspb.ComponentInfo{ + NodeID: randID, + }, + }, nil }) assert.NoError(t, err) assert.Equal(t, res.(*milvuspb.ComponentStates).GetState().GetNodeID(), randID) @@ -366,16 +409,16 @@ func TestClientBase_Compression(t *testing.T) { if err != nil { log.Fatalf("failed to listen: %v", err) } - var kaep = keepalive.EnforcementPolicy{ + kaep := keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, PermitWithoutStream: true, } - var kasp = keepalive.ServerParameters{ + kasp := keepalive.ServerParameters{ Time: 60 * time.Second, Timeout: 60 * time.Second, } - maxAttempts := 5 + maxAttempts := 1 s := grpc.NewServer( grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp), @@ -399,7 +442,6 @@ func TestClientBase_Compression(t *testing.T) { MaxAttempts: maxAttempts, InitialBackoff: 10.0, MaxBackoff: 60.0, - BackoffMultiplier: 2.0, CompressionEnabled: true, } clientBase.SetRole(typeutil.DataCoordRole) @@ -414,9 +456,12 @@ func TestClientBase_Compression(t *testing.T) { ctx := context.Background() randID := rand.Int63() res, err := clientBase.Call(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { - return &milvuspb.ComponentStates{State: &milvuspb.ComponentInfo{ - NodeID: randID, - }}, nil + return &milvuspb.ComponentStates{ + State: &milvuspb.ComponentInfo{ + NodeID: randID, + }, + Status: merr.Success(), + }, nil }) assert.NoError(t, err) assert.Equal(t, res.(*milvuspb.ComponentStates).GetState().GetNodeID(), randID) diff --git a/internal/util/grpcclient/grpc_encoder.go b/internal/util/grpcclient/grpc_encoder.go index fbc7c1369d47f..276008eae709c 100644 --- a/internal/util/grpcclient/grpc_encoder.go +++ b/internal/util/grpcclient/grpc_encoder.go @@ -25,8 +25,10 @@ import ( "google.golang.org/grpc/encoding" ) -const None = "" -const Zstd = "zstd" +const ( + None = "" + Zstd = "zstd" +) type grpcCompressor struct { encoder *zstd.Encoder diff --git a/internal/util/importutil/binlog_adapter.go b/internal/util/importutil/binlog_adapter.go index 2b14094b108ec..b06428e587573 100644 --- a/internal/util/importutil/binlog_adapter.go +++ b/internal/util/importutil/binlog_adapter.go @@ -24,13 +24,13 @@ import ( "strings" "github.com/cockroachdb/errors" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" ) // A struct to hold insert log paths and delta log paths of a segment @@ -74,18 +74,13 @@ func NewBinlogAdapter(ctx context.Context, chunkManager storage.ChunkManager, flushFunc ImportFlushFunc, tsStartPoint uint64, - tsEndPoint uint64) (*BinlogAdapter, error) { + tsEndPoint uint64, +) (*BinlogAdapter, error) { if collectionInfo == nil { log.Warn("Binlog adapter: collection schema is nil") return nil, errors.New("collection schema is nil") } - // binlog import doesn't support partition key, the caller must specify one partition for importing - if len(collectionInfo.PartitionIDs) != 1 { - log.Warn("Binlog adapter: target partition must be only one", zap.Int("partitions", len(collectionInfo.PartitionIDs))) - return nil, errors.New("target partition must be only one") - } - if chunkManager == nil { log.Warn("Binlog adapter: chunk manager pointer is nil") return nil, errors.New("chunk manager pointer is nil") @@ -513,7 +508,8 @@ func (p *BinlogAdapter) readPrimaryKeys(logPath string) ([]int64, []string, erro func (p *BinlogAdapter) getShardingListByPrimaryInt64(primaryKeys []int64, timestampList []int64, memoryData []ShardData, - intDeletedList map[int64]uint64) ([]int32, error) { + intDeletedList map[int64]uint64, +) ([]int32, error) { if len(timestampList) != len(primaryKeys) { log.Warn("Binlog adapter: primary key length is not equal to timestamp list length", zap.Int("primaryKeysLen", len(primaryKeys)), zap.Int("timestampLen", len(timestampList))) @@ -566,7 +562,8 @@ func (p *BinlogAdapter) getShardingListByPrimaryInt64(primaryKeys []int64, func (p *BinlogAdapter) getShardingListByPrimaryVarchar(primaryKeys []string, timestampList []int64, memoryData []ShardData, - strDeletedList map[string]uint64) ([]int32, error) { + strDeletedList map[string]uint64, +) ([]int32, error) { if len(timestampList) != len(primaryKeys) { log.Warn("Binlog adapter: primary key length is not equal to timestamp list length", zap.Int("primaryKeysLen", len(primaryKeys)), zap.Int("timestampLen", len(timestampList))) @@ -637,7 +634,8 @@ func (p *BinlogAdapter) verifyField(fieldID storage.FieldID, memoryData []ShardD // the no.2, no.4, no.6, no.8, no.10 will be put into shard_1 // Note: the row count of insert log need to be equal to length of shardList func (p *BinlogAdapter) readInsertlog(fieldID storage.FieldID, logPath string, - memoryData []ShardData, shardList []int32) error { + memoryData []ShardData, shardList []int32, +) error { err := p.verifyField(fieldID, memoryData) if err != nil { log.Warn("Binlog adapter: could not read binlog file", zap.String("logPath", logPath), zap.Error(err)) @@ -779,7 +777,8 @@ func (p *BinlogAdapter) readInsertlog(fieldID storage.FieldID, logPath string, } func (p *BinlogAdapter) dispatchBoolToShards(data []bool, memoryData []ShardData, - shardList []int32, fieldID storage.FieldID) error { + shardList []int32, fieldID storage.FieldID, +) error { // verify row count if len(data) != len(shardList) { log.Warn("Binlog adapter: bool field row count is not equal to shard list row count %d", zap.Int("dataLen", len(data)), zap.Int("shardLen", len(shardList))) @@ -813,7 +812,8 @@ func (p *BinlogAdapter) dispatchBoolToShards(data []bool, memoryData []ShardData } func (p *BinlogAdapter) dispatchInt8ToShards(data []int8, memoryData []ShardData, - shardList []int32, fieldID storage.FieldID) error { + shardList []int32, fieldID storage.FieldID, +) error { // verify row count if len(data) != len(shardList) { log.Warn("Binlog adapter: int8 field row count is not equal to shard list row count", zap.Int("dataLen", len(data)), zap.Int("shardLen", len(shardList))) @@ -847,7 +847,8 @@ func (p *BinlogAdapter) dispatchInt8ToShards(data []int8, memoryData []ShardData } func (p *BinlogAdapter) dispatchInt16ToShards(data []int16, memoryData []ShardData, - shardList []int32, fieldID storage.FieldID) error { + shardList []int32, fieldID storage.FieldID, +) error { // verify row count if len(data) != len(shardList) { log.Warn("Binlog adapter: int16 field row count is not equal to shard list row count", zap.Int("dataLen", len(data)), zap.Int("shardLen", len(shardList))) @@ -881,7 +882,8 @@ func (p *BinlogAdapter) dispatchInt16ToShards(data []int16, memoryData []ShardDa } func (p *BinlogAdapter) dispatchInt32ToShards(data []int32, memoryData []ShardData, - shardList []int32, fieldID storage.FieldID) error { + shardList []int32, fieldID storage.FieldID, +) error { // verify row count if len(data) != len(shardList) { log.Warn("Binlog adapter: int32 field row count is not equal to shard list row count", zap.Int("dataLen", len(data)), zap.Int("shardLen", len(shardList))) @@ -915,7 +917,8 @@ func (p *BinlogAdapter) dispatchInt32ToShards(data []int32, memoryData []ShardDa } func (p *BinlogAdapter) dispatchInt64ToShards(data []int64, memoryData []ShardData, - shardList []int32, fieldID storage.FieldID) error { + shardList []int32, fieldID storage.FieldID, +) error { // verify row count if len(data) != len(shardList) { log.Warn("Binlog adapter: int64 field row count is not equal to shard list row count", zap.Int("dataLen", len(data)), zap.Int("shardLen", len(shardList))) @@ -949,7 +952,8 @@ func (p *BinlogAdapter) dispatchInt64ToShards(data []int64, memoryData []ShardDa } func (p *BinlogAdapter) dispatchFloatToShards(data []float32, memoryData []ShardData, - shardList []int32, fieldID storage.FieldID) error { + shardList []int32, fieldID storage.FieldID, +) error { // verify row count if len(data) != len(shardList) { log.Warn("Binlog adapter: float field row count is not equal to shard list row count", zap.Int("dataLen", len(data)), zap.Int("shardLen", len(shardList))) @@ -983,7 +987,8 @@ func (p *BinlogAdapter) dispatchFloatToShards(data []float32, memoryData []Shard } func (p *BinlogAdapter) dispatchDoubleToShards(data []float64, memoryData []ShardData, - shardList []int32, fieldID storage.FieldID) error { + shardList []int32, fieldID storage.FieldID, +) error { // verify row count if len(data) != len(shardList) { log.Warn("Binlog adapter: double field row count is not equal to shard list row count", zap.Int("dataLen", len(data)), zap.Int("shardLen", len(shardList))) @@ -1017,7 +1022,8 @@ func (p *BinlogAdapter) dispatchDoubleToShards(data []float64, memoryData []Shar } func (p *BinlogAdapter) dispatchVarcharToShards(data []string, memoryData []ShardData, - shardList []int32, fieldID storage.FieldID) error { + shardList []int32, fieldID storage.FieldID, +) error { // verify row count if len(data) != len(shardList) { log.Warn("Binlog adapter: varchar field row count is not equal to shard list row count", zap.Int("dataLen", len(data)), zap.Int("shardLen", len(shardList))) @@ -1051,7 +1057,8 @@ func (p *BinlogAdapter) dispatchVarcharToShards(data []string, memoryData []Shar } func (p *BinlogAdapter) dispatchBytesToShards(data [][]byte, memoryData []ShardData, - shardList []int32, fieldID storage.FieldID) error { + shardList []int32, fieldID storage.FieldID, +) error { // verify row count if len(data) != len(shardList) { log.Warn("Binlog adapter: JSON field row count is not equal to shard list row count", zap.Int("dataLen", len(data)), zap.Int("shardLen", len(shardList))) @@ -1085,7 +1092,8 @@ func (p *BinlogAdapter) dispatchBytesToShards(data [][]byte, memoryData []ShardD } func (p *BinlogAdapter) dispatchBinaryVecToShards(data []byte, dim int, memoryData []ShardData, - shardList []int32, fieldID storage.FieldID) error { + shardList []int32, fieldID storage.FieldID, +) error { // verify row count bytesPerVector := dim / 8 count := len(data) / bytesPerVector @@ -1132,7 +1140,8 @@ func (p *BinlogAdapter) dispatchBinaryVecToShards(data []byte, dim int, memoryDa } func (p *BinlogAdapter) dispatchFloatVecToShards(data []float32, dim int, memoryData []ShardData, - shardList []int32, fieldID storage.FieldID) error { + shardList []int32, fieldID storage.FieldID, +) error { // verify row count count := len(data) / dim if count != len(shardList) { diff --git a/internal/util/importutil/binlog_adapter_test.go b/internal/util/importutil/binlog_adapter_test.go index 7fda1d3667727..866169a7971c3 100644 --- a/internal/util/importutil/binlog_adapter_test.go +++ b/internal/util/importutil/binlog_adapter_test.go @@ -23,11 +23,11 @@ import ( "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" - "github.com/stretchr/testify/assert" ) const ( diff --git a/internal/util/importutil/binlog_file.go b/internal/util/importutil/binlog_file.go index 5b65a66377925..2353de7299d4f 100644 --- a/internal/util/importutil/binlog_file.go +++ b/internal/util/importutil/binlog_file.go @@ -21,11 +21,11 @@ import ( "fmt" "github.com/cockroachdb/errors" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/log" - "go.uber.org/zap" ) // BinlogFile class is a wrapper of storage.BinlogReader, to read binlog file, block by block. @@ -524,6 +524,49 @@ func (p *BinlogFile) ReadBinaryVector() ([]byte, int, error) { return result, dim, nil } +func (p *BinlogFile) ReadFloat16Vector() ([]byte, int, error) { + if p.reader == nil { + log.Warn("Binlog file: binlog reader not yet initialized") + return nil, 0, errors.New("binlog reader not yet initialized") + } + + dim := 0 + result := make([]byte, 0) + for { + event, err := p.reader.NextEventReader() + if err != nil { + log.Warn("Binlog file: failed to iterate events reader", zap.Error(err)) + return nil, 0, fmt.Errorf("failed to iterate events reader, error: %w", err) + } + + // end of the file + if event == nil { + break + } + + if event.TypeCode != storage.InsertEventType { + log.Warn("Binlog file: binlog file is not insert log") + return nil, 0, errors.New("binlog file is not insert log") + } + + if p.DataType() != schemapb.DataType_Float16Vector { + log.Warn("Binlog file: binlog data type is not float16 vector") + return nil, 0, errors.New("binlog data type is not float16 vector") + } + + data, dimenson, err := event.PayloadReaderInterface.GetFloat16VectorFromPayload() + if err != nil { + log.Warn("Binlog file: failed to read float16 vector data", zap.Error(err)) + return nil, 0, fmt.Errorf("failed to read float16 vector data, error: %w", err) + } + + dim = dimenson + result = append(result, data...) + } + + return result, dim, nil +} + // ReadFloatVector method reads all the blocks of a binlog by a data type. // A binlog is designed to support multiple blocks, but so far each binlog always contains only one block. // return vectors data and the dimension diff --git a/internal/util/importutil/binlog_file_test.go b/internal/util/importutil/binlog_file_test.go index 1bad508431206..4a80983d8fe2b 100644 --- a/internal/util/importutil/binlog_file_test.go +++ b/internal/util/importutil/binlog_file_test.go @@ -21,10 +21,10 @@ import ( "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" - "github.com/stretchr/testify/assert" ) func createBinlogBuf(t *testing.T, dataType schemapb.DataType, data interface{}) []byte { @@ -43,6 +43,11 @@ func createBinlogBuf(t *testing.T, dataType schemapb.DataType, data interface{}) if len(vectors) > 0 { dim = len(vectors[0]) } + } else if dataType == schemapb.DataType_Float16Vector { + vectors := data.([][]byte) + if len(vectors) > 0 { + dim = len(vectors[0]) / 2 + } } evt, err := w.NextInsertEventWriter(dim) @@ -144,6 +149,16 @@ func createBinlogBuf(t *testing.T, dataType schemapb.DataType, data interface{}) // the "original_size" is come from storage.originalSizeKey sizeTotal := len(vectors) * dim * 4 w.AddExtra("original_size", fmt.Sprintf("%v", sizeTotal)) + case schemapb.DataType_Float16Vector: + vectors := data.([][]byte) + for i := 0; i < len(vectors); i++ { + err = evt.AddFloat16VectorToPayload(vectors[i], dim) + assert.NoError(t, err) + } + // without the two lines, the case will crash at here. + // the "original_size" is come from storage.originalSizeKey + sizeTotal := len(vectors) * dim * 2 + w.AddExtra("original_size", fmt.Sprintf("%v", sizeTotal)) default: assert.True(t, false) return nil @@ -256,6 +271,11 @@ func Test_BinlogFileOpen(t *testing.T) { assert.Nil(t, dataFloatVector) assert.Equal(t, 0, dim) assert.Error(t, err) + + dataFloat16Vector, dim, err := binlogFile.ReadFloat16Vector() + assert.Nil(t, dataFloat16Vector) + assert.Equal(t, 0, dim) + assert.Error(t, err) } func Test_BinlogFileBool(t *testing.T) { @@ -894,3 +914,71 @@ func Test_BinlogFileFloatVector(t *testing.T) { binlogFile.Close() } + +func Test_BinlogFileFloat16Vector(t *testing.T) { + vectors := make([][]byte, 0) + vectors = append(vectors, []byte{1, 3, 5, 7}) + vectors = append(vectors, []byte{2, 4, 6, 8}) + dim := len(vectors[0]) / 2 + vecCount := len(vectors) + + chunkManager := &MockChunkManager{ + readBuf: map[string][]byte{ + "dummy": createBinlogBuf(t, schemapb.DataType_Float16Vector, vectors), + }, + } + + binlogFile, err := NewBinlogFile(chunkManager) + assert.NoError(t, err) + assert.NotNil(t, binlogFile) + + // correct reading + err = binlogFile.Open("dummy") + assert.NoError(t, err) + assert.Equal(t, schemapb.DataType_Float16Vector, binlogFile.DataType()) + + data, d, err := binlogFile.ReadFloat16Vector() + assert.NoError(t, err) + assert.Equal(t, dim, d) + assert.NotNil(t, data) + assert.Equal(t, vecCount*dim*2, len(data)) + for i := 0; i < vecCount; i++ { + for j := 0; j < dim*2; j++ { + assert.Equal(t, vectors[i][j], data[i*dim*2+j]) + } + } + + binlogFile.Close() + + // wrong data type reading + binlogFile, err = NewBinlogFile(chunkManager) + assert.NoError(t, err) + err = binlogFile.Open("dummy") + assert.NoError(t, err) + + dt, d, err := binlogFile.ReadFloatVector() + assert.Zero(t, len(dt)) + assert.Zero(t, d) + assert.Error(t, err) + + binlogFile.Close() + + // wrong log type + chunkManager.readBuf["dummy"] = createDeltalogBuf(t, []int64{1}, false) + err = binlogFile.Open("dummy") + assert.NoError(t, err) + + data, d, err = binlogFile.ReadFloat16Vector() + assert.Zero(t, len(data)) + assert.Zero(t, d) + assert.Error(t, err) + + // failed to iterate events reader + binlogFile.reader.Close() + data, d, err = binlogFile.ReadFloat16Vector() + assert.Zero(t, len(data)) + assert.Zero(t, d) + assert.Error(t, err) + + binlogFile.Close() +} diff --git a/internal/util/importutil/binlog_parser.go b/internal/util/importutil/binlog_parser.go index cd6d56e3bb2e6..9ea2ce039a55f 100644 --- a/internal/util/importutil/binlog_parser.go +++ b/internal/util/importutil/binlog_parser.go @@ -25,10 +25,10 @@ import ( "strings" "github.com/cockroachdb/errors" + "go.uber.org/zap" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/log" - "go.uber.org/zap" ) type BinlogParser struct { @@ -60,7 +60,8 @@ func NewBinlogParser(ctx context.Context, flushFunc ImportFlushFunc, updateProgressFunc func(percent int64), tsStartPoint uint64, - tsEndPoint uint64) (*BinlogParser, error) { + tsEndPoint uint64, +) (*BinlogParser, error) { if collectionInfo == nil { log.Warn("Binlog parser: collection schema is nil") return nil, errors.New("collection schema is nil") diff --git a/internal/util/importutil/binlog_parser_test.go b/internal/util/importutil/binlog_parser_test.go index 8d7a152ee2afa..d2be3fd63f3d8 100644 --- a/internal/util/importutil/binlog_parser_test.go +++ b/internal/util/importutil/binlog_parser_test.go @@ -23,9 +23,9 @@ import ( "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/stretchr/testify/assert" ) func Test_BinlogParserNew(t *testing.T) { diff --git a/internal/util/importutil/collection_info.go b/internal/util/importutil/collection_info.go index ad0194a45fe67..f7fc31270e874 100644 --- a/internal/util/importutil/collection_info.go +++ b/internal/util/importutil/collection_info.go @@ -23,14 +23,13 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/typeutil" ) type CollectionInfo struct { Schema *schemapb.CollectionSchema ShardNum int32 - PartitionIDs []int64 // target partitions of bulkinsert, one partition for non-partition-key collection, or all partiitons for partition-key collection + PartitionIDs []int64 // target partitions of bulkinsert PrimaryKey *schemapb.FieldSchema PartitionKey *schemapb.FieldSchema @@ -39,23 +38,10 @@ type CollectionInfo struct { Name2FieldID map[string]int64 // this member is for Numpy file name validation and JSON row validation } -func DeduceTargetPartitions(partitions map[string]int64, collectionSchema *schemapb.CollectionSchema, defaultPartition int64) ([]int64, error) { - // if no partition key, rutrn the default partition ID as target partition - _, err := typeutil.GetPartitionKeyFieldSchema(collectionSchema) - if err != nil { - return []int64{defaultPartition}, nil - } - - _, partitionIDs, err := typeutil.RearrangePartitionsForPartitionKey(partitions) - if err != nil { - return nil, err - } - return partitionIDs, nil -} - func NewCollectionInfo(collectionSchema *schemapb.CollectionSchema, shardNum int32, - partitionIDs []int64) (*CollectionInfo, error) { + partitionIDs []int64, +) (*CollectionInfo, error) { if shardNum <= 0 { return nil, fmt.Errorf("illegal shard number %d", shardNum) } diff --git a/internal/util/importutil/collection_info_test.go b/internal/util/importutil/collection_info_test.go index e1ec73c06808f..71994e6b74a73 100644 --- a/internal/util/importutil/collection_info_test.go +++ b/internal/util/importutil/collection_info_test.go @@ -18,33 +18,10 @@ package importutil import ( "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/stretchr/testify/assert" -) - -func Test_DeduceTargetPartitions(t *testing.T) { - schema := sampleSchema() - partitions := map[string]int64{ - "part_0": 100, - "part_1": 200, - } - partitionIDs, err := DeduceTargetPartitions(partitions, schema, int64(1)) - assert.NoError(t, err) - assert.Equal(t, 1, len(partitionIDs)) - assert.Equal(t, int64(1), partitionIDs[0]) - - schema.Fields[7].IsPartitionKey = true - partitionIDs, err = DeduceTargetPartitions(partitions, schema, int64(1)) - assert.NoError(t, err) - assert.Equal(t, len(partitions), len(partitionIDs)) - partitions = map[string]int64{ - "part_a": 100, - } - partitionIDs, err = DeduceTargetPartitions(partitions, schema, int64(1)) - assert.Error(t, err) - assert.Nil(t, partitionIDs) -} + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) func Test_CollectionInfoNew(t *testing.T) { t.Run("succeed", func(t *testing.T) { diff --git a/internal/util/importutil/import_options_test.go b/internal/util/importutil/import_options_test.go index e7de6dfb0a3ea..9bb58476a27f2 100644 --- a/internal/util/importutil/import_options_test.go +++ b/internal/util/importutil/import_options_test.go @@ -20,12 +20,12 @@ import ( "math" "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) func Test_ValidateOptions(t *testing.T) { - assert.NoError(t, ValidateOptions([]*commonpb.KeyValuePair{})) assert.NoError(t, ValidateOptions([]*commonpb.KeyValuePair{ {Key: "start_ts", Value: "1666007457"}, diff --git a/internal/util/importutil/import_util.go b/internal/util/importutil/import_util.go index 636f6beafdb8f..bed5d8aaecd35 100644 --- a/internal/util/importutil/import_util.go +++ b/internal/util/importutil/import_util.go @@ -37,8 +37,10 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -type BlockData map[storage.FieldID]storage.FieldData // a map of field ID to field data -type ShardData map[int64]BlockData // a map of partition ID to block data +type ( + BlockData map[storage.FieldID]storage.FieldData // a map of field ID to field data + ShardData map[int64]BlockData // a map of partition ID to block data +) func isCanceled(ctx context.Context) bool { // canceled? @@ -154,7 +156,7 @@ type Validator struct { isString bool // for string field dimension int // only for vector field fieldName string // field name - fieldID int64 //field ID + fieldID int64 // field ID } // initValidators constructs valiator methods and data conversion methods @@ -469,8 +471,8 @@ func tryFlushBlocks(ctx context.Context, callFlushFunc ImportFlushFunc, blockSize int64, maxTotalSize int64, - force bool) error { - + force bool, +) error { totalSize := 0 biggestSize := 0 biggestItem := -1 diff --git a/internal/util/importutil/import_util_test.go b/internal/util/importutil/import_util_test.go index 88bb820f1cb9f..859b746f46688 100644 --- a/internal/util/importutil/import_util_test.go +++ b/internal/util/importutil/import_util_test.go @@ -24,13 +24,13 @@ import ( "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/stretchr/testify/assert" ) // sampleSchema() return a schema contains all supported data types with an int64 primary key @@ -381,14 +381,14 @@ func createBlockData(collectionSchema *schemapb.CollectionSchema, fieldsData map } func createShardsData(collectionSchema *schemapb.CollectionSchema, fieldsData map[storage.FieldID]interface{}, - shardNum int32, partitionIDs []int64) []ShardData { + shardNum int32, partitionIDs []int64, +) []ShardData { shardsData := make([]ShardData, 0, shardNum) for i := 0; i < int(shardNum); i++ { shardData := make(ShardData) for p := 0; p < len(partitionIDs); p++ { blockData := createBlockData(collectionSchema, fieldsData) shardData[partitionIDs[p]] = blockData - } shardsData = append(shardsData, shardData) } diff --git a/internal/util/importutil/import_wrapper.go b/internal/util/importutil/import_wrapper.go index 70ef5cfe1b3a1..35fa92b67886c 100644 --- a/internal/util/importutil/import_wrapper.go +++ b/internal/util/importutil/import_wrapper.go @@ -68,11 +68,13 @@ const ( // ReportImportAttempts is the maximum # of attempts to retry when import fails. var ReportImportAttempts uint = 10 -type ImportFlushFunc func(fields BlockData, shardID int, partID int64) error -type AssignSegmentFunc func(shardID int, partID int64) (int64, string, error) -type CreateBinlogsFunc func(fields BlockData, segmentID int64, partID int64) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error) -type SaveSegmentFunc func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, segmentID int64, targetChName string, rowCount int64, partID int64) error -type ReportFunc func(res *rootcoordpb.ImportResult) error +type ( + ImportFlushFunc func(fields BlockData, shardID int, partID int64) error + AssignSegmentFunc func(shardID int, partID int64) (int64, string, error) + CreateBinlogsFunc func(fields BlockData, segmentID int64, partID int64) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error) + SaveSegmentFunc func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, segmentID int64, targetChName string, rowCount int64, partID int64) error + ReportFunc func(res *rootcoordpb.ImportResult) error +) type WorkingSegment struct { segmentID int64 // segment ID @@ -107,7 +109,8 @@ type ImportWrapper struct { func NewImportWrapper(ctx context.Context, collectionInfo *CollectionInfo, segmentSize int64, idAlloc *allocator.IDAllocator, cm storage.ChunkManager, importResult *rootcoordpb.ImportResult, - reportFunc func(res *rootcoordpb.ImportResult) error) *ImportWrapper { + reportFunc func(res *rootcoordpb.ImportResult) error, +) *ImportWrapper { if collectionInfo == nil || collectionInfo.Schema == nil { log.Warn("import wrapper: collection schema is nil") return nil @@ -424,7 +427,7 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er } } else { flushFunc = func(fields BlockData, shardID int, partitionID int64) error { - var filePaths = []string{filePath} + filePaths := []string{filePath} printFieldsDataInfo(fields, "import wrapper: prepare to flush binlogs", filePaths) return p.flushFunc(fields, shardID, partitionID) } diff --git a/internal/util/importutil/import_wrapper_test.go b/internal/util/importutil/import_wrapper_test.go index 6b7611b52d860..e52e214786f77 100644 --- a/internal/util/importutil/import_wrapper_test.go +++ b/internal/util/importutil/import_wrapper_test.go @@ -29,7 +29,6 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" "golang.org/x/exp/mmap" @@ -40,6 +39,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" ) @@ -170,7 +170,8 @@ func createMockCallbackFunctions(t *testing.T, rowCounter *rowCounterTest) (Assi } saveSegmentFunc := func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, - segmentID int64, targetChName string, rowCount int64, partID int64) error { + segmentID int64, targetChName string, rowCount int64, partID int64, + ) error { return nil } @@ -214,7 +215,8 @@ func Test_ImportWrapperNew(t *testing.T) { return nil, nil, nil } saveBinFunc := func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, - segmentID int64, targetChName string, rowCount int64, partID int64) error { + segmentID int64, targetChName string, rowCount int64, partID int64, + ) error { return nil } @@ -265,9 +267,7 @@ func Test_ImportWrapperRowBased(t *testing.T) { assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter) importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), TaskId: 1, DatanodeId: 1, State: commonpb.ImportState_ImportStarted, @@ -345,9 +345,7 @@ func Test_ImportWrapperColumnBased_numpy(t *testing.T) { assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter) importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), TaskId: 1, DatanodeId: 1, State: commonpb.ImportState_ImportStarted, @@ -500,9 +498,7 @@ func Test_ImportWrapperRowBased_perf(t *testing.T) { schema := perfSchema(dim) importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), TaskId: 1, DatanodeId: 1, State: commonpb.ImportState_ImportStarted, @@ -676,9 +672,7 @@ func Test_ImportWrapperReportFailRowBased(t *testing.T) { // success case importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), TaskId: 1, DatanodeId: 1, State: commonpb.ImportState_ImportStarted, @@ -725,9 +719,7 @@ func Test_ImportWrapperReportFailColumnBased_numpy(t *testing.T) { // success case importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), TaskId: 1, DatanodeId: 1, State: commonpb.ImportState_ImportStarted, @@ -870,9 +862,7 @@ func Test_ImportWrapperDoBinlogImport(t *testing.T) { return nil } wrapper.importResult = &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), TaskId: 1, DatanodeId: 1, State: commonpb.ImportState_ImportStarted, @@ -891,9 +881,7 @@ func Test_ImportWrapperReportPersisted(t *testing.T) { tr := timerecord.NewTimeRecorder("test") importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), TaskId: 1, DatanodeId: 1, State: commonpb.ImportState_ImportStarted, @@ -921,7 +909,8 @@ func Test_ImportWrapperReportPersisted(t *testing.T) { // error when closing segments wrapper.saveSegmentFunc = func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, - segmentID int64, targetChName string, rowCount int64, partID int64) error { + segmentID int64, targetChName string, rowCount int64, partID int64, + ) error { return errors.New("error") } wrapper.workingSegments[0] = map[int64]*WorkingSegment{ @@ -932,7 +921,8 @@ func Test_ImportWrapperReportPersisted(t *testing.T) { // failed to report wrapper.saveSegmentFunc = func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, - segmentID int64, targetChName string, rowCount int64, partID int64) error { + segmentID int64, targetChName string, rowCount int64, partID int64, + ) error { return nil } wrapper.reportFunc = func(res *rootcoordpb.ImportResult) error { @@ -971,9 +961,7 @@ func Test_ImportWrapperFlushFunc(t *testing.T) { assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter) importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, + Status: merr.Success(), TaskId: 1, DatanodeId: 1, State: commonpb.ImportState_ImportStarted, @@ -1010,7 +998,8 @@ func Test_ImportWrapperFlushFunc(t *testing.T) { t.Run("close segment, saveSegmentFunc returns error", func(t *testing.T) { wrapper.saveSegmentFunc = func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, - segmentID int64, targetChName string, rowCount int64, partID int64) error { + segmentID int64, targetChName string, rowCount int64, partID int64, + ) error { return errors.New("error") } wrapper.segmentSize = 1 @@ -1035,7 +1024,8 @@ func Test_ImportWrapperFlushFunc(t *testing.T) { t.Run("createBinlogsFunc returns error", func(t *testing.T) { wrapper.saveSegmentFunc = func(fieldsInsert []*datapb.FieldBinlog, fieldsStats []*datapb.FieldBinlog, - segmentID int64, targetChName string, rowCount int64, partID int64) error { + segmentID int64, targetChName string, rowCount int64, partID int64, + ) error { return nil } wrapper.assignSegmentFunc = func(shardID int, partID int64) (int64, string, error) { diff --git a/internal/util/importutil/json_handler.go b/internal/util/importutil/json_handler.go index 6cef1eeb24b2d..eec84d3d849ac 100644 --- a/internal/util/importutil/json_handler.go +++ b/internal/util/importutil/json_handler.go @@ -23,7 +23,6 @@ import ( "strconv" "github.com/cockroachdb/errors" - "go.uber.org/zap" "github.com/milvus-io/milvus/internal/allocator" @@ -72,7 +71,8 @@ func NewJSONRowConsumer(ctx context.Context, collectionInfo *CollectionInfo, idAlloc *allocator.IDAllocator, blockSize int64, - flushFunc ImportFlushFunc) (*JSONRowConsumer, error) { + flushFunc ImportFlushFunc, +) (*JSONRowConsumer, error) { if collectionInfo == nil { log.Warn("JSON row consumer: collection schema is nil") return nil, errors.New("collection schema is nil") diff --git a/internal/util/importutil/json_handler_test.go b/internal/util/importutil/json_handler_test.go index 9e00e84be1141..7f5db26db096a 100644 --- a/internal/util/importutil/json_handler_test.go +++ b/internal/util/importutil/json_handler_test.go @@ -23,28 +23,25 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" + "google.golang.org/grpc" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/util/merr" ) type mockIDAllocator struct { allocErr error } -func (a *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { +func (a *mockIDAllocator) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { return &rootcoordpb.AllocIDResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - ID: int64(1), - Count: req.Count, + Status: merr.Success(), + ID: int64(1), + Count: req.Count, }, a.allocErr } diff --git a/internal/util/importutil/json_parser.go b/internal/util/importutil/json_parser.go index a0332d4c4e49c..d187b5c4b8ce5 100644 --- a/internal/util/importutil/json_parser.go +++ b/internal/util/importutil/json_parser.go @@ -24,13 +24,13 @@ import ( "strings" "github.com/cockroachdb/errors" + "go.uber.org/zap" + "golang.org/x/exp/maps" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/zap" - "golang.org/x/exp/maps" ) const ( @@ -322,6 +322,7 @@ func (p *JSONParser) ParseRows(reader *IOReader, handler JSONRowHandler) error { return errors.New("import task was canceled") } + // nolint // this break means we require the first node must be RowRootNode // once the RowRootNode is parsed, just finish break diff --git a/internal/util/importutil/json_parser_test.go b/internal/util/importutil/json_parser_test.go index ceadbec0c3856..e241c1dc63083 100644 --- a/internal/util/importutil/json_parser_test.go +++ b/internal/util/importutil/json_parser_test.go @@ -26,12 +26,12 @@ import ( "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" - "github.com/stretchr/testify/assert" ) // mock class of JSONRowCounsumer @@ -83,7 +83,6 @@ func Test_AdjustBufSize(t *testing.T) { assert.NotNil(t, parser) assert.Greater(t, parser.bufRowCount, 0) adjustBufSize(parser, schema) - } func Test_JSONParserParseRows_IntPK(t *testing.T) { diff --git a/internal/util/importutil/numpy_adapter.go b/internal/util/importutil/numpy_adapter.go index 17c6832717a10..a1c2f414305a9 100644 --- a/internal/util/importutil/numpy_adapter.go +++ b/internal/util/importutil/numpy_adapter.go @@ -21,7 +21,6 @@ import ( "encoding/binary" "fmt" "io" - "io/ioutil" "os" "reflect" "regexp" @@ -29,13 +28,13 @@ import ( "unicode/utf8" "github.com/cockroachdb/errors" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/pkg/log" "github.com/sbinet/npyio" "github.com/sbinet/npyio/npy" "go.uber.org/zap" "golang.org/x/text/encoding/unicode" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/log" ) var ( @@ -589,7 +588,7 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) { // the character "a" occupys 2*4=8 bytes(0x97,0x00,0x00,0x00,0x00,0x00,0x00,0x00), // the "bb" occupys 8 bytes(0x97,0x00,0x00,0x00,0x98,0x00,0x00,0x00) // for non-ascii characters, the unicode could be 1 ~ 4 bytes, each character occupys 4 bytes, too - raw, err := ioutil.ReadAll(io.LimitReader(n.reader, utf8.UTFMax*int64(maxLen)*int64(batchRead))) + raw, err := io.ReadAll(io.LimitReader(n.reader, utf8.UTFMax*int64(maxLen)*int64(batchRead))) if err != nil { log.Warn("Numpy adapter: failed to read utf32 bytes from numpy file", zap.Int("readDone", readDone), zap.Error(err)) @@ -610,7 +609,7 @@ func (n *NumpyAdapter) ReadString(count int) ([]string, error) { } else { // in the numpy file with ansi encoding, the dType could be like "S2", maxLen is 2, each string occupys 2 bytes // bytes.Index(buf, []byte{0}) tell us which position is the end of the string - buf, err := ioutil.ReadAll(io.LimitReader(n.reader, int64(maxLen)*int64(batchRead))) + buf, err := io.ReadAll(io.LimitReader(n.reader, int64(maxLen)*int64(batchRead))) if err != nil { log.Warn("Numpy adapter: failed to read ascii bytes from numpy file", zap.Int("readDone", readDone), zap.Error(err)) diff --git a/internal/util/importutil/numpy_adapter_test.go b/internal/util/importutil/numpy_adapter_test.go index 3e5f1fa58b84b..06ac172be5892 100644 --- a/internal/util/importutil/numpy_adapter_test.go +++ b/internal/util/importutil/numpy_adapter_test.go @@ -25,13 +25,13 @@ import ( "strings" "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/sbinet/npyio/npy" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) -type MockReader struct { -} +type MockReader struct{} func (r *MockReader) Read(p []byte) (n int, err error) { return 0, io.EOF diff --git a/internal/util/importutil/numpy_parser.go b/internal/util/importutil/numpy_parser.go index 3b3f74b039d67..ef5b0f951a96a 100644 --- a/internal/util/importutil/numpy_parser.go +++ b/internal/util/importutil/numpy_parser.go @@ -72,7 +72,8 @@ func NewNumpyParser(ctx context.Context, blockSize int64, chunkManager storage.ChunkManager, flushFunc ImportFlushFunc, - updateProgressFunc func(percent int64)) (*NumpyParser, error) { + updateProgressFunc func(percent int64), +) (*NumpyParser, error) { if collectionInfo == nil { log.Warn("Numper parser: collection schema is nil") return nil, errors.New("collection schema is nil") @@ -600,7 +601,6 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s log.Warn("Numpy parser: illegal value in float vector array", zap.Error(err)) return nil, fmt.Errorf("illegal value in float vector array: %s", err.Error()) } - } else if elementType == schemapb.DataType_Double { data = make([]float32, 0, columnReader.rowCount) data64, err := columnReader.reader.ReadFloat64(rowCount * columnReader.dimension) diff --git a/internal/util/importutil/numpy_parser_test.go b/internal/util/importutil/numpy_parser_test.go index 6b687ef351261..be545fe48a0c4 100644 --- a/internal/util/importutil/numpy_parser_test.go +++ b/internal/util/importutil/numpy_parser_test.go @@ -24,7 +24,6 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" diff --git a/internal/util/indexcgowrapper/build_index_info.go b/internal/util/indexcgowrapper/build_index_info.go index 82e030bfb5133..9199067ae0380 100644 --- a/internal/util/indexcgowrapper/build_index_info.go +++ b/internal/util/indexcgowrapper/build_index_info.go @@ -48,6 +48,7 @@ func NewBuildIndexInfo(config *indexpb.StorageConfig) (*BuildIndexInfo, error) { cAccessValue := C.CString(config.SecretAccessKey) cRootPath := C.CString(config.RootPath) cStorageType := C.CString(config.StorageType) + cCloudProvider := C.CString(config.CloudProvider) cIamEndPoint := C.CString(config.IAMEndpoint) cRegion := C.CString(config.Region) defer C.free(unsafe.Pointer(cAddress)) @@ -56,6 +57,7 @@ func NewBuildIndexInfo(config *indexpb.StorageConfig) (*BuildIndexInfo, error) { defer C.free(unsafe.Pointer(cAccessValue)) defer C.free(unsafe.Pointer(cRootPath)) defer C.free(unsafe.Pointer(cStorageType)) + defer C.free(unsafe.Pointer(cCloudProvider)) defer C.free(unsafe.Pointer(cIamEndPoint)) defer C.free(unsafe.Pointer(cRegion)) storageConfig := C.CStorageConfig{ @@ -65,11 +67,13 @@ func NewBuildIndexInfo(config *indexpb.StorageConfig) (*BuildIndexInfo, error) { access_key_value: cAccessValue, root_path: cRootPath, storage_type: cStorageType, + cloud_provider: cCloudProvider, iam_endpoint: cIamEndPoint, useSSL: C.bool(config.UseSSL), useIAM: C.bool(config.UseIAM), region: cRegion, useVirtualHost: C.bool(config.UseVirtualHost), + requestTimeoutMs: C.int64_t(config.RequestTimeoutMs), } status := C.NewBuildIndexInfo(&cBuildIndexInfo, storageConfig) @@ -160,3 +164,10 @@ func (bi *BuildIndexInfo) AppendInsertFile(filePath string) error { status := C.AppendInsertFilePath(bi.cBuildIndexInfo, cInsertFilePath) return HandleCStatus(&status, "appendInsertFile failed") } + +func (bi *BuildIndexInfo) AppendIndexEngineVersion(indexEngineVersion int32) error { + cIndexEngineVersion := C.int32_t(indexEngineVersion) + + status := C.AppendIndexEngineVersionToBuildInfo(bi.cBuildIndexInfo, cIndexEngineVersion) + return HandleCStatus(&status, "AppendIndexEngineVersion failed") +} diff --git a/internal/util/indexcgowrapper/helper.go b/internal/util/indexcgowrapper/helper.go index 5de29422f4419..142b9f76a769d 100644 --- a/internal/util/indexcgowrapper/helper.go +++ b/internal/util/indexcgowrapper/helper.go @@ -13,10 +13,8 @@ import ( "fmt" "unsafe" - "github.com/cockroachdb/errors" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" ) func GetBinarySetKeys(cBinarySet C.CBinarySet) ([]string, error) { @@ -65,18 +63,13 @@ func HandleCStatus(status *C.CStatus, extraInfo string) error { if status.error_code == 0 { return nil } - errorCode := status.error_code - errorName, ok := commonpb.ErrorCode_name[int32(errorCode)] - if !ok { - errorName = "UnknownError" - } + errorCode := int(status.error_code) errorMsg := C.GoString(status.error_msg) defer C.free(unsafe.Pointer(status.error_msg)) - finalMsg := fmt.Sprintf("[%s] %s", errorName, errorMsg) - logMsg := fmt.Sprintf("%s, C Runtime Exception: %s\n", extraInfo, finalMsg) + logMsg := fmt.Sprintf("%s, C Runtime Exception: %s\n", extraInfo, errorMsg) log.Warn(logMsg) - return errors.New(finalMsg) + return merr.WrapErrSegcore(int32(errorCode), logMsg) } func GetLocalUsedSize(path string) (int64, error) { diff --git a/internal/util/indexcgowrapper/index.go b/internal/util/indexcgowrapper/index.go index fd638f3ad55ba..2f62f99f9c6ea 100644 --- a/internal/util/indexcgowrapper/index.go +++ b/internal/util/indexcgowrapper/index.go @@ -7,6 +7,7 @@ package indexcgowrapper #include "indexbuilder/index_c.h" */ import "C" + import ( "context" "fmt" @@ -18,7 +19,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/proto/indexcgopb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/log" @@ -41,9 +41,7 @@ type CodecIndex interface { UpLoad() (map[string]int64, error) } -var ( - _ CodecIndex = (*CgoIndex)(nil) -) +var _ CodecIndex = (*CgoIndex)(nil) type CgoIndex struct { indexPtr C.CIndex @@ -130,6 +128,8 @@ func (index *CgoIndex) Build(dataset *Dataset) error { return fmt.Errorf("build index on supported data type: %s", dataset.DType.String()) case schemapb.DataType_FloatVector: return index.buildFloatVecIndex(dataset) + case schemapb.DataType_Float16Vector: + return fmt.Errorf("build index on supported data type: %s", dataset.DType.String()) case schemapb.DataType_BinaryVector: return index.buildBinaryVecIndex(dataset) case schemapb.DataType_Bool: diff --git a/internal/util/indexcgowrapper/index_test.go b/internal/util/indexcgowrapper/index_test.go index 97369bbc137b6..8678bc227fff1 100644 --- a/internal/util/indexcgowrapper/index_test.go +++ b/internal/util/indexcgowrapper/index_test.go @@ -185,9 +185,9 @@ func TestCIndex_Codec(t *testing.T) { err = copyIndex.Load(blobs) assert.Equal(t, err, nil) // IVF_FLAT_NM index don't support load and serialize - //copyBlobs, err := copyIndex.Serialize() - //assert.Equal(t, err, nil) - //assert.Equal(t, len(blobs), len(copyBlobs)) + // copyBlobs, err := copyIndex.Serialize() + // assert.Equal(t, err, nil) + // assert.Equal(t, len(blobs), len(copyBlobs)) // TODO: check key, value and more err = index.Delete() @@ -224,10 +224,11 @@ func TestCIndex_Error(t *testing.T) { }) t.Run("Load error", func(t *testing.T) { - blobs := []*Blob{{ - Key: "test", - Value: []byte("value"), - }, + blobs := []*Blob{ + { + Key: "test", + Value: []byte("value"), + }, } err = indexPtr.Load(blobs) assert.Error(t, err) diff --git a/internal/util/initcore/init_core.go b/internal/util/initcore/init_core.go index a356a021c1588..a9cd20d56112d 100644 --- a/internal/util/initcore/init_core.go +++ b/internal/util/initcore/init_core.go @@ -62,6 +62,7 @@ func InitRemoteChunkManager(params *paramtable.ComponentParam) error { cAccessValue := C.CString(params.MinioCfg.SecretAccessKey.GetValue()) cRootPath := C.CString(params.MinioCfg.RootPath.GetValue()) cStorageType := C.CString(params.CommonCfg.StorageType.GetValue()) + cCloudProvider := C.CString(params.MinioCfg.CloudProvider.GetValue()) cIamEndPoint := C.CString(params.MinioCfg.IAMEndpoint.GetValue()) cLogLevel := C.CString(params.MinioCfg.LogLevel.GetValue()) cRegion := C.CString(params.MinioCfg.Region.GetValue()) @@ -71,6 +72,7 @@ func InitRemoteChunkManager(params *paramtable.ComponentParam) error { defer C.free(unsafe.Pointer(cAccessValue)) defer C.free(unsafe.Pointer(cRootPath)) defer C.free(unsafe.Pointer(cStorageType)) + defer C.free(unsafe.Pointer(cCloudProvider)) defer C.free(unsafe.Pointer(cIamEndPoint)) defer C.free(unsafe.Pointer(cLogLevel)) defer C.free(unsafe.Pointer(cRegion)) @@ -81,18 +83,29 @@ func InitRemoteChunkManager(params *paramtable.ComponentParam) error { access_key_value: cAccessValue, root_path: cRootPath, storage_type: cStorageType, + cloud_provider: cCloudProvider, iam_endpoint: cIamEndPoint, useSSL: C.bool(params.MinioCfg.UseSSL.GetAsBool()), useIAM: C.bool(params.MinioCfg.UseIAM.GetAsBool()), log_level: cLogLevel, region: cRegion, useVirtualHost: C.bool(params.MinioCfg.UseVirtualHost.GetAsBool()), + requestTimeoutMs: C.int64_t(params.MinioCfg.RequestTimeoutMs.GetAsInt64()), } status := C.InitRemoteChunkManagerSingleton(storageConfig) return HandleCStatus(&status, "InitRemoteChunkManagerSingleton failed") } +func InitChunkCache(mmapDirPath string, readAheadPolicy string) error { + cMmapDirPath := C.CString(mmapDirPath) + defer C.free(unsafe.Pointer(cMmapDirPath)) + cReadAheadPolicy := C.CString(readAheadPolicy) + defer C.free(unsafe.Pointer(cReadAheadPolicy)) + status := C.InitChunkCacheSingleton(cMmapDirPath, cReadAheadPolicy) + return HandleCStatus(&status, "InitChunkCacheSingleton failed") +} + func CleanRemoteChunkManager() { C.CleanRemoteChunkManagerSingleton() } diff --git a/internal/util/metrics/c_registry.go b/internal/util/metrics/c_registry.go index 3e8973fc7599b..89771aff13d28 100644 --- a/internal/util/metrics/c_registry.go +++ b/internal/util/metrics/c_registry.go @@ -26,18 +26,19 @@ package metrics */ import "C" + import ( "sort" "strings" "sync" "unsafe" - "github.com/milvus-io/milvus/pkg/log" "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" "github.com/prometheus/common/expfmt" "go.uber.org/zap" - dto "github.com/prometheus/client_model/go" + "github.com/milvus-io/milvus/pkg/log" ) // metricSorter is a sortable slice of *dto.Metric. @@ -119,9 +120,7 @@ type CRegistry struct { // Gather implements Gatherer. func (r *CRegistry) Gather() (res []*dto.MetricFamily, err error) { - var ( - parser expfmt.TextParser - ) + var parser expfmt.TextParser r.mtx.RLock() cMetricsStr := C.GetKnowhereMetrics() diff --git a/internal/util/mock/datacoord_client.go b/internal/util/mock/datacoord_client.go index 0414adde01cd5..8276afaa336bb 100644 --- a/internal/util/mock/datacoord_client.go +++ b/internal/util/mock/datacoord_client.go @@ -19,12 +19,13 @@ package mock import ( "context" + "google.golang.org/grpc" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/types" - "google.golang.org/grpc" ) // DataCoordClient mocks of DataCoordClient @@ -112,7 +113,8 @@ func (m *DataCoordClient) GetCompactionStateWithPlans(ctx context.Context, req * func (m *DataCoordClient) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest, opts ...grpc.CallOption) (*datapb.WatchChannelsResponse, error) { return &datapb.WatchChannelsResponse{}, m.Err } -func (m *DataCoordClient) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) { + +func (m *DataCoordClient) GetFlushState(ctx context.Context, req *datapb.GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) { return &milvuspb.GetFlushStateResponse{}, m.Err } diff --git a/internal/util/mock/grpc_datacoord_client.go b/internal/util/mock/grpc_datacoord_client.go index a3b2fb9791ded..b7dfe111cd52f 100644 --- a/internal/util/mock/grpc_datacoord_client.go +++ b/internal/util/mock/grpc_datacoord_client.go @@ -19,14 +19,15 @@ package mock import ( "context" - "github.com/milvus-io/milvus/internal/proto/indexpb" - "google.golang.org/grpc" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/uniquegenerator" ) var _ datapb.DataCoordClient = &GrpcDataCoordClient{} @@ -45,7 +46,16 @@ func (m *GrpcDataCoordClient) CheckHealth(ctx context.Context, in *milvuspb.Chec } func (m *GrpcDataCoordClient) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { - return &milvuspb.ComponentStates{}, m.Err + return &milvuspb.ComponentStates{ + State: &milvuspb.ComponentInfo{ + NodeID: int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()), + Role: "MockDataCoord", + StateCode: commonpb.StateCode_Healthy, + ExtraInfo: nil, + }, + SubcomponentStates: nil, + Status: merr.Success(), + }, m.Err } func (m *GrpcDataCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { @@ -136,7 +146,7 @@ func (m *GrpcDataCoordClient) WatchChannels(ctx context.Context, req *datapb.Wat return &datapb.WatchChannelsResponse{}, m.Err } -func (m *GrpcDataCoordClient) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) { +func (m *GrpcDataCoordClient) GetFlushState(ctx context.Context, req *datapb.GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) { return &milvuspb.GetFlushStateResponse{}, m.Err } @@ -178,7 +188,6 @@ func (m *GrpcDataCoordClient) MarkSegmentsDropped(context.Context, *datapb.MarkS func (m *GrpcDataCoordClient) BroadcastAlteredCollection(ctx context.Context, in *datapb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, m.Err - } func (m *GrpcDataCoordClient) CreateIndex(ctx context.Context, req *indexpb.CreateIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { @@ -221,3 +230,7 @@ func (m *GrpcDataCoordClient) GetIndexBuildProgress(ctx context.Context, req *in func (m *GrpcDataCoordClient) ReportDataNodeTtMsgs(ctx context.Context, in *datapb.ReportDataNodeTtMsgsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, m.Err } + +func (m *GrpcDataCoordClient) Close() error { + return nil +} diff --git a/internal/util/mock/grpc_datanode_client.go b/internal/util/mock/grpc_datanode_client.go index 7c9031946e93f..6d382dca7eb1d 100644 --- a/internal/util/mock/grpc_datanode_client.go +++ b/internal/util/mock/grpc_datanode_client.go @@ -80,3 +80,15 @@ func (m *GrpcDataNodeClient) AddImportSegment(ctx context.Context, in *datapb.Ad func (m *GrpcDataNodeClient) SyncSegments(ctx context.Context, in *datapb.SyncSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, m.Err } + +func (m *GrpcDataNodeClient) FlushChannels(ctx context.Context, in *datapb.FlushChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *GrpcDataNodeClient) NotifyChannelOperation(ctx context.Context, in *datapb.ChannelOperationsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *GrpcDataNodeClient) CheckChannelOperationProgress(ctx context.Context, req *datapb.ChannelWatchInfo, opts ...grpc.CallOption) (*datapb.ChannelOperationProgressResponse, error) { + return &datapb.ChannelOperationProgressResponse{}, m.Err +} diff --git a/internal/util/mock/grpc_indexnode_client.go b/internal/util/mock/grpc_indexnode_client.go index 77d3e314b314c..d8bbbd57c7b3c 100644 --- a/internal/util/mock/grpc_indexnode_client.go +++ b/internal/util/mock/grpc_indexnode_client.go @@ -68,3 +68,7 @@ func (m *GrpcIndexNodeClient) GetMetrics(ctx context.Context, in *milvuspb.GetMe func (m *GrpcIndexNodeClient) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { return &internalpb.ShowConfigurationsResponse{}, m.Err } + +func (m *GrpcIndexNodeClient) Close() error { + return m.Err +} diff --git a/internal/util/mock/grpc_querynode_client.go b/internal/util/mock/grpc_querynode_client.go index 31b4f002b43b0..e20dc0d635422 100644 --- a/internal/util/mock/grpc_querynode_client.go +++ b/internal/util/mock/grpc_querynode_client.go @@ -25,6 +25,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/streamrpc" ) var _ querypb.QueryNodeClient = &GrpcQueryNodeClient{} @@ -89,10 +90,18 @@ func (m *GrpcQueryNodeClient) Query(ctx context.Context, in *querypb.QueryReques return &internalpb.RetrieveResults{}, m.Err } +func (m *GrpcQueryNodeClient) QueryStream(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (querypb.QueryNode_QueryStreamClient, error) { + return &streamrpc.LocalQueryClient{}, m.Err +} + func (m *GrpcQueryNodeClient) QuerySegments(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (*internalpb.RetrieveResults, error) { return &internalpb.RetrieveResults{}, m.Err } +func (m *GrpcQueryNodeClient) QueryStreamSegments(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (querypb.QueryNode_QueryStreamSegmentsClient, error) { + return &streamrpc.LocalQueryClient{}, m.Err +} + func (m *GrpcQueryNodeClient) SyncReplicaSegments(ctx context.Context, in *querypb.SyncReplicaSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, m.Err } @@ -120,3 +129,7 @@ func (m *GrpcQueryNodeClient) UnsubDmChannel(ctx context.Context, req *querypb.U func (m *GrpcQueryNodeClient) Delete(ctx context.Context, in *querypb.DeleteRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, m.Err } + +func (m *GrpcQueryNodeClient) Close() error { + return m.Err +} diff --git a/internal/util/mock/grpc_rootcoord_client.go b/internal/util/mock/grpc_rootcoord_client.go index b5bad586a5fe6..98be9c7c4f87e 100644 --- a/internal/util/mock/grpc_rootcoord_client.go +++ b/internal/util/mock/grpc_rootcoord_client.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/uniquegenerator" ) var _ rootcoordpb.RootCoordClient = &GrpcRootCoordClient{} @@ -49,7 +50,7 @@ func (m *GrpcRootCoordClient) ListDatabases(ctx context.Context, in *milvuspb.Li } func (m *GrpcRootCoordClient) RenameCollection(ctx context.Context, in *milvuspb.RenameCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return merr.Status(nil), nil + return merr.Success(), nil } func (m *GrpcRootCoordClient) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest, opts ...grpc.CallOption) (*milvuspb.CheckHealthResponse, error) { @@ -89,7 +90,16 @@ func (m *GrpcRootCoordClient) ListPolicy(ctx context.Context, in *internalpb.Lis } func (m *GrpcRootCoordClient) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { - return &milvuspb.ComponentStates{}, m.Err + return &milvuspb.ComponentStates{ + State: &milvuspb.ComponentInfo{ + NodeID: int64(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()), + Role: "MockRootCoord", + StateCode: commonpb.StateCode_Healthy, + ExtraInfo: nil, + }, + SubcomponentStates: nil, + Status: merr.Success(), + }, m.Err } func (m *GrpcRootCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { @@ -251,3 +261,7 @@ func (m *GrpcRootCoordClient) GetCredential(ctx context.Context, in *rootcoordpb func (m *GrpcRootCoordClient) AlterCollection(ctx context.Context, in *milvuspb.AlterCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, m.Err } + +func (m *GrpcRootCoordClient) Close() error { + return nil +} diff --git a/internal/util/mock/grpcclient.go b/internal/util/mock/grpcclient.go index 50d83a0a222e7..b466f097c3759 100644 --- a/internal/util/mock/grpcclient.go +++ b/internal/util/mock/grpcclient.go @@ -58,7 +58,6 @@ func (c *GRPCClientBase[T]) SetRole(role string) { } func (c *GRPCClientBase[T]) EnableEncryption() { - } func (c *GRPCClientBase[T]) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) T) { diff --git a/internal/util/mock/querynode_client.go b/internal/util/mock/querynode_client.go deleted file mode 100644 index 15f6eedc48071..0000000000000 --- a/internal/util/mock/querynode_client.go +++ /dev/null @@ -1,138 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package mock - -import ( - "context" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/types" -) - -var _ types.QueryNode = &QueryNodeClient{} - -type QueryNodeClient struct { - grpcClient *GrpcQueryNodeClient - Err error -} - -func (q QueryNodeClient) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) { - return q.grpcClient.GetStatistics(ctx, req) -} - -func (q QueryNodeClient) Init() error { - return nil -} - -func (q QueryNodeClient) Start() error { - return nil -} - -func (q QueryNodeClient) Stop() error { - return nil -} - -func (q QueryNodeClient) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - return q.grpcClient.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) -} - -func (q QueryNodeClient) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return q.grpcClient.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) -} - -func (q QueryNodeClient) Register() error { - return nil -} - -func (q QueryNodeClient) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return q.grpcClient.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) -} - -func (q QueryNodeClient) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { - return q.grpcClient.WatchDmChannels(ctx, req) -} - -func (q QueryNodeClient) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { - return q.grpcClient.LoadSegments(ctx, req) -} - -func (q QueryNodeClient) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { - return q.grpcClient.ReleaseCollection(ctx, req) -} - -func (q QueryNodeClient) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { - return q.grpcClient.LoadPartitions(ctx, req) -} - -func (q QueryNodeClient) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { - return q.grpcClient.ReleasePartitions(ctx, req) -} - -func (q QueryNodeClient) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { - return q.grpcClient.ReleaseSegments(ctx, req) -} - -func (q QueryNodeClient) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { - return q.grpcClient.GetSegmentInfo(ctx, req) -} - -func (q QueryNodeClient) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { - return q.grpcClient.Search(ctx, req) -} - -func (q QueryNodeClient) SearchSegments(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { - return q.grpcClient.Search(ctx, req) -} - -func (q QueryNodeClient) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { - return q.grpcClient.Query(ctx, req) -} - -func (q QueryNodeClient) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { - return q.grpcClient.Query(ctx, req) -} - -func (q QueryNodeClient) SyncReplicaSegments(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) { - return q.grpcClient.SyncReplicaSegments(ctx, req) -} - -func (q QueryNodeClient) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - return q.grpcClient.GetMetrics(ctx, req) -} - -func (q QueryNodeClient) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - return q.grpcClient.ShowConfigurations(ctx, req) -} - -func (q QueryNodeClient) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { - return q.grpcClient.UnsubDmChannel(ctx, req) -} - -func (q QueryNodeClient) GetDataDistribution(ctx context.Context, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) { - return q.grpcClient.GetDataDistribution(ctx, req) -} - -func (q QueryNodeClient) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) { - return q.grpcClient.SyncDistribution(ctx, req) -} - -func (q QueryNodeClient) Delete(ctx context.Context, req *querypb.DeleteRequest) (*commonpb.Status, error) { - return q.grpcClient.Delete(ctx, req) -} diff --git a/internal/util/pipeline/node.go b/internal/util/pipeline/node.go index e4fc7242b5064..ad42e6318fe51 100644 --- a/internal/util/pipeline/node.go +++ b/internal/util/pipeline/node.go @@ -20,9 +20,10 @@ import ( "fmt" "sync" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/timerecord" - "go.uber.org/zap" ) type Node interface { @@ -66,7 +67,7 @@ func (c *nodeCtx) work() { for { select { - //close + // close case <-c.closeCh: c.node.Close() close(c.inputChannel) diff --git a/internal/util/pipeline/pipeline.go b/internal/util/pipeline/pipeline.go index 17bb3eea517f7..047bf65f48714 100644 --- a/internal/util/pipeline/pipeline.go +++ b/internal/util/pipeline/pipeline.go @@ -19,9 +19,10 @@ package pipeline import ( "time" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/timerecord" - "go.uber.org/zap" ) type Pipeline interface { diff --git a/internal/util/pipeline/pipeline_test.go b/internal/util/pipeline/pipeline_test.go index 82793091c6f6b..8ddeb9c35534a 100644 --- a/internal/util/pipeline/pipeline_test.go +++ b/internal/util/pipeline/pipeline_test.go @@ -19,8 +19,9 @@ package pipeline import ( "testing" - "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/pkg/mq/msgstream" ) type testNode struct { diff --git a/internal/util/pipeline/stream_pipeline.go b/internal/util/pipeline/stream_pipeline.go index 79b0e262e2fce..6cb6b6900e04e 100644 --- a/internal/util/pipeline/stream_pipeline.go +++ b/internal/util/pipeline/stream_pipeline.go @@ -17,16 +17,18 @@ package pipeline import ( + "context" "sync" "time" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/tsoutil" - "go.uber.org/zap" ) type StreamPipeline interface { @@ -68,7 +70,7 @@ func (p *streamPipeline) ConsumeMsgStream(position *msgpb.MsgPosition) error { } start := time.Now() - p.input, err = p.dispatcher.Register(p.vChannel, position, mqwrapper.SubscriptionPositionUnknown) + p.input, err = p.dispatcher.Register(context.TODO(), p.vChannel, position, mqwrapper.SubscriptionPositionUnknown) if err != nil { log.Error("dispatcher register failed", zap.String("channel", position.ChannelName)) return WrapErrRegDispather(err) diff --git a/internal/util/pipeline/stream_pipeline_test.go b/internal/util/pipeline/stream_pipeline_test.go index bf0d5d9f0cd01..7bf28a5a0c351 100644 --- a/internal/util/pipeline/stream_pipeline_test.go +++ b/internal/util/pipeline/stream_pipeline_test.go @@ -20,12 +20,13 @@ import ( "fmt" "testing" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/suite" ) type StreamPipelineSuite struct { @@ -33,10 +34,10 @@ type StreamPipelineSuite struct { pipeline StreamPipeline inChannel chan *msgstream.MsgPack outChannel chan msgstream.Timestamp - //data + // data length int channel string - //mock + // mock msgDispatcher *msgdispatcher.MockClient } @@ -45,7 +46,7 @@ func (suite *StreamPipelineSuite) SetupTest() { suite.inChannel = make(chan *msgstream.MsgPack, 1) suite.outChannel = make(chan msgstream.Timestamp) suite.msgDispatcher = msgdispatcher.NewMockClient(suite.T()) - suite.msgDispatcher.EXPECT().Register(suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.inChannel, nil) + suite.msgDispatcher.EXPECT().Register(mock.Anything, suite.channel, mock.Anything, mqwrapper.SubscriptionPositionUnknown).Return(suite.inChannel, nil) suite.msgDispatcher.EXPECT().Deregister(suite.channel) suite.pipeline = NewPipelineWithStream(suite.msgDispatcher, 0, false, suite.channel) suite.length = 4 diff --git a/internal/util/segmentutil/utils.go b/internal/util/segmentutil/utils.go index de20278070ffb..3f183ef32b067 100644 --- a/internal/util/segmentutil/utils.go +++ b/internal/util/segmentutil/utils.go @@ -1,9 +1,10 @@ package segmentutil import ( + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" - "go.uber.org/zap" ) // ReCalcRowCount re-calculates number of rows of `oldSeg` based on its bin log count, and correct its value in its diff --git a/internal/util/sessionutil/session_util.go b/internal/util/sessionutil/session_util.go index 30d1ac1a487b5..1c08e94bfe528 100644 --- a/internal/util/sessionutil/session_util.go +++ b/internal/util/sessionutil/session_util.go @@ -20,8 +20,11 @@ import ( "context" "encoding/json" "fmt" + "os" "path" + "path/filepath" "strconv" + "strings" "sync" "time" @@ -78,6 +81,27 @@ const ( SessionUpdateEvent ) +type IndexEngineVersion struct { + MinimalIndexVersion int32 `json:"MinimalIndexVersion,omitempty"` + CurrentIndexVersion int32 `json:"CurrentIndexVersion,omitempty"` +} + +// SessionRaw the persistent part of Session. +type SessionRaw struct { + ServerID int64 `json:"ServerID,omitempty"` + ServerName string `json:"ServerName,omitempty"` + Address string `json:"Address,omitempty"` + Exclusive bool `json:"Exclusive,omitempty"` + Stopping bool `json:"Stopping,omitempty"` + TriggerKill bool + Version string `json:"Version"` + IndexEngineVersion IndexEngineVersion `json:"IndexEngineVersion,omitempty"` + LeaseID *clientv3.LeaseID `json:"LeaseID,omitempty"` + + HostName string `json:"HostName,omitempty"` + EnableDisk bool `json:"EnableDisk,omitempty"` +} + // Session is a struct to store service's session, including ServerID, ServerName, // Address. // Exclusive indicates that this server can only start one. @@ -89,19 +113,14 @@ type Session struct { keepAliveCancel context.CancelFunc keepAliveCtx context.Context - ServerID int64 `json:"ServerID,omitempty"` - ServerName string `json:"ServerName,omitempty"` - Address string `json:"Address,omitempty"` - Exclusive bool `json:"Exclusive,omitempty"` - Stopping bool `json:"Stopping,omitempty"` - TriggerKill bool - Version semver.Version `json:"Version,omitempty"` + SessionRaw + + Version semver.Version `json:"Version,omitempty"` liveChOnce sync.Once liveCh chan struct{} etcdCli *clientv3.Client - leaseID *clientv3.LeaseID watchSessionKeyCh clientv3.WatchChan watchCancel atomic.Pointer[context.CancelFunc] wg sync.WaitGroup @@ -134,6 +153,20 @@ func WithResueNodeID(b bool) SessionOption { return func(session *Session) { session.reuseNodeID = b } } +// WithIndexEngineVersion should be only used by querynode. +func WithIndexEngineVersion(minimal, current int32) SessionOption { + return func(session *Session) { + session.IndexEngineVersion.MinimalIndexVersion = minimal + session.IndexEngineVersion.CurrentIndexVersion = current + } +} + +func WithEnableDisk(enableDisk bool) SessionOption { + return func(s *Session) { + s.EnableDisk = enableDisk + } +} + func (s *Session) apply(opts ...SessionOption) { for _, opt := range opts { opt(s) @@ -142,62 +175,25 @@ func (s *Session) apply(opts ...SessionOption) { // UnmarshalJSON unmarshal bytes to Session. func (s *Session) UnmarshalJSON(data []byte) error { - var raw struct { - ServerID int64 `json:"ServerID,omitempty"` - ServerName string `json:"ServerName,omitempty"` - Address string `json:"Address,omitempty"` - Exclusive bool `json:"Exclusive,omitempty"` - Stopping bool `json:"Stopping,omitempty"` - TriggerKill bool - Version string `json:"Version"` - LeaseID *clientv3.LeaseID `json:"LeaseID,omitempty"` - } - err := json.Unmarshal(data, &raw) + err := json.Unmarshal(data, &s.SessionRaw) if err != nil { return err } - if raw.Version != "" { - s.Version, err = semver.Parse(raw.Version) + if s.SessionRaw.Version != "" { + s.Version, err = semver.Parse(s.SessionRaw.Version) if err != nil { return err } } - s.ServerID = raw.ServerID - s.ServerName = raw.ServerName - s.Address = raw.Address - s.Exclusive = raw.Exclusive - s.Stopping = raw.Stopping - s.TriggerKill = raw.TriggerKill - s.leaseID = raw.LeaseID return nil } // MarshalJSON marshals session to bytes. func (s *Session) MarshalJSON() ([]byte, error) { - - verStr := s.Version.String() - return json.Marshal(&struct { - ServerID int64 `json:"ServerID,omitempty"` - ServerName string `json:"ServerName,omitempty"` - Address string `json:"Address,omitempty"` - Exclusive bool `json:"Exclusive,omitempty"` - Stopping bool `json:"Stopping,omitempty"` - TriggerKill bool - Version string `json:"Version"` - LeaseID *clientv3.LeaseID `json:"LeaseID,omitempty"` - }{ - ServerID: s.ServerID, - ServerName: s.ServerName, - Address: s.Address, - Exclusive: s.Exclusive, - Stopping: s.Stopping, - TriggerKill: s.TriggerKill, - Version: verStr, - LeaseID: s.leaseID, - }) - + s.SessionRaw.Version = s.Version.String() + return json.Marshal(s.SessionRaw) } // NewSession is a helper to build Session object. @@ -205,11 +201,20 @@ func (s *Session) MarshalJSON() ([]byte, error) { // metaRoot is a path in etcd to save session information. // etcdEndpoints is to init etcdCli when NewSession func NewSession(ctx context.Context, metaRoot string, client *clientv3.Client, opts ...SessionOption) *Session { + hostName, hostNameErr := os.Hostname() + if hostNameErr != nil { + log.Error("get host name fail", zap.Error(hostNameErr)) + } + session := &Session{ ctx: ctx, metaRoot: metaRoot, Version: common.Version, + SessionRaw: SessionRaw{ + HostName: hostName, + }, + // options sessionTTL: paramtable.Get().CommonCfg.SessionTTL.GetAsInt64(), sessionRetryTimes: paramtable.Get().CommonCfg.SessionRetryTimes.GetAsInt64(), @@ -374,10 +379,10 @@ func (s *Session) initWatchSessionCh(ctx context.Context) error { err = retry.Do(ctx, func() error { getResp, err = s.etcdCli.Get(ctx, s.getSessionKey()) - log.Warn("fail to get the session key from the etcd", zap.Error(err)) return err }, retry.Attempts(uint(s.sessionRetryTimes))) if err != nil { + log.Warn("fail to get the session key from the etcd", zap.Error(err)) return err } s.watchSessionKeyCh = s.etcdCli.Watch(ctx, s.getSessionKey(), clientv3.WithRev(getResp.Header.Revision)) @@ -413,7 +418,7 @@ func (s *Session) registerService() (<-chan *clientv3.LeaseKeepAliveResponse, er log.Error("register service", zap.Error(err)) return err } - s.leaseID = &resp.ID + s.LeaseID = &resp.ID sessionJSON, err := json.Marshal(s) if err != nil { @@ -426,7 +431,6 @@ func (s *Session) registerService() (<-chan *clientv3.LeaseKeepAliveResponse, er "=", 0)). Then(clientv3.OpPut(completeKey, string(sessionJSON), clientv3.WithLease(resp.ID))).Commit() - if err != nil { log.Warn("compare and swap error, maybe the key has already been registered", zap.Error(err)) return err @@ -487,22 +491,21 @@ func (s *Session) processKeepAliveResponse(ch <-chan *clientv3.LeaseKeepAliveRes err := retry.Do(s.ctx, func() error { ctx, cancel := context.WithTimeout(s.keepAliveCtx, time.Second*10) defer cancel() - resp, err := s.etcdCli.KeepAliveOnce(ctx, *s.leaseID) + resp, err := s.etcdCli.KeepAliveOnce(ctx, *s.LeaseID) keepAliveOnceResp = resp return err }, retry.Attempts(3)) - if err != nil { - log.Warn("fail to retry keepAliveOnce", zap.String("serverName", s.ServerName), zap.Int64("leaseID", int64(*s.leaseID)), zap.Error(err)) + log.Warn("fail to retry keepAliveOnce", zap.String("serverName", s.ServerName), zap.Int64("LeaseID", int64(*s.LeaseID)), zap.Error(err)) s.safeCloseLiveCh() return } - log.Info("succeed to KeepAliveOnce", zap.String("serverName", s.ServerName), zap.Int64("leaseID", int64(*s.leaseID)), zap.Any("resp", keepAliveOnceResp)) + log.Info("succeed to KeepAliveOnce", zap.String("serverName", s.ServerName), zap.Int64("LeaseID", int64(*s.LeaseID)), zap.Any("resp", keepAliveOnceResp)) var chNew <-chan *clientv3.LeaseKeepAliveResponse keepAliveFunc := func() error { var err1 error - chNew, err1 = s.etcdCli.KeepAlive(s.keepAliveCtx, *s.leaseID) + chNew, err1 = s.etcdCli.KeepAlive(s.keepAliveCtx, *s.LeaseID) return err1 } err = fnWithTimeout(keepAliveFunc, time.Second*10) @@ -560,9 +563,10 @@ func (s *Session) GetSessions(prefix string) (map[string]*Session, int64, error) return nil, 0, err } _, mapKey := path.Split(string(kv.Key)) - log.Debug("SessionUtil GetSessions ", zap.Any("prefix", prefix), + log.Debug("SessionUtil GetSessions", + zap.String("prefix", prefix), zap.String("key", mapKey), - zap.Any("address", session.Address)) + zap.String("address", session.Address)) res[mapKey] = session } return res, resp.Header.Revision, nil @@ -598,7 +602,7 @@ func (s *Session) GetSessionsWithVersionRange(prefix string, r semver.Range) (ma } func (s *Session) GoingStop() error { - if s == nil || s.etcdCli == nil || s.leaseID == nil { + if s == nil || s.etcdCli == nil || s.LeaseID == nil { return errors.New("the session hasn't been init") } @@ -621,7 +625,7 @@ func (s *Session) GoingStop() error { log.Error("fail to marshal the session", zap.String("key", completeKey)) return err } - _, err = s.etcdCli.Put(s.ctx, completeKey, string(sessionJSON), clientv3.WithLease(*s.leaseID)) + _, err = s.etcdCli.Put(s.ctx, completeKey, string(sessionJSON), clientv3.WithLease(*s.LeaseID)) if err != nil { log.Error("fail to update the session to stopping state", zap.String("key", completeKey)) return err @@ -757,7 +761,7 @@ func (w *sessionWatcher) handleWatchResponse(wresp clientv3.WatchResponse) { func (w *sessionWatcher) handleWatchErr(err error) error { // if not ErrCompacted, just close the channel if err != v3rpc.ErrCompacted { - //close event channel + // close event channel log.Warn("Watch service found error", zap.Error(err)) close(w.eventCh) return err @@ -788,16 +792,30 @@ func (w *sessionWatcher) handleWatchErr(err error) error { // LivenessCheck performs liveness check with provided context and channel // ctx controls the liveness check loop // ch is the liveness signal channel, ch is closed only when the session is expired -// callback is the function to call when ch is closed, note that callback will not be invoked when loop exits due to context +// callback must be called before liveness check exit, to close the session's owner component func (s *Session) LivenessCheck(ctx context.Context, callback func()) { err := s.initWatchSessionCh(ctx) if err != nil { log.Error("failed to get session for liveness check", zap.Error(err)) s.cancelKeepAlive() + if callback != nil { + go callback() + } + return } + s.wg.Add(1) go func() { defer s.wg.Done() + if callback != nil { + // before exit liveness check, callback to exit the session owner + defer func() { + if ctx.Err() == nil { + go callback() + } + }() + } + defer s.SetDisconnected(true) for { select { case _, ok := <-s.liveCh: @@ -807,10 +825,6 @@ func (s *Session) LivenessCheck(ctx context.Context, callback func()) { } // not ok, connection lost log.Warn("connection lost detected, shuting down") - s.SetDisconnected(true) - if callback != nil { - go callback() - } return case <-ctx.Done(): log.Warn("liveness exits due to context done") @@ -826,7 +840,7 @@ func (s *Session) LivenessCheck(ctx context.Context, callback func()) { if resp.Err() != nil { // if not ErrCompacted, just close the channel if resp.Err() != v3rpc.ErrCompacted { - //close event channel + // close event channel log.Warn("Watch service found error", zap.Error(resp.Err())) s.cancelKeepAlive() return @@ -867,12 +881,12 @@ func (s *Session) Stop() { s.wg.Wait() } -// Revoke revokes the internal leaseID for the session key +// Revoke revokes the internal LeaseID for the session key func (s *Session) Revoke(timeout time.Duration) { if s == nil { return } - if s.etcdCli == nil || s.leaseID == nil { + if s.etcdCli == nil || s.LeaseID == nil { return } if s.Disconnected() { @@ -882,7 +896,7 @@ func (s *Session) Revoke(timeout time.Duration) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() // ignores resp & error, just do best effort to revoke - _, _ = s.etcdCli.Revoke(ctx, *s.leaseID) + _, _ = s.etcdCli.Revoke(ctx, *s.LeaseID) } // UpdateRegistered update the state of registered. @@ -959,7 +973,7 @@ func (s *Session) ProcessActiveStandBy(activateFunc func() error) error { clientv3.Version(s.activeKey), "=", 0)). - Then(clientv3.OpPut(s.activeKey, string(sessionJSON), clientv3.WithLease(*s.leaseID))).Commit() + Then(clientv3.OpPut(s.activeKey, string(sessionJSON), clientv3.WithLease(*s.LeaseID))).Commit() if err != nil { log.Error("register active key to etcd failed", zap.Error(err)) return false, -1, err @@ -1019,7 +1033,7 @@ func (s *Session) ProcessActiveStandBy(activateFunc func() error) error { } s.updateStandby(false) - log.Info(fmt.Sprintf("serverName: %v quit STANDBY mode, this node will become ACTIVE", s.ServerName)) + log.Info(fmt.Sprintf("serverName: %v quit STANDBY mode, this node will become ACTIVE, ID: %d", s.ServerName, s.ServerID)) if activateFunc != nil { return activateFunc() } @@ -1046,8 +1060,8 @@ func (s *Session) ForceActiveStandby(activateFunc func() error) error { if len(sessions) != 0 { activeSess := sessions[s.ServerName] - if activeSess == nil || activeSess.leaseID == nil { - //force delete all old sessions + if activeSess == nil || activeSess.LeaseID == nil { + // force delete all old sessions s.etcdCli.Delete(s.ctx, s.activeKey) for _, sess := range sessions { if sess.ServerID != s.ServerID { @@ -1058,7 +1072,7 @@ func (s *Session) ForceActiveStandby(activateFunc func() error) error { } } else { // force release old active session - _, _ = s.etcdCli.Revoke(s.ctx, *activeSess.leaseID) + _, _ = s.etcdCli.Revoke(s.ctx, *activeSess.LeaseID) } } @@ -1068,7 +1082,7 @@ func (s *Session) ForceActiveStandby(activateFunc func() error) error { clientv3.Version(s.activeKey), "=", 0)). - Then(clientv3.OpPut(s.activeKey, string(sessionJSON), clientv3.WithLease(*s.leaseID))).Commit() + Then(clientv3.OpPut(s.activeKey, string(sessionJSON), clientv3.WithLease(*s.LeaseID))).Commit() if !resp.Succeeded { msg := fmt.Sprintf("failed to force register ACTIVE %s", s.ServerName) @@ -1086,9 +1100,73 @@ func (s *Session) ForceActiveStandby(activateFunc func() error) error { return err } s.updateStandby(false) - log.Info(fmt.Sprintf("serverName: %v quit STANDBY mode, this node will become ACTIVE", s.ServerName)) + log.Info(fmt.Sprintf("serverName: %v quit STANDBY mode, this node will become ACTIVE, ID: %d", s.ServerName, s.ServerID)) if activateFunc != nil { return activateFunc() } return nil } + +func filterEmptyStrings(s []string) []string { + var filtered []string + for _, str := range s { + if str != "" { + filtered = append(filtered, str) + } + } + return filtered +} + +func GetSessions(pid int) []string { + fileFullName := GetServerInfoFilePath(pid) + if _, err := os.Stat(fileFullName); errors.Is(err, os.ErrNotExist) { + log.Warn("not found server info file path", zap.String("filePath", fileFullName), zap.Error(err)) + return []string{} + } + + v, err := os.ReadFile(fileFullName) + if err != nil { + log.Warn("read server info file path failed", zap.String("filePath", fileFullName), zap.Error(err)) + return []string{} + } + + return filterEmptyStrings(strings.Split(string(v), "\n")) +} + +func RemoveServerInfoFile(pid int) { + fullPath := GetServerInfoFilePath(pid) + _ = os.Remove(fullPath) +} + +// GetServerInfoFilePath get server info file path, eg: /tmp/milvus/server_id_123456789 +// Notes: this method will not support Windows OS +// return file path +func GetServerInfoFilePath(pid int) string { + tmpDir := "/tmp/milvus" + _ = os.Mkdir(tmpDir, os.ModePerm) + fileName := fmt.Sprintf("server_id_%d", pid) + filePath := filepath.Join(tmpDir, fileName) + return filePath +} + +func saveServerInfoInternal(role string, serverID int64, pid int) { + fileFullPath := GetServerInfoFilePath(pid) + fd, err := os.OpenFile(fileFullPath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o664) + if err != nil { + log.Warn("open server info file fail", zap.String("filePath", fileFullPath), zap.Error(err)) + return + } + defer fd.Close() + + data := fmt.Sprintf("%s-%d\n", role, serverID) + _, err = fd.WriteString(data) + if err != nil { + log.Warn("write server info file fail", zap.String("filePath", fileFullPath), zap.Error(err)) + } + + log.Info("save server info into file", zap.String("content", data), zap.String("filePath", fileFullPath)) +} + +func SaveServerInfo(role string, serverID int64) { + saveServerInfoInternal(role, serverID, os.Getpid()) +} diff --git a/internal/util/sessionutil/session_util_test.go b/internal/util/sessionutil/session_util_test.go index 9adda794e1dad..25f9d7a80efb1 100644 --- a/internal/util/sessionutil/session_util_test.go +++ b/internal/util/sessionutil/session_util_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" "math/rand" "net/url" "os" @@ -24,6 +23,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/server/v3/embed" "go.etcd.io/etcd/server/v3/etcdserver/api/v3client" + "go.uber.org/atomic" "go.uber.org/zap" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) func TestGetServerIDConcurrently(t *testing.T) { @@ -54,7 +55,7 @@ func TestGetServerIDConcurrently(t *testing.T) { defer etcdKV.RemoveWithPrefix("") var wg sync.WaitGroup - var muList = sync.Mutex{} + muList := sync.Mutex{} s := NewSession(ctx, metaRoot, etcdCli) res := make([]int64, 0) @@ -97,7 +98,7 @@ func TestInit(t *testing.T) { s := NewSession(ctx, metaRoot, etcdCli) s.Init("inittest", "testAddr", false, false) - assert.NotEqual(t, int64(0), s.leaseID) + assert.NotEqual(t, int64(0), s.LeaseID) assert.NotEqual(t, int64(0), s.ServerID) s.Register() sessions, _, err := s.GetSessions("inittest") @@ -122,7 +123,7 @@ func TestUpdateSessions(t *testing.T) { defer etcdKV.RemoveWithPrefix("") var wg sync.WaitGroup - var muList = sync.Mutex{} + muList := sync.Mutex{} s := NewSession(ctx, metaRoot, etcdCli, WithResueNodeID(false)) @@ -194,39 +195,55 @@ func TestSessionLivenessCheck(t *testing.T) { etcdCli, err := etcd.GetRemoteEtcdClient(etcdEndpoints) require.NoError(t, err) s := NewSession(context.Background(), metaRoot, etcdCli) - ctx := context.Background() + s.Register() ch := make(chan struct{}) s.liveCh = ch signal := make(chan struct{}, 1) - flag := false - - s.LivenessCheck(ctx, func() { - flag = true + flag := atomic.NewBool(false) + s.LivenessCheck(context.Background(), func() { + flag.Store(true) signal <- struct{}{} }) + assert.False(t, flag.Load()) - assert.False(t, flag) + // test liveCh receive event, liveness won't exit, callback won't trigger ch <- struct{}{} + assert.False(t, flag.Load()) - assert.False(t, flag) + // test close liveCh, liveness exit, callback should trigger close(ch) - <-signal - assert.True(t, flag) + assert.True(t, flag.Load()) - ctx, cancel := context.WithCancel(ctx) - cancel() - ch = make(chan struct{}) - s.liveCh = ch - flag = false + // test context done, liveness exit, callback shouldn't trigger + metaRoot = fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) + s1 := NewSession(context.Background(), metaRoot, etcdCli) + s1.Register() + ctx, cancel := context.WithCancel(context.Background()) + flag.Store(false) - s.LivenessCheck(ctx, func() { - flag = true + s1.LivenessCheck(ctx, func() { + flag.Store(true) signal <- struct{}{} }) + cancel() + assert.False(t, flag.Load()) - assert.False(t, flag) + // test context done, liveness start failed, callback should trigger + metaRoot = fmt.Sprintf("%d/%s", rand.Int(), DefaultServiceRoot) + s2 := NewSession(context.Background(), metaRoot, etcdCli) + s2.Register() + ctx, cancel = context.WithCancel(context.Background()) + signal = make(chan struct{}, 1) + flag.Store(false) + cancel() + s2.LivenessCheck(ctx, func() { + flag.Store(true) + signal <- struct{}{} + }) + <-signal + assert.True(t, flag.Load()) } func TestWatcherHandleWatchResp(t *testing.T) { @@ -365,9 +382,7 @@ func TestWatcherHandleWatchResp(t *testing.T) { assert.Panics(t, func() { w.handleWatchResponse(wresp) }) - }) - } func TestSession_Registered(t *testing.T) { @@ -385,10 +400,12 @@ func TestSession_String(t *testing.T) { func TestSesssionMarshal(t *testing.T) { s := &Session{ - ServerID: 1, - ServerName: "test", - Address: "localhost", - Version: common.Version, + SessionRaw: SessionRaw{ + ServerID: 1, + ServerName: "test", + Address: "localhost", + }, + Version: common.Version, } bs, err := json.Marshal(s) @@ -430,7 +447,7 @@ type SessionWithVersionSuite struct { // SetupSuite setup suite env func (suite *SessionWithVersionSuite) SetupSuite() { - dir, err := ioutil.TempDir(os.TempDir(), "milvus_ut") + dir, err := os.MkdirTemp(os.TempDir(), "milvus_ut") suite.Require().NoError(err) suite.tmpDir = dir suite.T().Log("using tmp dir:", dir) @@ -491,7 +508,6 @@ func (suite *SessionWithVersionSuite) SetupTest() { s3.Register() suite.sessions = append(suite.sessions, s3) - } func (suite *SessionWithVersionSuite) TearDownTest() { @@ -603,8 +619,7 @@ func TestSessionProcessActiveStandBy(t *testing.T) { defer etcdKV.RemoveWithPrefix("") var wg sync.WaitGroup - ch := make(chan struct{}) - signal := make(chan struct{}, 1) + signal := make(chan struct{}) flag := false // register session 1, will be active @@ -614,15 +629,16 @@ func TestSessionProcessActiveStandBy(t *testing.T) { s1.SetEnableActiveStandBy(true) s1.Register() wg.Add(1) - s1.liveCh = ch s1.ProcessActiveStandBy(func() error { log.Debug("Session 1 become active") wg.Done() return nil }) + wg.Wait() s1.LivenessCheck(ctx1, func() { + log.Debug("Session 1 livenessCheck callback") flag = true - signal <- struct{}{} + close(signal) s1.cancelKeepAlive() }) assert.False(t, s1.isStandby.Load().(bool)) @@ -641,18 +657,31 @@ func TestSessionProcessActiveStandBy(t *testing.T) { }) assert.True(t, s2.isStandby.Load().(bool)) - //assert.True(t, s2.watchingPrimaryKeyLock) + // assert.True(t, s2.watchingPrimaryKeyLock) // stop session 1, session 2 will take over primary service log.Debug("Stop session 1, session 2 will take over primary service") assert.False(t, flag) - ch <- struct{}{} - assert.False(t, flag) + s1.safeCloseLiveCh() - <-signal + { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, _ = s1.etcdCli.Revoke(ctx, *s1.LeaseID) + } + select { + case <-signal: + log.Debug("receive s1 signal") + case <-time.After(10 * time.Second): + log.Debug("wait to fail Liveness Check timeout") + t.FailNow() + } assert.True(t, flag) + log.Debug("session s1 stop") wg.Wait() + log.Debug("session s2 wait done") assert.False(t, s2.isStandby.Load().(bool)) + s2.Stop() } func TestSessionEventType_String(t *testing.T) { @@ -673,6 +702,42 @@ func TestSessionEventType_String(t *testing.T) { } } +func TestServerInfoOp(t *testing.T) { + t.Run("test with specified pid", func(t *testing.T) { + pid := 9999999 + serverID := int64(999) + + filePath := GetServerInfoFilePath(pid) + defer os.RemoveAll(filePath) + + saveServerInfoInternal(typeutil.QueryCoordRole, serverID, pid) + saveServerInfoInternal(typeutil.DataCoordRole, serverID, pid) + saveServerInfoInternal(typeutil.ProxyRole, serverID, pid) + + sessions := GetSessions(pid) + assert.Equal(t, 3, len(sessions)) + assert.ElementsMatch(t, sessions, []string{ + "querycoord-999", + "datacoord-999", + "proxy-999", + }) + + RemoveServerInfoFile(pid) + sessions = GetSessions(pid) + assert.Equal(t, 0, len(sessions)) + }) + + t.Run("test with os pid", func(t *testing.T) { + serverID := int64(9999) + filePath := GetServerInfoFilePath(os.Getpid()) + defer os.RemoveAll(filePath) + + SaveServerInfo(typeutil.QueryCoordRole, serverID) + sessions := GetSessions(os.Getpid()) + assert.Equal(t, 1, len(sessions)) + }) +} + func TestSession_apply(t *testing.T) { session := &Session{} opts := []SessionOption{WithTTL(100), WithRetryTimes(200)} @@ -718,7 +783,7 @@ type SessionSuite struct { func (s *SessionSuite) SetupSuite() { paramtable.Init() - dir, err := ioutil.TempDir(os.TempDir(), "milvus_ut") + dir, err := os.MkdirTemp(os.TempDir(), "milvus_ut") s.Require().NoError(err) s.tmpDir = dir s.T().Log("using tmp dir:", dir) @@ -920,17 +985,15 @@ func (s *SessionSuite) TestKeepAliveRetryActiveCancel() { // Register ch, err := session.registerService() - if err != nil { - panic(err) - } + s.Require().NoError(err) session.liveCh = make(chan struct{}) session.processKeepAliveResponse(ch) session.LivenessCheck(ctx, nil) // active cancel, should not retry connect session.cancelKeepAlive() - // sleep a while wait goroutine process - time.Sleep(time.Millisecond * 100) + // wait workers exit + session.wg.Wait() // expected Disconnected = true, means session is closed assert.Equal(s.T(), true, session.Disconnected()) } diff --git a/internal/util/streamrpc/in_memory_streamer.go b/internal/util/streamrpc/in_memory_streamer.go new file mode 100644 index 0000000000000..a1f98aa1a875f --- /dev/null +++ b/internal/util/streamrpc/in_memory_streamer.go @@ -0,0 +1,161 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamrpc + +import ( + "context" + "io" + "sync" + + "go.uber.org/atomic" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/milvus-io/milvus/pkg/util/generic" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +// InMemoryStreamer is a utility to wrap in-memory stream methods. +type InMemoryStreamer[Msg any] struct { + grpc.ClientStream + grpc.ServerStream + + ctx context.Context + closed atomic.Bool + closeOnce sync.Once + buffer chan Msg +} + +// SetHeader sets the header metadata. It may be called multiple times. +// When call multiple times, all the provided metadata will be merged. +// All the metadata will be sent out when one of the following happens: +// - ServerStream.SendHeader() is called; +// - The first response is sent out; +// - An RPC status is sent out (error or success). +func (s *InMemoryStreamer[Msg]) SetHeader(_ metadata.MD) error { + return merr.WrapErrServiceInternal("shall not be called") +} + +// SendHeader sends the header metadata. +// The provided md and headers set by SetHeader() will be sent. +// It fails if called multiple times. +func (s *InMemoryStreamer[Msg]) SendHeader(_ metadata.MD) error { + return merr.WrapErrServiceInternal("shall not be called") +} + +// SetTrailer sets the trailer metadata which will be sent with the RPC status. +// When called more than once, all the provided metadata will be merged. +func (s *InMemoryStreamer[Msg]) SetTrailer(_ metadata.MD) {} + +// SendMsg sends a message. On error, SendMsg aborts the stream and the +// error is returned directly. +// +// SendMsg blocks until: +// - There is sufficient flow control to schedule m with the transport, or +// - The stream is done, or +// - The stream breaks. +// +// SendMsg does not wait until the message is received by the client. An +// untimely stream closure may result in lost messages. +// +// It is safe to have a goroutine calling SendMsg and another goroutine +// calling RecvMsg on the same stream at the same time, but it is not safe +// to call SendMsg on the same stream in different goroutines. +// +// It is not safe to modify the message after calling SendMsg. Tracing +// libraries and stats handlers may use the message lazily. +func (s *InMemoryStreamer[Msg]) SendMsg(m interface{}) error { + return merr.WrapErrServiceInternal("shall not be called") +} + +// RecvMsg blocks until it receives a message into m or the stream is +// done. It returns io.EOF when the client has performed a CloseSend. On +// any non-EOF error, the stream is aborted and the error contains the +// RPC status. +// +// It is safe to have a goroutine calling SendMsg and another goroutine +// calling RecvMsg on the same stream at the same time, but it is not +// safe to call RecvMsg on the same stream in different goroutines. +func (s *InMemoryStreamer[Msg]) RecvMsg(m interface{}) error { + return merr.WrapErrServiceInternal("shall not be called") +} + +// Header returns the header metadata received from the server if there +// is any. It blocks if the metadata is not ready to read. +func (s *InMemoryStreamer[Msg]) Header() (metadata.MD, error) { + return nil, merr.WrapErrServiceInternal("shall not be called") +} + +// Trailer returns the trailer metadata from the server, if there is any. +// It must only be called after stream.CloseAndRecv has returned, or +// stream.Recv has returned a non-nil error (including io.EOF). +func (s *InMemoryStreamer[Msg]) Trailer() metadata.MD { + return nil +} + +// CloseSend closes the send direction of the stream. It closes the stream +// when non-nil error is met. It is also not safe to call CloseSend +// concurrently with SendMsg. +func (s *InMemoryStreamer[Msg]) CloseSend() error { + return merr.WrapErrServiceInternal("shall not be called") +} + +func NewInMemoryStreamer[Msg any](ctx context.Context, bufferSize int) *InMemoryStreamer[Msg] { + return &InMemoryStreamer[Msg]{ + ctx: ctx, + buffer: make(chan Msg, bufferSize), + } +} + +func (s *InMemoryStreamer[Msg]) Context() context.Context { + return s.ctx +} + +func (s *InMemoryStreamer[Msg]) Recv() (Msg, error) { + select { + case result, ok := <-s.buffer: + if !ok { + return generic.Zero[Msg](), io.EOF + } + return result, nil + case <-s.ctx.Done(): + return generic.Zero[Msg](), io.EOF + } +} + +func (s *InMemoryStreamer[Msg]) Send(req Msg) error { + if s.closed.Load() || s.ctx.Err() != nil { + return merr.WrapErrIoFailedReason("streamer closed") + } + select { + case s.buffer <- req: + return nil + case <-s.ctx.Done(): + return io.EOF + } +} + +func (s *InMemoryStreamer[Msg]) Close() { + s.closeOnce.Do(func() { + s.closed.Store(true) + close(s.buffer) + }) +} + +func (s *InMemoryStreamer[Msg]) IsClosed() bool { + return s.closed.Load() +} diff --git a/internal/util/streamrpc/in_memory_streamer_test.go b/internal/util/streamrpc/in_memory_streamer_test.go new file mode 100644 index 0000000000000..b128b515aa7a8 --- /dev/null +++ b/internal/util/streamrpc/in_memory_streamer_test.go @@ -0,0 +1,94 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streamrpc + +import ( + "context" + "io" + "testing" + + "github.com/stretchr/testify/suite" + "google.golang.org/grpc/metadata" +) + +type InMemoryStreamerSuite struct { + suite.Suite +} + +func (s *InMemoryStreamerSuite) TestBufferedClose() { + streamer := NewInMemoryStreamer[int64](context.Background(), 10) + err := streamer.Send(1) + s.NoError(err) + err = streamer.Send(2) + s.NoError(err) + + streamer.Close() + + r, err := streamer.Recv() + s.NoError(err) + s.EqualValues(1, r) + + r, err = streamer.Recv() + s.NoError(err) + s.EqualValues(2, r) + + _, err = streamer.Recv() + s.Error(err) +} + +func (s *InMemoryStreamerSuite) TestStreamerCtxCanceled() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + streamer := NewInMemoryStreamer[int64](ctx, 10) + err := streamer.Send(1) + s.Error(err) + + _, err = streamer.Recv() + s.Error(err) + s.ErrorIs(err, io.EOF) +} + +func (s *InMemoryStreamerSuite) TestMockedMethods() { + streamer := NewInMemoryStreamer[int64](context.Background(), 10) + + s.NotPanics(func() { + err := streamer.SetHeader(make(metadata.MD)) + s.Error(err) + + err = streamer.SendHeader(make(metadata.MD)) + s.Error(err) + + streamer.SetTrailer(make(metadata.MD)) + + err = streamer.SendMsg(1) + s.Error(err) + + err = streamer.RecvMsg(1) + s.Error(err) + + trailer := streamer.Trailer() + s.Nil(trailer) + + err = streamer.CloseSend() + s.Error(err) + }) +} + +func TestInMemoryStreamer(t *testing.T) { + suite.Run(t, new(InMemoryStreamerSuite)) +} diff --git a/internal/util/streamrpc/mocks/mock_query_stream_segments_server.go b/internal/util/streamrpc/mocks/mock_query_stream_segments_server.go new file mode 100644 index 0000000000000..b71672a9983ba --- /dev/null +++ b/internal/util/streamrpc/mocks/mock_query_stream_segments_server.go @@ -0,0 +1,325 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" + metadata "google.golang.org/grpc/metadata" + + mock "github.com/stretchr/testify/mock" +) + +// MockQueryStreamSegmentsServer is an autogenerated mock type for the QueryNode_QueryStreamSegmentsServer type +type MockQueryStreamSegmentsServer struct { + mock.Mock +} + +type MockQueryStreamSegmentsServer_Expecter struct { + mock *mock.Mock +} + +func (_m *MockQueryStreamSegmentsServer) EXPECT() *MockQueryStreamSegmentsServer_Expecter { + return &MockQueryStreamSegmentsServer_Expecter{mock: &_m.Mock} +} + +// Context provides a mock function with given fields: +func (_m *MockQueryStreamSegmentsServer) Context() context.Context { + ret := _m.Called() + + var r0 context.Context + if rf, ok := ret.Get(0).(func() context.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + return r0 +} + +// MockQueryStreamSegmentsServer_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context' +type MockQueryStreamSegmentsServer_Context_Call struct { + *mock.Call +} + +// Context is a helper method to define mock.On call +func (_e *MockQueryStreamSegmentsServer_Expecter) Context() *MockQueryStreamSegmentsServer_Context_Call { + return &MockQueryStreamSegmentsServer_Context_Call{Call: _e.mock.On("Context")} +} + +func (_c *MockQueryStreamSegmentsServer_Context_Call) Run(run func()) *MockQueryStreamSegmentsServer_Context_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockQueryStreamSegmentsServer_Context_Call) Return(_a0 context.Context) *MockQueryStreamSegmentsServer_Context_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryStreamSegmentsServer_Context_Call) RunAndReturn(run func() context.Context) *MockQueryStreamSegmentsServer_Context_Call { + _c.Call.Return(run) + return _c +} + +// RecvMsg provides a mock function with given fields: m +func (_m *MockQueryStreamSegmentsServer) RecvMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryStreamSegmentsServer_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg' +type MockQueryStreamSegmentsServer_RecvMsg_Call struct { + *mock.Call +} + +// RecvMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockQueryStreamSegmentsServer_Expecter) RecvMsg(m interface{}) *MockQueryStreamSegmentsServer_RecvMsg_Call { + return &MockQueryStreamSegmentsServer_RecvMsg_Call{Call: _e.mock.On("RecvMsg", m)} +} + +func (_c *MockQueryStreamSegmentsServer_RecvMsg_Call) Run(run func(m interface{})) *MockQueryStreamSegmentsServer_RecvMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockQueryStreamSegmentsServer_RecvMsg_Call) Return(_a0 error) *MockQueryStreamSegmentsServer_RecvMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryStreamSegmentsServer_RecvMsg_Call) RunAndReturn(run func(interface{}) error) *MockQueryStreamSegmentsServer_RecvMsg_Call { + _c.Call.Return(run) + return _c +} + +// Send provides a mock function with given fields: _a0 +func (_m *MockQueryStreamSegmentsServer) Send(_a0 *internalpb.RetrieveResults) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(*internalpb.RetrieveResults) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryStreamSegmentsServer_Send_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Send' +type MockQueryStreamSegmentsServer_Send_Call struct { + *mock.Call +} + +// Send is a helper method to define mock.On call +// - _a0 *internalpb.RetrieveResults +func (_e *MockQueryStreamSegmentsServer_Expecter) Send(_a0 interface{}) *MockQueryStreamSegmentsServer_Send_Call { + return &MockQueryStreamSegmentsServer_Send_Call{Call: _e.mock.On("Send", _a0)} +} + +func (_c *MockQueryStreamSegmentsServer_Send_Call) Run(run func(_a0 *internalpb.RetrieveResults)) *MockQueryStreamSegmentsServer_Send_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*internalpb.RetrieveResults)) + }) + return _c +} + +func (_c *MockQueryStreamSegmentsServer_Send_Call) Return(_a0 error) *MockQueryStreamSegmentsServer_Send_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryStreamSegmentsServer_Send_Call) RunAndReturn(run func(*internalpb.RetrieveResults) error) *MockQueryStreamSegmentsServer_Send_Call { + _c.Call.Return(run) + return _c +} + +// SendHeader provides a mock function with given fields: _a0 +func (_m *MockQueryStreamSegmentsServer) SendHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryStreamSegmentsServer_SendHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendHeader' +type MockQueryStreamSegmentsServer_SendHeader_Call struct { + *mock.Call +} + +// SendHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockQueryStreamSegmentsServer_Expecter) SendHeader(_a0 interface{}) *MockQueryStreamSegmentsServer_SendHeader_Call { + return &MockQueryStreamSegmentsServer_SendHeader_Call{Call: _e.mock.On("SendHeader", _a0)} +} + +func (_c *MockQueryStreamSegmentsServer_SendHeader_Call) Run(run func(_a0 metadata.MD)) *MockQueryStreamSegmentsServer_SendHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockQueryStreamSegmentsServer_SendHeader_Call) Return(_a0 error) *MockQueryStreamSegmentsServer_SendHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryStreamSegmentsServer_SendHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockQueryStreamSegmentsServer_SendHeader_Call { + _c.Call.Return(run) + return _c +} + +// SendMsg provides a mock function with given fields: m +func (_m *MockQueryStreamSegmentsServer) SendMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryStreamSegmentsServer_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg' +type MockQueryStreamSegmentsServer_SendMsg_Call struct { + *mock.Call +} + +// SendMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockQueryStreamSegmentsServer_Expecter) SendMsg(m interface{}) *MockQueryStreamSegmentsServer_SendMsg_Call { + return &MockQueryStreamSegmentsServer_SendMsg_Call{Call: _e.mock.On("SendMsg", m)} +} + +func (_c *MockQueryStreamSegmentsServer_SendMsg_Call) Run(run func(m interface{})) *MockQueryStreamSegmentsServer_SendMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockQueryStreamSegmentsServer_SendMsg_Call) Return(_a0 error) *MockQueryStreamSegmentsServer_SendMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryStreamSegmentsServer_SendMsg_Call) RunAndReturn(run func(interface{}) error) *MockQueryStreamSegmentsServer_SendMsg_Call { + _c.Call.Return(run) + return _c +} + +// SetHeader provides a mock function with given fields: _a0 +func (_m *MockQueryStreamSegmentsServer) SetHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryStreamSegmentsServer_SetHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetHeader' +type MockQueryStreamSegmentsServer_SetHeader_Call struct { + *mock.Call +} + +// SetHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockQueryStreamSegmentsServer_Expecter) SetHeader(_a0 interface{}) *MockQueryStreamSegmentsServer_SetHeader_Call { + return &MockQueryStreamSegmentsServer_SetHeader_Call{Call: _e.mock.On("SetHeader", _a0)} +} + +func (_c *MockQueryStreamSegmentsServer_SetHeader_Call) Run(run func(_a0 metadata.MD)) *MockQueryStreamSegmentsServer_SetHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockQueryStreamSegmentsServer_SetHeader_Call) Return(_a0 error) *MockQueryStreamSegmentsServer_SetHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryStreamSegmentsServer_SetHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockQueryStreamSegmentsServer_SetHeader_Call { + _c.Call.Return(run) + return _c +} + +// SetTrailer provides a mock function with given fields: _a0 +func (_m *MockQueryStreamSegmentsServer) SetTrailer(_a0 metadata.MD) { + _m.Called(_a0) +} + +// MockQueryStreamSegmentsServer_SetTrailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTrailer' +type MockQueryStreamSegmentsServer_SetTrailer_Call struct { + *mock.Call +} + +// SetTrailer is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockQueryStreamSegmentsServer_Expecter) SetTrailer(_a0 interface{}) *MockQueryStreamSegmentsServer_SetTrailer_Call { + return &MockQueryStreamSegmentsServer_SetTrailer_Call{Call: _e.mock.On("SetTrailer", _a0)} +} + +func (_c *MockQueryStreamSegmentsServer_SetTrailer_Call) Run(run func(_a0 metadata.MD)) *MockQueryStreamSegmentsServer_SetTrailer_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockQueryStreamSegmentsServer_SetTrailer_Call) Return() *MockQueryStreamSegmentsServer_SetTrailer_Call { + _c.Call.Return() + return _c +} + +func (_c *MockQueryStreamSegmentsServer_SetTrailer_Call) RunAndReturn(run func(metadata.MD)) *MockQueryStreamSegmentsServer_SetTrailer_Call { + _c.Call.Return(run) + return _c +} + +// NewMockQueryStreamSegmentsServer creates a new instance of MockQueryStreamSegmentsServer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockQueryStreamSegmentsServer(t interface { + mock.TestingT + Cleanup(func()) +}) *MockQueryStreamSegmentsServer { + mock := &MockQueryStreamSegmentsServer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/util/streamrpc/mocks/mock_query_stream_server.go b/internal/util/streamrpc/mocks/mock_query_stream_server.go new file mode 100644 index 0000000000000..bebb58f3124cf --- /dev/null +++ b/internal/util/streamrpc/mocks/mock_query_stream_server.go @@ -0,0 +1,325 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + + internalpb "github.com/milvus-io/milvus/internal/proto/internalpb" + metadata "google.golang.org/grpc/metadata" + + mock "github.com/stretchr/testify/mock" +) + +// MockQueryStreamServer is an autogenerated mock type for the QueryNode_QueryStreamServer type +type MockQueryStreamServer struct { + mock.Mock +} + +type MockQueryStreamServer_Expecter struct { + mock *mock.Mock +} + +func (_m *MockQueryStreamServer) EXPECT() *MockQueryStreamServer_Expecter { + return &MockQueryStreamServer_Expecter{mock: &_m.Mock} +} + +// Context provides a mock function with given fields: +func (_m *MockQueryStreamServer) Context() context.Context { + ret := _m.Called() + + var r0 context.Context + if rf, ok := ret.Get(0).(func() context.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + return r0 +} + +// MockQueryStreamServer_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context' +type MockQueryStreamServer_Context_Call struct { + *mock.Call +} + +// Context is a helper method to define mock.On call +func (_e *MockQueryStreamServer_Expecter) Context() *MockQueryStreamServer_Context_Call { + return &MockQueryStreamServer_Context_Call{Call: _e.mock.On("Context")} +} + +func (_c *MockQueryStreamServer_Context_Call) Run(run func()) *MockQueryStreamServer_Context_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockQueryStreamServer_Context_Call) Return(_a0 context.Context) *MockQueryStreamServer_Context_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryStreamServer_Context_Call) RunAndReturn(run func() context.Context) *MockQueryStreamServer_Context_Call { + _c.Call.Return(run) + return _c +} + +// RecvMsg provides a mock function with given fields: m +func (_m *MockQueryStreamServer) RecvMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryStreamServer_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg' +type MockQueryStreamServer_RecvMsg_Call struct { + *mock.Call +} + +// RecvMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockQueryStreamServer_Expecter) RecvMsg(m interface{}) *MockQueryStreamServer_RecvMsg_Call { + return &MockQueryStreamServer_RecvMsg_Call{Call: _e.mock.On("RecvMsg", m)} +} + +func (_c *MockQueryStreamServer_RecvMsg_Call) Run(run func(m interface{})) *MockQueryStreamServer_RecvMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockQueryStreamServer_RecvMsg_Call) Return(_a0 error) *MockQueryStreamServer_RecvMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryStreamServer_RecvMsg_Call) RunAndReturn(run func(interface{}) error) *MockQueryStreamServer_RecvMsg_Call { + _c.Call.Return(run) + return _c +} + +// Send provides a mock function with given fields: _a0 +func (_m *MockQueryStreamServer) Send(_a0 *internalpb.RetrieveResults) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(*internalpb.RetrieveResults) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryStreamServer_Send_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Send' +type MockQueryStreamServer_Send_Call struct { + *mock.Call +} + +// Send is a helper method to define mock.On call +// - _a0 *internalpb.RetrieveResults +func (_e *MockQueryStreamServer_Expecter) Send(_a0 interface{}) *MockQueryStreamServer_Send_Call { + return &MockQueryStreamServer_Send_Call{Call: _e.mock.On("Send", _a0)} +} + +func (_c *MockQueryStreamServer_Send_Call) Run(run func(_a0 *internalpb.RetrieveResults)) *MockQueryStreamServer_Send_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*internalpb.RetrieveResults)) + }) + return _c +} + +func (_c *MockQueryStreamServer_Send_Call) Return(_a0 error) *MockQueryStreamServer_Send_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryStreamServer_Send_Call) RunAndReturn(run func(*internalpb.RetrieveResults) error) *MockQueryStreamServer_Send_Call { + _c.Call.Return(run) + return _c +} + +// SendHeader provides a mock function with given fields: _a0 +func (_m *MockQueryStreamServer) SendHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryStreamServer_SendHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendHeader' +type MockQueryStreamServer_SendHeader_Call struct { + *mock.Call +} + +// SendHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockQueryStreamServer_Expecter) SendHeader(_a0 interface{}) *MockQueryStreamServer_SendHeader_Call { + return &MockQueryStreamServer_SendHeader_Call{Call: _e.mock.On("SendHeader", _a0)} +} + +func (_c *MockQueryStreamServer_SendHeader_Call) Run(run func(_a0 metadata.MD)) *MockQueryStreamServer_SendHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockQueryStreamServer_SendHeader_Call) Return(_a0 error) *MockQueryStreamServer_SendHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryStreamServer_SendHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockQueryStreamServer_SendHeader_Call { + _c.Call.Return(run) + return _c +} + +// SendMsg provides a mock function with given fields: m +func (_m *MockQueryStreamServer) SendMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryStreamServer_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg' +type MockQueryStreamServer_SendMsg_Call struct { + *mock.Call +} + +// SendMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockQueryStreamServer_Expecter) SendMsg(m interface{}) *MockQueryStreamServer_SendMsg_Call { + return &MockQueryStreamServer_SendMsg_Call{Call: _e.mock.On("SendMsg", m)} +} + +func (_c *MockQueryStreamServer_SendMsg_Call) Run(run func(m interface{})) *MockQueryStreamServer_SendMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockQueryStreamServer_SendMsg_Call) Return(_a0 error) *MockQueryStreamServer_SendMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryStreamServer_SendMsg_Call) RunAndReturn(run func(interface{}) error) *MockQueryStreamServer_SendMsg_Call { + _c.Call.Return(run) + return _c +} + +// SetHeader provides a mock function with given fields: _a0 +func (_m *MockQueryStreamServer) SetHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockQueryStreamServer_SetHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetHeader' +type MockQueryStreamServer_SetHeader_Call struct { + *mock.Call +} + +// SetHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockQueryStreamServer_Expecter) SetHeader(_a0 interface{}) *MockQueryStreamServer_SetHeader_Call { + return &MockQueryStreamServer_SetHeader_Call{Call: _e.mock.On("SetHeader", _a0)} +} + +func (_c *MockQueryStreamServer_SetHeader_Call) Run(run func(_a0 metadata.MD)) *MockQueryStreamServer_SetHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockQueryStreamServer_SetHeader_Call) Return(_a0 error) *MockQueryStreamServer_SetHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockQueryStreamServer_SetHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockQueryStreamServer_SetHeader_Call { + _c.Call.Return(run) + return _c +} + +// SetTrailer provides a mock function with given fields: _a0 +func (_m *MockQueryStreamServer) SetTrailer(_a0 metadata.MD) { + _m.Called(_a0) +} + +// MockQueryStreamServer_SetTrailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTrailer' +type MockQueryStreamServer_SetTrailer_Call struct { + *mock.Call +} + +// SetTrailer is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockQueryStreamServer_Expecter) SetTrailer(_a0 interface{}) *MockQueryStreamServer_SetTrailer_Call { + return &MockQueryStreamServer_SetTrailer_Call{Call: _e.mock.On("SetTrailer", _a0)} +} + +func (_c *MockQueryStreamServer_SetTrailer_Call) Run(run func(_a0 metadata.MD)) *MockQueryStreamServer_SetTrailer_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockQueryStreamServer_SetTrailer_Call) Return() *MockQueryStreamServer_SetTrailer_Call { + _c.Call.Return() + return _c +} + +func (_c *MockQueryStreamServer_SetTrailer_Call) RunAndReturn(run func(metadata.MD)) *MockQueryStreamServer_SetTrailer_Call { + _c.Call.Return(run) + return _c +} + +// NewMockQueryStreamServer creates a new instance of MockQueryStreamServer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockQueryStreamServer(t interface { + mock.TestingT + Cleanup(func()) +}) *MockQueryStreamServer { + mock := &MockQueryStreamServer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/util/streamrpc/streamer.go b/internal/util/streamrpc/streamer.go new file mode 100644 index 0000000000000..53571672eeb8c --- /dev/null +++ b/internal/util/streamrpc/streamer.go @@ -0,0 +1,137 @@ +package streamrpc + +import ( + "context" + "io" + "sync" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proto/internalpb" +) + +type QueryStreamServer interface { + Send(*internalpb.RetrieveResults) error + Context() context.Context +} +type QueryStreamClient interface { + Recv() (*internalpb.RetrieveResults, error) + Context() context.Context + CloseSend() error +} + +type ConcurrentQueryStreamServer struct { + server QueryStreamServer + mu sync.Mutex +} + +func (s *ConcurrentQueryStreamServer) Send(result *internalpb.RetrieveResults) error { + s.mu.Lock() + defer s.mu.Unlock() + return s.server.Send(result) +} + +func (s *ConcurrentQueryStreamServer) Context() context.Context { + return s.server.Context() +} + +func NewConcurrentQueryStreamServer(srv QueryStreamServer) *ConcurrentQueryStreamServer { + return &ConcurrentQueryStreamServer{ + server: srv, + mu: sync.Mutex{}, + } +} + +// TODO LOCAL SERVER AND CLIENT FOR STANDALONE +// ONLY FOR TEST +type LocalQueryServer struct { + grpc.ServerStream + + resultCh chan *internalpb.RetrieveResults + ctx context.Context + + finishOnce sync.Once + errCh chan error + mu sync.Mutex +} + +func (s *LocalQueryServer) Send(result *internalpb.RetrieveResults) error { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + default: + s.resultCh <- result + return nil + } +} + +func (s *LocalQueryServer) FinishError() error { + return <-s.errCh +} + +func (s *LocalQueryServer) Context() context.Context { + return s.ctx +} + +func (s *LocalQueryServer) FinishSend(err error) error { + s.finishOnce.Do(func() { + close(s.resultCh) + if err != nil { + s.errCh <- err + } else { + s.errCh <- io.EOF + } + }) + return nil +} + +type LocalQueryClient struct { + grpc.ClientStream + + server *LocalQueryServer + resultCh chan *internalpb.RetrieveResults + ctx context.Context +} + +func (s *LocalQueryClient) RecvMsg(m interface{}) error { + // TODO implement me + panic("implement me") +} + +func (s *LocalQueryClient) Recv() (*internalpb.RetrieveResults, error) { + select { + case <-s.ctx.Done(): + return nil, s.ctx.Err() + default: + result, ok := <-s.resultCh + if !ok { + return nil, s.server.FinishError() + } + return result, nil + } +} + +func (s *LocalQueryClient) Context() context.Context { + return s.ctx +} + +func (s *LocalQueryClient) CloseSend() error { + return nil +} + +func (s *LocalQueryClient) CreateServer() *LocalQueryServer { + s.server = &LocalQueryServer{ + resultCh: s.resultCh, + ctx: s.ctx, + mu: sync.Mutex{}, + errCh: make(chan error, 1), + } + return s.server +} + +func NewLocalQueryClient(ctx context.Context) *LocalQueryClient { + return &LocalQueryClient{ + resultCh: make(chan *internalpb.RetrieveResults, 64), + ctx: ctx, + } +} diff --git a/internal/util/tsoutil/tso.go b/internal/util/tsoutil/tso.go index 4f259b043146e..9d2a83d32facc 100644 --- a/internal/util/tsoutil/tso.go +++ b/internal/util/tsoutil/tso.go @@ -19,13 +19,20 @@ package tsoutil import ( "path" + "github.com/tikv/client-go/v2/txnkv" clientv3 "go.etcd.io/etcd/client/v3" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/kv/tikv" ) // NewTSOKVBase returns a kv.TxnKV object func NewTSOKVBase(client *clientv3.Client, tsoRoot, subPath string) kv.TxnKV { return etcdkv.NewEtcdKV(client, path.Join(tsoRoot, subPath)) } + +// NewTSOTiKVBase returns a kv.TxnKV object +func NewTSOTiKVBase(client *txnkv.Client, tsoRoot, subPath string) kv.TxnKV { + return tikv.NewTiKV(client, path.Join(tsoRoot, subPath)) +} diff --git a/internal/util/typeutil/hash.go b/internal/util/typeutil/hash.go index bdf6f6bca9248..8815768336593 100644 --- a/internal/util/typeutil/hash.go +++ b/internal/util/typeutil/hash.go @@ -35,7 +35,6 @@ func HashKey2Partitions(fieldSchema *schemapb.FieldSchema, keys []*planpb.Generi } default: return nil, errors.New("currently only support DataType Int64 or VarChar as partition keys") - } result := make([]string, 0) diff --git a/internal/util/typeutil/result_helper_test.go b/internal/util/typeutil/result_helper_test.go index 5db74e48eb5a5..8d1cca1902203 100644 --- a/internal/util/typeutil/result_helper_test.go +++ b/internal/util/typeutil/result_helper_test.go @@ -3,19 +3,15 @@ package typeutil import ( "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/segcorepb" - - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) func fieldDataEmpty(data *schemapb.FieldData) bool { @@ -67,6 +63,7 @@ func TestGenEmptyFieldData(t *testing.T) { vectorTypes := []schemapb.DataType{ schemapb.DataType_BinaryVector, schemapb.DataType_FloatVector, + schemapb.DataType_Float16Vector, } field := &schemapb.FieldSchema{Name: "field_name", FieldID: 100} diff --git a/internal/util/utils.go b/internal/util/utils.go deleted file mode 100644 index 7fa8aef322667..0000000000000 --- a/internal/util/utils.go +++ /dev/null @@ -1,50 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package util - -import ( - "fmt" - "strings" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" -) - -// WrapStatus wraps status with given error code, message and errors -func WrapStatus(code commonpb.ErrorCode, msg string, errs ...error) *commonpb.Status { - status := &commonpb.Status{ - ErrorCode: code, - Reason: msg, - } - - for _, err := range errs { - status.Reason = fmt.Sprintf("%s, err=%v", status.Reason, err) - } - - return status -} - -// SuccessStatus returns a success status with given message -func SuccessStatus(msgs ...string) *commonpb.Status { - return &commonpb.Status{ - Reason: strings.Join(msgs, "; "), - } -} - -// WrapError wraps error with given message -func WrapError(msg string, err error) error { - return fmt.Errorf("%s[%w]", msg, err) -} diff --git a/internal/util/wrappers/qn_wrapper.go b/internal/util/wrappers/qn_wrapper.go new file mode 100644 index 0000000000000..63147c0116986 --- /dev/null +++ b/internal/util/wrappers/qn_wrapper.go @@ -0,0 +1,155 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wrappers + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/streamrpc" +) + +type qnServerWrapper struct { + types.QueryNode +} + +func (qn *qnServerWrapper) Close() error { + return nil +} + +func (qn *qnServerWrapper) GetComponentStates(ctx context.Context, in *milvuspb.GetComponentStatesRequest, opts ...grpc.CallOption) (*milvuspb.ComponentStates, error) { + return qn.QueryNode.GetComponentStates(ctx, in) +} + +func (qn *qnServerWrapper) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return qn.QueryNode.GetTimeTickChannel(ctx, in) +} + +func (qn *qnServerWrapper) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return qn.QueryNode.GetStatisticsChannel(ctx, in) +} + +func (qn *qnServerWrapper) WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.WatchDmChannels(ctx, in) +} + +func (qn *qnServerWrapper) UnsubDmChannel(ctx context.Context, in *querypb.UnsubDmChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.UnsubDmChannel(ctx, in) +} + +func (qn *qnServerWrapper) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.LoadSegments(ctx, in) +} + +func (qn *qnServerWrapper) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.ReleaseCollection(ctx, in) +} + +func (qn *qnServerWrapper) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.LoadPartitions(ctx, in) +} + +func (qn *qnServerWrapper) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.ReleasePartitions(ctx, in) +} + +func (qn *qnServerWrapper) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.ReleaseSegments(ctx, in) +} + +func (qn *qnServerWrapper) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) { + return qn.QueryNode.GetSegmentInfo(ctx, in) +} + +func (qn *qnServerWrapper) SyncReplicaSegments(ctx context.Context, in *querypb.SyncReplicaSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.SyncReplicaSegments(ctx, in) +} + +func (qn *qnServerWrapper) GetStatistics(ctx context.Context, in *querypb.GetStatisticsRequest, opts ...grpc.CallOption) (*internalpb.GetStatisticsResponse, error) { + return qn.QueryNode.GetStatistics(ctx, in) +} + +func (qn *qnServerWrapper) Search(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) { + return qn.QueryNode.Search(ctx, in) +} + +func (qn *qnServerWrapper) SearchSegments(ctx context.Context, in *querypb.SearchRequest, opts ...grpc.CallOption) (*internalpb.SearchResults, error) { + return qn.QueryNode.SearchSegments(ctx, in) +} + +func (qn *qnServerWrapper) Query(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (*internalpb.RetrieveResults, error) { + return qn.QueryNode.Query(ctx, in) +} + +func (qn *qnServerWrapper) QueryStream(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (querypb.QueryNode_QueryStreamClient, error) { + streamer := streamrpc.NewInMemoryStreamer[*internalpb.RetrieveResults](ctx, 16) + + go func() { + qn.QueryNode.QueryStream(in, streamer) + streamer.Close() + }() + + return streamer, nil +} + +func (qn *qnServerWrapper) QuerySegments(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (*internalpb.RetrieveResults, error) { + return qn.QueryNode.QuerySegments(ctx, in) +} + +func (qn *qnServerWrapper) QueryStreamSegments(ctx context.Context, in *querypb.QueryRequest, opts ...grpc.CallOption) (querypb.QueryNode_QueryStreamSegmentsClient, error) { + streamer := streamrpc.NewInMemoryStreamer[*internalpb.RetrieveResults](ctx, 16) + + go func() { + qn.QueryNode.QueryStreamSegments(in, streamer) + streamer.Close() + }() + + return streamer, nil +} + +func (qn *qnServerWrapper) ShowConfigurations(ctx context.Context, in *internalpb.ShowConfigurationsRequest, opts ...grpc.CallOption) (*internalpb.ShowConfigurationsResponse, error) { + return qn.QueryNode.ShowConfigurations(ctx, in) +} + +// https://wiki.lfaidata.foundation/display/MIL/MEP+8+--+Add+metrics+for+proxy +func (qn *qnServerWrapper) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + return qn.QueryNode.GetMetrics(ctx, in) +} + +func (qn *qnServerWrapper) GetDataDistribution(ctx context.Context, in *querypb.GetDataDistributionRequest, opts ...grpc.CallOption) (*querypb.GetDataDistributionResponse, error) { + return qn.QueryNode.GetDataDistribution(ctx, in) +} + +func (qn *qnServerWrapper) SyncDistribution(ctx context.Context, in *querypb.SyncDistributionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.SyncDistribution(ctx, in) +} + +func (qn *qnServerWrapper) Delete(ctx context.Context, in *querypb.DeleteRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return qn.QueryNode.Delete(ctx, in) +} + +func WrapQueryNodeServerAsClient(qn types.QueryNode) types.QueryNodeClient { + return &qnServerWrapper{ + QueryNode: qn, + } +} diff --git a/internal/util/wrappers/qn_wrapper_test.go b/internal/util/wrappers/qn_wrapper_test.go new file mode 100644 index 0000000000000..94719ee2daae2 --- /dev/null +++ b/internal/util/wrappers/qn_wrapper_test.go @@ -0,0 +1,294 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wrappers + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +type QnWrapperSuite struct { + suite.Suite + + qn *mocks.MockQueryNode + client types.QueryNodeClient +} + +func (s *QnWrapperSuite) SetupTest() { + s.qn = mocks.NewMockQueryNode(s.T()) + s.client = WrapQueryNodeServerAsClient(s.qn) +} + +func (s *QnWrapperSuite) TearDownTest() { + s.client = nil + s.qn = nil +} + +func (s *QnWrapperSuite) TestGetComponentStates() { + s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything). + Return(&milvuspb.ComponentStates{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetComponentStates(context.Background(), &milvuspb.GetComponentStatesRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestGetTimeTickChannel() { + s.qn.EXPECT().GetTimeTickChannel(mock.Anything, mock.Anything). + Return(&milvuspb.StringResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetTimeTickChannel(context.Background(), &internalpb.GetTimeTickChannelRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestGetStatisticsChannel() { + s.qn.EXPECT().GetStatisticsChannel(mock.Anything, mock.Anything). + Return(&milvuspb.StringResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetStatisticsChannel(context.Background(), &internalpb.GetStatisticsChannelRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestWatchDmChannels() { + s.qn.EXPECT().WatchDmChannels(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.WatchDmChannels(context.Background(), &querypb.WatchDmChannelsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestUnsubDmChannel() { + s.qn.EXPECT().UnsubDmChannel(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.UnsubDmChannel(context.Background(), &querypb.UnsubDmChannelRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestLoadSegments() { + s.qn.EXPECT().LoadSegments(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.LoadSegments(context.Background(), &querypb.LoadSegmentsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestReleaseCollection() { + s.qn.EXPECT().ReleaseCollection(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.ReleaseCollection(context.Background(), &querypb.ReleaseCollectionRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestLoadPartitions() { + s.qn.EXPECT().LoadPartitions(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.LoadPartitions(context.Background(), &querypb.LoadPartitionsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestReleasePartitions() { + s.qn.EXPECT().ReleasePartitions(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.ReleasePartitions(context.Background(), &querypb.ReleasePartitionsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestReleaseSegments() { + s.qn.EXPECT().ReleaseSegments(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.ReleaseSegments(context.Background(), &querypb.ReleaseSegmentsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestGetSegmentInfo() { + s.qn.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything). + Return(&querypb.GetSegmentInfoResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetSegmentInfo(context.Background(), &querypb.GetSegmentInfoRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestSyncReplicaSegments() { + s.qn.EXPECT().SyncReplicaSegments(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.SyncReplicaSegments(context.Background(), &querypb.SyncReplicaSegmentsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestGetStatistics() { + s.qn.EXPECT().GetStatistics(mock.Anything, mock.Anything). + Return(&internalpb.GetStatisticsResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetStatistics(context.Background(), &querypb.GetStatisticsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestSearch() { + s.qn.EXPECT().Search(mock.Anything, mock.Anything). + Return(&internalpb.SearchResults{Status: merr.Status(nil)}, nil) + + resp, err := s.client.Search(context.Background(), &querypb.SearchRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestSearchSegments() { + s.qn.EXPECT().SearchSegments(mock.Anything, mock.Anything). + Return(&internalpb.SearchResults{Status: merr.Status(nil)}, nil) + + resp, err := s.client.SearchSegments(context.Background(), &querypb.SearchRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestQuery() { + s.qn.EXPECT().Query(mock.Anything, mock.Anything). + Return(&internalpb.RetrieveResults{Status: merr.Status(nil)}, nil) + + resp, err := s.client.Query(context.Background(), &querypb.QueryRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestQuerySegments() { + s.qn.EXPECT().QuerySegments(mock.Anything, mock.Anything). + Return(&internalpb.RetrieveResults{Status: merr.Status(nil)}, nil) + + resp, err := s.client.QuerySegments(context.Background(), &querypb.QueryRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestShowConfigurations() { + s.qn.EXPECT().ShowConfigurations(mock.Anything, mock.Anything). + Return(&internalpb.ShowConfigurationsResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.ShowConfigurations(context.Background(), &internalpb.ShowConfigurationsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestGetMetrics() { + s.qn.EXPECT().GetMetrics(mock.Anything, mock.Anything). + Return(&milvuspb.GetMetricsResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetMetrics(context.Background(), &milvuspb.GetMetricsRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) GetDataDistribution() { + s.qn.EXPECT().GetDataDistribution(mock.Anything, mock.Anything). + Return(&querypb.GetDataDistributionResponse{Status: merr.Status(nil)}, nil) + + resp, err := s.client.GetDataDistribution(context.Background(), &querypb.GetDataDistributionRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestSyncDistribution() { + s.qn.EXPECT().SyncDistribution(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.SyncDistribution(context.Background(), &querypb.SyncDistributionRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +func (s *QnWrapperSuite) TestDelete() { + s.qn.EXPECT().Delete(mock.Anything, mock.Anything). + Return(merr.Status(nil), nil) + + resp, err := s.client.Delete(context.Background(), &querypb.DeleteRequest{}) + err = merr.CheckRPCCall(resp, err) + s.NoError(err) +} + +// Race caused by mock parameter check on once +/* +func (s *QnWrapperSuite) TestQueryStream() { + s.qn.EXPECT().QueryStream(mock.Anything, mock.Anything). + Run(func(_ *querypb.QueryRequest, server querypb.QueryNode_QueryStreamServer) { + server.Send(&internalpb.RetrieveResults{}) + }). + Return(nil) + + streamer, err := s.client.QueryStream(context.Background(), &querypb.QueryRequest{}) + s.NoError(err) + inMemStreamer, ok := streamer.(*streamrpc.InMemoryStreamer[*internalpb.RetrieveResults]) + s.Require().True(ok) + + r, err := streamer.Recv() + err = merr.CheckRPCCall(r, err) + s.NoError(err) + + s.Eventually(func() bool { + return inMemStreamer.IsClosed() + }, time.Second, time.Millisecond*100) +} + +func (s *QnWrapperSuite) TestQueryStreamSegments() { + s.qn.EXPECT().QueryStreamSegments(mock.Anything, mock.Anything). + Run(func(_ *querypb.QueryRequest, server querypb.QueryNode_QueryStreamSegmentsServer) { + server.Send(&internalpb.RetrieveResults{}) + }). + Return(nil) + + streamer, err := s.client.QueryStreamSegments(context.Background(), &querypb.QueryRequest{}) + s.NoError(err) + inMemStreamer, ok := streamer.(*streamrpc.InMemoryStreamer[*internalpb.RetrieveResults]) + s.Require().True(ok) + + r, err := streamer.Recv() + err = merr.CheckRPCCall(r, err) + s.NoError(err) + s.Eventually(func() bool { + return inMemStreamer.IsClosed() + }, time.Second, time.Millisecond*100) +}*/ + +func TestQnServerWrapper(t *testing.T) { + suite.Run(t, new(QnWrapperSuite)) +} diff --git a/pkg/Makefile b/pkg/Makefile index b478d71d19575..639bf54f17042 100644 --- a/pkg/Makefile +++ b/pkg/Makefile @@ -13,6 +13,7 @@ getdeps: generate-mockery: getdeps $(INSTALL_PATH)/mockery --name=MsgStream --dir=$(PWD)/mq/msgstream --output=$(PWD)/mq/msgstream --filename=mock_msgstream.go --with-expecter --structname=MockMsgStream --outpkg=msgstream --inpackage + $(INSTALL_PATH)/mockery --name=Factory --dir=$(PWD)/mq/msgstream --output=$(PWD)/mq/msgstream --filename=mock_msgstream_factory.go --with-expecter --structname=MockFactory --outpkg=msgstream --inpackage $(INSTALL_PATH)/mockery --name=Client --dir=$(PWD)/mq/msgdispatcher --output=$(PWD)/mq/msgsdispatcher --filename=mock_client.go --with-expecter --structname=MockClient --outpkg=msgdispatcher --inpackage $(INSTALL_PATH)/mockery --name=Logger --dir=$(PWD)/eventlog --output=$(PWD)/eventlog --filename=mock_logger.go --with-expecter --structname=MockLogger --outpkg=eventlog --inpackage - + $(INSTALL_PATH)/mockery --name=MessageID --dir=$(PWD)/mq/msgstream/mqwrapper --output=$(PWD)/mq/msgstream/mqwrapper --filename=mock_id.go --with-expecter --structname=MockMessageID --outpkg=mqwrapper --inpackage diff --git a/pkg/common/byte_slice_test.go b/pkg/common/byte_slice_test.go index 46f0c492d1587..8a6385f07812c 100644 --- a/pkg/common/byte_slice_test.go +++ b/pkg/common/byte_slice_test.go @@ -30,7 +30,8 @@ func TestCloneByteSlice(t *testing.T) { { args: args{s: []byte{0xf0}}, want: []byte{0xf0}, - }, { + }, + { args: args{s: []byte{0x0, 0xff, 0x0f, 0xf0}}, want: []byte{0x0, 0xff, 0x0f, 0xf0}, }, diff --git a/pkg/common/common.go b/pkg/common/common.go index 67b56953accf3..1ec3bfe3bd51d 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -94,6 +94,7 @@ const ( MetricTypeKey = "metric_type" DimKey = "dim" MaxLengthKey = "max_length" + MaxCapacityKey = "max_capacity" ) // Collection properties key diff --git a/pkg/common/error.go b/pkg/common/error.go index 81d815160efde..128151e9dc26d 100644 --- a/pkg/common/error.go +++ b/pkg/common/error.go @@ -22,10 +22,8 @@ import ( "github.com/cockroachdb/errors" ) -var ( - // ErrNodeIDNotMatch stands for the error that grpc target id and node session id not match. - ErrNodeIDNotMatch = errors.New("target node id not match") -) +// ErrNodeIDNotMatch stands for the error that grpc target id and node session id not match. +var ErrNodeIDNotMatch = errors.New("target node id not match") // WrapNodeIDNotMatchError wraps `ErrNodeIDNotMatch` with targetID and sessionID. func WrapNodeIDNotMatchError(targetID, nodeID int64) error { @@ -55,22 +53,3 @@ func IsIgnorableError(err error) bool { _, ok := err.(*IgnorableError) return ok } - -var _ error = &KeyNotExistError{} - -func NewKeyNotExistError(key string) error { - return &KeyNotExistError{key: key} -} - -func IsKeyNotExistError(err error) bool { - _, ok := err.(*KeyNotExistError) - return ok -} - -type KeyNotExistError struct { - key string -} - -func (k *KeyNotExistError) Error() string { - return fmt.Sprintf("there is no value on key = %s", k.key) -} diff --git a/pkg/common/error_test.go b/pkg/common/error_test.go index d35b3fc64e8a8..1fe68b95b9a68 100644 --- a/pkg/common/error_test.go +++ b/pkg/common/error_test.go @@ -29,9 +29,3 @@ func TestIgnorableError(t *testing.T) { assert.True(t, IsIgnorableError(iErr)) assert.False(t, IsIgnorableError(err)) } - -func TestNotExistError(t *testing.T) { - err := errors.New("err") - assert.Equal(t, false, IsKeyNotExistError(err)) - assert.Equal(t, true, IsKeyNotExistError(NewKeyNotExistError("foo"))) -} diff --git a/pkg/common/key_data_pairs_test.go b/pkg/common/key_data_pairs_test.go index 9beb7d0c316d7..ce9c74c1c035b 100644 --- a/pkg/common/key_data_pairs_test.go +++ b/pkg/common/key_data_pairs_test.go @@ -3,8 +3,9 @@ package common import ( "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) func TestCloneKeyDataPairs(t *testing.T) { diff --git a/pkg/common/key_value_pairs_test.go b/pkg/common/key_value_pairs_test.go index 0a18f2b5ea255..761030921ec06 100644 --- a/pkg/common/key_value_pairs_test.go +++ b/pkg/common/key_value_pairs_test.go @@ -3,8 +3,9 @@ package common import ( "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) func TestCloneKeyValuePairs(t *testing.T) { diff --git a/pkg/config/config.go b/pkg/config/config.go index de93941e6e5cb..4e81220790717 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -37,7 +37,6 @@ func Init(opts ...Option) (*Manager, error) { if o.FileInfo != nil { s := NewFileSource(o.FileInfo) sourceManager.AddSource(s) - } if o.EnvKeyFormatter != nil { sourceManager.AddSource(NewEnvSource(o.EnvKeyFormatter)) @@ -50,7 +49,6 @@ func Init(opts ...Option) (*Manager, error) { sourceManager.AddSource(s) } return sourceManager, nil - } func formatKey(key string) string { diff --git a/pkg/config/env_source.go b/pkg/config/env_source.go index 9c2138b5c3dfb..abef8bb821cf0 100644 --- a/pkg/config/env_source.go +++ b/pkg/config/env_source.go @@ -79,11 +79,10 @@ func (es EnvSource) GetSourceName() string { } func (es EnvSource) SetEventHandler(eh EventHandler) { - } + func (es EnvSource) UpdateOptions(opts Options) { } func (es EnvSource) Close() { - } diff --git a/pkg/config/event.go b/pkg/config/event.go index 993ac04e9b415..e2f2173b7e308 100644 --- a/pkg/config/event.go +++ b/pkg/config/event.go @@ -38,7 +38,6 @@ func newEvent(eventSource, eventType string, key string, value string) *Event { Value: value, HasUpdated: false, } - } func PopulateEvents(source string, currentConfig, updatedConfig map[string]string) ([]*Event, error) { diff --git a/pkg/config/manager.go b/pkg/config/manager.go index 11cf407a403b1..7a28797bbd0ab 100644 --- a/pkg/config/manager.go +++ b/pkg/config/manager.go @@ -324,7 +324,6 @@ func (m *Manager) updateEvent(e *Event) error { m.keySourceMap[e.Key] = source.GetSourceName() } } - } log.Info("receive update event", zap.Any("event", e)) @@ -387,7 +386,7 @@ func (m *Manager) getHighPrioritySource(srcNameA, srcNameB string) Source { return sourceA } - if sourceA.GetPriority() < sourceB.GetPriority() { //less value has high priority + if sourceA.GetPriority() < sourceB.GetPriority() { // less value has high priority return sourceA } diff --git a/pkg/config/manager_test.go b/pkg/config/manager_test.go index 8f39cd5d348de..2635f07979390 100644 --- a/pkg/config/manager_test.go +++ b/pkg/config/manager_test.go @@ -38,8 +38,8 @@ func TestAllConfigFromManager(t *testing.T) { func TestConfigChangeEvent(t *testing.T) { dir, _ := os.MkdirTemp("", "milvus") - os.WriteFile(path.Join(dir, "milvus.yaml"), []byte("a.b: 1\nc.d: 2"), 0600) - os.WriteFile(path.Join(dir, "user.yaml"), []byte("a.b: 3"), 0600) + os.WriteFile(path.Join(dir, "milvus.yaml"), []byte("a.b: 1\nc.d: 2"), 0o600) + os.WriteFile(path.Join(dir, "user.yaml"), []byte("a.b: 3"), 0o600) fs := NewFileSource(&FileInfo{[]string{path.Join(dir, "milvus.yaml"), path.Join(dir, "user.yaml")}, 1}) mgr, _ := Init() @@ -48,7 +48,7 @@ func TestConfigChangeEvent(t *testing.T) { res, err := mgr.GetConfig("a.b") assert.NoError(t, err) assert.Equal(t, res, "3") - os.WriteFile(path.Join(dir, "user.yaml"), []byte("a.b: 6"), 0600) + os.WriteFile(path.Join(dir, "user.yaml"), []byte("a.b: 6"), 0o600) time.Sleep(3 * time.Second) res, err = mgr.GetConfig("a.b") assert.NoError(t, err) @@ -69,8 +69,7 @@ func TestAllDupliateSource(t *testing.T) { assert.Error(t, err, "invalid source or source not added") } -type ErrSource struct { -} +type ErrSource struct{} func (e ErrSource) Close() { } @@ -95,7 +94,6 @@ func (ErrSource) GetSourceName() string { } func (e ErrSource) SetEventHandler(eh EventHandler) { - } func (e ErrSource) UpdateOptions(opt Options) { diff --git a/pkg/config/refresher.go b/pkg/config/refresher.go index 353629cf000f9..4987d986b36e5 100644 --- a/pkg/config/refresher.go +++ b/pkg/config/refresher.go @@ -77,7 +77,6 @@ func (r *refresher) refreshPeriodically(name string) { return } } - } func (r *refresher) fireEvents(name string, source, target map[string]string) error { diff --git a/pkg/config/source.go b/pkg/config/source.go index 4095971e18760..8382915797f53 100644 --- a/pkg/config/source.go +++ b/pkg/config/source.go @@ -44,7 +44,7 @@ type EtcdInfo struct { CaCertFile string MinVersion string - //Pull Configuration interval, unit is second + // Pull Configuration interval, unit is second RefreshInterval time.Duration } diff --git a/pkg/config/source_test.go b/pkg/config/source_test.go index 0836bc9be833f..1b9068faf95a1 100644 --- a/pkg/config/source_test.go +++ b/pkg/config/source_test.go @@ -39,8 +39,8 @@ func TestLoadFromFileSource(t *testing.T) { t.Run("multiple files", func(t *testing.T) { dir, _ := os.MkdirTemp("", "milvus") - os.WriteFile(path.Join(dir, "milvus.yaml"), []byte("a.b: 1\nc.d: 2"), 0600) - os.WriteFile(path.Join(dir, "user.yaml"), []byte("a.b: 3"), 0600) + os.WriteFile(path.Join(dir, "milvus.yaml"), []byte("a.b: 1\nc.d: 2"), 0o600) + os.WriteFile(path.Join(dir, "user.yaml"), []byte("a.b: 3"), 0o600) fs := NewFileSource(&FileInfo{[]string{path.Join(dir, "milvus.yaml"), path.Join(dir, "user.yaml")}, -1}) fs.loadFromFile() diff --git a/pkg/eventlog/global.go b/pkg/eventlog/global.go index bd363d5096f70..13549ac234a2b 100644 --- a/pkg/eventlog/global.go +++ b/pkg/eventlog/global.go @@ -17,9 +17,10 @@ package eventlog import ( + "go.uber.org/atomic" + "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/atomic" ) var ( diff --git a/pkg/eventlog/grpc.go b/pkg/eventlog/grpc.go index ddb515869b9d4..e3cd9911804a9 100644 --- a/pkg/eventlog/grpc.go +++ b/pkg/eventlog/grpc.go @@ -21,11 +21,12 @@ import ( "sync" "time" + "go.uber.org/atomic" + "google.golang.org/grpc" + "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/typeutil" - "go.uber.org/atomic" - "google.golang.org/grpc" ) var ( diff --git a/pkg/eventlog/handler.go b/pkg/eventlog/handler.go index 4ba5ea496d57d..7461d2c627442 100644 --- a/pkg/eventlog/handler.go +++ b/pkg/eventlog/handler.go @@ -20,8 +20,9 @@ import ( "encoding/json" "net/http" - "github.com/milvus-io/milvus/pkg/log" "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" ) const ( @@ -34,8 +35,7 @@ const ( ) // eventLogHandler is the event log http handler -type eventLogHandler struct { -} +type eventLogHandler struct{} func Handler() http.Handler { return &eventLogHandler{} diff --git a/pkg/eventlog/handler_test.go b/pkg/eventlog/handler_test.go index 54c4487bbfa0b..fd87f29d61ca0 100644 --- a/pkg/eventlog/handler_test.go +++ b/pkg/eventlog/handler_test.go @@ -18,7 +18,7 @@ package eventlog import ( "encoding/json" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -55,7 +55,7 @@ func (s *HandlerSuite) TestServerHTTP() { res := w.Result() defer res.Body.Close() - data, err := ioutil.ReadAll(res.Body) + data, err := io.ReadAll(res.Body) s.Require().NoError(err) resp := eventLogResponse{} diff --git a/pkg/eventlog/mock_logger.go b/pkg/eventlog/mock_logger.go index 8d5c8c3306e78..566126521a368 100644 --- a/pkg/eventlog/mock_logger.go +++ b/pkg/eventlog/mock_logger.go @@ -130,7 +130,8 @@ func (_c *MockLogger_RecordFunc_Call) RunAndReturn(run func(Level, func() Evt)) func NewMockLogger(t interface { mock.TestingT Cleanup(func()) -}) *MockLogger { +}, +) *MockLogger { mock := &MockLogger{} mock.Mock.Test(t) diff --git a/pkg/go.mod b/pkg/go.mod index 17cf2afda9734..903280645888e 100644 --- a/pkg/go.mod +++ b/pkg/go.mod @@ -13,11 +13,12 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/klauspost/compress v1.16.5 github.com/lingdor/stackerror v0.0.0-20191119040541-976d8885ed76 - github.com/milvus-io/milvus-proto/go-api/v2 v2.3.0-dev.1.0.20230716112827-c3fe148f5e1d + github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2-0.20231008032233-5d64d443769d github.com/nats-io/nats-server/v2 v2.9.17 github.com/nats-io/nats.go v1.24.0 github.com/panjf2000/ants/v2 v2.7.2 github.com/prometheus/client_golang v1.14.0 + github.com/quasilyte/go-ruleguard/dsl v0.3.22 github.com/samber/lo v1.27.0 github.com/shirou/gopsutil/v3 v3.22.9 github.com/spaolacci/murmur3 v1.1.0 @@ -25,6 +26,7 @@ require ( github.com/spf13/viper v1.8.1 github.com/streamnative/pulsarctl v0.5.0 github.com/stretchr/testify v1.8.3 + github.com/tikv/client-go/v2 v2.0.4 github.com/uber/jaeger-client-go v2.30.0+incompatible go.etcd.io/etcd/client/v3 v3.5.5 go.etcd.io/etcd/server/v3 v3.5.5 @@ -37,9 +39,10 @@ require ( go.opentelemetry.io/otel/trace v1.13.0 go.uber.org/atomic v1.10.0 go.uber.org/automaxprocs v1.5.2 - go.uber.org/zap v1.17.0 - golang.org/x/crypto v0.9.0 + go.uber.org/zap v1.20.0 + golang.org/x/crypto v0.14.0 golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 + golang.org/x/net v0.17.0 golang.org/x/sync v0.1.0 google.golang.org/grpc v1.54.0 google.golang.org/protobuf v1.30.0 @@ -53,6 +56,7 @@ require ( github.com/BurntSushi/toml v1.2.1 // indirect github.com/DataDog/zstd v1.5.0 // indirect github.com/ardielle/ardielle-go v1.5.2 // indirect + github.com/benbjohnson/clock v1.1.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v4 v4.2.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect @@ -60,8 +64,10 @@ require ( github.com/cockroachdb/redact v1.1.3 // indirect github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect + github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 // indirect github.com/danieljoos/wincred v1.1.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 // indirect github.com/docker/go-units v0.4.0 // indirect github.com/dustin/go-humanize v1.0.0 // indirect github.com/dvsekhvalnov/jose2go v1.5.0 // indirect @@ -76,7 +82,8 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang/snappy v0.0.4 // indirect - github.com/google/btree v1.0.1 // indirect + github.com/google/btree v1.1.2 // indirect + github.com/google/uuid v1.3.0 // indirect github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect github.com/gorilla/websocket v1.4.2 // indirect github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect @@ -104,14 +111,21 @@ require ( github.com/nats-io/nkeys v0.4.4 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/opencontainers/runtime-spec v1.0.2 // indirect + github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pelletier/go-toml v1.9.3 // indirect github.com/pierrec/lz4 v2.5.2+incompatible // indirect + github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect + github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 // indirect + github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989 // indirect + github.com/pingcap/kvproto v0.0.0-20221129023506-621ec37aac7a // indirect + github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/prometheus/client_model v0.3.0 // indirect github.com/prometheus/common v0.42.0 // indirect github.com/prometheus/procfs v0.9.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/rogpeppe/go-internal v1.8.1 // indirect github.com/sirupsen/logrus v1.8.1 // indirect github.com/smartystreets/assertions v1.1.0 // indirect @@ -119,11 +133,15 @@ require ( github.com/spf13/afero v1.6.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stathat/consistent v1.0.0 // indirect github.com/stretchr/objx v0.5.0 // indirect github.com/subosito/gotenv v1.2.0 // indirect + github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a // indirect + github.com/tikv/pd/client v0.0.0-20221031025758-80f0d8ca4d07 // indirect github.com/tklauser/go-sysconf v0.3.10 // indirect github.com/tklauser/numcpus v0.4.0 // indirect github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 // indirect + github.com/twmb/murmur3 v1.1.3 // indirect github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect go.etcd.io/bbolt v1.3.6 // indirect @@ -136,12 +154,11 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.13.0 // indirect go.opentelemetry.io/otel/metric v0.35.0 // indirect go.opentelemetry.io/proto/otlp v0.19.0 // indirect - go.uber.org/multierr v1.6.0 // indirect - golang.org/x/net v0.10.0 // indirect + go.uber.org/multierr v1.7.0 // indirect golang.org/x/oauth2 v0.6.0 // indirect - golang.org/x/sys v0.8.0 // indirect - golang.org/x/term v0.8.0 // indirect - golang.org/x/text v0.9.0 // indirect + golang.org/x/sys v0.13.0 // indirect + golang.org/x/term v0.13.0 // indirect + golang.org/x/text v0.13.0 // indirect golang.org/x/time v0.3.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633 // indirect diff --git a/pkg/go.sum b/pkg/go.sum index 3ed3a90b8a2d4..53dfd644f9863 100644 --- a/pkg/go.sum +++ b/pkg/go.sum @@ -67,6 +67,7 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuy github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= +github.com/antihax/optional v0.0.0-20180407024304-ca021399b1a6/go.mod h1:V8iCPQYkqmusNa815XgQio277wI47sdRh1dUOLdyC6Q= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/ardielle/ardielle-go v1.5.2 h1:TilHTpHIQJ27R1Tl/iITBzMwiUGSlVfiVhwDNGM3Zj4= github.com/ardielle/ardielle-go v1.5.2/go.mod h1:I4hy1n795cUhaVt/ojz83SNVCYIGsAFAONtv2Dr7HUI= @@ -77,6 +78,8 @@ github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmV github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/aws/aws-sdk-go v1.32.6/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benesch/cgosymbolizer v0.0.0-20190515212042-bec6fe6e597b h1:5JgaFtHFRnOPReItxvhMDXbvuBkjSWE+9glJyF466yw= github.com/benesch/cgosymbolizer v0.0.0-20190515212042-bec6fe6e597b/go.mod h1:eMD2XUcPsHYbakFEocKrWZp47G0MRJYoC60qFblGjpA= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= @@ -141,6 +144,8 @@ github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwc github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 h1:iwZdTE0PVqJCos1vaoKsclOGD3ADKpshg3SRtYBbwso= +github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM= github.com/danieljoos/wincred v1.1.2 h1:QLdCxFs1/Yl4zduvBdcHB8goaYk9RARS2SgLLRuAyr0= github.com/danieljoos/wincred v1.1.2/go.mod h1:GijpziifJoIBfYh+S7BbkdUTU4LfM+QnGqR5Vl2tAx0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -148,6 +153,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/badger v1.6.0/go.mod h1:zwt7syl517jmP8s94KqSxTlM6IMsdhYy6psNgSztDR4= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dimfeld/httptreemux v5.0.1+incompatible h1:Qj3gVcDNoOthBAqftuD596rm4wg/adLLz5xh5CmpiCA= @@ -211,6 +217,7 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= @@ -274,8 +281,9 @@ github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/gomodule/redigo v1.7.1-0.20190724094224-574c33c3df38/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= +github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= +github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.2.1-0.20190312032427-6f77996f0c42/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -328,6 +336,7 @@ github.com/grpc-ecosystem/go-grpc-middleware v1.3.0/go.mod h1:z0ButlSOZa5vEBq9m2 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= +github.com/grpc-ecosystem/grpc-gateway v1.12.1/go.mod h1:8XEsbTttt/W+VvjtQhLACqCisSPWTxCZ7sBRjU6iH9c= github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 h1:BZHcxBETFHIdVyhyEfOvn/RdU/QGdLI4y34qQGjGWO0= @@ -468,8 +477,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8= github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.0-dev.1.0.20230716112827-c3fe148f5e1d h1:XsQQ/MigebXEE2VXPKKmA3K7OHC+mkEUiErWvaWMikI= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.0-dev.1.0.20230716112827-c3fe148f5e1d/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2-0.20231008032233-5d64d443769d h1:K8yyzz8BCBm+wirhRgySyB8wN+sw33eB3VsLz6Slu5s= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.2-0.20231008032233-5d64d443769d/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek= github.com/milvus-io/pulsar-client-go v0.6.10 h1:eqpJjU+/QX0iIhEo3nhOqMNXL+TyInAs1IAHZCrCM/A= github.com/milvus-io/pulsar-client-go v0.6.10/go.mod h1:lQqCkgwDF8YFYjKA+zOheTk1tev2B+bKj5j7+nm8M1w= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= @@ -531,6 +540,7 @@ github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9 github.com/opencontainers/runtime-spec v1.0.2 h1:UfAcuLBJB9Coz72x1hgl8O5RVzTdNiaglX6v2DM6FI0= github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/panjf2000/ants/v2 v2.7.2 h1:2NUt9BaZFO5kQzrieOmK/wdb/tQ/K+QHaxN8sOgD63U= github.com/panjf2000/ants/v2 v2.7.2/go.mod h1:KIBmYG9QQX5U2qzFP/yQJaq/nSb6rahS9iEHkrCMgM8= @@ -541,8 +551,19 @@ github.com/pelletier/go-toml v1.9.3/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCko github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/lz4 v2.5.2+incompatible h1:WCjObylUIOlKy/+7Abdn34TLIkXiA4UWUMhxq9m9ZXI= github.com/pierrec/lz4 v2.5.2+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= -github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= +github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= +github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c h1:xpW9bvK+HuuTmyFqUwr+jcCvpVkK7sumiz+ko5H9eq4= +github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg= +github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00 h1:C3N3itkduZXDZFh4N3vQ5HEtld3S+Y+StULhWVvumU0= +github.com/pingcap/failpoint v0.0.0-20210918120811-547c13e3eb00/go.mod h1:4qGtCB0QK0wBzKtFEGDhxXnSnbQApw1gc9siScUl8ew= +github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989 h1:surzm05a8C9dN8dIUmo4Be2+pMRb6f55i+UIYrluu2E= +github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989/go.mod h1:O17XtbryoCJhkKGbT62+L2OlrniwqiGLSqrmdHCMzZw= +github.com/pingcap/kvproto v0.0.0-20221026112947-f8d61344b172/go.mod h1:OYtxs0786qojVTmkVeufx93xe+jUgm56GUYRIKnmaGI= +github.com/pingcap/kvproto v0.0.0-20221129023506-621ec37aac7a h1:LzIZsQpXQlj8yF7+yvyOg680OaPq7bmPuDuszgXfHsw= +github.com/pingcap/kvproto v0.0.0-20221129023506-621ec37aac7a/go.mod h1:OYtxs0786qojVTmkVeufx93xe+jUgm56GUYRIKnmaGI= +github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81 h1:URLoJ61DmmY++Sa/yyPEQHG2s/ZBeV1FbIswHEMrdoY= +github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -559,6 +580,7 @@ github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= +github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.11.1/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.14.0 h1:nJdhIvne2eSX/XRAFV9PcvFFRbrjbcTUj0VP62TMhnw= github.com/prometheus/client_golang v1.14.0/go.mod h1:8vpkKitgIVNcqrRBWh1C4TIUQgYNtG/XQE4E/Zae36Y= @@ -583,6 +605,10 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1 github.com/prometheus/procfs v0.9.0 h1:wzCHvIvM5SxWqYvwgVL7yJY8Lz3PKn49KQtpgMYJfhI= github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= +github.com/quasilyte/go-ruleguard/dsl v0.3.22 h1:wd8zkOhSNr+I+8Qeciml08ivDt1pSXe60+5DqOpCjPE= +github.com/quasilyte/go-ruleguard/dsl v0.3.22/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/clock v0.0.0-20190514195947-2896927a307a/go.mod h1:4r5QyqhjIWCcK8DO4KMclc5Iknq5qVBAlbYYzAbUScQ= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= @@ -601,6 +627,7 @@ github.com/santhosh-tekuri/jsonschema/v5 v5.0.0/go.mod h1:FKdcjfQW6rpZSnxxUvEA5H github.com/schollz/closestmatch v2.1.0+incompatible/go.mod h1:RtP1ddjLong6gTkbtmuhtR2uUrrJOpYzYRvbcPAid+g= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/shirou/gopsutil/v3 v3.22.9 h1:yibtJhIVEMcdw+tCTbOPiF1VcsuDeTE4utJ8Dm4c5eA= github.com/shirou/gopsutil/v3 v3.22.9/go.mod h1:bBYl1kjgEJpWpxeHmLI+dVHWtyAwfcmSBLDsp2TNT8A= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= @@ -641,6 +668,8 @@ github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DM github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/spf13/viper v1.8.1 h1:Kq1fyeebqsBfbjZj4EL7gj2IO0mMaiyjYUWcUsl2O44= github.com/spf13/viper v1.8.1/go.mod h1:o0Pch8wJ9BVSWGQMbra6iw0oQ5oktSIBaujf1rJH9Ns= +github.com/stathat/consistent v1.0.0 h1:ZFJ1QTRn8npNBKW065raSZ8xfOqhpb8vLOkfp4CcL/U= +github.com/stathat/consistent v1.0.0/go.mod h1:uajTPbgSygZBJ+V+0mY7meZ8i0XAcZs7AQ6V121XSxw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= @@ -663,6 +692,12 @@ github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M= +github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a h1:J/YdBZ46WKpXsxsW93SG+q0F8KI+yFrcIDT4c/RNoc4= +github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a/go.mod h1:h4xBhSNtOeEosLJ4P7JyKXX7Cabg7AVkWCK5gV2vOrM= +github.com/tikv/client-go/v2 v2.0.4 h1:cPtMXTExqjzk8L40qhrgB/mXiBXKP5LRU0vwjtI2Xxo= +github.com/tikv/client-go/v2 v2.0.4/go.mod h1:v52O5zDtv2BBus4lm5yrSQhxGW4Z4RaXWfg0U1Kuyqo= +github.com/tikv/pd/client v0.0.0-20221031025758-80f0d8ca4d07 h1:ckPpxKcl75mO2N6a4cJXiZH43hvcHPpqc9dh1TmH1nc= +github.com/tikv/pd/client v0.0.0-20221031025758-80f0d8ca4d07/go.mod h1:CipBxPfxPUME+BImx9MUYXCnAVLS3VJUr3mnSJwh40A= github.com/tklauser/go-sysconf v0.3.10 h1:IJ1AZGZRWbY8T5Vfk04D9WOA5WSejdflXxP03OUqALw= github.com/tklauser/go-sysconf v0.3.10/go.mod h1:C8XykCvCb+Gn0oNCWPIlcb0RuglQTYaQ2hGm7jmxEFk= github.com/tklauser/numcpus v0.4.0 h1:E53Dm1HjH1/R2/aoCtXtPgzmElmn51aOkhCFSuZq//o= @@ -670,6 +705,8 @@ github.com/tklauser/numcpus v0.4.0/go.mod h1:1+UI3pD8NW14VMwdgJNJ1ESk2UnwhAnz5hM github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 h1:uruHq4dN7GR16kFc5fp3d1RIYzJW5onx8Ybykw2YQFA= github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= +github.com/twmb/murmur3 v1.1.3 h1:D83U0XYKcHRYwYIpBKf3Pks91Z0Byda/9SJ8B6EMRcA= +github.com/twmb/murmur3 v1.1.3/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o= github.com/uber/jaeger-client-go v2.30.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= @@ -759,18 +796,25 @@ go.opentelemetry.io/proto/otlp v0.9.0/go.mod h1:1vKfU9rv61e9EVGthD1zNvUbiwPcimSs go.opentelemetry.io/proto/otlp v0.19.0 h1:IVN6GR+mhC4s5yfcTbmzHYODqvWAp3ZedA2SJPI1Nnw= go.opentelemetry.io/proto/otlp v0.19.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/automaxprocs v1.5.2 h1:2LxUOGiR3O6tw8ui5sZa2LAaHnsviZdVOUZw4fvbnME= go.uber.org/automaxprocs v1.5.2/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= +go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= +go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.7.0 h1:zaiO/rmgFjbmCXdSYJWQcdvOCsthmdaHfr3Gm2Kx4Ec= +go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -go.uber.org/zap v1.17.0 h1:MTjgFu6ZLKvY6Pvaqk97GlxNBuMpV4Hy/3P6tRGlI2U= go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= +go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= +go.uber.org/zap v1.20.0 h1:N4oPlghZwYG55MlU6LXk/Zp00FVNE9X9wrYO8CEs4lc= +go.uber.org/zap v1.20.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -785,8 +829,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -844,6 +888,7 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191002035440-2ec189313ef0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -874,8 +919,8 @@ golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20211008194852-3b03d305991f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -981,12 +1026,12 @@ golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= +golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -996,8 +1041,8 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1025,6 +1070,8 @@ golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgw golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191112195655-aa38f8e97acc/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= @@ -1063,6 +1110,7 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1106,6 +1154,7 @@ google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRn google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= +google.golang.org/genproto v0.0.0-20190927181202-20e1ac93f88c/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= @@ -1150,6 +1199,7 @@ google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZi google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.24.0/go.mod h1:XDChyiUovWa60DnaeDeZmSW86xtLtjtZbwvSiRnRtcA= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= @@ -1170,6 +1220,7 @@ google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQ google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.41.0/go.mod h1:U3l9uK9J0sini8mHphKoXyaqDA/8VyGnDee1zzIUK6k= google.golang.org/grpc v1.42.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= +google.golang.org/grpc v1.43.0/go.mod h1:k+4IHHFw41K8+bbowsex27ge2rCb65oeWqe4jJ590SU= google.golang.org/grpc v1.46.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= google.golang.org/grpc v1.54.0 h1:EhTqbhiYeixwWQtAEZAxmV9MGqcjEU2mFx52xCzNyag= google.golang.org/grpc v1.54.0/go.mod h1:PUSEXI6iWghWaB6lXM4knEgpJNu2qUcKfDtNci3EC2g= @@ -1243,3 +1294,4 @@ rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= sigs.k8s.io/yaml v1.2.0 h1:kr/MCeFWJWTwyaHoR9c8EjH9OumOmoF9YGiZd7lFm/Q= sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= +stathat.com/c/consistent v1.0.0 h1:ezyc51EGcRPJUxfHGSgJjWzJdj3NiMU9pNfLNGiXV0c= diff --git a/pkg/log/global.go b/pkg/log/global.go index 8fc1df50af4dd..297c906fcf0eb 100644 --- a/pkg/log/global.go +++ b/pkg/log/global.go @@ -22,9 +22,7 @@ import ( type ctxLogKeyType struct{} -var ( - CtxLogKey = ctxLogKeyType{} -) +var CtxLogKey = ctxLogKeyType{} // Debug logs a message at DebugLevel. The message includes any fields passed // at the log site, as well as any fields accumulated on the logger. diff --git a/pkg/log/log.go b/pkg/log/log.go index 25de71bde18ac..2a9d1947c17e5 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -64,7 +64,6 @@ func init() { r := utils.NewRateLimiter(1.0, 60.0) _globalR.Store(r) - } // InitLogger initializes a zap logger. @@ -218,8 +217,10 @@ func ReplaceGlobals(logger *zap.Logger, props *ZapProperties) { } func replaceLeveledLoggers(debugLogger *zap.Logger) { - levels := []zapcore.Level{zapcore.DebugLevel, zapcore.InfoLevel, zapcore.WarnLevel, zapcore.ErrorLevel, - zapcore.DPanicLevel, zapcore.PanicLevel, zapcore.FatalLevel} + levels := []zapcore.Level{ + zapcore.DebugLevel, zapcore.InfoLevel, zapcore.WarnLevel, zapcore.ErrorLevel, + zapcore.DPanicLevel, zapcore.PanicLevel, zapcore.FatalLevel, + } for _, level := range levels { levelL := debugLogger.WithOptions(zap.IncreaseLevel(level)) _globalLevelLogger.Store(level, levelL) diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go index fa6d64897d2ce..d42673244ec3a 100644 --- a/pkg/log/log_test.go +++ b/pkg/log/log_test.go @@ -253,7 +253,6 @@ func TestLeveledLogger(t *testing.T) { SetLevel(zapcore.FatalLevel + 1) assert.Equal(t, ctxL(), L()) SetLevel(orgLevel) - } func TestStdAndFileLogger(t *testing.T) { diff --git a/pkg/log/mlogger_test.go b/pkg/log/mlogger_test.go index fcdf9339c595a..63a71f7140b3e 100644 --- a/pkg/log/mlogger_test.go +++ b/pkg/log/mlogger_test.go @@ -28,6 +28,7 @@ func TestExporterV2(t *testing.T) { ts.assertMessagesContains("traceID=mock-trace") ts.CleanBuffer() + // nolint Ctx(nil).Info("empty context") ts.assertMessagesNotContains("traceID") diff --git a/pkg/log/zap_log_test.go b/pkg/log/zap_log_test.go index c54ff4bf622da..54edbe471f217 100644 --- a/pkg/log/zap_log_test.go +++ b/pkg/log/zap_log_test.go @@ -32,9 +32,9 @@ package log import ( "fmt" - "io/ioutil" "math" "net" + "os" "strings" "testing" "time" @@ -226,7 +226,7 @@ func TestRotateLog(t *testing.T) { logger.Info(string(data)) data = data[:0] } - files, _ := ioutil.ReadDir(tempDir) + files, _ := os.ReadDir(tempDir) assert.Len(t, files, c.expectedFileNum) }) } diff --git a/pkg/metrics/datacoord_metrics.go b/pkg/metrics/datacoord_metrics.go index ab401dc583b84..aa8ac822624a1 100644 --- a/pkg/metrics/datacoord_metrics.go +++ b/pkg/metrics/datacoord_metrics.go @@ -19,9 +19,10 @@ package metrics import ( "fmt" + "github.com/prometheus/client_golang/prometheus" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/util/typeutil" - "github.com/prometheus/client_golang/prometheus" ) const ( @@ -41,7 +42,7 @@ const ( ) var ( - //DataCoordNumDataNodes records the num of data nodes managed by DataCoord. + // DataCoordNumDataNodes records the num of data nodes managed by DataCoord. DataCoordNumDataNodes = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: milvusNamespace, @@ -60,7 +61,7 @@ var ( segmentStateLabelName, }) - //DataCoordCollectionNum records the num of collections managed by DataCoord. + // DataCoordCollectionNum records the num of collections managed by DataCoord. DataCoordNumCollections = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: milvusNamespace, @@ -148,6 +149,26 @@ var ( Buckets: buckets, }, []string{segmentFileTypeLabelName}) + /* garbage collector related metrics */ + + // GarbageCollectorListLatency metrics for gc scan storage files. + GarbageCollectorListLatency = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: milvusNamespace, + Subsystem: typeutil.DataCoordRole, + Name: "gc_list_latency", + Help: "latency of list objects in storage while garbage collecting (in milliseconds)", + Buckets: longTaskBuckets, + }, []string{nodeIDLabelName, segmentFileTypeLabelName}) + + GarbageCollectorRunCount = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: milvusNamespace, + Subsystem: typeutil.DataCoordRole, + Name: "gc_run_count", + Help: "garbage collection running count", + }, []string{nodeIDLabelName}) + /* hard to implement, commented now DataCoordSegmentSizeRatio = prometheus.NewHistogramVec( prometheus.HistogramOpts{ diff --git a/pkg/metrics/datanode_metrics.go b/pkg/metrics/datanode_metrics.go index 89a9c5cb2c03e..1403c10a056c1 100644 --- a/pkg/metrics/datanode_metrics.go +++ b/pkg/metrics/datanode_metrics.go @@ -154,7 +154,7 @@ var ( Subsystem: typeutil.DataNodeRole, Name: "compaction_latency", Help: "latency of compaction operation", - Buckets: []float64{0.001, 0.1, 0.5, 1, 5, 10, 20, 50, 100, 250, 500, 1000, 3600, 5000, 10000}, // unit seconds + Buckets: longTaskBuckets, }, []string{ nodeIDLabelName, }) diff --git a/pkg/metrics/indexnode_metrics.go b/pkg/metrics/indexnode_metrics.go index f85f59a116f2f..3a02286307ead 100644 --- a/pkg/metrics/indexnode_metrics.go +++ b/pkg/metrics/indexnode_metrics.go @@ -23,6 +23,9 @@ import ( ) var ( + // unit second, from 1ms to 2hrs + indexBucket = []float64{0.001, 0.1, 0.5, 1, 5, 10, 20, 50, 100, 250, 500, 1000, 3600, 5000, 10000} + IndexNodeBuildIndexTaskCounter = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: milvusNamespace, @@ -37,7 +40,7 @@ var ( Subsystem: typeutil.IndexNodeRole, Name: "load_field_latency", Help: "latency of loading the field data", - Buckets: buckets, + Buckets: indexBucket, }, []string{nodeIDLabelName}) IndexNodeDecodeFieldLatency = prometheus.NewHistogramVec( @@ -46,7 +49,7 @@ var ( Subsystem: typeutil.IndexNodeRole, Name: "decode_field_latency", Help: "latency of decode field data", - Buckets: buckets, + Buckets: indexBucket, }, []string{nodeIDLabelName}) IndexNodeKnowhereBuildIndexLatency = prometheus.NewHistogramVec( @@ -55,7 +58,7 @@ var ( Subsystem: typeutil.IndexNodeRole, Name: "knowhere_build_index_latency", Help: "latency of building the index by knowhere", - Buckets: buckets, + Buckets: indexBucket, }, []string{nodeIDLabelName}) IndexNodeEncodeIndexFileLatency = prometheus.NewHistogramVec( @@ -64,7 +67,7 @@ var ( Subsystem: typeutil.IndexNodeRole, Name: "encode_index_latency", Help: "latency of encoding the index file", - Buckets: buckets, + Buckets: indexBucket, }, []string{nodeIDLabelName}) IndexNodeSaveIndexFileLatency = prometheus.NewHistogramVec( @@ -73,7 +76,7 @@ var ( Subsystem: typeutil.IndexNodeRole, Name: "save_index_latency", Help: "latency of saving the index file", - Buckets: buckets, + Buckets: indexBucket, }, []string{nodeIDLabelName}) IndexNodeIndexTaskLatencyInQueue = prometheus.NewHistogramVec( @@ -91,7 +94,7 @@ var ( Subsystem: typeutil.IndexNodeRole, Name: "build_index_latency", Help: "latency of build index for segment", - Buckets: buckets, + Buckets: indexBucket, }, []string{nodeIDLabelName}) ) diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 015e7c25e10f6..050c2a5ced511 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -95,6 +95,9 @@ var ( // [1 2 4 8 16 32 64 128 256 512 1024 2048 4096 8192 16384 32768 65536 1.31072e+05] buckets = prometheus.ExponentialBuckets(1, 2, 18) + // longTaskBuckets provides long task duration in milliseconds + longTaskBuckets = []float64{1, 100, 500, 1000, 5000, 10000, 20000, 50000, 100000, 250000, 500000, 1000000, 3600000, 5000000, 10000000} // unit milliseconds + NumNodes = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: milvusNamespace, @@ -113,10 +116,23 @@ var ( lockType, lockOp, }) + + metricRegisterer prometheus.Registerer ) +// GetRegisterer returns the global prometheus registerer +// metricsRegistry must be call after Register is called or no Register is called. +func GetRegisterer() prometheus.Registerer { + if metricRegisterer == nil { + return prometheus.DefaultRegisterer + } + return metricRegisterer +} + // Register serves prometheus http service -func Register(r *prometheus.Registry) { +// Should be called by init function. +func Register(r prometheus.Registerer) { r.MustRegister(NumNodes) r.MustRegister(LockCosts) + metricRegisterer = r } diff --git a/pkg/metrics/metrics_test.go b/pkg/metrics/metrics_test.go index e146627827265..9709349c555b6 100644 --- a/pkg/metrics/metrics_test.go +++ b/pkg/metrics/metrics_test.go @@ -39,3 +39,14 @@ func TestRegisterMetrics(t *testing.T) { RegisterMsgStreamMetrics(r) }) } + +func TestGetRegisterer(t *testing.T) { + register := GetRegisterer() + assert.NotNil(t, register) + assert.Equal(t, prometheus.DefaultRegisterer, register) + r := prometheus.NewRegistry() + Register(r) + register = GetRegisterer() + assert.NotNil(t, register) + assert.Equal(t, r, register) +} diff --git a/pkg/metrics/proxy_metrics.go b/pkg/metrics/proxy_metrics.go index 32ecab9677d5d..3b45aaf43fd04 100644 --- a/pkg/metrics/proxy_metrics.go +++ b/pkg/metrics/proxy_metrics.go @@ -334,26 +334,48 @@ func RegisterProxy(registry *prometheus.Registry) { } func CleanupCollectionMetrics(nodeID int64, collection string) { - ProxyCollectionSQLatency.Delete(prometheus.Labels{nodeIDLabelName: strconv.FormatInt(nodeID, 10), - queryTypeLabelName: SearchLabel, collectionName: collection}) - ProxyCollectionSQLatency.Delete(prometheus.Labels{nodeIDLabelName: strconv.FormatInt(nodeID, 10), - queryTypeLabelName: QueryLabel, collectionName: collection}) - ProxyCollectionMutationLatency.Delete(prometheus.Labels{nodeIDLabelName: strconv.FormatInt(nodeID, 10), - msgTypeLabelName: InsertLabel, collectionName: collection}) - ProxyCollectionMutationLatency.Delete(prometheus.Labels{nodeIDLabelName: strconv.FormatInt(nodeID, 10), - msgTypeLabelName: DeleteLabel, collectionName: collection}) - ProxyReceivedNQ.Delete(prometheus.Labels{nodeIDLabelName: strconv.FormatInt(nodeID, 10), - queryTypeLabelName: SearchLabel, collectionName: collection}) - ProxyReceivedNQ.Delete(prometheus.Labels{nodeIDLabelName: strconv.FormatInt(nodeID, 10), - queryTypeLabelName: QueryLabel, collectionName: collection}) - ProxyReceiveBytes.Delete(prometheus.Labels{nodeIDLabelName: strconv.FormatInt(nodeID, 10), - msgTypeLabelName: SearchLabel, collectionName: collection}) - ProxyReceiveBytes.Delete(prometheus.Labels{nodeIDLabelName: strconv.FormatInt(nodeID, 10), - msgTypeLabelName: QueryLabel, collectionName: collection}) - ProxyReceiveBytes.Delete(prometheus.Labels{nodeIDLabelName: strconv.FormatInt(nodeID, 10), - msgTypeLabelName: InsertLabel, collectionName: collection}) - ProxyReceiveBytes.Delete(prometheus.Labels{nodeIDLabelName: strconv.FormatInt(nodeID, 10), - msgTypeLabelName: DeleteLabel, collectionName: collection}) - ProxyReceiveBytes.Delete(prometheus.Labels{nodeIDLabelName: strconv.FormatInt(nodeID, 10), - msgTypeLabelName: UpsertLabel, collectionName: collection}) + ProxyCollectionSQLatency.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + queryTypeLabelName: SearchLabel, collectionName: collection, + }) + ProxyCollectionSQLatency.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + queryTypeLabelName: QueryLabel, collectionName: collection, + }) + ProxyCollectionMutationLatency.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: InsertLabel, collectionName: collection, + }) + ProxyCollectionMutationLatency.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: DeleteLabel, collectionName: collection, + }) + ProxyReceivedNQ.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + queryTypeLabelName: SearchLabel, collectionName: collection, + }) + ProxyReceivedNQ.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + queryTypeLabelName: QueryLabel, collectionName: collection, + }) + ProxyReceiveBytes.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: SearchLabel, collectionName: collection, + }) + ProxyReceiveBytes.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: QueryLabel, collectionName: collection, + }) + ProxyReceiveBytes.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: InsertLabel, collectionName: collection, + }) + ProxyReceiveBytes.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: DeleteLabel, collectionName: collection, + }) + ProxyReceiveBytes.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: UpsertLabel, collectionName: collection, + }) } diff --git a/pkg/metrics/rootcoord_metrics.go b/pkg/metrics/rootcoord_metrics.go index 1ce0e94a68c72..c73238f470b37 100644 --- a/pkg/metrics/rootcoord_metrics.go +++ b/pkg/metrics/rootcoord_metrics.go @@ -36,7 +36,7 @@ var ( Help: "count of DDL operations", }, []string{functionLabelName, statusLabelName}) - //RootCoordDDLReqLatency records the latency for read type of DDL operations. + // RootCoordDDLReqLatency records the latency for read type of DDL operations. RootCoordDDLReqLatency = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: milvusNamespace, @@ -65,7 +65,7 @@ var ( Help: "count of ID allocated", }) - //RootCoordTimestamp records the number of timestamp allocations in RootCoord. + // RootCoordTimestamp records the number of timestamp allocations in RootCoord. RootCoordTimestamp = prometheus.NewGauge( prometheus.GaugeOpts{ Namespace: milvusNamespace, diff --git a/pkg/mq/msgdispatcher/client.go b/pkg/mq/msgdispatcher/client.go index 38a4c84f5d093..762bb1fe41a92 100644 --- a/pkg/mq/msgdispatcher/client.go +++ b/pkg/mq/msgdispatcher/client.go @@ -17,11 +17,12 @@ package msgdispatcher import ( + "context" "sync" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" @@ -35,7 +36,7 @@ type ( ) type Client interface { - Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) + Register(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) Deregister(vchannel string) Close() } @@ -60,7 +61,7 @@ func NewClient(factory msgstream.Factory, role string, nodeID int64) Client { } } -func (c *client) Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) { +func (c *client) Register(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) { log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) pchannel := funcutil.ToPhysicalChannel(vchannel) @@ -73,7 +74,7 @@ func (c *client) Register(vchannel string, pos *Pos, subPos SubPos) (<-chan *Msg c.managers[pchannel] = manager go manager.Run() } - ch, err := manager.Add(vchannel, pos, subPos) + ch, err := manager.Add(ctx, vchannel, pos, subPos) if err != nil { if manager.Num() == 0 { manager.Close() diff --git a/pkg/mq/msgdispatcher/client_test.go b/pkg/mq/msgdispatcher/client_test.go index 359ea06740b45..6d24f64cc017e 100644 --- a/pkg/mq/msgdispatcher/client_test.go +++ b/pkg/mq/msgdispatcher/client_test.go @@ -17,10 +17,12 @@ package msgdispatcher import ( + "context" "fmt" "math/rand" "sync" "testing" + "time" "github.com/stretchr/testify/assert" "go.uber.org/atomic" @@ -32,10 +34,25 @@ import ( func TestClient(t *testing.T) { client := NewClient(newMockFactory(), typeutil.ProxyRole, 1) assert.NotNil(t, client) - _, err := client.Register("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown) + _, err := client.Register(context.Background(), "mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown) + assert.NoError(t, err) + _, err = client.Register(context.Background(), "mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown) assert.NoError(t, err) assert.NotPanics(t, func() { client.Deregister("mock_vchannel_0") + client.Close() + }) + + t.Run("with timeout ctx", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Millisecond) + defer cancel() + <-time.After(2 * time.Millisecond) + + client := NewClient(newMockFactory(), typeutil.DataNodeRole, 1) + defer client.Close() + assert.NotNil(t, client) + _, err := client.Register(ctx, "mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown) + assert.Error(t, err) }) } @@ -49,7 +66,7 @@ func TestClient_Concurrency(t *testing.T) { vchannel := fmt.Sprintf("mock-vchannel-%d-%d", i, rand.Int()) wg.Add(1) go func() { - _, err := client1.Register(vchannel, nil, mqwrapper.SubscriptionPositionUnknown) + _, err := client1.Register(context.Background(), vchannel, nil, mqwrapper.SubscriptionPositionUnknown) assert.NoError(t, err) for j := 0; j < rand.Intn(2); j++ { client1.Deregister(vchannel) diff --git a/pkg/mq/msgdispatcher/dispatcher.go b/pkg/mq/msgdispatcher/dispatcher.go index fdf9e81b13ce4..ee552046ddc08 100644 --- a/pkg/mq/msgdispatcher/dispatcher.go +++ b/pkg/mq/msgdispatcher/dispatcher.go @@ -22,10 +22,10 @@ import ( "sync" "time" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "go.uber.org/atomic" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -78,7 +78,8 @@ type Dispatcher struct { stream msgstream.MsgStream } -func NewDispatcher(factory msgstream.Factory, +func NewDispatcher(ctx context.Context, + factory msgstream.Factory, isMain bool, pchannel string, position *Pos, @@ -90,14 +91,19 @@ func NewDispatcher(factory msgstream.Factory, log := log.With(zap.String("pchannel", pchannel), zap.String("subName", subName), zap.Bool("isMain", isMain)) log.Info("creating dispatcher...") - stream, err := factory.NewTtMsgStream(context.Background()) + stream, err := factory.NewTtMsgStream(ctx) if err != nil { return nil, err } if position != nil && len(position.MsgID) != 0 { position.ChannelName = funcutil.ToPhysicalChannel(position.ChannelName) - stream.AsConsumer([]string{pchannel}, subName, mqwrapper.SubscriptionPositionUnknown) - err = stream.Seek([]*Pos{position}) + err = stream.AsConsumer(ctx, []string{pchannel}, subName, mqwrapper.SubscriptionPositionUnknown) + if err != nil { + log.Error("asConsumer failed", zap.Error(err)) + return nil, err + } + + err = stream.Seek(ctx, []*Pos{position}) if err != nil { stream.Close() log.Error("seek failed", zap.Error(err)) @@ -107,7 +113,11 @@ func NewDispatcher(factory msgstream.Factory, log.Info("seek successfully", zap.Time("posTime", posTime), zap.Duration("tsLag", time.Since(posTime))) } else { - stream.AsConsumer([]string{pchannel}, subName, subPos) + err := stream.AsConsumer(ctx, []string{pchannel}, subName, subPos) + if err != nil { + log.Error("asConsumer failed", zap.Error(err)) + return nil, err + } log.Info("asConsumer successfully") } diff --git a/pkg/mq/msgdispatcher/dispatcher_test.go b/pkg/mq/msgdispatcher/dispatcher_test.go index ad642b994b4fb..e7c79b54fc0fa 100644 --- a/pkg/mq/msgdispatcher/dispatcher_test.go +++ b/pkg/mq/msgdispatcher/dispatcher_test.go @@ -21,15 +21,19 @@ import ( "testing" "time" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/net/context" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" ) func TestDispatcher(t *testing.T) { + ctx := context.Background() t.Run("test base", func(t *testing.T) { - d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil, + d, err := NewDispatcher(ctx, newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil) assert.NoError(t, err) assert.NotPanics(t, func() { @@ -49,8 +53,23 @@ func TestDispatcher(t *testing.T) { assert.Equal(t, pos.Timestamp, curTs) }) + t.Run("test AsConsumer fail", func(t *testing.T) { + ms := msgstream.NewMockMsgStream(t) + ms.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock error")) + factory := &msgstream.MockMqFactory{ + NewMsgStreamFunc: func(ctx context.Context) (msgstream.MsgStream, error) { + return ms, nil + }, + } + d, err := NewDispatcher(ctx, factory, true, "mock_pchannel_0", nil, + "mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil) + + assert.Error(t, err) + assert.Nil(t, d) + }) + t.Run("test target", func(t *testing.T) { - d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil, + d, err := NewDispatcher(ctx, newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil) assert.NoError(t, err) output := make(chan *msgstream.MsgPack, 1024) @@ -113,7 +132,7 @@ func TestDispatcher(t *testing.T) { } func BenchmarkDispatcher_handle(b *testing.B) { - d, err := NewDispatcher(newMockFactory(), true, "mock_pchannel_0", nil, + d, err := NewDispatcher(context.Background(), newMockFactory(), true, "mock_pchannel_0", nil, "mock_subName_0", mqwrapper.SubscriptionPositionEarliest, nil, nil) assert.NoError(b, err) diff --git a/pkg/mq/msgdispatcher/manager.go b/pkg/mq/msgdispatcher/manager.go index 4546dc6bf9914..ecd3d079aeb37 100644 --- a/pkg/mq/msgdispatcher/manager.go +++ b/pkg/mq/msgdispatcher/manager.go @@ -22,7 +22,6 @@ import ( "sync" "time" - "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" @@ -31,15 +30,14 @@ import ( "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/retry" + "github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) -var ( - CheckPeriod = 1 * time.Second // TODO: dyh, move to config -) +var CheckPeriod = 1 * time.Second // TODO: dyh, move to config type DispatcherManager interface { - Add(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) + Add(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) Remove(vchannel string) Num() int Run() @@ -85,14 +83,14 @@ func (c *dispatcherManager) constructSubName(vchannel string, isMain bool) strin return fmt.Sprintf("%s-%d-%s-%t", c.role, c.nodeID, vchannel, isMain) } -func (c *dispatcherManager) Add(vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) { +func (c *dispatcherManager) Add(ctx context.Context, vchannel string, pos *Pos, subPos SubPos) (<-chan *MsgPack, error) { log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel)) c.mu.Lock() defer c.mu.Unlock() isMain := c.mainDispatcher == nil - d, err := NewDispatcher(c.factory, isMain, c.pchannel, pos, + d, err := NewDispatcher(ctx, c.factory, isMain, c.pchannel, pos, c.constructSubName(vchannel, isMain), subPos, c.lagNotifyChan, c.lagTargets) if err != nil { return nil, err @@ -236,7 +234,7 @@ func (c *dispatcherManager) split(t *target) { var newSolo *Dispatcher err := retry.Do(context.Background(), func() error { var err error - newSolo, err = NewDispatcher(c.factory, false, c.pchannel, t.pos, + newSolo, err = NewDispatcher(context.Background(), c.factory, false, c.pchannel, t.pos, c.constructSubName(t.vchannel, false), mqwrapper.SubscriptionPositionUnknown, c.lagNotifyChan, c.lagTargets) return err }, retry.Attempts(10)) diff --git a/pkg/mq/msgdispatcher/manager_test.go b/pkg/mq/msgdispatcher/manager_test.go index d7a2c38bfc8bd..621767dd6e42b 100644 --- a/pkg/mq/msgdispatcher/manager_test.go +++ b/pkg/mq/msgdispatcher/manager_test.go @@ -25,10 +25,10 @@ import ( "testing" "time" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -46,7 +46,7 @@ func TestManager(t *testing.T) { for j := 0; j < r; j++ { offset++ t.Logf("dyh add, %s", fmt.Sprintf("mock-pchannel-0_vchannel_%d", offset)) - _, err := c.Add(fmt.Sprintf("mock-pchannel-0_vchannel_%d", offset), nil, mqwrapper.SubscriptionPositionUnknown) + _, err := c.Add(context.Background(), fmt.Sprintf("mock-pchannel-0_vchannel_%d", offset), nil, mqwrapper.SubscriptionPositionUnknown) assert.NoError(t, err) assert.Equal(t, offset, c.Num()) } @@ -61,13 +61,14 @@ func TestManager(t *testing.T) { t.Run("test merge and split", func(t *testing.T) { prefix := fmt.Sprintf("mock%d", time.Now().UnixNano()) + ctx := context.Background() c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory()) assert.NotNil(t, c) - _, err := c.Add("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown) + _, err := c.Add(ctx, "mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown) assert.NoError(t, err) - _, err = c.Add("mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown) + _, err = c.Add(ctx, "mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown) assert.NoError(t, err) - _, err = c.Add("mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown) + _, err = c.Add(ctx, "mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown) assert.NoError(t, err) assert.Equal(t, 3, c.Num()) @@ -85,13 +86,14 @@ func TestManager(t *testing.T) { t.Run("test run and close", func(t *testing.T) { prefix := fmt.Sprintf("mock%d", time.Now().UnixNano()) + ctx := context.Background() c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory()) assert.NotNil(t, c) - _, err := c.Add("mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown) + _, err := c.Add(ctx, "mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown) assert.NoError(t, err) - _, err = c.Add("mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown) + _, err = c.Add(ctx, "mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown) assert.NoError(t, err) - _, err = c.Add("mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown) + _, err = c.Add(ctx, "mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown) assert.NoError(t, err) assert.Equal(t, 3, c.Num()) @@ -105,6 +107,28 @@ func TestManager(t *testing.T) { c.Close() }) }) + + t.Run("test add timeout", func(t *testing.T) { + prefix := fmt.Sprintf("mock%d", time.Now().UnixNano()) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Millisecond*2) + defer cancel() + time.Sleep(time.Millisecond * 2) + c := NewDispatcherManager(prefix+"_pchannel_0", typeutil.ProxyRole, 1, newMockFactory()) + go c.Run() + assert.NotNil(t, c) + _, err := c.Add(ctx, "mock_vchannel_0", nil, mqwrapper.SubscriptionPositionUnknown) + assert.Error(t, err) + _, err = c.Add(ctx, "mock_vchannel_1", nil, mqwrapper.SubscriptionPositionUnknown) + assert.Error(t, err) + _, err = c.Add(ctx, "mock_vchannel_2", nil, mqwrapper.SubscriptionPositionUnknown) + assert.Error(t, err) + assert.Equal(t, 0, c.Num()) + + assert.NotPanics(t, func() { + c.Close() + }) + }) } type vchannelHelper struct { @@ -232,7 +256,7 @@ func (suite *SimulationSuite) consumeMsg(ctx context.Context, wg *sync.WaitGroup } func (suite *SimulationSuite) produceTimeTickOnly(ctx context.Context) { - var tt = 1 + tt := 1 ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for { @@ -255,7 +279,7 @@ func (suite *SimulationSuite) TestDispatchToVchannels() { suite.vchannels = make(map[string]*vchannelHelper, vchannelNum) for i := 0; i < vchannelNum; i++ { vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i) - output, err := suite.manager.Add(vchannel, nil, mqwrapper.SubscriptionPositionEarliest) + output, err := suite.manager.Add(context.Background(), vchannel, nil, mqwrapper.SubscriptionPositionEarliest) assert.NoError(suite.T(), err) suite.vchannels[vchannel] = &vchannelHelper{output: output} } @@ -289,7 +313,7 @@ func (suite *SimulationSuite) TestMerge() { for i := 0; i < vchannelNum; i++ { vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i) - output, err := suite.manager.Add(vchannel, positions[rand.Intn(len(positions))], + output, err := suite.manager.Add(context.Background(), vchannel, positions[rand.Intn(len(positions))], mqwrapper.SubscriptionPositionUnknown) // seek from random position assert.NoError(suite.T(), err) suite.vchannels[vchannel] = &vchannelHelper{output: output} @@ -325,7 +349,7 @@ func (suite *SimulationSuite) TestSplit() { DefaultTargetChanSize = 10 } vchannel := fmt.Sprintf("%s_vchannelv%d", suite.pchannel, i) - _, err := suite.manager.Add(vchannel, nil, mqwrapper.SubscriptionPositionEarliest) + _, err := suite.manager.Add(context.Background(), vchannel, nil, mqwrapper.SubscriptionPositionEarliest) assert.NoError(suite.T(), err) } @@ -345,7 +369,6 @@ func (suite *SimulationSuite) TearDownTest() { } func (suite *SimulationSuite) TearDownSuite() { - } func TestSimulation(t *testing.T) { diff --git a/pkg/mq/msgdispatcher/mock_client.go b/pkg/mq/msgdispatcher/mock_client.go index 5a9a00d86dced..4b99b5e8f45a5 100644 --- a/pkg/mq/msgdispatcher/mock_client.go +++ b/pkg/mq/msgdispatcher/mock_client.go @@ -3,10 +3,13 @@ package msgdispatcher import ( - msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + context "context" + mqwrapper "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" mock "github.com/stretchr/testify/mock" + msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + msgstream "github.com/milvus-io/milvus/pkg/mq/msgstream" ) @@ -88,25 +91,25 @@ func (_c *MockClient_Deregister_Call) RunAndReturn(run func(string)) *MockClient return _c } -// Register provides a mock function with given fields: vchannel, pos, subPos -func (_m *MockClient) Register(vchannel string, pos *msgpb.MsgPosition, subPos mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error) { - ret := _m.Called(vchannel, pos, subPos) +// Register provides a mock function with given fields: ctx, vchannel, pos, subPos +func (_m *MockClient) Register(ctx context.Context, vchannel string, pos *msgpb.MsgPosition, subPos mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error) { + ret := _m.Called(ctx, vchannel, pos, subPos) var r0 <-chan *msgstream.MsgPack var r1 error - if rf, ok := ret.Get(0).(func(string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)); ok { - return rf(vchannel, pos, subPos) + if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)); ok { + return rf(ctx, vchannel, pos, subPos) } - if rf, ok := ret.Get(0).(func(string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) <-chan *msgstream.MsgPack); ok { - r0 = rf(vchannel, pos, subPos) + if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) <-chan *msgstream.MsgPack); ok { + r0 = rf(ctx, vchannel, pos, subPos) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(<-chan *msgstream.MsgPack) } } - if rf, ok := ret.Get(1).(func(string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) error); ok { - r1 = rf(vchannel, pos, subPos) + if rf, ok := ret.Get(1).(func(context.Context, string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) error); ok { + r1 = rf(ctx, vchannel, pos, subPos) } else { r1 = ret.Error(1) } @@ -120,16 +123,17 @@ type MockClient_Register_Call struct { } // Register is a helper method to define mock.On call +// - ctx context.Context // - vchannel string // - pos *msgpb.MsgPosition // - subPos mqwrapper.SubscriptionInitialPosition -func (_e *MockClient_Expecter) Register(vchannel interface{}, pos interface{}, subPos interface{}) *MockClient_Register_Call { - return &MockClient_Register_Call{Call: _e.mock.On("Register", vchannel, pos, subPos)} +func (_e *MockClient_Expecter) Register(ctx interface{}, vchannel interface{}, pos interface{}, subPos interface{}) *MockClient_Register_Call { + return &MockClient_Register_Call{Call: _e.mock.On("Register", ctx, vchannel, pos, subPos)} } -func (_c *MockClient_Register_Call) Run(run func(vchannel string, pos *msgpb.MsgPosition, subPos mqwrapper.SubscriptionInitialPosition)) *MockClient_Register_Call { +func (_c *MockClient_Register_Call) Run(run func(ctx context.Context, vchannel string, pos *msgpb.MsgPosition, subPos mqwrapper.SubscriptionInitialPosition)) *MockClient_Register_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(*msgpb.MsgPosition), args[2].(mqwrapper.SubscriptionInitialPosition)) + run(args[0].(context.Context), args[1].(string), args[2].(*msgpb.MsgPosition), args[3].(mqwrapper.SubscriptionInitialPosition)) }) return _c } @@ -139,7 +143,7 @@ func (_c *MockClient_Register_Call) Return(_a0 <-chan *msgstream.MsgPack, _a1 er return _c } -func (_c *MockClient_Register_Call) RunAndReturn(run func(string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)) *MockClient_Register_Call { +func (_c *MockClient_Register_Call) RunAndReturn(run func(context.Context, string, *msgpb.MsgPosition, mqwrapper.SubscriptionInitialPosition) (<-chan *msgstream.MsgPack, error)) *MockClient_Register_Call { _c.Call.Return(run) return _c } @@ -149,7 +153,8 @@ func (_c *MockClient_Register_Call) RunAndReturn(run func(string, *msgpb.MsgPosi func NewMockClient(t interface { mock.TestingT Cleanup(func()) -}) *MockClient { +}, +) *MockClient { mock := &MockClient{} mock.Mock.Test(t) diff --git a/pkg/mq/msgdispatcher/mock_test.go b/pkg/mq/msgdispatcher/mock_test.go index e8035c9c0d16d..b1685cf0c3db7 100644 --- a/pkg/mq/msgdispatcher/mock_test.go +++ b/pkg/mq/msgdispatcher/mock_test.go @@ -66,7 +66,7 @@ func getSeekPositions(factory msgstream.Factory, pchannel string, maxNum int) ([ return nil, err } defer stream.Close() - stream.AsConsumer([]string{pchannel}, fmt.Sprintf("%d", rand.Int()), mqwrapper.SubscriptionPositionEarliest) + stream.AsConsumer(context.TODO(), []string{pchannel}, fmt.Sprintf("%d", rand.Int()), mqwrapper.SubscriptionPositionEarliest) positions := make([]*msgstream.MsgPosition, 0) timeoutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -199,7 +199,6 @@ func defaultInsertRepackFunc( tsMsgs []msgstream.TsMsg, hashKeys [][]int32, ) (map[int32]*msgstream.MsgPack, error) { - if len(hashKeys) < len(tsMsgs) { return nil, fmt.Errorf( "the length of hash keys (%d) is less than the length of messages (%d)", diff --git a/pkg/mq/msgstream/common_mq_factory.go b/pkg/mq/msgstream/common_mq_factory.go index eefdbd323a7f4..1e301f3943719 100644 --- a/pkg/mq/msgstream/common_mq_factory.go +++ b/pkg/mq/msgstream/common_mq_factory.go @@ -4,6 +4,7 @@ import ( "context" "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" ) @@ -14,7 +15,7 @@ var _ Factory = &CommonFactory{} // It contains a function field named newer, which is a function that creates // an mqwrapper.Client when called. type CommonFactory struct { - Newer func() (mqwrapper.Client, error) // client constructor + Newer func(context.Context) (mqwrapper.Client, error) // client constructor DispatcherFactory ProtoUDFactory ReceiveBufSize int64 MQBufSize int64 @@ -23,7 +24,7 @@ type CommonFactory struct { // NewMsgStream is used to generate a new Msgstream object func (f *CommonFactory) NewMsgStream(ctx context.Context) (ms MsgStream, err error) { defer wrapError(&err, "NewMsgStream") - cli, err := f.Newer() + cli, err := f.Newer(ctx) if err != nil { return nil, err } @@ -33,7 +34,7 @@ func (f *CommonFactory) NewMsgStream(ctx context.Context) (ms MsgStream, err err // NewTtMsgStream is used to generate a new TtMsgstream object func (f *CommonFactory) NewTtMsgStream(ctx context.Context) (ms MsgStream, err error) { defer wrapError(&err, "NewTtMsgStream") - cli, err := f.Newer() + cli, err := f.Newer(ctx) if err != nil { return nil, err } @@ -50,7 +51,7 @@ func (f *CommonFactory) NewMsgStreamDisposer(ctx context.Context) func([]string, if err != nil { return err } - msgs.AsConsumer(channels, subName, mqwrapper.SubscriptionPositionUnknown) + msgs.AsConsumer(ctx, channels, subName, mqwrapper.SubscriptionPositionUnknown) msgs.Close() return nil } diff --git a/pkg/mq/msgstream/factory_stream_test.go b/pkg/mq/msgstream/factory_stream_test.go index 8e6ec850df35e..cb7ff8702cd08 100644 --- a/pkg/mq/msgstream/factory_stream_test.go +++ b/pkg/mq/msgstream/factory_stream_test.go @@ -8,12 +8,13 @@ import ( "runtime" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) type streamNewer func(ctx context.Context) (MsgStream, error) @@ -764,8 +765,8 @@ func consume(ctx context.Context, mq MsgStream) *MsgPack { func createAndSeekConsumer(ctx context.Context, t *testing.T, newer streamNewer, channels []string, seekPositions []*msgpb.MsgPosition) MsgStream { consumer, err := newer(ctx) assert.NoError(t, err) - consumer.AsConsumer(channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown) - err = consumer.Seek(seekPositions) + consumer.AsConsumer(context.Background(), channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown) + err = consumer.Seek(context.Background(), seekPositions) assert.NoError(t, err) return consumer } @@ -780,14 +781,14 @@ func createProducer(ctx context.Context, t *testing.T, newer streamNewer, channe func createConsumer(ctx context.Context, t *testing.T, newer streamNewer, channels []string) MsgStream { consumer, err := newer(ctx) assert.NoError(t, err) - consumer.AsConsumer(channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionEarliest) + consumer.AsConsumer(context.Background(), channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionEarliest) return consumer } func createLatestConsumer(ctx context.Context, t *testing.T, newer streamNewer, channels []string) MsgStream { consumer, err := newer(ctx) assert.NoError(t, err) - consumer.AsConsumer(channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionLatest) + consumer.AsConsumer(context.Background(), channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionLatest) return consumer } @@ -801,7 +802,7 @@ func createStream(ctx context.Context, t *testing.T, newer []streamNewer, channe consumer, err := newer[1](ctx) assert.NoError(t, err) - consumer.AsConsumer(channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionEarliest) + consumer.AsConsumer(context.Background(), channels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionEarliest) return producer, consumer } diff --git a/pkg/mq/msgstream/factory_test.go b/pkg/mq/msgstream/factory_test.go index 7c0ca90b6b01a..0f6661cc978ca 100644 --- a/pkg/mq/msgstream/factory_test.go +++ b/pkg/mq/msgstream/factory_test.go @@ -45,7 +45,7 @@ func TestNmq(t *testing.T) { f1 := NewNatsmqFactory() f2 := NewNatsmqFactory() - client, err := nmq.NewClientWithDefaultOptions() + client, err := nmq.NewClientWithDefaultOptions(context.Background()) if err != nil { panic(err) } diff --git a/pkg/mq/msgstream/mock_mq_factory.go b/pkg/mq/msgstream/mock_mq_factory.go index d457b3707a6c4..5b253e9aa84a3 100644 --- a/pkg/mq/msgstream/mock_mq_factory.go +++ b/pkg/mq/msgstream/mock_mq_factory.go @@ -14,3 +14,7 @@ func NewMockMqFactory() *MockMqFactory { func (m MockMqFactory) NewMsgStream(ctx context.Context) (MsgStream, error) { return m.NewMsgStreamFunc(ctx) } + +func (m MockMqFactory) NewTtMsgStream(ctx context.Context) (MsgStream, error) { + return m.NewMsgStreamFunc(ctx) +} diff --git a/pkg/mq/msgstream/mock_msgstream.go b/pkg/mq/msgstream/mock_msgstream.go index 42363d0b1073f..18be8faa46aee 100644 --- a/pkg/mq/msgstream/mock_msgstream.go +++ b/pkg/mq/msgstream/mock_msgstream.go @@ -3,9 +3,12 @@ package msgstream import ( - msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + context "context" + mqwrapper "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" mock "github.com/stretchr/testify/mock" + + msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" ) // MockMsgStream is an autogenerated mock type for the MsgStream type @@ -21,9 +24,18 @@ func (_m *MockMsgStream) EXPECT() *MockMsgStream_Expecter { return &MockMsgStream_Expecter{mock: &_m.Mock} } -// AsConsumer provides a mock function with given fields: channels, subName, position -func (_m *MockMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) { - _m.Called(channels, subName, position) +// AsConsumer provides a mock function with given fields: ctx, channels, subName, position +func (_m *MockMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error { + ret := _m.Called(ctx, channels, subName, position) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []string, string, mqwrapper.SubscriptionInitialPosition) error); ok { + r0 = rf(ctx, channels, subName, position) + } else { + r0 = ret.Error(0) + } + + return r0 } // MockMsgStream_AsConsumer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AsConsumer' @@ -32,26 +44,27 @@ type MockMsgStream_AsConsumer_Call struct { } // AsConsumer is a helper method to define mock.On call +// - ctx context.Context // - channels []string // - subName string // - position mqwrapper.SubscriptionInitialPosition -func (_e *MockMsgStream_Expecter) AsConsumer(channels interface{}, subName interface{}, position interface{}) *MockMsgStream_AsConsumer_Call { - return &MockMsgStream_AsConsumer_Call{Call: _e.mock.On("AsConsumer", channels, subName, position)} +func (_e *MockMsgStream_Expecter) AsConsumer(ctx interface{}, channels interface{}, subName interface{}, position interface{}) *MockMsgStream_AsConsumer_Call { + return &MockMsgStream_AsConsumer_Call{Call: _e.mock.On("AsConsumer", ctx, channels, subName, position)} } -func (_c *MockMsgStream_AsConsumer_Call) Run(run func(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition)) *MockMsgStream_AsConsumer_Call { +func (_c *MockMsgStream_AsConsumer_Call) Run(run func(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition)) *MockMsgStream_AsConsumer_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]string), args[1].(string), args[2].(mqwrapper.SubscriptionInitialPosition)) + run(args[0].(context.Context), args[1].([]string), args[2].(string), args[3].(mqwrapper.SubscriptionInitialPosition)) }) return _c } -func (_c *MockMsgStream_AsConsumer_Call) Return() *MockMsgStream_AsConsumer_Call { - _c.Call.Return() +func (_c *MockMsgStream_AsConsumer_Call) Return(_a0 error) *MockMsgStream_AsConsumer_Call { + _c.Call.Return(_a0) return _c } -func (_c *MockMsgStream_AsConsumer_Call) RunAndReturn(run func([]string, string, mqwrapper.SubscriptionInitialPosition)) *MockMsgStream_AsConsumer_Call { +func (_c *MockMsgStream_AsConsumer_Call) RunAndReturn(run func(context.Context, []string, string, mqwrapper.SubscriptionInitialPosition) error) *MockMsgStream_AsConsumer_Call { _c.Call.Return(run) return _c } @@ -260,6 +273,39 @@ func (_c *MockMsgStream_Close_Call) RunAndReturn(run func()) *MockMsgStream_Clos return _c } +// EnableProduce provides a mock function with given fields: can +func (_m *MockMsgStream) EnableProduce(can bool) { + _m.Called(can) +} + +// MockMsgStream_EnableProduce_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EnableProduce' +type MockMsgStream_EnableProduce_Call struct { + *mock.Call +} + +// EnableProduce is a helper method to define mock.On call +// - can bool +func (_e *MockMsgStream_Expecter) EnableProduce(can interface{}) *MockMsgStream_EnableProduce_Call { + return &MockMsgStream_EnableProduce_Call{Call: _e.mock.On("EnableProduce", can)} +} + +func (_c *MockMsgStream_EnableProduce_Call) Run(run func(can bool)) *MockMsgStream_EnableProduce_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(bool)) + }) + return _c +} + +func (_c *MockMsgStream_EnableProduce_Call) Return() *MockMsgStream_EnableProduce_Call { + _c.Call.Return() + return _c +} + +func (_c *MockMsgStream_EnableProduce_Call) RunAndReturn(run func(bool)) *MockMsgStream_EnableProduce_Call { + _c.Call.Return(run) + return _c +} + // GetLatestMsgID provides a mock function with given fields: channel func (_m *MockMsgStream) GetLatestMsgID(channel string) (mqwrapper.MessageID, error) { ret := _m.Called(channel) @@ -399,13 +445,13 @@ func (_c *MockMsgStream_Produce_Call) RunAndReturn(run func(*MsgPack) error) *Mo return _c } -// Seek provides a mock function with given fields: offset -func (_m *MockMsgStream) Seek(offset []*msgpb.MsgPosition) error { - ret := _m.Called(offset) +// Seek provides a mock function with given fields: ctx, offset +func (_m *MockMsgStream) Seek(ctx context.Context, offset []*msgpb.MsgPosition) error { + ret := _m.Called(ctx, offset) var r0 error - if rf, ok := ret.Get(0).(func([]*msgpb.MsgPosition) error); ok { - r0 = rf(offset) + if rf, ok := ret.Get(0).(func(context.Context, []*msgpb.MsgPosition) error); ok { + r0 = rf(ctx, offset) } else { r0 = ret.Error(0) } @@ -419,14 +465,15 @@ type MockMsgStream_Seek_Call struct { } // Seek is a helper method to define mock.On call +// - ctx context.Context // - offset []*msgpb.MsgPosition -func (_e *MockMsgStream_Expecter) Seek(offset interface{}) *MockMsgStream_Seek_Call { - return &MockMsgStream_Seek_Call{Call: _e.mock.On("Seek", offset)} +func (_e *MockMsgStream_Expecter) Seek(ctx interface{}, offset interface{}) *MockMsgStream_Seek_Call { + return &MockMsgStream_Seek_Call{Call: _e.mock.On("Seek", ctx, offset)} } -func (_c *MockMsgStream_Seek_Call) Run(run func(offset []*msgpb.MsgPosition)) *MockMsgStream_Seek_Call { +func (_c *MockMsgStream_Seek_Call) Run(run func(ctx context.Context, offset []*msgpb.MsgPosition)) *MockMsgStream_Seek_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]*msgpb.MsgPosition)) + run(args[0].(context.Context), args[1].([]*msgpb.MsgPosition)) }) return _c } @@ -436,7 +483,7 @@ func (_c *MockMsgStream_Seek_Call) Return(_a0 error) *MockMsgStream_Seek_Call { return _c } -func (_c *MockMsgStream_Seek_Call) RunAndReturn(run func([]*msgpb.MsgPosition) error) *MockMsgStream_Seek_Call { +func (_c *MockMsgStream_Seek_Call) RunAndReturn(run func(context.Context, []*msgpb.MsgPosition) error) *MockMsgStream_Seek_Call { _c.Call.Return(run) return _c } @@ -479,7 +526,8 @@ func (_c *MockMsgStream_SetRepackFunc_Call) RunAndReturn(run func(RepackFunc)) * func NewMockMsgStream(t interface { mock.TestingT Cleanup(func()) -}) *MockMsgStream { +}, +) *MockMsgStream { mock := &MockMsgStream{} mock.Mock.Test(t) diff --git a/pkg/mq/msgstream/mock_msgstream_factory.go b/pkg/mq/msgstream/mock_msgstream_factory.go new file mode 100644 index 0000000000000..1d0c6ff129ed2 --- /dev/null +++ b/pkg/mq/msgstream/mock_msgstream_factory.go @@ -0,0 +1,188 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package msgstream + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// MockFactory is an autogenerated mock type for the Factory type +type MockFactory struct { + mock.Mock +} + +type MockFactory_Expecter struct { + mock *mock.Mock +} + +func (_m *MockFactory) EXPECT() *MockFactory_Expecter { + return &MockFactory_Expecter{mock: &_m.Mock} +} + +// NewMsgStream provides a mock function with given fields: ctx +func (_m *MockFactory) NewMsgStream(ctx context.Context) (MsgStream, error) { + ret := _m.Called(ctx) + + var r0 MsgStream + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (MsgStream, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) MsgStream); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(MsgStream) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockFactory_NewMsgStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewMsgStream' +type MockFactory_NewMsgStream_Call struct { + *mock.Call +} + +// NewMsgStream is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockFactory_Expecter) NewMsgStream(ctx interface{}) *MockFactory_NewMsgStream_Call { + return &MockFactory_NewMsgStream_Call{Call: _e.mock.On("NewMsgStream", ctx)} +} + +func (_c *MockFactory_NewMsgStream_Call) Run(run func(ctx context.Context)) *MockFactory_NewMsgStream_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockFactory_NewMsgStream_Call) Return(_a0 MsgStream, _a1 error) *MockFactory_NewMsgStream_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockFactory_NewMsgStream_Call) RunAndReturn(run func(context.Context) (MsgStream, error)) *MockFactory_NewMsgStream_Call { + _c.Call.Return(run) + return _c +} + +// NewMsgStreamDisposer provides a mock function with given fields: ctx +func (_m *MockFactory) NewMsgStreamDisposer(ctx context.Context) func([]string, string) error { + ret := _m.Called(ctx) + + var r0 func([]string, string) error + if rf, ok := ret.Get(0).(func(context.Context) func([]string, string) error); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(func([]string, string) error) + } + } + + return r0 +} + +// MockFactory_NewMsgStreamDisposer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewMsgStreamDisposer' +type MockFactory_NewMsgStreamDisposer_Call struct { + *mock.Call +} + +// NewMsgStreamDisposer is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockFactory_Expecter) NewMsgStreamDisposer(ctx interface{}) *MockFactory_NewMsgStreamDisposer_Call { + return &MockFactory_NewMsgStreamDisposer_Call{Call: _e.mock.On("NewMsgStreamDisposer", ctx)} +} + +func (_c *MockFactory_NewMsgStreamDisposer_Call) Run(run func(ctx context.Context)) *MockFactory_NewMsgStreamDisposer_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockFactory_NewMsgStreamDisposer_Call) Return(_a0 func([]string, string) error) *MockFactory_NewMsgStreamDisposer_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockFactory_NewMsgStreamDisposer_Call) RunAndReturn(run func(context.Context) func([]string, string) error) *MockFactory_NewMsgStreamDisposer_Call { + _c.Call.Return(run) + return _c +} + +// NewTtMsgStream provides a mock function with given fields: ctx +func (_m *MockFactory) NewTtMsgStream(ctx context.Context) (MsgStream, error) { + ret := _m.Called(ctx) + + var r0 MsgStream + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (MsgStream, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) MsgStream); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(MsgStream) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockFactory_NewTtMsgStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewTtMsgStream' +type MockFactory_NewTtMsgStream_Call struct { + *mock.Call +} + +// NewTtMsgStream is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockFactory_Expecter) NewTtMsgStream(ctx interface{}) *MockFactory_NewTtMsgStream_Call { + return &MockFactory_NewTtMsgStream_Call{Call: _e.mock.On("NewTtMsgStream", ctx)} +} + +func (_c *MockFactory_NewTtMsgStream_Call) Run(run func(ctx context.Context)) *MockFactory_NewTtMsgStream_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockFactory_NewTtMsgStream_Call) Return(_a0 MsgStream, _a1 error) *MockFactory_NewTtMsgStream_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockFactory_NewTtMsgStream_Call) RunAndReturn(run func(context.Context) (MsgStream, error)) *MockFactory_NewTtMsgStream_Call { + _c.Call.Return(run) + return _c +} + +// NewMockFactory creates a new instance of MockFactory. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockFactory(t interface { + mock.TestingT + Cleanup(func()) +}) *MockFactory { + mock := &MockFactory{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/mq/msgstream/mq_factory.go b/pkg/mq/msgstream/mq_factory.go index fc5aaf870b018..201d22457c1e5 100644 --- a/pkg/mq/msgstream/mq_factory.go +++ b/pkg/mq/msgstream/mq_factory.go @@ -23,11 +23,13 @@ import ( "github.com/apache/pulsar-client-go/pulsar" "github.com/cockroachdb/errors" + "github.com/prometheus/client_golang/prometheus" "github.com/streamnative/pulsarctl/pkg/cli" "github.com/streamnative/pulsarctl/pkg/pulsar/utils" "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" kafkawrapper "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper/kafka" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper/nmq" @@ -49,11 +51,12 @@ type PmsFactory struct { PulsarTenant string PulsarNameSpace string RequestTimeout time.Duration + metricRegisterer prometheus.Registerer } func NewPmsFactory(serviceParam *paramtable.ServiceParam) *PmsFactory { config := &serviceParam.PulsarCfg - return &PmsFactory{ + f := &PmsFactory{ MQBufSize: serviceParam.MQCfg.MQBufSize.GetAsInt64(), ReceiveBufSize: serviceParam.MQCfg.ReceiveBufSize.GetAsInt64(), PulsarAddress: config.Address.GetValue(), @@ -64,18 +67,33 @@ func NewPmsFactory(serviceParam *paramtable.ServiceParam) *PmsFactory { PulsarNameSpace: config.Namespace.GetValue(), RequestTimeout: config.RequestTimeout.GetAsDuration(time.Second), } + if config.EnableClientMetrics.GetAsBool() { + // Enable client metrics if config.EnableClientMetrics is true, use pkg-defined registerer. + f.metricRegisterer = metrics.GetRegisterer() + } + return f } // NewMsgStream is used to generate a new Msgstream object func (f *PmsFactory) NewMsgStream(ctx context.Context) (MsgStream, error) { + var timeout time.Duration = f.RequestTimeout + + if deadline, ok := ctx.Deadline(); ok { + if deadline.Before(time.Now()) { + return nil, errors.New("context timeout when NewMsgStream") + } + timeout = time.Until(deadline) + } + auth, err := f.getAuthentication() if err != nil { return nil, err } clientOpts := pulsar.ClientOptions{ - URL: f.PulsarAddress, - Authentication: auth, - OperationTimeout: f.RequestTimeout, + URL: f.PulsarAddress, + Authentication: auth, + OperationTimeout: timeout, + MetricsRegisterer: f.metricRegisterer, } pulsarClient, err := pulsarmqwrapper.NewClient(f.PulsarTenant, f.PulsarNameSpace, clientOpts) @@ -87,13 +105,22 @@ func (f *PmsFactory) NewMsgStream(ctx context.Context) (MsgStream, error) { // NewTtMsgStream is used to generate a new TtMsgstream object func (f *PmsFactory) NewTtMsgStream(ctx context.Context) (MsgStream, error) { + var timeout time.Duration = f.RequestTimeout + if deadline, ok := ctx.Deadline(); ok { + if deadline.Before(time.Now()) { + return nil, errors.New("context timeout when NewTtMsgStream") + } + timeout = time.Until(deadline) + } auth, err := f.getAuthentication() if err != nil { return nil, err } clientOpts := pulsar.ClientOptions{ - URL: f.PulsarAddress, - Authentication: auth, + URL: f.PulsarAddress, + Authentication: auth, + OperationTimeout: timeout, + MetricsRegisterer: f.metricRegisterer, } pulsarClient, err := pulsarmqwrapper.NewClient(f.PulsarTenant, f.PulsarNameSpace, clientOpts) @@ -156,12 +183,18 @@ type KmsFactory struct { } func (f *KmsFactory) NewMsgStream(ctx context.Context) (MsgStream, error) { - kafkaClient := kafkawrapper.NewKafkaClientInstanceWithConfig(f.config) + kafkaClient, err := kafkawrapper.NewKafkaClientInstanceWithConfig(ctx, f.config) + if err != nil { + return nil, err + } return NewMqMsgStream(ctx, f.ReceiveBufSize, f.MQBufSize, kafkaClient, f.dispatcherFactory.NewUnmarshalDispatcher()) } func (f *KmsFactory) NewTtMsgStream(ctx context.Context) (MsgStream, error) { - kafkaClient := kafkawrapper.NewKafkaClientInstanceWithConfig(f.config) + kafkaClient, err := kafkawrapper.NewKafkaClientInstanceWithConfig(ctx, f.config) + if err != nil { + return nil, err + } return NewMqTtMsgStream(ctx, f.ReceiveBufSize, f.MQBufSize, kafkaClient, f.dispatcherFactory.NewUnmarshalDispatcher()) } @@ -171,7 +204,7 @@ func (f *KmsFactory) NewMsgStreamDisposer(ctx context.Context) func([]string, st if err != nil { return err } - msgstream.AsConsumer(channels, subname, mqwrapper.SubscriptionPositionUnknown) + msgstream.AsConsumer(ctx, channels, subname, mqwrapper.SubscriptionPositionUnknown) msgstream.Close() return nil } diff --git a/pkg/mq/msgstream/mq_factory_test.go b/pkg/mq/msgstream/mq_factory_test.go index 5ae1738e3a0b2..1f1d161335d6f 100644 --- a/pkg/mq/msgstream/mq_factory_test.go +++ b/pkg/mq/msgstream/mq_factory_test.go @@ -19,6 +19,7 @@ package msgstream import ( "context" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -26,15 +27,51 @@ import ( func TestPmsFactory(t *testing.T) { pmsFactory := NewPmsFactory(&Params.ServiceParam) - ctx := context.Background() - _, err := pmsFactory.NewMsgStream(ctx) - assert.NoError(t, err) - - _, err = pmsFactory.NewTtMsgStream(ctx) + err := pmsFactory.NewMsgStreamDisposer(context.Background())([]string{"hello"}, "xx") assert.NoError(t, err) - err = pmsFactory.NewMsgStreamDisposer(ctx)([]string{"hello"}, "xx") - assert.NoError(t, err) + tests := []struct { + description string + withTimeout bool + ctxTimeouted bool + expectedError bool + }{ + {"normal ctx", false, false, false}, + {"timeout ctx not timeout", true, false, false}, + {"timeout ctx timeout", true, true, true}, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + var cancel context.CancelFunc + ctx := context.Background() + if test.withTimeout { + ctx, cancel = context.WithTimeout(ctx, time.Millisecond) + defer cancel() + } + + if test.ctxTimeouted { + time.Sleep(time.Millisecond) + } + stream, err := pmsFactory.NewMsgStream(ctx) + if test.expectedError { + assert.Error(t, err) + assert.Nil(t, stream) + } else { + assert.NoError(t, err) + assert.NotNil(t, stream) + } + + ttStream, err := pmsFactory.NewTtMsgStream(ctx) + if test.expectedError { + assert.Error(t, err) + assert.Nil(t, ttStream) + } else { + assert.NoError(t, err) + assert.NotNil(t, ttStream) + } + }) + } } func TestPmsFactoryWithAuth(t *testing.T) { @@ -63,19 +100,52 @@ func TestPmsFactoryWithAuth(t *testing.T) { _, err = pmsFactory.NewTtMsgStream(ctx) assert.Error(t, err) - } func TestKafkaFactory(t *testing.T) { kmsFactory := NewKmsFactory(&Params.ServiceParam) - ctx := context.Background() - _, err := kmsFactory.NewMsgStream(ctx) - assert.NoError(t, err) - - _, err = kmsFactory.NewTtMsgStream(ctx) - assert.NoError(t, err) - - // err = kmsFactory.NewMsgStreamDisposer(ctx)([]string{"hello"}, "xx") - // assert.NoError(t, err) + tests := []struct { + description string + withTimeout bool + ctxTimeouted bool + expectedError bool + }{ + {"normal ctx", false, false, false}, + {"timeout ctx not timeout", true, false, false}, + {"timeout ctx timeout", true, true, true}, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + var cancel context.CancelFunc + ctx := context.Background() + timeoutDur := time.Millisecond * 30 + if test.withTimeout { + ctx, cancel = context.WithTimeout(ctx, timeoutDur) + defer cancel() + } + + if test.ctxTimeouted { + time.Sleep(timeoutDur) + } + stream, err := kmsFactory.NewMsgStream(ctx) + if test.expectedError { + assert.Error(t, err) + assert.Nil(t, stream) + } else { + assert.NoError(t, err) + assert.NotNil(t, stream) + } + + ttStream, err := kmsFactory.NewTtMsgStream(ctx) + if test.expectedError { + assert.Error(t, err) + assert.Nil(t, ttStream) + } else { + assert.NoError(t, err) + assert.NotNil(t, ttStream) + } + }) + } } diff --git a/pkg/mq/msgstream/mq_kafka_msgstream_test.go b/pkg/mq/msgstream/mq_kafka_msgstream_test.go index ca1da97ad89c0..468d4e054a96f 100644 --- a/pkg/mq/msgstream/mq_kafka_msgstream_test.go +++ b/pkg/mq/msgstream/mq_kafka_msgstream_test.go @@ -23,10 +23,10 @@ import ( "testing" "github.com/confluentinc/confluent-kafka-go/kafka" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" kafkawrapper "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper/kafka" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -145,7 +145,7 @@ func TestStream_KafkaMsgStream_SeekToLast(t *testing.T) { defer outputStream2.Close() assert.NoError(t, err) - err = outputStream2.Seek([]*msgpb.MsgPosition{seekPosition}) + err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}) assert.NoError(t, err) cnt := 0 @@ -392,7 +392,6 @@ func TestStream_KafkaTtMsgStream_2(t *testing.T) { cnt1 := (len(msgPacks1)/2 - 1) * len(msgPacks1[0].Msgs) cnt2 := (len(msgPacks2)/2 - 1) * len(msgPacks2[0].Msgs) assert.Equal(t, (cnt1 + cnt2), msgCount) - } func TestStream_KafkaTtMsgStream_DataNodeTimetickMsgstream(t *testing.T) { @@ -408,7 +407,7 @@ func TestStream_KafkaTtMsgStream_DataNodeTimetickMsgstream(t *testing.T) { factory := ProtoUDFactory{} kafkaClient := kafkawrapper.NewKafkaClientInstance(kafkaAddress) outputStream, _ := NewMqTtMsgStream(ctx, 100, 100, kafkaClient, factory.NewUnmarshalDispatcher()) - outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionLatest) + outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionLatest) var wg sync.WaitGroup wg.Add(1) @@ -462,7 +461,7 @@ func getKafkaOutputStream(ctx context.Context, kafkaAddress string, consumerChan factory := ProtoUDFactory{} kafkaClient := kafkawrapper.NewKafkaClientInstance(kafkaAddress) outputStream, _ := NewMqMsgStream(ctx, 100, 100, kafkaClient, factory.NewUnmarshalDispatcher()) - outputStream.AsConsumer(consumerChannels, consumerSubName, position) + outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, position) return outputStream } @@ -470,7 +469,7 @@ func getKafkaTtOutputStream(ctx context.Context, kafkaAddress string, consumerCh factory := ProtoUDFactory{} kafkaClient := kafkawrapper.NewKafkaClientInstance(kafkaAddress) outputStream, _ := NewMqTtMsgStream(ctx, 100, 100, kafkaClient, factory.NewUnmarshalDispatcher()) - outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) + outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) return outputStream } @@ -482,7 +481,7 @@ func getKafkaTtOutputStreamAndSeek(ctx context.Context, kafkaAddress string, pos for _, c := range positions { consumerName = append(consumerName, c.ChannelName) } - outputStream.AsConsumer(consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown) - outputStream.Seek(positions) + outputStream.AsConsumer(context.Background(), consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown) + outputStream.Seek(context.Background(), positions) return outputStream } diff --git a/pkg/mq/msgstream/mq_msgstream.go b/pkg/mq/msgstream/mq_msgstream.go index 85f931d0c1993..81aa95c06631b 100644 --- a/pkg/mq/msgstream/mq_msgstream.go +++ b/pkg/mq/msgstream/mq_msgstream.go @@ -20,19 +20,22 @@ import ( "context" "fmt" "path/filepath" + "strconv" "sync" "sync/atomic" "time" "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/samber/lo" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/tsoutil" @@ -49,16 +52,17 @@ type mqMsgStream struct { consumers map[string]mqwrapper.Consumer consumerChannels []string - repackFunc RepackFunc - unmarshal UnmarshalDispatcher - receiveBuf chan *MsgPack - closeRWMutex *sync.RWMutex - streamCancel func() - bufSize int64 - producerLock *sync.Mutex - consumerLock *sync.Mutex - closed int32 - onceChan sync.Once + repackFunc RepackFunc + unmarshal UnmarshalDispatcher + receiveBuf chan *MsgPack + closeRWMutex *sync.RWMutex + streamCancel func() + bufSize int64 + producerLock *sync.RWMutex + consumerLock *sync.Mutex + closed int32 + onceChan sync.Once + enableProduce atomic.Value } // NewMqMsgStream is used to generate a new mqMsgStream object @@ -66,8 +70,8 @@ func NewMqMsgStream(ctx context.Context, receiveBufSize int64, bufSize int64, client mqwrapper.Client, - unmarshal UnmarshalDispatcher) (*mqMsgStream, error) { - + unmarshal UnmarshalDispatcher, +) (*mqMsgStream, error) { streamCtx, streamCancel := context.WithCancel(ctx) producers := make(map[string]mqwrapper.Producer) consumers := make(map[string]mqwrapper.Consumer) @@ -87,11 +91,23 @@ func NewMqMsgStream(ctx context.Context, bufSize: bufSize, receiveBuf: receiveBuf, streamCancel: streamCancel, - producerLock: &sync.Mutex{}, + producerLock: &sync.RWMutex{}, consumerLock: &sync.Mutex{}, closeRWMutex: &sync.RWMutex{}, closed: 0, } + ctxLog := log.Ctx(ctx) + stream.enableProduce.Store(paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool()) + paramtable.Get().Watch(paramtable.Get().CommonCfg.TTMsgEnabled.Key, config.NewHandler("enable send tt msg", func(event *config.Event) { + value, err := strconv.ParseBool(event.Value) + if err != nil { + ctxLog.Warn("Failed to parse bool value", zap.String("v", event.Value), zap.Error(err)) + return + } + stream.enableProduce.Store(value) + ctxLog.Info("Msg Stream state updated", zap.Bool("can_produce", stream.isEnabledProduce())) + })) + ctxLog.Info("Msg Stream state", zap.Bool("can_produce", stream.isEnabledProduce())) return stream, nil } @@ -146,7 +162,7 @@ func (ms *mqMsgStream) CheckTopicValid(channel string) error { // AsConsumerWithPosition Create consumer to receive message from channels, with initial position // if initial position is set to latest, last message in the channel is exclusive -func (ms *mqMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) { +func (ms *mqMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error { for _, channel := range channels { if _, ok := ms.consumers[channel]; ok { continue @@ -171,14 +187,19 @@ func (ms *mqMsgStream) AsConsumer(channels []string, subName string, position mq ms.consumerChannels = append(ms.consumerChannels, channel) return nil } - // TODO if know the former subscribe is invalid, should we use pulsarctl to accelerate recovery speed - err := retry.Do(context.TODO(), fn, retry.Attempts(50), retry.Sleep(time.Millisecond*200), retry.MaxSleepTime(5*time.Second)) + + err := retry.Do(ctx, fn, retry.Attempts(20), retry.Sleep(time.Millisecond*200), retry.MaxSleepTime(5*time.Second)) if err != nil { - errMsg := "Failed to create consumer " + channel + ", error = " + err.Error() - panic(errMsg) + errMsg := fmt.Sprintf("Failed to create consumer %s", channel) + if merr.IsCanceledOrTimeout(err) { + return errors.Wrapf(err, errMsg) + } + + panic(fmt.Sprintf("%s, errors = %s", errMsg, err.Error())) } log.Info("Successfully create consumer", zap.String("channel", channel), zap.String("subname", subName)) } + return nil } func (ms *mqMsgStream) SetRepackFunc(repackFunc RepackFunc) { @@ -208,7 +229,6 @@ func (ms *mqMsgStream) Close() { ms.client.Close() close(ms.receiveBuf) - } func (ms *mqMsgStream) ComputeProduceChannelIndexes(tsMsgs []TsMsg) [][]int32 { @@ -236,7 +256,19 @@ func (ms *mqMsgStream) GetProduceChannels() []string { return ms.producerChannels } +func (ms *mqMsgStream) EnableProduce(can bool) { + ms.enableProduce.Store(can) +} + +func (ms *mqMsgStream) isEnabledProduce() bool { + return ms.enableProduce.Load().(bool) +} + func (ms *mqMsgStream) Produce(msgPack *MsgPack) error { + if !ms.isEnabledProduce() { + log.Warn("can't produce the msg in the backup instance", zap.Stack("stack")) + return merr.ErrDenyProduceMsg + } if msgPack == nil || len(msgPack.Msgs) <= 0 { log.Debug("Warning: Receive empty msgPack") return nil @@ -283,13 +315,13 @@ func (ms *mqMsgStream) Produce(msgPack *MsgPack) error { msg := &mqwrapper.ProducerMessage{Payload: m, Properties: map[string]string{}} InjectCtx(spanCtx, msg.Properties) - ms.producerLock.Lock() + ms.producerLock.RLock() if _, err := ms.producers[channel].Send(spanCtx, msg); err != nil { - ms.producerLock.Unlock() + ms.producerLock.RUnlock() sp.RecordError(err) return err } - ms.producerLock.Unlock() + ms.producerLock.RUnlock() } } return nil @@ -302,6 +334,14 @@ func (ms *mqMsgStream) Broadcast(msgPack *MsgPack) (map[string][]MessageID, erro if msgPack == nil || len(msgPack.Msgs) <= 0 { return ids, errors.New("empty msgs") } + // Only allow to create collection msg in backup instance + // However, there may be a problem of ts disorder here, but because the start position of the collection only uses offsets, not time, there is no problem for the time being + isCreateCollectionMsg := len(msgPack.Msgs) == 1 && msgPack.Msgs[0].Type() == commonpb.MsgType_CreateCollection + + if !ms.isEnabledProduce() && !isCreateCollectionMsg { + log.Warn("can't broadcast the msg in the backup instance", zap.Stack("stack")) + return ids, merr.ErrDenyProduceMsg + } for _, v := range msgPack.Msgs { spanCtx, sp := MsgSpanFromCtx(v.TraceCtx(), v) @@ -352,7 +392,6 @@ func (ms *mqMsgStream) getTsMsgFromConsumerMsg(msg mqwrapper.Message) (TsMsg, er return nil, fmt.Errorf("failed to unmarshal tsMsg, err %s", err.Error()) } - // set msg info to tsMsg tsMsg.SetPosition(&MsgPosition{ ChannelName: filepath.Base(msg.Topic()), MsgID: msg.ID().Serialize(), @@ -381,9 +420,11 @@ func (ms *mqMsgStream) receiveMsg(consumer mqwrapper.Consumer) { log.Warn("MqMsgStream get msg whose payload is nil") continue } + // not need to check the preCreatedTopic is empty, related issue: https://github.com/milvus-io/milvus/issues/27295 + // if the message not belong to the topic, will skip it tsMsg, err := ms.getTsMsgFromConsumerMsg(msg) if err != nil { - log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err)) + log.Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err)) continue } pos := tsMsg.Position() @@ -425,7 +466,7 @@ func (ms *mqMsgStream) Chan() <-chan *MsgPack { // Seek reset the subscription associated with this consumer to a specific position, the seek position is exclusive // User has to ensure mq_msgstream is not closed before seek, and the seek position is already written. -func (ms *mqMsgStream) Seek(msgPositions []*msgpb.MsgPosition) error { +func (ms *mqMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPosition) error { for _, mp := range msgPositions { consumer, ok := ms.consumers[mp.ChannelName] if !ok { @@ -468,7 +509,8 @@ func NewMqTtMsgStream(ctx context.Context, receiveBufSize int64, bufSize int64, client mqwrapper.Client, - unmarshal UnmarshalDispatcher) (*MqTtMsgStream, error) { + unmarshal UnmarshalDispatcher, +) (*MqTtMsgStream, error) { msgStream, err := NewMqMsgStream(ctx, receiveBufSize, bufSize, client, unmarshal) if err != nil { return nil, err @@ -509,7 +551,7 @@ func (ms *MqTtMsgStream) addConsumer(consumer mqwrapper.Consumer, channel string } // AsConsumerWithPosition subscribes channels as consumer for a MsgStream and seeks to a certain position. -func (ms *MqTtMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) { +func (ms *MqTtMsgStream) AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error { for _, channel := range channels { if _, ok := ms.consumers[channel]; ok { continue @@ -533,12 +575,19 @@ func (ms *MqTtMsgStream) AsConsumer(channels []string, subName string, position ms.addConsumer(pc, channel) return nil } - err := retry.Do(context.TODO(), fn, retry.Attempts(20), retry.Sleep(time.Millisecond*200), retry.MaxSleepTime(5*time.Second)) + + err := retry.Do(ctx, fn, retry.Attempts(20), retry.Sleep(time.Millisecond*200), retry.MaxSleepTime(5*time.Second)) if err != nil { - errMsg := "Failed to create consumer " + channel + ", error = " + err.Error() - panic(errMsg) + errMsg := fmt.Sprintf("Failed to create consumer %s", channel) + if merr.IsCanceledOrTimeout(err) { + return errors.Wrapf(err, errMsg) + } + + panic(fmt.Sprintf("%s, errors = %s", errMsg, err.Error())) } } + + return nil } // Close will stop goroutine and free internal producers and consumers @@ -726,9 +775,11 @@ func (ms *MqTtMsgStream) consumeToTtMsg(consumer mqwrapper.Consumer) { log.Warn("MqTtMsgStream get msg whose payload is nil") continue } + // not need to check the preCreatedTopic is empty, related issue: https://github.com/milvus-io/milvus/issues/27295 + // if the message not belong to the topic, will skip it tsMsg, err := ms.getTsMsgFromConsumerMsg(msg) if err != nil { - log.Error("Failed to getTsMsgFromConsumerMsg", zap.Error(err)) + log.Warn("Failed to getTsMsgFromConsumerMsg", zap.Error(err)) continue } @@ -773,7 +824,7 @@ func (ms *MqTtMsgStream) allChanReachSameTtMsg(chanTtMsgSync map[mqwrapper.Consu } // Seek to the specified position -func (ms *MqTtMsgStream) Seek(msgPositions []*msgpb.MsgPosition) error { +func (ms *MqTtMsgStream) Seek(ctx context.Context, msgPositions []*msgpb.MsgPosition) error { var consumer mqwrapper.Consumer var mp *MsgPosition var err error @@ -797,7 +848,7 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*msgpb.MsgPosition) error { if err != nil { log.Warn("Failed to seek", zap.String("channel", mp.ChannelName), zap.Error(err)) // stop retry if consumer topic not exist - if errors.Is(err, mqwrapper.ErrTopicNotExist) { + if errors.Is(err, merr.ErrMqTopicNotFound) { return retry.Unrecoverable(err) } return err @@ -815,7 +866,7 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*msgpb.MsgPosition) error { if len(mp.MsgID) == 0 { return fmt.Errorf("when msgID's length equal to 0, please use AsConsumer interface") } - err = retry.Do(context.TODO(), fn, retry.Attempts(20), retry.Sleep(time.Millisecond*200), retry.MaxSleepTime(5*time.Second)) + err = retry.Do(ctx, fn, retry.Attempts(20), retry.Sleep(time.Millisecond*200), retry.MaxSleepTime(5*time.Second)) if err != nil { return fmt.Errorf("failed to seek, error %s", err.Error()) } @@ -828,6 +879,8 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*msgpb.MsgPosition) error { select { case <-ms.ctx.Done(): return ms.ctx.Err() + case <-ctx.Done(): + return ctx.Err() case msg, ok := <-consumer.Chan(): if !ok { return fmt.Errorf("consumer closed") @@ -845,7 +898,6 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*msgpb.MsgPosition) error { } if tsMsg.Type() == commonpb.MsgType_TimeTick && tsMsg.BeginTs() >= mp.Timestamp { runLoop = false - break } else if tsMsg.BeginTs() > mp.Timestamp { ctx, _ := ExtractCtx(tsMsg, msg.Properties()) tsMsg.SetTraceCtx(ctx) @@ -855,6 +907,8 @@ func (ms *MqTtMsgStream) Seek(msgPositions []*msgpb.MsgPosition) error { MsgID: msg.ID().Serialize(), }) ms.chanMsgBuf[consumer] = append(ms.chanMsgBuf[consumer], tsMsg) + } else { + log.Info("skip msg", zap.Any("msg", tsMsg)) } } } diff --git a/pkg/mq/msgstream/mq_msgstream_test.go b/pkg/mq/msgstream/mq_msgstream_test.go index c5b34a04eecad..2d55265085dfd 100644 --- a/pkg/mq/msgstream/mq_msgstream_test.go +++ b/pkg/mq/msgstream/mq_msgstream_test.go @@ -29,12 +29,12 @@ import ( "github.com/apache/pulsar-client-go/pulsar" "github.com/cockroachdb/errors" "github.com/confluentinc/confluent-kafka-go/kafka" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" pulsarwrapper "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper/pulsar" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -111,6 +111,13 @@ func TestStream_PulsarMsgStream_Insert(t *testing.T) { inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) + { + inputStream.EnableProduce(false) + err := inputStream.Produce(&msgPack) + require.Error(t, err) + } + + inputStream.EnableProduce(true) err := inputStream.Produce(&msgPack) require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) @@ -177,6 +184,13 @@ func TestStream_PulsarMsgStream_BroadCast(t *testing.T) { inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) + { + inputStream.EnableProduce(false) + _, err := inputStream.Broadcast(&msgPack) + require.Error(t, err) + } + + inputStream.EnableProduce(true) _, err := inputStream.Broadcast(&msgPack) require.NoErrorf(t, err, fmt.Sprintf("broadcast error = %v", err)) @@ -250,7 +264,7 @@ func TestStream_PulsarMsgStream_InsertRepackFunc(t *testing.T) { pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher()) - outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) + outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) var output MsgStream = outputStream err := (*inputStream).Produce(&msgPack) @@ -301,7 +315,7 @@ func TestStream_PulsarMsgStream_DeleteRepackFunc(t *testing.T) { pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher()) - outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) + outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) var output MsgStream = outputStream err := (*inputStream).Produce(&msgPack) @@ -333,7 +347,7 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) { pulsarClient2, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient2, factory.NewUnmarshalDispatcher()) - outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) + outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) var output MsgStream = outputStream err := (*inputStream).Produce(&msgPack) @@ -482,12 +496,12 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) { factory := ProtoUDFactory{} pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) - outputStream2.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) + outputStream2.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) lastMsgID, err := outputStream2.GetLatestMsgID(c) defer outputStream2.Close() assert.NoError(t, err) - err = outputStream2.Seek([]*msgpb.MsgPosition{seekPosition}) + err = outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}) assert.NoError(t, err) cnt := 0 @@ -521,8 +535,34 @@ func TestStream_PulsarMsgStream_SeekToLast(t *testing.T) { assert.Equal(t, 4, cnt) } +func TestStream_MsgStream_AsConsumerCtxDone(t *testing.T) { + pulsarAddress := getPulsarAddress() + + t.Run("MsgStream AsConsumer with timeout context", func(t *testing.T) { + c1 := funcutil.RandomString(8) + consumerChannels := []string{c1} + consumerSubName := funcutil.RandomString(8) + + ctx := context.Background() + factory := ProtoUDFactory{} + pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) + outputStream, _ := NewMqTtMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) + + ctx, cancel := context.WithTimeout(ctx, time.Millisecond) + defer cancel() + <-time.After(2 * time.Millisecond) + err := outputStream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) + assert.Error(t, err) + + omsgstream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) + err = omsgstream.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) + assert.Error(t, err) + }) +} + func TestStream_PulsarTtMsgStream_Seek(t *testing.T) { pulsarAddress := getPulsarAddress() + c1 := funcutil.RandomString(8) producerChannels := []string{c1} consumerChannels := []string{c1} @@ -889,8 +929,8 @@ func TestStream_MqMsgStream_Seek(t *testing.T) { factory := ProtoUDFactory{} pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) - outputStream2.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) - outputStream2.Seek([]*msgpb.MsgPosition{seekPosition}) + outputStream2.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) + outputStream2.Seek(ctx, []*msgpb.MsgPosition{seekPosition}) for i := 6; i < 10; i++ { result := consumer(ctx, outputStream2) @@ -930,7 +970,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) { factory := ProtoUDFactory{} pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) - outputStream2.AsConsumer(consumerChannels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionEarliest) + outputStream2.AsConsumer(ctx, consumerChannels, funcutil.RandomString(8), mqwrapper.SubscriptionPositionEarliest) defer outputStream2.Close() messageID, _ := pulsar.DeserializeMessageID(seekPosition.MsgID) // try to seek to not written position @@ -945,7 +985,7 @@ func TestStream_MqMsgStream_SeekInvalidMessage(t *testing.T) { }, } - err = outputStream2.Seek(p) + err = outputStream2.Seek(ctx, p) assert.NoError(t, err) for i := 10; i < 20; i++ { @@ -979,7 +1019,7 @@ func TestStream_MqMsgStream_SeekLatest(t *testing.T) { factory := ProtoUDFactory{} pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) outputStream2, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) - outputStream2.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionLatest) + outputStream2.AsConsumer(ctx, consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionLatest) msgPack.Msgs = nil // produce another 10 tsMs @@ -1321,7 +1361,7 @@ func getPulsarOutputStream(ctx context.Context, pulsarAddress string, consumerCh factory := ProtoUDFactory{} pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) outputStream, _ := NewMqMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) - outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) + outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) return outputStream } @@ -1329,7 +1369,7 @@ func getPulsarTtOutputStream(ctx context.Context, pulsarAddress string, consumer factory := ProtoUDFactory{} pulsarClient, _ := pulsarwrapper.NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) outputStream, _ := NewMqTtMsgStream(ctx, 100, 100, pulsarClient, factory.NewUnmarshalDispatcher()) - outputStream.AsConsumer(consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) + outputStream.AsConsumer(context.Background(), consumerChannels, consumerSubName, mqwrapper.SubscriptionPositionEarliest) return outputStream } @@ -1341,8 +1381,8 @@ func getPulsarTtOutputStreamAndSeek(ctx context.Context, pulsarAddress string, p for _, c := range positions { consumerName = append(consumerName, c.ChannelName) } - outputStream.AsConsumer(consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown) - outputStream.Seek(positions) + outputStream.AsConsumer(context.Background(), consumerName, funcutil.RandomString(8), mqwrapper.SubscriptionPositionUnknown) + outputStream.Seek(context.Background(), positions) return outputStream } diff --git a/pkg/mq/msgstream/mqwrapper/consumer.go b/pkg/mq/msgstream/mqwrapper/consumer.go index 4046f8d75361d..f8b49e40601f6 100644 --- a/pkg/mq/msgstream/mqwrapper/consumer.go +++ b/pkg/mq/msgstream/mqwrapper/consumer.go @@ -59,7 +59,7 @@ type Consumer interface { // Get Message channel, once you chan you can not seek again Chan() <-chan Message - // Seek to the uniqueID position + // Seek to the uniqueID position, the second bool param indicates whether the message is included in the position Seek(MessageID, bool) error //nolint:govet // Ack make sure that msg is received diff --git a/pkg/mq/msgstream/mqwrapper/errors.go b/pkg/mq/msgstream/mqwrapper/errors.go deleted file mode 100644 index 9683d9f297bd9..0000000000000 --- a/pkg/mq/msgstream/mqwrapper/errors.go +++ /dev/null @@ -1,6 +0,0 @@ -package mqwrapper - -import "github.com/cockroachdb/errors" - -// ErrTopicNotExist topic not exist error. -var ErrTopicNotExist = errors.New("topic not exist") diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_client.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_client.go index e12ac1f5881a8..d8e45bdcbb0f3 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_client.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_client.go @@ -1,10 +1,13 @@ package kafka import ( + "context" "fmt" "strconv" "sync" + "time" + "github.com/cockroachdb/errors" "github.com/confluentinc/confluent-kafka-go/kafka" "go.uber.org/atomic" "go.uber.org/zap" @@ -43,7 +46,6 @@ func getBasicConfig(address string) kafka.ConfigMap { func NewKafkaClientInstance(address string) *kafkaClient { config := getBasicConfig(address) return NewKafkaClientInstanceWithConfigMap(config, kafka.ConfigMap{}, kafka.ConfigMap{}) - } func NewKafkaClientInstanceWithConfigMap(config kafka.ConfigMap, extraConsumerConfig kafka.ConfigMap, extraProducerConfig kafka.ConfigMap) *kafkaClient { @@ -54,9 +56,18 @@ func NewKafkaClientInstanceWithConfigMap(config kafka.ConfigMap, extraConsumerCo return &kafkaClient{basicConfig: config, consumerConfig: extraConsumerConfig, producerConfig: extraProducerConfig} } -func NewKafkaClientInstanceWithConfig(config *paramtable.KafkaConfig) *kafkaClient { +func NewKafkaClientInstanceWithConfig(ctx context.Context, config *paramtable.KafkaConfig) (*kafkaClient, error) { kafkaConfig := getBasicConfig(config.Address.GetValue()) + // connection setup timeout, default as 30000ms + if deadline, ok := ctx.Deadline(); ok { + if deadline.Before(time.Now()) { + return nil, errors.New("context timeout when new kafka client") + } + timeout := time.Until(deadline).Milliseconds() + kafkaConfig.SetKey("socket.connection.setup.timeout.ms", timeout) + } + if (config.SaslUsername.GetValue() == "" && config.SaslPassword.GetValue() != "") || (config.SaslUsername.GetValue() != "" && config.SaslPassword.GetValue() == "") { panic("enable security mode need config username and password at the same time!") @@ -77,8 +88,10 @@ func NewKafkaClientInstanceWithConfig(config *paramtable.KafkaConfig) *kafkaClie return kafkaConfigMap } - return NewKafkaClientInstanceWithConfigMap(kafkaConfig, specExtraConfig(config.ConsumerExtraConfig.GetValue()), specExtraConfig(config.ProducerExtraConfig.GetValue())) - + return NewKafkaClientInstanceWithConfigMap( + kafkaConfig, + specExtraConfig(config.ConsumerExtraConfig.GetValue()), + specExtraConfig(config.ProducerExtraConfig.GetValue())), nil } func cloneKafkaConfig(config kafka.ConfigMap) *kafka.ConfigMap { @@ -137,7 +150,7 @@ func (kc *kafkaClient) newProducerConfig() *kafka.ConfigMap { // we want to ensure tt send out as soon as possible newConf.SetKey("linger.ms", 2) - //special producer config + // special producer config kc.specialExtraConfig(newConf, kc.producerConfig) return newConf @@ -148,9 +161,9 @@ func (kc *kafkaClient) newConsumerConfig(group string, offset mqwrapper.Subscrip newConf.SetKey("group.id", group) newConf.SetKey("enable.auto.commit", false) - //Kafka default will not create topics if consumer's the topics don't exist. - //In order to compatible with other MQ, we need to enable the following configuration, - //meanwhile, some implementation also try to consume a non-exist topic, such as dataCoordTimeTick. + // Kafka default will not create topics if consumer's the topics don't exist. + // In order to compatible with other MQ, we need to enable the following configuration, + // meanwhile, some implementation also try to consume a non-exist topic, such as dataCoordTimeTick. newConf.SetKey("allow.auto.create.topics", true) kc.specialExtraConfig(newConf, kc.consumerConfig) diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_client_test.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_client_test.go index a006e574d1d96..b33d425a2e1e7 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_client_test.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_client_test.go @@ -35,6 +35,7 @@ func TestMain(m *testing.M) { broker := mockCluster.BootstrapServers() Params.Save("kafka.brokerList", broker) + log.Info("start testing kafka broker", zap.String("address", broker)) exitCode := m.Run() os.Exit(exitCode) @@ -364,10 +365,10 @@ func createKafkaConfig(opts ...kafkaCfgOption) *paramtable.KafkaConfig { func TestKafkaClient_NewKafkaClientInstanceWithConfig(t *testing.T) { config1 := createKafkaConfig(withAddr("addr"), withPasswd("password")) - assert.Panics(t, func() { NewKafkaClientInstanceWithConfig(config1) }) + assert.Panics(t, func() { NewKafkaClientInstanceWithConfig(context.Background(), config1) }) config2 := createKafkaConfig(withAddr("addr"), withUsername("username")) - assert.Panics(t, func() { NewKafkaClientInstanceWithConfig(config2) }) + assert.Panics(t, func() { NewKafkaClientInstanceWithConfig(context.Background(), config2) }) producerConfig := make(map[string]string) producerConfig["client.id"] = "dc1" @@ -378,7 +379,8 @@ func TestKafkaClient_NewKafkaClientInstanceWithConfig(t *testing.T) { config.ConsumerExtraConfig = paramtable.ParamGroup{GetFunc: func() map[string]string { return consumerConfig }} config.ProducerExtraConfig = paramtable.ParamGroup{GetFunc: func() map[string]string { return producerConfig }} - client := NewKafkaClientInstanceWithConfig(config) + client, err := NewKafkaClientInstanceWithConfig(context.Background(), config) + assert.NoError(t, err) assert.NotNil(t, client) assert.NotNil(t, client.basicConfig) @@ -406,7 +408,8 @@ func createConsumer(t *testing.T, kc *kafkaClient, topic string, groupID string, - initPosition mqwrapper.SubscriptionInitialPosition) mqwrapper.Consumer { + initPosition mqwrapper.SubscriptionInitialPosition, +) mqwrapper.Consumer { consumer, err := kc.Subscribe(mqwrapper.ConsumerOptions{ Topic: topic, SubscriptionName: groupID, diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer.go index 2b497922995e0..f1b0b9d125e0c 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer.go @@ -11,6 +11,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) type Consumer struct { @@ -125,22 +126,13 @@ func (kc *Consumer) Chan() <-chan mqwrapper.Message { for { select { case <-kc.closeCh: - log.Info("close consumer ", zap.String("topic", kc.topic), zap.String("groupID", kc.groupID)) - start := time.Now() - err := kc.c.Close() - if err != nil { - log.Warn("failed to close ", zap.String("topic", kc.topic), zap.Error(err)) - } - cost := time.Since(start).Milliseconds() - if cost > 200 { - log.Warn("close consumer costs too long time", zap.Any("topic", kc.topic), zap.String("groupID", kc.groupID), zap.Int64("time(ms)", cost)) - } if kc.msgChannel != nil { close(kc.msgChannel) } return default: - e, err := kc.c.ReadMessage(30 * time.Second) + readTimeout := paramtable.Get().KafkaCfg.ReadTimeout.GetAsDuration(time.Second) + e, err := kc.c.ReadMessage(readTimeout) if err != nil { // if we failed to read message in 30 Seconds, print out a warn message since there should always be a tt log.Warn("consume msg failed", zap.Any("topic", kc.topic), zap.String("groupID", kc.groupID), zap.Error(err)) @@ -195,7 +187,8 @@ func (kc *Consumer) internalSeek(offset kafka.Offset, inclusive bool) error { if err := kc.c.Seek(kafka.TopicPartition{ Topic: &kc.topic, Partition: mqwrapper.DefaultPartitionIdx, - Offset: offset}, timeout); err != nil { + Offset: offset, + }, timeout); err != nil { return err } cost = time.Since(start).Milliseconds() @@ -229,32 +222,43 @@ func (kc *Consumer) GetLatestMsgID() (mqwrapper.MessageID, error) { } func (kc *Consumer) CheckTopicValid(topic string) error { - latestMsgID, err := kc.GetLatestMsgID() + _, err := kc.GetLatestMsgID() log.With(zap.String("topic", kc.topic)) // check topic is existed if err != nil { switch v := err.(type) { case kafka.Error: - if v.Code() == kafka.ErrUnknownTopic || v.Code() == kafka.ErrUnknownPartition || v.Code() == kafka.ErrUnknownTopicOrPart { - return merr.WrapErrTopicNotFound(topic, "topic get latest msg ID failed, topic or partition does not exists") + if v.Code() == kafka.ErrUnknownTopic || v.Code() == kafka.ErrUnknownTopicOrPart { + return merr.WrapErrMqTopicNotFound(topic, err.Error()) } + return merr.WrapErrMqInternal(err) default: return err } } - // check topic is empty - if !latestMsgID.AtEarliestPosition() { - return merr.WrapErrTopicNotEmpty(topic, "topic is not empty") - } - log.Info("created topic is empty") - return nil } +func (kc *Consumer) closeInternal() { + log.Info("close consumer ", zap.String("topic", kc.topic), zap.String("groupID", kc.groupID)) + start := time.Now() + err := kc.c.Close() + if err != nil { + log.Warn("failed to close ", zap.String("topic", kc.topic), zap.Error(err)) + } + cost := time.Since(start).Milliseconds() + if cost > 200 { + log.Warn("close consumer costs too long time", zap.Any("topic", kc.topic), zap.String("groupID", kc.groupID), zap.Int64("time(ms)", cost)) + } +} + func (kc *Consumer) Close() { kc.closeOnce.Do(func() { close(kc.closeCh) - kc.wg.Wait() // wait worker exist and close the client + // wait work goroutine exit + kc.wg.Wait() + // close the client + kc.closeInternal() }) } diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer_test.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer_test.go index 5078c541e3c0d..43efe783addd2 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer_test.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer_test.go @@ -113,8 +113,7 @@ func TestKafkaConsumer_ChanWithNoAssign(t *testing.T) { }) } -type mockMsgID struct { -} +type mockMsgID struct{} func (m2 mockMsgID) AtEarliestPosition() bool { return false @@ -269,3 +268,31 @@ func TestKafkaConsumer_CheckPreTopicValid(t *testing.T) { err = consumer.CheckTopicValid(topic) assert.NoError(t, err) } + +func TestKafkaConsumer_Close(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + topic := fmt.Sprintf("test-topicName-%d", rand.Int()) + + data1 := []int{111, 222, 333} + data2 := []string{"111", "222", "333"} + testKafkaConsumerProduceData(t, topic, data1, data2) + + t.Run("close after only get latest msgID", func(t *testing.T) { + groupID := fmt.Sprintf("test-groupid-%d", rand.Int()) + config := createConfig(groupID) + consumer, err := newKafkaConsumer(config, 16, topic, groupID, mqwrapper.SubscriptionPositionEarliest) + assert.NoError(t, err) + _, err = consumer.GetLatestMsgID() + assert.NoError(t, err) + consumer.Close() + }) + + t.Run("close after only Chan method is invoked", func(t *testing.T) { + groupID := fmt.Sprintf("test-groupid-%d", rand.Int()) + config := createConfig(groupID) + consumer, err := newKafkaConsumer(config, 16, topic, groupID, mqwrapper.SubscriptionPositionEarliest) + assert.NoError(t, err) + <-consumer.Chan() + consumer.Close() + }) +} diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer.go index e9e1a2dd0193f..f2f0ec4e43cad 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer.go @@ -7,12 +7,13 @@ import ( "time" "github.com/confluentinc/confluent-kafka-go/kafka" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/timerecord" - "go.uber.org/zap" ) type kafkaProducer struct { @@ -47,7 +48,6 @@ func (kp *kafkaProducer) Send(ctx context.Context, message *mqwrapper.ProducerMe Value: message.Payload, Headers: headers, }, kp.deliveryChan) - if err != nil { metrics.MsgStreamOpCounter.WithLabelValues(metrics.SendMsgLabel, metrics.FailLabel).Inc() return nil, err @@ -78,7 +78,7 @@ func (kp *kafkaProducer) Close() { kp.isClosed = true start := time.Now() - //flush in-flight msg within queue. + // flush in-flight msg within queue. i := kp.p.Flush(10000) if i > 0 { log.Warn("There are still un-flushed outstanding events", zap.Int("event_num", i), zap.Any("topic", kp.topic)) diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer_test.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer_test.go index 9c80a19d1ffb3..3ddbde026927a 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer_test.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer_test.go @@ -44,7 +44,6 @@ func TestKafkaProducer_SendSuccess(t *testing.T) { func TestKafkaProducer_SendFail(t *testing.T) { kafkaAddress := getKafkaBrokerList() { - deliveryChan := make(chan kafka.Event, 1) rand.Seed(time.Now().UnixNano()) topic := fmt.Sprintf("test-topic-%d", rand.Int()) diff --git a/pkg/mq/msgstream/mqwrapper/mock_id.go b/pkg/mq/msgstream/mqwrapper/mock_id.go new file mode 100644 index 0000000000000..a8dad1aa9a2e9 --- /dev/null +++ b/pkg/mq/msgstream/mqwrapper/mock_id.go @@ -0,0 +1,220 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mqwrapper + +import mock "github.com/stretchr/testify/mock" + +// MockMessageID is an autogenerated mock type for the MessageID type +type MockMessageID struct { + mock.Mock +} + +type MockMessageID_Expecter struct { + mock *mock.Mock +} + +func (_m *MockMessageID) EXPECT() *MockMessageID_Expecter { + return &MockMessageID_Expecter{mock: &_m.Mock} +} + +// AtEarliestPosition provides a mock function with given fields: +func (_m *MockMessageID) AtEarliestPosition() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockMessageID_AtEarliestPosition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AtEarliestPosition' +type MockMessageID_AtEarliestPosition_Call struct { + *mock.Call +} + +// AtEarliestPosition is a helper method to define mock.On call +func (_e *MockMessageID_Expecter) AtEarliestPosition() *MockMessageID_AtEarliestPosition_Call { + return &MockMessageID_AtEarliestPosition_Call{Call: _e.mock.On("AtEarliestPosition")} +} + +func (_c *MockMessageID_AtEarliestPosition_Call) Run(run func()) *MockMessageID_AtEarliestPosition_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMessageID_AtEarliestPosition_Call) Return(_a0 bool) *MockMessageID_AtEarliestPosition_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMessageID_AtEarliestPosition_Call) RunAndReturn(run func() bool) *MockMessageID_AtEarliestPosition_Call { + _c.Call.Return(run) + return _c +} + +// Equal provides a mock function with given fields: msgID +func (_m *MockMessageID) Equal(msgID []byte) (bool, error) { + ret := _m.Called(msgID) + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func([]byte) (bool, error)); ok { + return rf(msgID) + } + if rf, ok := ret.Get(0).(func([]byte) bool); ok { + r0 = rf(msgID) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(msgID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMessageID_Equal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Equal' +type MockMessageID_Equal_Call struct { + *mock.Call +} + +// Equal is a helper method to define mock.On call +// - msgID []byte +func (_e *MockMessageID_Expecter) Equal(msgID interface{}) *MockMessageID_Equal_Call { + return &MockMessageID_Equal_Call{Call: _e.mock.On("Equal", msgID)} +} + +func (_c *MockMessageID_Equal_Call) Run(run func(msgID []byte)) *MockMessageID_Equal_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *MockMessageID_Equal_Call) Return(_a0 bool, _a1 error) *MockMessageID_Equal_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMessageID_Equal_Call) RunAndReturn(run func([]byte) (bool, error)) *MockMessageID_Equal_Call { + _c.Call.Return(run) + return _c +} + +// LessOrEqualThan provides a mock function with given fields: msgID +func (_m *MockMessageID) LessOrEqualThan(msgID []byte) (bool, error) { + ret := _m.Called(msgID) + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func([]byte) (bool, error)); ok { + return rf(msgID) + } + if rf, ok := ret.Get(0).(func([]byte) bool); ok { + r0 = rf(msgID) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(msgID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMessageID_LessOrEqualThan_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LessOrEqualThan' +type MockMessageID_LessOrEqualThan_Call struct { + *mock.Call +} + +// LessOrEqualThan is a helper method to define mock.On call +// - msgID []byte +func (_e *MockMessageID_Expecter) LessOrEqualThan(msgID interface{}) *MockMessageID_LessOrEqualThan_Call { + return &MockMessageID_LessOrEqualThan_Call{Call: _e.mock.On("LessOrEqualThan", msgID)} +} + +func (_c *MockMessageID_LessOrEqualThan_Call) Run(run func(msgID []byte)) *MockMessageID_LessOrEqualThan_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]byte)) + }) + return _c +} + +func (_c *MockMessageID_LessOrEqualThan_Call) Return(_a0 bool, _a1 error) *MockMessageID_LessOrEqualThan_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMessageID_LessOrEqualThan_Call) RunAndReturn(run func([]byte) (bool, error)) *MockMessageID_LessOrEqualThan_Call { + _c.Call.Return(run) + return _c +} + +// Serialize provides a mock function with given fields: +func (_m *MockMessageID) Serialize() []byte { + ret := _m.Called() + + var r0 []byte + if rf, ok := ret.Get(0).(func() []byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + return r0 +} + +// MockMessageID_Serialize_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Serialize' +type MockMessageID_Serialize_Call struct { + *mock.Call +} + +// Serialize is a helper method to define mock.On call +func (_e *MockMessageID_Expecter) Serialize() *MockMessageID_Serialize_Call { + return &MockMessageID_Serialize_Call{Call: _e.mock.On("Serialize")} +} + +func (_c *MockMessageID_Serialize_Call) Run(run func()) *MockMessageID_Serialize_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMessageID_Serialize_Call) Return(_a0 []byte) *MockMessageID_Serialize_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMessageID_Serialize_Call) RunAndReturn(run func() []byte) *MockMessageID_Serialize_Call { + _c.Call.Return(run) + return _c +} + +// NewMockMessageID creates a new instance of MockMessageID. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockMessageID(t interface { + mock.TestingT + Cleanup(func()) +}) *MockMessageID { + mock := &MockMessageID{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/mq/msgstream/mqwrapper/nmq/nmq_client.go b/pkg/mq/msgstream/mqwrapper/nmq/nmq_client.go index 680addf8387b4..774adb5e7fb27 100644 --- a/pkg/mq/msgstream/mqwrapper/nmq/nmq_client.go +++ b/pkg/mq/msgstream/mqwrapper/nmq/nmq_client.go @@ -17,7 +17,9 @@ package nmq import ( + "context" "fmt" + "net" "strconv" "time" @@ -38,11 +40,33 @@ type nmqClient struct { conn *nats.Conn } +type nmqDialer struct { + ctx func() context.Context +} + +func (d *nmqDialer) Dial(network, address string) (net.Conn, error) { + ctx := d.ctx() + + dial := &net.Dialer{} + + // keep default 2s timeout + if _, ok := ctx.Deadline(); !ok { + dial.Timeout = 2 * time.Second + } + + return dial.DialContext(ctx, network, address) +} + // NewClientWithDefaultOptions returns a new NMQ client with default options. // It retrieves the NMQ client URL from the server configuration. -func NewClientWithDefaultOptions() (mqwrapper.Client, error) { +func NewClientWithDefaultOptions(ctx context.Context) (mqwrapper.Client, error) { url := Nmq.ClientURL() - return NewClient(url) + + opt := nats.SetCustomDialer(&nmqDialer{ + ctx: func() context.Context { return ctx }, + }) + + return NewClient(url, opt) } // NewClient returns a new nmqClient object diff --git a/pkg/mq/msgstream/mqwrapper/nmq/nmq_client_test.go b/pkg/mq/msgstream/mqwrapper/nmq/nmq_client_test.go index 5d86482c8052a..c32b325e9828a 100644 --- a/pkg/mq/msgstream/mqwrapper/nmq/nmq_client_test.go +++ b/pkg/mq/msgstream/mqwrapper/nmq/nmq_client_test.go @@ -38,6 +38,43 @@ func Test_NewNmqClient(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, client) client.Close() + + tests := []struct { + description string + withTimeout bool + ctxTimeouted bool + expectErr bool + }{ + {"without context", false, false, false}, + {"without timeout context, no timeout", true, false, false}, + {"without timeout context, timeout", true, true, true}, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + ctx := context.Background() + var cancel context.CancelFunc + if test.withTimeout { + ctx, cancel = context.WithTimeout(ctx, time.Second) + if test.ctxTimeouted { + cancel() + } else { + defer cancel() + } + } + + client, err := NewClientWithDefaultOptions(ctx) + + if test.expectErr { + assert.Error(t, err) + assert.Nil(t, client) + } else { + assert.NoError(t, err) + assert.NotNil(t, client) + client.Close() + } + }) + } } func TestNmqClient_CreateProducer(t *testing.T) { diff --git a/pkg/mq/msgstream/mqwrapper/nmq/nmq_consumer.go b/pkg/mq/msgstream/mqwrapper/nmq/nmq_consumer.go index c972638f91733..43c4dcee49df0 100644 --- a/pkg/mq/msgstream/mqwrapper/nmq/nmq_consumer.go +++ b/pkg/mq/msgstream/mqwrapper/nmq/nmq_consumer.go @@ -149,7 +149,6 @@ func (nc *Consumer) GetLatestMsgID() (mqwrapper.MessageID, error) { // CheckTopicValid verifies if the given topic is valid for this consumer. // 1. topic should exist. -// 2. topic should be empty. func (nc *Consumer) CheckTopicValid(topic string) error { if err := nc.closed(); err != nil { return err @@ -162,18 +161,13 @@ func (nc *Consumer) CheckTopicValid(topic string) error { } // check if topic valid or exist. - streamInfo, err := nc.js.StreamInfo(topic) + _, err := nc.js.StreamInfo(topic) if errors.Is(err, nats.ErrStreamNotFound) { - return merr.WrapErrTopicNotFound(topic, err.Error()) + return merr.WrapErrMqTopicNotFound(topic, err.Error()) } else if err != nil { log.Warn("fail to get stream info of nats", zap.String("topic", nc.topic), zap.Error(err)) return errors.Wrap(err, "failed to get stream info of nats jetstream") } - - // check if topic stream is empty. - if streamInfo.State.Msgs > 0 { - return merr.WrapErrTopicNotEmpty(topic, "stream in nats is not empty") - } return nil } diff --git a/pkg/mq/msgstream/mqwrapper/nmq/nmq_consumer_test.go b/pkg/mq/msgstream/mqwrapper/nmq/nmq_consumer_test.go index 11f500546bc6c..7425493871243 100644 --- a/pkg/mq/msgstream/mqwrapper/nmq/nmq_consumer_test.go +++ b/pkg/mq/msgstream/mqwrapper/nmq/nmq_consumer_test.go @@ -26,7 +26,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" - "github.com/milvus-io/milvus/pkg/util/merr" ) func TestNatsConsumer_Subscription(t *testing.T) { @@ -219,7 +218,7 @@ func TestCheckTopicValid(t *testing.T) { err = consumer.CheckTopicValid("BadTopic") assert.Error(t, err) - // non empty topic should fail + // not empty topic can pass pub, err := client.CreateProducer(mqwrapper.ProducerOptions{ Topic: topic, }) @@ -230,7 +229,7 @@ func TestCheckTopicValid(t *testing.T) { assert.NoError(t, err) err = consumer.CheckTopicValid(topic) - assert.ErrorIs(t, err, merr.ErrTopicNotEmpty) + assert.NoError(t, err) consumer.Close() err = consumer.CheckTopicValid(topic) diff --git a/pkg/mq/msgstream/mqwrapper/nmq/nmq_id_test.go b/pkg/mq/msgstream/mqwrapper/nmq/nmq_id_test.go index 4044031d04ce4..5593499f0095b 100644 --- a/pkg/mq/msgstream/mqwrapper/nmq/nmq_id_test.go +++ b/pkg/mq/msgstream/mqwrapper/nmq/nmq_id_test.go @@ -79,7 +79,6 @@ func Test_Equal(t *testing.T) { ret, err := rid1.Equal(rid1.Serialize()) assert.NoError(t, err) assert.True(t, ret) - } { diff --git a/pkg/mq/msgstream/mqwrapper/nmq/nmq_message.go b/pkg/mq/msgstream/mqwrapper/nmq/nmq_message.go index cdcb2e281afe3..833245aa85862 100644 --- a/pkg/mq/msgstream/mqwrapper/nmq/nmq_message.go +++ b/pkg/mq/msgstream/mqwrapper/nmq/nmq_message.go @@ -19,9 +19,10 @@ package nmq import ( "log" - "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/nats-io/nats.go" "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" ) // Check nmqMessage implements ConsumerMessage diff --git a/pkg/mq/msgstream/mqwrapper/nmq/nmq_producer.go b/pkg/mq/msgstream/mqwrapper/nmq/nmq_producer.go index e6335f8c65b02..26c627e5aa1fb 100644 --- a/pkg/mq/msgstream/mqwrapper/nmq/nmq_producer.go +++ b/pkg/mq/msgstream/mqwrapper/nmq/nmq_producer.go @@ -19,12 +19,13 @@ package nmq import ( "context" + "github.com/nats-io/nats.go" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/timerecord" - "github.com/nats-io/nats.go" - "go.uber.org/zap" ) var _ mqwrapper.Producer = (*nmqProducer)(nil) diff --git a/pkg/mq/msgstream/mqwrapper/nmq/nmq_server_test.go b/pkg/mq/msgstream/mqwrapper/nmq/nmq_server_test.go index 53a5b1509d2ed..c65e169919671 100644 --- a/pkg/mq/msgstream/mqwrapper/nmq/nmq_server_test.go +++ b/pkg/mq/msgstream/mqwrapper/nmq/nmq_server_test.go @@ -22,9 +22,10 @@ import ( "testing" "time" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/nats-io/nats-server/v2/server" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/paramtable" ) var natsServerAddress string diff --git a/pkg/mq/msgstream/mqwrapper/producer.go b/pkg/mq/msgstream/mqwrapper/producer.go index 270e1dec5149f..caf43688d977b 100644 --- a/pkg/mq/msgstream/mqwrapper/producer.go +++ b/pkg/mq/msgstream/mqwrapper/producer.go @@ -40,7 +40,7 @@ type ProducerMessage struct { // Producer is the interface that provides operations of producer type Producer interface { // return the topic which producer is publishing to - //Topic() string + // Topic() string // publish a message Send(ctx context.Context, message *ProducerMessage) (MessageID, error) diff --git a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client.go b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client.go index b56a809f40fa5..3d71d7c76e395 100644 --- a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client.go +++ b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client.go @@ -48,8 +48,10 @@ var once sync.Once // NewClient creates a pulsarClient object // according to the parameter opts of type pulsar.ClientOptions func NewClient(tenant string, namespace string, opts pulsar.ClientOptions) (*pulsarClient, error) { + var err error once.Do(func() { - c, err := pulsar.NewClient(opts) + var c pulsar.Client + c, err = pulsar.NewClient(opts) if err != nil { log.Error("Failed to set pulsar client: ", zap.Error(err)) return @@ -61,7 +63,7 @@ func NewClient(tenant string, namespace string, opts pulsar.ClientOptions) (*pul } sc = cli }) - return sc, nil + return sc, err } // CreateProducer create a pulsar producer from options diff --git a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client_test.go b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client_test.go index 68d5f48764a61..751532a75017a 100644 --- a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client_test.go +++ b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_client_test.go @@ -135,7 +135,7 @@ func Consume1(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, consumer.Ack(msg) VerifyMessage(t, msg) (*total)++ - //log.Debug("total", zap.Int("val", *total)) + // log.Debug("total", zap.Int("val", *total)) } } c <- msg.ID() @@ -174,7 +174,7 @@ func Consume2(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, consumer.Ack(msg) VerifyMessage(t, msg) (*total)++ - //log.Debug("total", zap.Int("val", *total)) + // log.Debug("total", zap.Int("val", *total)) } } } @@ -201,7 +201,7 @@ func Consume3(ctx context.Context, t *testing.T, pc *pulsarClient, topic string, consumer.Ack(msg) VerifyMessage(t, msg) (*total)++ - //log.Debug("total", zap.Int("val", *total)) + // log.Debug("total", zap.Int("val", *total)) } } } @@ -284,7 +284,7 @@ func Consume21(ctx context.Context, t *testing.T, pc *pulsarClient, topic string v := BytesToInt(msg.Payload()) log.Info("RECV", zap.Any("v", v)) (*total)++ - //log.Debug("total", zap.Int("val", *total)) + // log.Debug("total", zap.Int("val", *total)) } } c <- &pulsarID{messageID: msg.ID()} @@ -324,7 +324,7 @@ func Consume22(ctx context.Context, t *testing.T, pc *pulsarClient, topic string v := BytesToInt(msg.Payload()) log.Info("RECV", zap.Any("v", v)) (*total)++ - //log.Debug("total", zap.Int("val", *total)) + // log.Debug("total", zap.Int("val", *total)) } } } @@ -352,7 +352,7 @@ func Consume23(ctx context.Context, t *testing.T, pc *pulsarClient, topic string v := BytesToInt(msg.Payload()) log.Info("RECV", zap.Any("v", v)) (*total)++ - //log.Debug("total", zap.Int("val", *total)) + // log.Debug("total", zap.Int("val", *total)) } } } diff --git a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_consumer.go b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_consumer.go index eb6c99d4c7945..9a644a6e36a73 100644 --- a/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_consumer.go +++ b/pkg/mq/msgstream/mqwrapper/pulsar/pulsar_consumer.go @@ -28,7 +28,6 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" - "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/retry" ) @@ -158,16 +157,11 @@ func (pc *Consumer) GetLatestMsgID() (mqwrapper.MessageID, error) { } func (pc *Consumer) CheckTopicValid(topic string) error { - latestMsgID, err := pc.GetLatestMsgID() + _, err := pc.GetLatestMsgID() // Pulsar creates that topic under the namespace provided in the topic name automatically if err != nil { return err } - - if !latestMsgID.AtEarliestPosition() { - return merr.WrapErrTopicNotEmpty(topic, "topic is not empty") - } - log.Info("created topic is empty", zap.String("topic", topic)) return nil } diff --git a/pkg/mq/msgstream/msg.go b/pkg/mq/msgstream/msg.go index f05f8684072af..2db757a0f6ed5 100644 --- a/pkg/mq/msgstream/msg.go +++ b/pkg/mq/msgstream/msg.go @@ -23,10 +23,10 @@ import ( "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/typeutil" diff --git a/pkg/mq/msgstream/msg_for_collection.go b/pkg/mq/msgstream/msg_for_collection.go new file mode 100644 index 0000000000000..4411684cce87f --- /dev/null +++ b/pkg/mq/msgstream/msg_for_collection.go @@ -0,0 +1,190 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msgstream + +import ( + "github.com/golang/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +) + +// LoadCollectionMsg is a message pack that contains load collection request +type LoadCollectionMsg struct { + BaseMsg + milvuspb.LoadCollectionRequest +} + +// interface implementation validation +var _ TsMsg = &LoadCollectionMsg{} + +func (l *LoadCollectionMsg) ID() UniqueID { + return l.Base.MsgID +} + +func (l *LoadCollectionMsg) SetID(id UniqueID) { + l.Base.MsgID = id +} + +func (l *LoadCollectionMsg) Type() MsgType { + return l.Base.MsgType +} + +func (l *LoadCollectionMsg) SourceID() int64 { + return l.Base.SourceID +} + +func (l *LoadCollectionMsg) Marshal(input TsMsg) (MarshalType, error) { + loadCollectionMsg := input.(*LoadCollectionMsg) + createIndexRequest := &loadCollectionMsg.LoadCollectionRequest + mb, err := proto.Marshal(createIndexRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (l *LoadCollectionMsg) Unmarshal(input MarshalType) (TsMsg, error) { + loadCollectionRequest := milvuspb.LoadCollectionRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, &loadCollectionRequest) + if err != nil { + return nil, err + } + loadCollectionMsg := &LoadCollectionMsg{LoadCollectionRequest: loadCollectionRequest} + loadCollectionMsg.BeginTimestamp = loadCollectionMsg.GetBase().GetTimestamp() + loadCollectionMsg.EndTimestamp = loadCollectionMsg.GetBase().GetTimestamp() + + return loadCollectionMsg, nil +} + +func (l *LoadCollectionMsg) Size() int { + return proto.Size(&l.LoadCollectionRequest) +} + +// ReleaseCollectionMsg is a message pack that contains release collection request +type ReleaseCollectionMsg struct { + BaseMsg + milvuspb.ReleaseCollectionRequest +} + +var _ TsMsg = &ReleaseCollectionMsg{} + +func (r *ReleaseCollectionMsg) ID() UniqueID { + return r.Base.MsgID +} + +func (r *ReleaseCollectionMsg) SetID(id UniqueID) { + r.Base.MsgID = id +} + +func (r *ReleaseCollectionMsg) Type() MsgType { + return r.Base.MsgType +} + +func (r *ReleaseCollectionMsg) SourceID() int64 { + return r.Base.SourceID +} + +func (r *ReleaseCollectionMsg) Marshal(input TsMsg) (MarshalType, error) { + releaseCollectionMsg := input.(*ReleaseCollectionMsg) + releaseCollectionRequest := &releaseCollectionMsg.ReleaseCollectionRequest + mb, err := proto.Marshal(releaseCollectionRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (r *ReleaseCollectionMsg) Unmarshal(input MarshalType) (TsMsg, error) { + releaseCollectionRequest := milvuspb.ReleaseCollectionRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, &releaseCollectionRequest) + if err != nil { + return nil, err + } + releaseCollectionMsg := &ReleaseCollectionMsg{ReleaseCollectionRequest: releaseCollectionRequest} + releaseCollectionMsg.BeginTimestamp = releaseCollectionMsg.GetBase().GetTimestamp() + releaseCollectionMsg.EndTimestamp = releaseCollectionMsg.GetBase().GetTimestamp() + + return releaseCollectionMsg, nil +} + +func (r *ReleaseCollectionMsg) Size() int { + return proto.Size(&r.ReleaseCollectionRequest) +} + +type FlushMsg struct { + BaseMsg + milvuspb.FlushRequest +} + +var _ TsMsg = &FlushMsg{} + +func (f *FlushMsg) ID() UniqueID { + return f.Base.MsgID +} + +func (f *FlushMsg) SetID(id UniqueID) { + f.Base.MsgID = id +} + +func (f *FlushMsg) Type() MsgType { + return f.Base.MsgType +} + +func (f *FlushMsg) SourceID() int64 { + return f.Base.SourceID +} + +func (f *FlushMsg) Marshal(input TsMsg) (MarshalType, error) { + flushMsg := input.(*FlushMsg) + flushRequest := &flushMsg.FlushRequest + mb, err := proto.Marshal(flushRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (f *FlushMsg) Unmarshal(input MarshalType) (TsMsg, error) { + flushRequest := milvuspb.FlushRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, &flushRequest) + if err != nil { + return nil, err + } + flushMsg := &FlushMsg{FlushRequest: flushRequest} + flushMsg.BeginTimestamp = flushMsg.GetBase().GetTimestamp() + flushMsg.EndTimestamp = flushMsg.GetBase().GetTimestamp() + + return flushMsg, nil +} + +func (f *FlushMsg) Size() int { + return proto.Size(&f.FlushRequest) +} diff --git a/pkg/mq/msgstream/msg_for_collection_test.go b/pkg/mq/msgstream/msg_for_collection_test.go new file mode 100644 index 0000000000000..5f9f42a7480d2 --- /dev/null +++ b/pkg/mq/msgstream/msg_for_collection_test.go @@ -0,0 +1,139 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msgstream + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +) + +func TestFlushMsg(t *testing.T) { + var msg TsMsg = &FlushMsg{ + FlushRequest: milvuspb.FlushRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Flush, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + DbName: "unit_db", + CollectionNames: []string{"col1", "col2"}, + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_Flush, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &FlushMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + + assert.True(t, msg.Size() > 0) +} + +func TestLoadCollection(t *testing.T) { + var msg TsMsg = &LoadCollectionMsg{ + LoadCollectionRequest: milvuspb.LoadCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadCollection, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + DbName: "unit_db", + CollectionName: "col1", + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_LoadCollection, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &LoadCollectionMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + + assert.True(t, msg.Size() > 0) +} + +func TestReleaseCollection(t *testing.T) { + var msg TsMsg = &ReleaseCollectionMsg{ + ReleaseCollectionRequest: milvuspb.ReleaseCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_ReleaseCollection, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + DbName: "unit_db", + CollectionName: "col1", + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_ReleaseCollection, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &ReleaseCollectionMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + + assert.True(t, msg.Size() > 0) +} diff --git a/pkg/mq/msgstream/msg_for_database.go b/pkg/mq/msgstream/msg_for_database.go new file mode 100644 index 0000000000000..b08bc98e01b49 --- /dev/null +++ b/pkg/mq/msgstream/msg_for_database.go @@ -0,0 +1,133 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msgstream + +import ( + "github.com/golang/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +) + +type CreateDatabaseMsg struct { + BaseMsg + milvuspb.CreateDatabaseRequest +} + +var _ TsMsg = &CreateDatabaseMsg{} + +func (c *CreateDatabaseMsg) ID() UniqueID { + return c.Base.MsgID +} + +func (c *CreateDatabaseMsg) SetID(id UniqueID) { + c.Base.MsgID = id +} + +func (c *CreateDatabaseMsg) Type() MsgType { + return c.Base.MsgType +} + +func (c *CreateDatabaseMsg) SourceID() int64 { + return c.Base.SourceID +} + +func (c *CreateDatabaseMsg) Marshal(input TsMsg) (MarshalType, error) { + createDataBaseMsg := input.(*CreateDatabaseMsg) + createDatabaseRequest := &createDataBaseMsg.CreateDatabaseRequest + mb, err := proto.Marshal(createDatabaseRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (c *CreateDatabaseMsg) Unmarshal(input MarshalType) (TsMsg, error) { + createDatabaseRequest := milvuspb.CreateDatabaseRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, &createDatabaseRequest) + if err != nil { + return nil, err + } + createDatabaseMsg := &CreateDatabaseMsg{CreateDatabaseRequest: createDatabaseRequest} + createDatabaseMsg.BeginTimestamp = createDatabaseMsg.GetBase().GetTimestamp() + createDatabaseMsg.EndTimestamp = createDatabaseMsg.GetBase().GetTimestamp() + + return createDatabaseMsg, nil +} + +func (c *CreateDatabaseMsg) Size() int { + return proto.Size(&c.CreateDatabaseRequest) +} + +type DropDatabaseMsg struct { + BaseMsg + milvuspb.DropDatabaseRequest +} + +var _ TsMsg = &DropDatabaseMsg{} + +func (d *DropDatabaseMsg) ID() UniqueID { + return d.Base.MsgID +} + +func (d *DropDatabaseMsg) SetID(id UniqueID) { + d.Base.MsgID = id +} + +func (d *DropDatabaseMsg) Type() MsgType { + return d.Base.MsgType +} + +func (d *DropDatabaseMsg) SourceID() int64 { + return d.Base.SourceID +} + +func (d *DropDatabaseMsg) Marshal(input TsMsg) (MarshalType, error) { + dropDataBaseMsg := input.(*DropDatabaseMsg) + dropDatabaseRequest := &dropDataBaseMsg.DropDatabaseRequest + mb, err := proto.Marshal(dropDatabaseRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (d *DropDatabaseMsg) Unmarshal(input MarshalType) (TsMsg, error) { + dropDatabaseRequest := milvuspb.DropDatabaseRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, &dropDatabaseRequest) + if err != nil { + return nil, err + } + dropDatabaseMsg := &DropDatabaseMsg{DropDatabaseRequest: dropDatabaseRequest} + dropDatabaseMsg.BeginTimestamp = dropDatabaseMsg.GetBase().GetTimestamp() + dropDatabaseMsg.EndTimestamp = dropDatabaseMsg.GetBase().GetTimestamp() + + return dropDatabaseMsg, nil +} + +func (d *DropDatabaseMsg) Size() int { + return proto.Size(&d.DropDatabaseRequest) +} diff --git a/pkg/mq/msgstream/msg_for_database_test.go b/pkg/mq/msgstream/msg_for_database_test.go new file mode 100644 index 0000000000000..d7cfc80eebd21 --- /dev/null +++ b/pkg/mq/msgstream/msg_for_database_test.go @@ -0,0 +1,100 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msgstream + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +) + +func TestCreateDatabase(t *testing.T) { + var msg TsMsg = &CreateDatabaseMsg{ + CreateDatabaseRequest: milvuspb.CreateDatabaseRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreateDatabase, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + DbName: "unit_db", + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_CreateDatabase, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &ReleaseCollectionMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + + assert.True(t, msg.Size() > 0) +} + +func TestDropDatabase(t *testing.T) { + var msg TsMsg = &DropDatabaseMsg{ + DropDatabaseRequest: milvuspb.DropDatabaseRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DropDatabase, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + DbName: "unit_db", + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_DropDatabase, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &DropDatabaseMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + + assert.True(t, msg.Size() > 0) +} diff --git a/pkg/mq/msgstream/msg_for_index.go b/pkg/mq/msgstream/msg_for_index.go new file mode 100644 index 0000000000000..063e008daa74a --- /dev/null +++ b/pkg/mq/msgstream/msg_for_index.go @@ -0,0 +1,142 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msgstream + +import ( + "github.com/golang/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +) + +// CreateIndexMsg is a message pack that contains create index request +type CreateIndexMsg struct { + BaseMsg + milvuspb.CreateIndexRequest +} + +// interface implementation validation +var _ TsMsg = &CreateIndexMsg{} + +// ID returns the ID of this message pack +func (it *CreateIndexMsg) ID() UniqueID { + return it.Base.MsgID +} + +// SetID set the ID of this message pack +func (it *CreateIndexMsg) SetID(id UniqueID) { + it.Base.MsgID = id +} + +// Type returns the type of this message pack +func (it *CreateIndexMsg) Type() MsgType { + return it.Base.MsgType +} + +// SourceID indicates which component generated this message +func (it *CreateIndexMsg) SourceID() int64 { + return it.Base.SourceID +} + +// Marshal is used to serialize a message pack to byte array +func (it *CreateIndexMsg) Marshal(input TsMsg) (MarshalType, error) { + createIndexMsg := input.(*CreateIndexMsg) + createIndexRequest := &createIndexMsg.CreateIndexRequest + mb, err := proto.Marshal(createIndexRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +// Unmarshal is used to deserialize a message pack from byte array +func (it *CreateIndexMsg) Unmarshal(input MarshalType) (TsMsg, error) { + createIndexRequest := milvuspb.CreateIndexRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, &createIndexRequest) + if err != nil { + return nil, err + } + createIndexMsg := &CreateIndexMsg{CreateIndexRequest: createIndexRequest} + createIndexMsg.BeginTimestamp = createIndexMsg.GetBase().GetTimestamp() + createIndexMsg.EndTimestamp = createIndexMsg.GetBase().GetTimestamp() + + return createIndexMsg, nil +} + +func (it *CreateIndexMsg) Size() int { + return proto.Size(&it.CreateIndexRequest) +} + +// DropIndexMsg is a message pack that contains drop index request +type DropIndexMsg struct { + BaseMsg + milvuspb.DropIndexRequest +} + +var _ TsMsg = &DropIndexMsg{} + +func (d *DropIndexMsg) ID() UniqueID { + return d.Base.MsgID +} + +func (d *DropIndexMsg) SetID(id UniqueID) { + d.Base.MsgID = id +} + +func (d *DropIndexMsg) Type() MsgType { + return d.Base.MsgType +} + +func (d *DropIndexMsg) SourceID() int64 { + return d.Base.SourceID +} + +func (d *DropIndexMsg) Marshal(input TsMsg) (MarshalType, error) { + dropIndexMsg := input.(*DropIndexMsg) + dropIndexRequest := &dropIndexMsg.DropIndexRequest + mb, err := proto.Marshal(dropIndexRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (d *DropIndexMsg) Unmarshal(input MarshalType) (TsMsg, error) { + dropIndexRequest := milvuspb.DropIndexRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, &dropIndexRequest) + if err != nil { + return nil, err + } + dropIndexMsg := &DropIndexMsg{DropIndexRequest: dropIndexRequest} + dropIndexMsg.BeginTimestamp = dropIndexMsg.GetBase().GetTimestamp() + dropIndexMsg.EndTimestamp = dropIndexMsg.GetBase().GetTimestamp() + + return dropIndexMsg, nil +} + +func (d *DropIndexMsg) Size() int { + return proto.Size(&d.DropIndexRequest) +} diff --git a/pkg/mq/msgstream/msg_for_index_test.go b/pkg/mq/msgstream/msg_for_index_test.go new file mode 100644 index 0000000000000..cccc1e09b9bea --- /dev/null +++ b/pkg/mq/msgstream/msg_for_index_test.go @@ -0,0 +1,100 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msgstream + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +) + +func TestCreateIndex(t *testing.T) { + var msg TsMsg = &CreateIndexMsg{ + CreateIndexRequest: milvuspb.CreateIndexRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreateIndex, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + DbName: "unit_db", + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_CreateIndex, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &ReleaseCollectionMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + + assert.True(t, msg.Size() > 0) +} + +func TestDropIndex(t *testing.T) { + var msg TsMsg = &DropIndexMsg{ + DropIndexRequest: milvuspb.DropIndexRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DropIndex, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + DbName: "unit_db", + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_DropIndex, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &ReleaseCollectionMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + + assert.True(t, msg.Size() > 0) +} diff --git a/pkg/mq/msgstream/msg_test.go b/pkg/mq/msgstream/msg_test.go index 4acad24596642..20e7b4c81c1e7 100644 --- a/pkg/mq/msgstream/msg_test.go +++ b/pkg/mq/msgstream/msg_test.go @@ -20,10 +20,11 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/stretchr/testify/assert" ) func TestBaseMsg(t *testing.T) { diff --git a/pkg/mq/msgstream/msgstream.go b/pkg/mq/msgstream/msgstream.go index 05b967300a62a..184d44967d098 100644 --- a/pkg/mq/msgstream/msgstream.go +++ b/pkg/mq/msgstream/msgstream.go @@ -20,7 +20,6 @@ import ( "context" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -62,12 +61,14 @@ type MsgStream interface { GetProduceChannels() []string Broadcast(*MsgPack) (map[string][]MessageID, error) - AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) + AsConsumer(ctx context.Context, channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) error Chan() <-chan *MsgPack - Seek(offset []*MsgPosition) error + Seek(ctx context.Context, offset []*MsgPosition) error GetLatestMsgID(channel string) (MessageID, error) CheckTopicValid(channel string) error + + EnableProduce(can bool) } type Factory interface { diff --git a/pkg/mq/msgstream/msgstream_util.go b/pkg/mq/msgstream/msgstream_util.go index 9d05d4bda936a..f442eac838dc7 100644 --- a/pkg/mq/msgstream/msgstream_util.go +++ b/pkg/mq/msgstream/msgstream_util.go @@ -18,10 +18,13 @@ package msgstream import ( "context" + "fmt" + "math/rand" "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" ) // unsubscribeChannels create consumer first, and unsubscribe channel through msgStream.close() @@ -34,3 +37,25 @@ func UnsubscribeChannels(ctx context.Context, factory Factory, subName string, c panic(err) } } + +func GetChannelLatestMsgID(ctx context.Context, factory Factory, channelName string) ([]byte, error) { + dmlStream, err := factory.NewMsgStream(ctx) + if err != nil { + log.Warn("fail to NewMsgStream", zap.String("channelName", channelName), zap.Error(err)) + return nil, err + } + defer dmlStream.Close() + + subName := fmt.Sprintf("get-latest_msg_id-%s-%d", channelName, rand.Int()) + err = dmlStream.AsConsumer(ctx, []string{channelName}, subName, mqwrapper.SubscriptionPositionUnknown) + if err != nil { + log.Warn("fail to AsConsumer", zap.String("channelName", channelName), zap.Error(err)) + return nil, err + } + id, err := dmlStream.GetLatestMsgID(channelName) + if err != nil { + log.Error("fail to GetLatestMsgID", zap.String("channelName", channelName), zap.Error(err)) + return nil, err + } + return id.Serialize(), nil +} diff --git a/pkg/mq/msgstream/msgstream_util_test.go b/pkg/mq/msgstream/msgstream_util_test.go index 5bdeab4edf039..69fb4a8622fa1 100644 --- a/pkg/mq/msgstream/msgstream_util_test.go +++ b/pkg/mq/msgstream/msgstream_util_test.go @@ -20,7 +20,11 @@ import ( "context" "testing" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" ) func TestPulsarMsgUtil(t *testing.T) { @@ -36,3 +40,43 @@ func TestPulsarMsgUtil(t *testing.T) { UnsubscribeChannels(ctx, pmsFactory, "sub", []string{"test"}) } + +func TestGetLatestMsgID(t *testing.T) { + factory := NewMockMqFactory() + ctx := context.Background() + { + factory.NewMsgStreamFunc = func(ctx context.Context) (MsgStream, error) { + return nil, errors.New("mock") + } + _, err := GetChannelLatestMsgID(ctx, factory, "test") + assert.Error(t, err) + } + stream := NewMockMsgStream(t) + factory.NewMsgStreamFunc = func(ctx context.Context) (MsgStream, error) { + return stream, nil + } + stream.EXPECT().Close().Return() + + { + stream.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock")).Once() + _, err := GetChannelLatestMsgID(ctx, factory, "test") + assert.Error(t, err) + } + + { + stream.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + stream.EXPECT().GetLatestMsgID(mock.Anything).Return(nil, errors.New("mock")).Once() + _, err := GetChannelLatestMsgID(ctx, factory, "test") + assert.Error(t, err) + } + + { + mockMsgID := mqwrapper.NewMockMessageID(t) + mockMsgID.EXPECT().Serialize().Return([]byte("mock")).Once() + stream.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + stream.EXPECT().GetLatestMsgID(mock.Anything).Return(mockMsgID, nil).Once() + id, err := GetChannelLatestMsgID(ctx, factory, "test") + assert.NoError(t, err) + assert.Equal(t, []byte("mock"), id) + } +} diff --git a/pkg/mq/msgstream/repack_func.go b/pkg/mq/msgstream/repack_func.go index 54c86cc83983a..bbc38b64fe64d 100644 --- a/pkg/mq/msgstream/repack_func.go +++ b/pkg/mq/msgstream/repack_func.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) diff --git a/pkg/mq/msgstream/stream_bench_test.go b/pkg/mq/msgstream/stream_bench_test.go index 2c0994f794d4e..823fbf637d437 100644 --- a/pkg/mq/msgstream/stream_bench_test.go +++ b/pkg/mq/msgstream/stream_bench_test.go @@ -8,11 +8,12 @@ import ( "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper/nmq" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" ) func BenchmarkProduceAndConsumeNatsMQ(b *testing.B) { @@ -25,7 +26,7 @@ func BenchmarkProduceAndConsumeNatsMQ(b *testing.B) { cfg.Opts.StoreDir = storeDir nmq.MustInitNatsMQ(cfg) - client, err := nmq.NewClientWithDefaultOptions() + client, err := nmq.NewClientWithDefaultOptions(context.Background()) if err != nil { panic(err) } diff --git a/pkg/mq/msgstream/stream_test.go b/pkg/mq/msgstream/stream_test.go index d297ea4365919..fea2746fbe37c 100644 --- a/pkg/mq/msgstream/stream_test.go +++ b/pkg/mq/msgstream/stream_test.go @@ -9,11 +9,12 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/stretchr/testify/assert" - "go.uber.org/zap" ) func testStreamOperation(t *testing.T, mqClient mqwrapper.Client) { diff --git a/pkg/mq/msgstream/trace.go b/pkg/mq/msgstream/trace.go index 55719ae50eb9b..db1d027615750 100644 --- a/pkg/mq/msgstream/trace.go +++ b/pkg/mq/msgstream/trace.go @@ -19,11 +19,12 @@ package msgstream import ( "context" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) // ExtractCtx extracts trace span from msg.properties. diff --git a/pkg/mq/msgstream/unmarshal.go b/pkg/mq/msgstream/unmarshal.go index 89e66e1f79275..31cee49d8bca5 100644 --- a/pkg/mq/msgstream/unmarshal.go +++ b/pkg/mq/msgstream/unmarshal.go @@ -18,6 +18,7 @@ package msgstream import ( "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) @@ -62,6 +63,16 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher { dropPartitionMsg := DropPartitionMsg{} dataNodeTtMsg := DataNodeTtMsg{} + createIndexMsg := CreateIndexMsg{} + dropIndexMsg := DropIndexMsg{} + + loadCollectionMsg := LoadCollectionMsg{} + releaseCollectionMsg := ReleaseCollectionMsg{} + flushMsg := FlushMsg{} + + createDatabaseMsg := CreateDatabaseMsg{} + dropDatabaseMsg := DropDatabaseMsg{} + p := &ProtoUnmarshalDispatcher{} p.TempMap = make(map[commonpb.MsgType]UnmarshalFunc) p.TempMap[commonpb.MsgType_Insert] = insertMsg.Unmarshal @@ -72,6 +83,13 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher { p.TempMap[commonpb.MsgType_CreatePartition] = createPartitionMsg.Unmarshal p.TempMap[commonpb.MsgType_DropPartition] = dropPartitionMsg.Unmarshal p.TempMap[commonpb.MsgType_DataNodeTt] = dataNodeTtMsg.Unmarshal + p.TempMap[commonpb.MsgType_CreateIndex] = createIndexMsg.Unmarshal + p.TempMap[commonpb.MsgType_DropIndex] = dropIndexMsg.Unmarshal + p.TempMap[commonpb.MsgType_LoadCollection] = loadCollectionMsg.Unmarshal + p.TempMap[commonpb.MsgType_ReleaseCollection] = releaseCollectionMsg.Unmarshal + p.TempMap[commonpb.MsgType_Flush] = flushMsg.Unmarshal + p.TempMap[commonpb.MsgType_CreateDatabase] = createDatabaseMsg.Unmarshal + p.TempMap[commonpb.MsgType_DropDatabase] = dropDatabaseMsg.Unmarshal return p } diff --git a/pkg/mq/msgstream/unmarshal_test.go b/pkg/mq/msgstream/unmarshal_test.go index b2413e743e7d5..962102bafed5c 100644 --- a/pkg/mq/msgstream/unmarshal_test.go +++ b/pkg/mq/msgstream/unmarshal_test.go @@ -20,9 +20,10 @@ import ( "testing" "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - "github.com/stretchr/testify/assert" ) func Test_ProtoUnmarshalDispatcher(t *testing.T) { diff --git a/pkg/rules.go b/pkg/rules.go new file mode 100644 index 0000000000000..5bc3422c9b450 --- /dev/null +++ b/pkg/rules.go @@ -0,0 +1,409 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gorules + +import ( + "github.com/quasilyte/go-ruleguard/dsl" +) + +// This is a collection of rules for ruleguard: https://github.com/quasilyte/go-ruleguard + +// Remove extra conversions: mdempsky/unconvert +func unconvert(m dsl.Matcher) { + m.Match("int($x)").Where(m["x"].Type.Is("int") && !m["x"].Const).Report("unnecessary conversion").Suggest("$x") + + m.Match("float32($x)").Where(m["x"].Type.Is("float32") && !m["x"].Const).Report("unnecessary conversion").Suggest("$x") + m.Match("float64($x)").Where(m["x"].Type.Is("float64") && !m["x"].Const).Report("unnecessary conversion").Suggest("$x") + + // m.Match("byte($x)").Where(m["x"].Type.Is("byte")).Report("unnecessary conversion").Suggest("$x") + // m.Match("rune($x)").Where(m["x"].Type.Is("rune")).Report("unnecessary conversion").Suggest("$x") + m.Match("bool($x)").Where(m["x"].Type.Is("bool") && !m["x"].Const).Report("unnecessary conversion").Suggest("$x") + + m.Match("int8($x)").Where(m["x"].Type.Is("int8") && !m["x"].Const).Report("unnecessary conversion").Suggest("$x") + m.Match("int16($x)").Where(m["x"].Type.Is("int16") && !m["x"].Const).Report("unnecessary conversion").Suggest("$x") + m.Match("int32($x)").Where(m["x"].Type.Is("int32") && !m["x"].Const).Report("unnecessary conversion").Suggest("$x") + m.Match("int64($x)").Where(m["x"].Type.Is("int64") && !m["x"].Const).Report("unnecessary conversion").Suggest("$x") + + m.Match("uint8($x)").Where(m["x"].Type.Is("uint8") && !m["x"].Const).Report("unnecessary conversion").Suggest("$x") + m.Match("uint16($x)").Where(m["x"].Type.Is("uint16") && !m["x"].Const).Report("unnecessary conversion").Suggest("$x") + m.Match("uint32($x)").Where(m["x"].Type.Is("uint32") && !m["x"].Const).Report("unnecessary conversion").Suggest("$x") + m.Match("uint64($x)").Where(m["x"].Type.Is("uint64") && !m["x"].Const).Report("unnecessary conversion").Suggest("$x") + + m.Match("time.Duration($x)").Where(m["x"].Type.Is("time.Duration") && !m["x"].Text.Matches("^[0-9]*$")).Report("unnecessary conversion").Suggest("$x") +} + +// Don't use == or != with time.Time +// https://github.com/dominikh/go-tools/issues/47 : Wontfix +func timeeq(m dsl.Matcher) { + m.Match("$t0 == $t1").Where(m["t0"].Type.Is("time.Time")).Report("using == with time.Time") + m.Match("$t0 != $t1").Where(m["t0"].Type.Is("time.Time")).Report("using != with time.Time") + m.Match(`map[$k]$v`).Where(m["k"].Type.Is("time.Time")).Report("map with time.Time keys are easy to misuse") +} + +// err but no an error +func errnoterror(m dsl.Matcher) { + // Would be easier to check for all err identifiers instead, but then how do we get the type from m[] ? + + m.Match( + "if $*_, err := $x; $err != nil { $*_ } else if $_ { $*_ }", + "if $*_, err := $x; $err != nil { $*_ } else { $*_ }", + "if $*_, err := $x; $err != nil { $*_ }", + + "if $*_, err = $x; $err != nil { $*_ } else if $_ { $*_ }", + "if $*_, err = $x; $err != nil { $*_ } else { $*_ }", + "if $*_, err = $x; $err != nil { $*_ }", + + "$*_, err := $x; if $err != nil { $*_ } else if $_ { $*_ }", + "$*_, err := $x; if $err != nil { $*_ } else { $*_ }", + "$*_, err := $x; if $err != nil { $*_ }", + + "$*_, err = $x; if $err != nil { $*_ } else if $_ { $*_ }", + "$*_, err = $x; if $err != nil { $*_ } else { $*_ }", + "$*_, err = $x; if $err != nil { $*_ }", + ). + Where(m["err"].Text == "err" && !m["err"].Type.Is("error") && m["x"].Text != "recover()"). + Report("err variable not error type") +} + +// Identical if and else bodies +func ifbodythenbody(m dsl.Matcher) { + m.Match("if $*_ { $body } else { $body }"). + Report("identical if and else bodies") + + // Lots of false positives. + // m.Match("if $*_ { $body } else if $*_ { $body }"). + // Report("identical if and else bodies") +} + +// Odd inequality: A - B < 0 instead of != +// Too many false positives. +/* +func subtractnoteq(m dsl.Matcher) { + m.Match("$a - $b < 0").Report("consider $a != $b") + m.Match("$a - $b > 0").Report("consider $a != $b") + m.Match("0 < $a - $b").Report("consider $a != $b") + m.Match("0 > $a - $b").Report("consider $a != $b") +} +*/ + +// Self-assignment +func selfassign(m dsl.Matcher) { + m.Match("$x = $x").Report("useless self-assignment") +} + +// Odd nested ifs +func oddnestedif(m dsl.Matcher) { + m.Match("if $x { if $x { $*_ }; $*_ }", + "if $x == $y { if $x != $y {$*_ }; $*_ }", + "if $x != $y { if $x == $y {$*_ }; $*_ }", + "if $x { if !$x { $*_ }; $*_ }", + "if !$x { if $x { $*_ }; $*_ }"). + Report("odd nested ifs") + + m.Match("for $x { if $x { $*_ }; $*_ }", + "for $x == $y { if $x != $y {$*_ }; $*_ }", + "for $x != $y { if $x == $y {$*_ }; $*_ }", + "for $x { if !$x { $*_ }; $*_ }", + "for !$x { if $x { $*_ }; $*_ }"). + Report("odd nested for/ifs") +} + +// odd bitwise expressions +func oddbitwise(m dsl.Matcher) { + m.Match("$x | $x", + "$x | ^$x", + "^$x | $x"). + Report("odd bitwise OR") + + m.Match("$x & $x", + "$x & ^$x", + "^$x & $x"). + Report("odd bitwise AND") + + m.Match("$x &^ $x"). + Report("odd bitwise AND-NOT") +} + +// odd sequence of if tests with return +func ifreturn(m dsl.Matcher) { + m.Match("if $x { return $*_ }; if $x {$*_ }").Report("odd sequence of if test") + m.Match("if $x { return $*_ }; if !$x {$*_ }").Report("odd sequence of if test") + m.Match("if !$x { return $*_ }; if $x {$*_ }").Report("odd sequence of if test") + m.Match("if $x == $y { return $*_ }; if $x != $y {$*_ }").Report("odd sequence of if test") + m.Match("if $x != $y { return $*_ }; if $x == $y {$*_ }").Report("odd sequence of if test") +} + +func oddifsequence(m dsl.Matcher) { + /* + m.Match("if $x { $*_ }; if $x {$*_ }").Report("odd sequence of if test") + + m.Match("if $x == $y { $*_ }; if $y == $x {$*_ }").Report("odd sequence of if tests") + m.Match("if $x != $y { $*_ }; if $y != $x {$*_ }").Report("odd sequence of if tests") + + m.Match("if $x < $y { $*_ }; if $y > $x {$*_ }").Report("odd sequence of if tests") + m.Match("if $x <= $y { $*_ }; if $y >= $x {$*_ }").Report("odd sequence of if tests") + + m.Match("if $x > $y { $*_ }; if $y < $x {$*_ }").Report("odd sequence of if tests") + m.Match("if $x >= $y { $*_ }; if $y <= $x {$*_ }").Report("odd sequence of if tests") + */ +} + +// odd sequence of nested if tests +func nestedifsequence(m dsl.Matcher) { + /* + m.Match("if $x < $y { if $x >= $y {$*_ }; $*_ }").Report("odd sequence of nested if tests") + m.Match("if $x <= $y { if $x > $y {$*_ }; $*_ }").Report("odd sequence of nested if tests") + m.Match("if $x > $y { if $x <= $y {$*_ }; $*_ }").Report("odd sequence of nested if tests") + m.Match("if $x >= $y { if $x < $y {$*_ }; $*_ }").Report("odd sequence of nested if tests") + */ +} + +// odd sequence of assignments +func identicalassignments(m dsl.Matcher) { + m.Match("$x = $y; $y = $x").Report("odd sequence of assignments") +} + +func oddcompoundop(m dsl.Matcher) { + m.Match("$x += $x + $_", + "$x += $x - $_"). + Report("odd += expression") + + m.Match("$x -= $x + $_", + "$x -= $x - $_"). + Report("odd -= expression") +} + +func constswitch(m dsl.Matcher) { + m.Match("switch $x { $*_ }", "switch $*_; $x { $*_ }"). + Where(m["x"].Const && !m["x"].Text.Matches(`^runtime\.`)). + Report("constant switch") +} + +func oddcomparisons(m dsl.Matcher) { + m.Match( + "$x - $y == 0", + "$x - $y != 0", + "$x - $y < 0", + "$x - $y <= 0", + "$x - $y > 0", + "$x - $y >= 0", + "$x ^ $y == 0", + "$x ^ $y != 0", + ).Report("odd comparison") +} + +func oddmathbits(m dsl.Matcher) { + m.Match( + "64 - bits.LeadingZeros64($x)", + "32 - bits.LeadingZeros32($x)", + "16 - bits.LeadingZeros16($x)", + "8 - bits.LeadingZeros8($x)", + ).Report("odd math/bits expression: use bits.Len*() instead?") +} + +// func floateq(m dsl.Matcher) { +// m.Match( +// "$x == $y", +// "$x != $y", +// ). +// Where(m["x"].Type.Is("float32") && !m["x"].Const && !m["y"].Text.Matches("0(.0+)?") && !m.File().Name.Matches("floating_comparision.go")). +// Report("floating point tested for equality") + +// m.Match( +// "$x == $y", +// "$x != $y", +// ). +// Where(m["x"].Type.Is("float64") && !m["x"].Const && !m["y"].Text.Matches("0(.0+)?") && !m.File().Name.Matches("floating_comparision.go")). +// Report("floating point tested for equality") + +// m.Match("switch $x { $*_ }", "switch $*_; $x { $*_ }"). +// Where(m["x"].Type.Is("float32")). +// Report("floating point as switch expression") + +// m.Match("switch $x { $*_ }", "switch $*_; $x { $*_ }"). +// Where(m["x"].Type.Is("float64")). +// Report("floating point as switch expression") + +// } + +func badexponent(m dsl.Matcher) { + m.Match( + "2 ^ $x", + "10 ^ $x", + ). + Report("caret (^) is not exponentiation") +} + +func floatloop(m dsl.Matcher) { + m.Match( + "for $i := $x; $i < $y; $i += $z { $*_ }", + "for $i = $x; $i < $y; $i += $z { $*_ }", + ). + Where(m["i"].Type.Is("float64")). + Report("floating point for loop counter") + + m.Match( + "for $i := $x; $i < $y; $i += $z { $*_ }", + "for $i = $x; $i < $y; $i += $z { $*_ }", + ). + Where(m["i"].Type.Is("float32")). + Report("floating point for loop counter") +} + +func urlredacted(m dsl.Matcher) { + m.Match( + "log.Println($x, $*_)", + "log.Println($*_, $x, $*_)", + "log.Println($*_, $x)", + "log.Printf($*_, $x, $*_)", + "log.Printf($*_, $x)", + + "log.Println($x, $*_)", + "log.Println($*_, $x, $*_)", + "log.Println($*_, $x)", + "log.Printf($*_, $x, $*_)", + "log.Printf($*_, $x)", + ). + Where(m["x"].Type.Is("*url.URL")). + Report("consider $x.Redacted() when outputting URLs") +} + +func sprinterr(m dsl.Matcher) { + m.Match(`fmt.Sprint($err)`, + `fmt.Sprintf("%s", $err)`, + `fmt.Sprintf("%v", $err)`, + ). + Where(m["err"].Type.Is("error")). + Report("maybe call $err.Error() instead of fmt.Sprint()?") +} + +// disable this check, because it can not apply to generic type +//func largeloopcopy(m dsl.Matcher) { +// m.Match( +// `for $_, $v := range $_ { $*_ }`, +// ). +// Where(m["v"].Type.Size > 1024). +// Report(`loop copies large value each iteration`) +//} + +func joinpath(m dsl.Matcher) { + m.Match( + `strings.Join($_, "/")`, + `strings.Join($_, "\\")`, + "strings.Join($_, `\\`)", + ). + Report(`did you mean path.Join() or filepath.Join() ?`) +} + +func readfull(m dsl.Matcher) { + m.Match(`$n, $err := io.ReadFull($_, $slice) + if $err != nil || $n != len($slice) { + $*_ + }`, + `$n, $err := io.ReadFull($_, $slice) + if $n != len($slice) || $err != nil { + $*_ + }`, + `$n, $err = io.ReadFull($_, $slice) + if $err != nil || $n != len($slice) { + $*_ + }`, + `$n, $err = io.ReadFull($_, $slice) + if $n != len($slice) || $err != nil { + $*_ + }`, + `if $n, $err := io.ReadFull($_, $slice); $n != len($slice) || $err != nil { + $*_ + }`, + `if $n, $err := io.ReadFull($_, $slice); $err != nil || $n != len($slice) { + $*_ + }`, + `if $n, $err = io.ReadFull($_, $slice); $n != len($slice) || $err != nil { + $*_ + }`, + `if $n, $err = io.ReadFull($_, $slice); $err != nil || $n != len($slice) { + $*_ + }`, + ).Report("io.ReadFull() returns err == nil iff n == len(slice)") +} + +func nilerr(m dsl.Matcher) { + m.Match( + `if err == nil { return err }`, + `if err == nil { return $*_, err }`, + ). + Report(`return nil error instead of nil value`) +} + +func mailaddress(m dsl.Matcher) { + m.Match( + "fmt.Sprintf(`\"%s\" <%s>`, $NAME, $EMAIL)", + "fmt.Sprintf(`\"%s\"<%s>`, $NAME, $EMAIL)", + "fmt.Sprintf(`%s <%s>`, $NAME, $EMAIL)", + "fmt.Sprintf(`%s<%s>`, $NAME, $EMAIL)", + `fmt.Sprintf("\"%s\"<%s>", $NAME, $EMAIL)`, + `fmt.Sprintf("\"%s\" <%s>", $NAME, $EMAIL)`, + `fmt.Sprintf("%s<%s>", $NAME, $EMAIL)`, + `fmt.Sprintf("%s <%s>", $NAME, $EMAIL)`, + ). + Report("use net/mail Address.String() instead of fmt.Sprintf()"). + Suggest("(&mail.Address{Name:$NAME, Address:$EMAIL}).String()") +} + +func errnetclosed(m dsl.Matcher) { + m.Match( + `strings.Contains($err.Error(), $text)`, + ). + Where(m["text"].Text.Matches("\".*closed network connection.*\"")). + Report(`String matching against error texts is fragile; use net.ErrClosed instead`). + Suggest(`errors.Is($err, net.ErrClosed)`) +} + +func httpheaderadd(m dsl.Matcher) { + m.Match( + `$H.Add($KEY, $VALUE)`, + ). + Where(m["H"].Type.Is("http.Header")). + Report("use http.Header.Set method instead of Add to overwrite all existing header values"). + Suggest(`$H.Set($KEY, $VALUE)`) +} + +func hmacnew(m dsl.Matcher) { + m.Match("hmac.New(func() hash.Hash { return $x }, $_)", + `$f := func() hash.Hash { return $x } + $*_ + hmac.New($f, $_)`, + ).Where(m["x"].Pure). + Report("invalid hash passed to hmac.New()") +} + +func writestring(m dsl.Matcher) { + m.Match(`io.WriteString($w, string($b))`). + Where(m["b"].Type.Is("[]byte")). + Suggest("$w.Write($b)") +} + +func badlock(m dsl.Matcher) { + // Shouldn't give many false positives without type filter + // as Lock+Unlock pairs in combination with defer gives us pretty + // a good chance to guess correctly. If we constrain the type to sync.Mutex + // then it'll be harder to match embedded locks and custom methods + // that may forward the call to the sync.Mutex (or other synchronization primitive). + + m.Match(`$mu.Lock(); defer $mu.RUnlock()`).Report(`maybe $mu.RLock() was intended?`) + m.Match(`$mu.RLock(); defer $mu.Unlock()`).Report(`maybe $mu.Lock() was intended?`) +} diff --git a/pkg/tracer/interceptor_suite.go b/pkg/tracer/interceptor_suite.go index 2d94bfc582523..15ac8b0fe4787 100644 --- a/pkg/tracer/interceptor_suite.go +++ b/pkg/tracer/interceptor_suite.go @@ -21,21 +21,19 @@ import ( "go.opentelemetry.io/otel" ) -var ( - filterFunc = func(info *otelgrpc.InterceptorInfo) bool { - var fullMethod string - if info.UnaryServerInfo != nil { - fullMethod = info.UnaryServerInfo.FullMethod - } else if info.StreamServerInfo != nil { - fullMethod = info.StreamServerInfo.FullMethod - } - if fullMethod == `/milvus.proto.rootcoord.RootCoord/UpdateChannelTimeTick` || - fullMethod == `/milvus.proto.rootcoord.RootCoord/AllocTimestamp` { - return false - } - return true +var filterFunc = func(info *otelgrpc.InterceptorInfo) bool { + var fullMethod string + if info.UnaryServerInfo != nil { + fullMethod = info.UnaryServerInfo.FullMethod + } else if info.StreamServerInfo != nil { + fullMethod = info.StreamServerInfo.FullMethod } -) + if fullMethod == `/milvus.proto.rootcoord.RootCoord/UpdateChannelTimeTick` || + fullMethod == `/milvus.proto.rootcoord.RootCoord/AllocTimestamp` { + return false + } + return true +} // GetInterceptorOpts returns the Option of gRPC open-tracing func GetInterceptorOpts() []otelgrpc.Option { diff --git a/pkg/util/cache/hash_test.go b/pkg/util/cache/hash_test.go index df599e1ecceb1..6f6c4d86a9caa 100644 --- a/pkg/util/cache/hash_test.go +++ b/pkg/util/cache/hash_test.go @@ -45,7 +45,7 @@ func sumFNVu32(v uint32) uint64 { } func TestSum(t *testing.T) { - var tests = []struct { + tests := []struct { k interface{} h uint64 }{ diff --git a/pkg/util/cache/local_cache_test.go b/pkg/util/cache/local_cache_test.go index 508e1495c2623..07c7345383f11 100644 --- a/pkg/util/cache/local_cache_test.go +++ b/pkg/util/cache/local_cache_test.go @@ -353,12 +353,12 @@ func TestGetIfPresentExpired(t *testing.T) { c := NewCache(WithExpireAfterWrite[int, string](1*time.Second), WithInsertionListener(insFunc)) defer c.Close() - v, ok := c.GetIfPresent(0) + _, ok := c.GetIfPresent(0) assert.False(t, ok) wg.Add(1) c.Put(0, "0") - v, ok = c.GetIfPresent(0) + v, ok := c.GetIfPresent(0) assert.True(t, ok) assert.Equal(t, "0", v) diff --git a/pkg/util/commonpbutil/commonpbutil.go b/pkg/util/commonpbutil/commonpbutil.go index 42b505525791f..724b8978a6d4e 100644 --- a/pkg/util/commonpbutil/commonpbutil.go +++ b/pkg/util/commonpbutil/commonpbutil.go @@ -72,7 +72,6 @@ func FillMsgBaseFromClient(sourceID int64, options ...MsgBaseOptions) MsgBaseOpt op(msgBase) } } - } func newMsgBaseDefault() *commonpb.MsgBase { @@ -100,11 +99,3 @@ func UpdateMsgBase(msgBase *commonpb.MsgBase, options ...MsgBaseOptions) *common } return msgBaseRt } - -func IsHealthy(stateCode commonpb.StateCode) bool { - return stateCode == commonpb.StateCode_Healthy -} - -func IsHealthyOrStopping(stateCode commonpb.StateCode) bool { - return stateCode == commonpb.StateCode_Healthy || stateCode == commonpb.StateCode_Stopping -} diff --git a/pkg/util/commonpbutil/commonpbutil_test.go b/pkg/util/commonpbutil/commonpbutil_test.go deleted file mode 100644 index 4ab0964aa01e9..0000000000000 --- a/pkg/util/commonpbutil/commonpbutil_test.go +++ /dev/null @@ -1,66 +0,0 @@ -/* - * # Licensed to the LF AI & Data foundation under one - * # or more contributor license agreements. See the NOTICE file - * # distributed with this work for additional information - * # regarding copyright ownership. The ASF licenses this file - * # to you under the Apache License, Version 2.0 (the - * # "License"); you may not use this file except in compliance - * # with the License. You may obtain a copy of the License at - * # - * # http://www.apache.org/licenses/LICENSE-2.0 - * # - * # Unless required by applicable law or agreed to in writing, software - * # distributed under the License is distributed on an "AS IS" BASIS, - * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * # See the License for the specific language governing permissions and - * # limitations under the License. - */ - -package commonpbutil - -import ( - "testing" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/stretchr/testify/assert" -) - -func TestIsHealthy(t *testing.T) { - type testCase struct { - code commonpb.StateCode - expect bool - } - - cases := []testCase{ - {commonpb.StateCode_Healthy, true}, - {commonpb.StateCode_Initializing, false}, - {commonpb.StateCode_Abnormal, false}, - {commonpb.StateCode_StandBy, false}, - {commonpb.StateCode_Stopping, false}, - } - for _, tc := range cases { - t.Run(tc.code.String(), func(t *testing.T) { - assert.Equal(t, tc.expect, IsHealthy(tc.code)) - }) - } -} - -func TestIsHealthyOrStopping(t *testing.T) { - type testCase struct { - code commonpb.StateCode - expect bool - } - - cases := []testCase{ - {commonpb.StateCode_Healthy, true}, - {commonpb.StateCode_Initializing, false}, - {commonpb.StateCode_Abnormal, false}, - {commonpb.StateCode_StandBy, false}, - {commonpb.StateCode_Stopping, true}, - } - for _, tc := range cases { - t.Run(tc.code.String(), func(t *testing.T) { - assert.Equal(t, tc.expect, IsHealthyOrStopping(tc.code)) - }) - } -} diff --git a/pkg/util/conc/options.go b/pkg/util/conc/options.go index 4deb9d292b0b4..281ab675b3194 100644 --- a/pkg/util/conc/options.go +++ b/pkg/util/conc/options.go @@ -19,9 +19,10 @@ package conc import ( "time" - "github.com/milvus-io/milvus/pkg/log" "github.com/panjf2000/ants/v2" "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" ) type poolOption struct { diff --git a/pkg/util/conc/pool.go b/pkg/util/conc/pool.go index 600518ca4b22a..1d19ea02eb00f 100644 --- a/pkg/util/conc/pool.go +++ b/pkg/util/conc/pool.go @@ -21,8 +21,9 @@ import ( "runtime" "sync" - "github.com/milvus-io/milvus/pkg/util/generic" ants "github.com/panjf2000/ants/v2" + + "github.com/milvus-io/milvus/pkg/util/generic" ) // A goroutine pool diff --git a/pkg/util/constant.go b/pkg/util/constant.go index 912d475bec7c4..ecc51c7448a37 100644 --- a/pkg/util/constant.go +++ b/pkg/util/constant.go @@ -26,6 +26,7 @@ import ( // Meta Prefix consts const ( MetaStoreTypeEtcd = "etcd" + MetaStoreTypeTiKV = "tikv" SegmentMetaPrefix = "queryCoord-segmentMeta" ChangeInfoMetaPrefix = "queryCoord-sealedSegmentChangeInfo" diff --git a/pkg/util/errorutil/util.go b/pkg/util/errorutil/util.go deleted file mode 100644 index f9967d3ad2bbe..0000000000000 --- a/pkg/util/errorutil/util.go +++ /dev/null @@ -1,46 +0,0 @@ -package errorutil - -import ( - "fmt" - - "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - - "github.com/milvus-io/milvus/pkg/util/typeutil" -) - -func UnhealthyStatus(code commonpb.StateCode) *commonpb.Status { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: "proxy not healthy, StateCode=" + commonpb.StateCode_name[int32(code)], - } -} - -func UnhealthyError() error { - return errors.New("unhealthy node") -} - -func PermissionDenyError() error { - return errors.New("permission deny") -} - -func UnHealthReason(role string, nodeID typeutil.UniqueID, reason string) string { - return fmt.Sprintf("role %s[nodeID: %d] is unhealthy, reason: %s", role, nodeID, reason) -} - -func UnHealthReasonWithComponentStatesOrErr(role string, nodeID typeutil.UniqueID, cs *milvuspb.ComponentStates, err error) (bool, string) { - if err != nil { - return false, UnHealthReason(role, nodeID, fmt.Sprintf("inner error: %s", err.Error())) - } - - if cs != nil && cs.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - return false, UnHealthReason(role, nodeID, fmt.Sprintf("rpc status error: %d", cs.GetStatus().GetErrorCode())) - } - - if cs != nil && cs.GetState().GetStateCode() != commonpb.StateCode_Healthy { - return false, UnHealthReason(role, nodeID, fmt.Sprintf("node is unhealthy, state code: %d", cs.GetState().GetStateCode())) - } - - return true, "" -} diff --git a/pkg/util/errorutil/util_test.go b/pkg/util/errorutil/util_test.go deleted file mode 100644 index b21a29e2e4406..0000000000000 --- a/pkg/util/errorutil/util_test.go +++ /dev/null @@ -1 +0,0 @@ -package errorutil diff --git a/pkg/util/etcd/etcd_util.go b/pkg/util/etcd/etcd_util.go index 9ee284e1391a5..77dc0ce040a40 100644 --- a/pkg/util/etcd/etcd_util.go +++ b/pkg/util/etcd/etcd_util.go @@ -20,7 +20,6 @@ import ( "crypto/tls" "crypto/x509" "fmt" - "io/ioutil" "net/url" "os" "time" @@ -33,9 +32,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" ) -var ( - maxTxnNum = 128 -) +var maxTxnNum = 128 // GetEtcdClient returns etcd client func GetEtcdClient( @@ -45,7 +42,8 @@ func GetEtcdClient( certFile string, keyFile string, caCertFile string, - minVersion string) (*clientv3.Client, error) { + minVersion string, +) (*clientv3.Client, error) { log.Info("create etcd client", zap.Bool("useEmbedEtcd", useEmbedEtcd), zap.Bool("useSSL", useSSL), @@ -76,7 +74,7 @@ func GetRemoteEtcdSSLClient(endpoints []string, certFile string, keyFile string, if err != nil { return nil, errors.Wrap(err, "load etcd cert key pair error") } - caCert, err := ioutil.ReadFile(caCertFile) + caCert, err := os.ReadFile(caCertFile) if err != nil { return nil, errors.Wrapf(err, "load etcd CACert file error, filename = %s", caCertFile) } @@ -183,7 +181,7 @@ func buildKvGroup(keys, values []string) (map[string]string, error) { // StartTestEmbedEtcdServer returns a newly created embed etcd server. // ### USED FOR UNIT TEST ONLY ### func StartTestEmbedEtcdServer() (*embed.Etcd, string, error) { - dir, err := ioutil.TempDir(os.TempDir(), "milvus_ut") + dir, err := os.MkdirTemp(os.TempDir(), "milvus_ut") if err != nil { return nil, "", err } diff --git a/pkg/util/etcd/etcd_util_test.go b/pkg/util/etcd/etcd_util_test.go index ca67454604a1e..86a60ae4eab2f 100644 --- a/pkg/util/etcd/etcd_util_test.go +++ b/pkg/util/etcd/etcd_util_test.go @@ -42,27 +42,26 @@ func TestEtcd(t *testing.T) { assert.False(t, resp.Count < 1) assert.Equal(t, string(resp.Kvs[0].Value), "value") - etcdCli, err = GetEtcdClient(false, true, []string{}, + _, err = GetEtcdClient(false, true, []string{}, "../../../configs/cert/client.pem", "../../../configs/cert/client.key", "../../../configs/cert/ca.pem", "some not right word") assert.Error(t, err) - etcdCli, err = GetEtcdClient(false, true, []string{}, + _, err = GetEtcdClient(false, true, []string{}, "../../../configs/cert/client.pem", "../../../configs/cert/client.key", "wrong/file", "1.2") assert.Error(t, err) - etcdCli, err = GetEtcdClient(false, true, []string{}, + _, err = GetEtcdClient(false, true, []string{}, "wrong/file", "../../../configs/cert/client.key", "../../../configs/cert/ca.pem", "1.2") assert.Error(t, err) - } func Test_buildKvGroup(t *testing.T) { diff --git a/pkg/util/funcutil/func.go b/pkg/util/funcutil/func.go index ff8909a40518a..ffca8c19a14e9 100644 --- a/pkg/util/funcutil/func.go +++ b/pkg/util/funcutil/func.go @@ -29,12 +29,13 @@ import ( "time" "github.com/cockroachdb/errors" + "google.golang.org/grpc/codes" + grpcStatus "google.golang.org/grpc/status" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/util/typeutil" - "google.golang.org/grpc/codes" - grpcStatus "google.golang.org/grpc/status" ) // CheckGrpcReady wait for context timeout, or wait 100ms then send nil to targetCh @@ -49,6 +50,14 @@ func CheckGrpcReady(ctx context.Context, targetCh chan error) { } } +// GetIP return the ip address +func GetIP(ip string) string { + if len(ip) == 0 { + return GetLocalIP() + } + return ip +} + // GetLocalIP return the local ip address func GetLocalIP() string { addrs, err := net.InterfaceAddrs() @@ -108,7 +117,7 @@ func CheckCtxValid(ctx context.Context) bool { func GetVecFieldIDs(schema *schemapb.CollectionSchema) []int64 { var vecFieldIDs []int64 for _, field := range schema.Fields { - if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector { + if field.DataType == schemapb.DataType_BinaryVector || field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_Float16Vector { vecFieldIDs = append(vecFieldIDs, field.FieldID) } } @@ -222,7 +231,17 @@ func GetNumRowsOfBinaryVectorField(bDatas []byte, dim int64) (uint64, error) { return uint64((8 * int64(l)) / dim), nil } -// GetNumRowOfFieldData return num rows of the field data +func GetNumRowsOfFloat16VectorField(f16Datas []byte, dim int64) (uint64, error) { + if dim <= 0 { + return 0, fmt.Errorf("dim(%d) should be greater than 0", dim) + } + l := len(f16Datas) + if int64(l)%dim != 0 { + return 0, fmt.Errorf("the length(%d) of float data should divide the dim(%d)", l, dim) + } + return uint64((int64(l)) / dim / 2), nil +} + func GetNumRowOfFieldData(fieldData *schemapb.FieldData) (uint64, error) { var fieldNumRows uint64 var err error @@ -264,6 +283,12 @@ func GetNumRowOfFieldData(fieldData *schemapb.FieldData) (uint64, error) { if err != nil { return 0, err } + case *schemapb.VectorField_Float16Vector: + dim := vectorField.GetDim() + fieldNumRows, err = GetNumRowsOfFloat16VectorField(vectorField.GetFloat16Vector(), dim) + if err != nil { + return 0, err + } default: return 0, fmt.Errorf("%s is not supported now", vectorFieldType) } diff --git a/pkg/util/funcutil/func_test.go b/pkg/util/funcutil/func_test.go index f60495a8fe634..cabf80982cfb1 100644 --- a/pkg/util/funcutil/func_test.go +++ b/pkg/util/funcutil/func_test.go @@ -27,11 +27,12 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/stretchr/testify/assert" grpcCodes "google.golang.org/grpc/codes" grpcStatus "google.golang.org/grpc/status" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" ) func Test_CheckGrpcReady(t *testing.T) { @@ -55,6 +56,14 @@ func Test_GetLocalIP(t *testing.T) { assert.NotZero(t, len(ip)) } +func Test_GetIP(t *testing.T) { + ip := GetIP("") + assert.NotNil(t, ip) + assert.NotZero(t, len(ip)) + ip = GetIP("127.0.0") + assert.Equal(t, ip, "127.0.0") +} + func Test_ParseIndexParamsMap(t *testing.T) { num := 10 keys := make([]string, 0) @@ -219,6 +228,34 @@ func TestGetNumRowsOfFloatVectorField(t *testing.T) { } } +func TestGetNumRowsOfFloat16VectorField(t *testing.T) { + cases := []struct { + bDatas []byte + dim int64 + want uint64 + errIsNil bool + }{ + {[]byte{}, -1, 0, false}, // dim <= 0 + {[]byte{}, 0, 0, false}, // dim <= 0 + {[]byte{1.0}, 128, 0, false}, // length % dim != 0 + {[]byte{}, 128, 0, true}, + {[]byte{1.0, 2.0}, 1, 1, true}, + {[]byte{1.0, 2.0, 3.0, 4.0}, 2, 1, true}, + } + + for _, test := range cases { + got, err := GetNumRowsOfFloat16VectorField(test.bDatas, test.dim) + if test.errIsNil { + assert.Equal(t, nil, err) + if got != test.want { + t.Errorf("GetNumRowsOfFloat16VectorField(%v, %v) = %v, %v", test.bDatas, test.dim, test.want, nil) + } + } else { + assert.NotEqual(t, nil, err) + } + } +} + func TestGetNumRowsOfBinaryVectorField(t *testing.T) { cases := []struct { bDatas []byte @@ -310,7 +347,7 @@ func Test_ReadBinary(t *testing.T) { // float vector bs = []byte{0, 0, 0, 0, 0, 0, 0, 0} - var fs = make([]float32, 2) + fs := make([]float32, 2) assert.NoError(t, ReadBinary(endian, bs, &fs)) assert.ElementsMatch(t, []float32{0, 0}, fs) } diff --git a/pkg/util/funcutil/placeholdergroup.go b/pkg/util/funcutil/placeholdergroup.go new file mode 100644 index 0000000000000..e2e2ef163a64d --- /dev/null +++ b/pkg/util/funcutil/placeholdergroup.go @@ -0,0 +1,118 @@ +package funcutil + +import ( + "encoding/binary" + "math" + + "github.com/cockroachdb/errors" + "github.com/golang/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func FieldDataToPlaceholderGroupBytes(fieldData *schemapb.FieldData) ([]byte, error) { + placeholderValue, err := fieldDataToPlaceholderValue(fieldData) + if err != nil { + return nil, err + } + + placeholderGroup := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{placeholderValue}, + } + + bytes, _ := proto.Marshal(placeholderGroup) + return bytes, nil +} + +func fieldDataToPlaceholderValue(fieldData *schemapb.FieldData) (*commonpb.PlaceholderValue, error) { + switch fieldData.Type { + case schemapb.DataType_FloatVector: + vectors := fieldData.GetVectors() + x, ok := vectors.GetData().(*schemapb.VectorField_FloatVector) + if !ok { + return nil, errors.New("vector data is not schemapb.VectorField_FloatVector") + } + + placeholderValue := &commonpb.PlaceholderValue{ + Tag: "$0", + Type: commonpb.PlaceholderType_FloatVector, + Values: flattenedFloatVectorsToByteVectors(x.FloatVector.Data, int(vectors.Dim)), + } + return placeholderValue, nil + case schemapb.DataType_BinaryVector: + vectors := fieldData.GetVectors() + x, ok := vectors.GetData().(*schemapb.VectorField_BinaryVector) + if !ok { + return nil, errors.New("vector data is not schemapb.VectorField_BinaryVector") + } + placeholderValue := &commonpb.PlaceholderValue{ + Tag: "$0", + Type: commonpb.PlaceholderType_BinaryVector, + Values: flattenedByteVectorsToByteVectors(x.BinaryVector, int(vectors.Dim)), + } + return placeholderValue, nil + case schemapb.DataType_Float16Vector: + vectors := fieldData.GetVectors() + x, ok := vectors.GetData().(*schemapb.VectorField_Float16Vector) + if !ok { + return nil, errors.New("vector data is not schemapb.VectorField_Float16Vector") + } + placeholderValue := &commonpb.PlaceholderValue{ + Tag: "$0", + Type: commonpb.PlaceholderType_Float16Vector, + Values: flattenedFloat16VectorsToByteVectors(x.Float16Vector, int(vectors.Dim)), + } + return placeholderValue, nil + default: + return nil, errors.New("field is not a vector field") + } +} + +func flattenedFloatVectorsToByteVectors(flattenedVectors []float32, dimension int) [][]byte { + floatVectors := flattenedFloatVectorsToFloatVectors(flattenedVectors, dimension) + result := make([][]byte, 0) + for _, floatVector := range floatVectors { + result = append(result, floatVectorToByteVector(floatVector)) + } + + return result +} + +func flattenedFloatVectorsToFloatVectors(flattenedVectors []float32, dimension int) [][]float32 { + result := make([][]float32, 0) + for i := 0; i < len(flattenedVectors); i += dimension { + result = append(result, flattenedVectors[i:i+dimension]) + } + return result +} + +func floatVectorToByteVector(vector []float32) []byte { + data := make([]byte, 0, 4*len(vector)) // float32 occupies 4 bytes + buf := make([]byte, 4) + for _, f := range vector { + binary.LittleEndian.PutUint32(buf, math.Float32bits(f)) + data = append(data, buf...) + } + return data +} + +func flattenedByteVectorsToByteVectors(flattenedVectors []byte, dimension int) [][]byte { + result := make([][]byte, 0) + for i := 0; i < len(flattenedVectors); i += dimension { + result = append(result, flattenedVectors[i:i+dimension]) + } + return result +} + +func flattenedFloat16VectorsToByteVectors(flattenedVectors []byte, dimension int) [][]byte { + result := make([][]byte, 0) + + vectorBytes := 2 * dimension + + for i := 0; i < len(flattenedVectors); i += vectorBytes { + result = append(result, flattenedVectors[i:i+vectorBytes]) + } + + return result +} diff --git a/pkg/util/funcutil/placeholdergroup_test.go b/pkg/util/funcutil/placeholdergroup_test.go new file mode 100644 index 0000000000000..d53fb256b3820 --- /dev/null +++ b/pkg/util/funcutil/placeholdergroup_test.go @@ -0,0 +1,33 @@ +package funcutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_flattenedByteVectorsToByteVectors(t *testing.T) { + flattenedVectors := []byte{0, 1, 2, 3, 4, 5} + dimension := 3 + + actual := flattenedByteVectorsToByteVectors(flattenedVectors, dimension) + expected := [][]byte{ + {0, 1, 2}, + {3, 4, 5}, + } + + assert.Equal(t, expected, actual) +} + +func Test_flattenedFloat16VectorsToByteVectors(t *testing.T) { + flattenedVectors := []byte{0, 1, 2, 3, 4, 5, 6, 7} + dimension := 2 + + actual := flattenedFloat16VectorsToByteVectors(flattenedVectors, dimension) + expected := [][]byte{ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + } + + assert.Equal(t, expected, actual) +} diff --git a/pkg/util/funcutil/policy.go b/pkg/util/funcutil/policy.go index 5405f25cc72f7..730f01f675fd8 100644 --- a/pkg/util/funcutil/policy.go +++ b/pkg/util/funcutil/policy.go @@ -6,12 +6,13 @@ import ( "github.com/golang/protobuf/descriptor" "github.com/golang/protobuf/proto" + "go.uber.org/zap" + "google.golang.org/protobuf/reflect/protoreflect" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util" - "go.uber.org/zap" - "google.golang.org/protobuf/reflect/protoreflect" ) func GetVersion(m proto.GeneratedMessage) (string, error) { @@ -39,7 +40,7 @@ func GetPrivilegeExtObj(m proto.GeneratedMessage) (commonpb.PrivilegeExt, error) extObj, err := proto.GetExtension(md.Options, commonpb.E_PrivilegeExtObj) if err != nil { - log.Warn("GetExtension fail", zap.Error(err)) + log.Info("GetExtension fail", zap.Error(err)) return commonpb.PrivilegeExt{}, err } privilegeExt := extObj.(*commonpb.PrivilegeExt) diff --git a/pkg/util/funcutil/policy_test.go b/pkg/util/funcutil/policy_test.go index a0a5ce51b438c..03bf498884f77 100644 --- a/pkg/util/funcutil/policy_test.go +++ b/pkg/util/funcutil/policy_test.go @@ -3,9 +3,10 @@ package funcutil import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/stretchr/testify/assert" ) func Test_GetPrivilegeExtObj(t *testing.T) { @@ -43,7 +44,6 @@ func Test_GetResourceName(t *testing.T) { request = &milvuspb.SelectUserRequest{} assert.Equal(t, "*", GetObjectName(request, 2)) } - } func Test_GetResourceNames(t *testing.T) { diff --git a/pkg/util/funcutil/verify_response.go b/pkg/util/funcutil/verify_response.go deleted file mode 100644 index cbb3f6bd242ec..0000000000000 --- a/pkg/util/funcutil/verify_response.go +++ /dev/null @@ -1,48 +0,0 @@ -package funcutil - -import ( - "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" -) - -// errors for VerifyResponse -var errNilResponse = errors.New("response is nil") - -var errNilStatusResponse = errors.New("response has nil status") - -var errUnknownResponseType = errors.New("unknown response type") - -// Response response interface for verification -type Response interface { - GetStatus() *commonpb.Status -} - -// VerifyResponse verify grpc Response 1. check error is nil 2. check response.GetStatus() with status success -func VerifyResponse(response interface{}, err error) error { - if err != nil { - return err - } - if response == nil { - return errNilResponse - } - switch resp := response.(type) { - case Response: - // note that resp will not be nil here, since it's still an interface - if resp.GetStatus() == nil { - return errNilStatusResponse - } - if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - return errors.New(resp.GetStatus().GetReason()) - } - case *commonpb.Status: - if resp == nil { - return errNilResponse - } - if resp.ErrorCode != commonpb.ErrorCode_Success { - return errors.New(resp.GetReason()) - } - default: - return errUnknownResponseType - } - return nil -} diff --git a/pkg/util/generic/generic.go b/pkg/util/generic/generic.go index cbc51eee17284..87b6de1bc3f51 100644 --- a/pkg/util/generic/generic.go +++ b/pkg/util/generic/generic.go @@ -19,7 +19,8 @@ package generic import "reflect" func Zero[T any]() T { - return *new(T) + var zero T + return zero } func IsZero[T any](v T) bool { diff --git a/pkg/util/hardware/container_linux.go b/pkg/util/hardware/container_linux.go index 08eb37bfa8a81..49d5168054c3a 100644 --- a/pkg/util/hardware/container_linux.go +++ b/pkg/util/hardware/container_linux.go @@ -15,7 +15,6 @@ import ( "strings" "github.com/cockroachdb/errors" - "github.com/containerd/cgroups" ) diff --git a/pkg/util/indexparamcheck/base_checker.go b/pkg/util/indexparamcheck/base_checker.go index a416a2990e4de..a8c27776c7a36 100644 --- a/pkg/util/indexparamcheck/base_checker.go +++ b/pkg/util/indexparamcheck/base_checker.go @@ -2,11 +2,11 @@ package indexparamcheck import ( "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) -type baseChecker struct { -} +type baseChecker struct{} func (c baseChecker) CheckTrain(params map[string]string) error { if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) { diff --git a/pkg/util/indexparamcheck/base_checker_test.go b/pkg/util/indexparamcheck/base_checker_test.go index eee11c5af48c1..a016d4da88498 100644 --- a/pkg/util/indexparamcheck/base_checker_test.go +++ b/pkg/util/indexparamcheck/base_checker_test.go @@ -4,10 +4,10 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/util/metric" - - "github.com/stretchr/testify/assert" ) func Test_baseChecker_CheckTrain(t *testing.T) { diff --git a/pkg/util/indexparamcheck/bin_flat_checker_test.go b/pkg/util/indexparamcheck/bin_flat_checker_test.go index 4fa8814cd4187..7c10f2e62b3d1 100644 --- a/pkg/util/indexparamcheck/bin_flat_checker_test.go +++ b/pkg/util/indexparamcheck/bin_flat_checker_test.go @@ -4,10 +4,10 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/util/metric" - - "github.com/stretchr/testify/assert" ) func Test_binFlatChecker_CheckTrain(t *testing.T) { @@ -76,7 +76,6 @@ func Test_binFlatChecker_CheckTrain(t *testing.T) { } func Test_binFlatChecker_CheckValidDataType(t *testing.T) { - cases := []struct { dType schemapb.DataType errIsNil bool diff --git a/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go b/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go index 487e47198c33c..27ef913c2aee0 100644 --- a/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go +++ b/pkg/util/indexparamcheck/bin_ivf_flat_checker_test.go @@ -4,10 +4,10 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/util/metric" - - "github.com/stretchr/testify/assert" ) func Test_binIVFFlatChecker_CheckTrain(t *testing.T) { @@ -127,7 +127,6 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) { } func Test_binIVFFlatChecker_CheckValidDataType(t *testing.T) { - cases := []struct { dType schemapb.DataType errIsNil bool diff --git a/pkg/util/indexparamcheck/binary_vector_base_checker.go b/pkg/util/indexparamcheck/binary_vector_base_checker.go index 4fa69af2042dc..ccafd4f0a9de6 100644 --- a/pkg/util/indexparamcheck/binary_vector_base_checker.go +++ b/pkg/util/indexparamcheck/binary_vector_base_checker.go @@ -3,9 +3,8 @@ package indexparamcheck import ( "fmt" - "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" ) type binaryVectorBaseChecker struct { diff --git a/pkg/util/indexparamcheck/binary_vector_base_checker_test.go b/pkg/util/indexparamcheck/binary_vector_base_checker_test.go index d1b09cd449f0e..fc166fabd9218 100644 --- a/pkg/util/indexparamcheck/binary_vector_base_checker_test.go +++ b/pkg/util/indexparamcheck/binary_vector_base_checker_test.go @@ -3,12 +3,12 @@ package indexparamcheck import ( "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) func Test_binaryVectorBaseChecker_CheckValidDataType(t *testing.T) { - cases := []struct { dType schemapb.DataType errIsNil bool diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr.go b/pkg/util/indexparamcheck/conf_adapter_mgr.go index dd60ae638a136..5099da0ca3108 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr.go +++ b/pkg/util/indexparamcheck/conf_adapter_mgr.go @@ -48,7 +48,7 @@ func (mgr *indexCheckerMgrImpl) registerIndexChecker() { mgr.checkers[IndexFaissIDMap] = newFlatChecker() mgr.checkers[IndexFaissIvfFlat] = newIVFBaseChecker() mgr.checkers[IndexFaissIvfPQ] = newIVFPQChecker() - mgr.checkers[IndexScaNN] = newIVFBaseChecker() + mgr.checkers[IndexScaNN] = newScaNNChecker() mgr.checkers[IndexFaissIvfSQ8] = newIVFSQChecker() mgr.checkers[IndexFaissBinIDMap] = newBinFlatChecker() mgr.checkers[IndexFaissBinIvfFlat] = newBinIVFFlatChecker() diff --git a/pkg/util/indexparamcheck/conf_adapter_mgr_test.go b/pkg/util/indexparamcheck/conf_adapter_mgr_test.go index 370a98e2c298c..6ab9469ee501d 100644 --- a/pkg/util/indexparamcheck/conf_adapter_mgr_test.go +++ b/pkg/util/indexparamcheck/conf_adapter_mgr_test.go @@ -44,7 +44,7 @@ func Test_GetConfAdapterMgrInstance(t *testing.T) { adapter, err = adapterMgr.GetChecker(IndexScaNN) assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*ivfBaseChecker) + _, ok = adapter.(*scaNNChecker) assert.Equal(t, true, ok) adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ) @@ -104,7 +104,7 @@ func TestConfAdapterMgrImpl_GetAdapter(t *testing.T) { adapter, err = adapterMgr.GetChecker(IndexScaNN) assert.Equal(t, nil, err) assert.NotEqual(t, nil, adapter) - _, ok = adapter.(*ivfBaseChecker) + _, ok = adapter.(*scaNNChecker) assert.Equal(t, true, ok) adapter, err = adapterMgr.GetChecker(IndexFaissIvfPQ) diff --git a/pkg/util/indexparamcheck/constraints.go b/pkg/util/indexparamcheck/constraints.go index f2e80db3d6b10..b30d16b86c64a 100644 --- a/pkg/util/indexparamcheck/constraints.go +++ b/pkg/util/indexparamcheck/constraints.go @@ -42,11 +42,13 @@ const ( var METRICS = []string{metric.L2, metric.IP, metric.COSINE} // const // BinIDMapMetrics is a set of all metric types supported for binary vector. -var BinIDMapMetrics = []string{metric.HAMMING, metric.JACCARD, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE} // const -var BinIvfMetrics = []string{metric.HAMMING, metric.JACCARD} // const -var HnswMetrics = []string{metric.L2, metric.IP, metric.COSINE, metric.HAMMING, metric.JACCARD} // const -var supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const -var supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const +var ( + BinIDMapMetrics = []string{metric.HAMMING, metric.JACCARD, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE} // const + BinIvfMetrics = []string{metric.HAMMING, metric.JACCARD} // const + HnswMetrics = []string{metric.L2, metric.IP, metric.COSINE, metric.HAMMING, metric.JACCARD} // const + supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const + supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const +) const ( FloatVectorDefaultMetricType = metric.IP diff --git a/pkg/util/indexparamcheck/diskann_checker_test.go b/pkg/util/indexparamcheck/diskann_checker_test.go index 11005e16115f5..411e8f97d8e91 100644 --- a/pkg/util/indexparamcheck/diskann_checker_test.go +++ b/pkg/util/indexparamcheck/diskann_checker_test.go @@ -4,10 +4,10 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/util/metric" - - "github.com/stretchr/testify/assert" ) func Test_diskannChecker_CheckTrain(t *testing.T) { @@ -84,7 +84,6 @@ func Test_diskannChecker_CheckTrain(t *testing.T) { } func Test_diskannChecker_CheckValidDataType(t *testing.T) { - cases := []struct { dType schemapb.DataType errIsNil bool diff --git a/pkg/util/indexparamcheck/flat_checker_test.go b/pkg/util/indexparamcheck/flat_checker_test.go index c44a215dae102..115fd839317ea 100644 --- a/pkg/util/indexparamcheck/flat_checker_test.go +++ b/pkg/util/indexparamcheck/flat_checker_test.go @@ -4,13 +4,12 @@ import ( "strconv" "testing" - "github.com/milvus-io/milvus/pkg/util/metric" - "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/metric" ) func Test_flatChecker_CheckTrain(t *testing.T) { - p1 := map[string]string{ DIM: strconv.Itoa(128), Metric: metric.L2, diff --git a/pkg/util/indexparamcheck/float_vector_base_checker.go b/pkg/util/indexparamcheck/float_vector_base_checker.go index be94d79cfa790..de237c08ea643 100644 --- a/pkg/util/indexparamcheck/float_vector_base_checker.go +++ b/pkg/util/indexparamcheck/float_vector_base_checker.go @@ -3,9 +3,8 @@ package indexparamcheck import ( "fmt" - "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" ) type floatVectorBaseChecker struct { @@ -29,8 +28,8 @@ func (c floatVectorBaseChecker) CheckTrain(params map[string]string) error { } func (c floatVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error { - if dType != schemapb.DataType_FloatVector { - return fmt.Errorf("float vector is only supported") + if dType != schemapb.DataType_FloatVector && dType != schemapb.DataType_Float16Vector { + return fmt.Errorf("float or float16 vector are only supported") } return nil } diff --git a/pkg/util/indexparamcheck/float_vector_base_checker_test.go b/pkg/util/indexparamcheck/float_vector_base_checker_test.go index 22ae463e4db49..affc4d9d53c24 100644 --- a/pkg/util/indexparamcheck/float_vector_base_checker_test.go +++ b/pkg/util/indexparamcheck/float_vector_base_checker_test.go @@ -3,12 +3,12 @@ package indexparamcheck import ( "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) func Test_floatVectorBaseChecker_CheckValidDataType(t *testing.T) { - cases := []struct { dType schemapb.DataType errIsNil bool diff --git a/pkg/util/indexparamcheck/hnsw_checker.go b/pkg/util/indexparamcheck/hnsw_checker.go index ccf3d3f79dba0..fa3df38c23d42 100644 --- a/pkg/util/indexparamcheck/hnsw_checker.go +++ b/pkg/util/indexparamcheck/hnsw_checker.go @@ -31,7 +31,7 @@ func (c hnswChecker) CheckTrain(params map[string]string) error { } func (c hnswChecker) CheckValidDataType(dType schemapb.DataType) error { - if dType != schemapb.DataType_FloatVector && dType != schemapb.DataType_BinaryVector { + if dType != schemapb.DataType_FloatVector && dType != schemapb.DataType_BinaryVector && dType != schemapb.DataType_Float16Vector { return fmt.Errorf("only support float vector or binary vector") } return nil diff --git a/pkg/util/indexparamcheck/hnsw_checker_test.go b/pkg/util/indexparamcheck/hnsw_checker_test.go index d2ea5d9f7028a..bcb7c482a1781 100644 --- a/pkg/util/indexparamcheck/hnsw_checker_test.go +++ b/pkg/util/indexparamcheck/hnsw_checker_test.go @@ -4,14 +4,13 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/util/metric" - - "github.com/stretchr/testify/assert" ) func Test_hnswChecker_CheckTrain(t *testing.T) { - validParams := map[string]string{ DIM: strconv.Itoa(128), HNSWM: strconv.Itoa(16), @@ -105,7 +104,6 @@ func Test_hnswChecker_CheckTrain(t *testing.T) { } func Test_hnswChecker_CheckValidDataType(t *testing.T) { - cases := []struct { dType schemapb.DataType errIsNil bool diff --git a/pkg/util/indexparamcheck/ivf_base_checker_test.go b/pkg/util/indexparamcheck/ivf_base_checker_test.go index e9ed4c017d600..ad0ad42a2090c 100644 --- a/pkg/util/indexparamcheck/ivf_base_checker_test.go +++ b/pkg/util/indexparamcheck/ivf_base_checker_test.go @@ -4,10 +4,10 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/util/metric" - - "github.com/stretchr/testify/assert" ) func Test_ivfBaseChecker_CheckTrain(t *testing.T) { diff --git a/pkg/util/indexparamcheck/ivf_pq_checker.go b/pkg/util/indexparamcheck/ivf_pq_checker.go index 07c830f26e94c..51da64e0ffe7a 100644 --- a/pkg/util/indexparamcheck/ivf_pq_checker.go +++ b/pkg/util/indexparamcheck/ivf_pq_checker.go @@ -53,7 +53,7 @@ func (c *ivfPQChecker) checkPQParams(params map[string]string) error { func (c *ivfPQChecker) checkCPUPQParams(dimension, m int) error { if (dimension % m) != 0 { - return fmt.Errorf("dimension must be abled to be divided by `m`, dimension: %d, m: %d", dimension, m) + return fmt.Errorf("dimension must be able to be divided by `m`, dimension: %d, m: %d", dimension, m) } return nil } diff --git a/pkg/util/indexparamcheck/ivf_pq_checker_test.go b/pkg/util/indexparamcheck/ivf_pq_checker_test.go index 11938473d0574..8c44f22c34edb 100644 --- a/pkg/util/indexparamcheck/ivf_pq_checker_test.go +++ b/pkg/util/indexparamcheck/ivf_pq_checker_test.go @@ -4,10 +4,10 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/util/metric" - - "github.com/stretchr/testify/assert" ) func Test_ivfPQChecker_CheckTrain(t *testing.T) { @@ -151,7 +151,6 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) { } func Test_ivfPQChecker_CheckValidDataType(t *testing.T) { - cases := []struct { dType schemapb.DataType errIsNil bool diff --git a/pkg/util/indexparamcheck/ivf_sq_checker_test.go b/pkg/util/indexparamcheck/ivf_sq_checker_test.go index eef0a73251f5b..fa8a5a73c86ef 100644 --- a/pkg/util/indexparamcheck/ivf_sq_checker_test.go +++ b/pkg/util/indexparamcheck/ivf_sq_checker_test.go @@ -4,10 +4,10 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/util/metric" - - "github.com/stretchr/testify/assert" ) func Test_ivfSQChecker_CheckTrain(t *testing.T) { diff --git a/pkg/util/indexparamcheck/raft_ivf_pq_checker.go b/pkg/util/indexparamcheck/raft_ivf_pq_checker.go index 8ce89d72e4121..65f6d1d1b7503 100644 --- a/pkg/util/indexparamcheck/raft_ivf_pq_checker.go +++ b/pkg/util/indexparamcheck/raft_ivf_pq_checker.go @@ -53,7 +53,7 @@ func (c *raftIVFPQChecker) checkPQParams(params map[string]string) error { return nil } if dimension%m != 0 { - return fmt.Errorf("dimension must be abled to be divided by `m`, dimension: %d, m: %d", dimension, m) + return fmt.Errorf("dimension must be able to be divided by `m`, dimension: %d, m: %d", dimension, m) } return nil } diff --git a/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go b/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go index 27d6939ed7597..f1b743359727f 100644 --- a/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go +++ b/pkg/util/indexparamcheck/raft_ivf_pq_checker_test.go @@ -4,14 +4,13 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/util/metric" - - "github.com/stretchr/testify/assert" ) func Test_raftIVFPQChecker_CheckTrain(t *testing.T) { - validParams := map[string]string{ DIM: strconv.Itoa(128), NLIST: strconv.Itoa(1024), diff --git a/pkg/util/indexparamcheck/scalar_index_checker_test.go b/pkg/util/indexparamcheck/scalar_index_checker_test.go index 01a755d700ef3..3289cd00b2d87 100644 --- a/pkg/util/indexparamcheck/scalar_index_checker_test.go +++ b/pkg/util/indexparamcheck/scalar_index_checker_test.go @@ -3,8 +3,9 @@ package indexparamcheck import ( "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) func TestCheckIndexValid(t *testing.T) { diff --git a/pkg/util/indexparamcheck/scann_checker.go b/pkg/util/indexparamcheck/scann_checker.go new file mode 100644 index 0000000000000..eecf2ded64bbf --- /dev/null +++ b/pkg/util/indexparamcheck/scann_checker.go @@ -0,0 +1,41 @@ +package indexparamcheck + +import ( + "fmt" + "strconv" +) + +// scaNNChecker checks if a SCANN index can be built. +type scaNNChecker struct { + ivfBaseChecker +} + +// CheckTrain checks if SCANN index can be built with the specific index parameters. +func (c *scaNNChecker) CheckTrain(params map[string]string) error { + if err := c.ivfBaseChecker.CheckTrain(params); err != nil { + return err + } + + return c.checkScaNNParams(params) +} + +func (c *scaNNChecker) checkScaNNParams(params map[string]string) error { + dimStr, dimensionExist := params[DIM] + if !dimensionExist { + return fmt.Errorf("dimension not found") + } + + dimension, err := strconv.Atoi(dimStr) + if err != nil { // invalid dimension + return fmt.Errorf("invalid dimension: %s", dimStr) + } + + if (dimension % 2) != 0 { + return fmt.Errorf("dimension must be able to be divided by 2, dimension: %d", dimension) + } + return nil +} + +func newScaNNChecker() IndexChecker { + return &scaNNChecker{} +} diff --git a/pkg/util/indexparamcheck/scann_checker_test.go b/pkg/util/indexparamcheck/scann_checker_test.go new file mode 100644 index 0000000000000..7e86beeb1f831 --- /dev/null +++ b/pkg/util/indexparamcheck/scann_checker_test.go @@ -0,0 +1,169 @@ +package indexparamcheck + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/metric" +) + +func Test_scaNNChecker_CheckTrain(t *testing.T) { + validParams := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.L2, + } + + paramsNotMultiplier := map[string]string{ + DIM: strconv.Itoa(127), + NLIST: strconv.Itoa(1024), + Metric: metric.L2, + } + + validParamsWithoutDim := map[string]string{ + NLIST: strconv.Itoa(1024), + Metric: metric.L2, + } + + invalidParamsDim := copyParams(validParams) + invalidParamsDim[DIM] = "NAN" + + p1 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.L2, + } + p2 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.IP, + } + p3 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.COSINE, + } + + p4 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.HAMMING, + } + p5 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.JACCARD, + } + p6 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.SUBSTRUCTURE, + } + p7 := map[string]string{ + DIM: strconv.Itoa(128), + NLIST: strconv.Itoa(1024), + Metric: metric.SUPERSTRUCTURE, + } + + cases := []struct { + params map[string]string + errIsNil bool + }{ + {validParams, true}, + {paramsNotMultiplier, false}, + {invalidIVFParamsMin(), false}, + {invalidIVFParamsMax(), false}, + {validParamsWithoutDim, false}, + {invalidParamsDim, false}, + {p1, true}, + {p2, true}, + {p3, true}, + {p4, false}, + {p5, false}, + {p6, false}, + {p7, false}, + } + + c := newScaNNChecker() + for _, test := range cases { + err := c.CheckTrain(test.params) + if test.errIsNil { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + } +} + +func Test_scaNNChecker_CheckValidDataType(t *testing.T) { + cases := []struct { + dType schemapb.DataType + errIsNil bool + }{ + { + dType: schemapb.DataType_Bool, + errIsNil: false, + }, + { + dType: schemapb.DataType_Int8, + errIsNil: false, + }, + { + dType: schemapb.DataType_Int16, + errIsNil: false, + }, + { + dType: schemapb.DataType_Int32, + errIsNil: false, + }, + { + dType: schemapb.DataType_Int64, + errIsNil: false, + }, + { + dType: schemapb.DataType_Float, + errIsNil: false, + }, + { + dType: schemapb.DataType_Double, + errIsNil: false, + }, + { + dType: schemapb.DataType_String, + errIsNil: false, + }, + { + dType: schemapb.DataType_VarChar, + errIsNil: false, + }, + { + dType: schemapb.DataType_Array, + errIsNil: false, + }, + { + dType: schemapb.DataType_JSON, + errIsNil: false, + }, + { + dType: schemapb.DataType_FloatVector, + errIsNil: true, + }, + { + dType: schemapb.DataType_BinaryVector, + errIsNil: false, + }, + } + + c := newScaNNChecker() + for _, test := range cases { + err := c.CheckValidDataType(test.dType) + if test.errIsNil { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + } +} diff --git a/pkg/util/indexparams/disk_index_params.go b/pkg/util/indexparams/disk_index_params.go index d9dee23d6cc06..564baef1515c2 100644 --- a/pkg/util/indexparams/disk_index_params.go +++ b/pkg/util/indexparams/disk_index_params.go @@ -208,17 +208,17 @@ func SetDiskIndexBuildParams(indexParams map[string]string, fieldDataSize int64) } searchCacheBudgetGBRatioStr, ok := indexParams[SearchCacheBudgetRatioKey] - if !ok { - return fmt.Errorf("index param searchCacheBudgetGBRatio not exist") - } - SearchCacheBudgetGBRatio, err := strconv.ParseFloat(searchCacheBudgetGBRatioStr, 64) - if err != nil { - return err + // set generate cache size when cache ratio param set + if ok { + SearchCacheBudgetGBRatio, err := strconv.ParseFloat(searchCacheBudgetGBRatioStr, 64) + if err != nil { + return err + } + indexParams[SearchCacheBudgetKey] = fmt.Sprintf("%f", float32(fieldDataSize)*float32(SearchCacheBudgetGBRatio)/(1<<30)) } indexParams[PQCodeBudgetKey] = fmt.Sprintf("%f", float32(fieldDataSize)*float32(pqCodeBudgetGBRatio)/(1<<30)) indexParams[NumBuildThreadKey] = strconv.Itoa(int(float32(hardware.GetCPUNum()) * float32(buildNumThreadsRatio))) indexParams[BuildDramBudgetKey] = fmt.Sprintf("%f", float32(hardware.GetFreeMemoryCount())/(1<<30)) - indexParams[SearchCacheBudgetKey] = fmt.Sprintf("%f", float32(fieldDataSize)*float32(SearchCacheBudgetGBRatio)/(1<<30)) return nil } @@ -260,7 +260,6 @@ func SetDiskIndexLoadParams(params *paramtable.ComponentParam, indexParams map[s if err != nil { return err } - } indexParams[SearchCacheBudgetKey] = fmt.Sprintf("%f", diff --git a/pkg/util/indexparams/disk_index_params_test.go b/pkg/util/indexparams/disk_index_params_test.go index 503321733f8b5..833301bab6138 100644 --- a/pkg/util/indexparams/disk_index_params_test.go +++ b/pkg/util/indexparams/disk_index_params_test.go @@ -22,10 +22,11 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/hardware" "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/stretchr/testify/assert" ) func TestDiskIndexParams(t *testing.T) { @@ -132,18 +133,26 @@ func TestDiskIndexParams(t *testing.T) { err := SetDiskIndexBuildParams(indexParams, 100) assert.NoError(t, err) + _, ok := indexParams[SearchCacheBudgetKey] + assert.True(t, ok) + indexParams[SearchCacheBudgetRatioKey] = "aabb" err = SetDiskIndexBuildParams(indexParams, 100) assert.Error(t, err) - _, ok := indexParams[PQCodeBudgetKey] + delete(indexParams, SearchCacheBudgetRatioKey) + delete(indexParams, SearchCacheBudgetKey) + err = SetDiskIndexBuildParams(indexParams, 100) + assert.NoError(t, err) + + _, ok = indexParams[PQCodeBudgetKey] assert.True(t, ok) _, ok = indexParams[BuildDramBudgetKey] assert.True(t, ok) _, ok = indexParams[NumBuildThreadKey] assert.True(t, ok) _, ok = indexParams[SearchCacheBudgetKey] - assert.True(t, ok) + assert.False(t, ok) }) t.Run("set disk index load params without auto index param", func(t *testing.T) { diff --git a/pkg/util/interceptor/cluster_interceptor.go b/pkg/util/interceptor/cluster_interceptor.go index e27f90f5314f0..f943f4e468d8e 100644 --- a/pkg/util/interceptor/cluster_interceptor.go +++ b/pkg/util/interceptor/cluster_interceptor.go @@ -43,7 +43,7 @@ func ClusterValidationUnaryServerInterceptor() grpc.UnaryServerInterceptor { } cluster := clusters[0] if cluster != "" && cluster != paramtable.Get().CommonCfg.ClusterPrefix.GetValue() { - return nil, merr.WrapErrCrossClusterRouting(paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), cluster) + return nil, merr.WrapErrServiceCrossClusterRouting(paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), cluster) } return handler(ctx, req) } @@ -64,7 +64,7 @@ func ClusterValidationStreamServerInterceptor() grpc.StreamServerInterceptor { } cluster := clusters[0] if cluster != "" && cluster != paramtable.Get().CommonCfg.ClusterPrefix.GetValue() { - return merr.WrapErrCrossClusterRouting(paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), cluster) + return merr.WrapErrServiceCrossClusterRouting(paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), cluster) } return handler(srv, ss) } diff --git a/pkg/util/interceptor/cluster_interceptor_test.go b/pkg/util/interceptor/cluster_interceptor_test.go index 851e908683c57..a5bd041f505a4 100644 --- a/pkg/util/interceptor/cluster_interceptor_test.go +++ b/pkg/util/interceptor/cluster_interceptor_test.go @@ -90,7 +90,7 @@ func TestClusterInterceptor(t *testing.T) { md := metadata.Pairs(ClusterKey, "ins-1") ctx = metadata.NewIncomingContext(context.Background(), md) _, err = interceptor(ctx, req, serverInfo, handler) - assert.ErrorIs(t, err, merr.ErrCrossClusterRouting) + assert.ErrorIs(t, err, merr.ErrServiceCrossClusterRouting) // with same cluster md = metadata.Pairs(ClusterKey, paramtable.Get().CommonCfg.ClusterPrefix.GetValue()) @@ -118,7 +118,7 @@ func TestClusterInterceptor(t *testing.T) { md := metadata.Pairs(ClusterKey, "ins-1") ctx = metadata.NewIncomingContext(context.Background(), md) err = interceptor(nil, newMockSS(ctx), nil, handler) - assert.ErrorIs(t, err, merr.ErrCrossClusterRouting) + assert.ErrorIs(t, err, merr.ErrServiceCrossClusterRouting) // with same cluster md = metadata.Pairs(ClusterKey, paramtable.Get().CommonCfg.ClusterPrefix.GetValue()) diff --git a/pkg/util/lifetime/lifetime.go b/pkg/util/lifetime/lifetime.go index 80f8db8e34624..2847746ab7f68 100644 --- a/pkg/util/lifetime/lifetime.go +++ b/pkg/util/lifetime/lifetime.go @@ -23,33 +23,36 @@ import ( // Lifetime interface for lifetime control. type Lifetime[T any] interface { + SafeChan // SetState is the method to change lifetime state. SetState(state T) // GetState returns current state. GetState() T // Add records a task is running, returns false if the lifetime is not healthy. - Add(isHealthy IsHealthy[T]) bool + Add(isHealthy CheckHealth[T]) error // Done records a task is done. Done() // Wait waits until all tasks are done. Wait() } -// IsHealthy function type for lifetime healthy check. -type IsHealthy[T any] func(T) bool +// CheckHealth function type for lifetime healthy check. +type CheckHealth[T any] func(T) error var _ Lifetime[any] = (*lifetime[any])(nil) // NewLifetime returns a new instance of Lifetime with init state and isHealthy logic. func NewLifetime[T any](initState T) Lifetime[T] { return &lifetime[T]{ - state: initState, + safeChan: newSafeChan(), + state: initState, } } // lifetime implementation of Lifetime. // users shall not care about the internal fields of this struct. type lifetime[T any] struct { + *safeChan // wg is used for keeping record each running task. wg sync.WaitGroup // state is the "atomic" value to store component state. @@ -57,7 +60,7 @@ type lifetime[T any] struct { // mut is the rwmutex to control each task and state change event. mut sync.RWMutex // isHealthy is the method to check whether is legal to add a task. - isHealthy func(int32) bool + isHealthy func(int32) error } // SetState is the method to change lifetime state. @@ -77,17 +80,17 @@ func (l *lifetime[T]) GetState() T { } // Add records a task is running, returns false if the lifetime is not healthy. -func (l *lifetime[T]) Add(isHealthy IsHealthy[T]) bool { +func (l *lifetime[T]) Add(checkHealth CheckHealth[T]) error { l.mut.RLock() defer l.mut.RUnlock() // check lifetime healthy - if !isHealthy(l.state) { - return false + if err := checkHealth(l.state); err != nil { + return err } l.wg.Add(1) - return true + return nil } // Done records a task is done. diff --git a/pkg/util/lifetime/lifetime_test.go b/pkg/util/lifetime/lifetime_test.go index f964f56a98d9c..20b348e1deb33 100644 --- a/pkg/util/lifetime/lifetime_test.go +++ b/pkg/util/lifetime/lifetime_test.go @@ -21,6 +21,8 @@ import ( "time" "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/pkg/util/merr" ) type LifetimeSuite struct { @@ -29,15 +31,20 @@ type LifetimeSuite struct { func (s *LifetimeSuite) TestNormal() { l := NewLifetime[int32](0) - isHealthy := func(state int32) bool { return state == 0 } + checkHealth := func(state int32) error { + if state == 0 { + return nil + } + return merr.WrapErrServiceNotReady("test", 0, "0") + } state := l.GetState() s.EqualValues(0, state) - s.True(l.Add(isHealthy)) + s.NoError(l.Add(checkHealth)) l.SetState(1) - s.False(l.Add(isHealthy)) + s.Error(l.Add(checkHealth)) signal := make(chan struct{}) go func() { diff --git a/pkg/util/lifetime/safe_chan.go b/pkg/util/lifetime/safe_chan.go new file mode 100644 index 0000000000000..ac877c215f7ed --- /dev/null +++ b/pkg/util/lifetime/safe_chan.go @@ -0,0 +1,45 @@ +package lifetime + +import ( + "sync" + + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// SafeChan is the utility type combining chan struct{} & sync.Once. +// It provides double close protection internally. +type SafeChan interface { + IsClosed() bool + CloseCh() <-chan struct{} + Close() +} + +type safeChan struct { + closed chan struct{} + once sync.Once +} + +// NewSafeChan returns a SafeChan with internal channel initialized +func NewSafeChan() SafeChan { + return newSafeChan() +} + +func newSafeChan() *safeChan { + return &safeChan{ + closed: make(chan struct{}), + } +} + +func (sc *safeChan) CloseCh() <-chan struct{} { + return sc.closed +} + +func (sc *safeChan) IsClosed() bool { + return typeutil.IsChanClosed(sc.closed) +} + +func (sc *safeChan) Close() { + sc.once.Do(func() { + close(sc.closed) + }) +} diff --git a/pkg/util/lifetime/safe_chan_test.go b/pkg/util/lifetime/safe_chan_test.go new file mode 100644 index 0000000000000..05dd445134004 --- /dev/null +++ b/pkg/util/lifetime/safe_chan_test.go @@ -0,0 +1,38 @@ +package lifetime + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type SafeChanSuite struct { + suite.Suite +} + +func (s *SafeChanSuite) TestClose() { + sc := NewSafeChan() + + s.False(sc.IsClosed(), "IsClosed() shall return false before Close()") + s.False(typeutil.IsChanClosed(sc.CloseCh()), "CloseCh() returned channel shall not be closed before Close()") + + s.NotPanics(func() { + sc.Close() + }, "SafeChan shall not panic during first close") + + s.True(sc.IsClosed(), "IsClosed() shall return true after Close()") + s.True(typeutil.IsChanClosed(sc.CloseCh()), "CloseCh() returned channel shall be closed after Close()") + + s.NotPanics(func() { + sc.Close() + }, "SafeChan shall not panic during second close") + + s.True(sc.IsClosed(), "IsClosed() shall return true after double Close()") + s.True(typeutil.IsChanClosed(sc.CloseCh()), "CloseCh() returned channel shall be still closed after double Close()") +} + +func TestSafeChan(t *testing.T) { + suite.Run(t, new(SafeChanSuite)) +} diff --git a/pkg/util/lifetime/state.go b/pkg/util/lifetime/state.go new file mode 100644 index 0000000000000..c94be9e9443f6 --- /dev/null +++ b/pkg/util/lifetime/state.go @@ -0,0 +1,69 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lifetime + +import ( + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +// Singal alias for chan struct{}. +type Signal chan struct{} + +// BiState provides pre-defined simple binary state - normal or closed. +type BiState int32 + +const ( + Normal BiState = 0 + Closed BiState = 1 +) + +// State provides pre-defined three stage state. +type State int32 + +const ( + Initializing State = iota + Working + Stopped +) + +func (s State) String() string { + switch s { + case Initializing: + return "Initializing" + case Working: + return "Working" + case Stopped: + return "Stopped" + } + + return "Unknown" +} + +func NotStopped(state State) error { + if state != Stopped { + return nil + } + return merr.WrapErrServiceNotReady(paramtable.GetRole(), paramtable.GetNodeID(), state.String()) +} + +func IsWorking(state State) error { + if state == Working { + return nil + } + return merr.WrapErrServiceNotReady(paramtable.GetRole(), paramtable.GetNodeID(), state.String()) +} diff --git a/pkg/util/lock/key_lock.go b/pkg/util/lock/key_lock.go index 990d2f0d03e3d..97910aed7bd14 100644 --- a/pkg/util/lock/key_lock.go +++ b/pkg/util/lock/key_lock.go @@ -45,19 +45,19 @@ func newRefLock() *RefLock { return &c } -type KeyLock struct { +type KeyLock[K comparable] struct { keyLocksMutex sync.Mutex - refLocks map[string]*RefLock + refLocks map[K]*RefLock } -func NewKeyLock() *KeyLock { - keyLock := KeyLock{ - refLocks: make(map[string]*RefLock), +func NewKeyLock[K comparable]() *KeyLock[K] { + keyLock := KeyLock[K]{ + refLocks: make(map[K]*RefLock), } return &keyLock } -func (k *KeyLock) Lock(key string) { +func (k *KeyLock[K]) Lock(key K) { k.keyLocksMutex.Lock() // update the key map if keyLock, ok := k.refLocks[key]; ok { @@ -76,12 +76,12 @@ func (k *KeyLock) Lock(key string) { } } -func (k *KeyLock) Unlock(lockedKey string) { +func (k *KeyLock[K]) Unlock(lockedKey K) { k.keyLocksMutex.Lock() defer k.keyLocksMutex.Unlock() keyLock, ok := k.refLocks[lockedKey] if !ok { - log.Warn("Unlocking non-existing key", zap.String("key", lockedKey)) + log.Warn("Unlocking non-existing key", zap.Any("key", lockedKey)) return } keyLock.unref() @@ -91,7 +91,7 @@ func (k *KeyLock) Unlock(lockedKey string) { keyLock.mutex.Unlock() } -func (k *KeyLock) RLock(key string) { +func (k *KeyLock[K]) RLock(key K) { k.keyLocksMutex.Lock() // update the key map if keyLock, ok := k.refLocks[key]; ok { @@ -110,12 +110,12 @@ func (k *KeyLock) RLock(key string) { } } -func (k *KeyLock) RUnlock(lockedKey string) { +func (k *KeyLock[K]) RUnlock(lockedKey K) { k.keyLocksMutex.Lock() defer k.keyLocksMutex.Unlock() keyLock, ok := k.refLocks[lockedKey] if !ok { - log.Warn("Unlocking non-existing key", zap.String("key", lockedKey)) + log.Warn("Unlocking non-existing key", zap.Any("key", lockedKey)) return } keyLock.unref() @@ -125,7 +125,7 @@ func (k *KeyLock) RUnlock(lockedKey string) { keyLock.mutex.RUnlock() } -func (k *KeyLock) size() int { +func (k *KeyLock[K]) size() int { k.keyLocksMutex.Lock() defer k.keyLocksMutex.Unlock() return len(k.refLocks) diff --git a/pkg/util/lock/key_lock_test.go b/pkg/util/lock/key_lock_test.go index 9d06af0a82e4f..46002b9ed4176 100644 --- a/pkg/util/lock/key_lock_test.go +++ b/pkg/util/lock/key_lock_test.go @@ -11,7 +11,7 @@ import ( func TestKeyLock(t *testing.T) { keys := []string{"Milvus", "Blazing", "Fast"} - keyLock := NewKeyLock() + keyLock := NewKeyLock[string]() keyLock.Lock(keys[0]) keyLock.Lock(keys[1]) @@ -46,7 +46,7 @@ func TestKeyLock(t *testing.T) { func TestKeyRLock(t *testing.T) { keys := []string{"Milvus", "Blazing", "Fast"} - keyLock := NewKeyLock() + keyLock := NewKeyLock[string]() keyLock.RLock(keys[0]) keyLock.RLock(keys[0]) diff --git a/pkg/util/lock/metric_mutex.go b/pkg/util/lock/metric_mutex.go index 32c0a5117f869..4c8d4fb8efd78 100644 --- a/pkg/util/lock/metric_mutex.go +++ b/pkg/util/lock/metric_mutex.go @@ -5,10 +5,11 @@ import ( "time" "github.com/cockroachdb/errors" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/paramtable" - "go.uber.org/zap" ) type MetricsLockManager struct { diff --git a/pkg/util/lock/metrics_mutex_test.go b/pkg/util/lock/metrics_mutex_test.go index 8db946c9666e7..293f109b8c269 100644 --- a/pkg/util/lock/metrics_mutex_test.go +++ b/pkg/util/lock/metrics_mutex_test.go @@ -5,8 +5,9 @@ import ( "testing" "time" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func TestMetricsLockLock(t *testing.T) { @@ -23,15 +24,15 @@ func TestMetricsLockLock(t *testing.T) { testRWLock := lManager.applyRWLock(lName) wg := sync.WaitGroup{} testRWLock.Lock("main_thread") + wg.Add(1) go func() { - wg.Add(1) + defer wg.Done() before := time.Now() testRWLock.Lock("sub_thread") lkDuration := time.Since(before) assert.True(t, lkDuration >= lockDuration) testRWLock.UnLock("sub_threadXX") testRWLock.UnLock("sub_thread") - wg.Done() }() time.Sleep(lockDuration) testRWLock.UnLock("main_thread") @@ -52,14 +53,14 @@ func TestMetricsLockRLock(t *testing.T) { testRWLock := lManager.applyRWLock(lName) wg := sync.WaitGroup{} testRWLock.RLock("main_thread") + wg.Add(1) go func() { - wg.Add(1) + defer wg.Done() before := time.Now() testRWLock.Lock("sub_thread") lkDuration := time.Since(before) assert.True(t, lkDuration >= lockDuration) testRWLock.UnLock("sub_thread") - wg.Done() }() time.Sleep(lockDuration) assert.Equal(t, 1, len(testRWLock.acquireTimeMap)) diff --git a/pkg/util/logutil/logutil_test.go b/pkg/util/logutil/logutil_test.go index c177d8b092820..3506852327034 100644 --- a/pkg/util/logutil/logutil_test.go +++ b/pkg/util/logutil/logutil_test.go @@ -27,5 +27,4 @@ func TestName(t *testing.T) { wrapper.Error("Testing") wrapper.Errorln("Testing") wrapper.Errorf("%s", "Testing") - } diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index 92e19ddc3aa79..4348ccec7359c 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -22,7 +22,7 @@ import ( ) const ( - retriableFlag = 1 << 16 + retryableFlag = 1 << 16 CanceledCode int32 = 10000 TimeoutCode int32 = 10001 ) @@ -38,8 +38,11 @@ var ( ErrServiceMemoryLimitExceeded = newMilvusError("memory limit exceeded", 3, false) ErrServiceRequestLimitExceeded = newMilvusError("request limit exceeded", 4, true) ErrServiceInternal = newMilvusError("service internal error", 5, false) // Never return this error out of Milvus - ErrCrossClusterRouting = newMilvusError("cross cluster routing", 6, false) + ErrServiceCrossClusterRouting = newMilvusError("cross cluster routing", 6, false) ErrServiceDiskLimitExceeded = newMilvusError("disk limit exceeded", 7, false) + ErrServiceRateLimit = newMilvusError("rate limit exceeded", 8, true) + ErrServiceForceDeny = newMilvusError("force deny", 9, false) + ErrServiceUnimplemented = newMilvusError("service unimplemented", 10, false) // Collection related ErrCollectionNotFound = newMilvusError("collection not found", 100, false) @@ -48,9 +51,9 @@ var ( ErrCollectionNotFullyLoaded = newMilvusError("collection not fully loaded", 103, true) // Partition related - ErrPartitionNotFound = newMilvusError("partition not found", 202, false) - ErrPartitionNotLoaded = newMilvusError("partition not loaded", 203, false) - ErrPartitionNotFullyLoaded = newMilvusError("collection not fully loaded", 103, true) + ErrPartitionNotFound = newMilvusError("partition not found", 200, false) + ErrPartitionNotLoaded = newMilvusError("partition not loaded", 201, false) + ErrPartitionNotFullyLoaded = newMilvusError("partition not fully loaded", 202, true) // ResourceGroup related ErrResourceGroupNotFound = newMilvusError("resource group not found", 300, false) @@ -72,12 +75,14 @@ var ( ErrSegmentReduplicate = newMilvusError("segment reduplicates", 603, false) // Index related - ErrIndexNotFound = newMilvusError("index not found", 700, false) + ErrIndexNotFound = newMilvusError("index not found", 700, false) + ErrIndexNotSupported = newMilvusError("index type not supported", 701, false) + ErrIndexDuplicate = newMilvusError("index duplicates", 702, false) // Database related - ErrDatabaseNotfound = newMilvusError("database not found", 800, false) + ErrDatabaseNotFound = newMilvusError("database not found", 800, false) ErrDatabaseNumLimitExceeded = newMilvusError("exceeded the limit number of database", 801, false) - ErrInvalidedDatabaseName = newMilvusError("invalided database name", 802, false) + ErrDatabaseInvalidName = newMilvusError("invalid database name", 802, false) // Node related ErrNodeNotFound = newMilvusError("node not found", 901, false) @@ -96,15 +101,29 @@ var ( // Metrics related ErrMetricNotFound = newMilvusError("metric not found", 1200, false) - // Topic related - ErrTopicNotFound = newMilvusError("topic not found", 1300, false) - ErrTopicNotEmpty = newMilvusError("topic not empty", 1301, false) + // Message queue related + ErrMqTopicNotFound = newMilvusError("topic not found", 1300, false) + ErrMqTopicNotEmpty = newMilvusError("topic not empty", 1301, false) + ErrMqInternal = newMilvusError("message queue internal error", 1302, false) + ErrDenyProduceMsg = newMilvusError("deny to write the message to mq", 1303, false) + + // Privilege related + // this operation is denied because the user not authorized, user need to login in first + ErrPrivilegeNotAuthenticated = newMilvusError("not authenticated", 1400, false) + // this operation is denied because the user has no permission to do this, user need higher privilege + ErrPrivilegeNotPermitted = newMilvusError("privilege not permitted", 1401, false) + + // Alias related + ErrAliasNotFound = newMilvusError("alias not found", 1600, false) + ErrAliasCollectionNameConfilct = newMilvusError("alias and collection name conflict", 1601, false) + ErrAliasAlreadyExist = newMilvusError("alias already exist", 1602, false) // field related - ErrFieldNotFound = newMilvusError("field not found", 1700, false) + ErrFieldNotFound = newMilvusError("field not found", 1700, false) + ErrFieldInvalidName = newMilvusError("field name invalid", 1701, false) // high-level restful api related - ErrNeedAuthenticate = newMilvusError("user hasn't authenticate", 1800, false) + ErrNeedAuthenticate = newMilvusError("user hasn't authenticated", 1800, false) ErrIncorrectParameterFormat = newMilvusError("can only accept json format request", 1801, false) ErrMissingRequiredParameters = newMilvusError("missing required parameters", 1802, false) ErrMarshalCollectionSchema = newMilvusError("fail to marshal collection schema", 1803, false) @@ -112,6 +131,15 @@ var ( ErrInvalidSearchResult = newMilvusError("fail to parse search result", 1805, false) ErrCheckPrimaryKey = newMilvusError("please check the primary key and its' type can only in [int, string]", 1806, false) + // replicate related + ErrDenyReplicateMessage = newMilvusError("deny to use the replicate message in the normal instance", 1900, false) + ErrInvalidMsgBytes = newMilvusError("invalid replicate msg bytes", 1901, false) + ErrNoAssignSegmentID = newMilvusError("no assign segment id", 1902, false) + ErrInvalidStreamObj = newMilvusError("invalid stream object", 1903, false) + + // Segcore related + ErrSegcore = newMilvusError("segcore error", 2000, false) + // Do NOT export this, // never allow programmer using this, keep only for converting unknown error to milvusError errUnexpected = newMilvusError("unexpected error", (1<<16)-1, false) @@ -124,7 +152,7 @@ type milvusError struct { func newMilvusError(msg string, code int32, retriable bool) milvusError { if retriable { - code |= retriableFlag + code |= retryableFlag } return milvusError{ msg: msg, diff --git a/pkg/util/merr/errors_test.go b/pkg/util/merr/errors_test.go index 4710ff7eb2921..42b2477c9321a 100644 --- a/pkg/util/merr/errors_test.go +++ b/pkg/util/merr/errors_test.go @@ -21,9 +21,9 @@ import ( "testing" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -70,14 +70,15 @@ func (s *ErrSuite) TestStatusWithCode() { func (s *ErrSuite) TestWrap() { // Service related - s.ErrorIs(WrapErrServiceNotReady("init", "test init..."), ErrServiceNotReady) + s.ErrorIs(WrapErrServiceNotReady("test", 0, "test init..."), ErrServiceNotReady) s.ErrorIs(WrapErrServiceUnavailable("test", "test init"), ErrServiceUnavailable) s.ErrorIs(WrapErrServiceMemoryLimitExceeded(110, 100, "MLE"), ErrServiceMemoryLimitExceeded) s.ErrorIs(WrapErrServiceRequestLimitExceeded(100, "too many requests"), ErrServiceRequestLimitExceeded) s.ErrorIs(WrapErrServiceInternal("never throw out"), ErrServiceInternal) - s.ErrorIs(WrapErrCrossClusterRouting("ins-0", "ins-1"), ErrCrossClusterRouting) + s.ErrorIs(WrapErrServiceCrossClusterRouting("ins-0", "ins-1"), ErrServiceCrossClusterRouting) s.ErrorIs(WrapErrServiceDiskLimitExceeded(110, 100, "DLE"), ErrServiceDiskLimitExceeded) s.ErrorIs(WrapErrNodeNotMatch(0, 1, "SIM"), ErrNodeNotMatch) + s.ErrorIs(WrapErrServiceUnimplemented(errors.New("mock grpc err")), ErrServiceUnimplemented) // Collection related s.ErrorIs(WrapErrCollectionNotFound("test_collection", "failed to get collection"), ErrCollectionNotFound) @@ -109,6 +110,9 @@ func (s *ErrSuite) TestWrap() { // Index related s.ErrorIs(WrapErrIndexNotFound("failed to get Index"), ErrIndexNotFound) + s.ErrorIs(WrapErrIndexNotFoundForCollection("milvus_hello", "failed to get collection index"), ErrIndexNotFound) + s.ErrorIs(WrapErrIndexNotFoundForSegment(100, "failed to get collection index"), ErrIndexNotFound) + s.ErrorIs(WrapErrIndexNotSupported("wsnh", "failed to create index"), ErrIndexNotSupported) // Node related s.ErrorIs(WrapErrNodeNotFound(1, "failed to get node"), ErrNodeNotFound) @@ -126,14 +130,28 @@ func (s *ErrSuite) TestWrap() { // Metrics related s.ErrorIs(WrapErrMetricNotFound("unknown", "failed to get metric"), ErrMetricNotFound) - // Topic related - s.ErrorIs(WrapErrTopicNotFound("unknown", "failed to get topic"), ErrTopicNotFound) - s.ErrorIs(WrapErrTopicNotEmpty("unknown", "topic is not empty"), ErrTopicNotEmpty) + // Message queue related + s.ErrorIs(WrapErrMqTopicNotFound("unknown", "failed to get topic"), ErrMqTopicNotFound) + s.ErrorIs(WrapErrMqTopicNotEmpty("unknown", "topic is not empty"), ErrMqTopicNotEmpty) + s.ErrorIs(WrapErrMqInternal(errors.New("unknown"), "failed to consume"), ErrMqInternal) // field related s.ErrorIs(WrapErrFieldNotFound("meta", "failed to get field"), ErrFieldNotFound) } +func (s *ErrSuite) TestOldCode() { + s.ErrorIs(OldCodeToMerr(commonpb.ErrorCode_NotReadyServe), ErrServiceNotReady) + s.ErrorIs(OldCodeToMerr(commonpb.ErrorCode_CollectionNotExists), ErrCollectionNotFound) + s.ErrorIs(OldCodeToMerr(commonpb.ErrorCode_IllegalArgument), ErrParameterInvalid) + s.ErrorIs(OldCodeToMerr(commonpb.ErrorCode_NodeIDNotMatch), ErrNodeNotMatch) + s.ErrorIs(OldCodeToMerr(commonpb.ErrorCode_InsufficientMemoryToLoad), ErrServiceMemoryLimitExceeded) + s.ErrorIs(OldCodeToMerr(commonpb.ErrorCode_MemoryQuotaExhausted), ErrServiceMemoryLimitExceeded) + s.ErrorIs(OldCodeToMerr(commonpb.ErrorCode_DiskQuotaExhausted), ErrServiceDiskLimitExceeded) + s.ErrorIs(OldCodeToMerr(commonpb.ErrorCode_RateLimit), ErrServiceRateLimit) + s.ErrorIs(OldCodeToMerr(commonpb.ErrorCode_ForceDeny), ErrServiceForceDeny) + s.ErrorIs(OldCodeToMerr(commonpb.ErrorCode_UnexpectedError), errUnexpected) +} + func (s *ErrSuite) TestCombine() { var ( errFirst = errors.New("first") @@ -166,6 +184,46 @@ func (s *ErrSuite) TestCombineCode() { s.Equal(Code(ErrCollectionNotFound), Code(err)) } +func (s *ErrSuite) TestIsHealthy() { + type testCase struct { + code commonpb.StateCode + expect bool + } + + cases := []testCase{ + {commonpb.StateCode_Healthy, true}, + {commonpb.StateCode_Initializing, false}, + {commonpb.StateCode_Abnormal, false}, + {commonpb.StateCode_StandBy, false}, + {commonpb.StateCode_Stopping, false}, + } + for _, tc := range cases { + s.Run(tc.code.String(), func() { + s.Equal(tc.expect, IsHealthy(tc.code) == nil) + }) + } +} + +func (s *ErrSuite) TestIsHealthyOrStopping() { + type testCase struct { + code commonpb.StateCode + expect bool + } + + cases := []testCase{ + {commonpb.StateCode_Healthy, true}, + {commonpb.StateCode_Initializing, false}, + {commonpb.StateCode_Abnormal, false}, + {commonpb.StateCode_StandBy, false}, + {commonpb.StateCode_Stopping, true}, + } + for _, tc := range cases { + s.Run(tc.code.String(), func() { + s.Equal(tc.expect, IsHealthyOrStopping(tc.code) == nil) + }) + } +} + func TestErrors(t *testing.T) { suite.Run(t, new(ErrSuite)) } diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index 3995299c49bcc..f5e174aab65a1 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -18,19 +18,13 @@ package merr import ( "context" - "fmt" "strings" "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" -) -var ( - // For compatibility - oldErrCodes = map[int32]commonpb.ErrorCode{ - ErrServiceNotReady.code(): commonpb.ErrorCode_NotReadyServe, - ErrCollectionNotFound.code(): commonpb.ErrorCode_CollectionNotExists, - } + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) // Code returns the error code of the given error, @@ -56,8 +50,12 @@ func Code(err error) int32 { } } -func IsRetriable(err error) bool { - return Code(err)&retriableFlag != 0 +func IsRetryableErr(err error) bool { + return IsRetryableCode(Code(err)) +} + +func IsRetryableCode(code int32) bool { + return code&retryableFlag != 0 } func IsCanceledOrTimeout(err error) bool { @@ -80,6 +78,30 @@ func Status(err error) *commonpb.Status { } } +func CheckRPCCall(resp any, err error) error { + if err != nil { + return err + } + if resp == nil { + return errUnexpected + } + switch resp := resp.(type) { + case interface{ GetStatus() *commonpb.Status }: + return Error(resp.GetStatus()) + case *commonpb.Status: + return Error(resp) + } + return nil +} + +func Success(reason ...string) *commonpb.Status { + status := Status(nil) + // NOLINT + status.Reason = strings.Join(reason, " ") + return status +} + +// Deprecated func StatusWithErrorCode(err error, code commonpb.ErrorCode) *commonpb.Status { if err == nil { return &commonpb.Status{} @@ -96,25 +118,87 @@ func oldCode(code int32) commonpb.ErrorCode { switch code { case ErrServiceNotReady.code(): return commonpb.ErrorCode_NotReadyServe + case ErrCollectionNotFound.code(): return commonpb.ErrorCode_CollectionNotExists + case ErrParameterInvalid.code(): return commonpb.ErrorCode_IllegalArgument + case ErrNodeNotMatch.code(): return commonpb.ErrorCode_NodeIDNotMatch + case ErrCollectionNotFound.code(), ErrPartitionNotFound.code(), ErrReplicaNotFound.code(): return commonpb.ErrorCode_MetaFailed + case ErrReplicaNotAvailable.code(), ErrChannelNotAvailable.code(), ErrNodeNotAvailable.code(): return commonpb.ErrorCode_NoReplicaAvailable + case ErrServiceMemoryLimitExceeded.code(): return commonpb.ErrorCode_InsufficientMemoryToLoad + + case ErrServiceRateLimit.code(): + return commonpb.ErrorCode_RateLimit + + case ErrServiceForceDeny.code(): + return commonpb.ErrorCode_ForceDeny + + case ErrIndexNotFound.code(): + return commonpb.ErrorCode_IndexNotExist + + case ErrSegmentNotFound.code(): + return commonpb.ErrorCode_SegmentNotFound + + case ErrChannelLack.code(): + return commonpb.ErrorCode_MetaFailed + default: return commonpb.ErrorCode_UnexpectedError } } +func OldCodeToMerr(code commonpb.ErrorCode) error { + switch code { + case commonpb.ErrorCode_NotReadyServe: + return ErrServiceNotReady + + case commonpb.ErrorCode_CollectionNotExists: + return ErrCollectionNotFound + + case commonpb.ErrorCode_IllegalArgument: + return ErrParameterInvalid + + case commonpb.ErrorCode_NodeIDNotMatch: + return ErrNodeNotMatch + + case commonpb.ErrorCode_InsufficientMemoryToLoad, commonpb.ErrorCode_MemoryQuotaExhausted: + return ErrServiceMemoryLimitExceeded + + case commonpb.ErrorCode_DiskQuotaExhausted: + return ErrServiceDiskLimitExceeded + + case commonpb.ErrorCode_RateLimit: + return ErrServiceRateLimit + + case commonpb.ErrorCode_ForceDeny: + return ErrServiceForceDeny + + case commonpb.ErrorCode_IndexNotExist: + return ErrIndexNotFound + + case commonpb.ErrorCode_SegmentNotFound: + return ErrSegmentNotFound + + case commonpb.ErrorCode_MetaFailed: + return ErrChannelNotFound + + default: + return errUnexpected + } +} + func Ok(status *commonpb.Status) bool { - return status.ErrorCode == commonpb.ErrorCode_Success && status.Code == 0 + return status.GetErrorCode() == commonpb.ErrorCode_Success && status.GetCode() == 0 } // Error returns a error according to the given status, @@ -127,10 +211,9 @@ func Error(status *commonpb.Status) error { // use code first code := status.GetCode() if code == 0 { - return newMilvusError(fmt.Sprintf("legacy error code:%d, reason: %s", status.GetErrorCode(), status.GetReason()), errUnexpected.errCode, false) + return newMilvusError(status.GetReason(), Code(OldCodeToMerr(status.GetErrorCode())), false) } - - return newMilvusError(status.GetReason(), code, code&retriableFlag != 0) + return newMilvusError(status.GetReason(), code, code&retryableFlag != 0) } // CheckHealthy checks whether the state is healthy, @@ -138,15 +221,59 @@ func Error(status *commonpb.Status) error { // otherwise returns ErrServiceNotReady wrapped with current state func CheckHealthy(state commonpb.StateCode) error { if state != commonpb.StateCode_Healthy { - return WrapErrServiceNotReady(state.String()) + return WrapErrServiceNotReady(paramtable.GetRole(), paramtable.GetNodeID(), state.String()) + } + + return nil +} + +// CheckHealthyStandby checks whether the state is healthy or standby, +// returns nil if healthy or standby +// otherwise returns ErrServiceNotReady wrapped with current state +// this method only used in GetMetrics +func CheckHealthyStandby(state commonpb.StateCode) error { + if state != commonpb.StateCode_Healthy && state != commonpb.StateCode_StandBy { + return WrapErrServiceNotReady(paramtable.GetRole(), paramtable.GetNodeID(), state.String()) + } + + return nil +} + +func IsHealthy(stateCode commonpb.StateCode) error { + if stateCode == commonpb.StateCode_Healthy { + return nil + } + return CheckHealthy(stateCode) +} + +func IsHealthyOrStopping(stateCode commonpb.StateCode) error { + if stateCode == commonpb.StateCode_Healthy || stateCode == commonpb.StateCode_Stopping { + return nil + } + return CheckHealthy(stateCode) +} + +func AnalyzeState(role string, nodeID int64, state *milvuspb.ComponentStates) error { + if err := Error(state.GetStatus()); err != nil { + return errors.Wrapf(err, "%s=%d not healthy", role, nodeID) + } else if state := state.GetState().GetStateCode(); state != commonpb.StateCode_Healthy { + return WrapErrServiceNotReady(role, nodeID, state.String()) + } + + return nil +} + +func CheckTargetID(msg *commonpb.MsgBase) error { + if msg.GetTargetID() != paramtable.GetNodeID() { + return WrapErrNodeNotMatch(paramtable.GetNodeID(), msg.GetTargetID()) } return nil } // Service related -func WrapErrServiceNotReady(stage string, msg ...string) error { - err := errors.Wrapf(ErrServiceNotReady, "stage=%s", stage) +func WrapErrServiceNotReady(role string, sessionID int64, state string, msg ...string) error { + err := errors.Wrapf(ErrServiceNotReady, "%s=%d stage=%s", role, sessionID, state) if len(msg) > 0 { err = errors.Wrap(err, strings.Join(msg, "; ")) } @@ -184,8 +311,8 @@ func WrapErrServiceInternal(msg string, others ...string) error { return err } -func WrapErrCrossClusterRouting(expectedCluster, actualCluster string, msg ...string) error { - err := errors.Wrapf(ErrCrossClusterRouting, "expectedCluster=%s, actualCluster=%s", expectedCluster, actualCluster) +func WrapErrServiceCrossClusterRouting(expectedCluster, actualCluster string, msg ...string) error { + err := errors.Wrapf(ErrServiceCrossClusterRouting, "expectedCluster=%s, actualCluster=%s", expectedCluster, actualCluster) if len(msg) > 0 { err = errors.Wrap(err, strings.Join(msg, "; ")) } @@ -200,8 +327,24 @@ func WrapErrServiceDiskLimitExceeded(predict, limit float32, msg ...string) erro return err } +func WrapErrServiceRateLimit(rate float64) error { + err := errors.Wrapf(ErrServiceRateLimit, "rate=%v", rate) + return err +} + +func WrapErrServiceForceDeny(op string, reason error, method string) error { + err := errors.Wrapf(ErrServiceForceDeny, "deny to %s, reason: %s, req: %s", op, reason.Error(), method) + return err +} + +func WrapErrServiceUnimplemented(grpcErr error) error { + err := errors.Wrapf(ErrServiceUnimplemented, "err: %s", grpcErr.Error()) + return err +} + +// database related func WrapErrDatabaseNotFound(database any, msg ...string) error { - err := wrapWithField(ErrDatabaseNotfound, "database", database) + err := wrapWithField(ErrDatabaseNotFound, "database", database) if len(msg) > 0 { err = errors.Wrap(err, strings.Join(msg, "; ")) } @@ -216,8 +359,8 @@ func WrapErrDatabaseResourceLimitExceeded(msg ...string) error { return err } -func WrapErrInvalidedDatabaseName(database any, msg ...string) error { - err := wrapWithField(ErrInvalidedDatabaseName, "database", database) +func WrapErrDatabaseNameInvalid(database any, msg ...string) error { + err := wrapWithField(ErrDatabaseInvalidName, "database", database) if len(msg) > 0 { err = errors.Wrap(err, strings.Join(msg, "; ")) } @@ -265,6 +408,30 @@ func WrapErrCollectionNotFullyLoaded(collection any, msg ...string) error { return err } +func WrapErrAliasNotFound(db any, alias any, msg ...string) error { + err := errors.Wrapf(ErrAliasNotFound, "alias %v:%v", db, alias) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + +func WrapErrAliasCollectionNameConflict(db any, alias any, msg ...string) error { + err := errors.Wrapf(ErrAliasCollectionNameConfilct, "alias %v:%v", db, alias) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + +func WrapErrAliasAlreadyExist(db any, alias any, msg ...string) error { + err := errors.Wrapf(ErrAliasAlreadyExist, "alias %v:%v already exist", db, alias) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + // Partition related func WrapErrPartitionNotFound(partition any, msg ...string) error { err := wrapWithField(ErrPartitionNotFound, "partition", partition) @@ -383,8 +550,40 @@ func WrapErrSegmentReduplicate(id int64, msg ...string) error { } // Index related -func WrapErrIndexNotFound(msg ...string) error { - err := error(ErrIndexNotFound) +func WrapErrIndexNotFound(indexName string, msg ...string) error { + err := wrapWithField(ErrIndexNotFound, "indexName", indexName) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + +func WrapErrIndexNotFoundForSegment(segmentID int64, msg ...string) error { + err := wrapWithField(ErrIndexNotFound, "segmentID", segmentID) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + +func WrapErrIndexNotFoundForCollection(collection string, msg ...string) error { + err := wrapWithField(ErrIndexNotFound, "collection", collection) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + +func WrapErrIndexNotSupported(indexType string, msg ...string) error { + err := wrapWithField(ErrIndexNotSupported, "indexType", indexType) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + +func WrapErrIndexDuplicate(indexName string, msg ...string) error { + err := wrapWithField(ErrIndexDuplicate, "indexName", indexName) if len(msg) > 0 { err = errors.Wrap(err, strings.Join(msg, "; ")) } @@ -416,6 +615,14 @@ func WrapErrNodeLack(expectedNum, actualNum int64, msg ...string) error { return err } +func WrapErrNodeLackAny(msg ...string) error { + err := error(ErrNodeLack) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + func WrapErrNodeNotAvailable(id int64, msg ...string) error { err := wrapWithField(ErrNodeNotAvailable, "node", id) if len(msg) > 0 { @@ -449,6 +656,14 @@ func WrapErrIoFailed(key string, msg ...string) error { return err } +func WrapErrIoFailedReason(reason string, msg ...string) error { + err := errors.Wrapf(ErrIoFailed, reason) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + // Parameter related func WrapErrParameterInvalid[T any](expected, actual T, msg ...string) error { err := errors.Wrapf(ErrParameterInvalid, "expected=%v, actual=%v", expected, actual) @@ -480,17 +695,44 @@ func WrapErrMetricNotFound(name string, msg ...string) error { return err } -// Topic related -func WrapErrTopicNotFound(name string, msg ...string) error { - err := errors.Wrapf(ErrTopicNotFound, "topic=%s", name) +// Message queue related +func WrapErrMqTopicNotFound(name string, msg ...string) error { + err := errors.Wrapf(ErrMqTopicNotFound, "topic=%s", name) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + +func WrapErrMqTopicNotEmpty(name string, msg ...string) error { + err := errors.Wrapf(ErrMqTopicNotEmpty, "topic=%s", name) if len(msg) > 0 { err = errors.Wrap(err, strings.Join(msg, "; ")) } return err } -func WrapErrTopicNotEmpty(name string, msg ...string) error { - err := errors.Wrapf(ErrTopicNotEmpty, "topic=%s", name) +func WrapErrMqInternal(err error, msg ...string) error { + err = errors.Wrapf(ErrMqInternal, "internal=%v", err) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + +func WrapErrPrivilegeNotAuthenticated(fmt string, args ...any) error { + err := errors.Wrapf(ErrPrivilegeNotAuthenticated, fmt, args...) + return err +} + +func WrapErrPrivilegeNotPermitted(fmt string, args ...any) error { + err := errors.Wrapf(ErrPrivilegeNotPermitted, fmt, args...) + return err +} + +// Segcore related +func WrapErrSegcore(code int32, msg ...string) error { + err := errors.Wrapf(ErrSegcore, "internal code=%v", code) if len(msg) > 0 { err = errors.Wrap(err, strings.Join(msg, "; ")) } @@ -506,6 +748,14 @@ func WrapErrFieldNotFound[T any](field T, msg ...string) error { return err } +func WrapErrFieldNameInvalid(field any, msg ...string) error { + err := wrapWithField(ErrFieldInvalidName, "field", field) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + func wrapWithField(err error, name string, value any) error { return errors.Wrapf(err, "%s=%v", name, value) } diff --git a/pkg/util/metricsinfo/cache.go b/pkg/util/metricsinfo/cache.go index 0dbb817f04496..aae12709e452d 100644 --- a/pkg/util/metricsinfo/cache.go +++ b/pkg/util/metricsinfo/cache.go @@ -106,7 +106,6 @@ func (manager *MetricsCacheManager) GetSystemInfoMetrics() (*milvuspb.GetMetrics if manager.systemInfoMetricsInvalid || manager.systemInfoMetrics == nil || time.Since(manager.systemInfoMetricsLastUpdatedTime) >= retention { - return nil, errInvalidSystemInfosMetricCache } diff --git a/pkg/util/metricsinfo/cache_test.go b/pkg/util/metricsinfo/cache_test.go index 97694ed25bac3..dfbddeb4b3039 100644 --- a/pkg/util/metricsinfo/cache_test.go +++ b/pkg/util/metricsinfo/cache_test.go @@ -15,8 +15,9 @@ import ( "testing" "time" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" ) func Test_NewMetricsCacheManager(t *testing.T) { diff --git a/pkg/util/metricsinfo/metric_type.go b/pkg/util/metricsinfo/metric_type.go index 85e0b183b99b0..60e050315293e 100644 --- a/pkg/util/metricsinfo/metric_type.go +++ b/pkg/util/metricsinfo/metric_type.go @@ -17,7 +17,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/commonpbutil" ) @@ -52,7 +51,7 @@ func ConstructRequestByMetricType(metricType string) (*milvuspb.GetMetricsReques if err != nil { return nil, fmt.Errorf("failed to construct request by metric type %s: %s", metricType, err.Error()) } - //TODO:: switch metricType to different msgType and return err when metricType is not supported + // TODO:: switch metricType to different msgType and return err when metricType is not supported return &milvuspb.GetMetricsRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_SystemInfo), diff --git a/pkg/util/metricsinfo/metric_type_test.go b/pkg/util/metricsinfo/metric_type_test.go index 07b62df4365a8..12414d233cd9e 100644 --- a/pkg/util/metricsinfo/metric_type_test.go +++ b/pkg/util/metricsinfo/metric_type_test.go @@ -58,7 +58,6 @@ func Test_ParseMetricType(t *testing.T) { t.Errorf("ParseMetricType(%s) = %s, but got: %s", test.s, test.want, got) } } - } func Test_ConstructRequestByMetricType(t *testing.T) { diff --git a/pkg/util/metricsinfo/metrics_info.go b/pkg/util/metricsinfo/metrics_info.go index a8f2420a3d9c8..7673e5d0e8f37 100644 --- a/pkg/util/metricsinfo/metrics_info.go +++ b/pkg/util/metricsinfo/metrics_info.go @@ -16,8 +16,7 @@ import ( ) // ComponentInfos defines the interface of all component infos -type ComponentInfos interface { -} +type ComponentInfos interface{} // MarshalComponentInfos returns the json string of ComponentInfos func MarshalComponentInfos(infos ComponentInfos) (string, error) { diff --git a/pkg/util/metricsinfo/topology.go b/pkg/util/metricsinfo/topology.go index a9d8810aefd6a..774cfbcb27cab 100644 --- a/pkg/util/metricsinfo/topology.go +++ b/pkg/util/metricsinfo/topology.go @@ -27,8 +27,7 @@ func ConstructComponentName(role string, id typeutil.UniqueID) string { } // Topology defines the interface of topology graph between different components -type Topology interface { -} +type Topology interface{} // MarshalTopology returns the json string of Topology func MarshalTopology(topology Topology) (string, error) { diff --git a/pkg/util/parameterutil.go/get_max_len.go b/pkg/util/parameterutil.go/get_max_len.go index 58b3645a108b6..fc1aa6d7a3a46 100644 --- a/pkg/util/parameterutil.go/get_max_len.go +++ b/pkg/util/parameterutil.go/get_max_len.go @@ -12,7 +12,7 @@ import ( // GetMaxLength get max length of field. Maybe also helpful outside. func GetMaxLength(field *schemapb.FieldSchema) (int64, error) { - if !typeutil.IsStringType(field.GetDataType()) { + if !typeutil.IsStringType(field.GetDataType()) && !typeutil.IsStringType(field.GetElementType()) { msg := fmt.Sprintf("%s is not of string type", field.GetDataType()) return 0, merr.WrapErrParameterInvalid(schemapb.DataType_VarChar, field.GetDataType(), msg) } @@ -29,3 +29,23 @@ func GetMaxLength(field *schemapb.FieldSchema) (int64, error) { } return int64(maxLength), nil } + +// GetMaxCapacity get max capacity of array field. Maybe also helpful outside. +func GetMaxCapacity(field *schemapb.FieldSchema) (int64, error) { + if !typeutil.IsArrayType(field.GetDataType()) { + msg := fmt.Sprintf("%s is not of array type", field.GetDataType()) + return 0, merr.WrapErrParameterInvalid(schemapb.DataType_Array, field.GetDataType(), msg) + } + h := typeutil.NewKvPairs(append(field.GetIndexParams(), field.GetTypeParams()...)) + maxCapacityStr, err := h.Get(common.MaxCapacityKey) + if err != nil { + msg := "max capacity not found" + return 0, merr.WrapErrParameterInvalid("max capacity key in type parameters", "not found", msg) + } + maxCapacity, err := strconv.Atoi(maxCapacityStr) + if err != nil { + msg := fmt.Sprintf("invalid max capacity: %s", maxCapacityStr) + return 0, merr.WrapErrParameterInvalid("value of max length should be of int", maxCapacityStr, msg) + } + return int64(maxCapacity), nil +} diff --git a/pkg/util/parameterutil.go/get_max_len_test.go b/pkg/util/parameterutil.go/get_max_len_test.go index dc94744c9bb1b..cf27715fe03c0 100644 --- a/pkg/util/parameterutil.go/get_max_len_test.go +++ b/pkg/util/parameterutil.go/get_max_len_test.go @@ -3,12 +3,11 @@ package parameterutil import ( "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/pkg/common" - "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" ) func TestGetMaxLength(t *testing.T) { @@ -57,3 +56,53 @@ func TestGetMaxLength(t *testing.T) { assert.Equal(t, int64(100), maxLength) }) } + +func TestGetMaxCapacity(t *testing.T) { + t.Run("not array type", func(t *testing.T) { + f := &schemapb.FieldSchema{ + DataType: schemapb.DataType_Bool, + } + _, err := GetMaxCapacity(f) + assert.Error(t, err) + }) + + t.Run("max capacity not found", func(t *testing.T) { + f := &schemapb.FieldSchema{ + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Double, + } + _, err := GetMaxCapacity(f) + assert.Error(t, err) + }) + + t.Run("max capacity not int", func(t *testing.T) { + f := &schemapb.FieldSchema{ + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "not_int_aha", + }, + }, + } + _, err := GetMaxCapacity(f) + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + f := &schemapb.FieldSchema{ + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + } + maxCap, err := GetMaxCapacity(f) + assert.NoError(t, err) + assert.Equal(t, int64(100), maxCap) + }) +} diff --git a/pkg/util/paramtable/autoindex_param.go b/pkg/util/paramtable/autoindex_param.go index 9fff408ac4a5e..cdc4f5289fd32 100644 --- a/pkg/util/paramtable/autoindex_param.go +++ b/pkg/util/paramtable/autoindex_param.go @@ -19,10 +19,9 @@ package paramtable import ( "fmt" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/util/funcutil" - - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/indexparamcheck" ) diff --git a/pkg/util/paramtable/autoindex_param_test.go b/pkg/util/paramtable/autoindex_param_test.go index 0dcdd994544e9..1c4f262a4736c 100644 --- a/pkg/util/paramtable/autoindex_param_test.go +++ b/pkg/util/paramtable/autoindex_param_test.go @@ -21,13 +21,11 @@ import ( "strconv" "testing" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" - - "github.com/milvus-io/milvus/pkg/config" - "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/config" + "github.com/milvus-io/milvus/pkg/util/indexparamcheck" ) const ( @@ -41,8 +39,8 @@ func TestAutoIndexParams_build(t *testing.T) { CParams.Init(bt) t.Run("test parseBuildParams success", func(t *testing.T) { - //Params := CParams.AutoIndexConfig - //buildParams := make([string]interface) + // Params := CParams.AutoIndexConfig + // buildParams := make([string]interface) var err error map1 := map[string]any{ IndexTypeKey: "HNSW", diff --git a/pkg/util/paramtable/base_table.go b/pkg/util/paramtable/base_table.go index 697df59ceb032..db441b04b9545 100644 --- a/pkg/util/paramtable/base_table.go +++ b/pkg/util/paramtable/base_table.go @@ -51,11 +51,12 @@ const ( DefaultKnowhereThreadPoolNumRatioInBuild = 1 DefaultMinioRegion = "" DefaultMinioUseVirtualHost = "false" + DefaultMinioRequestTimeout = "3000" ) // Const of Global Config List func globalConfigPrefixs() []string { - return []string{"metastore", "localStorage", "etcd", "minio", "pulsar", "kafka", "rocksmq", "log", "grpc", "common", "quotaAndLimits"} + return []string{"metastore", "localStorage", "etcd", "tikv", "minio", "pulsar", "kafka", "rocksmq", "log", "grpc", "common", "quotaAndLimits"} } var defaultYaml = []string{"milvus.yaml"} @@ -146,7 +147,6 @@ func (bt *BaseTable) init() { if !bt.config.skipRemote { bt.initConfigsFromRemote() } - log.Info("Got Config", zap.Any("configs", bt.mgr.GetConfigs())) } func (bt *BaseTable) initConfigsFromLocal() { diff --git a/pkg/util/paramtable/base_table_test.go b/pkg/util/paramtable/base_table_test.go index bcba2b03647da..5fe37cb51ac5b 100644 --- a/pkg/util/paramtable/base_table_test.go +++ b/pkg/util/paramtable/base_table_test.go @@ -110,7 +110,7 @@ func TestBaseTable_Get(t *testing.T) { } func TestBaseTable_Pulsar(t *testing.T) { - //test PULSAR ADDRESS + // test PULSAR ADDRESS t.Setenv("PULSAR_ADDRESS", "pulsar://localhost:6650") baseParams.init() diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 0ac622d76a75a..61f3e105928fb 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -33,9 +33,9 @@ const ( DefaultIndexSliceSize = 16 DefaultGracefulTime = 5000 // ms DefaultGracefulStopTimeout = 1800 // s - DefaultHighPriorityThreadCoreCoefficient = 100 - DefaultMiddlePriorityThreadCoreCoefficient = 50 - DefaultLowPriorityThreadCoreCoefficient = 10 + DefaultHighPriorityThreadCoreCoefficient = 10 + DefaultMiddlePriorityThreadCoreCoefficient = 5 + DefaultLowPriorityThreadCoreCoefficient = 1 DefaultSessionTTL = 60 // s DefaultSessionRetryTimes = 30 @@ -162,6 +162,7 @@ type commonConfig struct { RootCoordTimeTick ParamItem `refreshable:"true"` RootCoordStatistics ParamItem `refreshable:"true"` RootCoordDml ParamItem `refreshable:"false"` + ReplicateMsgChannel ParamItem `refreshable:"false"` QueryCoordTimeTick ParamItem `refreshable:"true"` @@ -212,12 +213,13 @@ type commonConfig struct { MetricsPort ParamItem `refreshable:"false"` - //lock related params + // lock related params EnableLockMetrics ParamItem `refreshable:"false"` LockSlowLogInfoThreshold ParamItem `refreshable:"true"` LockSlowLogWarnThreshold ParamItem `refreshable:"true"` EnableStorageV2 ParamItem `refreshable:"false"` + TTMsgEnabled ParamItem `refreshable:"true"` } func (p *commonConfig) init(base *BaseTable) { @@ -268,6 +270,16 @@ func (p *commonConfig) init(base *BaseTable) { } p.RootCoordDml.Init(base.mgr) + p.ReplicateMsgChannel = ParamItem{ + Key: "msgChannel.chanNamePrefix.replicateMsg", + Version: "2.3.2", + FallbackKeys: []string{"common.chanNamePrefix.replicateMsg"}, + PanicIfEmpty: true, + Formatter: chanNamePrefix, + Export: true, + } + p.ReplicateMsgChannel.Init(base.mgr) + p.QueryCoordTimeTick = ParamItem{ Key: "msgChannel.chanNamePrefix.queryTimeTick", Version: "2.1.0", @@ -613,14 +625,22 @@ like the old password verification when updating the credential`, Doc: "minimum milliseconds for printing durations in warn level", Export: true, } - p.LockSlowLogWarnThreshold.Init(base.mgr) + p.EnableStorageV2 = ParamItem{ Key: "common.storage.enablev2", Version: "2.3.1", DefaultValue: "false", } p.EnableStorageV2.Init(base.mgr) + + p.TTMsgEnabled = ParamItem{ + Key: "common.ttMsgEnabled", + Version: "2.3.2", + DefaultValue: "true", + Doc: "Whether the instance disable sending ts messages", + } + p.TTMsgEnabled.Init(base.mgr) } type traceConfig struct { @@ -982,7 +1002,7 @@ So adjust at your risk!`, p.MaxTaskNum = ParamItem{ Key: "proxy.maxTaskNum", Version: "2.2.0", - DefaultValue: "1024", + DefaultValue: "10000", Doc: "max task number of proxy task queue", Export: true, } @@ -1096,7 +1116,7 @@ please adjust in embedded Milvus: false`, p.ShardLeaderCacheInterval = ParamItem{ Key: "proxy.shardLeaderCacheInterval", Version: "2.2.4", - DefaultValue: "10", + DefaultValue: "3", Doc: "time interval to update shard leader cache, in seconds", } p.ShardLeaderCacheInterval.Init(base.mgr) @@ -1152,11 +1172,11 @@ type queryCoordConfig struct { TaskMergeCap ParamItem `refreshable:"false"` TaskExecutionCap ParamItem `refreshable:"true"` - //---- Handoff --- - //Deprecated: Since 2.2.2 + // ---- Handoff --- + // Deprecated: Since 2.2.2 AutoHandoff ParamItem `refreshable:"true"` - //---- Balance --- + // ---- Balance --- AutoBalance ParamItem `refreshable:"true"` Balancer ParamItem `refreshable:"true"` GlobalRowCountFactor ParamItem `refreshable:"true"` @@ -1183,18 +1203,19 @@ type queryCoordConfig struct { // Deprecated: Since 2.2.2, use different interval for different checker CheckInterval ParamItem `refreshable:"true"` - NextTargetSurviveTime ParamItem `refreshable:"true"` - UpdateNextTargetInterval ParamItem `refreshable:"false"` - CheckNodeInReplicaInterval ParamItem `refreshable:"false"` - CheckResourceGroupInterval ParamItem `refreshable:"false"` - EnableRGAutoRecover ParamItem `refreshable:"true"` - CheckHealthInterval ParamItem `refreshable:"false"` - CheckHealthRPCTimeout ParamItem `refreshable:"true"` - BrokerTimeout ParamItem `refreshable:"false"` + NextTargetSurviveTime ParamItem `refreshable:"true"` + UpdateNextTargetInterval ParamItem `refreshable:"false"` + CheckNodeInReplicaInterval ParamItem `refreshable:"false"` + CheckResourceGroupInterval ParamItem `refreshable:"false"` + EnableRGAutoRecover ParamItem `refreshable:"true"` + CheckHealthInterval ParamItem `refreshable:"false"` + CheckHealthRPCTimeout ParamItem `refreshable:"true"` + BrokerTimeout ParamItem `refreshable:"false"` + CollectionRecoverTimesLimit ParamItem `refreshable:"true"` } func (p *queryCoordConfig) init(base *BaseTable) { - //---- Task --- + // ---- Task --- p.RetryNum = ParamItem{ Key: "queryCoord.task.retrynum", Version: "2.2.0", @@ -1492,6 +1513,16 @@ func (p *queryCoordConfig) init(base *BaseTable) { Export: true, } p.BrokerTimeout.Init(base.mgr) + + p.CollectionRecoverTimesLimit = ParamItem{ + Key: "queryCoord.collectionRecoverTimes", + Version: "2.3.3", + DefaultValue: "3", + PanicIfEmpty: true, + Doc: "if collection recover times reach the limit during loading state, release it", + Export: true, + } + p.CollectionRecoverTimesLimit.Init(base.mgr) } // ///////////////////////////////////////////////////////////////////////////// @@ -1527,6 +1558,9 @@ type queryNodeConfig struct { CacheMemoryLimit ParamItem `refreshable:"false"` MmapDirPath ParamItem `refreshable:"false"` + // chunk cache + ReadAheadPolicy ParamItem `refreshable:"false"` + GroupEnabled ParamItem `refreshable:"true"` MaxReceiveChanSize ParamItem `refreshable:"false"` MaxUnsolvedQueueSize ParamItem `refreshable:"true"` @@ -1713,6 +1747,14 @@ func (p *queryNodeConfig) init(base *BaseTable) { } p.MmapDirPath.Init(base.mgr) + p.ReadAheadPolicy = ParamItem{ + Key: "queryNode.cache.readAheadPolicy", + Version: "2.3.2", + DefaultValue: "willneed", + Doc: "The read ahead policy of chunk cache, options: `normal, random, sequential, willneed, dontneed`", + } + p.ReadAheadPolicy.Init(base.mgr) + p.GroupEnabled = ParamItem{ Key: "queryNode.grouping.enabled", Version: "2.0.0", @@ -1937,6 +1979,7 @@ type dataCoordConfig struct { WatchTimeoutInterval ParamItem `refreshable:"false"` ChannelBalanceSilentDuration ParamItem `refreshable:"true"` ChannelBalanceInterval ParamItem `refreshable:"true"` + ChannelOperationRPCTimeout ParamItem `refreshable:"true"` // --- SEGMENTS --- SegmentMaxSize ParamItem `refreshable:"false"` @@ -2014,6 +2057,15 @@ func (p *dataCoordConfig) init(base *BaseTable) { } p.ChannelBalanceInterval.Init(base.mgr) + p.ChannelOperationRPCTimeout = ParamItem{ + Key: "dataCoord.channel.notifyChannelOperationTimeout", + Version: "2.2.3", + DefaultValue: "5", + Doc: "Timeout notifing channel operations (in seconds).", + Export: true, + } + p.ChannelOperationRPCTimeout.Init(base.mgr) + p.SegmentMaxSize = ParamItem{ Key: "dataCoord.segment.maxSize", Version: "2.0.0", @@ -2348,9 +2400,12 @@ type dataNodeConfig struct { // watchEvent WatchEventTicklerInterval ParamItem `refreshable:"false"` - // io concurrency to fetch stats logs + // io concurrency to add segment IOConcurrency ParamItem `refreshable:"false"` + // Concurrency to handle compaction file read + FileReadConcurrency ParamItem `refreshable:"false"` + // memory management MemoryForceSyncEnable ParamItem `refreshable:"true"` MemoryForceSyncSegmentNum ParamItem `refreshable:"true"` @@ -2365,6 +2420,9 @@ type dataNodeConfig struct { // Skip BF SkipBFStatsLoad ParamItem `refreshable:"true"` + + // channel + ChannelWorkPoolSize ParamItem `refreshable:"true"` } func (p *dataNodeConfig) init(base *BaseTable) { @@ -2477,10 +2535,17 @@ func (p *dataNodeConfig) init(base *BaseTable) { p.IOConcurrency = ParamItem{ Key: "dataNode.dataSync.ioConcurrency", Version: "2.0.0", - DefaultValue: "10", + DefaultValue: "16", } p.IOConcurrency.Init(base.mgr) + p.FileReadConcurrency = ParamItem{ + Key: "dataNode.multiRead.concurrency", + Version: "2.0.0", + DefaultValue: "16", + } + p.FileReadConcurrency.Init(base.mgr) + p.DataNodeTimeTickByRPC = ParamItem{ Key: "datanode.timetick.byRPC", Version: "2.2.9", @@ -2512,6 +2577,14 @@ func (p *dataNodeConfig) init(base *BaseTable) { DefaultValue: "18000", } p.BulkInsertTimeoutSeconds.Init(base.mgr) + + p.ChannelWorkPoolSize = ParamItem{ + Key: "datanode.channel.workPoolSize", + Version: "2.3.2", + PanicIfEmpty: false, + DefaultValue: "-1", + } + p.ChannelWorkPoolSize.Init(base.mgr) } // ///////////////////////////////////////////////////////////////////////////// @@ -2600,12 +2673,15 @@ func (p *integrationTestConfig) init(base *BaseTable) { func (params *ComponentParam) Save(key string, value string) error { return params.baseTable.Save(key, value) } + func (params *ComponentParam) Remove(key string) error { return params.baseTable.Remove(key) } + func (params *ComponentParam) Reset(key string) error { return params.baseTable.Reset(key) } + func (params *ComponentParam) GetWithDefault(key string, dft string) string { return params.baseTable.GetWithDefault(key, dft) } diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index 5d7d4e196884f..583edf97edee4 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -278,6 +278,7 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, 1000, Params.ChannelCheckInterval.GetAsInt()) assert.Equal(t, 10000, Params.BalanceCheckInterval.GetAsInt()) assert.Equal(t, 10000, Params.IndexCheckInterval.GetAsInt()) + assert.Equal(t, 3, Params.CollectionRecoverTimesLimit.GetAsInt()) }) t.Run("test queryNodeConfig", func(t *testing.T) { @@ -308,6 +309,9 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, 10.0, Params.CPURatio.GetAsFloat()) assert.Equal(t, uint32(runtime.GOMAXPROCS(0)), Params.KnowhereThreadPoolSize.GetAsUint32()) + // chunk cache + assert.Equal(t, "willneed", Params.ReadAheadPolicy.GetValue()) + // test small indexNlist/NProbe default params.Remove("queryNode.segcore.smallIndex.nlist") params.Remove("queryNode.segcore.smallIndex.nprobe") @@ -376,6 +380,10 @@ func TestComponentParam(t *testing.T) { bulkinsertTimeout := &Params.BulkInsertTimeoutSeconds t.Logf("BulkInsertTimeoutSeconds: %v", bulkinsertTimeout) assert.Equal(t, "18000", Params.BulkInsertTimeoutSeconds.GetValue()) + + channelWorkPoolSize := Params.ChannelWorkPoolSize.GetAsInt() + t.Logf("channelWorkPoolSize: %d", channelWorkPoolSize) + assert.Equal(t, -1, Params.ChannelWorkPoolSize.GetAsInt()) }) t.Run("test indexNodeConfig", func(t *testing.T) { diff --git a/pkg/util/paramtable/grpc_param.go b/pkg/util/paramtable/grpc_param.go index 9760b20ef4d69..6264e4a71b9f7 100644 --- a/pkg/util/paramtable/grpc_param.go +++ b/pkg/util/paramtable/grpc_param.go @@ -43,12 +43,10 @@ const ( DefaultKeepAliveTimeout = 20000 // Grpc retry policy - DefaultMaxAttempts = 5 - DefaultInitialBackoff float64 = 1.0 - DefaultMaxBackoff float64 = 10.0 - DefaultBackoffMultiplier float64 = 2.0 - - DefaultCompressionEnabled bool = false + DefaultMaxAttempts = 10 + DefaultInitialBackoff float64 = 0.2 + DefaultMaxBackoff float64 = 10 + DefaultCompressionEnabled bool = true ProxyInternalPort = 19529 ProxyExternalPort = 19530 @@ -69,7 +67,14 @@ type grpcConfig struct { func (p *grpcConfig) init(domain string, base *BaseTable) { p.Domain = domain - p.IP = funcutil.GetLocalIP() + ipItem := ParamItem{ + Key: p.Domain + ".ip", + Version: "2.3.3", + DefaultValue: "", + Export: true, + } + ipItem.Init(base.mgr) + p.IP = funcutil.GetIP(ipItem.GetValue()) p.Port = ParamItem{ Key: p.Domain + ".port", @@ -131,6 +136,8 @@ type GrpcServerConfig struct { ServerMaxSendSize ParamItem `refreshable:"false"` ServerMaxRecvSize ParamItem `refreshable:"false"` + + GracefulStopTimeout ParamItem `refreshable:"true"` } func (p *GrpcServerConfig) Init(domain string, base *BaseTable) { @@ -179,6 +186,15 @@ func (p *GrpcServerConfig) Init(domain string, base *BaseTable) { Export: true, } p.ServerMaxRecvSize.Init(base.mgr) + + p.GracefulStopTimeout = ParamItem{ + Key: "grpc.gracefulStopTimeout", + Version: "2.3.1", + DefaultValue: "10", + Doc: "second, time to wait graceful stop finish", + Export: true, + } + p.GracefulStopTimeout.Init(base.mgr) } // GrpcClientConfig is configuration for grpc client. @@ -194,10 +210,12 @@ type GrpcClientConfig struct { KeepAliveTime ParamItem `refreshable:"false"` KeepAliveTimeout ParamItem `refreshable:"false"` - MaxAttempts ParamItem `refreshable:"false"` - InitialBackoff ParamItem `refreshable:"false"` - MaxBackoff ParamItem `refreshable:"false"` - BackoffMultiplier ParamItem `refreshable:"false"` + MaxAttempts ParamItem `refreshable:"false"` + InitialBackoff ParamItem `refreshable:"false"` + MaxBackoff ParamItem `refreshable:"false"` + MinResetInterval ParamItem `refreshable:"false"` + MaxCancelError ParamItem `refreshable:"false"` + MinSessionCheckInterval ParamItem `refreshable:"false"` } func (p *GrpcClientConfig) Init(domain string, base *BaseTable) { @@ -318,19 +336,13 @@ func (p *GrpcClientConfig) Init(domain string, base *BaseTable) { if v == "" { return maxAttempts } - iv, err := strconv.Atoi(v) + _, err := strconv.Atoi(v) if err != nil { log.Warn("Failed to convert int when parsing grpc.client.maxMaxAttempts, set to default", zap.String("role", p.Domain), zap.String("grpc.client.maxMaxAttempts", v)) return maxAttempts } - if iv < 2 || iv > 5 { - log.Warn("The value of %s should be greater than 1 and less than 6, set to default", - zap.String("role", p.Domain), - zap.String("grpc.client.maxMaxAttempts", v)) - return maxAttempts - } return v }, Export: true, @@ -345,7 +357,7 @@ func (p *GrpcClientConfig) Init(domain string, base *BaseTable) { if v == "" { return initialBackoff } - _, err := strconv.Atoi(v) + _, err := strconv.ParseFloat(v, 64) if err != nil { log.Warn("Failed to convert int when parsing grpc.client.initialBackoff, set to default", zap.String("role", p.Domain), @@ -379,45 +391,84 @@ func (p *GrpcClientConfig) Init(domain string, base *BaseTable) { } p.MaxBackoff.Init(base.mgr) - backoffMultiplier := fmt.Sprintf("%f", DefaultBackoffMultiplier) - p.BackoffMultiplier = ParamItem{ - Key: "grpc.client.backoffMultiplier", + compressionEnabled := fmt.Sprintf("%t", DefaultCompressionEnabled) + p.CompressionEnabled = ParamItem{ + Key: "grpc.client.compressionEnabled", Version: "2.0.0", Formatter: func(v string) string { if v == "" { - return backoffMultiplier + return compressionEnabled } - _, err := strconv.ParseFloat(v, 64) + _, err := strconv.ParseBool(v) if err != nil { - log.Warn("Failed to convert int when parsing grpc.client.backoffMultiplier, set to default", + log.Warn("Failed to convert int when parsing grpc.client.compressionEnabled, set to default", zap.String("role", p.Domain), - zap.String("grpc.client.backoffMultiplier", v)) - return backoffMultiplier + zap.String("grpc.client.compressionEnabled", v)) + return compressionEnabled } return v }, Export: true, } - p.BackoffMultiplier.Init(base.mgr) + p.CompressionEnabled.Init(base.mgr) - compressionEnabled := fmt.Sprintf("%t", DefaultCompressionEnabled) - p.CompressionEnabled = ParamItem{ - Key: "grpc.client.compressionEnabled", - Version: "2.0.0", + p.MinResetInterval = ParamItem{ + Key: "grpc.client.minResetInterval", + DefaultValue: "1000", Formatter: func(v string) string { if v == "" { - return compressionEnabled + return "1000" } - _, err := strconv.ParseBool(v) + _, err := strconv.Atoi(v) if err != nil { - log.Warn("Failed to convert int when parsing grpc.client.compressionEnabled, set to default", - zap.String("role", p.Domain), - zap.String("grpc.client.compressionEnabled", v)) - return backoffMultiplier + log.Warn("Failed to parse grpc.client.minResetInterval, set to default", + zap.String("role", p.Domain), zap.String("grpc.client.minResetInterval", v), + zap.Error(err)) + return "1000" } return v }, Export: true, } - p.CompressionEnabled.Init(base.mgr) + p.MinResetInterval.Init(base.mgr) + + p.MinSessionCheckInterval = ParamItem{ + Key: "grpc.client.minSessionCheckInterval", + DefaultValue: "200", + Formatter: func(v string) string { + if v == "" { + return "200" + } + _, err := strconv.Atoi(v) + if err != nil { + log.Warn("Failed to parse grpc.client.minSessionCheckInterval, set to default", + zap.String("role", p.Domain), zap.String("grpc.client.minSessionCheckInterval", v), + zap.Error(err)) + return "200" + } + return v + }, + Export: true, + } + p.MinSessionCheckInterval.Init(base.mgr) + + p.MaxCancelError = ParamItem{ + Key: "grpc.client.maxCancelError", + DefaultValue: "32", + Formatter: func(v string) string { + if v == "" { + return "32" + } + _, err := strconv.Atoi(v) + if err != nil { + log.Warn("Failed to parse grpc.client.maxCancelError, set to default", + zap.String("role", p.Domain), zap.String("grpc.client.maxCancelError", v), + zap.Error(err)) + return "32" + } + return v + }, + Export: true, + } + p.MaxCancelError.Init(base.mgr) } diff --git a/pkg/util/paramtable/grpc_param_test.go b/pkg/util/paramtable/grpc_param_test.go index 235f0668968ea..fd101dc25b8dd 100644 --- a/pkg/util/paramtable/grpc_param_test.go +++ b/pkg/util/paramtable/grpc_param_test.go @@ -61,6 +61,9 @@ func TestGrpcServerParams(t *testing.T) { base.Save("grpc.serverMaxSendSize", "a") assert.Equal(t, serverConfig.ServerMaxSendSize.GetAsInt(), DefaultServerMaxSendSize) + + base.Save(serverConfig.GracefulStopTimeout.Key, "1") + assert.Equal(t, serverConfig.GracefulStopTimeout.GetAsInt(), 1) } func TestGrpcClientParams(t *testing.T) { @@ -122,15 +125,14 @@ func TestGrpcClientParams(t *testing.T) { assert.Equal(t, clientConfig.MaxAttempts.GetAsInt(), DefaultMaxAttempts) base.Save("grpc.client.maxMaxAttempts", "a") assert.Equal(t, clientConfig.MaxAttempts.GetAsInt(), DefaultMaxAttempts) - base.Save("grpc.client.maxMaxAttempts", "1") - assert.Equal(t, clientConfig.MaxAttempts.GetAsInt(), DefaultMaxAttempts) - base.Save("grpc.client.maxMaxAttempts", "10") - assert.Equal(t, clientConfig.MaxAttempts.GetAsInt(), DefaultMaxAttempts) base.Save("grpc.client.maxMaxAttempts", "4") assert.Equal(t, clientConfig.MaxAttempts.GetAsInt(), 4) + assert.Equal(t, clientConfig.InitialBackoff.GetAsFloat(), DefaultInitialBackoff) base.Save("grpc.client.initialBackOff", "a") + assert.Equal(t, clientConfig.InitialBackoff.GetAsFloat(), DefaultInitialBackoff) base.Save("grpc.client.initialBackOff", "2.0") + assert.Equal(t, clientConfig.InitialBackoff.GetAsFloat(), 2.0) assert.Equal(t, clientConfig.MaxBackoff.GetAsFloat(), DefaultMaxBackoff) base.Save("grpc.client.maxBackOff", "a") @@ -138,18 +140,30 @@ func TestGrpcClientParams(t *testing.T) { base.Save("grpc.client.maxBackOff", "50.0") assert.Equal(t, clientConfig.MaxBackoff.GetAsFloat(), 50.0) - assert.Equal(t, clientConfig.BackoffMultiplier.GetAsFloat(), DefaultBackoffMultiplier) - base.Save("grpc.client.backoffMultiplier", "a") - assert.Equal(t, clientConfig.BackoffMultiplier.GetAsFloat(), DefaultBackoffMultiplier) - base.Save("grpc.client.backoffMultiplier", "3.0") - assert.Equal(t, clientConfig.BackoffMultiplier.GetAsFloat(), 3.0) - assert.Equal(t, clientConfig.CompressionEnabled.GetAsBool(), DefaultCompressionEnabled) base.Save("grpc.client.CompressionEnabled", "a") assert.Equal(t, clientConfig.CompressionEnabled.GetAsBool(), DefaultCompressionEnabled) base.Save("grpc.client.CompressionEnabled", "true") assert.Equal(t, clientConfig.CompressionEnabled.GetAsBool(), true) + assert.Equal(t, clientConfig.MinResetInterval.GetValue(), "1000") + base.Save("grpc.client.minResetInterval", "abc") + assert.Equal(t, clientConfig.MinResetInterval.GetValue(), "1000") + base.Save("grpc.client.minResetInterval", "5000") + assert.Equal(t, clientConfig.MinResetInterval.GetValue(), "5000") + + assert.Equal(t, clientConfig.MinSessionCheckInterval.GetValue(), "200") + base.Save("grpc.client.minSessionCheckInterval", "abc") + assert.Equal(t, clientConfig.MinSessionCheckInterval.GetValue(), "200") + base.Save("grpc.client.minSessionCheckInterval", "500") + assert.Equal(t, clientConfig.MinSessionCheckInterval.GetValue(), "500") + + assert.Equal(t, clientConfig.MaxCancelError.GetValue(), "32") + base.Save("grpc.client.maxCancelError", "abc") + assert.Equal(t, clientConfig.MaxCancelError.GetValue(), "32") + base.Save("grpc.client.maxCancelError", "64") + assert.Equal(t, clientConfig.MaxCancelError.GetValue(), "64") + base.Save("common.security.tlsMode", "1") base.Save("tls.serverPemPath", "/pem") base.Save("tls.serverKeyPath", "/key") diff --git a/pkg/util/paramtable/hook_config.go b/pkg/util/paramtable/hook_config.go index 37bef02437114..38881fd8f0ee1 100644 --- a/pkg/util/paramtable/hook_config.go +++ b/pkg/util/paramtable/hook_config.go @@ -1,7 +1,10 @@ package paramtable import ( + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/config" + "github.com/milvus-io/milvus/pkg/log" ) const hookYamlFile = "hook.yaml" @@ -15,6 +18,7 @@ type hookConfig struct { func (h *hookConfig) init(base *BaseTable) { h.hookBase = base + log.Info("hook config", zap.Any("hook", base.FileConfigs())) h.SoPath = ParamItem{ Key: "soPath", diff --git a/pkg/util/paramtable/http_param.go b/pkg/util/paramtable/http_param.go index 46c2abf64d885..ea04befbab209 100644 --- a/pkg/util/paramtable/http_param.go +++ b/pkg/util/paramtable/http_param.go @@ -27,7 +27,7 @@ func (p *httpConfig) init(base *BaseTable) { p.Port = ParamItem{ Key: "proxy.http.port", - Version: "2.1.0", + Version: "2.3.0", Doc: "high-level restful api", PanicIfEmpty: false, Export: true, diff --git a/pkg/util/paramtable/quota_param_test.go b/pkg/util/paramtable/quota_param_test.go index d8bf1cd30d319..8387f83650df2 100644 --- a/pkg/util/paramtable/quota_param_test.go +++ b/pkg/util/paramtable/quota_param_test.go @@ -212,6 +212,5 @@ func TestQuotaParam(t *testing.T) { // test invalid config params.Save(params.QuotaConfig.DiskQuotaPerCollection.Key, "-1") assert.Equal(t, qc.DiskQuota.GetAsFloat(), qc.DiskQuotaPerCollection.GetAsFloat()) - }) } diff --git a/pkg/util/paramtable/runtime.go b/pkg/util/paramtable/runtime.go index 7831dc3755460..55856d3836499 100644 --- a/pkg/util/paramtable/runtime.go +++ b/pkg/util/paramtable/runtime.go @@ -24,9 +24,11 @@ const ( runtimeUpdateTimeKey = "runtime.updateTime" ) -var once sync.Once -var params ComponentParam -var hookParams hookConfig +var ( + once sync.Once + params ComponentParam + hookParams hookConfig +) func Init() { once.Do(func() { diff --git a/pkg/util/paramtable/service_param.go b/pkg/util/paramtable/service_param.go index 2e9196705f749..3cafa06f6a73f 100644 --- a/pkg/util/paramtable/service_param.go +++ b/pkg/util/paramtable/service_param.go @@ -45,6 +45,7 @@ type ServiceParam struct { LocalStorageCfg LocalStorageConfig MetaStoreCfg MetaStoreConfig EtcdCfg EtcdConfig + TiKVCfg TiKVConfig MQCfg MQConfig PulsarCfg PulsarConfig KafkaCfg KafkaConfig @@ -57,6 +58,7 @@ func (p *ServiceParam) init(bt *BaseTable) { p.LocalStorageCfg.Init(bt) p.MetaStoreCfg.Init(bt) p.EtcdCfg.Init(bt) + p.TiKVCfg.Init(bt) p.MQCfg.Init(bt) p.PulsarCfg.Init(bt) p.KafkaCfg.Init(bt) @@ -257,6 +259,129 @@ We recommend using version 1.2 and above.`, p.EtcdTLSMinVersion.Init(base.mgr) } +// ///////////////////////////////////////////////////////////////////////////// +// --- tikv --- +type TiKVConfig struct { + Endpoints ParamItem `refreshable:"false"` + RootPath ParamItem `refreshable:"false"` + MetaSubPath ParamItem `refreshable:"false"` + KvSubPath ParamItem `refreshable:"false"` + MetaRootPath CompositeParamItem `refreshable:"false"` + KvRootPath CompositeParamItem `refreshable:"false"` + RequestTimeout ParamItem `refreshable:"true"` + SnapshotScanSize ParamItem `refreshable:"true"` + TiKVUseSSL ParamItem `refreshable:"false"` + TiKVTLSCert ParamItem `refreshable:"false"` + TiKVTLSKey ParamItem `refreshable:"false"` + TiKVTLSCACert ParamItem `refreshable:"false"` +} + +func (p *TiKVConfig) Init(base *BaseTable) { + p.Endpoints = ParamItem{ + Key: "tikv.endpoints", + Version: "2.3.0", + DefaultValue: "localhost:2379", + PanicIfEmpty: true, + Export: true, + } + p.Endpoints.Init(base.mgr) + + p.RootPath = ParamItem{ + Key: "tikv.rootPath", + Version: "2.3.0", + DefaultValue: "by-dev", + PanicIfEmpty: true, + Doc: "The root path where data is stored in tikv", + Export: true, + } + p.RootPath.Init(base.mgr) + + p.MetaSubPath = ParamItem{ + Key: "tikv.metaSubPath", + Version: "2.3.0", + DefaultValue: "meta", + PanicIfEmpty: true, + Doc: "metaRootPath = rootPath + '/' + metaSubPath", + Export: true, + } + p.MetaSubPath.Init(base.mgr) + + p.MetaRootPath = CompositeParamItem{ + Items: []*ParamItem{&p.RootPath, &p.MetaSubPath}, + Format: func(kvs map[string]string) string { + return path.Join(kvs[p.RootPath.Key], kvs[p.MetaSubPath.Key]) + }, + } + + p.KvSubPath = ParamItem{ + Key: "tikv.kvSubPath", + Version: "2.3.0", + DefaultValue: "kv", + PanicIfEmpty: true, + Doc: "kvRootPath = rootPath + '/' + kvSubPath", + Export: true, + } + p.KvSubPath.Init(base.mgr) + + p.KvRootPath = CompositeParamItem{ + Items: []*ParamItem{&p.RootPath, &p.KvSubPath}, + Format: func(kvs map[string]string) string { + return path.Join(kvs[p.RootPath.Key], kvs[p.KvSubPath.Key]) + }, + } + + p.RequestTimeout = ParamItem{ + Key: "tikv.requestTimeout", + Version: "2.3.0", + DefaultValue: "10000", + Doc: "ms, tikv request timeout", + Export: true, + } + p.RequestTimeout.Init(base.mgr) + + p.SnapshotScanSize = ParamItem{ + Key: "tikv.snapshotScanSize", + Version: "2.3.0", + DefaultValue: "256", + Doc: "batch size of tikv snapshot scan", + Export: true, + } + p.SnapshotScanSize.Init(base.mgr) + + p.TiKVUseSSL = ParamItem{ + Key: "tikv.ssl.enabled", + DefaultValue: "false", + Version: "2.3.0", + Doc: "Whether to support TiKV secure connection mode", + Export: true, + } + p.TiKVUseSSL.Init(base.mgr) + + p.TiKVTLSCert = ParamItem{ + Key: "tikv.ssl.tlsCert", + Version: "2.3.0", + Doc: "path to your cert file", + Export: true, + } + p.TiKVTLSCert.Init(base.mgr) + + p.TiKVTLSKey = ParamItem{ + Key: "tikv.ssl.tlsKey", + Version: "2.3.0", + Doc: "path to your key file", + Export: true, + } + p.TiKVTLSKey.Init(base.mgr) + + p.TiKVTLSCACert = ParamItem{ + Key: "tikv.ssl.tlsCACert", + Version: "2.3.0", + Doc: "path to your CACert file", + Export: true, + } + p.TiKVTLSCACert.Init(base.mgr) +} + type LocalStorageConfig struct { Path ParamItem `refreshable:"false"` } @@ -281,7 +406,7 @@ func (p *MetaStoreConfig) Init(base *BaseTable) { Key: "metastore.type", Version: "2.2.0", DefaultValue: util.MetaStoreTypeEtcd, - Doc: `Default value: etcd, Valid values: etcd `, + Doc: `Default value: etcd, Valid values: [etcd, tikv] `, Export: true, } p.MetaStoreType.Init(base.mgr) @@ -377,6 +502,9 @@ type PulsarConfig struct { // Global request timeout RequestTimeout ParamItem `refreshable:"false"` + + // Enable Client side metrics + EnableClientMetrics ParamItem `refreshable:"false"` } func (p *PulsarConfig) Init(base *BaseTable) { @@ -490,6 +618,14 @@ func (p *PulsarConfig) Init(base *BaseTable) { Export: true, } p.RequestTimeout.Init(base.mgr) + + p.EnableClientMetrics = ParamItem{ + Key: "pulsar.enableClientMetrics", + Version: "2.3.0", + DefaultValue: "false", + Export: true, + } + p.EnableClientMetrics.Init(base.mgr) } // --- kafka --- @@ -501,6 +637,7 @@ type KafkaConfig struct { SecurityProtocol ParamItem `refreshable:"false"` ConsumerExtraConfig ParamGroup `refreshable:"false"` ProducerExtraConfig ParamGroup `refreshable:"false"` + ReadTimeout ParamItem `refreshable:"true"` } func (k *KafkaConfig) Init(base *BaseTable) { @@ -556,6 +693,14 @@ func (k *KafkaConfig) Init(base *BaseTable) { Version: "2.2.0", } k.ProducerExtraConfig.Init(base.mgr) + + k.ReadTimeout = ParamItem{ + Key: "kafka.readTimeout", + DefaultValue: "10", + Version: "2.3.1", + Export: true, + } + k.ReadTimeout.Init(base.mgr) } // ///////////////////////////////////////////////////////////////////////////// @@ -787,19 +932,20 @@ func (r *NatsmqConfig) Init(base *BaseTable) { // ///////////////////////////////////////////////////////////////////////////// // --- minio --- type MinioConfig struct { - Address ParamItem `refreshable:"false"` - Port ParamItem `refreshable:"false"` - AccessKeyID ParamItem `refreshable:"false"` - SecretAccessKey ParamItem `refreshable:"false"` - UseSSL ParamItem `refreshable:"false"` - BucketName ParamItem `refreshable:"false"` - RootPath ParamItem `refreshable:"false"` - UseIAM ParamItem `refreshable:"false"` - CloudProvider ParamItem `refreshable:"false"` - IAMEndpoint ParamItem `refreshable:"false"` - LogLevel ParamItem `refreshable:"false"` - Region ParamItem `refreshable:"false"` - UseVirtualHost ParamItem `refreshable:"false"` + Address ParamItem `refreshable:"false"` + Port ParamItem `refreshable:"false"` + AccessKeyID ParamItem `refreshable:"false"` + SecretAccessKey ParamItem `refreshable:"false"` + UseSSL ParamItem `refreshable:"false"` + BucketName ParamItem `refreshable:"false"` + RootPath ParamItem `refreshable:"false"` + UseIAM ParamItem `refreshable:"false"` + CloudProvider ParamItem `refreshable:"false"` + IAMEndpoint ParamItem `refreshable:"false"` + LogLevel ParamItem `refreshable:"false"` + Region ParamItem `refreshable:"false"` + UseVirtualHost ParamItem `refreshable:"false"` + RequestTimeoutMs ParamItem `refreshable:"false"` } func (p *MinioConfig) Init(base *BaseTable) { @@ -872,8 +1018,15 @@ func (p *MinioConfig) Init(base *BaseTable) { p.BucketName.Init(base.mgr) p.RootPath = ParamItem{ - Key: "minio.rootPath", - Version: "2.0.0", + Key: "minio.rootPath", + Version: "2.0.0", + Formatter: func(rootPath string) string { + if rootPath == "" { + return "" + } + rootPath = strings.TrimLeft(rootPath, "/") + return path.Clean(rootPath) + }, PanicIfEmpty: false, Doc: "The root path where the message is stored in MinIO/S3", Export: true, @@ -943,4 +1096,12 @@ Leave it empty if you want to use AWS default endpoint`, } p.UseVirtualHost.Init(base.mgr) + p.RequestTimeoutMs = ParamItem{ + Key: "minio.requestTimeoutMs", + Version: "2.3.2", + DefaultValue: DefaultMinioRequestTimeout, + Doc: "minio timeout for request time in milliseconds", + Export: true, + } + p.RequestTimeoutMs.Init(base.mgr) } diff --git a/pkg/util/paramtable/service_param_test.go b/pkg/util/paramtable/service_param_test.go index ff884a36e6edc..847301ce93867 100644 --- a/pkg/util/paramtable/service_param_test.go +++ b/pkg/util/paramtable/service_param_test.go @@ -13,6 +13,7 @@ package paramtable import ( "testing" + "time" "github.com/stretchr/testify/assert" @@ -63,6 +64,22 @@ func TestServiceParam(t *testing.T) { SParams.init(bt) }) + t.Run("test tikvConfig", func(t *testing.T) { + Params := &SParams.TiKVCfg + + assert.NotZero(t, len(Params.Endpoints.GetAsStrings())) + t.Logf("tikv endpoints = %s", Params.Endpoints.GetAsStrings()) + + assert.NotEqual(t, Params.MetaRootPath, "") + t.Logf("meta root path = %s", Params.MetaRootPath.GetValue()) + + assert.NotEqual(t, Params.KvRootPath, "") + t.Logf("kv root path = %s", Params.KvRootPath.GetValue()) + + t.Setenv(metricsinfo.DeployModeEnvKey, metricsinfo.StandaloneDeployMode) + SParams.init(bt) + }) + t.Run("test pulsarConfig", func(t *testing.T) { // test default value { @@ -151,6 +168,7 @@ func TestServiceParam(t *testing.T) { assert.Empty(t, kc.Address.GetValue()) assert.Equal(t, kc.SaslMechanisms.GetValue(), "PLAIN") assert.Equal(t, kc.SecurityProtocol.GetValue(), "SASL_SSL") + assert.Equal(t, kc.ReadTimeout.GetAsDuration(time.Second), 10*time.Second) } }) diff --git a/pkg/util/resource/resource_manager.go b/pkg/util/resource/resource_manager.go new file mode 100644 index 0000000000000..4211441aa6137 --- /dev/null +++ b/pkg/util/resource/resource_manager.go @@ -0,0 +1,301 @@ +package resource + +import ( + "sync" + "time" +) + +const ( + NoExpiration time.Duration = -1 + DefaultCheckInterval = 2 * time.Second + DefaultExpiration = 4 * time.Second +) + +type Resource interface { + Type() string + Name() string + Get() any + Close() + // KeepAliveTime returns the time duration of the resource keep alive if the resource isn't used. + KeepAliveTime() time.Duration +} + +type wrapper struct { + res Resource + obj any + typ string + name string + closeFunc func() + keepAliveTime time.Duration +} + +func (w *wrapper) Type() string { + if w.typ != "" { + return w.typ + } + if w.res == nil { + return "" + } + return w.res.Type() +} + +func (w *wrapper) Name() string { + if w.name != "" { + return w.name + } + if w.res == nil { + return "" + } + return w.res.Name() +} + +func (w *wrapper) Get() any { + if w.obj != nil { + return w.obj + } + if w.res == nil { + return nil + } + return w.res.Get() +} + +func (w *wrapper) Close() { + if w.res != nil { + w.res.Close() + } + if w.closeFunc != nil { + w.closeFunc() + } +} + +func (w *wrapper) KeepAliveTime() time.Duration { + if w.keepAliveTime != 0 { + return w.keepAliveTime + } + if w.res == nil { + return 0 + } + return w.res.KeepAliveTime() +} + +type Option func(res *wrapper) + +func WithResource(res Resource) Option { + return func(w *wrapper) { + w.res = res + } +} + +func WithType(typ string) Option { + return func(res *wrapper) { + res.typ = typ + } +} + +func WithName(name string) Option { + return func(res *wrapper) { + res.name = name + } +} + +func WithObj(obj any) Option { + return func(res *wrapper) { + res.obj = obj + } +} + +func WithCloseFunc(closeFunc func()) Option { + return func(res *wrapper) { + res.closeFunc = closeFunc + } +} + +func WithKeepAliveTime(keepAliveTime time.Duration) Option { + return func(res *wrapper) { + res.keepAliveTime = keepAliveTime + } +} + +func NewResource(opts ...Option) Resource { + w := &wrapper{} + for _, opt := range opts { + opt(w) + } + return w +} + +func NewSimpleResource(obj any, typ, name string, keepAliveTime time.Duration, closeFunc func()) Resource { + return NewResource(WithObj(obj), WithType(typ), WithName(name), WithKeepAliveTime(keepAliveTime), WithCloseFunc(closeFunc)) +} + +type Manager interface { + Get(typ, name string, newResourceFunc NewResourceFunc) (Resource, error) + Delete(typ, name string) Resource + Close() +} + +type item struct { + res Resource + updateTimeChan chan int64 + deleteMark chan struct{} + expiration int64 +} + +type manager struct { + resources map[string]map[string]*item // key: resource type, value: resource name -> resource + checkInterval time.Duration + defaultExpiration time.Duration + defaultTypeExpirations map[string]time.Duration // key: resource type, value: expiration + mu sync.RWMutex + wg sync.WaitGroup + stop chan struct{} + stopOnce sync.Once +} + +func NewManager(checkInterval, defaultExpiration time.Duration, defaultTypeExpirations map[string]time.Duration) Manager { + if checkInterval <= 0 { + checkInterval = DefaultCheckInterval + } + if defaultExpiration <= 0 { + defaultExpiration = DefaultExpiration + } + if defaultTypeExpirations == nil { + defaultTypeExpirations = make(map[string]time.Duration) + } + m := &manager{ + resources: make(map[string]map[string]*item), + checkInterval: checkInterval, + defaultExpiration: defaultExpiration, + defaultTypeExpirations: defaultTypeExpirations, + stop: make(chan struct{}), + } + m.wg.Add(1) + go m.backgroundGC() + return m +} + +func (m *manager) backgroundGC() { + ticker := time.NewTicker(m.checkInterval) + defer m.wg.Done() + defer ticker.Stop() + for { + select { + case <-ticker.C: + m.gc() + case <-m.stop: + m.mu.Lock() + for _, typMap := range m.resources { + for _, item := range typMap { + item.res.Close() + } + } + m.resources = nil + m.mu.Unlock() + return + } + } +} + +func (m *manager) gc() { + m.mu.Lock() + defer m.mu.Unlock() + now := time.Now().UnixNano() + for typ, typMap := range m.resources { + for resName, item := range typMap { + select { + case lastTime := <-item.updateTimeChan: + if item.expiration >= 0 { + item.expiration = lastTime + } + case <-item.deleteMark: + item.res.Close() + delete(typMap, resName) + default: + if item.expiration >= 0 && item.expiration <= now { + item.res.Close() + delete(typMap, resName) + } + } + } + if len(typMap) == 0 { + delete(m.resources, typ) + } + } +} + +func (m *manager) updateExpire(item *item) { + select { + case item.updateTimeChan <- time.Now().UnixNano() + item.res.KeepAliveTime().Nanoseconds(): + default: + } +} + +type NewResourceFunc func() (Resource, error) + +func (m *manager) Get(typ, name string, newResourceFunc NewResourceFunc) (Resource, error) { + m.mu.RLock() + typMap, ok := m.resources[typ] + if ok { + item := typMap[name] + if item != nil { + m.mu.RUnlock() + m.updateExpire(item) + return item.res, nil + } + } + m.mu.RUnlock() + m.mu.Lock() + defer m.mu.Unlock() + typMap, ok = m.resources[typ] + if !ok { + typMap = make(map[string]*item) + m.resources[typ] = typMap + } + ite, ok := typMap[name] + if !ok { + res, err := newResourceFunc() + if err != nil { + return nil, err + } + if res.KeepAliveTime() == 0 { + defaultExpiration := m.defaultTypeExpirations[typ] + if defaultExpiration == 0 { + defaultExpiration = m.defaultExpiration + } + res = NewResource(WithResource(res), WithKeepAliveTime(defaultExpiration)) + } + ite = &item{ + res: res, + updateTimeChan: make(chan int64, 1), + deleteMark: make(chan struct{}, 1), + } + typMap[name] = ite + } + m.updateExpire(ite) + return ite.res, nil +} + +func (m *manager) Delete(typ, name string) Resource { + m.mu.Lock() + defer m.mu.Unlock() + typMap, ok := m.resources[typ] + if !ok { + return nil + } + ite, ok := typMap[name] + if !ok { + return nil + } + select { + case ite.deleteMark <- struct{}{}: + default: + } + return ite.res +} + +func (m *manager) Close() { + m.stopOnce.Do(func() { + close(m.stop) + m.wg.Wait() + }) +} diff --git a/pkg/util/resource/resource_manager_test.go b/pkg/util/resource/resource_manager_test.go new file mode 100644 index 0000000000000..eacc1fd3ad830 --- /dev/null +++ b/pkg/util/resource/resource_manager_test.go @@ -0,0 +1,160 @@ +package resource + +import ( + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" +) + +func TestResourceManager(t *testing.T) { + { + manager := NewManager(0, 0, nil) + manager.Close() + } + + manager := NewManager(500*time.Millisecond, 2*time.Second, map[string]time.Duration{ + "test": time.Second, + }) + defer manager.Close() + { + assert.Nil(t, manager.Delete("test", "test")) + res, err := manager.Get("stream", "foo", func() (Resource, error) { + return NewSimpleResource("stream-foo", "stream", "foo", 0, nil), nil + }) + assert.NoError(t, err) + assert.Equal(t, 2*time.Second, res.KeepAliveTime()) + assert.Equal(t, "stream-foo", res.Get()) + } + { + _, err := manager.Get("err", "foo", func() (Resource, error) { + return nil, errors.New("mock test error") + }) + assert.Error(t, err) + } + { + res, err := manager.Get("test", "foo", func() (Resource, error) { + return NewSimpleResource("foo", "test", "foo", 0, nil), nil + }) + assert.NoError(t, err) + assert.Equal(t, "foo", res.Get()) + + assert.Nil(t, manager.Delete("test", "test")) + } + { + time.Sleep(500 * time.Millisecond) + res, err := manager.Get("test", "foo", func() (Resource, error) { + return NewSimpleResource("foox", "test", "foo", 0, nil), nil + }) + assert.NoError(t, err) + assert.Equal(t, "foo", res.Get()) + } + { + time.Sleep(3 * time.Second) + res, err := manager.Get("test", "foo", func() (Resource, error) { + return NewSimpleResource("foo2", "test", "foo", 0, nil), nil + }) + assert.NoError(t, err) + assert.Equal(t, "foo2", res.Get(), res.KeepAliveTime()) + } + { + res := manager.Delete("test", "foo") + assert.Equal(t, "foo2", res.Get()) + res = manager.Delete("test", "foo") + assert.Equal(t, "foo2", res.Get()) + time.Sleep(time.Second) + + res, err := manager.Get("test", "foo", func() (Resource, error) { + return NewSimpleResource("foo3", "test", "foo", 0, nil), nil + }) + assert.NoError(t, err) + assert.Equal(t, "foo3", res.Get()) + } + { + time.Sleep(2 * time.Second) + res, err := manager.Get("stream", "foo", func() (Resource, error) { + return NewSimpleResource("stream-foox", "stream", "foo", 0, nil), nil + }) + assert.NoError(t, err) + assert.Equal(t, "stream-foox", res.Get()) + } + { + var res Resource + var err error + res, err = manager.Get("ever", "foo", func() (Resource, error) { + return NewSimpleResource("ever-foo", "ever", "foo", NoExpiration, nil), nil + }) + assert.NoError(t, err) + assert.Equal(t, "ever-foo", res.Get()) + + res, err = manager.Get("ever", "foo", func() (Resource, error) { + return NewSimpleResource("ever-foo2", "ever", "foo", NoExpiration, nil), nil + }) + assert.NoError(t, err) + assert.Equal(t, "ever-foo", res.Get()) + + manager.Delete("ever", "foo") + time.Sleep(time.Second) + res, err = manager.Get("ever", "foo", func() (Resource, error) { + return NewSimpleResource("ever-foo3", "ever", "foo", NoExpiration, nil), nil + }) + assert.NoError(t, err) + assert.Equal(t, "ever-foo3", res.Get()) + } +} + +func TestResource(t *testing.T) { + { + isClose := false + res := NewSimpleResource("obj", "test", "foo", 0, func() { + isClose = true + }) + assert.Equal(t, "test", res.Type()) + assert.Equal(t, "foo", res.Name()) + assert.Equal(t, "obj", res.Get()) + assert.EqualValues(t, 0, res.KeepAliveTime()) + res.Close() + assert.True(t, isClose) + } + + { + res := NewResource() + assert.Empty(t, res.Type()) + assert.Empty(t, res.Name()) + assert.Empty(t, res.Get()) + assert.EqualValues(t, 0, res.KeepAliveTime()) + } + + { + isClose := false + res := NewSimpleResource("obj", "test", "foo", 0, func() { + isClose = true + }) + isClose2 := false + wrapper := NewResource(WithResource(res), WithType("test2"), WithName("foo2"), WithObj("obj2"), WithKeepAliveTime(time.Second), WithCloseFunc(func() { + isClose2 = true + })) + wrapper.Close() + assert.Equal(t, "test2", wrapper.Type()) + assert.Equal(t, "foo2", wrapper.Name()) + assert.Equal(t, "obj2", wrapper.Get()) + assert.Equal(t, time.Second, wrapper.KeepAliveTime()) + assert.True(t, isClose) + assert.True(t, isClose2) + } + + { + isClose := false + res := NewSimpleResource("obj", "test", "foo", 0, func() { + isClose = true + }) + wrapper := NewResource(WithResource(res)) + assert.Equal(t, "test", wrapper.Type()) + assert.Equal(t, "foo", wrapper.Name()) + assert.Equal(t, "obj", wrapper.Get()) + assert.EqualValues(t, 0, wrapper.KeepAliveTime()) + wrapper.Close() + assert.True(t, isClose) + } +} diff --git a/pkg/util/retry/retry.go b/pkg/util/retry/retry.go index eb56745365364..3f48a18d7e277 100644 --- a/pkg/util/retry/retry.go +++ b/pkg/util/retry/retry.go @@ -19,6 +19,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -26,8 +27,11 @@ import ( // fn is the func to run. // Option can control the retry times and timeout. func Do(ctx context.Context, fn func() error, opts ...Option) error { - log := log.Ctx(ctx) + if !funcutil.CheckCtxValid(ctx) { + return ctx.Err() + } + log := log.Ctx(ctx) c := newDefaultConfig() for _, opt := range opts { @@ -38,7 +42,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error { for i := uint(0); i < c.attempts; i++ { if err := fn(); err != nil { - if i%10 == 0 { + if i%4 == 0 { log.Error("retry func failed", zap.Uint("retry time", i), zap.Error(err)) } @@ -52,8 +56,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error { select { case <-time.After(c.sleep): case <-ctx.Done(): - el = merr.Combine(el, errors.Wrapf(ctx.Err(), "context done during sleep after run#%d", i)) - return el + return merr.Combine(el, ctx.Err()) } c.sleep *= 2 diff --git a/pkg/util/retry/retry_test.go b/pkg/util/retry/retry_test.go index ec7f7a544a782..d21522482ee76 100644 --- a/pkg/util/retry/retry_test.go +++ b/pkg/util/retry/retry_test.go @@ -20,6 +20,8 @@ import ( "github.com/cockroachdb/errors" "github.com/lingdor/stackerror" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/util/merr" ) func TestDo(t *testing.T) { @@ -130,5 +132,16 @@ func TestContextCancel(t *testing.T) { err := Do(ctx, testFn) assert.Error(t, err) + assert.True(t, merr.IsCanceledOrTimeout(err)) t.Log(err) } + +func TestWrap(t *testing.T) { + err := merr.WrapErrSegmentNotFound(1, "failed to get Segment") + assert.True(t, errors.Is(err, merr.ErrSegmentNotFound)) + assert.True(t, IsRecoverable(err)) + err2 := Unrecoverable(err) + fmt.Println(err2) + assert.True(t, errors.Is(err2, merr.ErrSegmentNotFound)) + assert.False(t, IsRecoverable(err2)) +} diff --git a/internal/core/src/common/CGoHelper.h b/pkg/util/tikv/tikv_test_util.go similarity index 62% rename from internal/core/src/common/CGoHelper.h rename to pkg/util/tikv/tikv_test_util.go index 8f8ce80032310..21bbac9de7b2c 100644 --- a/internal/core/src/common/CGoHelper.h +++ b/pkg/util/tikv/tikv_test_util.go @@ -14,24 +14,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once +package tikv -#include -#include +import ( + "github.com/tikv/client-go/v2/testutils" + tilib "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/txnkv" +) -#include "common/type_c.h" - -namespace milvus { - -inline CStatus -SuccessCStatus() { - return CStatus{Success, ""}; +func SetupLocalTxn() *txnkv.Client { + client, cluster, pdClient, err := testutils.NewMockTiKV("", nil) + if err != nil { + panic(err) + } + testutils.BootstrapWithSingleStore(cluster) + store, err := tilib.NewTestTiKVStore(client, pdClient, nil, nil, 0) + if err != nil { + panic(err) + } + return &txnkv.Client{KVStore: store} } - -inline CStatus -FailureCStatus(ErrorCode error_code, const std::string_view str) { - auto str_dup = strdup(str.data()); - return CStatus{error_code, str_dup}; -} - -} // namespace milvus diff --git a/internal/indexnode/errors.go b/pkg/util/tikv/tikv_util.go similarity index 58% rename from internal/indexnode/errors.go rename to pkg/util/tikv/tikv_util.go index 4705b2cc891f2..53a1303bcf2fa 100644 --- a/internal/indexnode/errors.go +++ b/pkg/util/tikv/tikv_util.go @@ -14,25 +14,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -package indexnode +package tikv import ( - "fmt" + "github.com/tikv/client-go/v2/config" + "github.com/tikv/client-go/v2/txnkv" - "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) -var ( - ErrNoSuchKey = errors.New("NoSuchKey") - ErrEmptyInsertPaths = errors.New("empty insert paths") -) - -// msgIndexNodeIsUnhealthy return a message tha IndexNode is not healthy. -func msgIndexNodeIsUnhealthy(nodeID UniqueID) string { - return fmt.Sprintf("index node %d is not ready", nodeID) -} - -// errIndexNodeIsUnhealthy return an error that specified IndexNode is not healthy. -func errIndexNodeIsUnhealthy(nodeID UniqueID) error { - return errors.New(msgIndexNodeIsUnhealthy(nodeID)) +func GetTiKVClient(cfg *paramtable.TiKVConfig) (*txnkv.Client, error) { + if cfg.TiKVUseSSL.GetAsBool() { + f := func(conf *config.Config) { + conf.Security = config.NewSecurity(cfg.TiKVTLSCACert.GetValue(), cfg.TiKVTLSCert.GetValue(), cfg.TiKVTLSKey.GetValue(), []string{}) + } + config.UpdateGlobal(f) + return txnkv.NewClient([]string{cfg.Endpoints.GetValue()}) + } + return txnkv.NewClient([]string{cfg.Endpoints.GetValue()}) } diff --git a/pkg/util/tsoutil/tso_test.go b/pkg/util/tsoutil/tso_test.go index fac9bedce3eb3..a4c44a73f6035 100644 --- a/pkg/util/tsoutil/tso_test.go +++ b/pkg/util/tsoutil/tso_test.go @@ -62,11 +62,11 @@ func TestAddPhysicalDurationOnTs(t *testing.T) { duration := time.Millisecond * (20 * 1000) ts2 := AddPhysicalDurationOnTs(ts1, duration) ts3 := ComposeTSByTime(now.Add(duration), 0) - //diff := CalculateDuration(ts2, ts1) + // diff := CalculateDuration(ts2, ts1) assert.Equal(t, ts3, ts2) ts2 = AddPhysicalDurationOnTs(ts1, -duration) ts3 = ComposeTSByTime(now.Add(-duration), 0) - //diff := CalculateDuration(ts2, ts1) + // diff := CalculateDuration(ts2, ts1) assert.Equal(t, ts3, ts2) } diff --git a/pkg/util/typeutil/chan.go b/pkg/util/typeutil/chan.go new file mode 100644 index 0000000000000..1b33626d36613 --- /dev/null +++ b/pkg/util/typeutil/chan.go @@ -0,0 +1,12 @@ +package typeutil + +// IsChanClosed returns whether input signal channel is closed or not. +// this method accept `chan struct{}` type only in case of passing msg channels by mistake. +func IsChanClosed(ch <-chan struct{}) bool { + select { + case <-ch: + return true + default: + return false + } +} diff --git a/pkg/util/typeutil/conversion_test.go b/pkg/util/typeutil/conversion_test.go index 4f948ab68f282..da5a9623fb219 100644 --- a/pkg/util/typeutil/conversion_test.go +++ b/pkg/util/typeutil/conversion_test.go @@ -94,5 +94,4 @@ func TestConversion(t *testing.T) { ret1 := SliceRemoveDuplicate(arr) assert.Equal(t, 3, len(ret1)) }) - } diff --git a/pkg/util/typeutil/float_util_test.go b/pkg/util/typeutil/float_util_test.go index 6ac94aad96d04..16f204e3339ba 100644 --- a/pkg/util/typeutil/float_util_test.go +++ b/pkg/util/typeutil/float_util_test.go @@ -24,7 +24,7 @@ import ( ) func Test_VerifyFloat(t *testing.T) { - var value = math.NaN() + value := math.NaN() err := VerifyFloat(value) assert.Error(t, err) diff --git a/pkg/util/typeutil/gen_empty_field_data.go b/pkg/util/typeutil/gen_empty_field_data.go index 4123aad15f73f..818a8ec25f883 100644 --- a/pkg/util/typeutil/gen_empty_field_data.go +++ b/pkg/util/typeutil/gen_empty_field_data.go @@ -15,7 +15,8 @@ func genEmptyBoolFieldData(field *schemapb.FieldSchema) *schemapb.FieldData { Data: &schemapb.ScalarField_BoolData{BoolData: &schemapb.BoolArray{Data: nil}}, }, }, - FieldId: field.GetFieldID(), + FieldId: field.GetFieldID(), + IsDynamic: field.GetIsDynamic(), } } @@ -28,7 +29,8 @@ func genEmptyIntFieldData(field *schemapb.FieldSchema) *schemapb.FieldData { Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: nil}}, }, }, - FieldId: field.GetFieldID(), + FieldId: field.GetFieldID(), + IsDynamic: field.GetIsDynamic(), } } @@ -41,7 +43,8 @@ func genEmptyLongFieldData(field *schemapb.FieldSchema) *schemapb.FieldData { Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: nil}}, }, }, - FieldId: field.GetFieldID(), + FieldId: field.GetFieldID(), + IsDynamic: field.GetIsDynamic(), } } @@ -54,7 +57,8 @@ func genEmptyFloatFieldData(field *schemapb.FieldSchema) *schemapb.FieldData { Data: &schemapb.ScalarField_FloatData{FloatData: &schemapb.FloatArray{Data: nil}}, }, }, - FieldId: field.GetFieldID(), + FieldId: field.GetFieldID(), + IsDynamic: field.GetIsDynamic(), } } @@ -67,7 +71,8 @@ func genEmptyDoubleFieldData(field *schemapb.FieldSchema) *schemapb.FieldData { Data: &schemapb.ScalarField_DoubleData{DoubleData: &schemapb.DoubleArray{Data: nil}}, }, }, - FieldId: field.GetFieldID(), + FieldId: field.GetFieldID(), + IsDynamic: field.GetIsDynamic(), } } @@ -80,7 +85,8 @@ func genEmptyVarCharFieldData(field *schemapb.FieldSchema) *schemapb.FieldData { Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: nil}}, }, }, - FieldId: field.GetFieldID(), + FieldId: field.GetFieldID(), + IsDynamic: field.GetIsDynamic(), } } @@ -93,7 +99,8 @@ func genEmptyArrayFieldData(field *schemapb.FieldSchema) *schemapb.FieldData { Data: &schemapb.ScalarField_ArrayData{ArrayData: &schemapb.ArrayArray{Data: nil}}, }, }, - FieldId: field.GetFieldID(), + FieldId: field.GetFieldID(), + IsDynamic: field.GetIsDynamic(), } } @@ -106,7 +113,8 @@ func genEmptyJSONFieldData(field *schemapb.FieldSchema) *schemapb.FieldData { Data: &schemapb.ScalarField_JsonData{JsonData: &schemapb.JSONArray{Data: nil}}, }, }, - FieldId: field.GetFieldID(), + FieldId: field.GetFieldID(), + IsDynamic: field.GetIsDynamic(), } } @@ -126,7 +134,8 @@ func genEmptyBinaryVectorFieldData(field *schemapb.FieldSchema) (*schemapb.Field }, }, }, - FieldId: field.GetFieldID(), + FieldId: field.GetFieldID(), + IsDynamic: field.GetIsDynamic(), }, nil } @@ -146,7 +155,29 @@ func genEmptyFloatVectorFieldData(field *schemapb.FieldSchema) (*schemapb.FieldD }, }, }, - FieldId: field.GetFieldID(), + FieldId: field.GetFieldID(), + IsDynamic: field.GetIsDynamic(), + }, nil +} + +func genEmptyFloat16VectorFieldData(field *schemapb.FieldSchema) (*schemapb.FieldData, error) { + dim, err := GetDim(field) + if err != nil { + return nil, err + } + return &schemapb.FieldData{ + Type: field.GetDataType(), + FieldName: field.GetName(), + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: nil, + }, + }, + }, + FieldId: field.GetFieldID(), + IsDynamic: field.GetIsDynamic(), }, nil } @@ -173,6 +204,8 @@ func GenEmptyFieldData(field *schemapb.FieldSchema) (*schemapb.FieldData, error) return genEmptyBinaryVectorFieldData(field) case schemapb.DataType_FloatVector: return genEmptyFloatVectorFieldData(field) + case schemapb.DataType_Float16Vector: + return genEmptyFloat16VectorFieldData(field) default: return nil, fmt.Errorf("unsupported data type: %s", dataType.String()) } diff --git a/pkg/util/typeutil/hash.go b/pkg/util/typeutil/hash.go index 68f9336ae5afe..331785f3054f5 100644 --- a/pkg/util/typeutil/hash.go +++ b/pkg/util/typeutil/hash.go @@ -93,7 +93,7 @@ func HashPK2Channels(primaryKeys *schemapb.IDs, shardNames []string) []uint32 { hashValues = append(hashValues, hash%numShard) } default: - //TODO:: + // TODO:: } return hashValues @@ -121,7 +121,6 @@ func HashKey2Partitions(keys *schemapb.FieldData, partitionNames []string) ([]ui } default: return nil, errors.New("currently only support DataType Int64 or VarChar as partition key Field") - } default: return nil, errors.New("currently not support vector field as partition keys") diff --git a/pkg/util/typeutil/hash_test.go b/pkg/util/typeutil/hash_test.go index fccbb3ad1234a..e561af91c9c09 100644 --- a/pkg/util/typeutil/hash_test.go +++ b/pkg/util/typeutil/hash_test.go @@ -21,13 +21,14 @@ import ( "testing" "unsafe" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) func TestUint64(t *testing.T) { var i int64 = -1 - var u = uint64(i) + u := uint64(i) t.Log(i) t.Log(u) } @@ -54,7 +55,7 @@ func TestHash32_Uint64(t *testing.T) { } func TestHash32_String(t *testing.T) { - var u = "ok" + u := "ok" h, err := Hash32String(u) assert.NoError(t, err) @@ -151,7 +152,7 @@ func TestHashPK2Channels(t *testing.T) { } ret := HashPK2Channels(int64IDs, channels) assert.Equal(t, 5, len(ret)) - //same pk hash to same channel + // same pk hash to same channel assert.Equal(t, ret[1], ret[2]) stringIDs := &schemapb.IDs{ diff --git a/pkg/util/typeutil/index_test.go b/pkg/util/typeutil/index_test.go index 856625ded6ab1..b6044aac4cec6 100644 --- a/pkg/util/typeutil/index_test.go +++ b/pkg/util/typeutil/index_test.go @@ -19,8 +19,9 @@ package typeutil import ( "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) func TestCompareIndexParams(t *testing.T) { diff --git a/pkg/util/typeutil/kv_pair_helper_test.go b/pkg/util/typeutil/kv_pair_helper_test.go index 86de0d214d38f..576aea68e70de 100644 --- a/pkg/util/typeutil/kv_pair_helper_test.go +++ b/pkg/util/typeutil/kv_pair_helper_test.go @@ -3,9 +3,10 @@ package typeutil import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/common" - "github.com/stretchr/testify/assert" ) func TestNewKvPairs(t *testing.T) { diff --git a/pkg/util/typeutil/map.go b/pkg/util/typeutil/map.go index 6da8d3a02dccf..c11b1528dcc42 100644 --- a/pkg/util/typeutil/map.go +++ b/pkg/util/typeutil/map.go @@ -101,6 +101,14 @@ func (m *ConcurrentMap[K, V]) GetAndRemove(key K) (V, bool) { return value.(V), true } +// Remove removes the `key`, `value` set if `key` is in the map, +// does nothing if `key` not in the map. +func (m *ConcurrentMap[K, V]) Remove(key K) { + if _, loaded := m.inner.LoadAndDelete(key); loaded { + m.len.Dec() + } +} + func (m *ConcurrentMap[K, V]) Len() int { return int(m.len.Load()) } diff --git a/pkg/util/typeutil/map_test.go b/pkg/util/typeutil/map_test.go index 2e96fa016775b..b34a113ea9b73 100644 --- a/pkg/util/typeutil/map_test.go +++ b/pkg/util/typeutil/map_test.go @@ -127,6 +127,23 @@ func (suite *MapUtilSuite) TestConcurrentMap() { suite.FailNow("empty map range") return false }) + + suite.Run("TestRemove", func() { + currMap := NewConcurrentMap[int64, string]() + suite.Equal(0, currMap.Len()) + + currMap.Remove(100) + suite.Equal(0, currMap.Len()) + + suite.Equal(currMap.Len(), 0) + v, loaded := currMap.GetOrInsert(100, "v-100") + suite.Equal("v-100", v) + suite.Equal(false, loaded) + suite.Equal(1, currMap.Len()) + + currMap.Remove(100) + suite.Equal(0, currMap.Len()) + }) } func TestMapUtil(t *testing.T) { diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 8d3f8056ea454..16f12fa40f688 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -20,8 +20,10 @@ import ( "fmt" "math" "strconv" + "unsafe" "github.com/cockroachdb/errors" + "github.com/samber/lo" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -105,6 +107,17 @@ func EstimateSizePerRecord(schema *schemapb.CollectionSchema) (int, error) { break } } + case schemapb.DataType_Float16Vector: + for _, kv := range fs.TypeParams { + if kv.Key == common.DimKey { + v, err := strconv.Atoi(kv.Value) + if err != nil { + return -1, err + } + res += v * 2 + break + } + } } } return res, nil @@ -305,7 +318,7 @@ func (helper *SchemaHelper) GetVectorDimFromID(fieldID int64) (int, error) { // IsVectorType returns true if input is a vector type, otherwise false func IsVectorType(dataType schemapb.DataType) bool { switch dataType { - case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector: + case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector: return true default: return false @@ -371,7 +384,7 @@ func IsVariableDataType(dataType schemapb.DataType) bool { } // AppendFieldData appends fields data of specified index from src to dst -func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx int64) { +func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx int64) (appendSize int64) { for i, fieldData := range src { switch fieldType := fieldData.Field.(type) { case *schemapb.FieldData_Scalars: @@ -398,6 +411,8 @@ func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx i } else { dstScalar.GetBoolData().Data = append(dstScalar.GetBoolData().Data, srcScalar.BoolData.Data[idx]) } + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcScalar.BoolData.Data[idx])) case *schemapb.ScalarField_IntData: if dstScalar.GetIntData() == nil { dstScalar.Data = &schemapb.ScalarField_IntData{ @@ -408,6 +423,8 @@ func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx i } else { dstScalar.GetIntData().Data = append(dstScalar.GetIntData().Data, srcScalar.IntData.Data[idx]) } + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcScalar.IntData.Data[idx])) case *schemapb.ScalarField_LongData: if dstScalar.GetLongData() == nil { dstScalar.Data = &schemapb.ScalarField_LongData{ @@ -418,6 +435,8 @@ func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx i } else { dstScalar.GetLongData().Data = append(dstScalar.GetLongData().Data, srcScalar.LongData.Data[idx]) } + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcScalar.LongData.Data[idx])) case *schemapb.ScalarField_FloatData: if dstScalar.GetFloatData() == nil { dstScalar.Data = &schemapb.ScalarField_FloatData{ @@ -428,6 +447,8 @@ func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx i } else { dstScalar.GetFloatData().Data = append(dstScalar.GetFloatData().Data, srcScalar.FloatData.Data[idx]) } + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcScalar.FloatData.Data[idx])) case *schemapb.ScalarField_DoubleData: if dstScalar.GetDoubleData() == nil { dstScalar.Data = &schemapb.ScalarField_DoubleData{ @@ -438,6 +459,8 @@ func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx i } else { dstScalar.GetDoubleData().Data = append(dstScalar.GetDoubleData().Data, srcScalar.DoubleData.Data[idx]) } + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcScalar.DoubleData.Data[idx])) case *schemapb.ScalarField_StringData: if dstScalar.GetStringData() == nil { dstScalar.Data = &schemapb.ScalarField_StringData{ @@ -448,16 +471,21 @@ func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx i } else { dstScalar.GetStringData().Data = append(dstScalar.GetStringData().Data, srcScalar.StringData.Data[idx]) } + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcScalar.StringData.Data[idx])) case *schemapb.ScalarField_ArrayData: if dstScalar.GetArrayData() == nil { dstScalar.Data = &schemapb.ScalarField_ArrayData{ ArrayData: &schemapb.ArrayArray{ - Data: []*schemapb.ScalarField{srcScalar.ArrayData.Data[idx]}, + Data: []*schemapb.ScalarField{srcScalar.ArrayData.Data[idx]}, + ElementType: srcScalar.ArrayData.ElementType, }, } } else { dstScalar.GetArrayData().Data = append(dstScalar.GetArrayData().Data, srcScalar.ArrayData.Data[idx]) } + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcScalar.ArrayData.Data[idx])) case *schemapb.ScalarField_JsonData: if dstScalar.GetJsonData() == nil { dstScalar.Data = &schemapb.ScalarField_JsonData{ @@ -468,6 +496,8 @@ func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx i } else { dstScalar.GetJsonData().Data = append(dstScalar.GetJsonData().Data, srcScalar.JsonData.Data[idx]) } + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcScalar.JsonData.Data[idx])) default: log.Error("Not supported field type", zap.String("field type", fieldData.Type.String())) } @@ -498,6 +528,8 @@ func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx i dstBinaryVector := dstVector.Data.(*schemapb.VectorField_BinaryVector) dstBinaryVector.BinaryVector = append(dstBinaryVector.BinaryVector, srcVector.BinaryVector[idx*(dim/8):(idx+1)*(dim/8)]...) } + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcVector.BinaryVector[idx*(dim/8) : (idx+1)*(dim/8)])) case *schemapb.VectorField_FloatVector: if dstVector.GetFloatVector() == nil { srcToCopy := srcVector.FloatVector.Data[idx*dim : (idx+1)*dim] @@ -510,11 +542,28 @@ func AppendFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData, idx i } else { dstVector.GetFloatVector().Data = append(dstVector.GetFloatVector().Data, srcVector.FloatVector.Data[idx*dim:(idx+1)*dim]...) } + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcVector.FloatVector.Data[idx*dim : (idx+1)*dim])) + case *schemapb.VectorField_Float16Vector: + if dstVector.GetFloat16Vector() == nil { + srcToCopy := srcVector.Float16Vector[idx*(dim*2) : (idx+1)*(dim*2)] + dstVector.Data = &schemapb.VectorField_Float16Vector{ + Float16Vector: make([]byte, len(srcToCopy)), + } + copy(dstVector.Data.(*schemapb.VectorField_Float16Vector).Float16Vector, srcToCopy) + } else { + dstFloat16Vector := dstVector.Data.(*schemapb.VectorField_Float16Vector) + dstFloat16Vector.Float16Vector = append(dstFloat16Vector.Float16Vector, srcVector.Float16Vector[idx*(dim*2):(idx+1)*(dim*2)]...) + } + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcVector.Float16Vector[idx*(dim*2) : (idx+1)*(dim*2)])) default: log.Error("Not supported field type", zap.String("field type", fieldData.Type.String())) } } } + + return } // DeleteFieldData delete fields data appended last time @@ -558,6 +607,9 @@ func DeleteFieldData(dst []*schemapb.FieldData) { dstBinaryVector.BinaryVector = dstBinaryVector.BinaryVector[:len(dstBinaryVector.BinaryVector)-int(dim/8)] case *schemapb.VectorField_FloatVector: dstVector.GetFloatVector().Data = dstVector.GetFloatVector().Data[:len(dstVector.GetFloatVector().Data)-int(dim)] + case *schemapb.VectorField_Float16Vector: + dstFloat16Vector := dstVector.Data.(*schemapb.VectorField_Float16Vector) + dstFloat16Vector.Float16Vector = dstFloat16Vector.Float16Vector[:len(dstFloat16Vector.Float16Vector)-int(dim*2)] default: log.Error("wrong field type added", zap.String("field type", fieldData.Type.String())) } @@ -648,6 +700,17 @@ func MergeFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData) error } else { dstScalar.GetStringData().Data = append(dstScalar.GetStringData().Data, srcScalar.StringData.Data...) } + case *schemapb.ScalarField_ArrayData: + if dstScalar.GetArrayData() == nil { + dstScalar.Data = &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: srcScalar.ArrayData.Data, + ElementType: srcScalar.ArrayData.ElementType, + }, + } + } else { + dstScalar.GetArrayData().Data = append(dstScalar.GetArrayData().Data, srcScalar.ArrayData.Data...) + } case *schemapb.ScalarField_JsonData: if dstScalar.GetJsonData() == nil { dstScalar.Data = &schemapb.ScalarField_JsonData{ @@ -741,6 +804,16 @@ func GetPartitionKeyFieldSchema(schema *schemapb.CollectionSchema) (*schemapb.Fi return nil, errors.New("partition key field is not found") } +// HasPartitionKey check if a collection schema has PartitionKey field +func HasPartitionKey(schema *schemapb.CollectionSchema) bool { + for _, fieldSchema := range schema.Fields { + if fieldSchema.IsPartitionKey { + return true + } + } + return false +} + // GetPrimaryFieldData get primary field data from all field data inserted from sdk func GetPrimaryFieldData(datas []*schemapb.FieldData, primaryFieldSchema *schemapb.FieldSchema) (*schemapb.FieldData, error) { primaryFieldID := primaryFieldSchema.FieldID @@ -761,6 +834,12 @@ func GetPrimaryFieldData(datas []*schemapb.FieldData, primaryFieldSchema *schema return primaryFieldData, nil } +func GetField(schema *schemapb.CollectionSchema, fieldID int64) *schemapb.FieldSchema { + return lo.FindOrElse(schema.GetFields(), nil, func(field *schemapb.FieldSchema) bool { + return field.GetFieldID() == fieldID + }) +} + func IsPrimaryFieldDataExist(datas []*schemapb.FieldData, primaryFieldSchema *schemapb.FieldSchema) bool { primaryFieldID := primaryFieldSchema.FieldID primaryFieldName := primaryFieldSchema.Name @@ -799,7 +878,7 @@ func AppendIDs(dst *schemapb.IDs, src *schemapb.IDs, idx int) { dst.GetStrId().Data = append(dst.GetStrId().Data, src.GetStrId().Data[idx]) } default: - //TODO + // TODO } } @@ -815,7 +894,7 @@ func GetSizeOfIDs(data *schemapb.IDs) int { case *schemapb.IDs_StrId: result = len(data.GetStrId().GetData()) default: - //TODO:: + // TODO:: } return result @@ -873,6 +952,10 @@ func GetData(field *schemapb.FieldData, idx int) interface{} { dim := int(field.GetVectors().GetDim()) dataBytes := dim / 8 return field.GetVectors().GetBinaryVector()[idx*dataBytes : (idx+1)*dataBytes] + case schemapb.DataType_Float16Vector: + dim := int(field.GetVectors().GetDim()) + dataBytes := dim * 2 + return field.GetVectors().GetFloat16Vector()[idx*dataBytes : (idx+1)*dataBytes] } return nil } @@ -939,7 +1022,7 @@ type ResultWithID interface { } // SelectMinPK select the index of the minPK in results T of the cursors. -func SelectMinPK[T ResultWithID](results []T, cursors []int64) int { +func SelectMinPK[T ResultWithID](results []T, cursors []int64, stopForBest bool, realLimit int64) int { var ( sel = -1 minIntPK int64 = math.MaxInt64 @@ -950,6 +1033,18 @@ func SelectMinPK[T ResultWithID](results []T, cursors []int64) int { for i, cursor := range cursors { if int(cursor) >= GetSizeOfIDs(results[i].GetIds()) { + if realLimit == Unlimited { + // if there is no limit set and all possible results of one query unit(shard or segment) + // has drained all possible results without any leftover, so it's safe to continue the selection + // under this case + continue + } + if stopForBest && GetSizeOfIDs(results[i].GetIds()) >= int(realLimit) { + // if one query unit(shard or segment) has more than realLimit results, and it has run out of + // all results in this round, then we have to stop select since there may be further the latest result + // in the following result of current query unit + return -1 + } continue } diff --git a/pkg/util/typeutil/schema_test.go b/pkg/util/typeutil/schema_test.go index 594f98364255c..ba0f9eaf92936 100644 --- a/pkg/util/typeutil/schema_test.go +++ b/pkg/util/typeutil/schema_test.go @@ -21,12 +21,12 @@ import ( "reflect" "testing" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" ) @@ -210,7 +210,6 @@ func TestSchema(t *testing.T) { } func TestSchema_GetVectorFieldSchema(t *testing.T) { - schemaNormal := &schemapb.CollectionSchema{ Name: "testColl", Description: "", @@ -264,7 +263,6 @@ func TestSchema_GetVectorFieldSchema(t *testing.T) { _, err := GetVectorFieldSchema(schemaInvalid) assert.Error(t, err) }) - } func TestSchema_invalid(t *testing.T) { @@ -588,8 +586,21 @@ func genFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType, }, FieldId: fieldID, } + case schemapb.DataType_Float16Vector: + fieldData = &schemapb.FieldData{ + Type: schemapb.DataType_Float16Vector, + FieldName: fieldName, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: fieldValue.([]byte), + }, + }, + }, + FieldId: fieldID, + } case schemapb.DataType_Array: - data := fieldValue.([][]int32) fieldData = &schemapb.FieldData{ Type: schemapb.DataType_Array, FieldName: fieldName, @@ -597,23 +608,13 @@ func genFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType, Scalars: &schemapb.ScalarField{ Data: &schemapb.ScalarField_ArrayData{ ArrayData: &schemapb.ArrayArray{ - Data: []*schemapb.ScalarField{}, + Data: fieldValue.([]*schemapb.ScalarField), ElementType: schemapb.DataType_Int32, }, }, }, }, - } - - for _, list := range data { - arrayList := fieldData.GetScalars().GetArrayData() - arrayList.Data = append(arrayList.Data, &schemapb.ScalarField{ - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{ - Data: list, - }, - }, - }) + FieldId: fieldID, } case schemapb.DataType_JSON: @@ -640,21 +641,25 @@ func genFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType, func TestAppendFieldData(t *testing.T) { const ( - Dim = 8 - BoolFieldName = "BoolField" - Int32FieldName = "Int32Field" - Int64FieldName = "Int64Field" - FloatFieldName = "FloatField" - DoubleFieldName = "DoubleField" - BinaryVectorFieldName = "BinaryVectorField" - FloatVectorFieldName = "FloatVectorField" - BoolFieldID = common.StartOfUserFieldID + 1 - Int32FieldID = common.StartOfUserFieldID + 2 - Int64FieldID = common.StartOfUserFieldID + 3 - FloatFieldID = common.StartOfUserFieldID + 4 - DoubleFieldID = common.StartOfUserFieldID + 5 - BinaryVectorFieldID = common.StartOfUserFieldID + 6 - FloatVectorFieldID = common.StartOfUserFieldID + 7 + Dim = 8 + BoolFieldName = "BoolField" + Int32FieldName = "Int32Field" + Int64FieldName = "Int64Field" + FloatFieldName = "FloatField" + DoubleFieldName = "DoubleField" + BinaryVectorFieldName = "BinaryVectorField" + FloatVectorFieldName = "FloatVectorField" + Float16VectorFieldName = "Float16VectorField" + ArrayFieldName = "ArrayField" + BoolFieldID = common.StartOfUserFieldID + 1 + Int32FieldID = common.StartOfUserFieldID + 2 + Int64FieldID = common.StartOfUserFieldID + 3 + FloatFieldID = common.StartOfUserFieldID + 4 + DoubleFieldID = common.StartOfUserFieldID + 5 + BinaryVectorFieldID = common.StartOfUserFieldID + 6 + FloatVectorFieldID = common.StartOfUserFieldID + 7 + Float16VectorFieldID = common.StartOfUserFieldID + 8 + ArrayFieldID = common.StartOfUserFieldID + 9 ) BoolArray := []bool{true, false} Int32Array := []int32{1, 2} @@ -663,8 +668,28 @@ func TestAppendFieldData(t *testing.T) { DoubleArray := []float64{11.0, 22.0} BinaryVector := []byte{0x12, 0x34} FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} + Float16Vector := []byte{ + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, + } + ArrayArray := []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3}, + }, + }, + }, + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{4, 5, 6}, + }, + }, + }, + } - result := make([]*schemapb.FieldData, 7) + result := make([]*schemapb.FieldData, 9) var fieldDataArray1 []*schemapb.FieldData fieldDataArray1 = append(fieldDataArray1, genFieldData(BoolFieldName, BoolFieldID, schemapb.DataType_Bool, BoolArray[0:1], 1)) fieldDataArray1 = append(fieldDataArray1, genFieldData(Int32FieldName, Int32FieldID, schemapb.DataType_Int32, Int32Array[0:1], 1)) @@ -673,6 +698,8 @@ func TestAppendFieldData(t *testing.T) { fieldDataArray1 = append(fieldDataArray1, genFieldData(DoubleFieldName, DoubleFieldID, schemapb.DataType_Double, DoubleArray[0:1], 1)) fieldDataArray1 = append(fieldDataArray1, genFieldData(BinaryVectorFieldName, BinaryVectorFieldID, schemapb.DataType_BinaryVector, BinaryVector[0:Dim/8], Dim)) fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:Dim], Dim)) + fieldDataArray1 = append(fieldDataArray1, genFieldData(Float16VectorFieldName, Float16VectorFieldID, schemapb.DataType_Float16Vector, Float16Vector[0:Dim*2], Dim)) + fieldDataArray1 = append(fieldDataArray1, genFieldData(ArrayFieldName, ArrayFieldID, schemapb.DataType_Array, ArrayArray[0:1], 1)) var fieldDataArray2 []*schemapb.FieldData fieldDataArray2 = append(fieldDataArray2, genFieldData(BoolFieldName, BoolFieldID, schemapb.DataType_Bool, BoolArray[1:2], 1)) @@ -682,6 +709,8 @@ func TestAppendFieldData(t *testing.T) { fieldDataArray2 = append(fieldDataArray2, genFieldData(DoubleFieldName, DoubleFieldID, schemapb.DataType_Double, DoubleArray[1:2], 1)) fieldDataArray2 = append(fieldDataArray2, genFieldData(BinaryVectorFieldName, BinaryVectorFieldID, schemapb.DataType_BinaryVector, BinaryVector[Dim/8:2*Dim/8], Dim)) fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[Dim:2*Dim], Dim)) + fieldDataArray2 = append(fieldDataArray2, genFieldData(Float16VectorFieldName, Float16VectorFieldID, schemapb.DataType_Float16Vector, Float16Vector[2*Dim:4*Dim], Dim)) + fieldDataArray2 = append(fieldDataArray2, genFieldData(ArrayFieldName, ArrayFieldID, schemapb.DataType_Array, ArrayArray[1:2], 1)) AppendFieldData(result, fieldDataArray1, 0) AppendFieldData(result, fieldDataArray2, 0) @@ -693,19 +722,22 @@ func TestAppendFieldData(t *testing.T) { assert.Equal(t, DoubleArray, result[4].GetScalars().GetDoubleData().Data) assert.Equal(t, BinaryVector, result[5].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector) assert.Equal(t, FloatVector, result[6].GetVectors().GetFloatVector().Data) + assert.Equal(t, Float16Vector, result[7].GetVectors().Data.(*schemapb.VectorField_Float16Vector).Float16Vector) + assert.Equal(t, ArrayArray, result[8].GetScalars().GetArrayData().Data) } func TestDeleteFieldData(t *testing.T) { const ( - Dim = 8 - BoolFieldName = "BoolField" - Int32FieldName = "Int32Field" - Int64FieldName = "Int64Field" - FloatFieldName = "FloatField" - DoubleFieldName = "DoubleField" - JSONFieldName = "JSONField" - BinaryVectorFieldName = "BinaryVectorField" - FloatVectorFieldName = "FloatVectorField" + Dim = 8 + BoolFieldName = "BoolField" + Int32FieldName = "Int32Field" + Int64FieldName = "Int64Field" + FloatFieldName = "FloatField" + DoubleFieldName = "DoubleField" + JSONFieldName = "JSONField" + BinaryVectorFieldName = "BinaryVectorField" + FloatVectorFieldName = "FloatVectorField" + Float16VectorFieldName = "Float16VectorField" ) const ( @@ -717,6 +749,7 @@ func TestDeleteFieldData(t *testing.T) { JSONFieldID BinaryVectorFieldID FloatVectorFieldID + Float16VectorFieldID ) BoolArray := []bool{true, false} Int32Array := []int32{1, 2} @@ -726,9 +759,13 @@ func TestDeleteFieldData(t *testing.T) { JSONArray := [][]byte{[]byte("{\"hello\":0}"), []byte("{\"key\":1}")} BinaryVector := []byte{0x12, 0x34} FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} + Float16Vector := []byte{ + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, + } - result1 := make([]*schemapb.FieldData, 8) - result2 := make([]*schemapb.FieldData, 8) + result1 := make([]*schemapb.FieldData, 9) + result2 := make([]*schemapb.FieldData, 9) var fieldDataArray1 []*schemapb.FieldData fieldDataArray1 = append(fieldDataArray1, genFieldData(BoolFieldName, BoolFieldID, schemapb.DataType_Bool, BoolArray[0:1], 1)) fieldDataArray1 = append(fieldDataArray1, genFieldData(Int32FieldName, Int32FieldID, schemapb.DataType_Int32, Int32Array[0:1], 1)) @@ -738,6 +775,7 @@ func TestDeleteFieldData(t *testing.T) { fieldDataArray1 = append(fieldDataArray1, genFieldData(JSONFieldName, JSONFieldID, schemapb.DataType_JSON, JSONArray[0:1], 1)) fieldDataArray1 = append(fieldDataArray1, genFieldData(BinaryVectorFieldName, BinaryVectorFieldID, schemapb.DataType_BinaryVector, BinaryVector[0:Dim/8], Dim)) fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:Dim], Dim)) + fieldDataArray1 = append(fieldDataArray1, genFieldData(Float16VectorFieldName, Float16VectorFieldID, schemapb.DataType_Float16Vector, Float16Vector[0:2*Dim], Dim)) var fieldDataArray2 []*schemapb.FieldData fieldDataArray2 = append(fieldDataArray2, genFieldData(BoolFieldName, BoolFieldID, schemapb.DataType_Bool, BoolArray[1:2], 1)) @@ -748,6 +786,7 @@ func TestDeleteFieldData(t *testing.T) { fieldDataArray2 = append(fieldDataArray2, genFieldData(JSONFieldName, JSONFieldID, schemapb.DataType_JSON, JSONArray[1:2], 1)) fieldDataArray2 = append(fieldDataArray2, genFieldData(BinaryVectorFieldName, BinaryVectorFieldID, schemapb.DataType_BinaryVector, BinaryVector[Dim/8:2*Dim/8], Dim)) fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[Dim:2*Dim], Dim)) + fieldDataArray2 = append(fieldDataArray2, genFieldData(Float16VectorFieldName, Float16VectorFieldID, schemapb.DataType_Float16Vector, Float16Vector[2*Dim:4*Dim], Dim)) AppendFieldData(result1, fieldDataArray1, 0) AppendFieldData(result1, fieldDataArray2, 0) @@ -760,6 +799,7 @@ func TestDeleteFieldData(t *testing.T) { assert.Equal(t, JSONArray[0:1], result1[JSONFieldID-common.StartOfUserFieldID].GetScalars().GetJsonData().Data) assert.Equal(t, BinaryVector[0:Dim/8], result1[BinaryVectorFieldID-common.StartOfUserFieldID].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector) assert.Equal(t, FloatVector[0:Dim], result1[FloatVectorFieldID-common.StartOfUserFieldID].GetVectors().GetFloatVector().Data) + assert.Equal(t, Float16Vector[0:2*Dim], result1[Float16VectorFieldID-common.StartOfUserFieldID].GetVectors().Data.(*schemapb.VectorField_Float16Vector).Float16Vector) AppendFieldData(result2, fieldDataArray2, 0) AppendFieldData(result2, fieldDataArray1, 0) @@ -772,6 +812,7 @@ func TestDeleteFieldData(t *testing.T) { assert.Equal(t, JSONArray[1:2], result2[JSONFieldID-common.StartOfUserFieldID].GetScalars().GetJsonData().Data) assert.Equal(t, BinaryVector[Dim/8:2*Dim/8], result2[BinaryVectorFieldID-common.StartOfUserFieldID].GetVectors().Data.(*schemapb.VectorField_BinaryVector).BinaryVector) assert.Equal(t, FloatVector[Dim:2*Dim], result2[FloatVectorFieldID-common.StartOfUserFieldID].GetVectors().GetFloatVector().Data) + assert.Equal(t, Float16Vector[2*Dim:4*Dim], result2[Float16VectorFieldID-common.StartOfUserFieldID].GetVectors().Data.(*schemapb.VectorField_Float16Vector).Float16Vector) } func TestGetPrimaryFieldSchema(t *testing.T) { @@ -794,11 +835,16 @@ func TestGetPrimaryFieldSchema(t *testing.T) { // no primary field error _, err := GetPrimaryFieldSchema(schema) assert.Error(t, err) - int64Field.IsPrimaryKey = true primaryField, err := GetPrimaryFieldSchema(schema) assert.NoError(t, err) assert.Equal(t, schemapb.DataType_Int64, primaryField.DataType) + + hasPartitionKey := HasPartitionKey(schema) + assert.False(t, hasPartitionKey) + int64Field.IsPartitionKey = true + hasPartitionKey2 := HasPartitionKey(schema) + assert.True(t, hasPartitionKey2) } func TestGetPK(t *testing.T) { @@ -968,7 +1014,22 @@ func TestCalcColumnSize(t *testing.T) { 105: []float64{0, 1}, 106: []string{"0", "1"}, 107: []float32{0, 1, 2, 3}, - 109: [][]int32{{1, 2, 3}, {4, 5, 6}}, + 109: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3}, + }, + }, + }, + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{4, 5, 6}, + }, + }, + }, + }, 110: [][]byte{[]byte(`{"key":"value"}`), []byte(`{"hello":"world"}`)}, } schema := &schemapb.CollectionSchema{ @@ -1056,9 +1117,9 @@ func TestCalcColumnSize(t *testing.T) { expected += len(v) } case schemapb.DataType_Array: - data := values.([][]int32) + data := values.([]*schemapb.ScalarField) for _, v := range data { - expected += binary.Size(v) + expected += binary.Size(v.GetIntData().GetData()) } case schemapb.DataType_JSON: data := values.([][]byte) @@ -1091,6 +1152,10 @@ func TestGetDataAndGetDataSize(t *testing.T) { VarCharArray := []string{"a", "b"} BinaryVector := []byte{0x12, 0x34} FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} + Float16Vector := []byte{ + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + } boolData := genFieldData(fieldName, fieldID, schemapb.DataType_Bool, BoolArray, 1) int8Data := genFieldData(fieldName, fieldID, schemapb.DataType_Int8, Int8Array, 1) @@ -1102,6 +1167,7 @@ func TestGetDataAndGetDataSize(t *testing.T) { varCharData := genFieldData(fieldName, fieldID, schemapb.DataType_VarChar, VarCharArray, 1) binVecData := genFieldData(fieldName, fieldID, schemapb.DataType_BinaryVector, BinaryVector, Dim) floatVecData := genFieldData(fieldName, fieldID, schemapb.DataType_FloatVector, FloatVector, Dim) + float16VecData := genFieldData(fieldName, fieldID, schemapb.DataType_Float16Vector, Float16Vector, Dim) invalidData := &schemapb.FieldData{ Type: schemapb.DataType_None, } @@ -1125,6 +1191,7 @@ func TestGetDataAndGetDataSize(t *testing.T) { varCharDataRes := GetData(varCharData, 0) binVecDataRes := GetData(binVecData, 0) floatVecDataRes := GetData(floatVecData, 0) + float16VecDataRes := GetData(float16VecData, 0) invalidDataRes := GetData(invalidData, 0) assert.Equal(t, BoolArray[0], boolDataRes) @@ -1137,40 +1204,122 @@ func TestGetDataAndGetDataSize(t *testing.T) { assert.Equal(t, VarCharArray[0], varCharDataRes) assert.ElementsMatch(t, BinaryVector[:Dim/8], binVecDataRes) assert.ElementsMatch(t, FloatVector[:Dim], floatVecDataRes) + assert.ElementsMatch(t, Float16Vector[:2*Dim], float16VecDataRes) assert.Nil(t, invalidDataRes) }) } func TestMergeFieldData(t *testing.T) { - dstFields := []*schemapb.FieldData{ - genFieldData("int64", 100, schemapb.DataType_Int64, []int64{1, 2, 3}, 1), - genFieldData("vector", 101, schemapb.DataType_FloatVector, []float32{1, 2, 3}, 1), - genFieldData("json", 102, schemapb.DataType_JSON, [][]byte{[]byte(`{"key":"value"}`), []byte(`{"hello":"world"}`)}, 1), - } + t.Run("merge data", func(t *testing.T) { + dstFields := []*schemapb.FieldData{ + genFieldData("int64", 100, schemapb.DataType_Int64, []int64{1, 2, 3}, 1), + genFieldData("vector", 101, schemapb.DataType_FloatVector, []float32{1, 2, 3}, 1), + genFieldData("json", 102, schemapb.DataType_JSON, [][]byte{[]byte(`{"key":"value"}`), []byte(`{"hello":"world"}`)}, 1), + genFieldData("array", 103, schemapb.DataType_Array, []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3}, + }, + }, + }, + }, 1), + } - srcFields := []*schemapb.FieldData{ - genFieldData("int64", 100, schemapb.DataType_Int64, []int64{4, 5, 6}, 1), - genFieldData("vector", 101, schemapb.DataType_FloatVector, []float32{4, 5, 6}, 1), - genFieldData("json", 102, schemapb.DataType_JSON, [][]byte{[]byte(`{"key":"value"}`), []byte(`{"hello":"world"}`)}, 1), - } + srcFields := []*schemapb.FieldData{ + genFieldData("int64", 100, schemapb.DataType_Int64, []int64{4, 5, 6}, 1), + genFieldData("vector", 101, schemapb.DataType_FloatVector, []float32{4, 5, 6}, 1), + genFieldData("json", 102, schemapb.DataType_JSON, [][]byte{[]byte(`{"key":"value"}`), []byte(`{"hello":"world"}`)}, 1), + genFieldData("array", 103, schemapb.DataType_Array, []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{4, 5, 6}, + }, + }, + }, + }, 1), + } - err := MergeFieldData(dstFields, srcFields) - assert.NoError(t, err) + err := MergeFieldData(dstFields, srcFields) + assert.NoError(t, err) - assert.Equal(t, []int64{1, 2, 3, 4, 5, 6}, dstFields[0].GetScalars().GetLongData().Data) - assert.Equal(t, []float32{1, 2, 3, 4, 5, 6}, dstFields[1].GetVectors().GetFloatVector().Data) - assert.Equal(t, [][]byte{[]byte(`{"key":"value"}`), []byte(`{"hello":"world"}`), []byte(`{"key":"value"}`), []byte(`{"hello":"world"}`)}, - dstFields[2].GetScalars().GetJsonData().Data) + assert.Equal(t, []int64{1, 2, 3, 4, 5, 6}, dstFields[0].GetScalars().GetLongData().Data) + assert.Equal(t, []float32{1, 2, 3, 4, 5, 6}, dstFields[1].GetVectors().GetFloatVector().Data) + assert.Equal(t, [][]byte{[]byte(`{"key":"value"}`), []byte(`{"hello":"world"}`), []byte(`{"key":"value"}`), []byte(`{"hello":"world"}`)}, + dstFields[2].GetScalars().GetJsonData().Data) + assert.Equal(t, []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3}, + }, + }, + }, + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{4, 5, 6}, + }, + }, + }, + }, + dstFields[3].GetScalars().GetArrayData().Data) + }) - emptyField := &schemapb.FieldData{ - Type: schemapb.DataType_None, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: nil, + t.Run("merge with nil", func(t *testing.T) { + srcFields := []*schemapb.FieldData{ + genFieldData("int64", 100, schemapb.DataType_Int64, []int64{1, 2, 3}, 1), + genFieldData("vector", 101, schemapb.DataType_FloatVector, []float32{1, 2, 3}, 1), + genFieldData("json", 102, schemapb.DataType_JSON, [][]byte{[]byte(`{"key":"value"}`), []byte(`{"hello":"world"}`)}, 1), + genFieldData("array", 103, schemapb.DataType_Array, []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3}, + }, + }, + }, + }, 1), + } + + dstFields := []*schemapb.FieldData{ + {Type: schemapb.DataType_Int64, FieldName: "int64", Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{}}}, FieldId: 100}, + {Type: schemapb.DataType_FloatVector, FieldName: "vector", Field: &schemapb.FieldData_Vectors{Vectors: &schemapb.VectorField{Data: &schemapb.VectorField_FloatVector{}}}, FieldId: 101}, + {Type: schemapb.DataType_JSON, FieldName: "json", Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_JsonData{}}}, FieldId: 102}, + {Type: schemapb.DataType_Array, FieldName: "array", Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_ArrayData{}}}, FieldId: 103}, + } + + err := MergeFieldData(dstFields, srcFields) + assert.NoError(t, err) + + assert.Equal(t, []int64{1, 2, 3}, dstFields[0].GetScalars().GetLongData().Data) + assert.Equal(t, []float32{1, 2, 3}, dstFields[1].GetVectors().GetFloatVector().Data) + assert.Equal(t, [][]byte{[]byte(`{"key":"value"}`), []byte(`{"hello":"world"}`)}, + dstFields[2].GetScalars().GetJsonData().Data) + assert.Equal(t, []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3}, + }, + }, }, }, - } + dstFields[3].GetScalars().GetArrayData().Data) + }) - err = MergeFieldData([]*schemapb.FieldData{emptyField}, []*schemapb.FieldData{emptyField}) - assert.Error(t, err) + t.Run("error case", func(t *testing.T) { + emptyField := &schemapb.FieldData{ + Type: schemapb.DataType_None, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: nil, + }, + }, + } + + err := MergeFieldData([]*schemapb.FieldData{emptyField}, []*schemapb.FieldData{emptyField}) + assert.Error(t, err) + }) } diff --git a/pkg/util/typeutil/string_util.go b/pkg/util/typeutil/string_util.go index 001aec378313e..031b23f2e2026 100644 --- a/pkg/util/typeutil/string_util.go +++ b/pkg/util/typeutil/string_util.go @@ -27,7 +27,7 @@ func AddOne(data string) string { if len(data) == 0 { return data } - var datab = []byte(data) + datab := []byte(data) if datab[len(datab)-1] != 255 { datab[len(datab)-1]++ } else { diff --git a/pkg/util/typeutil/time.go b/pkg/util/typeutil/time.go index bad3b97012e3b..565dba034e1b9 100644 --- a/pkg/util/typeutil/time.go +++ b/pkg/util/typeutil/time.go @@ -32,7 +32,7 @@ var ZeroTimestamp = Timestamp(0) // ParseTimestamp returns a timestamp for a given byte slice. func ParseTimestamp(data []byte) (time.Time, error) { - //we use big endian here for compatibility issues + // we use big endian here for compatibility issues nano, err := BigEndianBytesToUint64(data) if err != nil { return ZeroTime, err diff --git a/rules.go b/rules.go index c35790a8444ae..5bc3422c9b450 100644 --- a/rules.go +++ b/rules.go @@ -56,7 +56,6 @@ func timeeq(m dsl.Matcher) { // err but no an error func errnoterror(m dsl.Matcher) { - // Would be easier to check for all err identifiers instead, but then how do we get the type from m[] ? m.Match( @@ -146,7 +145,6 @@ func ifreturn(m dsl.Matcher) { m.Match("if !$x { return $*_ }; if $x {$*_ }").Report("odd sequence of if test") m.Match("if $x == $y { return $*_ }; if $x != $y {$*_ }").Report("odd sequence of if test") m.Match("if $x != $y { return $*_ }; if $x == $y {$*_ }").Report("odd sequence of if test") - } func oddifsequence(m dsl.Matcher) { @@ -267,7 +265,6 @@ func floatloop(m dsl.Matcher) { } func urlredacted(m dsl.Matcher) { - m.Match( "log.Println($x, $*_)", "log.Println($*_, $x, $*_)", @@ -292,7 +289,6 @@ func sprinterr(m dsl.Matcher) { ). Where(m["err"].Type.Is("error")). Report("maybe call $err.Error() instead of fmt.Sprint()?") - } // disable this check, because it can not apply to generic type @@ -351,7 +347,6 @@ func nilerr(m dsl.Matcher) { `if err == nil { return $*_, err }`, ). Report(`return nil error instead of nil value`) - } func mailaddress(m dsl.Matcher) { @@ -367,7 +362,6 @@ func mailaddress(m dsl.Matcher) { ). Report("use net/mail Address.String() instead of fmt.Sprintf()"). Suggest("(&mail.Address{Name:$NAME, Address:$EMAIL}).String()") - } func errnetclosed(m dsl.Matcher) { @@ -377,7 +371,6 @@ func errnetclosed(m dsl.Matcher) { Where(m["text"].Text.Matches("\".*closed network connection.*\"")). Report(`String matching against error texts is fragile; use net.ErrClosed instead`). Suggest(`errors.Is($err, net.ErrClosed)`) - } func httpheaderadd(m dsl.Matcher) { diff --git a/scripts/azure_build.sh b/scripts/azure_build.sh new file mode 100644 index 0000000000000..9c0279b349631 --- /dev/null +++ b/scripts/azure_build.sh @@ -0,0 +1,14 @@ +ROOT_DIR=$1 + +ARCHITECTURE=$(uname -m) +if [[ ${ARCHITECTURE} == "aarch64" ]]; then + export VCPKG_FORCE_SYSTEM_BINARIES="arm" +fi + +AZURE_CMAKE_CMD="cmake \ +-DCMAKE_INSTALL_LIBDIR=${ROOT_DIR}/internal/core/output/lib \ +${ROOT_DIR}/internal/core/src/storage/azure-blob-storage" +echo ${AZURE_CMAKE_CMD} +${AZURE_CMAKE_CMD} + +make & make install \ No newline at end of file diff --git a/scripts/core_build.sh b/scripts/core_build.sh index a510e01cfde1e..8ad95c05571c5 100755 --- a/scripts/core_build.sh +++ b/scripts/core_build.sh @@ -90,33 +90,22 @@ BUILD_OUTPUT_DIR="${ROOT_DIR}/cmake_build" BUILD_TYPE="Release" BUILD_UNITTEST="OFF" INSTALL_PREFIX="${CPP_SRC_DIR}/output" -MAKE_CLEAN="OFF" BUILD_COVERAGE="OFF" -DB_PATH="/tmp/milvus" -PROFILING="OFF" RUN_CPPLINT="OFF" CUDA_COMPILER=/usr/local/cuda/bin/nvcc GPU_VERSION="OFF" #defaults to CPU version -WITH_PROMETHEUS="ON" CUDA_ARCH="DEFAULT" -CUSTOM_THIRDPARTY_PATH="" EMBEDDED_MILVUS="OFF" BUILD_DISK_ANN="OFF" USE_ASAN="OFF" -OPEN_SIMD="OFF" USE_DYNAMIC_SIMD="OFF" +INDEX_ENGINE="KNOWHERE" -while getopts "p:d:t:s:f:n:i:y:a:ulrcghzmeb" arg; do +while getopts "p:d:t:s:f:n:i:y:a:x:ulrcghzmebZ" arg; do case $arg in - f) - CUSTOM_THIRDPARTY_PATH=$OPTARG - ;; p) INSTALL_PREFIX=$OPTARG ;; - d) - DB_PATH=$OPTARG - ;; t) BUILD_TYPE=$OPTARG # BUILD_TYPE ;; @@ -127,23 +116,12 @@ while getopts "p:d:t:s:f:n:i:y:a:ulrcghzmeb" arg; do l) RUN_CPPLINT="ON" ;; - r) - if [[ -d ${BUILD_OUTPUT_DIR} ]]; then - MAKE_CLEAN="ON" - fi - ;; c) BUILD_COVERAGE="ON" ;; - z) - PROFILING="ON" - ;; g) GPU_VERSION="ON" ;; - e) - WITH_PROMETHEUS="OFF" - ;; s) CUDA_ARCH=$OPTARG ;; @@ -155,40 +133,41 @@ while getopts "p:d:t:s:f:n:i:y:a:ulrcghzmeb" arg; do ;; a) ENV_VAL=$OPTARG - if [[ ${ENV_VAL} == 'true' ]]; then + if [[ ${ENV_VAL} == 'ON' ]]; then echo "Set USE_ASAN to ON" USE_ASAN="ON" BUILD_TYPE=Debug fi ;; - i) - OPEN_SIMD=$OPTARG - ;; y) USE_DYNAMIC_SIMD=$OPTARG ;; + Z) + BUILD_WITHOUT_AZURE="on" + ;; + x) + INDEX_ENGINE=$OPTARG + ;; h) # help echo " parameter: --f: custom paths of thirdparty downloaded files(default: NULL) -p: install prefix(default: $(pwd)/milvus) -d: db data path(default: /tmp/milvus) -t: build type(default: Debug) -u: building unit test options(default: OFF) -l: run cpplint, clang-format and clang-tidy(default: OFF) --r: remove previous build directory(default: OFF) -c: code coverage(default: OFF) --z: profiling(default: OFF) -g: build GPU version(default: OFF) -e: build without prometheus(default: OFF) -s: build with CUDA arch(default:DEFAULT), for example '-gencode=compute_61,code=sm_61;-gencode=compute_75,code=sm_75' -b: build embedded milvus(default: OFF) -a: build milvus with AddressSanitizer(default: false) +-Z: build milvus without azure-sdk-for-cpp, so cannot use azure blob -h: help usage: -./core_build.sh -p \${INSTALL_PREFIX} -t \${BUILD_TYPE} -s \${CUDA_ARCH} -f\${CUSTOM_THIRDPARTY_PATH} [-u] [-l] [-r] [-c] [-z] [-g] [-m] [-e] [-h] [-b] +./core_build.sh -p \${INSTALL_PREFIX} -t \${BUILD_TYPE} -s \${CUDA_ARCH} [-u] [-l] [-r] [-c] [-z] [-g] [-m] [-e] [-h] [-b] " exit 0 ;; @@ -199,6 +178,33 @@ usage: esac done +if [ -z "$BUILD_WITHOUT_AZURE" ]; then + AZURE_BUILD_DIR="${ROOT_DIR}/cmake_build/azure" + if [ ! -d ${AZURE_BUILD_DIR} ]; then + mkdir -p ${AZURE_BUILD_DIR} + fi + pushd ${AZURE_BUILD_DIR} + env bash ${ROOT_DIR}/scripts/azure_build.sh ${ROOT_DIR} + if [ ! -e libblob-chunk-manager* ]; then + cat vcpkg-bootstrap.log + exit 1 + fi + popd + SYSTEM_NAME=$(uname -s) + if [[ ${SYSTEM_NAME} == "Darwin" ]]; then + SYSTEM_NAME="osx" + elif [[ ${SYSTEM_NAME} == "Linux" ]]; then + SYSTEM_NAME="linux" + fi + ARCHITECTURE=$(uname -m) + if [[ ${ARCHITECTURE} == "x86_64" ]]; then + ARCHITECTURE="x64" + elif [[ ${ARCHITECTURE} == "aarch64" ]]; then + ARCHITECTURE="arm64" + fi + VCPKG_TARGET_TRIPLET=${ARCHITECTURE}-${SYSTEM_NAME} +fi + if [[ ! -d ${BUILD_OUTPUT_DIR} ]]; then mkdir ${BUILD_OUTPUT_DIR} fi @@ -206,24 +212,6 @@ source ${ROOT_DIR}/scripts/setenv.sh CMAKE_GENERATOR="Unix Makefiles" -# MSYS system -if [ "$MSYSTEM" == "MINGW64" ] ; then - BUILD_COVERAGE=OFF - PROFILING=OFF - GPU_VERSION=OFF - WITH_PROMETHEUS=OFF - CUDA_ARCH=OFF - - # extra default cmake args for msys - CMAKE_GENERATOR="MSYS Makefiles" - - # clang tools path - export CLANG_TOOLS_PATH=/mingw64/bin - - # using system blas - export OpenBLAS_HOME="$(cygpath -w /mingw64)" -fi - # UBUNTU system build diskann index if [ "$OS_NAME" == "ubuntu20.04" ] ; then BUILD_DISK_ANN=ON @@ -235,13 +223,6 @@ pushd ${BUILD_OUTPUT_DIR} # Force update the variables each time make rebuild_cache >/dev/null 2>&1 - -if [[ ${MAKE_CLEAN} == "ON" ]]; then - echo "Runing make clean in ${BUILD_OUTPUT_DIR} ..." - make clean - exit 0 -fi - CPU_ARCH=$(get_cpu_arch $CPU_TARGET) arch=$(uname -m) @@ -250,23 +231,22 @@ ${CMAKE_EXTRA_ARGS} \ -DBUILD_UNIT_TEST=${BUILD_UNITTEST} \ -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ --DOpenBLAS_SOURCE=AUTO \ -DCMAKE_CUDA_COMPILER=${CUDA_COMPILER} \ -DCMAKE_LIBRARY_ARCHITECTURE=${arch} \ -DBUILD_COVERAGE=${BUILD_COVERAGE} \ --DMILVUS_DB_PATH=${DB_PATH} \ --DENABLE_CPU_PROFILING=${PROFILING} \ -DMILVUS_GPU_VERSION=${GPU_VERSION} \ --DMILVUS_WITH_PROMETHEUS=${WITH_PROMETHEUS} \ -DMILVUS_CUDA_ARCH=${CUDA_ARCH} \ --DCUSTOM_THIRDPARTY_DOWNLOAD_PATH=${CUSTOM_THIRDPARTY_PATH} \ -DEMBEDDED_MILVUS=${EMBEDDED_MILVUS} \ -DBUILD_DISK_ANN=${BUILD_DISK_ANN} \ -DUSE_ASAN=${USE_ASAN} \ --DOPEN_SIMD=${OPEN_SIMD} \ --DUSE_DYNAMIC_SIMD=${USE_DYNAMIC_SIMD} +-DUSE_DYNAMIC_SIMD=${USE_DYNAMIC_SIMD} \ -DCPU_ARCH=${CPU_ARCH} \ -${CPP_SRC_DIR}" +-DINDEX_ENGINE=${INDEX_ENGINE} " +if [ -z "$BUILD_WITHOUT_AZURE" ]; then +CMAKE_CMD=${CMAKE_CMD}"-DAZURE_BUILD_DIR=${AZURE_BUILD_DIR} \ +-DVCPKG_TARGET_TRIPLET=${VCPKG_TARGET_TRIPLET} " +fi +CMAKE_CMD=${CMAKE_CMD}"${CPP_SRC_DIR}" echo "CC $CC" echo ${CMAKE_CMD} diff --git a/scripts/devcontainer.sh b/scripts/devcontainer.sh index 4781a06dc623c..a3e40297a450d 100755 --- a/scripts/devcontainer.sh +++ b/scripts/devcontainer.sh @@ -80,11 +80,15 @@ if [ "${1-}" = "build" ];then fi if [ "${1-}" = "up" ]; then - docker-compose -f $ROOT_DIR/docker-compose-devcontainer.yml up -d + docker-compose -f $ROOT_DIR/docker-compose-devcontainer.yml up -d $(docker-compose config --services | grep -wv "gpubuilder") fi if [ "${1-}" = "down" ]; then docker-compose -f $ROOT_DIR/docker-compose-devcontainer.yml down fi -popd +if [ "${1-}" = "gpu" -a "${2-}" = "up" ]; then + docker-compose -f $ROOT_DIR/docker-compose-devcontainer.yml up -d $(docker-compose config --services | grep -wv "builder") +fi + +popd \ No newline at end of file diff --git a/scripts/install_deps.sh b/scripts/install_deps.sh index 01b7226bda563..c5b0e3f8a7fd6 100755 --- a/scripts/install_deps.sh +++ b/scripts/install_deps.sh @@ -30,7 +30,7 @@ function install_linux_deps() { sudo yum install -y epel-release centos-release-scl-rh sudo yum install -y wget curl which \ git make automake python3-devel \ - devtoolset-11-gcc devtoolset-11-gcc-c++ devtoolset-11-gcc-gfortran \ + devtoolset-11-gcc devtoolset-11-gcc-c++ devtoolset-11-gcc-gfortran devtoolset-11-libatomic-devel \ llvm-toolset-11.0-clang llvm-toolset-11.0-clang-tools-extra \ libaio libuuid-devel zip unzip \ ccache lcov libtool m4 autoconf automake @@ -56,7 +56,7 @@ function install_linux_deps() { function install_mac_deps() { sudo xcode-select --install > /dev/null 2>&1 - brew install libomp ninja cmake llvm@15 ccache grep pkg-config + brew install libomp ninja cmake llvm@15 ccache grep pkg-config zip unzip export PATH="/usr/local/opt/grep/libexec/gnubin:$PATH" brew update && brew upgrade && brew cleanup @@ -74,12 +74,6 @@ then exit fi -if ! command -v cmake &> /dev/null -then - echo "cmake could not be found, please install it" - exit -fi - unameOut="$(uname -s)" case "${unameOut}" in Linux*) install_linux_deps;; diff --git a/scripts/setenv.sh b/scripts/setenv.sh index 40e0a1a677c27..23a4555db34dc 100755 --- a/scripts/setenv.sh +++ b/scripts/setenv.sh @@ -47,7 +47,7 @@ case "${unameOut}" in export RPATH=$LD_LIBRARY_PATH;; Darwin*) # detect llvm version by valid list - for llvm_version in 16 15 14 NOT_FOUND ; do + for llvm_version in 17 16 15 14 NOT_FOUND ; do if brew ls --versions llvm@${llvm_version} > /dev/null; then break fi diff --git a/scripts/start_cluster.sh b/scripts/start_cluster.sh index 042dc8f2e7311..0a99dad73ee4e 100755 --- a/scripts/start_cluster.sh +++ b/scripts/start_cluster.sh @@ -27,25 +27,25 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then fi echo "Starting rootcoord..." -nohup ./bin/milvus run rootcoord > /tmp/rootcoord.log 2>&1 & +nohup ./bin/milvus run rootcoord --run-with-subprocess > /tmp/rootcoord.log 2>&1 & echo "Starting datacoord..." -nohup ./bin/milvus run datacoord > /tmp/datacoord.log 2>&1 & +nohup ./bin/milvus run datacoord --run-with-subprocess > /tmp/datacoord.log 2>&1 & echo "Starting datanode..." -nohup ./bin/milvus run datanode > /tmp/datanode.log 2>&1 & +nohup ./bin/milvus run datanode --run-with-subprocess > /tmp/datanode.log 2>&1 & echo "Starting proxy..." -nohup ./bin/milvus run proxy > /tmp/proxy.log 2>&1 & +nohup ./bin/milvus run proxy --run-with-subprocess > /tmp/proxy.log 2>&1 & echo "Starting querycoord..." -nohup ./bin/milvus run querycoord > /tmp/querycoord.log 2>&1 & +nohup ./bin/milvus run querycoord --run-with-subprocess > /tmp/querycoord.log 2>&1 & echo "Starting querynode..." -nohup ./bin/milvus run querynode > /tmp/querynode.log 2>&1 & +nohup ./bin/milvus run querynode --run-with-subprocess > /tmp/querynode.log 2>&1 & echo "Starting indexcoord..." -nohup ./bin/milvus run indexcoord > /tmp/indexcoord.log 2>&1 & +nohup ./bin/milvus run indexcoord --run-with-subprocess > /tmp/indexcoord.log 2>&1 & echo "Starting indexnode..." -nohup ./bin/milvus run indexnode > /tmp/indexnode.log 2>&1 & +nohup ./bin/milvus run indexnode --run-with-subprocess > /tmp/indexnode.log 2>&1 & diff --git a/scripts/start_standalone.sh b/scripts/start_standalone.sh index 31732ba0d34a8..6dce9e6d166f5 100755 --- a/scripts/start_standalone.sh +++ b/scripts/start_standalone.sh @@ -27,4 +27,4 @@ if [[ "$OSTYPE" == "linux-gnu"* ]]; then fi echo "Starting standalone..." -nohup ./bin/milvus run standalone > /tmp/standalone.log 2>&1 & +nohup ./bin/milvus run standalone --run-with-subprocess > /tmp/standalone.log 2>&1 & diff --git a/scripts/stop.sh b/scripts/stop.sh index 7b6248a6cd19d..61c28bddde7bc 100755 --- a/scripts/stop.sh +++ b/scripts/stop.sh @@ -15,11 +15,11 @@ # limitations under the License. echo "Stopping milvus..." -PROCESS=$(ps -e | grep milvus | grep -v grep | awk '{print $1}') +PROCESS=$(ps -e | grep milvus | grep -v grep | grep run-with-subprocess | awk '{print $1}') if [ -z "$PROCESS" ]; then echo "No milvus process" exit 0 fi -kill -9 $PROCESS +kill -15 $PROCESS echo "Milvus stopped" diff --git a/tests/docker/.env b/tests/docker/.env index ff7052d341e7e..3b29f89250165 100644 --- a/tests/docker/.env +++ b/tests/docker/.env @@ -3,5 +3,5 @@ MILVUS_SERVICE_PORT=19530 MILVUS_PYTEST_WORKSPACE=/milvus/tests/python_client MILVUS_PYTEST_LOG_PATH=/milvus/_artifacts/tests/pytest_logs IMAGE_REPO=milvusdb -IMAGE_TAG=20230830-a8e5dc3 -LATEST_IMAGE_TAG=20230830-a8e5dc3 +IMAGE_TAG=20231019-020ad9a +LATEST_IMAGE_TAG=20231019-020ad9a diff --git a/tests/integration/bulkinsert/bulkinsert_test.go b/tests/integration/bulkinsert/bulkinsert_test.go index ead972a62e471..1c8da6d7d55ac 100644 --- a/tests/integration/bulkinsert/bulkinsert_test.go +++ b/tests/integration/bulkinsert/bulkinsert_test.go @@ -65,7 +65,7 @@ func (s *BulkInsertSuite) TestBulkInsert() { prefix := "TestBulkInsert" dbName := "" collectionName := prefix + funcutil.GenRandomStr() - //floatVecField := floatVecField + // floatVecField := floatVecField dim := 128 schema := integration.ConstructSchema(collectionName, dim, true, @@ -187,7 +187,7 @@ func (s *BulkInsertSuite) TestBulkInsert() { s.WaitForLoad(ctx, collectionName) // search - expr := "" //fmt.Sprintf("%s > 0", int64Field) + expr := "" // fmt.Sprintf("%s > 0", int64Field) nq := 10 topk := 10 roundDecimal := -1 @@ -212,6 +212,7 @@ func (s *BulkInsertSuite) TestBulkInsert() { } func TestBulkInsert(t *testing.T) { + t.Skip("Skip integration test, need to refactor integration test framework") suite.Run(t, new(BulkInsertSuite)) } @@ -236,18 +237,18 @@ func GenerateNumpyFile(filePath string, rowCount int, dType schemapb.DataType, t if err != nil { return err } - //data := make([][]float32, rowCount) + // data := make([][]float32, rowCount) var data [][Dim]float32 for i := 0; i < rowCount; i++ { vec := [Dim]float32{} for j := 0; j < dim; j++ { vec[j] = 1.1 } - //v := reflect.Indirect(reflect.ValueOf(vec)) - //log.Info("type", zap.Any("type", v.Kind())) + // v := reflect.Indirect(reflect.ValueOf(vec)) + // log.Info("type", zap.Any("type", v.Kind())) data = append(data, vec) - //v2 := reflect.Indirect(reflect.ValueOf(data)) - //log.Info("type", zap.Any("type", v2.Kind())) + // v2 := reflect.Indirect(reflect.ValueOf(data)) + // log.Info("type", zap.Any("type", v2.Kind())) } err = importutil.CreateNumpyFile(filePath, data) if err != nil { @@ -259,6 +260,7 @@ func GenerateNumpyFile(filePath string, rowCount int, dType schemapb.DataType, t } func TestGenerateNumpyFile(t *testing.T) { + t.Skip("Skip integration test, need to refactor integration test framework") err := os.MkdirAll(TempFilesPath, os.ModePerm) require.NoError(t, err) err = GenerateNumpyFile(TempFilesPath+"embeddings.npy", 100, schemapb.DataType_FloatVector, []*commonpb.KeyValuePair{ diff --git a/tests/integration/crossclusterrouting/cross_cluster_routing_test.go b/tests/integration/crossclusterrouting/cross_cluster_routing_test.go index 681b22dd3b318..08cc1eeb99ade 100644 --- a/tests/integration/crossclusterrouting/cross_cluster_routing_test.go +++ b/tests/integration/crossclusterrouting/cross_cluster_routing_test.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "math/rand" + "strconv" "strings" "testing" "time" @@ -29,16 +30,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/proto/indexpb" - "github.com/milvus-io/milvus/internal/proto/proxypb" - "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/pkg/util/commonpbutil" - "github.com/milvus-io/milvus/pkg/util/etcd" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/paramtable" - grpcdatacoord "github.com/milvus-io/milvus/internal/distributed/datacoord" grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord/client" grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode" @@ -53,6 +44,15 @@ import ( grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client" grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord" grpcrootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) type CrossClusterRoutingSuite struct { @@ -88,10 +88,13 @@ func (s *CrossClusterRoutingSuite) SetupSuite() { rand.Seed(time.Now().UnixNano()) paramtable.Init() + + paramtable.Get().Save("grpc.client.maxMaxAttempts", "1") s.factory = dependency.NewDefaultFactory(true) } func (s *CrossClusterRoutingSuite) TearDownSuite() { + paramtable.Get().Save("grpc.client.maxMaxAttempts", strconv.FormatInt(paramtable.DefaultMaxAttempts, 10)) } func (s *CrossClusterRoutingSuite) SetupTest() { diff --git a/tests/integration/getvector/get_vector_test.go b/tests/integration/getvector/get_vector_test.go index d5c474306921f..c3addb30f875e 100644 --- a/tests/integration/getvector/get_vector_test.go +++ b/tests/integration/getvector/get_vector_test.go @@ -47,9 +47,6 @@ type TestGetVectorSuite struct { metricType string pkType schemapb.DataType vecType schemapb.DataType - - // expected - searchFailed bool } func (s *TestGetVectorSuite) run() { @@ -125,6 +122,8 @@ func (s *TestGetVectorSuite) run() { var vecFieldData *schemapb.FieldData if s.vecType == schemapb.DataType_FloatVector { vecFieldData = integration.NewFloatVectorFieldData(vecFieldName, NB, dim) + } else if s.vecType == schemapb.DataType_Float16Vector { + vecFieldData = integration.NewFloat16VectorFieldData(vecFieldName, NB, dim) } else { vecFieldData = integration.NewBinaryVectorFieldData(vecFieldName, NB, dim) } @@ -150,12 +149,14 @@ func (s *TestGetVectorSuite) run() { ids := segmentIDs.GetData() s.Require().NotEmpty(segmentIDs) s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collection] + s.Require().True(has) segments, err := s.Cluster.MetaWatcher.ShowSegments() s.Require().NoError(err) s.Require().NotEmpty(segments) - s.WaitForFlush(ctx, ids) + s.WaitForFlush(ctx, ids, flushTs, s.dbName, collection) // create index _, err = s.Cluster.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ @@ -190,11 +191,6 @@ func (s *TestGetVectorSuite) run() { searchResp, err := s.Cluster.Proxy.Search(ctx, searchReq) s.Require().NoError(err) - if s.searchFailed { - s.Require().NotEqual(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) - s.T().Logf("reason:%s", searchResp.GetStatus().GetReason()) - return - } s.Require().Equal(searchResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success) result := searchResp.GetResults() @@ -205,7 +201,7 @@ func (s *TestGetVectorSuite) run() { } s.Require().Len(result.GetScores(), nq*topk) s.Require().GreaterOrEqual(len(result.GetFieldsData()), 1) - var vecFieldIndex = -1 + vecFieldIndex := -1 for i, fieldData := range result.GetFieldsData() { if typeutil.IsVectorType(fieldData.GetType()) { vecFieldIndex = i @@ -235,6 +231,25 @@ func (s *TestGetVectorSuite) run() { s.Require().ElementsMatch(expect, actual) } } + } else if s.vecType == schemapb.DataType_Float16Vector { + // s.Require().Len(result.GetFieldsData()[vecFieldIndex].GetVectors().GetFloat16Vector(), nq*topk*dim*2) + // rawData := vecFieldData.GetVectors().GetFloat16Vector() + // resData := result.GetFieldsData()[vecFieldIndex].GetVectors().GetFloat16Vector() + // if s.pkType == schemapb.DataType_Int64 { + // for i, id := range result.GetIds().GetIntId().GetData() { + // expect := rawData[int(id)*dim : (int(id)+1)*dim] + // actual := resData[i*dim : (i+1)*dim] + // s.Require().ElementsMatch(expect, actual) + // } + // } else { + // for i, idStr := range result.GetIds().GetStrId().GetData() { + // id, err := strconv.Atoi(idStr) + // s.Require().NoError(err) + // expect := rawData[id*dim : (id+1)*dim] + // actual := resData[i*dim : (i+1)*dim] + // s.Require().ElementsMatch(expect, actual) + // } + // } } else { s.Require().Len(result.GetFieldsData()[vecFieldIndex].GetVectors().GetBinaryVector(), nq*topk*dim/8) rawData := vecFieldData.GetVectors().GetBinaryVector() @@ -277,7 +292,16 @@ func (s *TestGetVectorSuite) TestGetVector_FLAT() { s.metricType = metric.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector - s.searchFailed = false + s.run() +} + +func (s *TestGetVectorSuite) TestGetVector_Float16Vector() { + s.nq = 10 + s.topK = 10 + s.indexType = integration.IndexFaissIDMap + s.metricType = metric.L2 + s.pkType = schemapb.DataType_Int64 + s.vecType = schemapb.DataType_Float16Vector s.run() } @@ -288,7 +312,6 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_FLAT() { s.metricType = metric.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector - s.searchFailed = false s.run() } @@ -299,7 +322,6 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_PQ() { s.metricType = metric.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector - s.searchFailed = true s.run() } @@ -310,7 +332,6 @@ func (s *TestGetVectorSuite) TestGetVector_SCANN() { s.metricType = metric.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector - s.searchFailed = false s.run() } @@ -321,7 +342,16 @@ func (s *TestGetVectorSuite) TestGetVector_IVF_SQ8() { s.metricType = metric.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector - s.searchFailed = true + s.run() +} + +func (s *TestGetVectorSuite) TestGetVector_IVF_SQ8_StrPK() { + s.nq = 10 + s.topK = 10 + s.indexType = integration.IndexFaissIvfSQ8 + s.metricType = metric.L2 + s.pkType = schemapb.DataType_VarChar + s.vecType = schemapb.DataType_FloatVector s.run() } @@ -332,7 +362,6 @@ func (s *TestGetVectorSuite) TestGetVector_HNSW() { s.metricType = metric.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector - s.searchFailed = false s.run() } @@ -343,7 +372,6 @@ func (s *TestGetVectorSuite) TestGetVector_IP() { s.metricType = metric.IP s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector - s.searchFailed = false s.run() } @@ -354,7 +382,6 @@ func (s *TestGetVectorSuite) TestGetVector_StringPK() { s.metricType = metric.L2 s.pkType = schemapb.DataType_VarChar s.vecType = schemapb.DataType_FloatVector - s.searchFailed = false s.run() } @@ -365,7 +392,6 @@ func (s *TestGetVectorSuite) TestGetVector_BinaryVector() { s.metricType = metric.JACCARD s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_BinaryVector - s.searchFailed = false s.run() } @@ -377,7 +403,6 @@ func (s *TestGetVectorSuite) TestGetVector_Big_NQ_TOPK() { s.metricType = metric.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector - s.searchFailed = false s.run() } @@ -389,21 +414,30 @@ func (s *TestGetVectorSuite) TestGetVector_With_DB_Name() { s.metricType = metric.L2 s.pkType = schemapb.DataType_Int64 s.vecType = schemapb.DataType_FloatVector - s.searchFailed = false s.run() } -//func (s *TestGetVectorSuite) TestGetVector_DISKANN() { +//func (s *TestGetVectorSuite) TestGetVector_DISKANN_L2() { +// s.nq = 10 +// s.topK = 10 +// s.indexType = integration.IndexDISKANN +// s.metricType = metric.L2 +// s.pkType = schemapb.DataType_Int64 +// s.vecType = schemapb.DataType_FloatVector +// s.run() +//} + +//func (s *TestGetVectorSuite) TestGetVector_DISKANN_IP() { // s.nq = 10 // s.topK = 10 // s.indexType = integration.IndexDISKANN -// s.metricType = distance.L2 +// s.metricType = metric.IP // s.pkType = schemapb.DataType_Int64 // s.vecType = schemapb.DataType_FloatVector -// s.searchFailed = false // s.run() //} func TestGetVector(t *testing.T) { + t.Skip("Skip integration test, need to refactor integration test framework") suite.Run(t, new(TestGetVectorSuite)) } diff --git a/tests/integration/hellomilvus/hello_milvus_test.go b/tests/integration/hellomilvus/hello_milvus_test.go index ea9f3c4d40346..437c35e485d03 100644 --- a/tests/integration/hellomilvus/hello_milvus_test.go +++ b/tests/integration/hellomilvus/hello_milvus_test.go @@ -94,7 +94,9 @@ func (s *HelloMilvusSuite) TestHelloMilvus() { s.NoError(err) segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] ids := segmentIDs.GetData() - s.NotEmpty(segmentIDs) + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] s.True(has) segments, err := c.MetaWatcher.ShowSegments() @@ -103,7 +105,7 @@ func (s *HelloMilvusSuite) TestHelloMilvus() { for _, segment := range segments { log.Info("ShowSegments result", zap.String("segment", segment.String())) } - s.WaitForFlush(ctx, ids) + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) // create index createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ @@ -151,9 +153,9 @@ func (s *HelloMilvusSuite) TestHelloMilvus() { s.Equal(commonpb.ErrorCode_Success, searchResult.GetStatus().GetErrorCode()) log.Info("TestHelloMilvus succeed") - } func TestHelloMilvus(t *testing.T) { + t.Skip("Skip integration test, need to refactor integration test framework") suite.Run(t, new(HelloMilvusSuite)) } diff --git a/tests/integration/indexstat/get_index_statistics_test.go b/tests/integration/indexstat/get_index_statistics_test.go index ae0c2c772e3d3..4f77f52464b2a 100644 --- a/tests/integration/indexstat/get_index_statistics_test.go +++ b/tests/integration/indexstat/get_index_statistics_test.go @@ -68,9 +68,11 @@ func (s *GetIndexStatisticsSuite) TestGetIndexStatistics() { s.NoError(err) segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] ids := segmentIDs.GetData() - s.NotEmpty(segmentIDs) + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] s.Equal(true, has) - s.WaitForFlush(ctx, ids) + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) // create index indexName := "_default" @@ -152,5 +154,6 @@ func (s *GetIndexStatisticsSuite) TestGetIndexStatistics() { } func TestGetIndexStat(t *testing.T) { + t.Skip("Skip integration test, need to refactor integration test framework") suite.Run(t, new(GetIndexStatisticsSuite)) } diff --git a/tests/integration/insert/insert_test.go b/tests/integration/insert/insert_test.go index 02b4f6cb5ce60..54c190b541f4a 100644 --- a/tests/integration/insert/insert_test.go +++ b/tests/integration/insert/insert_test.go @@ -118,9 +118,9 @@ func (s *InsertSuite) TestInsert() { log.Info("TestInsert succeed") log.Info("==================") log.Info("==================") - } func TestInsert(t *testing.T) { + t.Skip("Skip integration test, need to refactor integration test framework") suite.Run(t, new(InsertSuite)) } diff --git a/tests/integration/jsonexpr/json_expr_test.go b/tests/integration/jsonexpr/json_expr_test.go index 4b665434a6f20..a8a97a9394eb1 100644 --- a/tests/integration/jsonexpr/json_expr_test.go +++ b/tests/integration/jsonexpr/json_expr_test.go @@ -24,19 +24,18 @@ import ( "testing" "time" - "github.com/cockroachdb/errors" - "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/tests/integration" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/suite" + "go.uber.org/zap" - "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/metric" - "go.uber.org/zap" + "github.com/milvus-io/milvus/tests/integration" ) type JSONExprSuite struct { @@ -733,7 +732,10 @@ func (s *JSONExprSuite) insertFlushIndexLoad(ctx context.Context, dbName, collec s.NoError(err) segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] ids := segmentIDs.GetData() - s.NotEmpty(segmentIDs) + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) segments, err := s.Cluster.MetaWatcher.ShowSegments() s.NoError(err) @@ -741,28 +743,7 @@ func (s *JSONExprSuite) insertFlushIndexLoad(ctx context.Context, dbName, collec for _, segment := range segments { log.Info("ShowSegments result", zap.String("segment", segment.String())) } - - if has && len(ids) > 0 { - flushed := func() bool { - resp, err := s.Cluster.Proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ - SegmentIDs: ids, - }) - if err != nil { - //panic(errors.New("GetFlushState failed")) - return false - } - return resp.GetFlushed() - } - for !flushed() { - // respect context deadline/cancel - select { - case <-ctx.Done(): - panic(errors.New("deadline exceeded")) - default: - } - time.Sleep(500 * time.Millisecond) - } - } + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) // create index createIndexStatus, err := s.Cluster.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ @@ -1160,5 +1141,6 @@ func (s *JSONExprSuite) TestJsonContains() { } func TestJsonExpr(t *testing.T) { + t.Skip("Skip integration test, need to refactor integration test framework") suite.Run(t, new(JSONExprSuite)) } diff --git a/tests/integration/meta_watcher.go b/tests/integration/meta_watcher.go index bc47488c671a0..b4cd2140b0c89 100644 --- a/tests/integration/meta_watcher.go +++ b/tests/integration/meta_watcher.go @@ -25,13 +25,13 @@ import ( "time" "github.com/golang/protobuf/proto" - "github.com/milvus-io/milvus/internal/util/sessionutil" - "github.com/milvus-io/milvus/pkg/log" clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/log" ) // MetaWatcher to observe meta data of milvus cluster @@ -117,7 +117,6 @@ func listReplicas(cli *clientv3.Client, prefix string) ([]*querypb.Replica, erro ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) defer cancel() resp, err := cli.Get(ctx, prefix, clientv3.WithPrefix()) - if err != nil { return nil, err } diff --git a/tests/integration/meta_watcher_test.go b/tests/integration/meta_watcher_test.go index 77f56b439f924..9ce6bd00f81ad 100644 --- a/tests/integration/meta_watcher_test.go +++ b/tests/integration/meta_watcher_test.go @@ -22,7 +22,6 @@ import ( "testing" "time" - "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/suite" "go.uber.org/zap" @@ -226,9 +225,6 @@ func (s *MetaWatcherSuite) TestShowReplicas() { CollectionNames: []string{collectionName}, }) s.NoError(err) - segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] - ids := segmentIDs.GetData() - s.NotEmpty(segmentIDs) segments, err := c.MetaWatcher.ShowSegments() s.NoError(err) @@ -236,28 +232,13 @@ func (s *MetaWatcherSuite) TestShowReplicas() { for _, segment := range segments { log.Info("ShowSegments result", zap.String("segment", segment.String())) } - - if has && len(ids) > 0 { - flushed := func() bool { - resp, err := c.Proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ - SegmentIDs: ids, - }) - if err != nil { - //panic(errors.New("GetFlushState failed")) - return false - } - return resp.GetFlushed() - } - for !flushed() { - // respect context deadline/cancel - select { - case <-ctx.Done(): - panic(errors.New("deadline exceeded")) - default: - } - time.Sleep(500 * time.Millisecond) - } - } + segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + ids := segmentIDs.GetData() + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) // create index createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ @@ -325,5 +306,6 @@ func (s *MetaWatcherSuite) TestShowReplicas() { } func TestMetaWatcher(t *testing.T) { + t.Skip("Skip integration test, need to refactor integration test framework") suite.Run(t, new(MetaWatcherSuite)) } diff --git a/tests/integration/minicluster.go b/tests/integration/minicluster.go index f73d68674ed53..dd9221c4a611b 100644 --- a/tests/integration/minicluster.go +++ b/tests/integration/minicluster.go @@ -16,6 +16,8 @@ package integration +import "C" + import ( "context" "fmt" @@ -30,6 +32,13 @@ import ( "github.com/milvus-io/milvus/internal/datacoord" "github.com/milvus-io/milvus/internal/datanode" + datacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord/client" + datanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client" + indexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client" + proxyclient "github.com/milvus-io/milvus/internal/distributed/proxy/client" + querycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client" + querynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client" + rootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" "github.com/milvus-io/milvus/internal/indexnode" proxy2 "github.com/milvus-io/milvus/internal/proxy" querycoord "github.com/milvus-io/milvus/internal/querycoordv2" @@ -53,7 +62,7 @@ type Cluster interface { AddRootCoord(types.RootCoordComponent) error AddDataCoord(types.DataCoordComponent) error AddQueryCoord(types.QueryCoordComponent) error - //AddIndexCoord(types.IndexCoordComponent) error + // AddIndexCoord(types.IndexCoordComponent) error AddDataNode(types.DataNodeComponent) error AddQueryNode(types.QueryNodeComponent) error AddIndexNode(types.IndexNodeComponent) error @@ -61,7 +70,7 @@ type Cluster interface { RemoveRootCoord(types.RootCoordComponent) error RemoveDataCoord(types.DataCoordComponent) error RemoveQueryCoord(types.QueryCoordComponent) error - //RemoveIndexCoord(types.IndexCoordComponent) error + // RemoveIndexCoord(types.IndexCoordComponent) error RemoveDataNode(types.DataNodeComponent) error RemoveQueryNode(types.QueryNodeComponent) error RemoveIndexNode(types.IndexNodeComponent) error @@ -76,12 +85,12 @@ type Cluster interface { } type ClusterConfig struct { - //ProxyNum int + // ProxyNum int // todo coord num can be more than 1 if enable Active-Standby - //RootCoordNum int - //DataCoordNum int - //IndexCoordNum int - //QueryCoordNum int + // RootCoordNum int + // DataCoordNum int + // IndexCoordNum int + // QueryCoordNum int QueryNodeNum int DataNodeNum int IndexNodeNum int @@ -110,6 +119,10 @@ type MiniCluster struct { RootCoord types.RootCoordComponent QueryCoord types.QueryCoordComponent + DataCoordClient types.DataCoordClient + RootCoordClient types.RootCoordClient + QueryCoordClient types.QueryCoordClient + QueryNodes []types.QueryNodeComponent DataNodes []types.DataNodeComponent IndexNodes []types.IndexNodeComponent @@ -256,59 +269,59 @@ func StartMiniCluster(ctx context.Context, opts ...Option) (cluster *MiniCluster cluster.Proxy = proxy } - //cluster.dataCoord.SetIndexCoord(cluster.indexCoord) - cluster.DataCoord.SetRootCoord(cluster.RootCoord) + // cluster.dataCoord.SetIndexCoord(cluster.indexCoord) + cluster.DataCoord.SetRootCoordClient(cluster.GetRootCoordClient()) - err = cluster.RootCoord.SetDataCoord(cluster.DataCoord) + err = cluster.RootCoord.SetDataCoordClient(cluster.GetDataCoordClient()) if err != nil { - return + return nil, err } //err = cluster.rootCoord.SetIndexCoord(cluster.indexCoord) //if err != nil { // return //} - err = cluster.RootCoord.SetQueryCoord(cluster.QueryCoord) + err = cluster.RootCoord.SetQueryCoordClient(cluster.GetQueryCoordClient()) if err != nil { - return + return nil, err } - //err = cluster.queryCoord.SetIndexCoord(cluster.indexCoord) - if err != nil { - return - } - err = cluster.QueryCoord.SetDataCoord(cluster.DataCoord) + // err = cluster.queryCoord.SetIndexCoord(cluster.indexCoord) + //if err != nil { + // return + //} + err = cluster.QueryCoord.SetDataCoordClient(cluster.GetDataCoordClient()) if err != nil { - return + return nil, err } - err = cluster.QueryCoord.SetRootCoord(cluster.RootCoord) + err = cluster.QueryCoord.SetRootCoordClient(cluster.GetRootCoordClient()) if err != nil { - return + return nil, err } - //err = cluster.indexCoord.SetDataCoord(cluster.dataCoord) + //err = cluster.indexCoord.SetDataCoordClient(cluster.GetDataCoordClient()) //if err != nil { // return //} - //err = cluster.indexCoord.SetRootCoord(cluster.rootCoord) + //err = cluster.indexCoord.SetRootCoordClient(cluster.GetRootCoordClient()) //if err != nil { // return //} for _, dataNode := range cluster.DataNodes { - err = dataNode.SetDataCoord(cluster.DataCoord) + err = dataNode.SetDataCoordClient(cluster.GetDataCoordClient()) if err != nil { - return + return nil, err } - err = dataNode.SetRootCoord(cluster.RootCoord) + err = dataNode.SetRootCoordClient(cluster.GetRootCoordClient()) if err != nil { - return + return nil, err } } - cluster.Proxy.SetDataCoordClient(cluster.DataCoord) - //cluster.proxy.SetIndexCoordClient(cluster.indexCoord) - cluster.Proxy.SetQueryCoordClient(cluster.QueryCoord) - cluster.Proxy.SetRootCoordClient(cluster.RootCoord) + cluster.Proxy.SetDataCoordClient(cluster.GetDataCoordClient()) + // cluster.proxy.SetIndexCoordClient(cluster.indexCoord) + cluster.Proxy.SetQueryCoordClient(cluster.GetQueryCoordClient()) + cluster.Proxy.SetRootCoordClient(cluster.GetRootCoordClient()) return cluster, nil } @@ -438,7 +451,7 @@ func (cluster *MiniCluster) Stop() error { log.Info("mini cluster rootCoord stopped") cluster.DataCoord.Stop() log.Info("mini cluster dataCoord stopped") - //cluster.indexCoord.Stop() + // cluster.indexCoord.Stop() cluster.QueryCoord.Stop() log.Info("mini cluster queryCoord stopped") cluster.Proxy.Stop() @@ -474,6 +487,10 @@ func (cluster *MiniCluster) Stop() error { return nil } +func GetMetaRootPath(rootPath string) string { + return fmt.Sprintf("%s/%s", rootPath, params.EtcdCfg.MetaSubPath.GetValue()) +} + func DefaultParams() map[string]string { testPath := fmt.Sprintf("integration-test-%d", time.Now().Unix()) return map[string]string{ @@ -672,15 +689,15 @@ func (cluster *MiniCluster) AddRootCoord(rootCoord types.RootCoordComponent) err } // link - rootCoord.SetDataCoord(cluster.DataCoord) - rootCoord.SetQueryCoord(cluster.QueryCoord) - //rootCoord.SetIndexCoord(cluster.indexCoord) - cluster.DataCoord.SetRootCoord(rootCoord) - cluster.QueryCoord.SetRootCoord(rootCoord) - //cluster.indexCoord.SetRootCoord(rootCoord) - cluster.Proxy.SetRootCoordClient(rootCoord) + rootCoord.SetDataCoordClient(cluster.GetDataCoordClient()) + rootCoord.SetQueryCoordClient(cluster.GetQueryCoordClient()) + // rootCoord.SetIndexCoord(cluster.indexCoord) + cluster.DataCoord.SetRootCoordClient(cluster.GetRootCoordClient()) + cluster.QueryCoord.SetRootCoordClient(cluster.GetRootCoordClient()) + // cluster.indexCoord.SetRootCoordClient(rootCoord) + cluster.Proxy.SetRootCoordClient(cluster.GetRootCoordClient()) for _, dataNode := range cluster.DataNodes { - err = dataNode.SetRootCoord(rootCoord) + err = dataNode.SetRootCoordClient(cluster.GetRootCoordClient()) if err != nil { return err } @@ -740,23 +757,23 @@ func (cluster *MiniCluster) AddDataCoord(dataCoord types.DataCoordComponent) err } // link - //dataCoord.SetIndexCoord(cluster.indexCoord) - dataCoord.SetRootCoord(cluster.RootCoord) - err = cluster.RootCoord.SetDataCoord(cluster.DataCoord) + // dataCoord.SetIndexCoord(cluster.indexCoord) + dataCoord.SetRootCoordClient(cluster.GetRootCoordClient()) + err = cluster.RootCoord.SetDataCoordClient(cluster.GetDataCoordClient()) if err != nil { return err } - err = cluster.QueryCoord.SetDataCoord(cluster.DataCoord) + err = cluster.QueryCoord.SetDataCoordClient(cluster.GetDataCoordClient()) if err != nil { return err } - //err = cluster.indexCoord.SetDataCoord(cluster.dataCoord) + //err = cluster.indexCoord.SetDataCoordClient(cluster.GetDataCoordClient()) //if err != nil { // return err //} - cluster.Proxy.SetDataCoordClient(dataCoord) + cluster.Proxy.SetDataCoordClient(cluster.GetDataCoordClient()) for _, dataNode := range cluster.DataNodes { - err = dataNode.SetDataCoord(dataCoord) + err = dataNode.SetDataCoordClient(cluster.GetDataCoordClient()) if err != nil { return err } @@ -816,11 +833,11 @@ func (cluster *MiniCluster) AddQueryCoord(queryCoord types.QueryCoordComponent) } // link - queryCoord.SetRootCoord(cluster.RootCoord) - queryCoord.SetDataCoord(cluster.DataCoord) - //queryCoord.SetIndexCoord(cluster.indexCoord) - cluster.RootCoord.SetQueryCoord(queryCoord) - cluster.Proxy.SetQueryCoordClient(queryCoord) + queryCoord.SetRootCoordClient(cluster.GetRootCoordClient()) + queryCoord.SetDataCoordClient(cluster.GetDataCoordClient()) + // queryCoord.SetIndexCoord(cluster.indexCoord) + cluster.RootCoord.SetQueryCoordClient(cluster.GetQueryCoordClient()) + cluster.Proxy.SetQueryCoordClient(cluster.GetQueryCoordClient()) // start err = queryCoord.Init() @@ -876,8 +893,8 @@ func (cluster *MiniCluster) RemoveQueryCoord(queryCoord types.QueryCoordComponen // } // // // link -// indexCoord.SetDataCoord(cluster.dataCoord) -// indexCoord.SetRootCoord(cluster.rootCoord) +// indexCoord.SetDataCoordClient(cluster.GetDataCoordClient()) +// indexCoord.SetRootCoordClient(cluster.GetRootCoordClient()) // //cluster.dataCoord.SetIndexCoord(indexCoord) // cluster.queryCoord.SetIndexCoord(indexCoord) // //cluster.rootCoord.SetIndexCoord(indexCoord) @@ -932,11 +949,11 @@ func (cluster *MiniCluster) AddDataNode(dataNode types.DataNodeComponent) error return err } } - err = dataNode.SetDataCoord(cluster.DataCoord) + err = dataNode.SetDataCoordClient(cluster.GetDataCoordClient()) if err != nil { return err } - err = dataNode.SetRootCoord(cluster.RootCoord) + err = dataNode.SetRootCoordClient(cluster.GetRootCoordClient()) if err != nil { return err } @@ -1131,8 +1148,8 @@ func (cluster *MiniCluster) UpdateClusterSize(clusterConfig ClusterConfig) error return errors.New("Illegal cluster size config") } // todo concurrent concerns - //cluster.mu.Lock() - //defer cluster.mu.Unlock() + // cluster.mu.Lock() + // defer cluster.mu.Unlock() if clusterConfig.DataNodeNum > len(cluster.DataNodes) { needAdd := clusterConfig.DataNodeNum - len(cluster.DataNodes) for i := 0; i < needAdd; i++ { @@ -1180,43 +1197,88 @@ func (cluster *MiniCluster) UpdateClusterSize(clusterConfig ClusterConfig) error return nil } -func (cluster *MiniCluster) GetProxy(ctx context.Context, addr string, nodeID int64) (types.Proxy, error) { +func (cluster *MiniCluster) GetRootCoordClient() types.RootCoordClient { + cluster.mu.Lock() + defer cluster.mu.Unlock() + if cluster.RootCoordClient != nil { + return cluster.RootCoordClient + } + + client, err := rootcoordclient.NewClient(cluster.ctx, GetMetaRootPath(cluster.params[EtcdRootPath]), cluster.EtcdCli) + if err != nil { + panic(err) + } + cluster.RootCoordClient = client + return client +} + +func (cluster *MiniCluster) GetDataCoordClient() types.DataCoordClient { + cluster.mu.Lock() + defer cluster.mu.Unlock() + if cluster.DataCoordClient != nil { + return cluster.DataCoordClient + } + + client, err := datacoordclient.NewClient(cluster.ctx, GetMetaRootPath(cluster.params[EtcdRootPath]), cluster.EtcdCli) + if err != nil { + panic(err) + } + cluster.DataCoordClient = client + return client +} + +func (cluster *MiniCluster) GetQueryCoordClient() types.QueryCoordClient { + cluster.mu.Lock() + defer cluster.mu.Unlock() + if cluster.QueryCoordClient != nil { + return cluster.QueryCoordClient + } + + client, err := querycoordclient.NewClient(cluster.ctx, GetMetaRootPath(cluster.params[EtcdRootPath]), cluster.EtcdCli) + if err != nil { + panic(err) + } + cluster.QueryCoordClient = client + return client +} + +func (cluster *MiniCluster) GetProxy(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) { cluster.mu.RLock() defer cluster.mu.RUnlock() if cluster.Proxy.GetAddress() == addr { - return cluster.Proxy, nil + return proxyclient.NewClient(ctx, addr, nodeID) } return nil, nil } -func (cluster *MiniCluster) GetQueryNode(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error) { +func (cluster *MiniCluster) GetQueryNode(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) { cluster.mu.RLock() defer cluster.mu.RUnlock() for _, queryNode := range cluster.QueryNodes { if queryNode.GetAddress() == addr { - return queryNode, nil + return querynodeclient.NewClient(ctx, addr, nodeID) } } return nil, errors.New("no related queryNode found") } -func (cluster *MiniCluster) GetDataNode(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) { +func (cluster *MiniCluster) GetDataNode(ctx context.Context, addr string, nodeID int64) (types.DataNodeClient, error) { cluster.mu.RLock() defer cluster.mu.RUnlock() for _, dataNode := range cluster.DataNodes { if dataNode.GetAddress() == addr { - return dataNode, nil + return datanodeclient.NewClient(ctx, addr, nodeID) } } return nil, errors.New("no related dataNode found") } -func (cluster *MiniCluster) GetIndexNode(ctx context.Context, addr string, nodeID int64) (types.IndexNode, error) { +func (cluster *MiniCluster) GetIndexNode(ctx context.Context, addr string, nodeID int64) (types.IndexNodeClient, error) { cluster.mu.RLock() defer cluster.mu.RUnlock() for _, indexNode := range cluster.IndexNodes { if indexNode.GetAddress() == addr { - return indexNode, nil + return indexnodeclient.NewClient(ctx, addr, nodeID, false) } } return nil, errors.New("no related indexNode found") diff --git a/tests/integration/minicluster_test.go b/tests/integration/minicluster_test.go index 3cd31ecfd626e..6aca36df52b68 100644 --- a/tests/integration/minicluster_test.go +++ b/tests/integration/minicluster_test.go @@ -32,7 +32,7 @@ type MiniClusterMethodsSuite struct { } func (s *MiniClusterMethodsSuite) TestStartAndStop() { - //Do nothing + // Do nothing } func (s *MiniClusterMethodsSuite) TestRemoveDataNode() { @@ -42,7 +42,7 @@ func (s *MiniClusterMethodsSuite) TestRemoveDataNode() { datanode := datanode.NewDataNode(ctx, c.factory) datanode.SetEtcdClient(c.EtcdCli) - //datanode := c.CreateDefaultDataNode() + // datanode := c.CreateDefaultDataNode() err := c.AddDataNode(datanode) s.NoError(err) @@ -77,7 +77,7 @@ func (s *MiniClusterMethodsSuite) TestRemoveQueryNode() { queryNode := querynodev2.NewQueryNode(ctx, c.factory) queryNode.SetEtcdClient(c.EtcdCli) - //queryNode := c.CreateDefaultQueryNode() + // queryNode := c.CreateDefaultQueryNode() err := c.AddQueryNode(queryNode) s.NoError(err) @@ -103,7 +103,6 @@ func (s *MiniClusterMethodsSuite) TestRemoveQueryNode() { s.Equal(1, c.clusterConfig.QueryNodeNum) s.Equal(1, len(c.QueryNodes)) - } func (s *MiniClusterMethodsSuite) TestRemoveIndexNode() { @@ -113,7 +112,7 @@ func (s *MiniClusterMethodsSuite) TestRemoveIndexNode() { indexNode := indexnode.NewIndexNode(ctx, c.factory) indexNode.SetEtcdClient(c.EtcdCli) - //indexNode := c.CreateDefaultIndexNode() + // indexNode := c.CreateDefaultIndexNode() err := c.AddIndexNode(indexNode) s.NoError(err) @@ -139,11 +138,9 @@ func (s *MiniClusterMethodsSuite) TestRemoveIndexNode() { s.Equal(1, c.clusterConfig.IndexNodeNum) s.Equal(1, len(c.IndexNodes)) - } func (s *MiniClusterMethodsSuite) TestUpdateClusterSize() { - c := s.Cluster err := c.UpdateClusterSize(ClusterConfig{ @@ -185,5 +182,6 @@ func (s *MiniClusterMethodsSuite) TestUpdateClusterSize() { } func TestMiniCluster(t *testing.T) { + t.Skip("Skip integration test, need to refactor integration test framework") suite.Run(t, new(MiniClusterMethodsSuite)) } diff --git a/tests/integration/rangesearch/range_search_test.go b/tests/integration/rangesearch/range_search_test.go index 62d58d1e4d843..aa54d2b332f48 100644 --- a/tests/integration/rangesearch/range_search_test.go +++ b/tests/integration/rangesearch/range_search_test.go @@ -92,9 +92,11 @@ func (s *RangeSearchSuite) TestRangeSearchIP() { }) s.NoError(err) segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] - s.True(has) ids := segmentIDs.GetData() - s.NotEmpty(segmentIDs) + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) segments, err := c.MetaWatcher.ShowSegments() s.NoError(err) @@ -102,7 +104,7 @@ func (s *RangeSearchSuite) TestRangeSearchIP() { for _, segment := range segments { log.Info("ShowSegments result", zap.String("segment", segment.String())) } - s.WaitForFlush(ctx, ids) + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) // create index createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ @@ -184,7 +186,6 @@ func (s *RangeSearchSuite) TestRangeSearchIP() { log.Info("TestRangeSearchIP succeed") log.Info("=========================") log.Info("=========================") - } func (s *RangeSearchSuite) TestRangeSearchL2() { @@ -240,9 +241,11 @@ func (s *RangeSearchSuite) TestRangeSearchL2() { }) s.NoError(err) segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] - s.True(has) ids := segmentIDs.GetData() - s.NotEmpty(segmentIDs) + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) segments, err := c.MetaWatcher.ShowSegments() s.NoError(err) @@ -250,7 +253,7 @@ func (s *RangeSearchSuite) TestRangeSearchL2() { for _, segment := range segments { log.Info("ShowSegments result", zap.String("segment", segment.String())) } - s.WaitForFlush(ctx, ids) + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) // create index createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ @@ -331,9 +334,9 @@ func (s *RangeSearchSuite) TestRangeSearchL2() { log.Info("TestRangeSearchL2 succeed") log.Info("=========================") log.Info("=========================") - } func TestRangeSearch(t *testing.T) { + t.Skip("Skip integration test, need to refactor integration test framework") suite.Run(t, new(RangeSearchSuite)) } diff --git a/tests/integration/refreshconfig/refresh_config_test.go b/tests/integration/refreshconfig/refresh_config_test.go index a2f369aeb0c31..e6b35c471d1db 100644 --- a/tests/integration/refreshconfig/refresh_config_test.go +++ b/tests/integration/refreshconfig/refresh_config_test.go @@ -23,6 +23,9 @@ import ( "time" "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -30,8 +33,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/tests/integration" - "github.com/stretchr/testify/suite" - "go.uber.org/zap" ) type RefreshConfigSuite struct { @@ -65,7 +66,6 @@ func (s *RefreshConfigSuite) TestRefreshPasswordLength() { log.Debug("second create result", zap.Any("state", resp)) return commonpb.ErrorCode_Success == resp.GetErrorCode() }, time.Second*20, time.Millisecond*500) - } func (s *RefreshConfigSuite) TestRefreshDefaultIndexName() { @@ -118,11 +118,13 @@ func (s *RefreshConfigSuite) TestRefreshDefaultIndexName() { }) s.NoError(err) segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] - s.True(has) ids := segmentIDs.GetData() - s.NotEmpty(segmentIDs) + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) - s.WaitForFlush(ctx, ids) + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) _, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ CollectionName: collectionName, @@ -142,5 +144,6 @@ func (s *RefreshConfigSuite) TestRefreshDefaultIndexName() { } func TestRefreshConfig(t *testing.T) { + t.Skip("Skip integration test, need to refactor integration test framework") suite.Run(t, new(RefreshConfigSuite)) } diff --git a/tests/integration/suite.go b/tests/integration/suite.go index d86662192f048..f83e30a93d8ed 100644 --- a/tests/integration/suite.go +++ b/tests/integration/suite.go @@ -23,10 +23,11 @@ import ( "strings" "time" - "github.com/milvus-io/milvus/pkg/util/etcd" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/stretchr/testify/suite" "go.etcd.io/etcd/server/v3/embed" + + "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) // EmbedEtcdSuite contains embed setup & teardown related logic diff --git a/tests/integration/upsert/upsert_test.go b/tests/integration/upsert/upsert_test.go index 00a83d0f17266..e24ec6d5ee663 100644 --- a/tests/integration/upsert/upsert_test.go +++ b/tests/integration/upsert/upsert_test.go @@ -22,6 +22,9 @@ import ( "testing" "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" @@ -30,8 +33,6 @@ import ( "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/tests/integration" - "github.com/stretchr/testify/suite" - "go.uber.org/zap" ) type UpsertSuite struct { @@ -92,9 +93,11 @@ func (s *UpsertSuite) TestUpsert() { }) s.NoError(err) segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] - s.True(has) ids := segmentIDs.GetData() - s.NotEmpty(segmentIDs) + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) segments, err := c.MetaWatcher.ShowSegments() s.NoError(err) @@ -102,7 +105,7 @@ func (s *UpsertSuite) TestUpsert() { for _, segment := range segments { log.Info("ShowSegments result", zap.String("segment", segment.String())) } - s.WaitForFlush(ctx, ids) + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) // create index createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ @@ -153,9 +156,9 @@ func (s *UpsertSuite) TestUpsert() { log.Info("TestUpsert succeed") log.Info("==================") log.Info("==================") - } func TestUpsert(t *testing.T) { + t.Skip("Skip integration test, need to refactor integration test framework") suite.Run(t, new(UpsertSuite)) } diff --git a/tests/integration/util_index.go b/tests/integration/util_index.go index dd30eef8d4274..55620db86ab00 100644 --- a/tests/integration/util_index.go +++ b/tests/integration/util_index.go @@ -178,7 +178,7 @@ func GetSearchParams(indexType string, metricType string) map[string]any { case IndexHNSW: params["ef"] = 200 case IndexDISKANN: - params["search_list"] = 5 + params["search_list"] = 20 default: panic(fmt.Sprintf("unimplemented search param for %s, please help to improve it", indexType)) } diff --git a/tests/integration/util_insert.go b/tests/integration/util_insert.go index c220bbf723576..ea03d853d532f 100644 --- a/tests/integration/util_insert.go +++ b/tests/integration/util_insert.go @@ -26,10 +26,13 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) -func (s *MiniClusterSuite) WaitForFlush(ctx context.Context, segIDs []int64) { +func (s *MiniClusterSuite) WaitForFlush(ctx context.Context, segIDs []int64, flushTs uint64, dbName, collectionName string) { flushed := func() bool { resp, err := s.Cluster.Proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ - SegmentIDs: segIDs, + SegmentIDs: segIDs, + FlushTs: flushTs, + DbName: dbName, + CollectionName: collectionName, }) if err != nil { return false @@ -116,6 +119,21 @@ func NewFloatVectorFieldData(fieldName string, numRows, dim int) *schemapb.Field } } +func NewFloat16VectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: schemapb.DataType_Float16Vector, + FieldName: fieldName, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: GenerateFloat16Vectors(numRows, dim), + }, + }, + }, + } +} + func NewBinaryVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData { return &schemapb.FieldData{ Type: schemapb.DataType_BinaryVector, @@ -166,6 +184,16 @@ func GenerateBinaryVectors(numRows, dim int) []byte { return ret } +func GenerateFloat16Vectors(numRows, dim int) []byte { + total := numRows * dim * 2 + ret := make([]byte, total) + _, err := rand.Read(ret) + if err != nil { + panic(err) + } + return ret +} + func GenerateHashKeys(numRows int) []uint32 { ret := make([]uint32, 0, numRows) for i := 0; i < numRows; i++ { diff --git a/tests/integration/util_query.go b/tests/integration/util_query.go index f150ba308225d..eed1b3f5ac6f5 100644 --- a/tests/integration/util_query.go +++ b/tests/integration/util_query.go @@ -181,6 +181,17 @@ func constructPlaceholderGroup(nq, dim int, vectorType schemapb.DataType) *commo } values = append(values, ret) } + case schemapb.DataType_Float16Vector: + placeholderType = commonpb.PlaceholderType_Float16Vector + for i := 0; i < nq; i++ { + total := dim * 2 + ret := make([]byte, total) + _, err := rand.Read(ret) + if err != nil { + panic(err) + } + values = append(values, ret) + } default: panic("invalid vector data type") } diff --git a/tests/integration/util_schema.go b/tests/integration/util_schema.go index a5054701aeed4..1686bd343b78a 100644 --- a/tests/integration/util_schema.go +++ b/tests/integration/util_schema.go @@ -25,17 +25,18 @@ import ( ) const ( - BoolField = "boolField" - Int8Field = "int8Field" - Int16Field = "int16Field" - Int32Field = "int32Field" - Int64Field = "int64Field" - FloatField = "floatField" - DoubleField = "doubleField" - VarCharField = "varCharField" - JSONField = "jsonField" - FloatVecField = "floatVecField" - BinVecField = "binVecField" + BoolField = "boolField" + Int8Field = "int8Field" + Int16Field = "int16Field" + Int32Field = "int32Field" + Int64Field = "int64Field" + FloatField = "floatField" + DoubleField = "doubleField" + VarCharField = "varCharField" + JSONField = "jsonField" + FloatVecField = "floatVecField" + BinVecField = "binVecField" + Float16VecField = "float16VecField" ) func ConstructSchema(collection string, dim int, autoID bool, fields ...*schemapb.FieldSchema) *schemapb.CollectionSchema { diff --git a/tests/python_client/README.md b/tests/python_client/README.md index f1a583f9d7b8c..98424e61b062a 100644 --- a/tests/python_client/README.md +++ b/tests/python_client/README.md @@ -227,7 +227,7 @@ assert self.partition_wrap.is_empty # drop collection collection_w.drop() # create partition failed - self.partition_wrap.init_partition(collection_w.collection, partition_name, check_task=CheckTasks.err_res, check_items={ct.err_code: 1, ct.err_msg: "can't find collection"}) + self.partition_wrap.init_partition(collection_w.collection, partition_name, check_task=CheckTasks.err_res, check_items={ct.err_code: 4, ct.err_msg: "collection not found"}) ``` - Tips diff --git a/tests/python_client/base/collection_wrapper.py b/tests/python_client/base/collection_wrapper.py index 4871145a0d3f6..69d3ffcef8687 100644 --- a/tests/python_client/base/collection_wrapper.py +++ b/tests/python_client/base/collection_wrapper.py @@ -13,7 +13,7 @@ from pymilvus.orm.types import CONSISTENCY_STRONG from common.common_func import param_info -TIMEOUT = 120 +TIMEOUT = 180 INDEX_NAME = "" @@ -272,16 +272,14 @@ def index(self, check_task=None, check_items=None): return res, check_result @trace() - def create_index(self, field_name, index_params=None, index_name=None, check_task=None, check_items=None, **kwargs): - disktimeout = 600 - timeout = kwargs.get("timeout", disktimeout * 2) + def create_index(self, field_name, index_params=None, index_name=None, timeout=None, check_task=None, check_items=None, **kwargs): + timeout = 1200 if timeout is None else timeout index_name = INDEX_NAME if index_name is None else index_name index_name = kwargs.get("index_name", index_name) - kwargs.update({"timeout": timeout, "index_name": index_name}) + kwargs.update({"index_name": index_name}) func_name = sys._getframe().f_code.co_name - res, check = api_request([self.collection.create_index, field_name, index_params], **kwargs) - check_result = ResponseChecker(res, func_name, check_task, check_items, check, - field_name=field_name, index_params=index_params, **kwargs).run() + res, check = api_request([self.collection.create_index, field_name, index_params, timeout], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, check_items, check, **kwargs).run() return res, check_result @trace() diff --git a/tests/python_client/base/partition_wrapper.py b/tests/python_client/base/partition_wrapper.py index c599691eff3ed..3d22978de105d 100644 --- a/tests/python_client/base/partition_wrapper.py +++ b/tests/python_client/base/partition_wrapper.py @@ -9,7 +9,7 @@ from common.common_func import param_info -TIMEOUT = 20 +TIMEOUT = 180 class ApiPartitionWrapper: diff --git a/tests/python_client/base/utility_wrapper.py b/tests/python_client/base/utility_wrapper.py index fcd71f314c962..07ab72c8a44ca 100644 --- a/tests/python_client/base/utility_wrapper.py +++ b/tests/python_client/base/utility_wrapper.py @@ -188,6 +188,24 @@ def wait_all_pending_tasks_finished(self): if task.task_id in pending_task_ids: log.info(f"task {task.task_id} state transfer from pending to {task.state_name}") + def wait_index_build_completed(self, collection_name, timeout=None): + start = time.time() + if timeout is not None: + task_timeout = timeout + else: + task_timeout = TIMEOUT + end = time.time() + while end - start <= task_timeout: + time.sleep(0.5) + index_states, _ = self.index_building_progress(collection_name) + log.debug(f"index states: {index_states}") + if index_states["total_rows"] == index_states["indexed_rows"]: + log.info(f"index build completed") + return True + end = time.time() + log.info(f"index build timeout") + return False + def get_query_segment_info(self, collection_name, timeout=None, using="default", check_task=None, check_items=None): timeout = TIMEOUT if timeout is None else timeout func_name = sys._getframe().f_code.co_name @@ -496,14 +514,14 @@ def transfer_replica(self, source, target, collection_name, num_replica, using=" check_result = ResponseChecker(res, func_name, check_task, check_items, check, **kwargs).run() return res, check_result - def rename_collection(self, old_collection_name, new_collection_name, timeout=None, check_task=None, - check_items=None, **kwargs): + def rename_collection(self, old_collection_name, new_collection_name, new_db_name="", timeout=None, + check_task=None, check_items=None, **kwargs): func_name = sys._getframe().f_code.co_name - res, check = api_request([self.ut.rename_collection, old_collection_name, new_collection_name, timeout], - **kwargs) + res, check = api_request([self.ut.rename_collection, old_collection_name, new_collection_name, new_db_name, + timeout], **kwargs) check_result = ResponseChecker(res, func_name, check_task, check_items, check, old_collection_name=old_collection_name, new_collection_name=new_collection_name, - timeout=timeout, **kwargs).run() + new_db_name=new_db_name, timeout=timeout, **kwargs).run() return res, check_result def flush_all(self, using="default", timeout=None, check_task=None, check_items=None, **kwargs): diff --git a/tests/python_client/bulk_insert/test_bulk_insert_api.py b/tests/python_client/bulk_insert/test_bulk_insert_api.py index 2d1c2e540dd12..5a2ba2f861352 100644 --- a/tests/python_client/bulk_insert/test_bulk_insert_api.py +++ b/tests/python_client/bulk_insert/test_bulk_insert_api.py @@ -114,6 +114,7 @@ def test_float_vector_only(self, is_row_based, auto_id, dim, entities): ] schema = cf.gen_collection_schema(fields=fields, auto_id=auto_id) self.collection_wrap.init_collection(c_name, schema=schema) + # import data t0 = time.time() task_id, _ = self.utility_wrap.do_bulk_insert( @@ -138,9 +139,11 @@ def test_float_vector_only(self, is_row_based, auto_id, dim, entities): self.collection_wrap.create_index( field_name=df.vec_field, index_params=index_params ) - self.collection_wrap.load() + success = self.utility_wrap.wait_index_build_completed(c_name) + assert success log.info(f"wait for load finished and be ready for search") - time.sleep(10) + self.collection_wrap.load() + self.collection_wrap.load(_refresh=True) log.info( f"query seg info: {self.utility_wrap.get_query_segment_info(c_name)[0]}" ) @@ -187,6 +190,7 @@ def test_str_pk_float_vector_only(self, is_row_based, dim, entities): auto_id=auto_id, str_pk=string_pk, data_fields=default_vec_only_fields, + force=True ) self._connect() c_name = cf.gen_unique_str("bulk_insert") @@ -218,9 +222,11 @@ def test_str_pk_float_vector_only(self, is_row_based, dim, entities): self.collection_wrap.create_index( field_name=df.vec_field, index_params=index_params ) - self.collection_wrap.load() + success = self.utility_wrap.wait_index_build_completed(c_name) + assert success log.info(f"wait for load finished and be ready for search") - time.sleep(10) + self.collection_wrap.load() + self.collection_wrap.load(_reshard=True) log.info( f"query seg info: {self.utility_wrap.get_query_segment_info(c_name)[0]}" ) @@ -228,7 +234,7 @@ def test_str_pk_float_vector_only(self, is_row_based, dim, entities): topk = 2 search_data = cf.gen_vectors(nq, dim) search_params = ct.default_search_params - time.sleep(10) + time.sleep(20) res, _ = self.collection_wrap.search( search_data, df.vec_field, @@ -301,7 +307,7 @@ def test_partition_float_vector_int_scalar( ) logging.info(f"bulk insert task ids:{task_id}") success, state = self.utility_wrap.wait_for_bulk_insert_tasks_completed( - task_ids=[task_id], timeout=90 + task_ids=[task_id], timeout=120 ) tt = time.time() - t0 log.info(f"bulk insert state:{success} in {tt}") @@ -310,10 +316,10 @@ def test_partition_float_vector_int_scalar( assert m_partition.num_entities == entities assert self.collection_wrap.num_entities == entities log.debug(state) - res, _ = self.utility_wrap.index_building_progress(c_name) - exp_res = {"total_rows": entities, "indexed_rows": entities} - assert res == exp_res + success = self.utility_wrap.wait_index_build_completed(c_name) + assert success log.info(f"wait for load finished and be ready for search") + self.collection_wrap.load(partition_names=[p_name], _refresh=True) time.sleep(10) log.info( f"query seg info: {self.utility_wrap.get_query_segment_info(c_name)[0]}" @@ -395,14 +401,13 @@ def test_binary_vector_only(self, is_row_based, auto_id, dim, entities): tt = time.time() - t0 log.info(f"bulk insert state:{success} in {tt}") assert success - res, _ = self.utility_wrap.index_building_progress(c_name) - exp_res = {'total_rows': entities, 'indexed_rows': entities} - assert res == exp_res - + success = self.utility_wrap.wait_index_build_completed(c_name) + assert success # verify num entities assert self.collection_wrap.num_entities == entities # verify search and query log.info(f"wait for load finished and be ready for search") + self.collection_wrap.load(_refresh=True) time.sleep(10) search_data = cf.gen_binary_vectors(1, dim)[1] search_params = {"metric_type": "JACCARD", "params": {"nprobe": 10}} @@ -423,7 +428,7 @@ def test_binary_vector_only(self, is_row_based, auto_id, dim, entities): @pytest.mark.parametrize("is_row_based", [True]) @pytest.mark.parametrize("auto_id", [True, False]) @pytest.mark.parametrize( - "fields_num_in_file", ["equal", "more", "less"] + "fields_num_in_file", ["more", "less", "equal"] ) # "equal", "more", "less" @pytest.mark.parametrize("dim", [16]) @pytest.mark.parametrize("entities", [500]) @@ -491,7 +496,7 @@ def test_float_vector_multi_scalars( assert not success if is_row_based: if fields_num_in_file == "less": - failed_reason = f"field '{additional_field}' missed at the row 0" + failed_reason = f"value of field '{additional_field}' is missed" else: failed_reason = f"field '{df.float_field}' is not defined in collection schema" else: @@ -506,11 +511,14 @@ def test_float_vector_multi_scalars( log.info(f" collection entities: {num_entities}") assert num_entities == entities - # verify no index + # verify index status res, _ = self.collection_wrap.has_index() assert res is True + success = self.utility_wrap.wait_index_build_completed(c_name) + assert success # verify search and query log.info(f"wait for load finished and be ready for search") + self.collection_wrap.load(_refresh=True) time.sleep(10) nq = 3 topk = 10 @@ -608,13 +616,12 @@ def test_insert_before_or_after_bulk_insert(self, insert_before_bulk_insert): num_entities = self.collection_wrap.num_entities log.info(f"collection entities: {num_entities}") assert num_entities == bulk_insert_row + direct_insert_row - # verify no index - res, _ = self.utility_wrap.index_building_progress(c_name) - exp_res = {'total_rows': num_entities, 'indexed_rows': num_entities} - assert res == exp_res + # verify index status + success = self.utility_wrap.wait_index_build_completed(c_name) + assert success # verify search and query log.info(f"wait for load finished and be ready for search") - time.sleep(10) + self.collection_wrap.load(_refresh=True) nq = 3 topk = 10 search_data = cf.gen_vectors(nq, dim=dim) @@ -670,13 +677,15 @@ def test_load_before_or_after_bulk_insert(self, loaded_before_bulk_insert, creat schema = cf.gen_collection_schema(fields=fields, auto_id=True) self.collection_wrap.init_collection(c_name, schema=schema) # build index - index_params = ct.default_index - self.collection_wrap.create_index( - field_name=df.vec_field, index_params=index_params - ) + if create_index_before_bulk_insert: + index_params = ct.default_index + self.collection_wrap.create_index( + field_name=df.vec_field, index_params=index_params + ) if loaded_before_bulk_insert: # load collection self.collection_wrap.load() + # import data t0 = time.time() task_id, _ = self.utility_wrap.do_bulk_insert( @@ -689,6 +698,12 @@ def test_load_before_or_after_bulk_insert(self, loaded_before_bulk_insert, creat tt = time.time() - t0 log.info(f"bulk insert state:{success} in {tt}") assert success + if not create_index_before_bulk_insert: + # build index + index_params = ct.default_index + self.collection_wrap.create_index( + field_name=df.vec_field, index_params=index_params + ) if not loaded_before_bulk_insert: # load collection self.collection_wrap.load() @@ -697,12 +712,11 @@ def test_load_before_or_after_bulk_insert(self, loaded_before_bulk_insert, creat log.info(f"collection entities: {num_entities}") assert num_entities == 500 # verify no index - res, _ = self.utility_wrap.index_building_progress(c_name) - exp_res = {'total_rows': num_entities, 'indexed_rows': num_entities} - assert res == exp_res + success = self.utility_wrap.wait_index_build_completed(c_name) + assert success # verify search and query log.info(f"wait for load finished and be ready for search") - time.sleep(10) + self.collection_wrap.load(_refresh=True) nq = 3 topk = 10 search_data = cf.gen_vectors(nq, 16) @@ -794,7 +808,7 @@ def test_string_pk_float_vector_multi_scalars( assert not success # TODO: check error msg if is_row_based: if fields_num_in_file == "less": - failed_reason = f"field '{additional_field}' missed at the row 0" + failed_reason = f"value of field '{additional_field}' is missed" else: failed_reason = f"field '{df.float_field}' is not defined in collection schema" else: @@ -806,11 +820,12 @@ def test_string_pk_float_vector_multi_scalars( assert success log.info(f" collection entities: {self.collection_wrap.num_entities}") assert self.collection_wrap.num_entities == entities - # verify no index - res, _ = self.collection_wrap.has_index() - assert res is True + # verify index + success = self.utility_wrap.wait_index_build_completed(c_name) + assert success # verify search and query log.info(f"wait for load finished and be ready for search") + self.collection_wrap.load(_refresh=True) time.sleep(10) search_data = cf.gen_vectors(1, dim) search_params = ct.default_search_params @@ -830,119 +845,11 @@ def test_string_pk_float_vector_multi_scalars( assert len(results) == len(ids) @pytest.mark.tags(CaseLabel.L3) - @pytest.mark.parametrize("is_row_based", [pytest.param(True, marks=pytest.mark.xfail(reason="issue: https://github.com/milvus-io/milvus/issues/19499"))]) # True, False - @pytest.mark.parametrize("auto_id", [True, False]) # True, False - @pytest.mark.parametrize("dim", [16]) # 16 - @pytest.mark.parametrize("entities", [100]) # 3000 - @pytest.mark.parametrize("file_nums", [32]) # 10 - @pytest.mark.parametrize("multi_folder", [True, False]) # True, False - def test_float_vector_from_multi_files( - self, is_row_based, auto_id, dim, entities, file_nums, multi_folder - ): - """ - collection: auto_id - collection schema: [pk, float_vector, - float_scalar, int_scalar, string_scalar, bool_scalar] - Steps: - 1. create collection - 2. build index and load collection - 3. import data from multiple files - 4. verify the data entities - 5. verify index status - 6. verify search successfully - 7. verify query successfully - """ - files = prepare_bulk_insert_json_files( - minio_endpoint=self.minio_endpoint, - bucket_name=self.bucket_name, - is_row_based=is_row_based, - rows=entities, - dim=dim, - auto_id=auto_id, - data_fields=default_multi_fields, - file_nums=file_nums, - multi_folder=multi_folder, - force=True, - ) - self._connect() - c_name = cf.gen_unique_str("bulk_insert") - fields = [ - cf.gen_int64_field(name=df.pk_field, is_primary=True), - cf.gen_float_vec_field(name=df.vec_field, dim=dim), - cf.gen_int32_field(name=df.int_field), - cf.gen_string_field(name=df.string_field), - cf.gen_bool_field(name=df.bool_field), - cf.gen_float_field(name=df.float_field) - ] - schema = cf.gen_collection_schema(fields=fields, auto_id=auto_id) - self.collection_wrap.init_collection(c_name, schema=schema) - # build index - index_params = ct.default_index - self.collection_wrap.create_index( - field_name=df.vec_field, index_params=index_params - ) - # load collection - self.collection_wrap.load() - # import data - t0 = time.time() - err_msg = "row-based import, only allow one JSON file each time" - task_id, _ = self.utility_wrap.do_bulk_insert( - collection_name=c_name, files=files, - check_task=CheckTasks.err_res, check_items={"err_code": 1, "err_msg": err_msg}, - ) - - # logging.info(f"bulk insert task ids:{task_id}") - # success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed( - # task_ids=[task_id], timeout=90 - # ) - # tt = time.time() - t0 - # log.info(f"bulk insert state:{success} in {tt}") - # if not is_row_based: - # assert not success - # failed_reason = "is duplicated" # "the field xxx is duplicated" - # for state in states.values(): - # assert state.state_name in ["Failed", "Failed and cleaned"] - # assert failed_reason in state.infos.get("failed_reason", "") - # else: - # assert success - # num_entities = self.collection_wrap.num_entities - # log.info(f" collection entities: {num_entities}") - # assert num_entities == entities * file_nums - # - # # verify index built - # res, _ = self.utility_wrap.index_building_progress(c_name) - # exp_res = {'total_rows': entities * file_nums, 'indexed_rows': entities * file_nums} - # assert res == exp_res - # - # # verify search and query - # log.info(f"wait for load finished and be ready for search") - # time.sleep(10) - # nq = 5 - # topk = 1 - # search_data = cf.gen_vectors(nq, dim) - # search_params = ct.default_search_params - # res, _ = self.collection_wrap.search( - # search_data, - # df.vec_field, - # param=search_params, - # limit=topk, - # check_task=CheckTasks.check_search_results, - # check_items={"nq": nq, "limit": topk}, - # ) - # for hits in res: - # ids = hits.ids - # results, _ = self.collection_wrap.query(expr=f"{df.pk_field} in {ids}") - # assert len(results) == len(ids) - - @pytest.mark.tags(CaseLabel.L3) - @pytest.mark.parametrize("is_row_based", [True]) @pytest.mark.parametrize("auto_id", [True, False]) - @pytest.mark.parametrize("multi_fields", [True, False]) @pytest.mark.parametrize("dim", [15]) @pytest.mark.parametrize("entities", [200]) - @pytest.mark.skip(reason="stop support for numpy files") def test_float_vector_from_numpy_file( - self, is_row_based, auto_id, multi_fields, dim, entities + self, auto_id, dim, entities ): """ collection schema 1: [pk, float_vector] @@ -956,7 +863,10 @@ def test_float_vector_from_numpy_file( 4.1 verify the data entities equal the import data 4.2 verify search and query successfully """ - data_fields = [df.vec_field] + if auto_id: + data_fields = [df.vec_field, df.int_field, df.string_field, df.float_field, df.bool_field] + else: + data_fields = [df.pk_field, df.vec_field, df.int_field, df.string_field, df.float_field, df.bool_field] np_files = prepare_bulk_insert_numpy_files( minio_endpoint=self.minio_endpoint, bucket_name=self.bucket_name, @@ -965,53 +875,15 @@ def test_float_vector_from_numpy_file( data_fields=data_fields, force=True, ) - if not multi_fields: - fields = [ - cf.gen_int64_field(name=df.pk_field, is_primary=True), - cf.gen_float_vec_field(name=df.vec_field, dim=dim), - ] - if not auto_id: - scalar_fields = [df.pk_field] - else: - scalar_fields = None - else: - fields = [ - cf.gen_int64_field(name=df.pk_field, is_primary=True), - cf.gen_float_vec_field(name=df.vec_field, dim=dim), - cf.gen_int32_field(name=df.int_field), - cf.gen_string_field(name=df.string_field), - cf.gen_bool_field(name=df.bool_field), - ] - if not auto_id: - scalar_fields = [ - df.pk_field, - df.float_field, - df.int_field, - df.string_field, - df.bool_field, - ] - else: - scalar_fields = [ - df.int_field, - df.string_field, - df.bool_field, - df.float_field, - ] - + fields = [ + cf.gen_int64_field(name=df.pk_field, is_primary=True), + cf.gen_float_vec_field(name=df.vec_field, dim=dim), + cf.gen_int64_field(name=df.int_field), + cf.gen_string_field(name=df.string_field), + cf.gen_float_field(name=df.float_field), + cf.gen_bool_field(name=df.bool_field), + ] files = np_files - if scalar_fields is not None: - json_files = prepare_bulk_insert_json_files( - minio_endpoint=self.minio_endpoint, - bucket_name=self.bucket_name, - is_row_based=is_row_based, - dim=dim, - auto_id=auto_id, - rows=entities, - data_fields=scalar_fields, - force=True, - ) - files = np_files + json_files - self._connect() c_name = cf.gen_unique_str("bulk_insert") schema = cf.gen_collection_schema(fields=fields, auto_id=auto_id) @@ -1024,64 +896,55 @@ def test_float_vector_from_numpy_file( ) logging.info(f"bulk insert task ids:{task_id}") success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed( - task_ids=[task_id], timeout=90 + task_ids=[task_id], timeout=120 ) tt = time.time() - t0 log.info(f"bulk insert state:{success} in {tt}") - if is_row_based: - assert not success - failed_reason1 = "unsupported file type for row-based mode" - failed_reason2 = ( - f"JSON row validator: field {df.vec_field} missed at the row 0" - ) - for state in states.values(): - assert state.state_name in ["Failed", "Failed and cleaned"] - assert failed_reason1 in state.infos.get( - "failed_reason", "" - ) or failed_reason2 in state.infos.get("failed_reason", "") - else: - assert success - log.info(f" collection entities: {self.collection_wrap.num_entities}") - assert self.collection_wrap.num_entities == entities - # create index and load - index_params = ct.default_index - self.collection_wrap.create_index( - field_name=df.vec_field, index_params=index_params - ) - self.collection_wrap.load() - log.info(f"wait for load finished and be ready for search") - time.sleep(10) - log.info( - f"query seg info: {self.utility_wrap.get_query_segment_info(c_name)[0]}" - ) - # verify imported data is available for search - nq = 2 - topk = 5 - search_data = cf.gen_vectors(nq, dim) - search_params = ct.default_search_params - res, _ = self.collection_wrap.search( - search_data, - df.vec_field, - param=search_params, - limit=topk, - check_task=CheckTasks.check_search_results, - check_items={"nq": nq, "limit": topk}, - ) - for hits in res: - ids = hits.ids - results, _ = self.collection_wrap.query(expr=f"{df.pk_field} in {ids}") - assert len(results) == len(ids) + assert success + log.info(f" collection entities: {self.collection_wrap.num_entities}") + assert self.collection_wrap.num_entities == entities + # create index and load + index_params = ct.default_index + self.collection_wrap.create_index( + field_name=df.vec_field, index_params=index_params + ) + result = self.utility_wrap.wait_index_build_completed(c_name) + assert result is True + self.collection_wrap.load() + self.collection_wrap.load(_refresh=True) + log.info(f"wait for load finished and be ready for search") + log.info( + f"query seg info: {self.utility_wrap.get_query_segment_info(c_name)[0]}" + ) + # verify imported data is available for search + nq = 2 + topk = 5 + search_data = cf.gen_vectors(nq, dim) + search_params = ct.default_search_params + res, _ = self.collection_wrap.search( + search_data, + df.vec_field, + param=search_params, + limit=topk, + check_task=CheckTasks.check_search_results, + check_items={"nq": nq, "limit": topk}, + ) + for hits in res: + ids = hits.ids + results, _ = self.collection_wrap.query(expr=f"{df.pk_field} in {ids}") + assert len(results) == len(ids) @pytest.mark.tags(CaseLabel.L3) @pytest.mark.parametrize("is_row_based", [True]) + @pytest.mark.parametrize("auto_id", [True, False]) @pytest.mark.parametrize("dim", [8]) @pytest.mark.parametrize("entities", [10]) - def test_data_type_float_on_int_pk(self, is_row_based, dim, entities): + def test_data_type_int_on_float_scalar(self, is_row_based, auto_id, dim, entities): """ collection schema: [pk, float_vector, float_scalar, int_scalar, string_scalar, bool_scalar] - data files: json file that one of entities has float on int pk + data files: json file that one of entities has int on float scalar Steps: 1. create collection 2. import data @@ -1094,23 +957,23 @@ def test_data_type_float_on_int_pk(self, is_row_based, dim, entities): is_row_based=is_row_based, rows=entities, dim=dim, - auto_id=False, + auto_id=auto_id, data_fields=default_multi_fields, - err_type=DataErrorType.float_on_int_pk, + err_type=DataErrorType.int_on_float_scalar, force=True, ) + self._connect() c_name = cf.gen_unique_str("bulk_insert") - # TODO: add string pk fields = [ cf.gen_int64_field(name=df.pk_field, is_primary=True), cf.gen_float_vec_field(name=df.vec_field, dim=dim), cf.gen_int32_field(name=df.int_field), + cf.gen_float_field(name=df.float_field), cf.gen_string_field(name=df.string_field), cf.gen_bool_field(name=df.bool_field), - cf.gen_float_field(name=df.float_field), ] - schema = cf.gen_collection_schema(fields=fields, auto_id=False) + schema = cf.gen_collection_schema(fields=fields, auto_id=auto_id) self.collection_wrap.init_collection(c_name, schema=schema) # import data task_id, _ = self.utility_wrap.do_bulk_insert( @@ -1123,60 +986,69 @@ def test_data_type_float_on_int_pk(self, is_row_based, dim, entities): log.info(f"bulk insert state:{success}") assert success assert self.collection_wrap.num_entities == entities + index_params = ct.default_index self.collection_wrap.create_index( field_name=df.vec_field, index_params=index_params ) - self.collection_wrap.load() + success = self.utility_wrap.wait_index_build_completed(c_name) + assert success + # verify imported data is available for search log.info(f"wait for load finished and be ready for search") - time.sleep(10) - # the pk value was automatically convert to int from float + self.collection_wrap.load() + self.collection_wrap.load(_refresh=True) + search_data = cf.gen_vectors(1, dim) + search_params = ct.default_search_params + res, _ = self.collection_wrap.search( + search_data, + df.vec_field, + param=search_params, + limit=1, + check_task=CheckTasks.check_search_results, + check_items={"nq": 1, "limit": 1}, + ) + uids = res[0].ids res, _ = self.collection_wrap.query( - expr=f"{df.pk_field} in [3]", output_fields=[df.pk_field] + expr=f"{df.pk_field} in {uids}", output_fields=[df.float_field] ) - assert [{df.pk_field: 3}] == res + assert isinstance(res[0].get(df.float_field, 1), np.float32) @pytest.mark.tags(CaseLabel.L3) - @pytest.mark.parametrize("is_row_based", [True]) - @pytest.mark.parametrize("auto_id", [True, False]) - @pytest.mark.parametrize("dim", [8]) - @pytest.mark.parametrize("entities", [10]) - def test_data_type_int_on_float_scalar(self, is_row_based, auto_id, dim, entities): + @pytest.mark.parametrize("auto_id", [True]) + @pytest.mark.parametrize("dim", [128]) # 128 + @pytest.mark.parametrize("entities", [1000]) # 1000 + def test_with_all_field_numpy(self, auto_id, dim, entities): """ - collection schema: [pk, float_vector, - float_scalar, int_scalar, string_scalar, bool_scalar] - data files: json file that one of entities has int on float scalar + collection schema 1: [pk, int64, float64, string float_vector] + data file: vectors.npy and uid.npy, Steps: 1. create collection 2. import data - 3. verify the data entities - 4. verify query successfully + 3. verify """ - files = prepare_bulk_insert_json_files( + data_fields = [df.int_field, df.float_field, df.double_field, df.vec_field] + fields = [ + cf.gen_int64_field(name=df.pk_field, is_primary=True), + cf.gen_int64_field(name=df.int_field), + cf.gen_float_field(name=df.float_field), + cf.gen_double_field(name=df.double_field), + cf.gen_float_vec_field(name=df.vec_field, dim=dim), + ] + files = prepare_bulk_insert_numpy_files( minio_endpoint=self.minio_endpoint, bucket_name=self.bucket_name, - is_row_based=is_row_based, rows=entities, dim=dim, - auto_id=auto_id, - data_fields=default_multi_fields, - err_type=DataErrorType.int_on_float_scalar, + data_fields=data_fields, force=True, ) - self._connect() c_name = cf.gen_unique_str("bulk_insert") - fields = [ - cf.gen_int64_field(name=df.pk_field, is_primary=True), - cf.gen_float_vec_field(name=df.vec_field, dim=dim), - cf.gen_int32_field(name=df.int_field), - cf.gen_float_field(name=df.float_field), - cf.gen_string_field(name=df.string_field), - cf.gen_bool_field(name=df.bool_field), - ] schema = cf.gen_collection_schema(fields=fields, auto_id=auto_id) self.collection_wrap.init_collection(c_name, schema=schema) + # import data + t0 = time.time() task_id, _ = self.utility_wrap.do_bulk_insert( collection_name=c_name, files=files ) @@ -1184,17 +1056,101 @@ def test_data_type_int_on_float_scalar(self, is_row_based, auto_id, dim, entitie success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed( task_ids=[task_id], timeout=90 ) + tt = time.time() - t0 + log.info(f"bulk insert state:{success} in {tt}") + assert success + num_entities = self.collection_wrap.num_entities + log.info(f" collection entities: {num_entities}") + assert num_entities == entities + # verify imported data is available for search + index_params = ct.default_index + self.collection_wrap.create_index( + field_name=df.vec_field, index_params=index_params + ) + success = self.utility_wrap.wait_index_build_completed(c_name) + log.info(f"wait for load finished and be ready for search") + self.collection_wrap.load() + self.collection_wrap.load(_refresh=True) + log.info(f"query seg info: {self.utility_wrap.get_query_segment_info(c_name)[0]}") + search_data = cf.gen_vectors(1, dim) + search_params = ct.default_search_params + res, _ = self.collection_wrap.search( + search_data, + df.vec_field, + param=search_params, + limit=1, + check_task=CheckTasks.check_search_results, + check_items={"nq": 1, "limit": 1}, + ) + for hits in res: + ids = hits.ids + results, _ = self.collection_wrap.query(expr=f"{df.pk_field} in {ids}") + assert len(results) == len(ids) + + @pytest.mark.tags(CaseLabel.L3) + @pytest.mark.parametrize("auto_id", [True, False]) + @pytest.mark.parametrize("dim", [6]) + @pytest.mark.parametrize("entities", [2000]) + @pytest.mark.parametrize("file_nums", [10]) + def test_multi_numpy_files_from_diff_folders( + self, auto_id, dim, entities, file_nums + ): + """ + collection schema 1: [pk, float_vector] + data file: .npy files in different folders + Steps: + 1. create collection, create index and load + 2. import data + 3. verify that import numpy files in a loop + """ + self._connect() + c_name = cf.gen_unique_str("bulk_insert") + fields = [ + cf.gen_int64_field(name=df.pk_field, is_primary=True), + cf.gen_int64_field(name=df.int_field), + cf.gen_float_field(name=df.float_field), + cf.gen_double_field(name=df.double_field), + cf.gen_float_vec_field(name=df.vec_field, dim=dim), + ] + schema = cf.gen_collection_schema(fields=fields) + self.collection_wrap.init_collection(c_name, schema=schema) + # build index + index_params = ct.default_index + self.collection_wrap.create_index( + field_name=df.vec_field, index_params=index_params + ) + # load collection + self.collection_wrap.load() + data_fields = [f.name for f in fields if not f.to_dict().get("auto_id", False)] + task_ids = [] + for i in range(file_nums): + files = prepare_bulk_insert_numpy_files( + minio_endpoint=self.minio_endpoint, + bucket_name=self.bucket_name, + rows=entities, + dim=dim, + data_fields=data_fields, + file_nums=1, + force=True, + ) + task_id, _ = self.utility_wrap.do_bulk_insert( + collection_name=c_name, files=files + ) + task_ids.append(task_id) + success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed( + task_ids=[task_id], timeout=90 + ) log.info(f"bulk insert state:{success}") - assert success - assert self.collection_wrap.num_entities == entities - index_params = ct.default_index - self.collection_wrap.create_index( - field_name=df.vec_field, index_params=index_params - ) - self.collection_wrap.load() + assert success + log.info(f" collection entities: {self.collection_wrap.num_entities}") + assert self.collection_wrap.num_entities == entities * file_nums + # verify imported data is indexed + success = self.utility_wrap.wait_index_build_completed(c_name) + assert success + # verify search and query log.info(f"wait for load finished and be ready for search") - time.sleep(10) + self.collection_wrap.load(_refresh=True) search_data = cf.gen_vectors(1, dim) search_params = ct.default_search_params res, _ = self.collection_wrap.search( @@ -1205,62 +1161,68 @@ def test_data_type_int_on_float_scalar(self, is_row_based, auto_id, dim, entitie check_task=CheckTasks.check_search_results, check_items={"nq": 1, "limit": 1}, ) - uids = res[0].ids - res, _ = self.collection_wrap.query( - expr=f"{df.pk_field} in {uids}", output_fields=[df.float_field] - ) - assert isinstance(res[0].get(df.float_field, 1), float) @pytest.mark.tags(CaseLabel.L3) - @pytest.mark.parametrize("auto_id", [True]) - @pytest.mark.parametrize("dim", [128]) # 128 - @pytest.mark.parametrize("entities", [1000]) # 1000 - @pytest.mark.skip(reason="stop support for numpy files") - def test_with_all_field_numpy(self, auto_id, dim, entities): + @pytest.mark.parametrize("is_row_based", [True]) + @pytest.mark.parametrize("auto_id", [True, False]) + @pytest.mark.parametrize("par_key_field", [df.int_field, df.string_field]) + def test_partition_key_on_json_file(self, is_row_based, auto_id, par_key_field): """ - collection schema 1: [pk, int64, float64, string float_vector] - data file: vectors.npy and uid.npy, + collection: auto_id, customized_id + collection schema: [pk, int64, varchar, float_vector] Steps: - 1. create collection + 1. create collection with partition key enabled 2. import data - 3. verify + 3. verify the data entities equal the import data and distributed by values of partition key field + 4. load the collection + 5. verify search successfully + 6. verify query successfully """ - data_fields = [df.pk_field, df.int_field, df.float_field, df.double_field, df.vec_field] - fields = [ - cf.gen_int64_field(name=df.pk_field, is_primary=True), - cf.gen_int64_field(name=df.int_field), - cf.gen_float_field(name=df.float_field), - cf.gen_double_field(name=df.double_field), - cf.gen_float_vec_field(name=df.vec_field, dim=dim), - ] - files = prepare_bulk_insert_numpy_files( + dim = 12 + entities = 200 + files = prepare_bulk_insert_json_files( minio_endpoint=self.minio_endpoint, bucket_name=self.bucket_name, + is_row_based=is_row_based, rows=entities, dim=dim, - data_fields=data_fields, + auto_id=auto_id, + data_fields=default_multi_fields, force=True, ) self._connect() - c_name = cf.gen_unique_str("bulk_insert") + c_name = cf.gen_unique_str("bulk_parkey") + fields = [ + cf.gen_int64_field(name=df.pk_field, is_primary=True), + cf.gen_float_vec_field(name=df.vec_field, dim=dim), + cf.gen_int64_field(name=df.int_field, is_partition_key=(par_key_field == df.int_field)), + cf.gen_string_field(name=df.string_field, is_partition_key=(par_key_field == df.string_field)), + cf.gen_bool_field(name=df.bool_field), + cf.gen_float_field(name=df.float_field), + ] schema = cf.gen_collection_schema(fields=fields, auto_id=auto_id) self.collection_wrap.init_collection(c_name, schema=schema) + assert len(self.collection_wrap.partitions) == ct.default_partition_num # import data t0 = time.time() task_id, _ = self.utility_wrap.do_bulk_insert( - collection_name=c_name, files=files + collection_name=c_name, + partition_name=None, + files=files, ) - logging.info(f"bulk insert task ids:{task_id}") - success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed( + logging.info(f"bulk insert task id:{task_id}") + success, _ = self.utility_wrap.wait_for_bulk_insert_tasks_completed( task_ids=[task_id], timeout=90 ) tt = time.time() - t0 log.info(f"bulk insert state:{success} in {tt}") assert success + num_entities = self.collection_wrap.num_entities log.info(f" collection entities: {num_entities}") assert num_entities == entities + # verify imported data is available for search index_params = ct.default_index self.collection_wrap.create_index( @@ -1269,39 +1231,70 @@ def test_with_all_field_numpy(self, auto_id, dim, entities): self.collection_wrap.load() log.info(f"wait for load finished and be ready for search") time.sleep(10) - # log.info(f"query seg info: {self.utility_wrap.get_query_segment_info(c_name)[0]}") - search_data = cf.gen_vectors(1, dim) + log.info( + f"query seg info: {self.utility_wrap.get_query_segment_info(c_name)[0]}" + ) + nq = 2 + topk = 2 + search_data = cf.gen_vectors(nq, dim) search_params = ct.default_search_params res, _ = self.collection_wrap.search( search_data, df.vec_field, param=search_params, - limit=1, + limit=topk, check_task=CheckTasks.check_search_results, - check_items={"nq": 1, "limit": 1}, + check_items={"nq": nq, "limit": topk}, + ) + for hits in res: + ids = hits.ids + results, _ = self.collection_wrap.query(expr=f"{df.pk_field} in {ids}") + assert len(results) == len(ids) + + # verify data was bulk inserted into different partitions + segment_num = len(self.utility_wrap.get_query_segment_info(c_name)[0]) + num_entities = 0 + empty_partition_num = 0 + for p in self.collection_wrap.partitions: + if p.num_entities == 0: + empty_partition_num += 1 + num_entities += p.num_entities + assert num_entities == entities + # as there are not many vectors, one partition should only have one segment after bulk insert + assert segment_num == (ct.default_partition_num - empty_partition_num) + + # verify error when tyring to bulk insert into a specific partition + # TODO: enable the error msg assert after issue #25586 fixed + err_msg = "not allow to set partition name for collection with partition key" + task_id, _ = self.utility_wrap.do_bulk_insert( + collection_name=c_name, + partition_name=self.collection_wrap.partitions[0].name, + files=files, + check_task=CheckTasks.err_res, + check_items={"err_code": 99, "err_msg": err_msg}, ) @pytest.mark.tags(CaseLabel.L3) @pytest.mark.parametrize("auto_id", [True, False]) - @pytest.mark.parametrize("dim", [6]) - @pytest.mark.parametrize("entities", [2000]) + @pytest.mark.parametrize("dim", [13]) + @pytest.mark.parametrize("entities", [300]) @pytest.mark.parametrize("file_nums", [10]) - def test_multi_numpy_files_from_diff_folders( - self, auto_id, dim, entities, file_nums + def test_partition_key_on_multi_numpy_files( + self, auto_id, dim, entities, file_nums ): """ - collection schema 1: [pk, float_vector] + collection schema 1: [pk, int64, float_vector, double] data file: .npy files in different folders Steps: - 1. create collection, create index and load + 1. create collection with partition key enabled, create index and load 2. import data 3. verify that import numpy files in a loop """ self._connect() - c_name = cf.gen_unique_str("bulk_insert") + c_name = cf.gen_unique_str("bulk_ins_parkey") fields = [ cf.gen_int64_field(name=df.pk_field, is_primary=True), - cf.gen_int64_field(name=df.int_field), + cf.gen_int64_field(name=df.int_field, is_partition_key=True), cf.gen_float_field(name=df.float_field), cf.gen_double_field(name=df.double_field), cf.gen_float_vec_field(name=df.vec_field, dim=dim), @@ -1339,10 +1332,12 @@ def test_multi_numpy_files_from_diff_folders( assert success log.info(f" collection entities: {self.collection_wrap.num_entities}") assert self.collection_wrap.num_entities == entities * file_nums - + # verify imported data is indexed + success = self.utility_wrap.wait_index_build_completed(c_name) + assert success # verify search and query log.info(f"wait for load finished and be ready for search") - time.sleep(10) + self.collection_wrap.load(_refresh=True) search_data = cf.gen_vectors(1, dim) search_params = ct.default_search_params res, _ = self.collection_wrap.search( @@ -1354,9 +1349,15 @@ def test_multi_numpy_files_from_diff_folders( check_items={"nq": 1, "limit": 1}, ) - # TODO: not supported yet - def test_from_customize_bucket(self): - pass + # verify data was bulk inserted into different partitions + segment_num = len(self.utility_wrap.get_query_segment_info(c_name)[0]) + num_entities = 0 + empty_partition_num = 0 + for p in self.collection_wrap.partitions: + if p.num_entities == 0: + empty_partition_num += 1 + num_entities += p.num_entities + assert num_entities == entities * file_nums class TestBulkInsertInvalidParams(TestcaseBaseBulkInsert): @@ -1441,18 +1442,82 @@ def test_empty_json_file(self, is_row_based, auto_id): success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed( task_ids=[task_id], timeout=90 ) - assert not success - failed_reason = "row count is 0" - for state in states.values(): - assert state.state_name in ["Failed", "Failed and cleaned"] - assert failed_reason in state.infos.get("failed_reason", "") + assert success + # TODO: remove the assert below if issue #25685 was by design + # assert not success + # failed_reason = "row count is 0" + # for state in states.values(): + # assert state.state_name in ["Failed", "Failed and cleaned"] + # assert failed_reason in state.infos.get("failed_reason", "") + + @pytest.mark.tags(CaseLabel.L3) + @pytest.mark.parametrize("is_row_based", [True]) # True, False + @pytest.mark.parametrize("auto_id", [True, False]) # True, False + @pytest.mark.parametrize("dim", [16]) # 16 + @pytest.mark.parametrize("entities", [100]) # 3000 + @pytest.mark.parametrize("file_nums", [32]) # 10 + @pytest.mark.parametrize("multi_folder", [True, False]) # True, False + def test_float_vector_from_multi_files( + self, is_row_based, auto_id, dim, entities, file_nums, multi_folder + ): + """ + collection: auto_id + collection schema: [pk, float_vector, + float_scalar, int_scalar, string_scalar, bool_scalar] + Steps: + 1. create collection + 2. build index and load collection + 3. import data from multiple files + 4. verify the data entities + 5. verify index status + 6. verify search successfully + 7. verify query successfully + """ + files = prepare_bulk_insert_json_files( + minio_endpoint=self.minio_endpoint, + bucket_name=self.bucket_name, + is_row_based=is_row_based, + rows=entities, + dim=dim, + auto_id=auto_id, + data_fields=default_multi_fields, + file_nums=file_nums, + multi_folder=multi_folder, + force=True, + ) + self._connect() + c_name = cf.gen_unique_str("bulk_insert") + fields = [ + cf.gen_int64_field(name=df.pk_field, is_primary=True), + cf.gen_float_vec_field(name=df.vec_field, dim=dim), + cf.gen_int32_field(name=df.int_field), + cf.gen_string_field(name=df.string_field), + cf.gen_bool_field(name=df.bool_field), + cf.gen_float_field(name=df.float_field) + ] + schema = cf.gen_collection_schema(fields=fields, auto_id=auto_id) + self.collection_wrap.init_collection(c_name, schema=schema) + # build index + index_params = ct.default_index + self.collection_wrap.create_index( + field_name=df.vec_field, index_params=index_params + ) + # load collection + self.collection_wrap.load() + # import data + t0 = time.time() + err_msg = "row-based import, only allow one JSON file each time" + task_id, _ = self.utility_wrap.do_bulk_insert( + collection_name=c_name, files=files, + check_task=CheckTasks.err_res, check_items={"err_code": 1, "err_msg": err_msg}, + ) + @pytest.mark.tags(CaseLabel.L3) @pytest.mark.parametrize("is_row_based", [True]) @pytest.mark.parametrize("auto_id", [True, False]) @pytest.mark.parametrize("dim", [8]) # 8 @pytest.mark.parametrize("entities", [100]) # 100 - # @pytest.mark.xfail(reason="issue https://github.com/milvus-io/milvus/issues/19658") def test_wrong_file_type(self, is_row_based, auto_id, dim, entities): """ collection schema: [pk, float_vector] @@ -1844,14 +1909,15 @@ def test_non_existing_partition(self, is_row_based, dim, entities): self.collection_wrap.init_collection(c_name, schema=schema) # import data into a non existing partition p_name = "non_existing" - err_msg = f"partition ID not found for partition name {p_name}" + err_msg = f"partition ID not found for partition name '{p_name}'" task_id, _ = self.utility_wrap.do_bulk_insert( collection_name=c_name, partition_name=p_name, files=files, check_task=CheckTasks.err_res, - check_items={"err_code": 1, "err_msg": err_msg}, + check_items={"err_code": 11, "err_msg": err_msg}, ) + print(task_id) @pytest.mark.tags(CaseLabel.L3) @pytest.mark.parametrize("is_row_based", [True]) @@ -1913,17 +1979,16 @@ def test_wrong_dim_in_one_entities_of_file( @pytest.mark.parametrize("dim", [16]) @pytest.mark.parametrize("entities", [300]) @pytest.mark.parametrize("file_nums", [10]) # max task nums 32? need improve - @pytest.mark.skip(reason="not support multiple files now") - def test_float_vector_one_of_files_fail( + def test_float_vector_with_multi_json_files( self, is_row_based, auto_id, dim, entities, file_nums ): """ collection schema: [pk, float_vectors, int_scalar], one of entities has wrong dim data - data files: multi files, and there are errors in one of files - 1. import data 11 files(10 correct and 1 with errors) into the collection + data files: multi files, + 1. import data 10 files 2. verify that import fails with errors and no data imported """ - correct_files = prepare_bulk_insert_json_files( + multi_files = prepare_bulk_insert_json_files( minio_endpoint=self.minio_endpoint, bucket_name=self.bucket_name, is_row_based=is_row_based, @@ -1934,20 +1999,7 @@ def test_float_vector_one_of_files_fail( file_nums=file_nums, force=True, ) - - # append a file that has errors - dismatch_dim = dim + 1 - err_files = prepare_bulk_insert_json_files( - minio_endpoint=self.minio_endpoint, - bucket_name=self.bucket_name, - is_row_based=is_row_based, - rows=entities, - dim=dismatch_dim, - auto_id=auto_id, - data_fields=default_multi_fields, - file_nums=1, - ) - files = correct_files + err_files + files = multi_files random.shuffle(files) # mix up the file order self._connect() @@ -1964,22 +2016,14 @@ def test_float_vector_one_of_files_fail( self.collection_wrap.init_collection(c_name, schema=schema) # import data - t0 = time.time() task_id, _ = self.utility_wrap.do_bulk_insert( - collection_name=c_name, files=files - ) - logging.info(f"bulk insert task ids:{task_id}") - success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed( - task_ids=[task_id], timeout=90 + collection_name=c_name, files=files, + check_task=CheckTasks.err_res, + check_items={"err_code": 1, + "err_msg": "row-based import, only allow one JSON file each time"} ) - tt = time.time() - t0 - log.info(f"bulk insert state:{success} in {tt}") - assert not success - if is_row_based: - # all correct files shall be imported successfully - assert self.collection_wrap.num_entities == entities * file_nums - else: - assert self.collection_wrap.num_entities == 0 + assert self.collection_wrap.num_entities == 0 + @pytest.mark.tags(CaseLabel.L3) @pytest.mark.parametrize("auto_id", [True, False]) @@ -2144,6 +2188,7 @@ def test_duplicate_numpy_files(self, auto_id, dim, entities): @pytest.mark.parametrize("is_row_based", [True]) @pytest.mark.parametrize("dim", [8]) @pytest.mark.parametrize("entities", [10]) + # @pytest.mark.xfail(reason="https://github.com/milvus-io/milvus/issues/21818") def test_data_type_string_on_int_pk(self, is_row_based, dim, entities): """ collection schema: default multi scalars @@ -2188,7 +2233,61 @@ def test_data_type_string_on_int_pk(self, is_row_based, dim, entities): ) log.info(f"bulk insert state:{success}") assert not success - failed_reason = f"illegal numeric value" + failed_reason = f"illegal value" + for state in states.values(): + assert state.state_name in ["Failed", "Failed and cleaned"] + assert failed_reason in state.infos.get("failed_reason", "") + assert self.collection_wrap.num_entities == 0 + + @pytest.mark.tags(CaseLabel.L3) + @pytest.mark.parametrize("is_row_based", [True]) + @pytest.mark.parametrize("dim", [8]) + @pytest.mark.parametrize("entities", [10]) + def test_data_type_float_on_int_pk(self, is_row_based, dim, entities): + """ + collection schema: [pk, float_vector, + float_scalar, int_scalar, string_scalar, bool_scalar] + data files: json file that one of entities has float on int pk + Steps: + 1. create collection + 2. import data with wrong data type + 3. verify import failed + """ + files = prepare_bulk_insert_json_files( + minio_endpoint=self.minio_endpoint, + bucket_name=self.bucket_name, + is_row_based=is_row_based, + rows=entities, + dim=dim, + auto_id=False, + data_fields=default_multi_fields, + err_type=DataErrorType.float_on_int_pk, + force=True, + ) + self._connect() + c_name = cf.gen_unique_str("bulk_insert") + # TODO: add string pk + fields = [ + cf.gen_int64_field(name=df.pk_field, is_primary=True), + cf.gen_float_vec_field(name=df.vec_field, dim=dim), + cf.gen_int32_field(name=df.int_field), + cf.gen_string_field(name=df.string_field), + cf.gen_bool_field(name=df.bool_field), + cf.gen_float_field(name=df.float_field), + ] + schema = cf.gen_collection_schema(fields=fields, auto_id=False) + self.collection_wrap.init_collection(c_name, schema=schema) + # import data + task_id, _ = self.utility_wrap.do_bulk_insert( + collection_name=c_name, files=files + ) + logging.info(f"bulk insert task ids:{task_id}") + success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed( + task_ids=[task_id], timeout=90 + ) + log.info(f"bulk insert state:{success}") + assert not success + failed_reason = f"failed to convert row value to entity" for state in states.values(): assert state.state_name in ["Failed", "Failed and cleaned"] assert failed_reason in state.infos.get("failed_reason", "") @@ -2209,20 +2308,29 @@ def test_data_type_typo_on_bool(self, is_row_based, auto_id, dim, entities): 2. import data 3. verify import failed with errors """ + + multi_fields = [ + df.vec_field, + df.int_field, + df.string_field, + df.bool_field, + df.float_field, + ] + if not auto_id: + multi_fields.insert(0, df.pk_field) files = prepare_bulk_insert_json_files( minio_endpoint=self.minio_endpoint, bucket_name=self.bucket_name, is_row_based=is_row_based, rows=entities, dim=dim, - auto_id=False, + auto_id=auto_id, data_fields=default_multi_fields, err_type=DataErrorType.typo_on_bool, force=True, ) self._connect() c_name = cf.gen_unique_str("bulk_insert") - # TODO: add string pk fields = [ cf.gen_int64_field(name=df.pk_field, is_primary=True), cf.gen_float_vec_field(name=df.vec_field, dim=dim), @@ -2233,6 +2341,7 @@ def test_data_type_typo_on_bool(self, is_row_based, auto_id, dim, entities): ] schema = cf.gen_collection_schema(fields=fields, auto_id=auto_id) self.collection_wrap.init_collection(c_name, schema=schema) + # import data task_id, _ = self.utility_wrap.do_bulk_insert( collection_name=c_name, files=files @@ -2364,7 +2473,7 @@ def test_data_type_str_on_float_scalar(self, is_row_based, auto_id, dim, entitie ) log.info(f"bulk insert state:{success}") assert not success - failed_reason = "illegal numeric value" + failed_reason = "failed to convert row value to entity" for state in states.values(): assert state.state_name in ["Failed", "Failed and cleaned"] assert failed_reason in state.infos.get("failed_reason", "") @@ -2426,15 +2535,16 @@ def test_data_type_str_on_vector_fields( ) log.info(f"bulk insert state:{success}") assert not success - failed_reason = "illegal numeric value" - if not float_vector: - failed_reason = f"the field '{df.vec_field}' value at the row {wrong_position} is invalid" + failed_reason1 = "failed to parse row value" + failed_reason2 = "failed to convert row value to entity" for state in states.values(): assert state.state_name in ["Failed", "Failed and cleaned"] - assert failed_reason in state.infos.get("failed_reason", "") + assert failed_reason1 in state.infos.get("failed_reason", "") or \ + failed_reason2 in state.infos.get("failed_reason", "") assert self.collection_wrap.num_entities == 0 +@pytest.mark.skip() class TestBulkInsertAdvanced(TestcaseBaseBulkInsert): @pytest.mark.tags(CaseLabel.L3) @@ -2520,7 +2630,7 @@ def test_float_vector_from_multi_numpy_files(self, auto_id, dim, entities): ) self.collection_wrap.load() log.info(f"wait for load finished and be ready for search") - time.sleep(10) + time.sleep(20) loaded_segs = len(self.utility_wrap.get_query_segment_info(c_name)[0]) log.info(f"query seg info: {loaded_segs} segs loaded.") search_data = cf.gen_vectors(1, dim) @@ -2532,4 +2642,4 @@ def test_float_vector_from_multi_numpy_files(self, auto_id, dim, entities): limit=1, check_task=CheckTasks.check_search_results, check_items={"nq": 1, "limit": 1}, - ) + ) \ No newline at end of file diff --git a/tests/python_client/chaos/chaos_commons.py b/tests/python_client/chaos/chaos_commons.py index 9e02b08719835..bc45c98cc43b5 100644 --- a/tests/python_client/chaos/chaos_commons.py +++ b/tests/python_client/chaos/chaos_commons.py @@ -37,10 +37,22 @@ def gen_experiment_config(yaml): def start_monitor_threads(checkers={}): """start the threads by checkers""" + tasks = [] for k, ch in checkers.items(): ch._keep_running = True t = threading.Thread(target=ch.keep_running, args=(), name=k, daemon=True) t.start() + tasks.append(t) + return tasks + + +def check_thread_status(tasks): + """check the status of all threads""" + for t in tasks: + if t.is_alive(): + log.info(f"thread {t.name} is still running") + else: + log.info(f"thread {t.name} is not running") def get_env_variable_by_name(name): diff --git a/tests/python_client/chaos/checker.py b/tests/python_client/chaos/checker.py index 543b05e66446b..b4571468575ab 100644 --- a/tests/python_client/chaos/checker.py +++ b/tests/python_client/chaos/checker.py @@ -102,6 +102,8 @@ def insert(self, operation_name, collection_name, start_time, time_cost, result) self.buffer = [] def sink(self): + if len(self.buffer) == 0: + return df = pd.DataFrame(self.buffer) if not self.created_file: with request_lock: @@ -126,6 +128,9 @@ def __init__(self): df = df.sort_values(by='start_time') self.df = df self.chaos_info = get_chaos_info() + self.chaos_start_time = self.chaos_info['create_time'] if self.chaos_info is not None else None + self.chaos_end_time = self.chaos_info['delete_time'] if self.chaos_info is not None else None + self.recovery_time = self.chaos_info['recovery_time'] if self.chaos_info is not None else None def get_stage_success_rate(self): df = self.df @@ -179,7 +184,9 @@ def get_realtime_success_rate(self, interval=10): def show_result_table(self): table = PrettyTable() - table.field_names = ['operation_name', 'before_chaos', 'during_chaos', 'after_chaos'] + table.field_names = ['operation_name', 'before_chaos', + f'during_chaos\n{self.chaos_start_time}~{self.recovery_time}', + 'after_chaos'] data = self.get_stage_success_rate() for operation, values in data.items(): row = [operation, values['before_chaos'], values['during_chaos'], values['after_chaos']] @@ -335,7 +342,7 @@ def check_result(self): checker_name = self.__class__.__name__ checkers_result = f"{checker_name}, succ_rate: {succ_rate:.2f}, total: {total:03d}, average_time: {average_time:.4f}, max_time: {max_time:.4f}, min_time: {min_time:.4f}" log.info(checkers_result) - log.info(f"{checker_name} rsp times: {self.rsp_times}") + log.debug(f"{checker_name} rsp times: {self.rsp_times}") if len(self.fail_records) > 0: log.info(f"{checker_name} failed at {self.fail_records}") return checkers_result @@ -479,6 +486,7 @@ def keep_running(self): self.initial_entities += constants.DELTA_PER_INS else: self._fail += 1 + sleep(constants.WAIT_PER_OP * 6) class FlushChecker(Checker): @@ -513,7 +521,7 @@ def run_task(self): def keep_running(self): while self._keep_running: self.run_task() - sleep(constants.WAIT_PER_OP / 10) + sleep(constants.WAIT_PER_OP * 6) class InsertChecker(Checker): @@ -529,6 +537,7 @@ def __init__(self, collection_name=None, flush=False, shards_num=2, schema=None) self.scale = 1 * 10 ** 6 self.start_time_stamp = int(time.time() * self.scale) # us self.term_expr = f'{self.int64_field_name} >= {self.start_time_stamp}' + self.file_name = f"/tmp/ci_logs/insert_data_{uuid.uuid4()}.parquet" @trace() def insert(self): @@ -546,6 +555,7 @@ def insert(self): enable_traceback=enable_traceback, check_task=CheckTasks.check_nothing) if result: + # TODO: persist data to file self.inserted_data.extend(ts_data) return res, result @@ -617,7 +627,7 @@ def run_task(self): def keep_running(self): while self._keep_running: self.run_task() - sleep(constants.WAIT_PER_OP / 10) + sleep(constants.WAIT_PER_OP) class IndexChecker(Checker): @@ -639,7 +649,6 @@ def create_index(self): res, result = self.c_wrap.create_index(self.float_vector_field_name, constants.DEFAULT_INDEX_PARAM, index_name=self.index_name, - timeout=timeout, enable_traceback=enable_traceback, check_task=CheckTasks.check_nothing) return res, result @@ -654,7 +663,7 @@ def run_task(self): def keep_running(self): while self._keep_running: self.run_task() - sleep(constants.WAIT_PER_OP / 10) + sleep(constants.WAIT_PER_OP * 6) class QueryChecker(Checker): @@ -810,12 +819,24 @@ def __init__(self, collection_name=None, schema=None): if collection_name is None: collection_name = cf.gen_unique_str("DropChecker_") super().__init__(collection_name=collection_name, schema=schema) + self.collection_pool = [] + self.gen_collection_pool(schema=self.schema) + + def gen_collection_pool(self, pool_size=50, schema=None): + for i in range(pool_size): + collection_name = cf.gen_unique_str("DropChecker_") + res, result = self.c_wrap.init_collection(name=collection_name, schema=schema) + if result: + self.collection_pool.append(collection_name) @trace() def drop(self): res, result = self.c_wrap.drop() + if result: + self.collection_pool.remove(self.c_wrap.name) return res, result + @exception_handler() def run_task(self): res, result = self.drop() return res, result @@ -824,12 +845,17 @@ def keep_running(self): while self._keep_running: res, result = self.run_task() if result: - self.c_wrap.init_collection( - name=cf.gen_unique_str("DropChecker_"), - schema=cf.gen_default_collection_schema(), - timeout=timeout, - check_task=CheckTasks.check_nothing) - sleep(constants.WAIT_PER_OP / 10) + try: + if len(self.collection_pool) <= 10: + self.gen_collection_pool(schema=self.schema) + except Exception as e: + log.error(f"Failed to generate collection pool: {e}") + try: + c_name = self.collection_pool[0] + self.c_wrap.init_collection(name=c_name) + except Exception as e: + log.error(f"Failed to init new collection: {e}") + sleep(constants.WAIT_PER_OP) class LoadBalanceChecker(Checker): diff --git a/tests/python_client/chaos/cluster-values.yaml b/tests/python_client/chaos/cluster-values.yaml index d6b1ff272da4a..316abd7d92ffd 100644 --- a/tests/python_client/chaos/cluster-values.yaml +++ b/tests/python_client/chaos/cluster-values.yaml @@ -8,10 +8,16 @@ image: tag: master-latest pullPolicy: IfNotPresent +indexNode: + resources: + requests: + cpu: 2 + limits: + cpu: 8 + etcd: replicaCount: 3 image: - debug: true repository: milvusdb/etcd tag: 3.5.5-r2 @@ -131,6 +137,9 @@ pulsar: extraConfigFiles: user.yaml: |+ - dataNode: - memory: - forceSyncEnable: false \ No newline at end of file + dataCoord: + compaction: + indexBasedCompaction: false + indexCoord: + scheduler: + interval: 100 \ No newline at end of file diff --git a/tests/python_client/chaos/conftest.py b/tests/python_client/chaos/conftest.py index 2531c1746bd54..2f75d64e8d5ca 100644 --- a/tests/python_client/chaos/conftest.py +++ b/tests/python_client/chaos/conftest.py @@ -6,10 +6,10 @@ def pytest_addoption(parser): parser.addoption("--role_type", action="store", default="activated", help="role_type") parser.addoption("--target_component", action="store", default="querynode", help="target_component") parser.addoption("--target_pod", action="store", default="etcd_leader", help="target_pod") + parser.addoption("--target_scope", action="store", default="all", help="target_scope") parser.addoption("--target_number", action="store", default="1", help="target_number") - parser.addoption("--chaos_duration", action="store", default="1m", help="chaos_duration") - parser.addoption("--chaos_interval", action="store", default="10s", help="chaos_interval") - parser.addoption("--request_duration", action="store", default="5m", help="request_duration") + parser.addoption("--chaos_duration", action="store", default="7m", help="chaos_duration") + parser.addoption("--chaos_interval", action="store", default="2m", help="chaos_interval") parser.addoption("--is_check", action="store", type=bool, default=False, help="is_check") parser.addoption("--wait_signal", action="store", type=bool, default=True, help="wait_signal") @@ -34,6 +34,11 @@ def target_pod(request): return request.config.getoption("--target_pod") +@pytest.fixture +def target_scope(request): + return request.config.getoption("--target_scope") + + @pytest.fixture def target_number(request): return request.config.getoption("--target_number") @@ -49,11 +54,6 @@ def chaos_interval(request): return request.config.getoption("--chaos_interval") -@pytest.fixture -def request_duration(request): - return request.config.getoption("--request_duration") - - @pytest.fixture def is_check(request): return request.config.getoption("--is_check") diff --git a/tests/python_client/chaos/nats-standalone-values.yaml b/tests/python_client/chaos/nats-standalone-values.yaml index 55979cb371aaa..f3440cf818db1 100644 --- a/tests/python_client/chaos/nats-standalone-values.yaml +++ b/tests/python_client/chaos/nats-standalone-values.yaml @@ -27,7 +27,6 @@ kafka: etcd: replicaCount: 3 image: - debug: true repository: milvusdb/etcd tag: 3.5.5-r2 minio: @@ -38,4 +37,10 @@ pulsar: extraConfigFiles: user.yaml: |+ mq: - type: natsmq \ No newline at end of file + type: natsmq + dataCoord: + compaction: + indexBasedCompaction: false + indexCoord: + scheduler: + interval: 100 \ No newline at end of file diff --git a/tests/python_client/chaos/standalone-values.yaml b/tests/python_client/chaos/standalone-values.yaml index 52e3f26144379..8a08f058d44b7 100644 --- a/tests/python_client/chaos/standalone-values.yaml +++ b/tests/python_client/chaos/standalone-values.yaml @@ -27,7 +27,6 @@ kafka: etcd: replicaCount: 3 image: - debug: true repository: milvusdb/etcd tag: 3.5.5-r2 minio: @@ -37,6 +36,9 @@ pulsar: extraConfigFiles: user.yaml: |+ - dataNode: - memory: - forceSyncEnable: false \ No newline at end of file + dataCoord: + compaction: + indexBasedCompaction: false + indexCoord: + scheduler: + interval: 100 \ No newline at end of file diff --git a/tests/python_client/chaos/test_chaos_apply.py b/tests/python_client/chaos/test_chaos_apply.py index e162a715aa8df..f1367803bf4c6 100644 --- a/tests/python_client/chaos/test_chaos_apply.py +++ b/tests/python_client/chaos/test_chaos_apply.py @@ -56,15 +56,14 @@ def teardown(self): chaos_res.delete(meta_name, raise_ex=False) sleep(2) - def test_chaos_apply(self, chaos_type, target_component, target_number, chaos_duration, chaos_interval, wait_signal): + def test_chaos_apply(self, chaos_type, target_component, target_scope, target_number, chaos_duration, chaos_interval, wait_signal): # start the monitor threads to check the milvus ops log.info("*********************Chaos Test Start**********************") if wait_signal: log.info("need wait signal to start chaos") ready_for_chaos = wait_signal_to_apply_chaos() if not ready_for_chaos: - log.info("did not get the signal to apply chaos") - raise Exception + log.info("get the signal to apply chaos timeout") else: log.info("get the signal to apply chaos") log.info(connections.get_connection_addr('default')) @@ -78,6 +77,7 @@ def test_chaos_apply(self, chaos_type, target_component, target_number, chaos_du update_key_value(chaos_config, "app.kubernetes.io/instance", release_name) update_key_value(chaos_config, "namespaces", [self.milvus_ns]) update_key_value(chaos_config, "value", target_number) + update_key_value(chaos_config, "mode", target_scope) self.chaos_config = chaos_config if "s" in chaos_interval: schedule = f"*/{chaos_interval[:-1]} * * * * *" diff --git a/tests/python_client/chaos/test_chaos_apply_to_determined_pod.py b/tests/python_client/chaos/test_chaos_apply_to_determined_pod.py index 1790e94fa5747..642bd62837482 100644 --- a/tests/python_client/chaos/test_chaos_apply_to_determined_pod.py +++ b/tests/python_client/chaos/test_chaos_apply_to_determined_pod.py @@ -2,6 +2,8 @@ import time from time import sleep from pathlib import Path +from datetime import datetime +import json from pymilvus import connections from common.cus_resource_opts import CustomResourceOperations as CusResource from common.milvus_sys import MilvusSys @@ -9,6 +11,7 @@ import logging as log from utils.util_k8s import (wait_pods_ready, get_milvus_instance_name, get_milvus_deploy_tool, get_etcd_leader, get_etcd_followers) +from utils.util_common import wait_signal_to_apply_chaos import constants @@ -54,9 +57,17 @@ def teardown(self): chaos_res.delete(meta_name, raise_ex=False) sleep(2) - def test_chaos_apply(self, chaos_type, target_pod, chaos_duration, chaos_interval): + def test_chaos_apply(self, chaos_type, target_pod, chaos_duration, chaos_interval, wait_signal): # start the monitor threads to check the milvus ops log.info("*********************Chaos Test Start**********************") + if wait_signal: + log.info("need wait signal to start chaos") + ready_for_chaos = wait_signal_to_apply_chaos() + if not ready_for_chaos: + log.info("did not get the signal to apply chaos") + raise Exception + else: + log.info("get the signal to apply chaos") log.info(connections.get_connection_addr('default')) release_name = self.release_name deploy_tool = get_milvus_deploy_tool(self.milvus_ns, self.milvus_sys) @@ -71,6 +82,9 @@ def test_chaos_apply(self, chaos_type, target_pod, chaos_duration, chaos_interva if etcd_followers is None: raise Exception("no etcd followers") target_pod_list.extend(etcd_followers) + if len(target_pod_list) >=2: + # only choose one follower to apply chaos + target_pod_list = target_pod_list[:1] log.info(f"target_pod_list: {target_pod_list}") chaos_type = chaos_type.replace('_', '-') chaos_config = cc.gen_experiment_config(f"{str(Path(__file__).absolute().parent)}/chaos_objects/template/{chaos_type}-by-pod-list.yaml") @@ -85,6 +99,7 @@ def test_chaos_apply(self, chaos_type, target_pod, chaos_duration, chaos_interva version=constants.CHAOS_VERSION, namespace=constants.CHAOS_NAMESPACE) chaos_res.create(chaos_config) + create_time = datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S.%f') log.info("chaos injected") res = chaos_res.list_all() chaos_list = [r['metadata']['name'] for r in res['items']] @@ -96,6 +111,7 @@ def test_chaos_apply(self, chaos_type, target_pod, chaos_duration, chaos_interva sleep(chaos_duration) # delete chaos chaos_res.delete(meta_name) + delete_time = datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S.%f') log.info("chaos deleted") res = chaos_res.list_all() chaos_list = [r['metadata']['name'] for r in res['items']] @@ -113,6 +129,19 @@ def test_chaos_apply(self, chaos_type, target_pod, chaos_duration, chaos_interva log.info("all pods are ready") pods_ready_time = time.time() - t0 log.info(f"pods ready time: {pods_ready_time}") + + recovery_time = datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S.%f') + event_records = { + "chaos_type": chaos_type, + "target_component": target_pod, + "meta_name": meta_name, + "create_time": create_time, + "delete_time": delete_time, + "recovery_time": recovery_time + } + # save event records to json file + with open(constants.CHAOS_INFO_SAVE_PATH, 'w') as f: + json.dump(event_records, f) # reconnect to test the service healthy start_time = time.time() end_time = start_time + 120 diff --git a/tests/python_client/chaos/testcases/test_concurrent_operation.py b/tests/python_client/chaos/testcases/test_concurrent_operation.py index 8e2892c1a0f40..e72ddcfbdc3a9 100644 --- a/tests/python_client/chaos/testcases/test_concurrent_operation.py +++ b/tests/python_client/chaos/testcases/test_concurrent_operation.py @@ -5,7 +5,6 @@ from pymilvus import connections from chaos.checker import (InsertChecker, FlushChecker, - CompactChecker, SearchChecker, QueryChecker, DeleteChecker, @@ -20,7 +19,7 @@ from chaos.chaos_commons import assert_statistic from common.common_type import CaseLabel from chaos import constants -from delayed_assert import expect, assert_expectations +from delayed_assert import assert_expectations def get_all_collections(): @@ -74,7 +73,6 @@ def init_health_checkers(self, collection_name=None): Op.flush: FlushChecker(collection_name=c_name), Op.search: SearchChecker(collection_name=c_name), Op.query: QueryChecker(collection_name=c_name), - Op.compact: CompactChecker(collection_name=c_name), Op.delete: DeleteChecker(collection_name=c_name), } self.health_checkers = checkers diff --git a/tests/python_client/chaos/testcases/test_single_request_operation.py b/tests/python_client/chaos/testcases/test_single_request_operation.py index 70f61a043cdc1..b7fa746ebbb3d 100644 --- a/tests/python_client/chaos/testcases/test_single_request_operation.py +++ b/tests/python_client/chaos/testcases/test_single_request_operation.py @@ -82,7 +82,7 @@ def test_operations(self, request_duration, is_check): event_records.insert("init_health_checkers", "start") self.init_health_checkers(collection_name=c_name) event_records.insert("init_health_checkers", "finished") - cc.start_monitor_threads(self.health_checkers) + tasks = cc.start_monitor_threads(self.health_checkers) log.info("*********************Load Start**********************") # wait request_duration request_duration = request_duration.replace("h", "*3600+").replace("m", "*60+").replace("s", "") @@ -102,6 +102,7 @@ def test_operations(self, request_duration, is_check): # wait all pod ready wait_pods_ready(self.milvus_ns, f"app.kubernetes.io/instance={self.release_name}") time.sleep(60) + cc.check_thread_status(tasks) for k, v in self.health_checkers.items(): v.pause() ra = ResultAnalyzer() diff --git a/tests/python_client/chaos/testcases/test_single_request_operation_for_standby.py b/tests/python_client/chaos/testcases/test_single_request_operation_for_standby.py index 059b6db1bffcd..ea82d14b81378 100644 --- a/tests/python_client/chaos/testcases/test_single_request_operation_for_standby.py +++ b/tests/python_client/chaos/testcases/test_single_request_operation_for_standby.py @@ -19,7 +19,7 @@ from delayed_assert import assert_expectations from utils.util_k8s import (get_milvus_instance_name, get_milvus_deploy_tool, - reset_healthy_checker_after_standby_activated) + record_time_when_standby_activated) class TestBase: @@ -84,19 +84,22 @@ def test_operations(self, request_duration, target_component, is_check): if request_duration[-1] == "+": request_duration = request_duration[:-1] request_duration = eval(request_duration) - # start a thread to reset health_checkers when standby is activated. - t = threading.Thread(target=reset_healthy_checker_after_standby_activated, - args=(self.milvus_ns, self.release_name, target_component, self.health_checkers), + # start a thread to record the time when standby is activated + t = threading.Thread(target=record_time_when_standby_activated, + args=(self.milvus_ns, self.release_name, target_component), kwargs={"timeout": request_duration//2}, daemon=True) t.start() - # t.join() log.info('start a thread to reset health_checkers when standby is activated') for i in range(10): sleep(request_duration//10) for k, v in self.health_checkers.items(): v.check_result() if is_check: - assert_statistic(self.health_checkers) - assert_expectations() + assert_statistic(self.health_checkers, succ_rate_threshold=0.99) + for k, v in self.health_checkers.items(): + log.info(f"{k} rto: {v.get_rto()}") + rto = v.get_rto() + pytest.assume(rto < 30, f"{k} rto expect 30s but get {rto}s") # rto should be less than 30s + log.info("*********************Chaos Test Completed**********************") \ No newline at end of file diff --git a/tests/python_client/common/bulk_insert_data.py b/tests/python_client/common/bulk_insert_data.py index 35ec03598a1a3..9dcf4ac2ee07a 100644 --- a/tests/python_client/common/bulk_insert_data.py +++ b/tests/python_client/common/bulk_insert_data.py @@ -69,7 +69,7 @@ def gen_str_invalid_vectors(nb, dim): def gen_binary_vectors(nb, dim): # binary: each int presents 8 dimension - # so if binary vector dimension is 16,use [x, y], which x and y could be any int between 0 to 255 + # so if binary vector dimension is 16,use [x, y], which x and y could be any int between 0 and 255 vectors = [[random.randint(0, 255) for _ in range(dim)] for _ in range(nb)] return vectors @@ -276,6 +276,21 @@ def gen_string_in_numpy_file(dir, data_field, rows, start=0, force=False): return file_name +def gen_bool_in_numpy_file(dir, data_field, rows, start=0, force=False): + file_name = f"{data_field}.npy" + file = f"{dir}/{file_name}" + if not os.path.exists(file) or force: + # non vector columns + data = [] + if rows > 0: + data = [random.choice([True, False]) for i in range(start, rows+start)] + arr = np.array(data) + # print(f"file_name: {file_name} data type: {arr.dtype}") + log.info(f"file_name: {file_name} data type: {arr.dtype} data shape: {arr.shape}") + np.save(file, arr) + return file_name + + def gen_int_or_float_in_numpy_file(dir, data_field, rows, start=0, force=False): file_name = f"{data_field}.npy" file = f"{dir}/{file_name}" @@ -378,6 +393,8 @@ def gen_npy_files(float_vector, rows, dim, data_fields, file_nums=1, err_type="" rows=rows, dim=dim, force=force) elif data_field == DataField.string_field: # string field for numpy not supported yet at 2022-10-17 file_name = gen_string_in_numpy_file(dir=data_source, data_field=data_field, rows=rows, force=force) + elif data_field == DataField.bool_field: + file_name = gen_bool_in_numpy_file(dir=data_source, data_field=data_field, rows=rows, force=force) else: file_name = gen_int_or_float_in_numpy_file(dir=data_source, data_field=data_field, rows=rows, force=force) @@ -468,8 +485,8 @@ def prepare_bulk_insert_json_files(minio_endpoint="", bucket_name="milvus-bucket return files -def prepare_bulk_insert_numpy_files(minio_endpoint="", bucket_name="milvus-bucket", rows=100, dim=128, data_fields=[DataField.vec_field], - float_vector=True, file_nums=1, force=False): +def prepare_bulk_insert_numpy_files(minio_endpoint="", bucket_name="milvus-bucket", rows=100, dim=128, + data_fields=[DataField.vec_field], float_vector=True, file_nums=1, force=False): """ Generate column based files based on params in numpy format and copy them to the minio Note: each field in data_fields would be generated one numpy file. @@ -484,25 +501,24 @@ def prepare_bulk_insert_numpy_files(minio_endpoint="", bucket_name="milvus-bucke :type float_vector: boolean :param: data_fields: data fields to be generated in the file(s): - it support one or all of [int_pk, vectors, int, float] - Note: it does not automatically adds pk field + it supports one or all of [int_pk, vectors, int, float] + Note: it does not automatically add pk field :type data_fields: list :param file_nums: file numbers to be generated The file(s) would be generated in data_source folder if file_nums = 1 - The file(s) would be generated in different subfolers if file_nums > 1 + The file(s) would be generated in different sub-folders if file_nums > 1 :type file_nums: int :param force: re-generate the file(s) regardless existing or not :type force: boolean Return: List - File name list or file name with subfolder list + File name list or file name with sub-folder list """ files = gen_npy_files(rows=rows, dim=dim, float_vector=float_vector, data_fields=data_fields, file_nums=file_nums, force=force) copy_files_to_minio(host=minio_endpoint, r_source=data_source, files=files, bucket_name=bucket_name, force=force) - return files - + return files \ No newline at end of file diff --git a/tests/python_client/common/common_func.py b/tests/python_client/common/common_func.py index 10b198e69e5b6..966932784f87c 100644 --- a/tests/python_client/common/common_func.py +++ b/tests/python_client/common/common_func.py @@ -333,7 +333,7 @@ def gen_default_rows_data(nb=ct.default_nb, dim=ct.default_dim, start=0, with_js dict = {ct.default_int64_field_name: i, ct.default_float_field_name: i*1.0, ct.default_string_field_name: str(i), - ct.default_json_field_name: {"number": i}, + ct.default_json_field_name: {"number": i, "float": i*1.0}, ct.default_float_vec_field_name: gen_vectors(1, dim)[0] } if with_json is False: @@ -850,6 +850,12 @@ def gen_invalid_search_params_type(): continue annoy_search_param = {"index_type": index_type, "search_params": {"search_k": search_k}} search_params.append(annoy_search_param) + elif index_type == "SCANN": + for reorder_k in ct.get_invalid_ints: + if isinstance(reorder_k, int): + continue + scann_search_param = {"index_type": index_type, "search_params": {"nprobe": 8, "reorder_k": reorder_k}} + search_params.append(scann_search_param) elif index_type == "DISKANN": for search_list in ct.get_invalid_ints: diskann_search_param = {"index_type": index_type, "search_params": {"search_list": search_list}} @@ -883,6 +889,10 @@ def gen_search_param(index_type, metric_type="L2"): for search_k in [1000, 5000]: annoy_search_param = {"metric_type": metric_type, "params": {"search_k": search_k}} search_params.append(annoy_search_param) + elif index_type == "SCANN": + for reorder_k in [1200, 3000]: + scann_search_param = {"metric_type": metric_type, "params": {"nprobe": 64, "reorder_k": reorder_k}} + search_params.append(scann_search_param) elif index_type == "DISKANN": for search_list in [20, 300, 1500]: diskann_search_param = {"metric_type": metric_type, "params": {"search_list": search_list}} @@ -925,7 +935,10 @@ def gen_invalid_search_param(index_type, metric_type="L2"): for search_list in ["-1"]: diskann_search_param = {"metric_type": metric_type, "params": {"search_list": search_list}} search_params.append(diskann_search_param) - + elif index_type == "SCANN": + for reorder_k in [-1]: + scann_search_param = {"metric_type": metric_type, "params": {"reorder_k": reorder_k, "nprobe": 10}} + search_params.append(scann_search_param) else: log.error("Invalid index_type.") raise Exception("Invalid index_type.") @@ -948,54 +961,71 @@ def gen_normal_expressions(): "(int64 > 0 && int64 < 400) or (int64 > 500 && int64 < 1000)", "int64 not in [1, 2, 3]", "int64 in [1, 2, 3] and float != 2", - "int64 == 0 || int64 == 1 || int64 == 2", - "0 < int64 < 400", - "500 <= int64 < 1000", + "int64 == 0 || float == 10**2 || (int64 + 1) == 3", + "0 <= int64 < 400 and int64 % 100 == 0", "200+300 < int64 <= 500+500", - "int64 in [300/2, 900%40, -10*30+800, 2048/2%200, (100+200)*2]", - "float in [+3**6, 2**10/2]", - "(int64 % 100 == 0) && int64 < 500", - "float <= 4**5/2 && float > 500-1 && float != 500/2+260", "int64 > 400 && int64 < 200", - "float < -2**8", - "(int64 + 1) == 3 || int64 * 2 == 64 || float == 10**2" + "int64 in [300/2, 900%40, -10*30+800, (100+200)*2] or float in [+3**6, 2**10/2]", + "float <= -4**5/2 && float > 500-1 && float != 500/2+260" ] return expressions -def gen_field_compare_expressions(): +def gen_json_field_expressions(): expressions = [ - "int64_1 | int64_2 == 1", - "int64_1 && int64_2 ==1", - "int64_1 + int64_2 == 10", - "int64_1 - int64_2 == 2", - "int64_1 * int64_2 == 8", - "int64_1 / int64_2 == 2", - "int64_1 ** int64_2 == 4", - "int64_1 % int64_2 == 0", - "int64_1 in int64_2", - "int64_1 + int64_2 >= 10" + "json_field['number'] > 0", + "0 <= json_field['number'] < 400 or 1000 > json_field['number'] >= 500", + "json_field['number'] not in [1, 2, 3]", + "json_field['number'] in [1, 2, 3] and json_field['float'] != 2", + "json_field['number'] == 0 || json_field['float'] == 10**2 || json_field['number'] + 1 == 3", + "json_field['number'] < 400 and json_field['number'] >= 100 and json_field['number'] % 100 == 0", + "json_field['float'] > 400 && json_field['float'] < 200", + "json_field['number'] in [300/2, -10*30+800, (100+200)*2] or json_field['float'] in [+3**6, 2**10/2]", + "json_field['float'] <= -4**5/2 && json_field['float'] > 500-1 && json_field['float'] != 500/2+260" ] return expressions -def gen_normal_string_expressions(field): - expressions = [ - f"\"0\"< {field} < \"3\"", - f"{field} >= \"0\"", - f"({field} > \"0\" && {field} < \"100\") or ({field} > \"200\" && {field} < \"300\")", - f"\"0\" <= {field} <= \"100\"", - f"{field} == \"0\"|| {field} == \"1\"|| {field} ==\"2\"", - f"{field} != \"0\"", - f"{field} not in [\"0\", \"1\", \"2\"]", - f"{field} in [\"0\", \"1\", \"2\"]" - ] +def gen_field_compare_expressions(fields1=None, fields2=None): + if fields1 is None: + fields1 = ["int64_1"] + fields2 = ["int64_2"] + expressions = [] + for field1, field2 in zip(fields1, fields2): + expression = [ + f"{field1} | {field2} == 1", + f"{field1} + {field2} <= 10 || {field1} - {field2} == 2", + f"{field1} * {field2} >= 8 && {field1} / {field2} < 2", + f"{field1} ** {field2} != 4 and {field1} + {field2} > 5", + f"{field1} not in {field2}", + f"{field1} in {field2}", + ] + expressions.extend(expression) + return expressions + + +def gen_normal_string_expressions(fields=None): + if fields is None: + fields = [ct.default_string_field_name] + expressions = [] + for field in fields: + expression = [ + f"\"0\"< {field} < \"3\"", + f"{field} >= \"0\"", + f"({field} > \"0\" && {field} < \"100\") or ({field} > \"200\" && {field} < \"300\")", + f"\"0\" <= {field} <= \"100\"", + f"{field} == \"0\"|| {field} == \"1\"|| {field} ==\"2\"", + f"{field} != \"0\"", + f"{field} not in [\"0\", \"1\", \"2\"]", + f"{field} in [\"0\", \"1\", \"2\"]" + ] + expressions.extend(expression) return expressions def gen_invalid_string_expressions(): expressions = [ - "varchar in [0, \"1\"]", + "varchar in [0, \"1\"]", "varchar not in [\"0\", 1, 2]" ] return expressions @@ -1187,6 +1217,29 @@ def index_to_dict(index): } +def assert_json_contains(expr, list_data): + result_ids = [] + expr_prefix = expr.split('(', 1)[0] + exp_ids = eval(expr.split(', ', 1)[1].split(')', 1)[0]) + if expr_prefix in ["json_contains", "JSON_CONTAINS"]: + for i in range(len(list_data)): + if exp_ids in list_data[i]: + result_ids.append(i) + elif expr_prefix in ["json_contains_all", "JSON_CONTAINS_ALL"]: + for i in range(len(list_data)): + set_list_data = set(tuple(element) if isinstance(element, list) else element for element in list_data[i]) + if set(exp_ids).issubset(set_list_data): + result_ids.append(i) + elif expr_prefix in ["json_contains_any", "JSON_CONTAINS_ANY"]: + for i in range(len(list_data)): + set_list_data = set(tuple(element) if isinstance(element, list) else element for element in list_data[i]) + if set(exp_ids) & set_list_data: + result_ids.append(i) + else: + log.warning("unknown expr: %s" % expr) + return result_ids + + def assert_equal_index(index_1, index_2): return index_to_dict(index_1) == index_to_dict(index_2) diff --git a/tests/python_client/common/common_type.py b/tests/python_client/common/common_type.py index f1d0e16397811..7db80d6d84837 100644 --- a/tests/python_client/common/common_type.py +++ b/tests/python_client/common/common_type.py @@ -21,7 +21,9 @@ default_diskann_index = {"index_type": "DISKANN", "metric_type": "COSINE", "params": {}} default_diskann_search_params = {"metric_type": "COSINE", "params": {"search_list": 30}} max_top_k = 16384 -max_partition_num = 4096 # 256 +max_partition_num = 4096 +max_role_num = 10 +default_partition_num = 64 # default num_partitions for partition key feature default_segment_row_limit = 1000 default_server_segment_row_limit = 1024 * 512 default_alias = "default" @@ -226,11 +228,11 @@ ] """ Specially defined list """ -all_index_types = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "DISKANN", "BIN_FLAT", "BIN_IVF_FLAT", +all_index_types = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "HNSW", "SCANN", "DISKANN", "BIN_FLAT", "BIN_IVF_FLAT", "GPU_IVF_FLAT", "GPU_IVF_PQ"] default_index_params = [{"nlist": 128}, {"nlist": 128}, {"nlist": 128}, {"nlist": 128, "m": 16, "nbits": 8}, - {"M": 48, "efConstruction": 500}, {}, {"nlist": 128}, {"nlist": 128}, + {"M": 48, "efConstruction": 500}, {"nlist": 128}, {}, {"nlist": 128}, {"nlist": 128}, {"nlist": 64}, {"nlist": 64, "m": 16, "nbits": 8}] Handler_type = ["GRPC", "HTTP"] diff --git a/tests/python_client/config/log_config.py b/tests/python_client/config/log_config.py index 8707af4a08068..3819f04b93b03 100644 --- a/tests/python_client/config/log_config.py +++ b/tests/python_client/config/log_config.py @@ -1,5 +1,5 @@ import os -import datetime +from pathlib import Path class LogConfig: def __init__(self): @@ -23,10 +23,10 @@ def get_env_variable(var="CI_LOG_PATH"): @staticmethod def create_path(log_path): - if not os.path.isdir(str(log_path)): - print("[create_path] folder(%s) is not exist." % log_path) - print("[create_path] create path now...") - os.makedirs(log_path) + print("[create_path] folder(%s) is not exist." % log_path) + print("[create_path] create path now...") + folder_path = Path(str(log_path)) + folder_path.mkdir(parents=True, exist_ok=True) def get_default_config(self): """ Make sure the path exists """ diff --git a/tests/python_client/conftest.py b/tests/python_client/conftest.py index df40c3fba6150..1119676fab607 100644 --- a/tests/python_client/conftest.py +++ b/tests/python_client/conftest.py @@ -47,6 +47,7 @@ def pytest_addoption(parser): parser.addoption('--minio_host', action='store', default="localhost", help="minio service's ip") parser.addoption('--uri', action='store', default="", help="uri for high level api") parser.addoption('--token', action='store', default="", help="token for high level api") + parser.addoption("--request_duration", action="store", default="10m", help="request_duration") @pytest.fixture @@ -81,7 +82,7 @@ def secure(request): @pytest.fixture def milvus_ns(request): - return request.config.getoption("--milvus_ns") + return request.config.getoption("--milvus_ns") @pytest.fixture @@ -185,6 +186,10 @@ def uri(request): def token(request): return request.config.getoption("--token") +@pytest.fixture +def request_duration(request): + return request.config.getoption("--request_duration") + """ fixture func """ diff --git a/tests/python_client/deploy/cluster-values.yaml b/tests/python_client/deploy/cluster-values.yaml index dd05c412edb50..67ac127d2c277 100644 --- a/tests/python_client/deploy/cluster-values.yaml +++ b/tests/python_client/deploy/cluster-values.yaml @@ -11,7 +11,6 @@ image: etcd: replicaCount: 3 image: - debug: true repository: milvusdb/etcd tag: 3.5.5-r2 @@ -127,8 +126,11 @@ pulsar: backlogQuotaDefaultLimitGB: "8" backlogQuotaDefaultRetentionPolicy: producer_exception -# extraConfigFiles: -# user.yaml: |+ -# dataNode: -# memory: -# forceSyncEnable: false \ No newline at end of file +extraConfigFiles: + user.yaml: |+ + dataCoord: + compaction: + indexBasedCompaction: false + indexCoord: + scheduler: + interval: 100 \ No newline at end of file diff --git a/tests/python_client/deploy/milvus_crd.yaml b/tests/python_client/deploy/milvus_crd.yaml index 26b7e6d5f8314..41cab3351122b 100644 --- a/tests/python_client/deploy/milvus_crd.yaml +++ b/tests/python_client/deploy/milvus_crd.yaml @@ -32,13 +32,11 @@ spec: image: milvusdb/milvus:2.2.0-20230208-2e4d64ec disableMetric: false dataNode: - replicas: 2 + replicas: 3 indexNode: - replicas: 2 + replicas: 3 queryNode: - replicas: 2 - mixCoord: - replicas: 1 + replicas: 3 dependencies: msgStreamType: kafka etcd: diff --git a/tests/python_client/deploy/scripts/utils.py b/tests/python_client/deploy/scripts/utils.py index a5bf21e53ad01..85ced42d94214 100644 --- a/tests/python_client/deploy/scripts/utils.py +++ b/tests/python_client/deploy/scripts/utils.py @@ -11,7 +11,7 @@ logger.add(sys.stderr, format= "{time:YYYY-MM-DD HH:mm:ss.SSS} | " "{level: <8} | " "{thread.name} |" - "{name}:{function}:{line} - {message}", + "{name}:{function}:{line} - {message}", level="INFO") pymilvus_version = pymilvus.__version__ @@ -95,7 +95,7 @@ def create_collections_and_insert_data(prefix, flush=True, count=3000, collectio for index_name in all_index_types[:collection_cnt]: logger.info("\nCreate collection...") col_name = prefix + index_name - collection = Collection(name=col_name, schema=default_schema) + collection = Collection(name=col_name, schema=default_schema) logger.info(f"collection name: {col_name}") logger.info(f"begin insert, count: {count} nb: {nb}") times = int(count // nb) @@ -118,7 +118,7 @@ def create_collections_and_insert_data(prefix, flush=True, count=3000, collectio collection.num_entities if j == times - 3: collection.compact() - + logger.info(f"end insert, time: {total_time:.4f}") if flush: @@ -183,9 +183,17 @@ def create_index(prefix): index["params"] = index_params_map[index_name] if index_name in ["BIN_FLAT", "BIN_IVF_FLAT"]: index["metric_type"] = "HAMMING" - t0 = time.time() - c.create_index(field_name="float_vector", index_params=index) - logger.info(f"create index time: {time.time() - t0:.4f}") + index_info_list = [x.to_dict() for x in c.indexes] + logger.info(index_info_list) + is_indexed = False + for index_info in index_info_list: + if "metric_type" in index_info.keys() or "metric_type" in index_info["index_param"]: + is_indexed = True + logger.info(f"collection {col_name} has been indexed with {index_info}") + if not is_indexed: + t0 = time.time() + c.create_index(field_name="float_vector", index_params=index) + logger.info(f"create index time: {time.time() - t0:.4f}") if replica_number > 0: c.load(replica_number=replica_number) diff --git a/tests/python_client/deploy/standalone-values.yaml b/tests/python_client/deploy/standalone-values.yaml index 34c2075e52ae2..647713311ad5f 100644 --- a/tests/python_client/deploy/standalone-values.yaml +++ b/tests/python_client/deploy/standalone-values.yaml @@ -26,7 +26,6 @@ kafka: etcd: replicaCount: 3 image: - debug: true repository: milvusdb/etcd tag: 3.5.5-r2 minio: @@ -34,8 +33,11 @@ minio: pulsar: enabled: false -# extraConfigFiles: -# user.yaml: |+ -# dataNode: -# memory: -# forceSyncEnable: false \ No newline at end of file +extraConfigFiles: + user.yaml: |+ + dataCoord: + compaction: + indexBasedCompaction: false + indexCoord: + scheduler: + interval: 100 \ No newline at end of file diff --git a/tests/python_client/deploy/testcases/test_action_second_deployment.py b/tests/python_client/deploy/testcases/test_action_second_deployment.py index 75c705d1b9276..76d94b8912e27 100644 --- a/tests/python_client/deploy/testcases/test_action_second_deployment.py +++ b/tests/python_client/deploy/testcases/test_action_second_deployment.py @@ -133,7 +133,7 @@ def test_check(self, all_collection_name, data_size): is_vector_indexed = False index_infos = [index.to_dict() for index in collection_w.indexes] for index_info in index_infos: - if "metric_type" in index_info.keys(): + if "metric_type" in index_info.keys() or "metric_type" in index_info["index_param"]: is_vector_indexed = True break if is_vector_indexed is False: @@ -143,7 +143,7 @@ def test_check(self, all_collection_name, data_size): # search and query if "empty" in name: - # if the collection is empty, the search result should be empty, so no need to check + # if the collection is empty, the search result should be empty, so no need to check check_task = None else: check_task = CheckTasks.check_search_results diff --git a/tests/python_client/requirements.txt b/tests/python_client/requirements.txt index 598d9c1aa0a52..b24114dd006de 100644 --- a/tests/python_client/requirements.txt +++ b/tests/python_client/requirements.txt @@ -12,7 +12,7 @@ allure-pytest==2.7.0 pytest-print==0.2.1 pytest-level==0.1.1 pytest-xdist==2.5.0 -pymilvus==2.3.0b0.post1.dev127 +pymilvus==2.3.1.post1.dev8 pytest-rerunfailures==9.1.1 git+https://github.com/Projectplace/pytest-tags ndg-httpsclient @@ -44,10 +44,11 @@ loguru==0.7.0 # util psutil==5.9.4 pandas==1.5.3 +tenacity==8.1.0 # for standby test -etcd-sdk-python==0.0.2 +etcd-sdk-python==0.0.4 # for test result anaylszer prettytable==3.8.0 pyarrow==11.0.0 -fastparquet==2023.7.0 \ No newline at end of file +fastparquet==2023.7.0 diff --git a/tests/python_client/testcases/test_alias.py b/tests/python_client/testcases/test_alias.py index effb6433dda62..7a52e497601be 100644 --- a/tests/python_client/testcases/test_alias.py +++ b/tests/python_client/testcases/test_alias.py @@ -43,7 +43,7 @@ def test_alias_create_alias_with_invalid_name(self, alias_name): collection_w = self.init_collection_wrap(name=c_name, schema=default_schema, check_task=CheckTasks.check_collection_property, check_items={exp_name: c_name, exp_schema: default_schema}) - error = {ct.err_code: 1, ct.err_msg: "Invalid collection alias"} + error = {ct.err_code: 1100, ct.err_msg: "Invalid collection alias"} self.utility_wrap.create_alias(collection_w.name, alias_name, check_task=CheckTasks.err_res, check_items=error) @@ -424,8 +424,9 @@ def test_alias_create_duplication_alias(self): collection_2 = self.init_collection_wrap(name=c_2_name, schema=default_schema, check_task=CheckTasks.check_collection_property, check_items={exp_name: c_2_name, exp_schema: default_schema}) - error = {ct.err_code: 1, - ct.err_msg: "Create alias failed: duplicate collection alias"} + error = {ct.err_code: 1602, + ct.err_msg: f"alias exists and already aliased to another collection, alias: {alias_a_name}, " + f"collection: {c_1_name}, other collection: {c_2_name}"} self.utility_wrap.create_alias(collection_2.name, alias_a_name, check_task=CheckTasks.err_res, check_items=error) @@ -453,7 +454,7 @@ def test_alias_alter_not_exist_alias(self): # collection_w.create_alias(alias_name) alias_not_exist_name = cf.gen_unique_str(prefix) - error = {ct.err_code: 1, + error = {ct.err_code: 1600, ct.err_msg: "Alter alias failed: alias does not exist"} self.utility_wrap.alter_alias(collection_w.name, alias_not_exist_name, check_task=CheckTasks.err_res, diff --git a/tests/python_client/testcases/test_bulk_insert.py b/tests/python_client/testcases/test_bulk_insert.py index ba3155157edac..af5f8d092c166 100644 --- a/tests/python_client/testcases/test_bulk_insert.py +++ b/tests/python_client/testcases/test_bulk_insert.py @@ -102,7 +102,7 @@ def test_float_vector_only(self, is_row_based, auto_id, dim, entities): ) logging.info(f"bulk insert task id:{task_id}") success, _ = self.utility_wrap.wait_for_bulk_insert_tasks_completed( - task_ids=[task_id], timeout=90 + task_ids=[task_id], timeout=300 ) tt = time.time() - t0 log.info(f"bulk insert state:{success} in {tt}") @@ -118,7 +118,7 @@ def test_float_vector_only(self, is_row_based, auto_id, dim, entities): field_name=df.vec_field, index_params=index_params ) time.sleep(2) - self.utility_wrap.wait_for_index_building_complete(c_name, timeout=120) + self.utility_wrap.wait_for_index_building_complete(c_name, timeout=300) res, _ = self.utility_wrap.index_building_progress(c_name) log.info(f"index building progress: {res}") self.collection_wrap.load() @@ -187,7 +187,7 @@ def test_str_pk_float_vector_only(self, is_row_based, dim, entities): ) logging.info(f"bulk insert task ids:{task_id}") completed, _ = self.utility_wrap.wait_for_bulk_insert_tasks_completed( - task_ids=[task_id], timeout=90 + task_ids=[task_id], timeout=300 ) tt = time.time() - t0 log.info(f"bulk insert state:{completed} in {tt}") @@ -202,7 +202,7 @@ def test_str_pk_float_vector_only(self, is_row_based, dim, entities): self.collection_wrap.create_index( field_name=df.vec_field, index_params=index_params ) - self.utility_wrap.wait_for_index_building_complete(c_name, timeout=120) + self.utility_wrap.wait_for_index_building_complete(c_name, timeout=300) res, _ = self.utility_wrap.index_building_progress(c_name) log.info(f"index building progress: {res}") self.collection_wrap.load() @@ -289,7 +289,7 @@ def test_partition_float_vector_int_scalar( ) logging.info(f"bulk insert task ids:{task_id}") success, state = self.utility_wrap.wait_for_bulk_insert_tasks_completed( - task_ids=[task_id], timeout=90 + task_ids=[task_id], timeout=300 ) tt = time.time() - t0 log.info(f"bulk insert state:{success} in {tt}") @@ -299,7 +299,7 @@ def test_partition_float_vector_int_scalar( assert self.collection_wrap.num_entities == entities log.debug(state) time.sleep(2) - self.utility_wrap.wait_for_index_building_complete(c_name, timeout=120) + self.utility_wrap.wait_for_index_building_complete(c_name, timeout=300) res, _ = self.utility_wrap.index_building_progress(c_name) log.info(f"index building progress: {res}") log.info(f"wait for load finished and be ready for search") @@ -380,13 +380,13 @@ def test_binary_vector_only(self, is_row_based, auto_id, dim, entities): files=files) logging.info(f"bulk insert task ids:{task_id}") success, _ = self.utility_wrap.wait_for_bulk_insert_tasks_completed( - task_ids=[task_id], timeout=90 + task_ids=[task_id], timeout=300 ) tt = time.time() - t0 log.info(f"bulk insert state:{success} in {tt}") assert success time.sleep(2) - self.utility_wrap.wait_for_index_building_complete(c_name, timeout=120) + self.utility_wrap.wait_for_index_building_complete(c_name, timeout=300) res, _ = self.utility_wrap.index_building_progress(c_name) log.info(f"index building progress: {res}") @@ -469,7 +469,7 @@ def test_insert_before_or_after_bulk_insert(self, insert_before_bulk_insert): ) logging.info(f"bulk insert task ids:{task_id}") success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed( - task_ids=[task_id], timeout=90 + task_ids=[task_id], timeout=300 ) tt = time.time() - t0 log.info(f"bulk insert state:{success} in {tt}") @@ -484,7 +484,7 @@ def test_insert_before_or_after_bulk_insert(self, insert_before_bulk_insert): assert num_entities == bulk_insert_row + direct_insert_row # verify index time.sleep(2) - self.utility_wrap.wait_for_index_building_complete(c_name, timeout=120) + self.utility_wrap.wait_for_index_building_complete(c_name, timeout=300) res, _ = self.utility_wrap.index_building_progress(c_name) log.info(f"index building progress: {res}") # verify search and query @@ -560,7 +560,7 @@ def test_load_before_or_after_bulk_insert(self, loaded_before_bulk_insert, creat ) logging.info(f"bulk insert task ids:{task_id}") success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed( - task_ids=[task_id], timeout=90 + task_ids=[task_id], timeout=300 ) tt = time.time() - t0 log.info(f"bulk insert state:{success} in {tt}") @@ -573,7 +573,7 @@ def test_load_before_or_after_bulk_insert(self, loaded_before_bulk_insert, creat log.info(f"collection entities: {num_entities}") assert num_entities == 500 time.sleep(2) - self.utility_wrap.wait_for_index_building_complete(c_name, timeout=120) + self.utility_wrap.wait_for_index_building_complete(c_name, timeout=300) res, _ = self.utility_wrap.index_building_progress(c_name) log.info(f"index building progress: {res}") # verify search and query @@ -641,7 +641,7 @@ def test_with_all_field_numpy(self, auto_id, dim, entities): ) logging.info(f"bulk insert task ids:{task_id}") success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed( - task_ids=[task_id], timeout=90 + task_ids=[task_id], timeout=300 ) tt = time.time() - t0 log.info(f"bulk insert state:{success} in {tt} with states:{states}") @@ -673,7 +673,7 @@ def test_with_all_field_numpy(self, auto_id, dim, entities): @pytest.mark.parametrize("auto_id", [True, False]) @pytest.mark.parametrize("dim", [128]) @pytest.mark.parametrize("entities", [2000]) - @pytest.mark.parametrize("file_nums", [10]) + @pytest.mark.parametrize("file_nums", [5]) def test_multi_numpy_files_from_diff_folders( self, auto_id, dim, entities, file_nums ): @@ -720,7 +720,7 @@ def test_multi_numpy_files_from_diff_folders( ) task_ids.append(task_id) success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed( - task_ids=[task_id], timeout=90 + task_ids=[task_id], timeout=300 ) log.info(f"bulk insert state:{success}") @@ -742,3 +742,197 @@ def test_multi_numpy_files_from_diff_folders( check_task=CheckTasks.check_search_results, check_items={"nq": 1, "limit": 1}, ) + + @pytest.mark.tags(CaseLabel.L3) + @pytest.mark.parametrize("is_row_based", [True]) + @pytest.mark.parametrize("auto_id", [True, False]) + @pytest.mark.parametrize("par_key_field", [df.int_field, df.string_field]) + def test_partition_key_on_json_file(self, is_row_based, auto_id, par_key_field): + """ + collection: auto_id, customized_id + collection schema: [pk, int64, varchar, float_vector] + Steps: + 1. create collection with partition key enabled + 2. import data + 3. verify the data entities equal the import data and distributed by values of partition key field + 4. load the collection + 5. verify search successfully + 6. verify query successfully + """ + dim = 12 + entities = 200 + files = prepare_bulk_insert_json_files( + minio_endpoint=self.minio_endpoint, + bucket_name=self.bucket_name, + is_row_based=is_row_based, + rows=entities, + dim=dim, + auto_id=auto_id, + data_fields=default_multi_fields, + force=True, + ) + self._connect() + c_name = cf.gen_unique_str("bulk_parkey") + fields = [ + cf.gen_int64_field(name=df.pk_field, is_primary=True), + cf.gen_float_vec_field(name=df.vec_field, dim=dim), + cf.gen_int64_field(name=df.int_field, is_partition_key=(par_key_field == df.int_field)), + cf.gen_string_field(name=df.string_field, is_partition_key=(par_key_field == df.string_field)), + cf.gen_bool_field(name=df.bool_field), + cf.gen_float_field(name=df.float_field), + ] + schema = cf.gen_collection_schema(fields=fields, auto_id=auto_id) + self.collection_wrap.init_collection(c_name, schema=schema) + assert len(self.collection_wrap.partitions) == ct.default_partition_num + + # import data + t0 = time.time() + task_id, _ = self.utility_wrap.do_bulk_insert( + collection_name=c_name, + partition_name=None, + files=files, + ) + logging.info(f"bulk insert task id:{task_id}") + success, _ = self.utility_wrap.wait_for_bulk_insert_tasks_completed( + task_ids=[task_id], timeout=300 + ) + tt = time.time() - t0 + log.info(f"bulk insert state:{success} in {tt}") + assert success + + num_entities = self.collection_wrap.num_entities + log.info(f" collection entities: {num_entities}") + assert num_entities == entities + + # verify imported data is available for search + index_params = ct.default_index + self.collection_wrap.create_index( + field_name=df.vec_field, index_params=index_params + ) + self.collection_wrap.load() + log.info(f"wait for load finished and be ready for search") + time.sleep(10) + log.info( + f"query seg info: {self.utility_wrap.get_query_segment_info(c_name)[0]}" + ) + nq = 2 + topk = 2 + search_data = cf.gen_vectors(nq, dim) + search_params = ct.default_search_params + res, _ = self.collection_wrap.search( + search_data, + df.vec_field, + param=search_params, + limit=topk, + check_task=CheckTasks.check_search_results, + check_items={"nq": nq, "limit": topk}, + ) + for hits in res: + ids = hits.ids + results, _ = self.collection_wrap.query(expr=f"{df.pk_field} in {ids}") + assert len(results) == len(ids) + + # verify data was bulk inserted into different partitions + num_entities = 0 + empty_partition_num = 0 + for p in self.collection_wrap.partitions: + if p.num_entities == 0: + empty_partition_num += 1 + num_entities += p.num_entities + assert num_entities == entities + + # verify error when trying to bulk insert into a specific partition + # TODO: enable the error msg assert after issue #25586 fixed + err_msg = "not allow to set partition name for collection with partition key" + task_id, _ = self.utility_wrap.do_bulk_insert( + collection_name=c_name, + partition_name=self.collection_wrap.partitions[0].name, + files=files, + check_task=CheckTasks.err_res, + check_items={"err_code": 99, "err_msg": err_msg}, + ) + + @pytest.mark.tags(CaseLabel.L3) + @pytest.mark.parametrize("auto_id", [True, False]) + @pytest.mark.parametrize("dim", [13]) + @pytest.mark.parametrize("entities", [150]) + @pytest.mark.parametrize("file_nums", [10]) + def test_partition_key_on_multi_numpy_files( + self, auto_id, dim, entities, file_nums + ): + """ + collection schema 1: [pk, int64, float_vector, double] + data file: .npy files in different folders + Steps: + 1. create collection with partition key enabled, create index and load + 2. import data + 3. verify that import numpy files in a loop + """ + self._connect() + c_name = cf.gen_unique_str("bulk_ins_parkey") + fields = [ + cf.gen_int64_field(name=df.pk_field, is_primary=True), + cf.gen_int64_field(name=df.int_field, is_partition_key=True), + cf.gen_float_field(name=df.float_field), + cf.gen_double_field(name=df.double_field), + cf.gen_float_vec_field(name=df.vec_field, dim=dim), + ] + schema = cf.gen_collection_schema(fields=fields) + self.collection_wrap.init_collection(c_name, schema=schema) + # build index + index_params = ct.default_index + self.collection_wrap.create_index( + field_name=df.vec_field, index_params=index_params + ) + # load collection + self.collection_wrap.load() + data_fields = [f.name for f in fields if not f.to_dict().get("auto_id", False)] + task_ids = [] + for i in range(file_nums): + files = prepare_bulk_insert_numpy_files( + minio_endpoint=self.minio_endpoint, + bucket_name=self.bucket_name, + rows=entities, + dim=dim, + data_fields=data_fields, + file_nums=1, + force=True, + ) + task_id, _ = self.utility_wrap.do_bulk_insert( + collection_name=c_name, files=files + ) + task_ids.append(task_id) + success, states = self.utility_wrap.wait_for_bulk_insert_tasks_completed( + task_ids=task_ids, timeout=300 + ) + log.info(f"bulk insert state:{success}") + + assert success + log.info(f" collection entities: {self.collection_wrap.num_entities}") + assert self.collection_wrap.num_entities == entities * file_nums + # verify imported data is indexed + success = self.utility_wrap.wait_index_build_completed(c_name) + assert success + # verify search and query + log.info(f"wait for load finished and be ready for search") + self.collection_wrap.load(_refresh=True) + search_data = cf.gen_vectors(1, dim) + search_params = ct.default_search_params + res, _ = self.collection_wrap.search( + search_data, + df.vec_field, + param=search_params, + limit=1, + check_task=CheckTasks.check_search_results, + check_items={"nq": 1, "limit": 1}, + ) + + # verify data was bulk inserted into different partitions + num_entities = 0 + empty_partition_num = 0 + for p in self.collection_wrap.partitions: + if p.num_entities == 0: + empty_partition_num += 1 + num_entities += p.num_entities + assert num_entities == entities * file_nums + diff --git a/tests/python_client/testcases/test_collection.py b/tests/python_client/testcases/test_collection.py index 3b6a2313b6bc8..5c24ae847b281 100644 --- a/tests/python_client/testcases/test_collection.py +++ b/tests/python_client/testcases/test_collection.py @@ -114,7 +114,11 @@ def test_collection_illegal_name(self, name): expected: raise exception """ self._connect() - error = {ct.err_code: 1, ct.err_msg: "`collection_name` value {} is illegal".format(name)} + error1 = {ct.err_code: 1, ct.err_msg: "`collection_name` value {} is illegal".format(name)} + error2 = {ct.err_code: 1100, ct.err_msg: "Invalid collection name: 1ns_. the first character of a" + " collection name must be an underscore or letter: invalid" + " parameter".format(name)} + error = error1 if name not in ["1ns_", "qw$_o90"] else error2 self.collection_wrap.init_collection(name, schema=default_schema, check_task=CheckTasks.err_res, check_items=error) @@ -315,7 +319,7 @@ def test_collection_invalid_type_fields(self, get_invalid_type_fields): """ self._connect() fields = get_invalid_type_fields - error = {ct.err_code: 0, ct.err_msg: "The fields of schema must be type list"} + error = {ct.err_code: 1, ct.err_msg: "The fields of schema must be type list."} self.collection_schema_wrap.init_collection_schema(fields=fields, check_task=CheckTasks.err_res, check_items=error) @@ -344,7 +348,7 @@ def test_collection_invalid_type_field(self, name): field, _ = self.field_schema_wrap.init_field_schema(name=name, dtype=5, is_primary=True) vec_field = cf.gen_float_vec_field() schema = cf.gen_collection_schema(fields=[field, vec_field]) - error = {ct.err_code: 1, ct.err_msg: f"bad argument type for built-in"} + error = {ct.err_code: 1701, ct.err_msg: f"bad argument type for built-in"} self.collection_wrap.init_collection(c_name, schema=schema, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @@ -374,7 +378,7 @@ def test_collection_none_field_name(self): c_name = cf.gen_unique_str(prefix) field, _ = self.field_schema_wrap.init_field_schema(name=None, dtype=DataType.INT64, is_primary=True) schema = cf.gen_collection_schema(fields=[field, cf.gen_float_vec_field()]) - error = {ct.err_code: 1, ct.err_msg: "You should specify the name of field"} + error = {ct.err_code: 1701, ct.err_msg: "field name should not be empty"} self.collection_wrap.init_collection(c_name, schema=schema, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @@ -820,7 +824,7 @@ def test_collection_vector_invalid_dim(self, get_invalid_dim): c_name = cf.gen_unique_str(prefix) float_vec_field = cf.gen_float_vec_field(dim=get_invalid_dim) schema = cf.gen_collection_schema(fields=[cf.gen_int64_field(is_primary=True), float_vec_field]) - error = {ct.err_code: 1, ct.err_msg: f'invalid dim: {get_invalid_dim}'} + error = {ct.err_code: 65535, ct.err_msg: "strconv.ParseInt: parsing \"[]\": invalid syntax"} self.collection_wrap.init_collection(c_name, schema=schema, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @@ -1068,7 +1072,7 @@ def test_collection_dup_name_drop(self): check_items={exp_name: c_name, exp_schema: default_schema}) self.collection_wrap.drop() assert not self.utility_wrap.has_collection(c_name)[0] - error = {ct.err_code: 1, ct.err_msg: f'HasPartition failed: collection not found: {c_name}'} + error = {ct.err_code: 4, ct.err_msg: 'collection not found'} collection_w.has_partition("p", check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) @@ -1097,7 +1101,9 @@ def test_collection_all_datatype_fields(self): self._connect() fields = [] for k, v in DataType.__members__.items(): - if v and v != DataType.UNKNOWN and v != DataType.STRING and v != DataType.VARCHAR and v != DataType.FLOAT_VECTOR and v != DataType.BINARY_VECTOR: + if v and v != DataType.UNKNOWN and v != DataType.STRING\ + and v != DataType.VARCHAR and v != DataType.FLOAT_VECTOR\ + and v != DataType.BINARY_VECTOR and v != DataType.ARRAY: field, _ = self.field_schema_wrap.init_field_schema(name=k.lower(), dtype=v) fields.append(field) fields.append(cf.gen_float_vec_field()) @@ -1738,7 +1744,7 @@ def test_create_collection_limit_fields(self): field_name_tmp = gen_unique_str("field_name") field_schema_temp = cf.gen_int64_field(field_name_tmp) field_schema_list.append(field_schema_temp) - error = {ct.err_code: 1, ct.err_msg: "'maximum field\'s number should be limited to 64'"} + error = {ct.err_code: 65535, ct.err_msg: "maximum field's number should be limited to 64"} schema, _ = self.collection_schema_wrap.init_collection_schema(fields=field_schema_list) self.init_collection_wrap(name=c_name, schema=schema, check_task=CheckTasks.err_res, check_items=error) @@ -2219,8 +2225,8 @@ def test_load_collection_not_existed(self): c_name = cf.gen_unique_str() collection_wr = self.init_collection_wrap(name=c_name) collection_wr.drop() - error = {ct.err_code: 1, - ct.err_msg: "DescribeCollection failed: can't find collection: %s" % c_name} + error = {ct.err_code: 100, + ct.err_msg: "collection= : collection not found"} collection_wr.load(check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @@ -2234,8 +2240,8 @@ def test_release_collection_not_existed(self): c_name = cf.gen_unique_str() collection_wr = self.init_collection_wrap(name=c_name) collection_wr.drop() - error = {ct.err_code: 1, - ct.err_msg: "DescribeCollection failed: can't find collection: %s" % c_name} + error = {ct.err_code: 100, + ct.err_msg: "collection= : collection not found"} collection_wr.release(check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @@ -2433,7 +2439,7 @@ def test_load_collection_after_release_partition_collection(self): self.init_partition_wrap(collection_w, partition2) collection_w.load() partition_w.release() - error = {ct.err_code: 1, ct.err_msg: 'not loaded into memory'} + error = {ct.err_code: 65538, ct.err_msg: 'partition not loaded'} collection_w.query(default_term_expr, partition_names=[partition1], check_task=CheckTasks.err_res, check_items=error) collection_w.release() @@ -2458,7 +2464,7 @@ def test_load_partitions_after_release_partition_collection(self): partition_w1.release() collection_w.release() partition_w1.load() - error = {ct.err_code: 1, ct.err_msg: 'not loaded into memory'} + error = {ct.err_code: 65538, ct.err_msg: 'partition not loaded'} collection_w.query(default_term_expr, partition_names=[partition2], check_task=CheckTasks.err_res, check_items=error) partition_w2.load() @@ -2516,7 +2522,7 @@ def test_load_collection_after_drop_partition_and_release_another(self): partition_w1.release() partition_w1.drop() partition_w2.release() - error = {ct.err_code: 1, ct.err_msg: 'not loaded into memory'} + error = {ct.err_code: 65538, ct.err_msg: 'partition not loaded'} collection_w.query(default_term_expr, partition_names=[partition2], check_task=CheckTasks.err_res, check_items=error) collection_w.load() @@ -2595,8 +2601,7 @@ def test_load_release_collection(self): collection_wr.load() collection_wr.release() collection_wr.drop() - error = {ct.err_code: 1, - ct.err_msg: "DescribeCollection failed: can't find collection: %s" % c_name} + error = {ct.err_code: 100, ct.err_msg: "collection not found"} collection_wr.load(check_task=CheckTasks.err_res, check_items=error) collection_wr.release(check_task=CheckTasks.err_res, check_items=error) @@ -2613,8 +2618,7 @@ def test_release_collection_after_drop(self): collection_wr.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_wr.load() collection_wr.drop() - error = {ct.err_code: 0, - ct.err_msg: "can't find collection"} + error = {ct.err_code: 100, ct.err_msg: "collection not found"} collection_wr.release(check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) @@ -2726,7 +2730,8 @@ def test_load_replica_greater_than_querynodes(self): assert collection_w.num_entities == ct.default_nb collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) - error = {ct.err_code: 1, ct.err_msg: f"no enough nodes to create replicas"} + error = {ct.err_code: 65535, + ct.err_msg: "failed to load collection: failed to spawn replica for collection: nodes not enough"} collection_w.load(replica_number=3, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.ClusterOnly) @@ -2756,7 +2761,8 @@ def test_load_replica_change(self): assert loading_progress == {'loading_progress': '100%'} # verify load different replicas thrown an exception - error = {ct.err_code: 5, ct.err_msg: f"Should release first then reload with the new number of replicas"} + error = {ct.err_code: 1100, ct.err_msg: "failed to load collection: can't change the replica number for " + "loaded collection: expected=1, actual=2: invalid parameter"} collection_w.load(replica_number=2, check_task=CheckTasks.err_res, check_items=error) one_replica, _ = collection_w.get_replicas() assert len(one_replica.groups) == 1 @@ -2842,7 +2848,7 @@ def test_load_replica_partitions(self): check_task=CheckTasks.check_query_results, check_items={'exp_res': df_2.iloc[:1, :1].to_dict('records')}) - error = {ct.err_code: 1, ct.err_msg: f"not loaded into memory"} + error = {ct.err_code: 65538, ct.err_msg: "partition not loaded"} collection_w.query(expr=f"{ct.default_int64_field_name} in [0]", partition_names=[ct.default_partition_name, ct.default_tag], check_task=CheckTasks.err_res, check_items=error) @@ -2994,8 +3000,9 @@ def test_get_collection_replicas_not_loaded(self): assert collection_w.num_entities == ct.default_nb collection_w.get_replicas(check_task=CheckTasks.err_res, - check_items={"err_code": 15, - "err_msg": "collection not found, maybe not loaded"}) + check_items={"err_code": 400, + "err_msg": "failed to get replicas by collection: " + "replica not found"}) @pytest.mark.tags(CaseLabel.L3) def test_count_multi_replicas(self): @@ -3030,7 +3037,7 @@ def test_load_collection_without_creating_index(self): collection_w = self.init_collection_general(prefix, True, is_index=False)[0] collection_w.load(check_task=CheckTasks.err_res, check_items={"err_code": 1, - "err_msg": "index not exist"}) + "err_msg": "index not found"}) class TestDescribeCollection(TestcaseBase): @@ -3041,7 +3048,6 @@ class TestDescribeCollection(TestcaseBase): """ @pytest.mark.tags(CaseLabel.L2) - @pytest.mark.xfail(reason="issue 24493") def test_collection_describe(self): """ target: test describe collection @@ -3052,24 +3058,19 @@ def test_collection_describe(self): c_name = cf.gen_unique_str(prefix) collection_w = self.init_collection_wrap(name=c_name) collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) - description = {'collection_name': c_name, 'auto_id': False, 'num_shards': ct.default_shards_num, 'description': '', - 'fields': [{'field_id': 100, 'name': 'int64', 'description': '', 'type': 5, - 'params': {}, 'is_primary': True, 'auto_id': False, - 'is_partition_key': False, 'default_value': None, 'is_dynamic': False}, - {'field_id': 101, 'name': 'float', 'description': '', 'type': 10, - 'params': {}, 'is_primary': False, 'auto_id': False, - 'is_partition_key': False, 'default_value': None, 'is_dynamic': False}, - {'field_id': 102, 'name': 'varchar', 'description': '', 'type': 21, - 'params': {'max_length': 65535}, 'is_primary': False, 'auto_id': False, - 'is_partition_key': False, 'default_value': None, 'is_dynamic': False}, - {'field_id': 103, 'name': 'json_field', 'description': '', 'type': 23, - 'params': {}, 'is_primary': False, 'auto_id': False, - 'is_partition_key': False, 'default_value': None, 'is_dynamic': False}, - {'field_id': 104, 'name': 'float_vector', 'description': '', 'type': 101, - 'params': {'dim': 128}, 'is_primary': False, 'auto_id': False, - 'is_partition_key': False, 'default_value': None, 'is_dynamic': False}], - 'aliases': [], 'consistency_level': 2, 'properties': [], 'num_partitions': 0, - 'enable_dynamic_field': False} + description = \ + {'collection_name': c_name, 'auto_id': False, 'num_shards': ct.default_shards_num, 'description': '', + 'fields': [{'field_id': 100, 'name': 'int64', 'description': '', 'type': 5, + 'params': {}, 'is_primary': True, 'element_type': 0,}, + {'field_id': 101, 'name': 'float', 'description': '', 'type': 10, 'params': {}, + 'element_type': 0,}, + {'field_id': 102, 'name': 'varchar', 'description': '', 'type': 21, + 'params': {'max_length': 65535}, 'element_type': 0,}, + {'field_id': 103, 'name': 'json_field', 'description': '', 'type': 23, 'params': {}, + 'element_type': 0,}, + {'field_id': 104, 'name': 'float_vector', 'description': '', 'type': 101, + 'params': {'dim': 128}, 'element_type': 0}], + 'aliases': [], 'consistency_level': 0, 'properties': [], 'num_partitions': 1} res = collection_w.describe()[0] del res['collection_id'] log.info(res) @@ -3098,7 +3099,7 @@ def test_release_collection_during_searching(self): search_res, _ = collection_wr.search(vectors, default_search_field, default_search_params, default_limit, _async=True) collection_wr.release() - error = {ct.err_code: 1, ct.err_msg: 'collection %s was not loaded into memory' % c_name} + error = {ct.err_code: 65535, ct.err_msg: "collection not loaded"} collection_wr.search(vectors, default_search_field, default_search_params, default_limit, check_task=CheckTasks.err_res, check_items=error) @@ -3125,8 +3126,8 @@ def test_release_partition_during_searching(self): default_search_params, limit, default_search_exp, [par_name], check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "partition has been released"}) + check_items={"err_code": 65535, + "err_msg": "collection not loaded"}) @pytest.mark.tags(CaseLabel.L0) def test_release_indexed_collection_during_searching(self): @@ -3147,7 +3148,7 @@ def test_release_indexed_collection_during_searching(self): default_search_params, limit, default_search_exp, [par_name], _async=True) collection_w.release() - error = {ct.err_code: 1, ct.err_msg: 'collection %s was not loaded into memory' % collection_w.name} + error = {ct.err_code: 65535, ct.err_msg: "collection not loaded"} collection_w.search(vectors, default_search_field, default_search_params, limit, default_search_exp, [par_name], @@ -3197,8 +3198,8 @@ def test_load_partition_after_index_binary(self, binary_index, metric_type): # for metric_type in ct.binary_metrics: binary_index["metric_type"] = metric_type if binary_index["index_type"] == "BIN_IVF_FLAT" and metric_type in ct.structure_metrics: - error = {ct.err_code: 1, ct.err_msg: 'Invalid metric_type: SUBSTRUCTURE, ' - 'which does not match the index type: BIN_IVF_FLAT'} + error = {ct.err_code: 65535, + ct.err_msg: "metric type not found or not supported, supported: [HAMMING JACCARD]"} collection_w.create_index(ct.default_binary_vec_field_name, binary_index, check_task=CheckTasks.err_res, check_items=error) collection_w.create_index(ct.default_binary_vec_field_name, ct.default_bin_flat_index) @@ -3379,7 +3380,8 @@ def test_load_collection_after_load_loaded_partition(self): partition_w2 = self.init_partition_wrap(collection_w, partition2) partition_w1.load() partition_w1.load() - error = {ct.err_code: 1, ct.err_msg: 'not loaded into memory'} + error = {ct.err_code: 65538, + ct.err_msg: 'partition not loaded'} collection_w.query(default_term_expr, partition_names=[partition2], check_task=CheckTasks.err_res, check_items=error) collection_w.load() @@ -3434,7 +3436,7 @@ def test_load_partitions_release_collection(self): partition_w2 = self.init_partition_wrap(collection_w, partition2) partition_w1.load() collection_w.release() - error = {ct.err_code: 1, ct.err_msg: 'not loaded into memory'} + error = {ct.err_code: 65535, ct.err_msg: "collection not loaded"} collection_w.query(default_term_expr, partition_names=[partition1], check_task=CheckTasks.err_res, check_items=error) partition_w1.load() @@ -3473,7 +3475,8 @@ def test_load_partitions_after_load_release_partition(self): partition_w2 = self.init_partition_wrap(collection_w, partition2) partition_w1.load() partition_w1.release() - error = {ct.err_code: 1, ct.err_msg: 'not loaded into memory'} + error = {ct.err_code: 65535, + ct.err_msg: 'collection not loaded'} collection_w.query(default_term_expr, partition_names=[partition1], check_task=CheckTasks.err_res, check_items=error) partition_w1.load() @@ -3536,7 +3539,8 @@ def test_load_collection_after_load_partition_release_partitions(self): partition_w1.load() partition_w1.release() partition_w2.release() - error = {ct.err_code: 1, ct.err_msg: 'not loaded into memory'} + error = {ct.err_code: 65535, + ct.err_msg: 'collection not loaded'} collection_w.query(default_term_expr, partition_names=[partition1, partition2], check_task=CheckTasks.err_res, check_items=error) collection_w.load() @@ -3579,7 +3583,7 @@ def test_load_collection_after_load_drop_partition(self): partition_w1.load() partition_w1.release() partition_w1.drop() - error = {ct.err_code: 1, ct.err_msg: 'name not found'} + error = {ct.err_code: 65535, ct.err_msg: f'partition name {partition1} not found'} collection_w.query(default_term_expr, partition_names=[partition1, partition2], check_task=CheckTasks.err_res, check_items=error) partition_w2.drop() @@ -3662,7 +3666,8 @@ def test_release_load_partition_after_load_partition_drop_another(self): partition_w1.load() partition_w2.drop() partition_w1.release() - error = {ct.err_code: 1, ct.err_msg: 'not loaded into memory'} + error = {ct.err_code: 65535, + ct.err_msg: 'collection not loaded'} collection_w.query(default_term_expr, partition_names=[partition1], check_task=CheckTasks.err_res, check_items=error) partition_w1.load() @@ -3775,7 +3780,7 @@ def test_collection_string_field_with_exceed_max_len(self): max_length = 100000 string_field = cf.gen_string_field(max_length=max_length) schema = cf.gen_collection_schema([int_field, string_field, vec_field]) - error = {ct.err_code: 1, ct.err_msg: "invalid max_length: %s" % max_length} + error = {ct.err_code: 65535, ct.err_msg: "the maximum length specified for a VarChar should be in (0, 65535]"} self.collection_wrap.init_collection(name=c_name, schema=schema, check_task=CheckTasks.err_res, check_items=error) diff --git a/tests/python_client/testcases/test_compaction.py b/tests/python_client/testcases/test_compaction.py index 5ab8909c349d6..88ef2debb0259 100644 --- a/tests/python_client/testcases/test_compaction.py +++ b/tests/python_client/testcases/test_compaction.py @@ -1152,7 +1152,7 @@ def test_compact_during_insert(self): is_dup=False) collection_w.create_index(ct.default_float_vec_field_name, ct.default_index) log.debug(collection_w.index()) - df = cf.gen_default_dataframe_data() + df = cf.gen_default_dataframe_data(start=ct.default_nb*2) def do_flush(): collection_w.insert(df) diff --git a/tests/python_client/testcases/test_concurrent.py b/tests/python_client/testcases/test_concurrent.py new file mode 100644 index 0000000000000..50da0f5b06c30 --- /dev/null +++ b/tests/python_client/testcases/test_concurrent.py @@ -0,0 +1,100 @@ +import time +import pytest +import json +from time import sleep +from pymilvus import connections +from chaos.checker import (InsertChecker, + SearchChecker, + QueryChecker, + DeleteChecker, + Op, + ResultAnalyzer + ) +from utils.util_log import test_log as log +from chaos import chaos_commons as cc +from common import common_func as cf +from chaos.chaos_commons import assert_statistic +from common.common_type import CaseLabel +from chaos import constants +from delayed_assert import assert_expectations + + +def get_all_collections(): + try: + with open("/tmp/ci_logs/all_collections.json", "r") as f: + data = json.load(f) + all_collections = data["all"] + except Exception as e: + log.error(f"get_all_collections error: {e}") + return [None] + return all_collections + + +class TestBase: + expect_create = constants.SUCC + expect_insert = constants.SUCC + expect_flush = constants.SUCC + expect_compact = constants.SUCC + expect_search = constants.SUCC + expect_query = constants.SUCC + host = '127.0.0.1' + port = 19530 + _chaos_config = None + health_checkers = {} + + +class TestOperations(TestBase): + + @pytest.fixture(scope="function", autouse=True) + def connection(self, host, port, user, password, milvus_ns): + if user and password: + connections.connect('default', host=host, port=port, user=user, password=password, secure=True) + else: + connections.connect('default', host=host, port=port) + if connections.has_connection("default") is False: + raise Exception("no connections") + log.info("connect to milvus successfully") + self.host = host + self.port = port + self.user = user + self.password = password + + def init_health_checkers(self, collection_name=None): + c_name = collection_name + checkers = { + Op.insert: InsertChecker(collection_name=c_name), + Op.search: SearchChecker(collection_name=c_name), + Op.query: QueryChecker(collection_name=c_name), + Op.delete: DeleteChecker(collection_name=c_name), + } + self.health_checkers = checkers + + @pytest.fixture(scope="function", params=get_all_collections()) + def collection_name(self, request): + if request.param == [] or request.param == "": + pytest.skip("The collection name is invalid") + yield request.param + + @pytest.mark.tags(CaseLabel.L3) + def test_operations(self, request_duration, collection_name): + # start the monitor threads to check the milvus ops + log.info("*********************Test Start**********************") + log.info(connections.get_connection_addr('default')) + c_name = collection_name if collection_name else cf.gen_unique_str("Checker_") + self.init_health_checkers(collection_name=c_name) + cc.start_monitor_threads(self.health_checkers) + log.info("*********************Load Start**********************") + request_duration = request_duration.replace("h", "*3600+").replace("m", "*60+").replace("s", "") + if request_duration[-1] == "+": + request_duration = request_duration[:-1] + request_duration = eval(request_duration) + for i in range(10): + sleep(request_duration//10) + for k, v in self.health_checkers.items(): + v.check_result() + time.sleep(60) + ra = ResultAnalyzer() + ra.get_stage_success_rate() + assert_statistic(self.health_checkers) + assert_expectations() + log.info("*********************Test Completed**********************") diff --git a/tests/python_client/testcases/test_database.py b/tests/python_client/testcases/test_database.py index 4fb0887dce557..61ae66d3e7930 100644 --- a/tests/python_client/testcases/test_database.py +++ b/tests/python_client/testcases/test_database.py @@ -523,6 +523,95 @@ def test_using_db_not_existed(self): colls, _ = self.utility_wrap.list_collections() assert collection_w.name in colls + def test_create_same_collection_name_different_db(self): + """ + target: test create same collection name in different db + method: 1. create a db and create 1 collection in db + 2. create the collection in another db + expected: exception + """ + # check default db is empty + self._connect() + assert self.utility_wrap.list_collections()[0] == [] + + # create a collection in default db + c_name = "collection_same" + self.init_collection_wrap(c_name) + assert self.utility_wrap.list_collections()[0] == [c_name] + + # create a new database + db_name = cf.gen_unique_str("db") + self.database_wrap.create_database(db_name) + self.database_wrap.using_database(db_name) + + # create a collection in new db using same name + self.init_collection_wrap(c_name) + assert self.utility_wrap.list_collections()[0] == [c_name] + + def test_rename_existed_collection_name_new_db(self): + """ + target: test create same collection name in different db + method: 1. create a db and create 1 collection in db + 2. create the collection in another db + expected: exception + """ + # check default db is empty + self._connect() + assert self.utility_wrap.list_collections()[0] == [] + + # create a collection in default db + c_name1 = "collection_1" + self.init_collection_wrap(c_name1) + assert self.utility_wrap.list_collections()[0] == [c_name1] + + # create a new database + db_name = cf.gen_unique_str("db") + self.database_wrap.create_database(db_name) + self.database_wrap.using_database(db_name) + + # create a collection in new db + c_name2 = "collection_2" + self.init_collection_wrap(c_name2) + assert self.utility_wrap.list_collections()[0] == [c_name2] + + # rename the collection and move it to default db + error = {ct.err_code: 65535, ct.err_msg: "duplicated new collection name default:collection_1 " + "with other collection name or alias"} + self.utility_wrap.rename_collection(c_name2, c_name1, "default", + check_task=CheckTasks.err_res, check_items=error) + + def test_rename_collection_in_new_db(self): + """ + target: test rename collection in new created db + method: 1. create a db and create 1 collection in db + 2. rename the collection + expected: exception + """ + self._connect() + # check default db is empty + assert self.utility_wrap.list_collections()[0] == [] + + # create a new database + db_name = cf.gen_unique_str("db") + self.database_wrap.create_database(db_name) + self.database_wrap.using_database(db_name) + + # create 1 collection in new db + old_name = "old_collection" + self.init_collection_wrap(old_name) + assert self.utility_wrap.list_collections()[0] == [old_name] + + # rename the collection + new_name = "new_collection" + self.utility_wrap.rename_collection(old_name, new_name) + + # check the collection still in new db + assert self.utility_wrap.list_collections()[0] == [new_name] + + # check the collection not in default db + self.database_wrap.using_database("default") + assert self.utility_wrap.list_collections()[0] == [] + @pytest.mark.tags(CaseLabel.RBAC) class TestDatabaseOtherApi(TestcaseBase): diff --git a/tests/python_client/testcases/test_delete.py b/tests/python_client/testcases/test_delete.py index 15c826e3d1b1b..7bfaab9abbf47 100644 --- a/tests/python_client/testcases/test_delete.py +++ b/tests/python_client/testcases/test_delete.py @@ -1,5 +1,6 @@ +import random import time - +import pandas as pd import pytest from base.client_base import TestcaseBase @@ -17,9 +18,9 @@ query_tmp_expr_str = [{f'{ct.default_string_field_name}': "0"}] exp_res = "exp_res" default_string_expr = "varchar in [ \"0\"]" -default_invaild_string_exp = "varchar >= 0" +default_invalid_string_exp = "varchar >= 0" index_name1 = cf.gen_unique_str("float") -index_name2 = cf.gen_unique_str("varhar") +index_name2 = cf.gen_unique_str("varchar") default_search_params = ct.default_search_params @@ -55,6 +56,30 @@ def test_delete_entities(self, is_binary): # query with deleted ids collection_w.query(expr, check_task=CheckTasks.check_query_empty) + @pytest.mark.tags(CaseLabel.L0) + @pytest.mark.parametrize('is_binary', [False, True]) + def test_delete_entities_with_range(self, is_binary): + """ + target: test delete data from collection + method: 1.create and insert nb with flush + 2.load collection + 3.delete half of nb + 4.query with deleted ids + expected: Query result is empty + """ + # init collection with default_nb default data + collection_w, _, _, ids = self.init_collection_general(prefix, insert_data=True, auto_id=True, is_binary=is_binary)[0:4] + expr = f'{ct.default_int64_field_name} < {ids[half_nb]}' + + # delete half of data + del_res = collection_w.delete(expr)[0] + assert del_res.delete_count == half_nb + # This flush will not persist the deleted ids, just delay the time to ensure that queryNode consumes deleteMsg + collection_w.num_entities + + # query with deleted ids + collection_w.query(expr, check_task=CheckTasks.check_query_empty) + @pytest.mark.tags(CaseLabel.L2) def test_delete_without_connection(self): """ @@ -159,32 +184,19 @@ def test_delete_expr_all_values(self): collection_w.query(expr, check_task=CheckTasks.check_query_empty) @pytest.mark.tags(CaseLabel.L1) - def test_delete_expr_non_primary_key(self): + def test_delete_expr_with_vector(self): """ - target: test delete with non-pk field - method: delete with expr field not pk + target: test delete with vector field + method: delete with expr vector field expected: raise exception """ - collection_w = self.init_collection_general(prefix, nb=tmp_nb, insert_data=True, is_all_data_type=True, is_index=True)[0] - exprs = [ - f"{ct.default_int32_field_name} in [1]", - f"{ct.default_int16_field_name} in [1]", - f"{ct.default_int8_field_name} in [1]", - f"{ct.default_bool_field_name} in [True]", - f"{ct.default_float_field_name} in [1.0]", - f"{ct.default_double_field_name} in [1.0]", - f"{ct.default_string_field_name} in [ \"0\"]", - f"{ct.default_float_vec_field_name} in [[0.1]]" - ] + collection_w = self.init_collection_general(prefix, nb=tmp_nb, insert_data=True, + is_all_data_type=True, is_index=True)[0] + expr = f"{ct.default_float_vec_field_name} in [[0.1]]" error = {ct.err_code: 1, - ct.err_msg: f"invalid expression, we only support to delete by pk"} - for expr in exprs: - collection_w.delete(expr, check_task=CheckTasks.err_res, check_items=error) + ct.err_msg: f"failed to create expr plan, expr = {expr}"} - # query - _query_res_tmp_expr = [{ct.default_int64_field_name: 0, ct.default_string_field_name: '0'}] - collection_w.query(tmp_expr, output_fields=[ct.default_string_field_name], check_task=CheckTasks.check_query_results, - check_items={exp_res: _query_res_tmp_expr}) + collection_w.delete(expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_delete_not_existed_values(self): @@ -198,7 +210,7 @@ def test_delete_not_existed_values(self): # No exception expr = f'{ct.default_int64_field_name} in {[tmp_nb]}' - collection_w.delete(expr=expr)[0] + collection_w.delete(expr=expr) collection_w.query(tmp_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: query_res_tmp_expr}) @@ -212,7 +224,7 @@ def test_delete_part_not_existed_values(self): # init collection with tmp_nb default data collection_w = self.init_collection_general(prefix, nb=tmp_nb, insert_data=True)[0] expr = f'{ct.default_int64_field_name} in {[0, tmp_nb]}' - collection_w.delete(expr=expr)[0] + collection_w.delete(expr=expr) collection_w.query(expr, check_task=CheckTasks.check_query_empty) @pytest.mark.tags(CaseLabel.L2) @@ -571,8 +583,8 @@ def test_delete_not_existed_partition(self): collection_w = self.init_collection_general(prefix, nb=tmp_nb, insert_data=True)[0] # raise exception - error = {ct.err_code: 1, - ct.err_msg: f"partitionID of partitionName:{ct.default_tag} can not be find"} + error = {ct.err_code: 200, + ct.err_msg: f"Failed to get partition id: partition={ct.default_tag}: partition not found"} collection_w.delete(tmp_expr, partition_name=ct.default_tag, check_task=CheckTasks.err_res, check_items=error) @@ -666,8 +678,7 @@ def test_delete_query_without_loading(self): assert res.delete_count == 1 # query without loading and raise exception - error = {ct.err_code: 1, - ct.err_msg: f"collection {collection_w.name} was not loaded into memory"} + error = {ct.err_code: 65535, ct.err_msg: "collection not loaded"} collection_w.query(expr=tmp_expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) @@ -952,7 +963,6 @@ def test_delete_flush_loop(self): expr = f'{ct.default_int64_field_name} in {ids[i * batch: (i + 1) * batch]}' res, _ = collection_w.delete(expr) assert res.delete_count == batch - assert collection_w.num_entities == tmp_nb # query with all ids expr = f'{ct.default_int64_field_name} in {ids}' @@ -1140,6 +1150,31 @@ def test_delete_sealed_only(self): collection_w.query(expr, check_task=CheckTasks.check_query_empty) + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.skip("issue #26820") + @pytest.mark.parametrize("consistency_level", ["Bounded", "Session", "Eventually"]) + def test_delete_flush_query_consistency_not_strong(self, consistency_level): + """ + target: test delete, flush and query with Consistency level not strong + method: 1.delete ids + 2.flush + 3.query with Consistency level not strong + expected: query successfully + """ + # init collection + collection_w = self.init_collection_general(prefix, True)[0] + + # delete and flush + delete_ids = [i for i in range(ct.default_nb // 2)] + delete_expr = f"{ct.default_int64_field_name} in {delete_ids}" + res = collection_w.delete(delete_expr)[0] + assert res.delete_count == ct.default_nb // 2 + collection_w.flush() + + # query with Consistency level not strong + collection_w.query(expr=delete_expr, consistency_level=consistency_level, + check_task=CheckTasks.check_query_empty) + class TestDeleteString(TestcaseBase): """ @@ -1558,7 +1593,6 @@ def test_delete_flush_loop_with_string(self): expr = expr.replace("'", "\"") res, _ = collection_w.delete(expr) assert res.delete_count == batch - assert collection_w.num_entities == tmp_nb # query with all ids expr = f'{ct.default_string_field_name} in {ids}' @@ -1729,8 +1763,8 @@ def test_delete_invalid_expr(self): self.init_collection_general(prefix, nb=tmp_nb, insert_data=True, primary_field=ct.default_string_field_name)[0] collection_w.load() error = {ct.err_code: 0, - ct.err_msg: f"failed to create expr plan, expr = {default_invaild_string_exp}"} - collection_w.delete(expr=default_invaild_string_exp, + ct.err_msg: f"failed to create expr plan, expr = {default_invalid_string_exp}"} + collection_w.delete(expr=default_invalid_string_exp, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) @@ -1804,6 +1838,7 @@ def test_delete_with_string_field_is_empty(self): df[2] = [""for _ in range(nb)] collection_w.insert(df) + collection_w.flush() collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() assert collection_w.num_entities == nb @@ -1816,3 +1851,366 @@ def test_delete_with_string_field_is_empty(self): # load and query with id collection_w.load() collection_w.query(string_expr, check_task=CheckTasks.check_query_empty) + + +class TestDeleteComplexExpr(TestcaseBase): + """ + Test case of delete interface with complex expr + """ + + @pytest.mark.tags(CaseLabel.L0) + @pytest.mark.parametrize("expression", cf.gen_normal_expressions()[1:]) + @pytest.mark.parametrize("enable_dynamic_field", [True, False]) + def test_delete_normal_expressions(self, expression, enable_dynamic_field): + """ + target: test delete entities using normal expression + method: delete using normal expression + expected: delete successfully + """ + # init collection with nb default data + collection_w, _vectors, _, insert_ids = \ + self.init_collection_general(prefix, True, enable_dynamic_field=enable_dynamic_field)[0:4] + + # filter result with expression in collection + _vectors = _vectors[0] + expression = expression.replace("&&", "and").replace("||", "or") + filter_ids = [] + for i, _id in enumerate(insert_ids): + if enable_dynamic_field: + int64 = _vectors[i][ct.default_int64_field_name] + float = _vectors[i][ct.default_float_field_name] + else: + int64 = _vectors.int64[i] + float = _vectors.float[i] + if not expression or eval(expression): + filter_ids.append(_id) + + # delete with expressions + res = collection_w.delete(expression)[0] + assert res.delete_count == len(filter_ids) + + # query to check + collection_w.query(f"int64 in {filter_ids}", check_task=CheckTasks.check_query_empty) + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("field_name", ["varchar", "json_field['string']", "NewStr"]) + @pytest.mark.parametrize("like", ["like", "LIKE"]) + @pytest.mark.parametrize("enable_dynamic_field", [True, False]) + def test_delete_string_expressions_like(self, field_name, like, enable_dynamic_field): + """ + target: test delete expr like + method: delete using expression like + expected: delete successfully + """ + if field_name == "NewStr" and enable_dynamic_field is False: + pytest.skip("only support when enable_dynamic_filed == True") + # init collection with nb default data + nb = 1000 + collection_w, _vectors, _, insert_ids = \ + self.init_collection_general(prefix, False, enable_dynamic_field=enable_dynamic_field)[0:4] + + # insert + string_list = [cf.gen_str_by_length() for _ in range(nb)] + if enable_dynamic_field: + data = cf.gen_default_rows_data(nb) + for i in range(nb): + data[i][ct.default_json_field_name] = {"string": string_list[i]} + data[i]['NewStr'] = string_list[i] + data[i][ct.default_string_field_name] = string_list[i] + else: + data = cf.gen_default_dataframe_data(nb) + data[ct.default_json_field_name] = [{"string": string_list[i]} for i in range(nb)] + data[ct.default_string_field_name] = string_list + collection_w.insert(data) + collection_w.flush() + collection_w.load() + + # delete with expressions + deleted_str = [s for s in string_list if s.startswith('a')] + expression = f"{field_name} {like} 'a%'" + res = collection_w.delete(expression)[0] + assert res.delete_count == len(deleted_str) + + # query to check + collection_w.load() + collection_w.query("int64 >= 0", output_fields=['count(*)'], + check_task=CheckTasks.check_query_results, + check_items={'count(*)': nb - len(deleted_str)}) + + @pytest.mark.tags(CaseLabel.L2) + def test_delete_expr_empty_string(self): + """ + target: test delete with expr empty + method: delete with expr="" + expected: raise exception + """ + # init collection with nb default data + collection_w = self.init_collection_general(prefix, True)[0] + + # delete + error = {ct.err_code: 1, ct.err_msg: "expr cannot be empty"} + collection_w.delete(expr="", check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L2) + def test_delete_complex_expr_before_load(self): + """ + target: test delete before load + method: delete with any complex expr + expected: raise exception + """ + # init collection with nb default data + collection_w = self.init_collection_general(prefix, False)[0] + + # delete + error = {ct.err_code: 1, ct.err_msg: "collection not loaded: unrecoverable error"} + collection_w.delete(expr="int64 >= 0", check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("expr_prefix", ["json_contains", "JSON_CONTAINS"]) + @pytest.mark.parametrize("field_name", ["json_field['list']", "list"]) + @pytest.mark.parametrize("enable_dynamic_field", [True, False]) + def test_delete_expr_json_contains_base(self, expr_prefix, field_name, enable_dynamic_field): + """ + target: test delete expr using json_contains + method: delete using expression using json_contains + expected: delete successfully + """ + if field_name == "list" and enable_dynamic_field is False: + pytest.skip("only support when enable_dynamic_filed == True") + # init collection with nb default data + collection_w = self.init_collection_general(prefix, False, enable_dynamic_field=enable_dynamic_field)[0] + + # insert + listMix = [[i, i + 2] for i in range(ct.default_nb)] # only int + if enable_dynamic_field: + data = cf.gen_default_rows_data() + for i in range(ct.default_nb): + data[i][ct.default_json_field_name] = {"list": listMix[i]} + data[i]['list'] = listMix[i] + else: + data = cf.gen_default_dataframe_data() + data[ct.default_json_field_name] = [{"list": listMix[i]} for i in range(ct.default_nb)] + collection_w.insert(data) + collection_w.flush() + collection_w.load() + + # delete with expressions + delete_ids = random.randint(2, ct.default_nb - 2) + expression = f"{expr_prefix}({field_name}, {delete_ids})" + res = collection_w.delete(expression)[0] + exp_ids = cf.assert_json_contains(expression, listMix) + assert res.delete_count == len(exp_ids) + + # query to check + collection_w.query(expression, check_task=CheckTasks.check_query_empty) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("expr_prefix", ["json_contains_all", "JSON_CONTAINS_ALL", + "json_contains_any", "JSON_CONTAINS_ANY"]) + @pytest.mark.parametrize("field_name", ["json_field['list']", "list"]) + @pytest.mark.parametrize("enable_dynamic_field", [True, False]) + def test_delete_expr_json_contains_all(self, expr_prefix, field_name, enable_dynamic_field): + """ + target: test delete expr using json_contains + method: delete using expression using json_contains + expected: delete successfully + """ + if field_name == "list" and enable_dynamic_field is False: + pytest.skip("only support when enable_dynamic_filed == True") + # init collection with nb default data + collection_w = self.init_collection_general(prefix, False, enable_dynamic_field=enable_dynamic_field)[0] + + # insert + listMix = [[i, i * 0.00001, bool(i % 2), [i, str(i)]] for i in range(ct.default_nb)] # mix int, float, list, bool + if enable_dynamic_field: + data = cf.gen_default_rows_data() + for i in range(ct.default_nb): + data[i][ct.default_json_field_name] = {"list": listMix[i]} + data[i]['list'] = listMix[i] + else: + data = cf.gen_default_dataframe_data() + data[ct.default_json_field_name] = [{"list": listMix[i]} for i in range(ct.default_nb)] + collection_w.insert(data) + collection_w.flush() + collection_w.load() + + # delete with expressions + ids = random.randint(0, ct.default_nb) + delete_ids = [bool(ids % 2), ids] + expression = f"{expr_prefix}({field_name}, {delete_ids})" + res = collection_w.delete(expression)[0] + exp_ids = cf.assert_json_contains(expression, listMix) + assert res.delete_count == len(exp_ids) + + # query to check + collection_w.query(expression, check_task=CheckTasks.check_query_empty) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("expressions", + cf.gen_field_compare_expressions(["int64_1", "json_field['int'][0]"], + ["int64_2", "json_field['int'][1]"])) + def test_delete_expr_compare_two_variables(self, expressions): + """ + target: test delete expr using 2 variables + method: delete with expressions using compare 2 variables + expected: delete successfully + """ + # init collection with nb default data + nb = 1000 + dim = 32 + fields = [cf.gen_int64_field("int64_1"), cf.gen_int64_field("int64_2"), + cf.gen_json_field("json_field"), cf.gen_float_vec_field("float_vector", dim=dim)] + schema = cf.gen_collection_schema(fields=fields, primary_field="int64_1") + collection_w = self.init_collection_wrap(schema=schema) + + # insert + int64_1_values = [i for i in range(nb)] + int64_2_values = [random.randint(0, nb) for _ in range(nb)] + vectors = cf.gen_vectors(nb, dim) + json_values = [[i, int64_2_values[i]] for i in range(nb)] + data = pd.DataFrame({ + "int64_1": int64_1_values, + "int64_2": int64_2_values, + "json_field": [{"int": json_values[i]} for i in range(nb)], + "float_vector": vectors + }) + collection_w.insert(data) + collection_w.flush() + collection_w.create_index("float_vector") + collection_w.load() + + # delete with expressions + error = {ct.err_code: 1, ct.err_msg: f"failed to create expr plan, expr = {expressions}"} + collection_w.delete(expressions, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("expression", cf.gen_json_field_expressions()) + @pytest.mark.parametrize("enable_dynamic_field", [True, False]) + def test_delete_expr_json_field(self, expression, enable_dynamic_field): + """ + target: test delete entities using normal expression + method: delete using normal expression + expected: delete successfully + """ + # init collection with nb default data + collection_w, _vectors, _, insert_ids = \ + self.init_collection_general(prefix, True, enable_dynamic_field=enable_dynamic_field)[0:4] + + # filter result with expression in collection + _vectors = _vectors[0] + expression = expression.replace("&&", "and").replace("||", "or") + filter_ids = [] + json_field = {} + for i, _id in enumerate(insert_ids): + if enable_dynamic_field: + json_field['number'] = _vectors[i][ct.default_json_field_name]['number'] + json_field['float'] = _vectors[i][ct.default_json_field_name]['float'] + else: + json_field['number'] = _vectors[ct.default_json_field_name][i]['number'] + json_field['float'] = _vectors[ct.default_json_field_name][i]['float'] + if not expression or eval(expression): + filter_ids.append(_id) + + # delete with expressions + res = collection_w.delete(expression)[0] + assert res.delete_count == len(filter_ids) + + # query to check + collection_w.query(f"int64 in {filter_ids}", check_task=CheckTasks.check_query_empty) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("normal_expression, json_expression", zip(cf.gen_normal_expressions()[1:4], + cf.gen_json_field_expressions()[6:9])) + @pytest.mark.parametrize("enable_dynamic_field", [True, False]) + def test_delete_expr_complex_mixed(self, normal_expression, json_expression, enable_dynamic_field): + """ + target: test delete entities using normal expression + method: delete using normal expression + expected: delete successfully + """ + # init collection with nb default data + collection_w, _vectors, _, insert_ids = \ + self.init_collection_general(prefix, True, enable_dynamic_field=enable_dynamic_field)[0:4] + + # filter result with expression in collection + expression = normal_expression + ' and ' + json_expression + _vectors = _vectors[0] + expression = expression.replace("&&", "and").replace("||", "or") + filter_ids = [] + json_field = {} + for i, _id in enumerate(insert_ids): + if enable_dynamic_field: + json_field['number'] = _vectors[i][ct.default_json_field_name]['number'] + json_field['float'] = _vectors[i][ct.default_json_field_name]['float'] + int64 = _vectors[i][ct.default_int64_field_name] + float = _vectors[i][ct.default_float_field_name] + else: + json_field['number'] = _vectors[ct.default_json_field_name][i]['number'] + json_field['float'] = _vectors[ct.default_json_field_name][i]['float'] + int64 = _vectors.int64[i] + float = _vectors.float[i] + if not expression or eval(expression): + filter_ids.append(_id) + + # delete with expressions + res = collection_w.delete(expression)[0] + assert res.delete_count == len(filter_ids) + + # query to check + collection_w.query(f"int64 in {filter_ids}", check_task=CheckTasks.check_query_empty) + + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("expression", cf.gen_normal_string_expressions(["varchar", "json_field['string']", "NewStr"])) + @pytest.mark.parametrize("enable_dynamic_field", [True, False]) + def test_delete_string_expressions_normal(self, expression, enable_dynamic_field): + """ + target: test delete expr like + method: delete using expression like + expected: delete successfully + """ + if "NewStr" in expression and enable_dynamic_field is False: + pytest.skip("only support when enable_dynamic_filed == True") + # init collection with nb default data + nb = 1000 + collection_w, _vectors, _, insert_ids = \ + self.init_collection_general(prefix, False, enable_dynamic_field=enable_dynamic_field)[0:4] + + # insert + if enable_dynamic_field: + data = cf.gen_default_rows_data(nb) + for i in range(nb): + data[i][ct.default_json_field_name] = {"string": str(i)} + data[i]['NewStr'] = str(i) + else: + data = cf.gen_default_dataframe_data(nb) + data[ct.default_json_field_name] = [{"string": str(i)} for i in range(nb)] + collection_w.insert(data) + collection_w.flush() + collection_w.load() + + # calculate the result + _vectors = data + expression = expression.replace("&&", "and").replace("||", "or") + filter_ids = [] + json_field = {} + for i in range(nb): + if enable_dynamic_field: + json_field['string'] = _vectors[i][ct.default_json_field_name]['string'] + varchar = _vectors[i][ct.default_string_field_name] + NewStr = _vectors[i]['NewStr'] + else: + json_field['string'] = _vectors[ct.default_json_field_name][i]['string'] + varchar = _vectors.varchar[i] + if not expression or eval(expression): + filter_ids.append(i) + + # delete with expressions + res = collection_w.delete(expression)[0] + assert res.delete_count == len(filter_ids) + + # query to check + collection_w.load() + collection_w.query("int64 >= 0", output_fields=['count(*)'], + check_task=CheckTasks.check_query_results, + check_items={'count(*)': nb - len(filter_ids)}) + diff --git a/tests/python_client/testcases/test_high_level_api.py b/tests/python_client/testcases/test_high_level_api.py index d4519d51a5274..d5f40b908e4fc 100644 --- a/tests/python_client/testcases/test_high_level_api.py +++ b/tests/python_client/testcases/test_high_level_api.py @@ -86,7 +86,8 @@ def test_high_level_collection_string_auto_id(self): client = self._connect(enable_high_level_api=True) collection_name = cf.gen_unique_str(prefix) # 1. create collection - error = {ct.err_code: 1, ct.err_msg: f"The auto_id can only be specified on field with DataType.INT64"} + error = {ct.err_code: 65535, ct.err_msg: f"type param(max_length) should be specified for varChar " + f"field of collection {collection_name}"} client_w.create_collection(client, collection_name, default_dim, id_type="string", auto_id=True, check_task=CheckTasks.err_res, check_items=error) @@ -121,7 +122,8 @@ def test_high_level_collection_invalid_metric_type(self): client = self._connect(enable_high_level_api=True) collection_name = cf.gen_unique_str(prefix) # 1. create collection - error = {ct.err_code: 1, ct.err_msg: f"metric type not found or not supported, supported: [L2 IP COSINE]"} + error = {ct.err_code: 65535, + ct.err_msg: "metric type not found or not supported, supported: [L2 IP COSINE HAMMING JACCARD]"} client_w.create_collection(client, collection_name, default_dim, metric_type="invalid", check_task=CheckTasks.err_res, check_items=error) diff --git a/tests/python_client/testcases/test_index.py b/tests/python_client/testcases/test_index.py index a9eb7924986ea..d5d080a8d8553 100644 --- a/tests/python_client/testcases/test_index.py +++ b/tests/python_client/testcases/test_index.py @@ -1,3 +1,4 @@ +import random from time import sleep import pytest import copy @@ -109,10 +110,10 @@ def test_index_type_invalid(self, index_type): if not isinstance(index_params["index_type"], str): msg = "must be str" else: - msg = "Invalid index_type" + msg = "invalid index type" self.index_wrap.init_index(collection_w.collection, default_field_name, index_params, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: msg}) + check_items={ct.err_code: 65535, ct.err_msg: msg}) @pytest.mark.tags(CaseLabel.L1) def test_index_type_not_supported(self): @@ -212,7 +213,8 @@ def test_index_create_with_different_indexes(self): c_name = cf.gen_unique_str(prefix) collection_w = self.init_collection_wrap(name=c_name) self.index_wrap.init_index(collection_w.collection, default_field_name, default_index_params) - error = {ct.err_code: 1, ct.err_msg: f"CreateIndex failed: index already exists"} + error = {ct.err_code: 65535, ct.err_msg: "CreateIndex failed: at most one " + "distinct index is allowed per field"} self.index_wrap.init_index(collection_w.collection, default_field_name, default_index, check_task=CheckTasks.err_res, check_items=error) @@ -241,8 +243,9 @@ def test_index_create_on_scalar_field(self): collection_w = self.init_collection_general(prefix, True, is_index=False)[0] collection_w.create_index(ct.default_int64_field_name, {}) collection_w.load(check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "there is no vector index on collection, " - "please create index firstly"}) + check_items={ct.err_code: 65535, + ct.err_msg: f"there is no vector index on collection: {collection_w.name}, " + f"please create index firstly"}) @pytest.mark.tags(CaseLabel.L1) def test_index_collection_empty(self): @@ -1149,8 +1152,9 @@ def test_create_index_invalid_metric_type_binary(self): binary_index_params = {'index_type': 'BIN_IVF_FLAT', 'metric_type': 'L2', 'params': {'nlist': 64}} collection_w.create_index(default_binary_vec_field_name, binary_index_params, index_name=binary_field_name, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, - ct.err_msg: "Invalid metric_type: L2, which does not match the index type: BIN_IVF_FLAT"}) + check_items={ct.err_code: 65535, + ct.err_msg: "metric type not found or not supported, supported: " + "[HAMMING JACCARD SUBSTRUCTURE SUPERSTRUCTURE]"}) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("metric_type", ["L2", "IP", "COSINE", "JACCARD", "HAMMING"]) @@ -1220,10 +1224,10 @@ def test_drop_index_partition(self): assert len(ins_res.primary_keys) == len(df) collection_w.create_index(default_binary_vec_field_name, default_binary_index_params, index_name=binary_field_name) - assert collection_w.has_index(index_name=binary_field_name)[0] == True + assert collection_w.has_index(index_name=binary_field_name)[0] is True assert len(collection_w.indexes) == 1 collection_w.drop_index(index_name=binary_field_name) - assert collection_w.has_index(index_name=binary_field_name)[0] == False + assert collection_w.has_index(index_name=binary_field_name)[0] is False assert len(collection_w.indexes) == 0 @@ -1298,7 +1302,7 @@ def test_drop_index_without_release(self): "loaded, please release it first"}) @pytest.mark.tags(CaseLabel.L1) - @pytest.mark.parametrize("n_trees", [-1, 1025, 'a', {34}]) + @pytest.mark.parametrize("n_trees", [-1, 1025, 'a']) def test_annoy_index_with_invalid_params(self, n_trees): """ target: test create index with invalid params @@ -1310,8 +1314,8 @@ def test_annoy_index_with_invalid_params(self, n_trees): index_annoy = {"index_type": "ANNOY", "params": {"n_trees": n_trees}, "metric_type": "L2"} collection_w.create_index("float_vector", index_annoy, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "invalid index params"}) + check_items={"err_code": 65535, + "err_msg": "invalid index type: ANNOY"}) @pytest.mark.tags(CaseLabel.L1) def test_create_index_json(self): @@ -1849,8 +1853,8 @@ def test_create_diskann_index_with_binary(self): collection_w.insert(data=df) collection_w.create_index(default_binary_vec_field_name, ct.default_diskann_index, index_name=binary_field_name, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, - ct.err_msg: "field data type BinaryVector don't support the index build type DISKANN"}) + check_items={ct.err_code: 65535, + ct.err_msg: "float or float16 vector are only supported"}) @pytest.mark.tags(CaseLabel.L2) def test_create_diskann_index_multithread(self): @@ -1879,7 +1883,7 @@ def build(collection_w): for t in threads: t.join() - @pytest.mark.skip(reason = "diskann dim range is set to be [1, 32768)") + @pytest.mark.skip(reason="diskann dim range is set to be [1, 32768)") @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("dim", [2, 4, 8]) def test_create_index_with_small_dim(self, dim): @@ -1958,3 +1962,50 @@ def test_create_autoindex_on_binary_vectors(self): collection_w.create_index(binary_field_name, {}) actual_index_params = collection_w.index()[0].params assert default_autoindex_params == actual_index_params + + +@pytest.mark.tags(CaseLabel.GPU) +class TestScaNNIndex(TestcaseBase): + """ Test case of Auto index """ + + @pytest.mark.tags(CaseLabel.L1) + def test_create_scann_index(self): + """ + target: test create scann index + method: create index with only one field name + expected: create successfully + """ + collection_w = self.init_collection_general(prefix, is_index=False)[0] + index_params = {"index_type": "SCANN", "metric_type": "L2", + "params": {"nlist": 1024, "with_raw_data": True}} + collection_w.create_index(default_field_name, index_params) + assert collection_w.has_index()[0] is True + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("nlist", [0, 65537]) + def test_create_scann_index_nlist_invalid(self, nlist): + """ + target: test create scann index invalid + method: create index with invalid nlist + expected: report error + """ + collection_w = self.init_collection_general(prefix, is_index=False)[0] + index_params = {"index_type": "SCANN", "metric_type": "L2", "params": {"nlist": nlist}} + error = {ct.err_code: 65535, ct.err_msg: "nlist out of range: [1, 65536]"} + collection_w.create_index(default_field_name, index_params, + check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("dim", [1, 127]) + def test_create_scann_index_dim_invalid(self, dim): + """ + target: test create scann index invalid + method: create index on vector dim % 2 == 1 + expected: report error + """ + collection_w = self.init_collection_general(prefix, is_index=False, dim=dim)[0] + index_params = {"index_type": "SCANN", "metric_type": "L2", "params": {"nlist": 1024}} + error = {ct.err_code: 65535, + ct.err_msg: f"dimension must be able to be divided by 2, dimension: {dim}"} + collection_w.create_index(default_field_name, index_params, + check_task=CheckTasks.err_res, check_items=error) diff --git a/tests/python_client/testcases/test_insert.py b/tests/python_client/testcases/test_insert.py index 588cbe5ac07b6..518efb88c2eb2 100644 --- a/tests/python_client/testcases/test_insert.py +++ b/tests/python_client/testcases/test_insert.py @@ -118,8 +118,8 @@ def test_insert_dataframe_only_columns(self): columns = [ct.default_int64_field_name, ct.default_float_vec_field_name] df = pd.DataFrame(columns=columns) - error = {ct.err_code: 0, - ct.err_msg: "Cannot infer schema from empty dataframe"} + error = {ct.err_code: 1, + ct.err_msg: "The data don't match with schema fields, expect 5 list, got 0"} collection_w.insert( data=df, check_task=CheckTasks.err_res, check_items=error) @@ -289,6 +289,7 @@ def test_insert_field_name_not_match(self): data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.skip(reason="Currently not check in pymilvus") def test_insert_field_value_not_match(self): """ target: test insert data value not match @@ -299,12 +300,11 @@ def test_insert_field_value_not_match(self): collection_w = self.init_collection_wrap(name=c_name) nb = 10 df = cf.gen_default_dataframe_data(nb) - new_float_value = pd.Series( - data=[float(i) for i in range(nb)], dtype="float64") + new_float_value = pd.Series(data=[float(i) for i in range(nb)], dtype="float64") df[df.columns[1]] = new_float_value - error = {ct.err_code: 5} - collection_w.insert( - data=df, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, + ct.err_msg: "The data type of field float doesn't match, expected: FLOAT, got DOUBLE"} + collection_w.insert(data=df, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_insert_value_less(self): @@ -391,7 +391,8 @@ def test_insert_list_order_inconsistent_schema(self): float_values = [np.float32(i) for i in range(nb)] float_vec_values = cf.gen_vectors(nb, ct.default_dim) data = [float_values, int_values, float_vec_values] - error = {ct.err_code: 5} + error = {ct.err_code: 1, + ct.err_msg: "The data type of field int64 doesn't match, expected: INT64, got FLOAT"} collection_w.insert( data=data, check_task=CheckTasks.err_res, check_items=error) @@ -414,7 +415,7 @@ def test_insert_dataframe_order_inconsistent_schema(self): ct.default_float_vec_field_name: float_vec_values, ct.default_int64_field_name: int_values }) - error = {ct.err_code: 5, ct.err_msg: 'Missing param in entities'} + error = {ct.err_code: 1, ct.err_msg: "The fields don't match with schema fields"} collection_w.insert( data=df, check_task=CheckTasks.err_res, check_items=error) @@ -1136,7 +1137,7 @@ def test_insert_async_invalid_data(self): ct.default_float_vec_field_name] df = pd.DataFrame(columns=columns) error = {ct.err_code: 0, - ct.err_msg: "Cannot infer schema from empty dataframe"} + ct.err_msg: "The fields don't match with schema fields"} collection_w.insert(data=df, _async=True, check_task=CheckTasks.err_res, check_items=error) @@ -1263,7 +1264,7 @@ def test_insert_with_invalid_partition_name(self): collection_name = cf.gen_unique_str(prefix) collection_w = self.init_collection_wrap(name=collection_name) df = cf.gen_default_list_data(ct.default_nb) - error = {ct.err_code: 1, 'err_msg': "partition name is illegal"} + error = {ct.err_code: 15, 'err_msg': "partition not found"} mutation_res, _ = collection_w.insert(data=df, partition_name="p", check_task=CheckTasks.err_res, check_items=error) @@ -1313,7 +1314,7 @@ def test_insert_int8_overflow(self, invalid_int8): prefix, is_all_data_type=True)[0] data = cf.gen_dataframe_all_data_type(nb=1) data[ct.default_int8_field_name] = [invalid_int8] - error = {ct.err_code: 1, 'err_msg': "The data type of field int8 doesn't match, " + error = {ct.err_code: 1100, 'err_msg': "The data type of field int8 doesn't match, " "expected: INT8, got INT64"} collection_w.insert( data, check_task=CheckTasks.err_res, check_items=error) @@ -1330,7 +1331,7 @@ def test_insert_int16_overflow(self, invalid_int16): prefix, is_all_data_type=True)[0] data = cf.gen_dataframe_all_data_type(nb=1) data[ct.default_int16_field_name] = [invalid_int16] - error = {ct.err_code: 1, 'err_msg': "The data type of field int16 doesn't match, " + error = {ct.err_code: 1100, 'err_msg': "The data type of field int16 doesn't match, " "expected: INT16, got INT64"} collection_w.insert( data, check_task=CheckTasks.err_res, check_items=error) @@ -1665,13 +1666,11 @@ def test_upsert_data_pk_exist(self, start): """ upsert_nb = 1000 collection_w = self.init_collection_general(pre_upsert, True)[0] - upsert_data, float_values = cf.gen_default_data_for_upsert( - upsert_nb, start=start) + upsert_data, float_values = cf.gen_default_data_for_upsert(upsert_nb, start=start) collection_w.upsert(data=upsert_data) exp = f"int64 >= {start} && int64 <= {upsert_nb + start}" res = collection_w.query(exp, output_fields=[default_float_name])[0] - assert [res[i][default_float_name] - for i in range(upsert_nb)] == float_values.to_list() + assert [res[i][default_float_name] for i in range(upsert_nb)] == float_values.to_list() @pytest.mark.tags(CaseLabel.L2) def test_upsert_with_primary_key_string(self): @@ -1683,13 +1682,10 @@ def test_upsert_with_primary_key_string(self): expected: raise no exception """ c_name = cf.gen_unique_str(pre_upsert) - fields = [cf.gen_string_field(), cf.gen_float_vec_field( - dim=ct.default_dim)] - schema = cf.gen_collection_schema( - fields=fields, primary_field=ct.default_string_field_name) + fields = [cf.gen_string_field(), cf.gen_float_vec_field(dim=ct.default_dim)] + schema = cf.gen_collection_schema(fields=fields, primary_field=ct.default_string_field_name) collection_w = self.init_collection_wrap(name=c_name, schema=schema) - vectors = [[random.random() for _ in range(ct.default_dim)] - for _ in range(2)] + vectors = [[random.random() for _ in range(ct.default_dim)] for _ in range(2)] collection_w.insert([["a", "b"], vectors]) collection_w.upsert([[" a", "b "], vectors]) assert collection_w.num_entities == 4 @@ -1705,14 +1701,12 @@ def test_upsert_binary_data(self): """ nb = 500 c_name = cf.gen_unique_str(pre_upsert) - collection_w = self.init_collection_general( - c_name, True, is_binary=True)[0] + collection_w = self.init_collection_general(c_name, True, is_binary=True)[0] binary_vectors = cf.gen_binary_vectors(nb, ct.default_dim)[1] data = [[i for i in range(nb)], [np.float32(i) for i in range(nb)], [str(i) for i in range(nb)], binary_vectors] collection_w.upsert(data) - res = collection_w.query( - "int64 >= 0", [ct.default_binary_vec_field_name])[0] + res = collection_w.query("int64 >= 0", [ct.default_binary_vec_field_name])[0] assert binary_vectors[0] == res[0][ct. default_binary_vec_field_name][0] @pytest.mark.tags(CaseLabel.L1) @@ -1742,8 +1736,7 @@ def test_upsert_data_is_none(self): 3. upsert data=None expected: raise no exception """ - collection_w = self.init_collection_general( - pre_upsert, insert_data=True, is_index=False)[0] + collection_w = self.init_collection_general(pre_upsert, insert_data=True, is_index=False)[0] assert collection_w.num_entities == ct.default_nb collection_w.upsert(data=None) assert collection_w.num_entities == ct.default_nb @@ -1762,8 +1755,7 @@ def test_upsert_in_specific_partition(self): collection_w = self.init_collection_wrap(name=c_name) collection_w.create_partition("partition_new") cf.insert_data(collection_w) - collection_w.create_index( - ct.default_float_vec_field_name, default_index_params) + collection_w.create_index(ct.default_float_vec_field_name, default_index_params) collection_w.load() # check the ids which will be upserted is in partition _default @@ -1772,10 +1764,8 @@ def test_upsert_in_specific_partition(self): res0 = collection_w.query(expr, [default_float_name], ["_default"])[0] assert len(res0) == upsert_nb collection_w.flush() - res1 = collection_w.query( - expr, [default_float_name], ["partition_new"])[0] - assert collection_w.partition('partition_new')[ - 0].num_entities == ct.default_nb // 2 + res1 = collection_w.query(expr, [default_float_name], ["partition_new"])[0] + assert collection_w.partition('partition_new')[0].num_entities == ct.default_nb // 2 # upsert ids in partition _default data, float_values = cf.gen_default_data_for_upsert(upsert_nb) @@ -1784,13 +1774,10 @@ def test_upsert_in_specific_partition(self): # check the result in partition _default(upsert successfully) and others(no missing, nothing new) collection_w.flush() res0 = collection_w.query(expr, [default_float_name], ["_default"])[0] - res2 = collection_w.query( - expr, [default_float_name], ["partition_new"])[0] + res2 = collection_w.query(expr, [default_float_name], ["partition_new"])[0] assert res1 == res2 - assert [res0[i][default_float_name] - for i in range(upsert_nb)] == float_values.to_list() - assert collection_w.partition('partition_new')[ - 0].num_entities == ct.default_nb // 2 + assert [res0[i][default_float_name] for i in range(upsert_nb)] == float_values.to_list() + assert collection_w.partition('partition_new')[0].num_entities == ct.default_nb // 2 @pytest.mark.tags(CaseLabel.L2) # @pytest.mark.skip(reason="issue #22592") @@ -1810,15 +1797,13 @@ def test_upsert_in_mismatched_partitions(self): # insert data and load collection cf.insert_data(collection_w) - collection_w.create_index( - ct.default_float_vec_field_name, default_index_params) + collection_w.create_index(ct.default_float_vec_field_name, default_index_params) collection_w.load() # check the ids which will be upserted is not in partition 'partition_1' upsert_nb = 100 expr = f"int64 >= 0 && int64 <= {upsert_nb}" - res = collection_w.query( - expr, [default_float_name], ["partition_1"])[0] + res = collection_w.query(expr, [default_float_name], ["partition_1"])[0] assert len(res) == 0 # upsert in partition 'partition_1' @@ -1826,10 +1811,8 @@ def test_upsert_in_mismatched_partitions(self): collection_w.upsert(data, "partition_1") # check the upserted data in 'partition_1' - res1 = collection_w.query( - expr, [default_float_name], ["partition_1"])[0] - assert [res1[i][default_float_name] - for i in range(upsert_nb)] == float_values.to_list() + res1 = collection_w.query(expr, [default_float_name], ["partition_1"])[0] + assert [res1[i][default_float_name] for i in range(upsert_nb)] == float_values.to_list() @pytest.mark.tags(CaseLabel.L1) def test_upsert_same_pk_concurrently(self): @@ -1843,8 +1826,7 @@ def test_upsert_same_pk_concurrently(self): # initialize a collection upsert_nb = 1000 collection_w = self.init_collection_general(pre_upsert, True)[0] - data1, float_values1 = cf.gen_default_data_for_upsert( - upsert_nb, size=1000) + data1, float_values1 = cf.gen_default_data_for_upsert(upsert_nb, size=1000) data2, float_values2 = cf.gen_default_data_for_upsert(upsert_nb) # upsert at the same time @@ -1864,8 +1846,7 @@ def do_upsert2(): # check the result exp = f"int64 >= 0 && int64 <= {upsert_nb}" - res = collection_w.query( - exp, [default_float_name], consistency_level="Strong")[0] + res = collection_w.query(exp, [default_float_name], consistency_level="Strong")[0] res = [res[i][default_float_name] for i in range(upsert_nb)] if not (res == float_values1.to_list() or res == float_values2.to_list()): assert False @@ -1910,13 +1891,31 @@ def test_upsert_pk_string_multiple_times(self): data = cf.gen_default_list_data(upsert_nb, start=i * step) collection_w.upsert(data) # load - collection_w.create_index( - ct.default_float_vec_field_name, default_index_params) + collection_w.create_index(ct.default_float_vec_field_name, default_index_params) collection_w.load() # check the result res = collection_w.query(expr="", output_fields=["count(*)"])[0] assert res[0]["count(*)"] == upsert_nb * 10 - step * 9 + @pytest.mark.tags(CaseLabel.L2) + def test_upsert_enable_dynamic_field(self): + """ + target: test upsert when enable dynamic field is True + method: 1. create a collection and insert data + 2. upsert + expected: not raise exception + """ + upsert_nb = ct.default_nb + start = ct.default_nb // 2 + collection_w = self.init_collection_general(pre_upsert, True, enable_dynamic_field=True)[0] + upsert_data = cf.gen_default_rows_data(start=start) + for i in range(start, start + upsert_nb): + upsert_data[i - start]["new"] = [i, i + 1] + collection_w.upsert(data=upsert_data) + exp = f"int64 >= {start} && int64 <= {upsert_nb + start}" + res = collection_w.query(exp, output_fields=["new"])[0] + assert len(res[0]["new"]) == 2 + @pytest.mark.tags(CaseLabel.L1) @pytest.mark.skip("not support default_value now") @pytest.mark.parametrize("default_value", [[], None]) @@ -2034,8 +2033,7 @@ def test_upsert_non_data_type(self, data): collection_w = self.init_collection_wrap(name=c_name) error = {ct.err_code: 1, ct.err_msg: "The fields don't match with schema fields, expected: " "['int64', 'float', 'varchar', 'float_vector']"} - collection_w.upsert( - data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_upsert_pk_type_invalid(self): @@ -2050,8 +2048,7 @@ def test_upsert_pk_type_invalid(self): cf.gen_vectors(2, ct.default_dim)] error = {ct.err_code: 1, ct.err_msg: "The data type of field int64 doesn't match, " "expected: INT64, got VARCHAR"} - collection_w.upsert( - data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_upsert_data_unmatch(self): @@ -2067,8 +2064,7 @@ def test_upsert_data_unmatch(self): data = [1, "a", 2.0, vector] error = {ct.err_code: 1, ct.err_msg: "The fields don't match with schema fields, " "expected: ['int64', 'float', 'varchar', 'float_vector']"} - collection_w.upsert( - data=[data], check_task=CheckTasks.err_res, check_items=error) + collection_w.upsert(data=[data], check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("vector", [[], [1.0, 2.0], "a", 1.0, None]) @@ -2084,8 +2080,7 @@ def test_upsert_vector_unmatch(self, vector): data = [2.0, "a", vector] error = {ct.err_code: 1, ct.err_msg: "The fields don't match with schema fields, " "expected: ['int64', 'float', 'varchar', 'float_vector']"} - collection_w.upsert( - data=[data], check_task=CheckTasks.err_res, check_items=error) + collection_w.upsert(data=[data], check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("dim", [120, 129, 200]) @@ -2096,13 +2091,11 @@ def test_upsert_binary_dim_unmatch(self, dim): 2. upsert with mismatched dim expected: raise exception """ - collection_w = self.init_collection_general( - pre_upsert, True, is_binary=True)[0] + collection_w = self.init_collection_general(pre_upsert, True, is_binary=True)[0] data = cf.gen_default_binary_dataframe_data(dim=dim)[0] error = {ct.err_code: 1, ct.err_msg: f"Collection field dim is 128, but entities field dim is {dim}"} - collection_w.upsert( - data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("dim", [127, 129, 200]) @@ -2117,8 +2110,7 @@ def test_upsert_dim_unmatch(self, dim): data = cf.gen_default_data_for_upsert(dim=dim)[0] error = {ct.err_code: 1, ct.err_msg: f"Collection field dim is 128, but entities field dim is {dim}"} - collection_w.upsert( - data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("partition_name", ct.get_invalid_strs[7:13]) @@ -2151,7 +2143,7 @@ def test_upsert_partition_name_nonexistent(self): collection_w = self.init_collection_wrap(name=c_name) data = cf.gen_default_dataframe_data(nb=2) partition_name = "partition1" - error = {ct.err_code: 1, ct.err_msg: "partition is not exist: partition1"} + error = {ct.err_code: 15, ct.err_msg: f"partition={partition_name}: partition not found"} collection_w.upsert(data=data, partition_name=partition_name, check_task=CheckTasks.err_res, check_items=error) @@ -2183,17 +2175,16 @@ def test_upsert_with_auto_id(self): 2. upsert data no pk expected: raise exception """ - collection_w = self.init_collection_general( - pre_upsert, auto_id=True, is_index=False)[0] + collection_w = self.init_collection_general(pre_upsert, auto_id=True, is_index=False)[0] error = {ct.err_code: 1, ct.err_msg: "Upsert don't support autoid == true"} float_vec_values = cf.gen_vectors(ct.default_nb, ct.default_dim) data = [[np.float32(i) for i in range(ct.default_nb)], [str(i) for i in range(ct.default_nb)], float_vec_values] - collection_w.upsert( - data=data, check_task=CheckTasks.err_res, check_items=error) + collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.skip("not support default_value now") @pytest.mark.parametrize("default_value", [[], None]) def test_upsert_array_using_default_value(self, default_value): """ @@ -2212,6 +2203,7 @@ def test_upsert_array_using_default_value(self, default_value): check_items={ct.err_code: 1, ct.err_msg: "Field varchar don't match in entities[0]"}) @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.skip("not support default_value now") @pytest.mark.parametrize("default_value", [[], None]) def test_upsert_tuple_using_default_value(self, default_value): """ diff --git a/tests/python_client/testcases/test_partition.py b/tests/python_client/testcases/test_partition.py index 7e408568a363d..df0ab556036d6 100644 --- a/tests/python_client/testcases/test_partition.py +++ b/tests/python_client/testcases/test_partition.py @@ -177,8 +177,10 @@ def test_partition_max_length_name(self): partition_name = cf.gen_str_by_length(256) self.partition_wrap.init_partition(collection_w.collection, partition_name, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, 'err_msg': "is illegal"} - ) + check_items={ct.err_code: 65535, + ct.err_msg: f"Invalid partition name: {partition_name}. " + f"The length of a partition name must be less " + f"than 255 characters."}) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("partition_name", ["_Partiti0n", "pArt1_ti0n"]) @@ -208,10 +210,13 @@ def test_partition_invalid_name(self, partition_name): collection_w = self.init_collection_wrap() # create partition + error1 = {ct.err_code: 1, ct.err_msg: f"`partition_name` value {partition_name} is illegal"} + error2 = {ct.err_code: 65535, ct.err_msg: f"Invalid partition name: {partition_name}. Partition name can" + f" only contain numbers, letters and underscores."} + error = error1 if partition_name in [None, [], 1, [1, "2", 3], (1,), {1: 1}] else error2 self.partition_wrap.init_partition(collection_w.collection, partition_name, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, 'err_msg': "is illegal"} - ) + check_items=error) # TODO: need an error code issue #5144 and assert independently @pytest.mark.tags(CaseLabel.L2) @@ -370,7 +375,8 @@ def test_load_replica_greater_than_querynodes(self): assert partition_w.num_entities == ct.default_nb # load with 2 replicas - error = {ct.err_code: 1, ct.err_msg: f"no enough nodes to create replicas"} + error = {ct.err_code: 65535, + ct.err_msg: "failed to load partitions: failed to spawn replica for collection: nodes not enough"} collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) partition_w.load(replica_number=3, check_task=CheckTasks.err_res, check_items=error) @@ -397,7 +403,8 @@ def test_load_replica_change(self): partition_w.load(replica_number=1) collection_w.query(expr=f"{ct.default_int64_field_name} in [0]", check_task=CheckTasks.check_query_results, check_items={'exp_res': [{'int64': 0}]}) - error = {ct.err_code: 5, ct.err_msg: f"Should release first then reload with the new number of replicas"} + error = {ct.err_code: 1100, ct.err_msg: "failed to load partitions: can't change the replica number for " + "loaded partitions: expected=1, actual=2: invalid parameter"} partition_w.load(replica_number=2, check_task=CheckTasks.err_res, check_items=error) partition_w.release() @@ -499,7 +506,8 @@ def test_partition_release(self): anns_field=ct.default_float_vec_field_name, params={"nprobe": 32}, limit=1, check_task=ct.CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "partitions have been released"}) + check_items={ct.err_code: 65535, + ct.err_msg: "collection not loaded"}) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("data", [cf.gen_default_dataframe_data(10), @@ -595,7 +603,7 @@ def test_partition_dropped_collection(self): # create partition failed self.partition_wrap.init_partition(collection_w.collection, cf.gen_unique_str(prefix), check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "can't find collection"}) + check_items={ct.err_code: 4, ct.err_msg: "collection not found"}) @pytest.mark.tags(CaseLabel.L2) def test_partition_same_name_in_diff_collections(self): @@ -665,8 +673,9 @@ def create_partition(collection, threads_n): self.partition_wrap.init_partition( collection_w.collection, p_name, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, - ct.err_msg: "maximum partition's number should be limit to 4096"}) + check_items={ct.err_code: 65535, + ct.err_msg: "partition number (4096) exceeds max configuration (4096), " + "collection: {}".format(collection_w.name)}) # TODO: Try to verify load collection with a large number of partitions. #11651 @@ -849,7 +858,7 @@ def test_partition_release_dropped_collection(self): # release the partition and check err response partition_w.release(check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "can't find collection"}) + check_items={ct.err_code: 4, ct.err_msg: "collection not found"}) @pytest.mark.tags(CaseLabel.L1) def test_partition_release_after_collection_released(self): @@ -963,7 +972,7 @@ def test_partition_insert_dropped_collection(self): # insert data to partition partition_w.insert(cf.gen_default_dataframe_data(), check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "None Type"}) + check_items={ct.err_code: 4, ct.err_msg: "collection not found"}) @pytest.mark.tags(CaseLabel.L2) def test_partition_insert_maximum_size_data(self): diff --git a/tests/python_client/testcases/test_partition_key.py b/tests/python_client/testcases/test_partition_key.py index 80ab75822b31f..b2ba42ed23edb 100644 --- a/tests/python_client/testcases/test_partition_key.py +++ b/tests/python_client/testcases/test_partition_key.py @@ -25,6 +25,7 @@ def test_partition_key_on_field_schema(self, par_key_field): schema = cf.gen_collection_schema(fields=[pk_field, int64_field, string_field, vector_field], auto_id=True) c_name = cf.gen_unique_str("par_key") collection_w, _ = self.collection_wrap.init_collection(name=c_name, schema=schema) + assert len(collection_w.partitions) == ct.default_partition_num # insert nb = 1000 diff --git a/tests/python_client/testcases/test_query.py b/tests/python_client/testcases/test_query.py index 0b38223fceb7e..c5a7bc70646e6 100644 --- a/tests/python_client/testcases/test_query.py +++ b/tests/python_client/testcases/test_query.py @@ -27,10 +27,8 @@ default_expr = f'{ct.default_int64_field_name} >= 0' default_invalid_expr = "varchar >= 0" default_string_term_expr = f'{ct.default_string_field_name} in [\"0\", \"1\"]' -default_index_params = {"index_type": "IVF_SQ8", - "metric_type": "L2", "params": {"nlist": 64}} -binary_index_params = {"index_type": "BIN_IVF_FLAT", - "metric_type": "JACCARD", "params": {"nlist": 64}} +default_index_params = {"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}} +binary_index_params = {"index_type": "BIN_IVF_FLAT", "metric_type": "JACCARD", "params": {"nlist": 64}} default_entities = ut.gen_entities(ut.default_nb, is_normal=True) default_pos = 5 @@ -61,12 +59,10 @@ def test_query_invalid(self): method: query with invalid term expr expected: raise exception """ - collection_w, entities = self.init_collection_general( - prefix, insert_data=True, nb=10)[0:2] + collection_w, entities = self.init_collection_general(prefix, insert_data=True, nb=10)[0:2] term_expr = f'{default_int_field_name} in {entities[:default_pos]}' - error = {ct.err_code: 1, ct.err_msg: "unexpected token Identifier"} - collection_w.query( - term_expr, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 65535, ct.err_msg: "cannot parse expression: int64 in .."} + collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L0) def test_query(self, enable_dynamic_field): @@ -84,16 +80,13 @@ def test_query(self, enable_dynamic_field): for vector in vectors[0]: vector = vector[ct.default_int64_field_name] int_values.append(vector) - res = [{ct.default_int64_field_name: int_values[i]} - for i in range(pos)] + res = [{ct.default_int64_field_name: int_values[i]} for i in range(pos)] else: - int_values = vectors[0][ct.default_int64_field_name].values.tolist( - ) + int_values = vectors[0][ct.default_int64_field_name].values.tolist() res = vectors[0].iloc[0:pos, :1].to_dict('records') term_expr = f'{ct.default_int64_field_name} in {int_values[:pos]}' - collection_w.query( - term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L1) def test_query_no_collection(self): @@ -126,8 +119,7 @@ def test_query_empty_collection(self): """ c_name = cf.gen_unique_str(prefix) collection_w = self.init_collection_wrap(name=c_name) - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() res, _ = collection_w.query(default_term_expr) assert len(res) == 0 @@ -149,16 +141,14 @@ def test_query_auto_id_collection(self): ids = insert_res[1].primary_keys pos = 5 res = df.iloc[:pos, :1].to_dict('records') - self.collection_wrap.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() # query with all primary keys term_expr_1 = f'{ct.default_int64_field_name} in {ids[:pos]}' for i in range(5): res[i][ct.default_int64_field_name] = ids[i] - self.collection_wrap.query( - term_expr_1, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + self.collection_wrap.query(term_expr_1, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) # query with part primary keys term_expr_2 = f'{ct.default_int64_field_name} in {[ids[0], 0]}' @@ -176,8 +166,7 @@ def test_query_with_dup_primary_key(self, dim, dup_times): expected: query results are de-duplicated """ nb = ct.default_nb - collection_w, insert_data, _, _ = self.init_collection_general( - prefix, True, nb, dim=dim)[0:4] + collection_w, insert_data, _, _ = self.init_collection_general(prefix, True, nb, dim=dim)[0:4] # insert dup data multi times for i in range(dup_times): collection_w.insert(insert_data[0]) @@ -196,14 +185,12 @@ def test_query_auto_id_not_existed_primary_values(self): expected: query result is empty """ schema = cf.gen_default_collection_schema(auto_id=True) - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix), schema=schema) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema) df = cf.gen_default_dataframe_data(ct.default_nb) df.drop(ct.default_int64_field_name, axis=1, inplace=True) mutation_res, _ = collection_w.insert(data=df) assert collection_w.num_entities == ct.default_nb - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() term_expr = f'{ct.default_int64_field_name} in [0, 1, 2]' res, _ = collection_w.query(term_expr) @@ -216,11 +203,9 @@ def test_query_expr_none(self): method: query with expr None expected: raise exception """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] error = {ct.err_code: 0, ct.err_msg: "The type of expr must be string"} - collection_w.query( - None, check_task=CheckTasks.err_res, check_items=error) + collection_w.query(None, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_query_non_string_expr(self): @@ -229,13 +214,11 @@ def test_query_non_string_expr(self): method: query with non-string expr, eg 1, [] .. expected: raise exception """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] exprs = [1, 2., [], {}, ()] error = {ct.err_code: 0, ct.err_msg: "The type of expr must be string"} for expr in exprs: - collection_w.query( - expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_invalid_string(self): @@ -244,13 +227,11 @@ def test_query_expr_invalid_string(self): method: query with invalid string expr expected: raise exception """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] - error = {ct.err_code: 1, ct.err_msg: "Invalid expression!"} + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + error = {ct.err_code: 65535, ct.err_msg: "cannot parse expression: 12-s, error: field s not exist"} exprs = ["12-s", "中文", "a", " "] for expr in exprs: - collection_w.query( - expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.skip(reason="repeat with test_query, waiting for other expr") @@ -260,11 +241,9 @@ def test_query_expr_term(self): method: query with TermExpr expected: query result is correct """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] res = vectors[0].iloc[:2, :1].to_dict('records') - collection_w.query( - default_term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + collection_w.query(default_term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_not_existed_field(self): @@ -275,9 +254,9 @@ def test_query_expr_not_existed_field(self): """ collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix)) term_expr = 'field in [1, 2]' - error = {ct.err_code: 1, ct.err_msg: "fieldName(field) not found"} - collection_w.query( - term_expr, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 65535, + ct.err_msg: "cannot parse expression: field in [1, 2], error: field field not exist"} + collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_non_primary_fields(self): @@ -295,14 +274,12 @@ def test_query_expr_non_primary_fields(self): ct.default_float_field_name: pd.Series(data=[np.float32(i) for i in range(ct.default_nb)], dtype="float32"), ct.default_double_field_name: pd.Series(data=[np.double(i) for i in range(ct.default_nb)], dtype="double"), ct.default_string_field_name: pd.Series(data=[str(i) for i in range(ct.default_nb)], dtype="string"), - ct.default_float_vec_field_name: cf.gen_vectors( - ct.default_nb, ct.default_dim) + ct.default_float_vec_field_name: cf.gen_vectors(ct.default_nb, ct.default_dim) }) self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) assert self.collection_wrap.num_entities == ct.default_nb - self.collection_wrap.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() # query by non_primary non_vector scalar field @@ -333,26 +310,23 @@ def test_query_expr_by_bool_field(self): """ self._connect() df = cf.gen_default_dataframe_data() - bool_values = pd.Series( - data=[True if i % 2 == 0 else False for i in range(ct.default_nb)], dtype="bool") + bool_values = pd.Series(data=[True if i % 2 == 0 else False for i in range(ct.default_nb)], dtype="bool") df.insert(2, ct.default_bool_field_name, bool_values) self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) assert self.collection_wrap.num_entities == ct.default_nb - self.collection_wrap.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() # output bool field - res, _ = self.collection_wrap.query(default_term_expr, output_fields=[ - ct.default_bool_field_name]) - assert set(res[0].keys()) == { - ct.default_int64_field_name, ct.default_bool_field_name} + + res, _ = self.collection_wrap.query(default_term_expr, output_fields=[ct.default_bool_field_name]) + assert set(res[0].keys()) == {ct.default_int64_field_name, ct.default_bool_field_name} # not support filter bool field with expr 'bool in [0/ 1]' not_support_expr = f'{ct.default_bool_field_name} in [0]' - error = {ct.err_code: 1, - ct.err_msg: 'error: value \"0\" in list cannot be casted to Bool'} + error = {ct.err_code: 65535, + ct.err_msg: "cannot parse expression: bool in [0], error: value '0' in list cannot be casted to Bool"} self.collection_wrap.query(not_support_expr, output_fields=[ct.default_bool_field_name], check_task=CheckTasks.err_res, check_items=error) @@ -361,8 +335,7 @@ def test_query_expr_by_bool_field(self): exprs = [f'{ct.default_bool_field_name} in [{bool_value}]', f'{ct.default_bool_field_name} == {bool_value}'] for expr in exprs: - res, _ = self.collection_wrap.query( - expr, output_fields=[ct.default_bool_field_name]) + res, _ = self.collection_wrap.query(expr, output_fields=[ct.default_bool_field_name]) assert len(res) == ct.default_nb / 2 for _r in res: assert _r[ct.default_bool_field_name] == bool_value @@ -378,8 +351,7 @@ def test_query_expr_by_int8_field(self): self._connect() # construct collection from dataFrame according to [int64, float, int8, float_vec] df = cf.gen_default_dataframe_data() - int8_values = pd.Series(data=[np.int8(i) - for i in range(ct.default_nb)], dtype="int8") + int8_values = pd.Series(data=[np.int8(i) for i in range(ct.default_nb)], dtype="int8") df.insert(2, ct.default_int8_field_name, int8_values) self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) @@ -391,8 +363,7 @@ def test_query_expr_by_int8_field(self): # int8 range [-128, 127] so when nb=1200, there are many repeated int8 values equal to 0 for i in range(0, ct.default_nb, 256): res.extend(df.iloc[i:i + 1, :-2].to_dict('records')) - self.collection_wrap.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() self.collection_wrap.query(term_expr, output_fields=["float", "int64", "int8", "varchar"], check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @@ -442,19 +413,16 @@ def test_query_expr_wrong_term_keyword(self): method: query with wrong keyword term expr expected: raise exception """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] expr_1 = f'{ct.default_int64_field_name} inn [1, 2]' - error_1 = {ct.err_code: 1, - ct.err_msg: f'unexpected token Identifier("inn")'} - collection_w.query( - expr_1, check_task=CheckTasks.err_res, check_items=error_1) + error_1 = {ct.err_code: 65535, ct.err_msg: "cannot parse expression: int64 inn [1, 2], " + "error: invalid expression: int64 inn [1, 2]"} + collection_w.query(expr_1, check_task=CheckTasks.err_res, check_items=error_1) expr_3 = f'{ct.default_int64_field_name} in not [1, 2]' - error_3 = {ct.err_code: 1, - ct.err_msg: 'right operand of the InExpr must be array'} - collection_w.query( - expr_3, check_task=CheckTasks.err_res, check_items=error_3) + error_3 = {ct.err_code: 65535, ct.err_msg: "cannot parse expression: int64 in not [1, 2], " + "error: line 1:9 no viable alternative at input 'innot'"} + collection_w.query(expr_3, check_task=CheckTasks.err_res, check_items=error_3) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("field", [ct.default_int64_field_name, ct.default_float_field_name]) @@ -469,8 +437,7 @@ def test_query_expr_not_in_term(self, field): self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) assert self.collection_wrap.num_entities == ct.default_nb - self.collection_wrap.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() values = df[field].tolist() pos = 100 @@ -492,14 +459,12 @@ def test_query_expr_not_in_empty_and_all(self, pos): self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) assert self.collection_wrap.num_entities == ct.default_nb - self.collection_wrap.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() int64_values = df[ct.default_int64_field_name].tolist() term_expr = f'{ct.default_int64_field_name} not in {int64_values[pos:]}' res = df.iloc[:pos, :1].to_dict('records') - self.collection_wrap.query( - term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + self.collection_wrap.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L1) def test_query_expr_random_values(self): @@ -514,16 +479,14 @@ def test_query_expr_random_values(self): self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) assert self.collection_wrap.num_entities == 100 - self.collection_wrap.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() # random_values = [random.randint(0, ct.default_nb) for _ in range(4)] random_values = [0, 2, 4, 3] term_expr = f'{ct.default_int64_field_name} in {random_values}' res = df.iloc[random_values, :1].to_dict('records') - self.collection_wrap.query( - term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + self.collection_wrap.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_not_in_random(self): @@ -538,8 +501,7 @@ def test_query_expr_not_in_random(self): self.collection_wrap.construct_from_dataframe(cf.gen_unique_str(prefix), df, primary_field=ct.default_int64_field_name) assert self.collection_wrap.num_entities == 50 - self.collection_wrap.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() random_values = [i for i in range(10, 50)] @@ -547,8 +509,7 @@ def test_query_expr_not_in_random(self): random.shuffle(random_values) term_expr = f'{ct.default_int64_field_name} not in {random_values}' res = df.iloc[:10, :1].to_dict('records') - self.collection_wrap.query( - term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + self.collection_wrap.query(term_expr, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_non_array_term(self): @@ -560,13 +521,11 @@ def test_query_expr_non_array_term(self): exprs = [f'{ct.default_int64_field_name} in 1', f'{ct.default_int64_field_name} in "in"', f'{ct.default_int64_field_name} in (mn)'] - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] - error = {ct.err_code: 1, - ct.err_msg: "right operand of the InExpr must be array"} + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + error = {ct.err_code: 65535, ct.err_msg: "cannot parse expression: int64 in 1, " + "error: line 1:9 no viable alternative at input 'in1'"} for expr in exprs: - collection_w.query( - expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query(expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_empty_term_array(self): @@ -576,8 +535,7 @@ def test_query_expr_empty_term_array(self): expected: empty result """ term_expr = f'{ct.default_int64_field_name} in []' - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] res, _ = collection_w.query(term_expr) assert len(res) == 0 @@ -591,11 +549,12 @@ def test_query_expr_inconsistent_mix_term_array(self): """ collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix)) int_values = [[1., 2.], [1, 2.]] - error = {ct.err_code: 1, ct.err_msg: "type mismatch"} + error = {ct.err_code: 65535, + ct.err_msg: "cannot parse expression: int64 in [1.0, 2.0], error: value '1.0' " + "in list cannot be casted to Int64"} for values in int_values: term_expr = f'{ct.default_int64_field_name} in {values}' - collection_w.query( - term_expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_non_constant_array_term(self): @@ -604,14 +563,14 @@ def test_query_expr_non_constant_array_term(self): method: query with non-constant array expr expected: raise exception """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] constants = [[1], (), {}] - error = {ct.err_code: 1, ct.err_msg: "unsupported leaf node"} + error = {ct.err_code: 65535, + ct.err_msg: "cannot parse expression: int64 in [[1]], error: value '[1]' in " + "list cannot be casted to Int64"} for constant in constants: term_expr = f'{ct.default_int64_field_name} in [{constant}]' - collection_w.query( - term_expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("expr_prefix", ["json_contains", "JSON_CONTAINS"]) @@ -659,8 +618,7 @@ def test_query_expr_list_json_contains(self, expr_prefix): data = { ct.default_int64_field_name: i, ct.default_json_field_name: [str(m) for m in range(i, i + limit)], - ct.default_float_vec_field_name: cf.gen_vectors(1, ct.default_dim)[ - 0] + ct.default_float_vec_field_name: cf.gen_vectors(1, ct.default_dim)[0] } array.append(data) collection_w.insert(array) @@ -680,15 +638,13 @@ def test_query_expr_json_contains_combined_with_normal(self, enable_dynamic_fiel expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general( - prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() limit = ct.default_nb // 3 for i in range(ct.default_nb): - array[i][ct.default_json_field_name] = { - "number": i, "list": [m for m in range(i, i + limit)]} + array[i][ct.default_json_field_name] = {"number": i, "list": [m for m in range(i, i + limit)]} collection_w.insert(array) @@ -708,20 +664,21 @@ def test_query_expr_all_datatype_json_contains_all(self, enable_dynamic_field, e expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general( - prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() limit = 10 for i in range(ct.default_nb): content = { - "listInt": [m for m in range(i, i + limit)], # test for int + # test for int + "listInt": [m for m in range(i, i + limit)], # test for string "listStr": [str(m) for m in range(i, i + limit)], # test for float "listFlt": [m * 1.0 for m in range(i, i + limit)], - "listBool": [bool(i % 2)], # test for bool + # test for bool + "listBool": [bool(i % 2)], # test for list "listList": [[i, str(i + 1)], [i * 1.0, i + 1]], # test for mixed data @@ -779,24 +736,18 @@ def test_query_expr_list_all_datatype_json_contains_all(self, expr_prefix): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general( - prefix, enable_dynamic_field=True)[0] + collection_w = self.init_collection_general(prefix, enable_dynamic_field=True)[0] # 2. insert data array = cf.gen_default_rows_data(with_json=False) limit = 10 for i in range(ct.default_nb): - array[i]["listInt"] = [m for m in range( - i, i + limit)] # test for int - array[i]["listStr"] = [str(m) for m in range( - i, i + limit)] # test for string - array[i]["listFlt"] = [ - m * 1.0 for m in range(i, i + limit)] # test for float + array[i]["listInt"] = [m for m in range(i, i + limit)] # test for int + array[i]["listStr"] = [str(m) for m in range(i, i + limit)] # test for string + array[i]["listFlt"] = [m * 1.0 for m in range(i, i + limit)] # test for float array[i]["listBool"] = [bool(i % 2)] # test for bool - array[i]["listList"] = [ - [i, str(i + 1)], [i * 1.0, i + 1]] # test for list - array[i]["listMix"] = [i, i * 1.1, - str(i), bool(i % 2), [i, str(i)]] # test for mixed data + array[i]["listList"] = [[i, str(i + 1)], [i * 1.0, i + 1]] # test for list + array[i]["listMix"] = [i, i * 1.1, str(i), bool(i % 2), [i, str(i)]] # test for mixed data collection_w.insert(array) @@ -849,20 +800,21 @@ def test_query_expr_all_datatype_json_contains_any(self, enable_dynamic_field, e expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general( - prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() limit = 10 for i in range(ct.default_nb): content = { - "listInt": [m for m in range(i, i + limit)], # test for int + # test for int + "listInt": [m for m in range(i, i + limit)], # test for string "listStr": [str(m) for m in range(i, i + limit)], # test for float "listFlt": [m * 1.0 for m in range(i, i + limit)], - "listBool": [bool(i % 2)], # test for bool + # test for bool + "listBool": [bool(i % 2)], # test for list "listList": [[i, str(i + 1)], [i * 1.0, i + 1]], # test for mixed data @@ -921,24 +873,18 @@ def test_query_expr_list_all_datatype_json_contains_any(self, expr_prefix): expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general( - prefix, enable_dynamic_field=True)[0] + collection_w = self.init_collection_general(prefix, enable_dynamic_field=True)[0] # 2. insert data array = cf.gen_default_rows_data(with_json=False) limit = 10 for i in range(ct.default_nb): - array[i]["listInt"] = [m for m in range( - i, i + limit)] # test for int - array[i]["listStr"] = [str(m) for m in range( - i, i + limit)] # test for string - array[i]["listFlt"] = [ - m * 1.0 for m in range(i, i + limit)] # test for float + array[i]["listInt"] = [m for m in range(i, i + limit)] # test for int + array[i]["listStr"] = [str(m) for m in range(i, i + limit)] # test for string + array[i]["listFlt"] = [m * 1.0 for m in range(i, i + limit)] # test for float array[i]["listBool"] = [bool(i % 2)] # test for bool - array[i]["listList"] = [ - [i, str(i + 1)], [i * 1.0, i + 1]] # test for list - array[i]["listMix"] = [i, i * 1.1, - str(i), bool(i % 2), [i, str(i)]] # test for mixed data + array[i]["listList"] = [[i, str(i + 1)], [i * 1.0, i + 1]] # test for list + array[i]["listMix"] = [i, i * 1.1, str(i), bool(i % 2), [i, str(i)]] # test for mixed data collection_w.insert(array) @@ -991,14 +937,12 @@ def test_query_expr_json_contains_list_in_list(self, expr_prefix, enable_dynamic expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general( - prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() for i in range(ct.default_nb): - array[i][json_field] = { - "list": [[i, i + 1], [i, i + 2], [i, i + 3]]} + array[i][json_field] = {"list": [[i, i + 1], [i, i + 2], [i, i + 3]]} collection_w.insert(array) @@ -1030,8 +974,7 @@ def test_query_expr_json_contains_invalid_type(self, expr_prefix, enable_dynamic expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general( - prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() @@ -1044,10 +987,9 @@ def test_query_expr_json_contains_invalid_type(self, expr_prefix, enable_dynamic # 3. query collection_w.load() expression = f"{expr_prefix}({json_field}['list'], {get_not_list})" - error = {ct.err_code: 1, ct.err_msg: f"cannot parse expression {expression}, error: " - f"error: {expr_prefix} operation element must be an array"} - collection_w.query( - expression, check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 65535, ct.err_msg: f"cannot parse expression: {expression}, " + f"error: contains_any operation element must be an array"} + collection_w.query(expression, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("expr_prefix", ["json_contains", "JSON_CONTAINS"]) @@ -1058,8 +1000,7 @@ def test_query_expr_json_contains_pagination(self, enable_dynamic_field, expr_pr expected: succeed """ # 1. initialize with data - collection_w = self.init_collection_general( - prefix, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0] # 2. insert data array = cf.gen_default_rows_data() @@ -1088,14 +1029,11 @@ def test_query_expr_empty_without_limit(self): collection_w = self.init_collection_general(prefix, True)[0] # 2. query with no limit and no offset - error = {ct.err_code: 1, - ct.err_msg: "empty expression should be used with limit"} - collection_w.query( - "", check_task=CheckTasks.err_res, check_items=error) + error = {ct.err_code: 1, ct.err_msg: "empty expression should be used with limit"} + collection_w.query("", check_task=CheckTasks.err_res, check_items=error) # 3. query with offset but no limit - collection_w.query( - "", offset=1, check_task=CheckTasks.err_res, check_items=error) + collection_w.query("", offset=1, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_query_empty(self): @@ -1123,15 +1061,13 @@ def test_query_expr_empty(self, auto_id, limit): expected: return topK results by order """ # 1. initialize with data - collection_w, _, _, insert_ids = self.init_collection_general( - prefix, True, auto_id=auto_id)[0:4] + collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id)[0:4] exp_ids, res = insert_ids[:limit], [] for ids in exp_ids: res.append({ct.default_int64_field_name: ids}) # 2. query with limit - collection_w.query( - "", limit=limit, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + collection_w.query("", limit=limit, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_empty_pk_string(self): @@ -1142,11 +1078,9 @@ def test_query_expr_empty_pk_string(self): """ # 1. initialize with data collection_w, _, _, insert_ids = \ - self.init_collection_general( - prefix, True, primary_field=ct.default_string_field_name)[0:4] + self.init_collection_general(prefix, True, primary_field=ct.default_string_field_name)[0:4] # string field is sorted by lexicographical order - exp_ids, res = ['0', '1', '10', '100', '1000', - '1001', '1002', '1003', '1004', '1005'], [] + exp_ids, res = ['0', '1', '10', '100', '1000', '1001', '1002', '1003', '1004', '1005'], [] for ids in exp_ids: res.append({ct.default_string_field_name: ids}) @@ -1156,8 +1090,8 @@ def test_query_expr_empty_pk_string(self): # 2. query with limit + offset res = res[5:] - collection_w.query( - "", limit=5, offset=5, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) + collection_w.query("", limit=5, offset=5, + check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("offset", [100, 1000]) @@ -1170,8 +1104,7 @@ def test_query_expr_empty_with_pagination(self, auto_id, limit, offset): expected: return topK results by order """ # 1. initialize with data - collection_w, _, _, insert_ids = self.init_collection_general( - prefix, True, auto_id=auto_id)[0:4] + collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id)[0:4] exp_ids, res = insert_ids[:limit + offset][offset:], [] for ids in exp_ids: res.append({ct.default_int64_field_name: ids}) @@ -1198,15 +1131,13 @@ def test_query_expr_empty_with_random_pk(self, limit, offset): float_value = [np.float32(i) for i in unordered_ids] string_value = [str(i) for i in unordered_ids] vector_value = cf.gen_vectors(nb=ct.default_nb, dim=ct.default_dim) - collection_w.insert([unordered_ids, float_value, - string_value, vector_value]) + collection_w.insert([unordered_ids, float_value, string_value, vector_value]) collection_w.load() # 3. query with empty expr and check the result exp_ids, res = sorted(unordered_ids)[:limit], [] for ids in exp_ids: - res.append({ct.default_int64_field_name: ids, - ct.default_string_field_name: str(ids)}) + res.append({ct.default_int64_field_name: ids, ct.default_string_field_name: str(ids)}) collection_w.query("", limit=limit, output_fields=[ct.default_string_field_name], check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @@ -1214,8 +1145,7 @@ def test_query_expr_empty_with_random_pk(self, limit, offset): # 4. query with pagination exp_ids, res = sorted(unordered_ids)[:limit + offset][offset:], [] for ids in exp_ids: - res.append({ct.default_int64_field_name: ids, - ct.default_string_field_name: str(ids)}) + res.append({ct.default_int64_field_name: ids, ct.default_string_field_name: str(ids)}) collection_w.query("", limit=limit, offset=offset, output_fields=[ct.default_string_field_name], check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @@ -1233,14 +1163,11 @@ def test_query_expr_with_limit_offset_out_of_range(self): # 2. query with limit > 16384 error = {ct.err_code: 1, ct.err_msg: "invalid max query result window, (offset+limit) should be in range [1, 16384]"} - collection_w.query( - "", limit=16385, check_task=CheckTasks.err_res, check_items=error) + collection_w.query("", limit=16385, check_task=CheckTasks.err_res, check_items=error) # 3. query with offset + limit > 16384 - collection_w.query("", limit=1, offset=16384, - check_task=CheckTasks.err_res, check_items=error) - collection_w.query("", limit=16384, offset=1, - check_task=CheckTasks.err_res, check_items=error) + collection_w.query("", limit=1, offset=16384, check_task=CheckTasks.err_res, check_items=error) + collection_w.query("", limit=16384, offset=1, check_task=CheckTasks.err_res, check_items=error) # 4. query with limit < 0 error = {ct.err_code: 1, @@ -1257,16 +1184,15 @@ def test_query_expr_out_of_range(self, expression): expected: """ # 1. initialize with data - collection_w = self.init_collection_general( - prefix, is_all_data_type=True)[0] + collection_w = self.init_collection_general(prefix, is_all_data_type=True)[0] start = ct.default_nb // 2 _vectors = cf.gen_dataframe_all_data_type(start=start) # increase the value to cover the int range - _vectors["int16"] = pd.Series(data=[np.int16( - i*40) for i in range(start, start + ct.default_nb)], dtype="int16") - _vectors["int32"] = pd.Series(data=[np.int32( - i*2200000) for i in range(start, start + ct.default_nb)], dtype="int32") + _vectors["int16"] = \ + pd.Series(data=[np.int16(i*40) for i in range(start, start + ct.default_nb)], dtype="int16") + _vectors["int32"] = \ + pd.Series(data=[np.int32(i*2200000) for i in range(start, start + ct.default_nb)], dtype="int32") insert_ids = collection_w.insert(_vectors)[0].primary_keys # filter result with expression in collection @@ -1294,8 +1220,7 @@ def test_query_output_field_none_or_empty(self, enable_dynamic_field): collection_w = self.init_collection_general(prefix, insert_data=True, enable_dynamic_field=enable_dynamic_field)[0] for fields in [None, []]: - res, _ = collection_w.query( - default_term_expr, output_fields=fields) + res, _ = collection_w.query(default_term_expr, output_fields=fields) assert res[0].keys() == {ct.default_int64_field_name} @pytest.mark.tags(CaseLabel.L0) @@ -1307,10 +1232,8 @@ def test_query_output_one_field(self, enable_dynamic_field): """ collection_w, vectors = self.init_collection_general(prefix, insert_data=True, enable_dynamic_field=enable_dynamic_field)[0:2] - res, _ = collection_w.query(default_term_expr, output_fields=[ - ct.default_float_field_name]) - assert set(res[0].keys()) == { - ct.default_int64_field_name, ct.default_float_field_name} + res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_float_field_name]) + assert set(res[0].keys()) == {ct.default_int64_field_name, ct.default_float_field_name} @pytest.mark.tags(CaseLabel.L1) def test_query_output_all_fields(self, enable_dynamic_field, random_primary_key): @@ -1333,8 +1256,7 @@ def test_query_output_all_fields(self, enable_dynamic_field, random_primary_key) else: res = [] for id in range(2): - num = df[0][df[0][ct.default_int64_field_name] == id].index.to_list()[ - 0] + num = df[0][df[0][ct.default_int64_field_name] == id].index.to_list()[0] res.append(df[0].iloc[num].to_dict()) log.info(res) collection_w.load() @@ -1350,17 +1272,14 @@ def test_query_output_float_vec_field(self): method: specify vec field as output field expected: return primary field and vec field """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data() collection_w.insert(df) assert collection_w.num_entities == ct.default_nb - fields = [[ct.default_float_vec_field_name], [ - ct.default_int64_field_name, ct.default_float_vec_field_name]] - res = df.loc[:1, [ct.default_int64_field_name, - ct.default_float_vec_field_name]].to_dict('records') - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + fields = [[ct.default_float_vec_field_name], + [ct.default_int64_field_name, ct.default_float_vec_field_name]] + res = df.loc[:1, [ct.default_int64_field_name, ct.default_float_vec_field_name]].to_dict('records') + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() for output_fields in fields: collection_w.query(default_term_expr, output_fields=output_fields, @@ -1376,20 +1295,16 @@ def test_query_output_field_wildcard(self, wildcard_output_fields): method: query with one output_field (wildcard) expected: query success """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data() collection_w.insert(df) assert collection_w.num_entities == ct.default_nb - output_fields = cf.get_wildcard_output_field_names( - collection_w, wildcard_output_fields) + output_fields = cf.get_wildcard_output_field_names(collection_w, wildcard_output_fields) output_fields.append(default_int_field_name) - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() with_vec = True if ct.default_float_vec_field_name in output_fields else False - actual_res = collection_w.query( - default_term_expr, output_fields=wildcard_output_fields)[0] + actual_res = collection_w.query(default_term_expr, output_fields=wildcard_output_fields)[0] assert set(actual_res[0].keys()) == set(output_fields) @pytest.mark.tags(CaseLabel.L1) @@ -1404,20 +1319,17 @@ def test_query_output_multi_float_vec_field(self, vec_fields): """ # init collection with two float vector fields schema = cf.gen_schema_multi_vector_fields(vec_fields) - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix), schema=schema) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema) df = cf.gen_dataframe_multi_vec_fields(vec_fields=vec_fields) collection_w.insert(df) assert collection_w.num_entities == ct.default_nb # query with two vec output_fields - output_fields = [ct.default_int64_field_name, - ct.default_float_vec_field_name] + output_fields = [ct.default_int64_field_name, ct.default_float_vec_field_name] for vec_field in vec_fields: output_fields.append(vec_field.name) res = df.loc[:1, output_fields].to_dict('records') - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() collection_w.query(default_term_expr, output_fields=output_fields, check_task=CheckTasks.check_query_results, @@ -1436,20 +1348,17 @@ def test_query_output_mix_float_binary_field(self, vec_fields): """ # init collection with two float vector fields schema = cf.gen_schema_multi_vector_fields(vec_fields) - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix), schema=schema) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), schema=schema) df = cf.gen_dataframe_multi_vec_fields(vec_fields=vec_fields) collection_w.insert(df) assert collection_w.num_entities == ct.default_nb # query with two vec output_fields - output_fields = [ct.default_int64_field_name, - ct.default_float_vec_field_name] + output_fields = [ct.default_int64_field_name, ct.default_float_vec_field_name] for vec_field in vec_fields: output_fields.append(vec_field.name) res = df.loc[:1, output_fields].to_dict('records') - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() collection_w.query(default_term_expr, output_fields=output_fields, check_task=CheckTasks.check_query_results, @@ -1467,13 +1376,11 @@ def test_query_output_binary_vec_field(self): method: specify binary vec field as output field expected: return primary field and binary vec field """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True, is_binary=True)[0:2] - fields = [[ct.default_binary_vec_field_name], [ - ct.default_int64_field_name, ct.default_binary_vec_field_name]] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_binary=True)[0:2] + fields = [[ct.default_binary_vec_field_name], + [ct.default_int64_field_name, ct.default_binary_vec_field_name]] for output_fields in fields: - res, _ = collection_w.query( - default_term_expr, output_fields=output_fields) + res, _ = collection_w.query(default_term_expr, output_fields=output_fields) assert res[0].keys() == set(fields[-1]) @pytest.mark.tags(CaseLabel.L1) @@ -1483,10 +1390,8 @@ def test_query_output_primary_field(self): method: specify int64 primary field as output field expected: return int64 field """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] - res, _ = collection_w.query(default_term_expr, output_fields=[ - ct.default_int64_field_name]) + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_int64_field_name]) assert res[0].keys() == {ct.default_int64_field_name} @pytest.mark.tags(CaseLabel.L2) @@ -1496,13 +1401,12 @@ def test_query_output_not_existed_field(self): method: query with not existed output field expected: raise exception """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] - error = {ct.err_code: 1, ct.err_msg: 'Field int not exist'} + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + error = {ct.err_code: 65535, ct.err_msg: 'field int not exist'} output_fields = [["int"], [ct.default_int64_field_name, "int"]] for fields in output_fields: - collection_w.query(default_term_expr, output_fields=fields, check_task=CheckTasks.err_res, - check_items=error) + collection_w.query(default_term_expr, output_fields=fields, + check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.xfail(reason="exception not MilvusException") @@ -1512,11 +1416,9 @@ def test_query_invalid_output_fields(self): method: query with invalid field fields expected: raise exception """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] output_fields = ["12-s", 1, [1, "2", 3], (1,), {1: 1}] - error = {ct.err_code: 0, - ct.err_msg: f'Invalid query format. \'output_fields\' must be a list'} + error = {ct.err_code: 0, ct.err_msg: f'Invalid query format. \'output_fields\' must be a list'} for fields in output_fields: collection_w.query(default_term_expr, output_fields=fields, check_task=CheckTasks.err_res, check_items=error) @@ -1531,8 +1433,7 @@ def test_query_output_fields_simple_wildcard(self): """ # init collection with fields: int64, float, float_vec, float_vector1 # collection_w, df = self.init_multi_fields_collection_wrap(cf.gen_unique_str(prefix)) - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] df = vectors[0] # query with wildcard all fields @@ -1550,14 +1451,12 @@ def test_query_output_fields_part_scale_wildcard(self): expected: verify query result """ # init collection with fields: int64, float, float_vec - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True, is_index=False)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_index=False)[0:2] df = vectors[0] # query with output_fields=["*", float_vector) res = df.iloc[:2].to_dict('records') - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() collection_w.query(default_term_expr, output_fields=["*", ct.default_float_vec_field_name], check_task=CheckTasks.check_query_results, @@ -1572,13 +1471,11 @@ def test_query_invalid_wildcard(self, output_fields): expected: raise exception """ # init collection with fields: int64, float, float_vec - collection_w = self.init_collection_general( - prefix, insert_data=True, nb=100)[0] + collection_w = self.init_collection_general(prefix, insert_data=True, nb=100)[0] collection_w.load() # query with invalid output_fields - error = {ct.err_code: 1, - ct.err_msg: f"Field {output_fields[-1]} not exist"} + error = {ct.err_code: 65535, ct.err_msg: f"field {output_fields[-1]} not exist"} collection_w.query(default_term_expr, output_fields=output_fields, check_task=CheckTasks.err_res, check_items=error) @@ -1589,14 +1486,12 @@ def test_query_partition(self): method: create a partition and query expected: verify query result """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) partition_w = self.init_partition_wrap(collection_wrap=collection_w) df = cf.gen_default_dataframe_data() partition_w.insert(df) assert collection_w.num_entities == ct.default_nb - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) partition_w.load() res = df.iloc[:2, :1].to_dict('records') collection_w.query(default_term_expr, partition_names=[partition_w.name], @@ -1609,14 +1504,12 @@ def test_query_partition_without_loading(self): method: query on partition and no loading expected: raise exception """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) partition_w = self.init_partition_wrap(collection_wrap=collection_w) df = cf.gen_default_dataframe_data() partition_w.insert(df) assert partition_w.num_entities == ct.default_nb - error = {ct.err_code: 1, - ct.err_msg: f'collection {collection_w.name} was not loaded into memory'} + error = {ct.err_code: 65535, ct.err_msg: "collection not loaded"} collection_w.query(default_term_expr, partition_names=[partition_w.name], check_task=CheckTasks.err_res, check_items=error) @@ -1627,8 +1520,7 @@ def test_query_default_partition(self): method: query on default partition expected: verify query result """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] res = vectors[0].iloc[:2, :1].to_dict('records') collection_w.query(default_term_expr, partition_names=[ct.default_partition_name], check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @@ -1642,8 +1534,7 @@ def test_query_empty_partition_names(self): """ # insert [0, half) into partition_w, [half, nb) into _default half = ct.default_nb // 2 - collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half( - half) + collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half(half) # query from empty partition_names term_expr = f'{ct.default_int64_field_name} in [0, {half}, {ct.default_nb}-1]' @@ -1658,15 +1549,12 @@ def test_query_empty_partition(self): method: query on an empty collection expected: empty query result """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) partition_w = self.init_partition_wrap(collection_wrap=collection_w) assert partition_w.is_empty - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) partition_w.load() - res, _ = collection_w.query( - default_term_expr, partition_names=[partition_w.name]) + res, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name]) assert len(res) == 0 @pytest.mark.tags(CaseLabel.L2) @@ -1677,12 +1565,10 @@ def test_query_not_existed_partition(self): expected: raise exception """ collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix)) - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() partition_names = cf.gen_unique_str() - error = {ct.err_code: 1, - ct.err_msg: f'PartitionName: {partition_names} not found'} + error = {ct.err_code: 65535, ct.err_msg: f'partition name {partition_names} not found'} collection_w.query(default_term_expr, partition_names=[partition_names], check_task=CheckTasks.err_res, check_items=error) @@ -1767,8 +1653,7 @@ def test_query_pagination(self, offset): expected: query successfully and verify query result """ # create collection, insert default_nb, load collection - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[offset: pos + offset]}' @@ -1811,8 +1696,7 @@ def test_query_pagination_with_expression(self, offset, get_normal_expr): """ # 1. initialize with data nb = 1000 - collection_w, _vectors, _, insert_ids = self.init_collection_general( - prefix, True, nb)[0:4] + collection_w, _vectors, _, insert_ids = self.init_collection_general(prefix, True, nb)[0:4] # filter result with expression in collection _vectors = _vectors[0] @@ -1838,14 +1722,12 @@ def test_query_pagination_with_partition(self, offset): method: create a partition and query with different offset expected: verify query result """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) partition_w = self.init_partition_wrap(collection_wrap=collection_w) df = cf.gen_default_dataframe_data() partition_w.insert(df) assert collection_w.num_entities == ct.default_nb - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) partition_w.load() res = df.iloc[:2, :1].to_dict('records') query_params = {"offset": offset, "limit": 10} @@ -1859,13 +1741,11 @@ def test_query_pagination_with_insert_data(self, offset): method: create a partition and query with pagination expected: verify query result """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data() collection_w.insert(df) assert collection_w.num_entities == ct.default_nb - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() res = df.iloc[:2, :1].to_dict('records') query_params = {"offset": offset, "limit": 10} @@ -1880,8 +1760,7 @@ def test_query_pagination_without_limit(self, offset): compare the result with query without pagination params expected: query successfully """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[offset: pos + offset]}' @@ -1904,8 +1783,7 @@ def test_query_pagination_with_offset_over_num_entities(self, offset): expected: return an empty list """ # create collection, insert default_nb, load collection - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[10: pos + 10]}' @@ -1921,8 +1799,7 @@ def test_query_pagination_with_invalid_limit_type(self, limit): expected: raise exception """ # create collection, insert default_nb, load collection - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[10: pos + 10]}' @@ -1940,16 +1817,15 @@ def test_query_pagination_with_invalid_limit_value(self, limit): expected: raise exception """ # create collection, insert default_nb, load collection - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[10: pos + 10]}' collection_w.query(term_expr, offset=10, limit=limit, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, - ct.err_msg: "limit [%s] is invalid, should be in range " - "[1, 16384], but got %s" % (limit, limit)}) + check_items={ct.err_code: 65535, + ct.err_msg: f"invalid max query result window, (offset+limit) " + f"should be in range [1, 16384], but got {limit}"}) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("offset", ["12 s", " ", [0, 1], {2}]) @@ -1960,8 +1836,7 @@ def test_query_pagination_with_invalid_offset_type(self, offset): expected: raise exception """ # create collection, insert default_nb, load collection - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[10: pos + 10]}' @@ -1979,16 +1854,15 @@ def test_query_pagination_with_invalid_offset_value(self, offset): expected: raise exception """ # create collection, insert default_nb, load collection - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] int_values = vectors[0][ct.default_int64_field_name].values.tolist() pos = 10 term_expr = f'{ct.default_int64_field_name} in {int_values[10: pos + 10]}' collection_w.query(term_expr, offset=offset, limit=10, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, - ct.err_msg: "offset [%s] is invalid, should be in range " - "[1, 16384], but got %s" % (offset, offset)}) + check_items={ct.err_code: 65535, + ct.err_msg: f"invalid max query result window, (offset+limit) " + f"should be in range [1, 16384], but got {offset}"}) @pytest.mark.tags(CaseLabel.L2) def test_query_during_upsert(self): @@ -2002,8 +1876,7 @@ def test_query_during_upsert(self): upsert_nb = 1000 expr = f"int64 >= 0 && int64 <= {upsert_nb}" collection_w = self.init_collection_general(prefix, True)[0] - res1 = collection_w.query( - expr, output_fields=[default_float_field_name])[0] + res1 = collection_w.query(expr, output_fields=[default_float_field_name])[0] def do_upsert(): data = cf.gen_default_data_for_upsert(upsert_nb)[0] @@ -2011,8 +1884,7 @@ def do_upsert(): t = threading.Thread(target=do_upsert, args=()) t.start() - res2 = collection_w.query( - expr, output_fields=[default_float_field_name])[0] + res2 = collection_w.query(expr, output_fields=[default_float_field_name])[0] t.join() assert [res1[i][default_float_field_name] for i in range(upsert_nb)] == \ [res2[i][default_float_field_name] for i in range(upsert_nb)] @@ -2034,16 +1906,13 @@ def test_query_without_connection(self): """ # init a collection with default connection - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) # remove default connection - self.connection_wrap.remove_connection( - alias=DefaultConfig.DEFAULT_USING) + self.connection_wrap.remove_connection(alias=DefaultConfig.DEFAULT_USING) # list connection to check - self.connection_wrap.list_connections( - check_task=ct.CheckTasks.ccr, check_items={ct.list_content: []}) + self.connection_wrap.list_connections(check_task=ct.CheckTasks.ccr, check_items={ct.list_content: []}) # query after remove default connection collection_w.query(default_term_expr, check_task=CheckTasks.err_res, @@ -2069,7 +1938,8 @@ def test_query_without_loading(self): # query without load collection_w.query(default_term_expr, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: clem.CollNotLoaded % collection_name}) + check_items={ct.err_code: 65535, + ct.err_msg: "collection not loaded"}) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("term_expr", [f'{ct.default_int64_field_name} in [0]']) @@ -2081,13 +1951,12 @@ def test_query_expr_single_term_array(self, term_expr): """ # init a collection and insert data - collection_w, vectors, binary_raw_vectors = self.init_collection_general( - prefix, insert_data=True)[0:3] + collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)[0:3] # query the first row of data check_vec = vectors[0].iloc[:, [0]][0:1].to_dict('records') - collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={ - exp_res: check_vec}) + collection_w.query(term_expr, + check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec}) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("term_expr", [f'{ct.default_int64_field_name} in [0]']) @@ -2104,8 +1973,8 @@ def test_query_binary_expr_single_term_array(self, term_expr, check_content): # query the first row of data check_vec = vectors[0].iloc[:, [0]][0:1].to_dict('records') - collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={ - exp_res: check_vec}) + collection_w.query(term_expr, + check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec}) @pytest.mark.tags(CaseLabel.L2) def test_query_expr_all_term_array(self): @@ -2116,18 +1985,17 @@ def test_query_expr_all_term_array(self): """ # init a collection and insert data - collection_w, vectors, binary_raw_vectors = self.init_collection_general( - prefix, insert_data=True)[0:3] + collection_w, vectors, binary_raw_vectors = \ + self.init_collection_general(prefix, insert_data=True)[0:3] # data preparation int_values = vectors[0][ct.default_int64_field_name].values.tolist() term_expr = f'{ct.default_int64_field_name} in {int_values}' - check_vec = vectors[0].iloc[:, [0]][0:len( - int_values)].to_dict('records') + check_vec = vectors[0].iloc[:, [0]][0:len(int_values)].to_dict('records') # query all array value - collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={ - exp_res: check_vec}) + collection_w.query(term_expr, + check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec}) @pytest.mark.tags(CaseLabel.L1) def test_query_expr_half_term_array(self): @@ -2138,8 +2006,8 @@ def test_query_expr_half_term_array(self): """ half = ct.default_nb // 2 - collection_w, partition_w, df_partition, df_default = self.insert_entities_into_two_partitions_in_half( - half) + collection_w, partition_w, df_partition, df_default = \ + self.insert_entities_into_two_partitions_in_half(half) int_values = df_default[ct.default_int64_field_name].values.tolist() term_expr = f'{ct.default_int64_field_name} in {int_values}' @@ -2153,8 +2021,7 @@ def test_query_expr_repeated_term_array(self): method: query with repeated array value expected: return hit entities, no repeated """ - collection_w, vectors, binary_raw_vectors = self.init_collection_general( - prefix, insert_data=True)[0:3] + collection_w, vectors, binary_raw_vectors = self.init_collection_general(prefix, insert_data=True)[0:3] int_values = [0, 0, 0, 0] term_expr = f'{ct.default_int64_field_name} in {int_values}' res, _ = collection_w.query(term_expr) @@ -2169,15 +2036,12 @@ def test_query_dup_ids_dup_term_array(self): 2.query with dup term array expected: todo """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data(nb=100) df[ct.default_int64_field_name] = 0 mutation_res, _ = collection_w.insert(df) - assert mutation_res.primary_keys == df[ct.default_int64_field_name].tolist( - ) - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + assert mutation_res.primary_keys == df[ct.default_int64_field_name].tolist() + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() term_expr = f'{ct.default_int64_field_name} in {[0, 0, 0]}' res = df.iloc[:, :2].to_dict('records') @@ -2203,10 +2067,9 @@ def test_query_after_index(self): int_values = [0] term_expr = f'{ct.default_int64_field_name} in {int_values}' - check_vec = vectors[0].iloc[:, [0]][0:len( - int_values)].to_dict('records') - collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={ - exp_res: check_vec}) + check_vec = vectors[0].iloc[:, [0]][0:len(int_values)].to_dict('records') + collection_w.query(term_expr, + check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec}) @pytest.mark.tags(CaseLabel.L1) def test_query_after_search(self): @@ -2223,8 +2086,7 @@ def test_query_after_search(self): self.init_collection_general(prefix, True, nb_old)[0:4] # 2. search for original data after load - vectors_s = [[random.random() for _ in range(ct.default_dim)] - for _ in range(ct.default_nq)] + vectors_s = [[random.random() for _ in range(ct.default_dim)] for _ in range(ct.default_nq)] collection_w.search(vectors_s[:ct.default_nq], ct.default_float_vec_field_name, ct.default_search_params, limit, "int64 >= 0", check_task=CheckTasks.check_search_results, @@ -2235,8 +2097,8 @@ def test_query_after_search(self): term_expr = f'{ct.default_int64_field_name} in [0, 1]' check_vec = vectors[0].iloc[:, [0]][0:2].to_dict('records') - collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={ - exp_res: check_vec}) + collection_w.query(term_expr, + check_task=CheckTasks.check_query_results, check_items={exp_res: check_vec}) @pytest.mark.tags(CaseLabel.L1) def test_query_output_vec_field_after_index(self): @@ -2245,22 +2107,18 @@ def test_query_output_vec_field_after_index(self): method: create index and specify vec field as output field expected: return primary field and vec field """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data(nb=5000) collection_w.insert(df) assert collection_w.num_entities == 5000 fields = [ct.default_int64_field_name, ct.default_float_vec_field_name] - collection_w.create_index( - ct.default_float_vec_field_name, default_index_params) + collection_w.create_index(ct.default_float_vec_field_name, default_index_params) assert collection_w.has_index()[0] - res = df.loc[:1, [ct.default_int64_field_name, - ct.default_float_vec_field_name]].to_dict('records') + res = df.loc[:1, [ct.default_int64_field_name, ct.default_float_vec_field_name]].to_dict('records') collection_w.load() - error = {ct.err_code: 1, ct.err_msg: 'not allowed'} collection_w.query(default_term_expr, output_fields=fields, - check_task=CheckTasks.err_res, - check_items=error) + check_task=CheckTasks.check_query_results, + check_items={exp_res: res, "with_vec": True}) @pytest.mark.tags(CaseLabel.L1) def test_query_output_binary_vec_field_after_index(self): @@ -2269,16 +2127,13 @@ def test_query_output_binary_vec_field_after_index(self): method: create index and specify vec field as output field expected: return primary field and vec field """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_binary=True, is_index=False)[ - 0:2] - fields = [ct.default_int64_field_name, - ct.default_binary_vec_field_name] - collection_w.create_index( - ct.default_binary_vec_field_name, binary_index_params) + collection_w, vectors = self.init_collection_general(prefix, insert_data=True, + is_binary=True, is_index=False)[0:2] + fields = [ct.default_int64_field_name, ct.default_binary_vec_field_name] + collection_w.create_index(ct.default_binary_vec_field_name, binary_index_params) assert collection_w.has_index()[0] collection_w.load() - res, _ = collection_w.query(default_term_expr, output_fields=[ - ct.default_binary_vec_field_name]) + res, _ = collection_w.query(default_term_expr, output_fields=[ct.default_binary_vec_field_name]) assert res[0].keys() == set(fields) @pytest.mark.tags(CaseLabel.L2) @@ -2293,8 +2148,7 @@ def test_query_partition_repeatedly(self): self._connect() # init collection - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) # init partition partition_w = self.init_partition_wrap(collection_wrap=collection_w) @@ -2307,15 +2161,12 @@ def test_query_partition_repeatedly(self): assert collection_w.num_entities == ct.default_nb # load partition - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) partition_w.load() # query twice - res_one, _ = collection_w.query( - default_term_expr, partition_names=[partition_w.name]) - res_two, _ = collection_w.query( - default_term_expr, partition_names=[partition_w.name]) + res_one, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name]) + res_two, _ = collection_w.query(default_term_expr, partition_names=[partition_w.name]) assert res_one == res_two @pytest.mark.tags(CaseLabel.L2) @@ -2327,13 +2178,12 @@ def test_query_another_partition(self): expected: query result is empty """ half = ct.default_nb // 2 - collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half( - half) + collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half(half) term_expr = f'{ct.default_int64_field_name} in [{half}]' # half entity in _default partition rather than partition_w - collection_w.query(term_expr, partition_names=[partition_w.name], check_task=CheckTasks.check_query_results, - check_items={exp_res: []}) + collection_w.query(term_expr, partition_names=[partition_w.name], + check_task=CheckTasks.check_query_results, check_items={exp_res: []}) @pytest.mark.tags(CaseLabel.L1) def test_query_multi_partitions_multi_results(self): @@ -2344,13 +2194,12 @@ def test_query_multi_partitions_multi_results(self): expected: query results from two partitions """ half = ct.default_nb // 2 - collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half( - half) + collection_w, partition_w, _, _ = self.insert_entities_into_two_partitions_in_half(half) term_expr = f'{ct.default_int64_field_name} in [{half - 1}, {half}]' # half entity in _default, half-1 entity in partition_w - res, _ = collection_w.query(term_expr, partition_names=[ - ct.default_partition_name, partition_w.name]) + res, _ = collection_w.query(term_expr, + partition_names=[ct.default_partition_name, partition_w.name]) assert len(res) == 2 @pytest.mark.tags(CaseLabel.L2) @@ -2362,13 +2211,13 @@ def test_query_multi_partitions_single_result(self): expected: query from two partitions and get single result """ half = ct.default_nb // 2 - collection_w, partition_w, df_partition, df_default = self.insert_entities_into_two_partitions_in_half( - half) + collection_w, partition_w, df_partition, df_default = \ + self.insert_entities_into_two_partitions_in_half(half) term_expr = f'{ct.default_int64_field_name} in [{half}]' # half entity in _default - res, _ = collection_w.query(term_expr, partition_names=[ - ct.default_partition_name, partition_w.name]) + res, _ = collection_w.query(term_expr, + partition_names=[ct.default_partition_name, partition_w.name]) assert len(res) == 1 assert res[0][ct.default_int64_field_name] == half @@ -2382,11 +2231,9 @@ def test_query_growing_segment_data(self): 4.query expected: Data can be queried """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) # load collection - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() tmp_nb = 100 df = cf.gen_default_dataframe_data(tmp_nb) @@ -2456,15 +2303,14 @@ def test_query_string_is_not_primary(self): expected: query successfully """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] res = vectors[0].iloc[:2, :3].to_dict('records') output_fields = [default_float_field_name, default_string_field_name] collection_w.query(default_string_term_expr, output_fields=output_fields, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @pytest.mark.tags(CaseLabel.L1) - @pytest.mark.parametrize("expression", cf.gen_normal_string_expressions(default_string_field_name)) + @pytest.mark.parametrize("expression", cf.gen_normal_string_expressions([default_string_field_name])) def test_query_string_is_primary(self, expression): """ target: test query with output field only primary field @@ -2473,8 +2319,7 @@ def test_query_string_is_primary(self, expression): """ collection_w, vectors = self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name)[0:2] - res, _ = collection_w.query(expression, output_fields=[ - ct.default_string_field_name]) + res, _ = collection_w.query(expression, output_fields=[ct.default_string_field_name]) assert res[0].keys() == {ct.default_string_field_name} @pytest.mark.tags(CaseLabel.L1) @@ -2501,10 +2346,11 @@ def test_query_with_invalid_string_expr(self, expression): query with invalid expr expected: Raise exception """ - collection_w = self.init_collection_general( - prefix, insert_data=True)[0] + collection_w = self.init_collection_general(prefix, insert_data=True)[0] collection_w.query(expression, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "type mismatch"}) + check_items={ct.err_code: 65535, + ct.err_msg: f"cannot parse expression: {expression}, error: value " + f"'0' in list cannot be casted to VarChar"}) @pytest.mark.tags(CaseLabel.L1) def test_query_string_expr_with_binary(self): @@ -2513,14 +2359,12 @@ def test_query_string_expr_with_binary(self): method: query string expr with binary expected: verify query successfully """ - collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_binary=True, is_index=False)[ - 0:2] - collection_w.create_index( - ct.default_binary_vec_field_name, binary_index_params) + collection_w, vectors = self.init_collection_general(prefix, insert_data=True, + is_binary=True, is_index=False)[0:2] + collection_w.create_index(ct.default_binary_vec_field_name, binary_index_params) collection_w.load() assert collection_w.has_index()[0] - res, _ = collection_w.query(default_string_term_expr, output_fields=[ - ct.default_binary_vec_field_name]) + res, _ = collection_w.query(default_string_term_expr, output_fields=[ct.default_binary_vec_field_name]) assert len(res) == 2 @pytest.mark.tags(CaseLabel.L1) @@ -2534,8 +2378,7 @@ def test_query_string_expr_with_prefixes(self): primary_field=ct.default_string_field_name)[0:2] res = vectors[0].iloc[:1, :3].to_dict('records') expression = 'varchar like "0%"' - output_fields = [default_int_field_name, - default_float_field_name, default_string_field_name] + output_fields = [default_int_field_name, default_float_field_name, default_string_field_name] collection_w.query(expression, output_fields=output_fields, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @@ -2546,13 +2389,13 @@ def test_query_string_with_invalid_prefix_expr(self): method: specify string primary field, use invalid prefix string expr expected: raise error """ - collection_w = self.init_collection_general( - prefix, insert_data=True)[0] + collection_w = self.init_collection_general(prefix, insert_data=True)[0] expression = 'float like "0%"' - collection_w.query(expression, check_task=CheckTasks.err_res, - check_items={ - ct.err_code: 1, ct.err_msg: "like operation on non-string field is unsupported"} - ) + collection_w.query(expression, + check_task=CheckTasks.err_res, + check_items={ct.err_code: 65535, + ct.err_msg: f"cannot parse expression: {expression}, error: like " + f"operation on non-string or no-json field is unsupported"}) @pytest.mark.tags(CaseLabel.L1) def test_query_compare_two_fields(self): @@ -2561,13 +2404,11 @@ def test_query_compare_two_fields(self): method: specify string primary field, compare two fields expected: verify query successfully """ - collection_w = \ - self.init_collection_general( - prefix, insert_data=True, primary_field=ct.default_string_field_name)[0] + collection_w = self.init_collection_general(prefix, insert_data=True, + primary_field=ct.default_string_field_name)[0] res = [] expression = 'float > int64' - output_fields = [default_int_field_name, - default_float_field_name, default_string_field_name] + output_fields = [default_int_field_name, default_float_field_name, default_string_field_name] collection_w.query(expression, output_fields=output_fields, check_task=CheckTasks.check_query_results, check_items={exp_res: res}) @@ -2578,12 +2419,13 @@ def test_query_compare_invalid_fields(self): method: specify string primary field, compare string and int field expected: raise error """ - collection_w = \ - self.init_collection_general( - prefix, insert_data=True, primary_field=ct.default_string_field_name)[0] + collection_w = self.init_collection_general(prefix, insert_data=True, + primary_field=ct.default_string_field_name)[0] expression = 'varchar == int64' collection_w.query(expression, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: f' cannot parse expression:{expression}'}) + check_items={ct.err_code: 65535, ct.err_msg: + f"cannot parse expression: {expression}, error: comparisons between VarChar, " + f"element_type: None and Int64 elementType: None are not supported"}) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.xfail(reason="issue 24637") @@ -2593,8 +2435,7 @@ def test_query_after_insert_multi_threading(self): method: multi threads insert, and query, compare queried data with original expected: verify data consistency """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) thread_num = 4 threads = [] primary_keys = [] @@ -2602,8 +2443,7 @@ def test_query_after_insert_multi_threading(self): # prepare original data for parallel insert for i in range(thread_num): - df = cf.gen_default_dataframe_data( - ct.default_nb, start=i * ct.default_nb) + df = cf.gen_default_dataframe_data(ct.default_nb, start=i * ct.default_nb) df_list.append(df) primary_key = df[ct.default_int64_field_name].values.tolist() primary_keys.append(primary_key) @@ -2623,8 +2463,7 @@ def insert(thread_i): assert collection_w.num_entities == ct.default_nb * thread_num # Check data consistency after parallel insert - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() df_dict_list = [] for df in df_list: @@ -2648,10 +2487,8 @@ def test_query_string_field_pk_is_empty(self): """ # 1. create a collection schema = cf.gen_string_pk_default_collection_schema() - collection_w = self.init_collection_wrap( - cf.gen_unique_str(prefix), schema=schema) - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix), schema=schema) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() nb = 3000 @@ -2662,8 +2499,7 @@ def test_query_string_field_pk_is_empty(self): assert collection_w.num_entities == nb string_exp = "varchar >= \"\"" - output_fields = [default_int_field_name, - default_float_field_name, default_string_field_name] + output_fields = [default_int_field_name, default_float_field_name, default_string_field_name] res, _ = collection_w.query(string_exp, output_fields=output_fields) assert len(res) == 1 @@ -2678,8 +2514,7 @@ def test_query_string_field_not_primary_is_empty(self): expected: query successfully """ # 1. create a collection - collection_w, vectors = self.init_collection_general( - prefix, insert_data=False, is_index=False)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=False, is_index=False)[0:2] nb = 3000 df = cf.gen_default_list_data(nb) @@ -2688,13 +2523,11 @@ def test_query_string_field_not_primary_is_empty(self): collection_w.insert(df) assert collection_w.num_entities == nb - collection_w.create_index( - ct.default_float_vec_field_name, default_index_params) + collection_w.create_index(ct.default_float_vec_field_name, default_index_params) assert collection_w.has_index()[0] collection_w.load() - output_fields = [default_int_field_name, - default_float_field_name, default_string_field_name] + output_fields = [default_int_field_name, default_float_field_name, default_string_field_name] expr = "varchar == \"\"" res, _ = collection_w.query(expr, output_fields=output_fields) @@ -2708,21 +2541,19 @@ def test_query_with_create_diskann_index(self): method: create a collection and build diskann index expected: verify query result """ - collection_w, vectors = self.init_collection_general( - prefix, insert_data=True, is_index=False)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_index=False)[0:2] - collection_w.create_index( - ct.default_float_vec_field_name, ct.default_diskann_index) + collection_w.create_index(ct.default_float_vec_field_name, ct.default_diskann_index) assert collection_w.has_index()[0] collection_w.load() int_values = [0] term_expr = f'{ct.default_int64_field_name} in {int_values}' - check_vec = vectors[0].iloc[:, [0]][0:len( - int_values)].to_dict('records') - collection_w.query(term_expr, check_task=CheckTasks.check_query_results, check_items={ - exp_res: check_vec}) + check_vec = vectors[0].iloc[:, [0]][0:len(int_values)].to_dict('records') + collection_w.query(term_expr, + check_task=CheckTasks.check_query_results, + check_items={exp_res: check_vec}) @pytest.mark.tags(CaseLabel.L2) def test_query_with_create_diskann_with_string_pk(self): @@ -2734,8 +2565,7 @@ def test_query_with_create_diskann_with_string_pk(self): collection_w, vectors = self.init_collection_general(prefix, insert_data=True, primary_field=ct.default_string_field_name, is_index=False)[0:2] - collection_w.create_index( - ct.default_float_vec_field_name, ct.default_diskann_index) + collection_w.create_index(ct.default_float_vec_field_name, ct.default_diskann_index) assert collection_w.has_index()[0] collection_w.load() res = vectors[0].iloc[:, 1:3].to_dict('records') @@ -2753,8 +2583,7 @@ def test_query_with_scalar_field(self): expected: query successfully """ # 1. create a collection - collection_w, vectors = self.init_collection_general( - prefix, insert_data=False, is_index=False)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=False, is_index=False)[0:2] nb = 3000 df = cf.gen_default_list_data(nb) @@ -2763,12 +2592,10 @@ def test_query_with_scalar_field(self): collection_w.insert(df) assert collection_w.num_entities == nb - collection_w.create_index( - ct.default_float_vec_field_name, default_index_params) + collection_w.create_index(ct.default_float_vec_field_name, default_index_params) assert collection_w.has_index()[0] index_params = {} - collection_w.create_index( - ct.default_int64_field_name, index_params=index_params) + collection_w.create_index(ct.default_int64_field_name, index_params=index_params) collection_w.load() @@ -2796,11 +2623,10 @@ def test_count_consistency_level(self, consistency_level): 4. verify count expected: expected count """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix), consistency_level=consistency_level) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), + consistency_level=consistency_level) # load collection - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() df = cf.gen_default_dataframe_data() @@ -2825,11 +2651,9 @@ def test_count_invalid_output_field(self, invalid_output_field): method: expected: """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) # load collection - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # insert @@ -2848,12 +2672,11 @@ def test_count_without_loading(self): method: count without loading expected: exception """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) collection_w.query(expr=default_term_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": f"has not been loaded to memory or load failed"}) + check_items={"err_code": 65535, + "err_msg": "collection not loaded"}) @pytest.mark.tags(CaseLabel.L1) def test_count_duplicate_ids(self): @@ -2866,10 +2689,8 @@ def test_count_duplicate_ids(self): expected: verify count """ # create - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # insert duplicate ids @@ -2903,8 +2724,7 @@ def test_count_multi_partitions(self): """ half = ct.default_nb // 2 # insert [0, half) into partition_w, [half, nb) into _default - collection_w, p1, _, _ = self.insert_entities_into_two_partitions_in_half( - half=half) + collection_w, p1, _, _ = self.insert_entities_into_two_partitions_in_half(half=half) # query count p1, [p1, _default] for p_name in [p1.name, ct.default_partition_name]: @@ -2918,14 +2738,11 @@ def test_count_multi_partitions(self): collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], partition_names=[ct.default_partition_name], check_task=CheckTasks.check_query_results, - check_items={exp_res: [{count: 0}]} - ) + check_items={exp_res: [{count: 0}]}) collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], - partition_names=[ - p1.name, ct.default_partition_name], + partition_names=[p1.name, ct.default_partition_name], check_task=CheckTasks.check_query_results, - check_items={exp_res: [{count: half}]} - ) + check_items={exp_res: [{count: half}]}) # drop p1 partition p1.release() @@ -2933,14 +2750,12 @@ def test_count_multi_partitions(self): collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], partition_names=[p1.name], check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": f'partition name: {p1.name} not found'} - ) + check_items={"err_code": 65535, + "err_msg": f'partition name {p1.name} not found'}) collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], partition_names=[ct.default_partition_name], check_task=CheckTasks.check_query_results, - check_items={exp_res: [{count: 0}]} - ) + check_items={exp_res: [{count: 0}]}) @pytest.mark.tags(CaseLabel.L2) def test_count_partition_duplicate(self): @@ -2953,8 +2768,7 @@ def test_count_partition_duplicate(self): """ # init partitions: _default and p1 p1 = "p1" - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) collection_w.create_partition(p1) df = cf.gen_default_dataframe_data() @@ -2962,8 +2776,7 @@ def test_count_partition_duplicate(self): collection_w.insert(df, partition_name=p1) # index and load - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # count @@ -2977,8 +2790,7 @@ def test_count_partition_duplicate(self): collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], partition_names=[p1], check_task=CheckTasks.check_query_results, - check_items={ - exp_res: [{count: ct.default_nb - delete_res.delete_count}]} + check_items={exp_res: [{count: ct.default_nb - delete_res.delete_count}]} ) @pytest.mark.tags(CaseLabel.L1) @@ -2993,8 +2805,7 @@ def test_count_growing_sealed_segment(self): """ tmp_nb = 100 # create -> insert -> index -> load -> count sealed - collection_w = self.init_collection_general( - insert_data=True, nb=tmp_nb)[0] + collection_w = self.init_collection_general(insert_data=True, nb=tmp_nb)[0] collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, check_items={exp_res: [{count: tmp_nb}]} @@ -3017,10 +2828,8 @@ def test_count_during_handoff(self): expected: verify count """ # create -> index -> load - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # flush while count @@ -3057,8 +2866,7 @@ def test_count_delete_insert_duplicate_ids(self): insert_res, _ = collection_w.insert(df) # delete growing and sealed ids -> count - collection_w.delete( - f"{ct.default_int64_field_name} in {[i for i in range(ct.default_nb)]}") + collection_w.delete(f"{ct.default_int64_field_name} in {[i for i in range(ct.default_nb)]}") collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, check_items={exp_res: [{count: tmp_nb}]} @@ -3069,8 +2877,7 @@ def test_count_delete_insert_duplicate_ids(self): collection_w.insert(df_same) collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, - check_items={ - exp_res: [{count: ct.default_nb + tmp_nb}]} + check_items={exp_res: [{count: ct.default_nb + tmp_nb}]} ) @pytest.mark.tags(CaseLabel.L1) @@ -3082,8 +2889,7 @@ def test_count_compact_merge(self): 3. count expected: verify count """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix), shards_num=1) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), shards_num=1) # init two segments tmp_nb = 100 @@ -3093,14 +2899,12 @@ def test_count_compact_merge(self): collection_w.insert(df) collection_w.flush() - collection_w.create_index( - ct.default_float_vec_field_name, ct.default_index) + collection_w.create_index(ct.default_float_vec_field_name, ct.default_index) collection_w.compact() collection_w.wait_for_compaction_completed() collection_w.load() - segment_info, _ = self.utility_wrap.get_query_segment_info( - collection_w.name) + segment_info, _ = self.utility_wrap.get_query_segment_info(collection_w.name) assert len(segment_info) == 1 # count after compact @@ -3118,10 +2922,8 @@ def test_count_compact_delete(self): expected: verify count """ # create -> index -> insert - collection_w = self.init_collection_wrap( - cf.gen_unique_str(prefix), shards_num=1) - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix), shards_num=1) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) df = cf.gen_default_dataframe_data() insert_res, _ = collection_w.insert(df) @@ -3149,8 +2951,7 @@ def test_count_during_compact(self): 2. compact while count expected: verify count """ - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix), shards_num=1) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix), shards_num=1) # init 2 segments tmp_nb = 100 @@ -3160,8 +2961,7 @@ def test_count_during_compact(self): collection_w.flush() # compact while count - collection_w.create_index( - ct.default_float_vec_field_name, ct.default_index) + collection_w.create_index(ct.default_float_vec_field_name, ct.default_index) collection_w.load() t_compact = threading.Thread(target=collection_w.compact, args=()) @@ -3170,7 +2970,7 @@ def test_count_during_compact(self): "output_fields": [ct.default_count_output], "check_task": CheckTasks.check_query_results, "check_items": {exp_res: [{count: tmp_nb * 10}]} - }) + }) t_compact.start() t_count.start() @@ -3196,6 +2996,34 @@ def test_count_with_expr(self): check_task=CheckTasks.check_query_results, check_items={exp_res: [{count: 2}]}) + @pytest.mark.tags(CaseLabel.L1) + def test_query_count_expr_json(self): + """ + target: test query with part json key value + method: 1. insert data and some entities doesn't have number key + 2. query count with number expr filet + expected: succeed + """ + # 1. initialize with data + collection_w = self.init_collection_general(prefix, enable_dynamic_field=True, with_json=True)[0] + + # 2. insert data + array = cf.gen_default_rows_data( with_json=False) + for i in range(ct.default_nb): + if i % 2 == 0: + array[i][json_field] = {"string": str(i), "bool": bool(i)} + else: + array[i][json_field] = {"string": str(i), "bool": bool(i), "number": i} + + collection_w.insert(array) + + # 3. query + collection_w.load() + expression = f'{ct.default_json_field_name}["number"] < 100' + collection_w.query(expression, output_fields=[ct.default_count_output], + check_task=CheckTasks.check_query_results, + check_items={exp_res: [{count: 50}]}) + @pytest.mark.tags(CaseLabel.L2) def test_count_with_pagination_param(self): """ @@ -3214,8 +3042,7 @@ def test_count_with_pagination_param(self): # count with limit collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], limit=10, check_task=CheckTasks.err_res, - check_items={ - ct.err_code: 1, ct.err_msg: "count entities with pagination is not allowed"} + check_items={ct.err_code: 1, ct.err_msg: "count entities with pagination is not allowed"} ) # count with pagination params collection_w.query(default_expr, output_fields=[ct.default_count_output], offset=10, limit=10, @@ -3243,18 +3070,17 @@ def test_count_alias_insert_delete_drop(self): # new insert partitions and count p_name = cf.gen_unique_str("p_alias") collection_w_alias.create_partition(p_name) - collection_w_alias.insert(cf.gen_default_dataframe_data( - start=ct.default_nb), partition_name=p_name) + collection_w_alias.insert(cf.gen_default_dataframe_data(start=ct.default_nb), partition_name=p_name) collection_w_alias.query(expr=default_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, check_items={exp_res: [{count: ct.default_nb * 2}]}) # release collection and alias drop partition collection_w_alias.drop_partition(p_name, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, - ct.err_msg: "cannot drop the collection via alias"}) - self.partition_wrap.init_partition( - collection_w_alias.collection, p_name) + check_items={ct.err_code: 65535, + ct.err_msg: "partition cannot be dropped, " + "partition is loaded, please release it first"}) + self.partition_wrap.init_partition(collection_w_alias.collection, p_name) self.partition_wrap.release() collection_w_alias.drop_partition(p_name) @@ -3265,8 +3091,7 @@ def test_count_alias_insert_delete_drop(self): check_items={exp_res: [{count: ct.default_nb}]}) # alias delete and count - collection_w_alias.delete( - f"{ct.default_int64_field_name} in {[i for i in range(ct.default_nb)]}") + collection_w_alias.delete(f"{ct.default_int64_field_name} in {[i for i in range(ct.default_nb)]}") collection_w_alias.query(expr=default_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, check_items={exp_res: [{count: 0}]}) @@ -3290,8 +3115,7 @@ def test_count_upsert_growing_sealed(self, is_growing): if is_growing: # create -> index -> load -> insert -> delete collection_w = self.init_collection_wrap(cf.gen_unique_str(prefix)) - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() collection_w.insert(cf.gen_default_dataframe_data()) @@ -3307,8 +3131,7 @@ def test_count_upsert_growing_sealed(self, is_growing): single_expr = f'{ct.default_int64_field_name} in [0]' collection_w.delete(single_expr) - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # upsert deleted id @@ -3344,14 +3167,12 @@ def test_count_upsert_duplicate(self): """ # init collection and insert same ids tmp_nb = 100 - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) df = cf.gen_default_dataframe_data(nb=tmp_nb) df[ct.default_int64_field_name] = 0 collection_w.insert(df) - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # upsert id and count @@ -3369,8 +3190,7 @@ def test_count_upsert_duplicate(self): check_items={exp_res: [{count: tmp_nb - delete_res.delete_count}]}) # upsert deleted id and count - df_deleted = cf.gen_default_dataframe_data( - nb=delete_res.delete_count, start=0) + df_deleted = cf.gen_default_dataframe_data(nb=delete_res.delete_count, start=0) collection_w.upsert(df_deleted) collection_w.query(expr=default_expr, output_fields=[ct.default_count_output], check_task=CheckTasks.check_query_results, @@ -3403,10 +3223,8 @@ def test_count_disable_growing_segments(self): expected: verify count 0 """ # create -> index -> load - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str(prefix)) - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w = self.init_collection_wrap(name=cf.gen_unique_str(prefix)) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # insert @@ -3424,8 +3242,7 @@ def test_count_expressions(self, expression): expected: verify count """ # create -> insert -> index -> load - collection_w, _vectors, _, insert_ids = self.init_collection_general( - insert_data=True)[0:4] + collection_w, _vectors, _, insert_ids = self.init_collection_general(insert_data=True)[0:4] # filter result with expression in collection _vectors = _vectors[0] @@ -3453,8 +3270,7 @@ def test_count_bool_expressions(self, bool_type): """ # create -> insert -> index -> load collection_w, _vectors, _, insert_ids = \ - self.init_collection_general( - insert_data=True, is_all_data_type=True)[0:4] + self.init_collection_general(insert_data=True, is_all_data_type=True)[0:4] # filter result with expression in collection filter_ids = [] @@ -3483,8 +3299,7 @@ def test_count_expression_auto_field(self, expression): expected: verify count """ # create -> insert -> index -> load - collection_w, _vectors, _, insert_ids = self.init_collection_general( - insert_data=True)[0:4] + collection_w, _vectors, _, insert_ids = self.init_collection_general(insert_data=True)[0:4] # filter result with expression in collection _vectors = _vectors[0] @@ -3510,8 +3325,7 @@ def test_count_expression_all_datatype(self): expected: verify count """ # create -> insert -> index -> load - collection_w = self.init_collection_general( - insert_data=True, is_all_data_type=True)[0] + collection_w = self.init_collection_general(insert_data=True, is_all_data_type=True)[0] # count with expr expression = "int64 >= 0 && int32 >= 1999 && int16 >= 0 && int8 >= 0 && float <= 1999.0 && double >= 0" @@ -3530,14 +3344,12 @@ def test_count_expression_comparative(self): # create -> insert -> index -> load fields = [cf.gen_int64_field("int64_1"), cf.gen_int64_field("int64_2"), cf.gen_float_vec_field()] - schema = cf.gen_collection_schema( - fields=fields, primary_field="int64_1") + schema = cf.gen_collection_schema(fields=fields, primary_field="int64_1") collection_w = self.init_collection_wrap(schema=schema) nb, res = 10, 0 int_values = [random.randint(0, nb) for _ in range(nb)] - data = [[i for i in range(nb)], int_values, - cf.gen_vectors(nb, ct.default_dim)] + data = [[i for i in range(nb)], int_values, cf.gen_vectors(nb, ct.default_dim)] collection_w.insert(data) collection_w.create_index(ct.default_float_vec_field_name) collection_w.load() @@ -3569,10 +3381,8 @@ def test_query_iterator_normal(self): """ # 1. initialize with data batch_size = 100 - collection_w = self.init_collection_general( - prefix, True, is_index=False)[0] - collection_w.create_index( - ct.default_float_vec_field_name, {"metric_type": "L2"}) + collection_w = self.init_collection_general(prefix, True, is_index=False)[0] + collection_w.create_index(ct.default_float_vec_field_name, {"metric_type": "L2"}) collection_w.load() # 2. search iterator expr = "int64 >= 0" @@ -3607,10 +3417,8 @@ def test_query_iterator_with_offset(self, offset): """ # 1. initialize with data batch_size = 300 - collection_w = self.init_collection_general( - prefix, True, is_index=False)[0] - collection_w.create_index( - ct.default_float_vec_field_name, {"metric_type": "L2"}) + collection_w = self.init_collection_general(prefix, True, is_index=False)[0] + collection_w.create_index(ct.default_float_vec_field_name, {"metric_type": "L2"}) collection_w.load() # 2. search iterator expr = "int64 >= 0" @@ -3630,10 +3438,8 @@ def test_query_iterator_with_different_batch_size(self, batch_size): """ # 1. initialize with data offset = 500 - collection_w = self.init_collection_general( - prefix, True, is_index=False)[0] - collection_w.create_index( - ct.default_float_vec_field_name, {"metric_type": "L2"}) + collection_w = self.init_collection_general(prefix, True, is_index=False)[0] + collection_w.create_index(ct.default_float_vec_field_name, {"metric_type": "L2"}) collection_w.load() # 2. search iterator expr = "int64 >= 0" @@ -3643,7 +3449,6 @@ def test_query_iterator_with_different_batch_size(self, batch_size): "batch_size": batch_size}) @pytest.mark.tags(CaseLabel.L2) - @pytest.mark.skip("issue #26767") @pytest.mark.parametrize("offset", [0, 10, 100, 1000]) @pytest.mark.parametrize("limit", [0, 100, 1500, 2000, 10000]) def test_query_iterator_with_different_limit(self, limit, offset): @@ -3656,7 +3461,7 @@ def test_query_iterator_with_different_limit(self, limit, offset): # 1. initialize with data collection_w = self.init_collection_general(prefix, True)[0] # 2. query iterator - Count = limit - offset if limit <= ct.default_nb else ct.default_nb - offset + Count = limit if limit + offset <= ct.default_nb else ct.default_nb - offset collection_w.query_iterator(limit=limit, expr="", offset=offset, check_task=CheckTasks.check_query_iterator, check_items={"count": max(Count, 0), @@ -3675,8 +3480,7 @@ def test_query_iterator_invalid_batch_size(self): # 2. search iterator expr = "int64 >= 0" error = {"err_code": 1, "err_msg": "batch size cannot be less than zero"} - collection_w.query_iterator( - batch_size=-1, expr=expr, check_task=CheckTasks.err_res, check_items=error) + collection_w.query_iterator(batch_size=-1, expr=expr, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L0) @pytest.mark.parametrize("batch_size", [100, 500]) @@ -3688,8 +3492,7 @@ def test_query_iterator_empty_expr(self, auto_id, batch_size): expected: return topK results by order """ # 1. initialize with data - collection_w, _, _, insert_ids = self.init_collection_general( - prefix, True, auto_id=auto_id)[0:4] + collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id)[0:4] # 2. query with limit collection_w.query_iterator(batch_size=batch_size, @@ -3708,16 +3511,15 @@ def test_query_iterator_expr_empty_with_random_pk_pagination(self, batch_size, o expected: return topK results by order """ # 1. initialize with data - collection_w, _, _, insert_ids = self.init_collection_general( - prefix, True, random_primary_key=True)[0:4] + collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, random_primary_key=True)[0:4] - # 3. query with empty expr and check the result + # 2. query with empty expr and check the result exp_ids = sorted(insert_ids) collection_w.query_iterator(batch_size, output_fields=[ct.default_string_field_name], check_task=CheckTasks.check_query_iterator, check_items={"batch_size": batch_size, "count": ct.default_nb, "exp_ids": exp_ids}) - # 4. query with pagination + # 3. query with pagination exp_ids = sorted(insert_ids)[offset:] collection_w.query_iterator(batch_size, offset=offset, output_fields=[ct.default_string_field_name], check_task=CheckTasks.check_query_iterator, diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index 91ce4c2e3a457..a676e552310cf 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -45,12 +45,10 @@ default_bool_field_name = ct.default_bool_field_name default_string_field_name = ct.default_string_field_name default_json_field_name = ct.default_json_field_name -default_index_params = {"index_type": "IVF_SQ8", - "metric_type": "COSINE", "params": {"nlist": 64}} -vectors = [[random.random() for _ in range(default_dim)] - for _ in range(default_nq)] -range_search_supported_index = ct.all_index_types[:6] -range_search_supported_index_params = ct.default_index_params[:6] +default_index_params = {"index_type": "IVF_SQ8", "metric_type": "COSINE", "params": {"nlist": 64}} +vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] +range_search_supported_index = ct.all_index_types[:7] +range_search_supported_index_params = ct.default_index_params[:7] uid = "test_search" nq = 1 epsilon = 0.001 @@ -60,8 +58,7 @@ entity = gen_entities(1, is_normal=True) entities = gen_entities(default_nb, is_normal=True) raw_vectors, binary_entities = gen_binary_entities(default_nb) -default_query, _ = gen_search_vectors_params( - field_name, entities, default_top_k, nq) +default_query, _ = gen_search_vectors_params(field_name, entities, default_top_k, nq) index_name1 = cf.gen_unique_str("float") index_name2 = cf.gen_unique_str("varhar") half_nb = ct.default_nb // 2 @@ -250,14 +247,12 @@ def test_search_param_invalid_dim(self): # 2. search with invalid dim log.info("test_search_param_invalid_dim: searching with invalid dim") wrong_dim = 129 - vectors = [[random.random() for _ in range(wrong_dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(wrong_dim)] for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "The dimension of query entities " - "is different from schema"}) + check_items={"err_code": 65538, + "err_msg": 'failed to search'}) @pytest.mark.tags(CaseLabel.L2) def test_search_param_invalid_field_type(self, get_invalid_fields_type): @@ -270,13 +265,14 @@ def test_search_param_invalid_field_type(self, get_invalid_fields_type): collection_w = self.init_collection_general(prefix)[0] # 2. search with invalid field invalid_search_field = get_invalid_fields_type - log.info("test_search_param_invalid_field_type: searching with " - "invalid field: %s" % invalid_search_field) + log.info("test_search_param_invalid_field_type: searching with invalid field: %s" + % invalid_search_field) + error1 = {"err_code": 65535, "err_msg": "collection not loaded"} + error2 = {"err_code": 1, "err_msg": f"`anns_field` value {get_invalid_fields_type} is illegal"} + error = error2 if get_invalid_fields_type in [[], 1, [1, "2", 3], (1,), {1: 1}] else error1 collection_w.search(vectors[:default_nq], invalid_search_field, default_search_params, default_limit, default_search_exp, - check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "`anns_field` value {} is illegal".format(invalid_search_field)}) + check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_search_param_invalid_field_value(self, get_invalid_fields_value): @@ -294,9 +290,9 @@ def test_search_param_invalid_field_value(self, get_invalid_fields_value): collection_w.search(vectors[:default_nq], invalid_search_field, default_search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "Field %s doesn't exist in schema" - % invalid_search_field}) + check_items={"err_code": 65535, + "err_msg": "failed to create query plan: failed to get field schema " + "by name: %s not found" % invalid_search_field}) @pytest.mark.tags(CaseLabel.L1) def test_search_param_invalid_metric_type(self, get_invalid_metric_type): @@ -308,21 +304,19 @@ def test_search_param_invalid_metric_type(self, get_invalid_metric_type): # 1. initialize with data collection_w = self.init_collection_general(prefix)[0] # 2. search with invalid metric_type - log.info( - "test_search_param_invalid_metric_type: searching with invalid metric_type") + log.info("test_search_param_invalid_metric_type: searching with invalid metric_type") invalid_metric = get_invalid_metric_type - search_params = {"metric_type": invalid_metric, - "params": {"nprobe": 10}} + search_params = {"metric_type": invalid_metric, "params": {"nprobe": 10}} collection_w.search(vectors[:default_nq], default_search_field, search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "metric type not found"}) + check_items={"err_code": 65535, + "err_msg": "collection not loaded"}) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[:6], - ct.default_index_params[:6])) + zip(ct.all_index_types[:7], + ct.default_index_params[:7])) def test_search_invalid_params_type(self, index, params): """ target: test search with invalid search params @@ -335,13 +329,11 @@ def test_search_invalid_params_type(self, index, params): collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, 5000, is_index=False)[0:4] # 2. create index and load - default_index = {"index_type": index, - "params": params, "metric_type": "L2"} + default_index = {"index_type": index, "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search invalid_search_params = cf.gen_invalid_search_params_type() - message = "Search params check failed" for invalid_search_param in invalid_search_params: if index == invalid_search_param["index_type"]: search_params = {"metric_type": "L2", @@ -350,8 +342,8 @@ def test_search_invalid_params_type(self, index, params): search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": message}) + check_items={"err_code": 65538, + "err_msg": "failed to search"}) @pytest.mark.skip("not fixed yet") @pytest.mark.tags(CaseLabel.L1) @@ -412,13 +404,13 @@ def test_search_param_invalid_limit_value(self, limit): # 2. search with invalid limit (topK) log.info("test_search_param_invalid_limit_value: searching with " "invalid limit (topK) = %s" % limit) - err_msg = "limit %d is too large!" % limit + err_msg = f"topk [{limit}] is invalid, top k should be in range [1, 16384], but got {limit}" if limit == 0: err_msg = "`limit` value 0 is illegal" collection_w.search(vectors[:default_nq], default_search_field, default_search_params, limit, default_search_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 1, + check_items={"err_code": 65535, "err_msg": err_msg}) @pytest.mark.tags(CaseLabel.L2) @@ -455,8 +447,7 @@ def test_search_with_expression_join_two_fields(self, expression): dim = 1 fields = [cf.gen_int64_field("int64_1"), cf.gen_int64_field("int64_2"), cf.gen_float_vec_field(dim=dim)] - schema = cf.gen_collection_schema( - fields=fields, primary_field="int64_1") + schema = cf.gen_collection_schema(fields=fields, primary_field="int64_1") collection_w = self.init_collection_wrap(schema=schema) # 2. insert data @@ -466,14 +457,11 @@ def test_search_with_expression_join_two_fields(self, expression): collection_w.insert(dataframe) # 3. search with expression - log.info( - "test_search_with_expression: searching with expression: %s" % expression) - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + log.info("test_search_with_expression: searching with expression: %s" % expression) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() expression = expression.replace("&&", "and").replace("||", "or") - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, nb, expression, check_task=CheckTasks.err_res, @@ -497,8 +485,8 @@ def test_search_param_invalid_expr_value(self, get_invalid_expr_value): collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, invalid_search_expr, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "invalid expression %s" + check_items={"err_code": 65535, + "err_msg": "failed to create query plan: cannot parse expression: %s" % invalid_search_expr}) @pytest.mark.tags(CaseLabel.L2) @@ -529,16 +517,14 @@ def test_search_with_expression_invalid_bool(self, expression): method: test search invalid bool expected: searched failed """ - collection_w = self.init_collection_general( - prefix, True, is_all_data_type=True)[0] - log.info( - "test_search_with_expression: searching with expression: %s" % expression) + collection_w = self.init_collection_general(prefix, True, is_all_data_type=True)[0] + log.info("test_search_with_expression: searching with expression: %s" % expression) collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, expression, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "failed to create query plan: cannot parse " - "expression: %s" % expression}) + check_items={"err_code": 65535, + "err_msg": "failed to create query plan: predicate is not a " + "boolean expression: %s, data type: Bool" % expression}) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("expression", ["int64 like 33", "float LIKE 33"]) @@ -622,9 +608,8 @@ def test_search_release_collection(self): collection_w.search(vectors, default_search_field, default_search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "collection %s was not loaded " - "into memory" % collection_w.name}) + check_items={"err_code": 65535, + "err_msg": "collection not loaded"}) @pytest.mark.tags(CaseLabel.L2) def test_search_release_partition(self): @@ -637,10 +622,8 @@ def test_search_release_partition(self): """ # 1. initialize with data partition_num = 1 - collection_w = self.init_collection_general( - prefix, True, 10, partition_num, is_index=False)[0] - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w = self.init_collection_general(prefix, True, 10, partition_num, is_index=False)[0] + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) par = collection_w.partitions par_name = par[partition_num].name par[partition_num].load() @@ -653,8 +636,8 @@ def test_search_release_partition(self): default_search_params, limit, default_search_exp, [par_name], check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "partition has been released"}) + check_items={"err_code": 65535, + "err_msg": "collection not loaded"}) @pytest.mark.skip("enable this later using session/strong consistency") @pytest.mark.tags(CaseLabel.L1) @@ -744,16 +727,14 @@ def test_search_partition_deleted(self): """ # 1. initialize with data partition_num = 1 - collection_w = self.init_collection_general( - prefix, True, 1000, partition_num, is_index=False)[0] + collection_w = self.init_collection_general(prefix, True, 1000, partition_num, is_index=False)[0] # 2. delete partitions log.info("test_search_partition_deleted: deleting a partition") par = collection_w.partitions deleted_par_name = par[partition_num].name collection_w.drop_partition(deleted_par_name) log.info("test_search_partition_deleted: deleted a partition") - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() # 3. search after delete partitions log.info("test_search_partition_deleted: searching deleted partition") @@ -761,13 +742,13 @@ def test_search_partition_deleted(self): default_search_params, default_limit, default_search_exp, [deleted_par_name], check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "PartitonName: %s not found" % deleted_par_name}) + check_items={"err_code": 65535, + "err_msg": "partition name search_partition_0 not found"}) - @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[1:5], - ct.default_index_params[1:5])) + zip(ct.all_index_types[1:7], + ct.default_index_params[1:7])) def test_search_different_index_invalid_params(self, index, params): """ target: test search with different index @@ -782,24 +763,20 @@ def test_search_different_index_invalid_params(self, index, params): if params.get("m"): if (default_dim % params["m"]) != 0: params["m"] = default_dim // 4 - log.info( - "test_search_different_index_invalid_params: Creating index-%s" % index) - default_index = {"index_type": index, - "params": params, "metric_type": "L2"} + log.info("test_search_different_index_invalid_params: Creating index-%s" % index) + default_index = {"index_type": index, "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) - log.info( - "test_search_different_index_invalid_params: Created index-%s" % index) + log.info("test_search_different_index_invalid_params: Created index-%s" % index) collection_w.load() # 3. search - log.info( - "test_search_different_index_invalid_params: Searching after creating index-%s" % index) + log.info("test_search_different_index_invalid_params: Searching after " + "creating index-%s" % index) search_params = cf.gen_invalid_search_param(index) collection_w.search(vectors, default_search_field, search_params[0], default_limit, default_search_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "Search params check failed"}) + check_items={"err_code": 65538, "err_msg": "failed to search"}) @pytest.mark.tags(CaseLabel.L2) def test_search_index_partition_not_existed(self): @@ -811,17 +788,42 @@ def test_search_index_partition_not_existed(self): # 1. initialize with data collection_w = self.init_collection_general(prefix, is_index=False)[0] # 2. create index - default_index = {"index_type": "IVF_FLAT", - "params": {"nlist": 128}, "metric_type": "L2"} + default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) # 3. search the non exist partition partition_name = "search_non_exist" collection_w.search(vectors[:default_nq], default_search_field, default_search_params, - default_limit, default_search_exp, [ - partition_name], + default_limit, default_search_exp, [partition_name], check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "PartitonName: %s not found" % partition_name}) + check_items={"err_code": 65535, + "err_msg": "partition name %s not found" % partition_name}) + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("reorder_k", [100]) + def test_search_scann_with_invalid_reorder_k(self, reorder_k): + """ + target: test search with invalid nq + method: search with invalid nq + expected: raise exception and report the error + """ + # initialize with data + collection_w = self.init_collection_general(prefix, True, is_index=False)[0] + index_params = {"index_type": "SCANN", "metric_type": "L2", "params": {"nlist": 1024}} + collection_w.create_index(default_search_field, index_params) + # search + search_params = {"metric_type": "L2", "params": {"nprobe": 10, "reorder_k": reorder_k}} + collection_w.load() + collection_w.search(vectors[:default_nq], default_search_field, + search_params, reorder_k + 1, + check_task=CheckTasks.err_res, + check_items={"err_code": 65538, + "err_msg": "failed to search: attempt #0: failed to search/query " + "delegator 1 for channel by-dev-rootcoord-dml_12_44501" + "8735380972010v0: fail to Search, QueryNode ID=1, reaso" + "n=worker(1) query failed: UnknownError: => failed to " + "search: out of range in json: reorder_k(100) should be" + " larger than k(101): attempt #1: no available shard de" + "legator found: service unavailable"}) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("nq", [16385]) @@ -876,17 +878,15 @@ def test_search_binary_flat_with_L2(self): expected: raise exception and report error """ # 1. initialize with binary data - collection_w = self.init_collection_general( - prefix, True, is_binary=True)[0] + collection_w = self.init_collection_general(prefix, True, is_binary=True)[0] # 2. search and assert - query_raw_vector, binary_vectors = cf.gen_binary_vectors( - 2, default_dim) + query_raw_vector, binary_vectors = cf.gen_binary_vectors(2, default_dim) search_params = {"metric_type": "L2", "params": {"nprobe": 10}} collection_w.search(binary_vectors[:default_nq], "binary_vector", search_params, default_limit, "int64 >= 0", check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "Data type and metric type mis-match"}) + check_items={"err_code": 65538, "err_msg": "metric type not match: " + "expected=JACCARD, actual=L2"}) @pytest.mark.tags(CaseLabel.L2) def test_search_with_output_fields_not_exist(self): @@ -904,8 +904,8 @@ def test_search_with_output_fields_not_exist(self): default_search_params, default_limit, default_search_exp, output_fields=["int63"], check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, - ct.err_msg: "Field int63 not exist"}) + check_items={ct.err_code: 65535, + ct.err_msg: "field int63 not exist"}) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.skip(reason="Now support output vector field") @@ -937,12 +937,10 @@ def test_search_output_field_vector_after_gpu_index(self, index, param): expected: raise exception and report the error """ # 1. create a collection and insert data - collection_w = self.init_collection_general( - prefix, True, is_index=False)[0] + collection_w = self.init_collection_general(prefix, True, is_index=False)[0] # 2. create an index which doesn't output vectors - default_index = {"index_type": index, - "params": param, "metric_type": "L2"} + default_index = {"index_type": index, "params": param, "metric_type": "L2"} collection_w.create_index(field_name, default_index) # 3. load and search @@ -971,8 +969,8 @@ def test_search_output_field_invalid_wildcard(self, output_fields): default_search_params, default_limit, default_search_exp, output_fields=output_fields, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": f"Field {output_fields[-1]} not exist"}) + check_items={"err_code": 65535, + "err_msg": f"field {output_fields[-1]} not exist"}) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("ignore_growing", ct.get_invalid_strs[2:8]) @@ -1058,14 +1056,14 @@ def test_range_search_invalid_radius(self, get_invalid_range_search_paras): log.info("test_range_search_invalid_radius: Range searching collection %s" % collection_w.name) radius = get_invalid_range_search_paras - range_search_params = {"metric_type": "L2", "params": { - "nprobe": 10, "radius": radius, "range_filter": 0}} + range_search_params = {"metric_type": "L2", + "params": {"nprobe": 10, "radius": radius, "range_filter": 0}} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": f"type must be number"}) + check_items={"err_code": 65535, + "err_msg": "collection not loaded"}) @pytest.mark.tags(CaseLabel.L2) def test_range_search_invalid_range_filter(self, get_invalid_range_search_paras): @@ -1080,14 +1078,14 @@ def test_range_search_invalid_range_filter(self, get_invalid_range_search_paras) log.info("test_range_search_invalid_range_filter: Range searching collection %s" % collection_w.name) range_filter = get_invalid_range_search_paras - range_search_params = {"metric_type": "L2", "params": { - "nprobe": 10, "radius": 1, "range_filter": range_filter}} + range_search_params = {"metric_type": "L2", + "params": {"nprobe": 10, "radius": 1, "range_filter": range_filter}} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": f"type must be number"}) + check_items={"err_code": 65535, + "err_msg": "collection not loaded"}) @pytest.mark.tags(CaseLabel.L1) def test_range_search_invalid_radius_range_filter_L2(self): @@ -1101,14 +1099,13 @@ def test_range_search_invalid_radius_range_filter_L2(self): # 2. range search log.info("test_range_search_invalid_radius_range_filter_L2: Range searching collection %s" % collection_w.name) - range_search_params = {"metric_type": "L2", "params": { - "nprobe": 10, "radius": 1, "range_filter": 10}} + range_search_params = {"metric_type": "L2", "params": {"nprobe": 10, "radius": 1, "range_filter": 10}} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": f"range_filter must less than radius except IP"}) + check_items={"err_code": 65535, + "err_msg": "collection not loaded"}) @pytest.mark.tags(CaseLabel.L1) def test_range_search_invalid_radius_range_filter_IP(self): @@ -1122,14 +1119,14 @@ def test_range_search_invalid_radius_range_filter_IP(self): # 2. range search log.info("test_range_search_invalid_radius_range_filter_IP: Range searching collection %s" % collection_w.name) - range_search_params = {"metric_type": "IP", "params": { - "nprobe": 10, "radius": 10, "range_filter": 1}} + range_search_params = {"metric_type": "IP", + "params": {"nprobe": 10, "radius": 10, "range_filter": 1}} collection_w.search(vectors[:default_nq], default_search_field, range_search_params, default_limit, default_search_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": f"range_filter must more than radius when IP"}) + check_items={"err_code": 65535, + "err_msg": "collection not loaded"}) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.skip(reason="annoy not supported any more") @@ -1242,26 +1239,23 @@ def test_search_dynamic_compare_two_fields(self): enable_dynamic_field=True)[0] # create index - index_params_one = {"index_type": "IVF_SQ8", - "metric_type": "COSINE", "params": {"nlist": 64}} + index_params_one = {"index_type": "IVF_SQ8", "metric_type": "COSINE", "params": {"nlist": 64}} collection_w.create_index( ct.default_float_vec_field_name, index_params_one, index_name=index_name1) index_params_two = {} - collection_w.create_index( - ct.default_string_field_name, index_params=index_params_two, index_name=index_name2) + collection_w.create_index(ct.default_string_field_name, index_params=index_params_two, index_name=index_name2) assert collection_w.has_index(index_name=index_name2) collection_w.load() # delete entity expr = 'float >= int64' # search with id 0 vectors - vectors = [[random.random() for _ in range(default_dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, expr, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": f"unsupported left datatype of compare expr"}) + check_items={"err_code": 65538, + "err_msg": "UnknownError: unsupported right datatype JSON of compare expr"}) class TestCollectionSearch(TestcaseBase): @@ -2086,8 +2080,8 @@ def test_search_HNSW_index_with_min_ef(self, M, efConstruction, limit, auto_id, @pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.GPU) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[:6], - ct.default_index_params[:6])) + zip(ct.all_index_types[:7], + ct.default_index_params[:7])) def test_search_after_different_index_with_params(self, dim, index, params, auto_id, _async, enable_dynamic_field): """ target: test search after different index @@ -2107,14 +2101,12 @@ def test_search_after_different_index_with_params(self, dim, index, params, auto if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - default_index = {"index_type": index, - "params": params, "metric_type": "COSINE"} + default_index = {"index_type": index, "params": params, "metric_type": "COSINE"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search search_params = cf.gen_search_param(index, "COSINE") - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) limit = default_limit @@ -2135,8 +2127,8 @@ def test_search_after_different_index_with_params(self, dim, index, params, auto @pytest.mark.tags(CaseLabel.GPU) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[8:10], - ct.default_index_params[8:10])) + zip(ct.all_index_types[9:11], + ct.default_index_params[9:11])) def test_search_after_different_index_with_params_gpu(self, dim, index, params, auto_id, _async, enable_dynamic_field): """ @@ -2157,14 +2149,12 @@ def test_search_after_different_index_with_params_gpu(self, dim, index, params, if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - default_index = {"index_type": index, - "params": params, "metric_type": "L2"} + default_index = {"index_type": index, "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search search_params = cf.gen_search_param(index) - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) collection_w.search(vectors[:default_nq], default_search_field, @@ -2203,9 +2193,10 @@ def test_search_default_search_params_fit_for_autoindex(self, search_params, aut @pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.GPU) + @pytest.mark.skip("issue #27252") @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[:5], - ct.default_index_params[:5])) + zip(ct.all_index_types[:7], + ct.default_index_params[:7])) def test_search_after_different_index_with_min_dim(self, index, params, auto_id, _async): """ target: test search after different index with min dim @@ -2222,14 +2213,12 @@ def test_search_after_different_index_with_min_dim(self, index, params, auto_id, params["m"] = min_dim if params.get("PQM"): params["PQM"] = min_dim - default_index = {"index_type": index, - "params": params, "metric_type": "L2"} + default_index = {"index_type": index, "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search search_params = cf.gen_search_param(index) - vectors = [[random.random() for _ in range(min_dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(min_dim)] for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) collection_w.search(vectors[:default_nq], default_search_field, @@ -2243,8 +2232,8 @@ def test_search_after_different_index_with_min_dim(self, index, params, auto_id, @pytest.mark.tags(CaseLabel.GPU) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[8:10], - ct.default_index_params[8:10])) + zip(ct.all_index_types[9:11], + ct.default_index_params[9:11])) def test_search_after_different_index_with_min_dim_gpu(self, index, params, auto_id, _async): """ target: test search after different index with min dim @@ -2261,14 +2250,12 @@ def test_search_after_different_index_with_min_dim_gpu(self, index, params, auto params["m"] = min_dim if params.get("PQM"): params["PQM"] = min_dim - default_index = {"index_type": index, - "params": params, "metric_type": "L2"} + default_index = {"index_type": index, "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search search_params = cf.gen_search_param(index) - vectors = [[random.random() for _ in range(min_dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(min_dim)] for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) collection_w.search(vectors[:default_nq], default_search_field, @@ -2283,8 +2270,8 @@ def test_search_after_different_index_with_min_dim_gpu(self, index, params, auto @pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.GPU) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[:6], - ct.default_index_params[:6])) + zip(ct.all_index_types[:7], + ct.default_index_params[:7])) def test_search_after_index_different_metric_type(self, dim, index, params, auto_id, _async, enable_dynamic_field, metric_type): """ @@ -2317,18 +2304,14 @@ def test_search_after_index_different_metric_type(self, dim, index, params, auto if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - log.info( - "test_search_after_index_different_metric_type: Creating index-%s" % index) - default_index = {"index_type": index, - "params": params, "metric_type": metric_type} + log.info("test_search_after_index_different_metric_type: Creating index-%s" % index) + default_index = {"index_type": index, "params": params, "metric_type": metric_type} collection_w.create_index("float_vector", default_index) - log.info( - "test_search_after_index_different_metric_type: Created index-%s" % index) + log.info("test_search_after_index_different_metric_type: Created index-%s" % index) collection_w.load() # 4. search search_params = cf.gen_search_param(index, metric_type) - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) limit = default_limit @@ -2353,8 +2336,8 @@ def test_search_after_index_different_metric_type(self, dim, index, params, auto @pytest.mark.tags(CaseLabel.L2) @pytest.mark.skip(reason="issue 24957") @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[:6], - ct.default_index_params[:6])) + zip(ct.all_index_types[:7], + ct.default_index_params[:7])) def test_search_after_release_recreate_index(self, dim, index, params, auto_id, _async, enable_dynamic_field, metric_type): """ @@ -2386,17 +2369,14 @@ def test_search_after_release_recreate_index(self, dim, index, params, auto_id, if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - log.info( - "test_search_after_release_recreate_index: Creating index-%s" % index) - default_index = {"index_type": index, - "params": params, "metric_type": "COSINE"} + log.info("test_search_after_release_recreate_index: Creating index-%s" % index) + default_index = {"index_type": index, "params": params, "metric_type": "COSINE"} collection_w.create_index("float_vector", default_index) log.info("test_search_after_release_recreate_index: Created index-%s" % index) collection_w.load() # 4. search search_params = cf.gen_search_param(index, "COSINE") - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) collection_w.search(vectors[:default_nq], default_search_field, @@ -2405,8 +2385,7 @@ def test_search_after_release_recreate_index(self, dim, index, params, auto_id, # 5. re-create index collection_w.release() collection_w.drop_index() - default_index = {"index_type": index, - "params": params, "metric_type": metric_type} + default_index = {"index_type": index, "params": params, "metric_type": metric_type} collection_w.create_index("float_vector", default_index) collection_w.load() for search_param in search_params: @@ -2425,8 +2404,8 @@ def test_search_after_release_recreate_index(self, dim, index, params, auto_id, @pytest.mark.tags(CaseLabel.GPU) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[8:10], - ct.default_index_params[8:10])) + zip(ct.all_index_types[9:11], + ct.default_index_params[9:11])) def test_search_after_index_different_metric_type_gpu(self, dim, index, params, auto_id, _async, enable_dynamic_field): """ target: test search with different metric type @@ -2446,18 +2425,14 @@ def test_search_after_index_different_metric_type_gpu(self, dim, index, params, if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - log.info( - "test_search_after_index_different_metric_type: Creating index-%s" % index) - default_index = {"index_type": index, - "params": params, "metric_type": "IP"} + log.info("test_search_after_index_different_metric_type: Creating index-%s" % index) + default_index = {"index_type": index, "params": params, "metric_type": "IP"} collection_w.create_index("float_vector", default_index) - log.info( - "test_search_after_index_different_metric_type: Created index-%s" % index) + log.info("test_search_after_index_different_metric_type: Created index-%s" % index) collection_w.load() # 3. search search_params = cf.gen_search_param(index, "IP") - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] for search_param in search_params: log.info("Searching with search params: {}".format(search_param)) collection_w.search(vectors[:default_nq], default_search_field, @@ -2611,13 +2586,11 @@ def test_search_index_partitions(self, nb, nq, dim, auto_id, _async): is_index=False)[0:4] vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] # 2. create index - default_index = {"index_type": "IVF_FLAT", - "params": {"nlist": 128}, "metric_type": "L2"} + default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search through partitions - log.info( - "test_search_index_partitions: searching (1000 entities) through partitions") + log.info("test_search_index_partitions: searching (1000 entities) through partitions") par = collection_w.partitions log.info("test_search_index_partitions: partitions: %s" % par) search_params = {"metric_type": "L2", "params": {"nprobe": 64}} @@ -2831,7 +2804,6 @@ def test_search_binary_tanimoto_flat_index(self, nq, dim, auto_id, _async, index min(distance_0, distance_1)) <= epsilon @pytest.mark.tags(CaseLabel.L2) - @pytest.mark.skip("substructure obsolete") @pytest.mark.parametrize("index", ["BIN_FLAT"]) def test_search_binary_substructure_flat_index(self, auto_id, _async, index, is_flush): """ @@ -2850,15 +2822,13 @@ def test_search_binary_substructure_flat_index(self, auto_id, _async, index, is_ = self.init_collection_general(prefix, True, default_nb, is_binary=True, auto_id=auto_id, dim=dim, is_index=False, is_flush=is_flush)[0:5] # 2. create index - default_index = {"index_type": index, "params": { - "nlist": 128}, "metric_type": "SUBSTRUCTURE"} + default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "SUBSTRUCTURE"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 3. generate search vectors _, binary_vectors = cf.gen_binary_vectors(nq, dim) # 4. search and compare the distance - search_params = {"metric_type": "SUBSTRUCTURE", - "params": {"nprobe": 10}} + search_params = {"metric_type": "SUBSTRUCTURE", "params": {"nprobe": 10}} res = collection_w.search(binary_vectors[:nq], "binary_vector", search_params, default_limit, "int64 >= 0", _async=_async)[0] @@ -2869,7 +2839,6 @@ def test_search_binary_substructure_flat_index(self, auto_id, _async, index, is_ assert len(res) <= default_limit @pytest.mark.tags(CaseLabel.L2) - @pytest.mark.skip("superstructure obsolete") @pytest.mark.parametrize("index", ["BIN_FLAT"]) def test_search_binary_superstructure_flat_index(self, auto_id, _async, index, is_flush): """ @@ -2888,15 +2857,13 @@ def test_search_binary_superstructure_flat_index(self, auto_id, _async, index, i = self.init_collection_general(prefix, True, default_nb, is_binary=True, auto_id=auto_id, dim=dim, is_index=False, is_flush=is_flush)[0:5] # 2. create index - default_index = {"index_type": index, "params": { - "nlist": 128}, "metric_type": "SUPERSTRUCTURE"} + default_index = {"index_type": index, "params": {"nlist": 128}, "metric_type": "SUPERSTRUCTURE"} collection_w.create_index("binary_vector", default_index) collection_w.load() # 3. generate search vectors _, binary_vectors = cf.gen_binary_vectors(nq, dim) # 4. search and compare the distance - search_params = {"metric_type": "SUPERSTRUCTURE", - "params": {"nprobe": 10}} + search_params = {"metric_type": "SUPERSTRUCTURE", "params": {"nprobe": 10}} res = collection_w.search(binary_vectors[:nq], "binary_vector", search_params, default_limit, "int64 >= 0", _async=_async)[0] @@ -2967,16 +2934,13 @@ def test_search_with_expression(self, dim, expression, _async, enable_dynamic_fi filter_ids.append(_id) # 2. create index - index_param = {"index_type": "IVF_FLAT", - "metric_type": "COSINE", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() # 3. search with expression - log.info( - "test_search_with_expression: searching with expression: %s" % expression) - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] + log.info("test_search_with_expression: searching with expression: %s" % expression) + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] search_res, _ = collection_w.search(vectors[:default_nq], default_search_field, default_search_params, nb, expression, _async=_async, @@ -3152,7 +3116,6 @@ def test_search_expression_all_data_type(self, nb, nq, dim, auto_id, _async, ena in res[0][0].entity._row_data @pytest.mark.tags(CaseLabel.L1) - @pytest.mark.skip(reason="issue #23646") @pytest.mark.parametrize("field", ct.all_scalar_data_types[:3]) def test_search_expression_different_data_type(self, field): """ @@ -3165,27 +3128,27 @@ def test_search_expression_different_data_type(self, field): offset = 2 ** (num - 1) default_schema = cf.gen_collection_schema_all_datatype() collection_w = self.init_collection_wrap(schema=default_schema) - collection_w = cf.insert_data( - collection_w, is_all_data_type=True, insert_offset=offset-1000)[0] + collection_w = cf.insert_data(collection_w, is_all_data_type=True, insert_offset=offset-1000)[0] # 2. create index and load collection_w.create_index(field_name, default_index_params) collection_w.load() - # 3. search - log.info("test_search_expression_different_data_type: Searching collection %s" % - collection_w.name) + # 3. search using expression which field value is out of bound + log.info("test_search_expression_different_data_type: Searching collection %s" % collection_w.name) expression = f"{field} >= {offset}" - res = collection_w.search(vectors, default_search_field, default_search_params, - default_limit, expression, output_fields=[ - field], - check_task=CheckTasks.check_search_results, - check_items={"nq": default_nq, - "limit": default_limit})[0] - - # 4. check the result - for ids in res[0].ids: - assert ids >= offset + collection_w.search(vectors, default_search_field, default_search_params, + default_limit, expression, output_fields=[field], + check_task=CheckTasks.check_search_results, + check_items={"nq": default_nq, + "limit": 0})[0] + # 4. search normal using all the scalar type as output fields + collection_w.search(vectors, default_search_field, default_search_params, + default_limit, output_fields=[field], + check_task=CheckTasks.check_search_results, + check_items={"nq": default_nq, + "limit": default_limit, + "output_fields": [field]}) @pytest.mark.tags(CaseLabel.L1) def test_search_with_comparative_expression(self, _async): @@ -3199,10 +3162,8 @@ def test_search_with_comparative_expression(self, _async): dim = 1 fields = [cf.gen_int64_field("int64_1"), cf.gen_int64_field("int64_2"), cf.gen_float_vec_field(dim=dim)] - schema = cf.gen_collection_schema( - fields=fields, primary_field="int64_1") - collection_w = self.init_collection_wrap( - name=cf.gen_unique_str("comparison"), schema=schema) + schema = cf.gen_collection_schema(fields=fields, primary_field="int64_1") + collection_w = self.init_collection_wrap(name=cf.gen_unique_str("comparison"), schema=schema) # 2. inset data values = pd.Series(data=[i for i in range(0, nb)]) @@ -3217,8 +3178,7 @@ def test_search_with_comparative_expression(self, _async): filter_ids.extend(_id) # 3. search with expression - collection_w.create_index( - ct.default_float_vec_field_name, index_params=ct.default_flat_index) + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.load() expression = "int64_1 <= int64_2" vectors = [[random.random() for _ in range(dim)] @@ -3251,8 +3211,7 @@ def test_search_with_output_fields_empty(self, nb, nq, dim, auto_id, _async): auto_id=auto_id, dim=dim)[0:4] # 2. search - log.info("test_search_with_output_fields_empty: Searching collection %s" % - collection_w.name) + log.info("test_search_with_output_fields_empty: Searching collection %s" % collection_w.name) vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] collection_w.search(vectors[:nq], default_search_field, default_search_params, default_limit, @@ -3277,19 +3236,18 @@ def test_search_with_output_field(self, auto_id, _async, enable_dynamic_field): auto_id=auto_id, enable_dynamic_field=enable_dynamic_field)[0:4] # 2. search - log.info("test_search_with_output_field: Searching collection %s" % - collection_w.name) + log.info("test_search_with_output_field: Searching collection %s" % collection_w.name) - res = collection_w.search(vectors[:default_nq], default_search_field, - default_search_params, default_limit, - default_search_exp, _async=_async, - output_fields=[default_int64_field_name], - check_task=CheckTasks.check_search_results, - check_items={"nq": default_nq, - "ids": insert_ids, - "limit": default_limit, - "_async": _async, - "output_fields": [default_int64_field_name]})[0] + collection_w.search(vectors[:default_nq], default_search_field, + default_search_params, default_limit, + default_search_exp, _async=_async, + output_fields=[default_int64_field_name], + check_task=CheckTasks.check_search_results, + check_items={"nq": default_nq, + "ids": insert_ids, + "limit": default_limit, + "_async": _async, + "output_fields": [default_int64_field_name]}) @pytest.mark.tags(CaseLabel.L1) def test_search_with_output_vector_field(self, auto_id, _async, enable_dynamic_field): @@ -3303,8 +3261,7 @@ def test_search_with_output_vector_field(self, auto_id, _async, enable_dynamic_f auto_id=auto_id, enable_dynamic_field=enable_dynamic_field)[0:4] # 2. search - log.info("test_search_with_output_field: Searching collection %s" % - collection_w.name) + log.info("test_search_with_output_field: Searching collection %s" % collection_w.name) collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, _async=_async, @@ -3329,8 +3286,7 @@ def test_search_with_output_fields(self, nb, nq, dim, auto_id, _async): auto_id=auto_id, dim=dim)[0:4] # 2. search - log.info("test_search_with_output_fields: Searching collection %s" % - collection_w.name) + log.info("test_search_with_output_fields: Searching collection %s" % collection_w.name) vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] output_fields = [default_int64_field_name, default_float_field_name] collection_w.search(vectors[:nq], default_search_field, @@ -3344,10 +3300,10 @@ def test_search_with_output_fields(self, nb, nq, dim, auto_id, _async): "_async": _async, "output_fields": output_fields}) - @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[:6], - ct.default_index_params[:6])) + zip(ct.all_index_types[:7], + ct.default_index_params[:7])) @pytest.mark.parametrize("metrics", ct.float_metrics) @pytest.mark.parametrize("limit", [20, 1200]) def test_search_output_field_vector_after_different_index_metrics(self, index, params, metrics, limit): @@ -3359,19 +3315,10 @@ def test_search_output_field_vector_after_different_index_metrics(self, index, p 4. check the result vectors should be equal to the inserted expected: search success """ - if index in ["IVF_SQ8", "IVF_PQ"]: - pytest.skip("IVF_SQ8 and IVF_PQ do not support output vector now") - if index == "DISKANN" and metrics == "IP": - pytest.skip("DISKANN(IP) does not support output vector now") - if metrics == "COSINE": - pytest.skip("COSINE does not support output vector now") - # 1. create a collection and insert data - collection_w, _vectors = self.init_collection_general( - prefix, True, is_index=False)[:2] + collection_w, _vectors = self.init_collection_general(prefix, True, is_index=False)[:2] # 2. create index and load - default_index = {"index_type": index, - "params": params, "metric_type": metrics} + default_index = {"index_type": index, "params": params, "metric_type": metrics} collection_w.create_index(field_name, default_index) collection_w.load() @@ -3394,9 +3341,10 @@ def test_search_output_field_vector_after_different_index_metrics(self, index, p "original_entities": _vectors, "output_fields": [field_name]}) - @pytest.mark.tags(CaseLabel.L2) - @pytest.mark.parametrize("index", ["HNSW", "BIN_FLAT", "BIN_IVF_FLAT"]) - def test_search_output_field_vector_after_binary_index(self, index): + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("metrics", ct.binary_metrics[:2]) + @pytest.mark.parametrize("index", ["BIN_FLAT", "BIN_IVF_FLAT", "HNSW"]) + def test_search_output_field_vector_after_binary_index(self, metrics, index): """ target: test search with output vector field after binary index method: 1. create a collection and insert data @@ -3406,27 +3354,59 @@ def test_search_output_field_vector_after_binary_index(self, index): expected: search success """ # 1. create a collection and insert data - collection_w = self.init_collection_general( - prefix, is_binary=True, is_index=False)[0] + collection_w = self.init_collection_general(prefix, is_binary=True, is_index=False)[0] data = cf.gen_default_binary_dataframe_data()[0] collection_w.insert(data) # 2. create index and load - default_index = {"index_type": index, "metric_type": "JACCARD", - "params": {"nlist": 128, "efConstruction": 64, "M": 10}} + params = {"M": 48, "efConstruction": 500} if index == "HNSW" else {"nlist": 128} + default_index = {"index_type": index, "metric_type": metrics, "params": params} collection_w.create_index(binary_field_name, default_index) collection_w.load() # 3. search with output field vector - search_params = {"metric_type": "JACCARD"} + search_params = cf.gen_search_param(index, metrics) binary_vectors = cf.gen_binary_vectors(1, default_dim)[1] + for search_param in search_params: + res = collection_w.search(binary_vectors, binary_field_name, + search_param, 2, default_search_exp, + output_fields=[binary_field_name])[0] + + # 4. check the result vectors should be equal to the inserted + assert res[0][0].entity.binary_vector == [data[binary_field_name][res[0][0].id]] + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("metrics", ct.structure_metrics) + @pytest.mark.parametrize("index", ["BIN_FLAT"]) + def test_search_output_field_vector_after_structure_metrics(self, metrics, index): + """ + target: test search with output vector field after binary index + method: 1. create a collection and insert data + 2. create index and load + 3. search with output field vector + 4. check the result vectors should be equal to the inserted + expected: search success + """ + dim = 8 + # 1. create a collection and insert data + collection_w = self.init_collection_general(prefix, dim=dim, is_binary=True, is_index=False)[0] + data = cf.gen_default_binary_dataframe_data(dim=dim)[0] + collection_w.insert(data) + + # 2. create index and load + default_index = {"index_type": index, "metric_type": metrics, "params": {"nlist": 128}} + collection_w.create_index(binary_field_name, default_index) + collection_w.load() + + # 3. search with output field vector + search_params = {"metric_type": metrics, "params": {"nprobe": 10}} + binary_vectors = cf.gen_binary_vectors(ct.default_nq, dim)[1] res = collection_w.search(binary_vectors, binary_field_name, search_params, 2, default_search_exp, output_fields=[binary_field_name])[0] # 4. check the result vectors should be equal to the inserted - assert res[0][0].entity.binary_vector == [ - data[binary_field_name][res[0][0].id]] + assert res[0][0].entity.binary_vector == [data[binary_field_name][res[0][0].id]] @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("dim", [32, 128, 768]) @@ -3440,8 +3420,7 @@ def test_search_output_field_vector_with_different_dim(self, dim): expected: search success """ # 1. create a collection and insert data - collection_w, _vectors = self.init_collection_general(prefix, True, dim=dim)[ - :2] + collection_w, _vectors = self.init_collection_general(prefix, True, dim=dim)[:2] # 2. search with output field vector vectors = cf.gen_vectors(default_nq, dim=dim) @@ -3468,8 +3447,7 @@ def test_search_output_vector_field_and_scalar_field(self, enable_dynamic_field) enable_dynamic_field=enable_dynamic_field)[:2] # 2. search with output field vector - output_fields = [default_float_field_name, - default_string_field_name, default_search_field] + output_fields = [default_float_field_name, default_string_field_name, default_search_field] original_entities = [] if enable_dynamic_field: entities = [] @@ -3500,12 +3478,11 @@ def test_search_output_vector_field_and_pk_field(self, enable_dynamic_field): expected: search success """ # 1. initialize a collection - collection_w = self.init_collection_general( - prefix, True, enable_dynamic_field=enable_dynamic_field)[0] + collection_w = self.init_collection_general(prefix, True, + enable_dynamic_field=enable_dynamic_field)[0] # 2. search with output field vector - output_fields = [default_int64_field_name, - default_string_field_name, default_search_field] + output_fields = [default_int64_field_name, default_string_field_name, default_search_field] collection_w.search(vectors[:1], default_search_field, default_search_params, default_limit, default_search_exp, output_fields=output_fields, @@ -3557,10 +3534,8 @@ def test_search_with_output_field_wildcard(self, wildcard_output_fields, auto_id collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id)[0:4] # 2. search - log.info("test_search_with_output_field_wildcard: Searching collection %s" % - collection_w.name) - output_fields = cf.get_wildcard_output_field_names( - collection_w, wildcard_output_fields) + log.info("test_search_with_output_field_wildcard: Searching collection %s" % collection_w.name) + output_fields = cf.get_wildcard_output_field_names(collection_w, wildcard_output_fields) collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, _async=_async, @@ -3581,18 +3556,17 @@ def test_search_with_invalid_output_fields(self, invalid_output_fields, auto_id) expected: search success """ # 1. initialize with data - collection_w, _, _, insert_ids = self.init_collection_general( - prefix, True, auto_id=auto_id)[0:4] + collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, auto_id=auto_id)[0:4] # 2. search - log.info("test_search_with_output_field_wildcard: Searching collection %s" % - collection_w.name) + log.info("test_search_with_output_field_wildcard: Searching collection %s" % collection_w.name) + error1 = {"err_code": 65535, "err_msg": "field %s not exist" % invalid_output_fields[0]} + error2 = {"err_code": 1, "err_msg": "`output_fields` value %s is illegal" % invalid_output_fields[0]} + error = error2 if invalid_output_fields == [""] else error1 collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_search_exp, output_fields=invalid_output_fields, - check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "field %s is not exist" % invalid_output_fields[0]}) + check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) def test_search_multi_collections(self, nb, nq, dim, auto_id, _async): @@ -3610,8 +3584,7 @@ def test_search_multi_collections(self, nb, nq, dim, auto_id, _async): auto_id=auto_id, dim=dim)[0:4] # 2. search - vectors = [[random.random() for _ in range(dim)] - for _ in range(nq)] + vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] log.info("test_search_multi_collections: searching %s entities (nq = %s) from collection %s" % (default_limit, nq, collection_w.name)) collection_w.search(vectors[:nq], default_search_field, @@ -3633,10 +3606,10 @@ def test_search_concurrent_multi_threads(self, nb, nq, dim, auto_id, _async, ena # 1. initialize with data threads_num = 10 threads = [] - collection_w, _, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, nb, - auto_id=auto_id, - dim=dim, - enable_dynamic_field=enable_dynamic_field)[0:5] + collection_w, _, _, insert_ids = self.init_collection_general(prefix, True, nb, + auto_id=auto_id, + dim=dim, + enable_dynamic_field=enable_dynamic_field)[0:4] def search(collection_w): vectors = [[random.random() for _ in range(dim)] @@ -3651,8 +3624,7 @@ def search(collection_w): "_async": _async}) # 2. search with multi-processes - log.info( - "test_search_concurrent_multi_threads: searching with %s processes" % threads_num) + log.info("test_search_concurrent_multi_threads: searching with %s processes" % threads_num) for i in range(threads_num): t = threading.Thread(target=search, args=(collection_w,)) threads.append(t) @@ -3671,18 +3643,15 @@ def test_search_insert_in_parallel(self): """ c_name = cf.gen_unique_str(prefix) collection_w = self.init_collection_wrap(name=c_name) - default_index = {"index_type": "IVF_FLAT", - "params": {"nlist": 128}, "metric_type": "L2"} - collection_w.create_index( - ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "L2"} + collection_w.create_index(ct.default_float_vec_field_name, default_index) collection_w.load() def do_insert(): df = cf.gen_default_dataframe_data(10000) for i in range(11): collection_w.insert(df) - log.info( - f'Collection num entities is : {collection_w.num_entities}') + log.info(f'Collection num entities is : {collection_w.num_entities}') def do_search(): while True: @@ -3696,8 +3665,7 @@ def do_search(): timeout=30) p_insert = multiprocessing.Process(target=do_insert, args=()) - p_search = multiprocessing.Process( - target=do_search, args=(), daemon=True) + p_search = multiprocessing.Process(target=do_search, args=(), daemon=True) p_insert.start() p_search.start() @@ -3720,8 +3688,7 @@ def test_search_round_decimal(self, round_decimal, enable_dynamic_field): collection_w = self.init_collection_general(prefix, True, nb=tmp_nb, enable_dynamic_field=enable_dynamic_field)[0] # 2. search - log.info("test_search_round_decimal: Searching collection %s" % - collection_w.name) + log.info("test_search_round_decimal: Searching collection %s" % collection_w.name) res, _ = collection_w.search(vectors[:tmp_nq], default_search_field, default_search_params, tmp_limit) @@ -3735,8 +3702,7 @@ def test_search_round_decimal(self, round_decimal, enable_dynamic_field): dis_actual = res_round[0][i].distance # log.debug(f'actual: {dis_actual}, expect: {dis_expect}') # abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) - assert math.isclose(dis_actual, dis_expect, - rel_tol=0, abs_tol=abs_tol) + assert math.isclose(dis_actual, dis_expect, rel_tol=0, abs_tol=abs_tol) @pytest.mark.tags(CaseLabel.L1) def test_search_with_expression_large(self, dim, enable_dynamic_field): @@ -3754,26 +3720,22 @@ def test_search_with_expression_large(self, dim, enable_dynamic_field): with_json=False)[0:4] # 2. create index - index_param = {"index_type": "IVF_FLAT", - "metric_type": "COSINE", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() # 3. search with expression expression = f"0 < {default_int64_field_name} < 5001" - log.info( - "test_search_with_expression: searching with expression: %s" % expression) + log.info("test_search_with_expression: searching with expression: %s" % expression) nums = 5000 vectors = [[random.random() for _ in range(dim)] for _ in range(nums)] search_res, _ = collection_w.search(vectors, default_search_field, default_search_params, default_limit, expression, check_task=CheckTasks.check_search_results, - check_items={ - "nq": nums, - "ids": insert_ids, - "limit": default_limit, - }) + check_items={"nq": nums, + "ids": insert_ids, + "limit": default_limit}) @pytest.mark.tags(CaseLabel.L1) def test_search_with_expression_large_two(self, dim, enable_dynamic_field): @@ -3791,8 +3753,7 @@ def test_search_with_expression_large_two(self, dim, enable_dynamic_field): with_json=False)[0:4] # 2. create index - index_param = {"index_type": "IVF_FLAT", - "metric_type": "COSINE", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() @@ -4004,8 +3965,7 @@ def test_search_ignore_growing(self, nq, dim, _async): collection_w.insert(data) # 3. search with param ignore_growing=True - search_params = {"metric_type": "COSINE", "params": { - "nprobe": 10}, "ignore_growing": True} + search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}, "ignore_growing": True} vector = [[random.random() for _ in range(dim)] for _ in range(nq)] res = collection_w.search(vector[:nq], default_search_field, search_params, default_limit, default_search_exp, _async=_async, @@ -4237,6 +4197,58 @@ def test_search_using_all_types_of_default_value(self, auto_id): assert res[ct.default_bool_field_name] is False assert res[ct.default_string_field_name] == "abc" + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("index, params", zip(ct.all_index_types[1:4], ct.default_index_params[1:4])) + def test_search_repeatedly_ivf_index_same_limit(self, index, params): + """ + target: test create collection repeatedly + method: search twice, check the results is the same + expected: search results are as expected + """ + nb = 5000 + limit = 30 + # 1. create a collection + collection_w = self.init_collection_general(prefix, True, nb, is_index=False)[0] + + # 2. insert data again + index_params = {"metric_type": "COSINE", "index_type": index, "params": params} + collection_w.create_index(default_search_field, index_params) + + # 3. search with param ignore_growing=True + collection_w.load() + search_params = cf.gen_search_param(index, "COSINE")[0] + vector = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + res1 = collection_w.search(vector[:default_nq], default_search_field, search_params, limit)[0] + res2 = collection_w.search(vector[:default_nq], default_search_field, search_params, limit)[0] + for i in range(default_nq): + res1[i].ids == res2[i].ids + + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("index, params", zip(ct.all_index_types[1:4], ct.default_index_params[1:4])) + def test_search_repeatedly_ivf_index_different_limit(self, index, params): + """ + target: test create collection repeatedly + method: search twice, check the results is the same + expected: search results are as expected + """ + nb = 5000 + limit = random.randint(10, 100) + # 1. create a collection + collection_w = self.init_collection_general(prefix, True, nb, is_index=False)[0] + + # 2. insert data again + index_params = {"metric_type": "COSINE", "index_type": index, "params": params} + collection_w.create_index(default_search_field, index_params) + + # 3. search with param ignore_growing=True + collection_w.load() + search_params = cf.gen_search_param(index, "COSINE")[0] + vector = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] + res1 = collection_w.search(vector, default_search_field, search_params, limit)[0] + res2 = collection_w.search(vector, default_search_field, search_params, limit * 2)[0] + for i in range(default_nq): + res1[i].ids == res2[i].ids[limit:] + class TestSearchBase(TestcaseBase): @pytest.fixture( @@ -4274,25 +4286,22 @@ def test_search_flat_top_k(self, get_nq): """ top_k = 16385 # max top k is 16384 nq = get_nq - collection_w, data, _, insert_ids = self.init_collection_general( - prefix, insert_data=True, nb=nq)[0:4] + collection_w, data, _, insert_ids = self.init_collection_general(prefix, insert_data=True, nb=nq)[0:4] collection_w.load() if top_k <= max_top_k: - res, _ = collection_w.search(vectors[:nq], default_search_field, default_search_params, - top_k) + res, _ = collection_w.search(vectors[:nq], default_search_field, default_search_params, top_k) assert len(res[0]) <= top_k else: - collection_w.search(vectors[:nq], default_search_field, default_search_params, - top_k, + collection_w.search(vectors[:nq], default_search_field, default_search_params, top_k, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "no Available QueryNode result, " - "filter reason limit %s is too large," % top_k}) + check_items={"err_code": 65535, + "err_msg": f"topk [{top_k}] is invalid, top k should be in range" + f" [1, 16384], but got {top_k}"}) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[:6], - ct.default_index_params[:6])) + zip(ct.all_index_types[:7], + ct.default_index_params[:7])) def test_search_index_empty_partition(self, index, params): """ target: test basic search function, all the search params are correct, test all index params, and build @@ -4309,8 +4318,7 @@ def test_search_index_empty_partition(self, index, params): vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] # 2. create partition partition_name = "search_partition_empty" - collection_w.create_partition( - partition_name=partition_name, description="search partition empty") + collection_w.create_partition(partition_name=partition_name, description="search partition empty") par = collection_w.partitions # collection_w.load() # 3. create different index @@ -4320,8 +4328,7 @@ def test_search_index_empty_partition(self, index, params): if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - default_index = {"index_type": index, - "params": params, "metric_type": "COSINE"} + default_index = {"index_type": index, "params": params, "metric_type": "COSINE"} collection_w.create_index("float_vector", default_index) collection_w.load() @@ -4342,8 +4349,8 @@ def test_search_index_empty_partition(self, index, params): @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[:6], - ct.default_index_params[:6])) + zip(ct.all_index_types[:7], + ct.default_index_params[:7])) def test_search_index_partitions(self, index, params, get_top_k): """ target: test basic search function, all the search params are correct, test all index params, and build @@ -4353,30 +4360,31 @@ def test_search_index_partitions(self, index, params, get_top_k): top_k = get_top_k nq = ct.default_nq dim = ct.default_dim - # 1. initialize with data - collection_w, _, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, nq, + # 1. initialize with data in 2 partitions + collection_w, _, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, partition_num=1, dim=dim, is_index=False)[0:5] vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] - # 2. create patition - partition_name = ct.default_partition_name - par = collection_w.partitions - # collection_w.load() - # 3. create different index + # 2. create different index if params.get("m"): if (dim % params["m"]) != 0: params["m"] = dim // 4 if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - default_index = {"index_type": index, - "params": params, "metric_type": "COSINE"} + default_index = {"index_type": index, "params": params, "metric_type": "COSINE"} collection_w.create_index("float_vector", default_index) + + # 3. load and search collection_w.load() - res, _ = collection_w.search(vectors[:nq], default_search_field, - ct.default_search_params, top_k, - default_search_exp, [partition_name]) - assert len(res[0]) <= top_k + par = collection_w.partitions + collection_w.search(vectors[:nq], default_search_field, + ct.default_search_params, top_k, + default_search_exp, [par[0].name, par[1].name], + check_task=CheckTasks.check_search_results, + check_items={"nq": nq, + "limit": top_k, + "ids": insert_ids}) @pytest.mark.tags(CaseLabel.L2) def test_search_ip_flat(self, get_top_k): @@ -4405,8 +4413,8 @@ def test_search_ip_flat(self, get_top_k): @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[:6], - ct.default_index_params[:6])) + zip(ct.all_index_types[:7], + ct.default_index_params[:7])) def test_search_ip_after_index(self, index, params): """ target: test basic search function, all the search params are correct, test all index params, and build @@ -4422,8 +4430,7 @@ def test_search_ip_after_index(self, index, params): dim=dim, is_index=False)[0:5] vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] # 2. create ip index - default_index = {"index_type": index, - "params": params, "metric_type": "IP"} + default_index = {"index_type": index, "params": params, "metric_type": "IP"} collection_w.create_index("float_vector", default_index) collection_w.load() search_params = {"metric_type": "IP", "params": {"nprobe": 10}} @@ -4467,8 +4474,8 @@ def test_search_ip_brute_force(self, nb, dim): @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[:6], - ct.default_index_params[:6])) + zip(ct.all_index_types[:7], + ct.default_index_params[:7])) def test_search_ip_index_empty_partition(self, index, params): """ target: test basic search function, all the search params are correct, test all index params, and build @@ -4485,13 +4492,11 @@ def test_search_ip_index_empty_partition(self, index, params): vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] # 2. create partition partition_name = "search_partition_empty" - collection_w.create_partition( - partition_name=partition_name, description="search partition empty") + collection_w.create_partition(partition_name=partition_name, description="search partition empty") par = collection_w.partitions # collection_w.load() # 3. create different index - default_index = {"index_type": index, - "params": params, "metric_type": "IP"} + default_index = {"index_type": index, "params": params, "metric_type": "IP"} collection_w.create_index("float_vector", default_index) collection_w.load() @@ -4513,8 +4518,8 @@ def test_search_ip_index_empty_partition(self, index, params): @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[:6], - ct.default_index_params[:6])) + zip(ct.all_index_types[:7], + ct.default_index_params[:7])) def test_search_ip_index_partitions(self, index, params): """ target: test basic search function, all the search params are correct, test all index params, and build @@ -4533,8 +4538,7 @@ def test_search_ip_index_partitions(self, index, params): par_name = collection_w.partitions[0].name # collection_w.load() # 3. create different index - default_index = {"index_type": index, - "params": params, "metric_type": "IP"} + default_index = {"index_type": index, "params": params, "metric_type": "IP"} collection_w.create_index("float_vector", default_index) collection_w.load() @@ -4544,7 +4548,119 @@ def test_search_ip_index_partitions(self, index, params): search_params, top_k, default_search_exp, [par_name]) - assert len(res[0]) <= top_k + @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.parametrize("index, params", zip(ct.all_index_types[:7], ct.default_index_params[:7])) + def test_search_cosine_all_indexes(self, index, params): + """ + target: test basic search function, all the search params are correct, test all index params, and build + method: search collection with the given vectors and tags, check the result + expected: the length of the result is top_k + """ + # 1. initialize with data + collection_w, _, _, insert_ids, time_stamp = self.init_collection_general(prefix, True, + is_index=False)[0:5] + # 2. create index + default_index = {"index_type": index, "params": params, "metric_type": "COSINE"} + collection_w.create_index("float_vector", default_index) + collection_w.load() + + # 3. search + search_params = {"metric_type": "COSINE"} + res, _ = collection_w.search(vectors[:default_nq], default_search_field, + search_params, default_limit, default_search_exp, + check_task=CheckTasks.check_search_results, + check_items={"nq": default_nq, + "ids": insert_ids, + "limit": default_limit}) + + # 4. check cosine distance + for i in range(default_nq): + for distance in res[i].distances: + assert 1 >= distance >= -1 + + @pytest.mark.tags(CaseLabel.L2) + def test_search_cosine_results_same_as_l2(self): + """ + target: test search results of l2 and cosine keep the same + method: 1. search L2 + 2. search cosine + 3. compare the results + expected: raise no exception + """ + nb = ct.default_nb + # 1. prepare original data and normalized data + original_vec = [[random.random() for _ in range(ct.default_dim)] for _ in range(nb)] + normalize_vec = preprocessing.normalize(original_vec, axis=1, norm='l2') + normalize_vec = normalize_vec.tolist() + data = cf.gen_default_dataframe_data() + + # 2. create L2 collection and insert normalized data + collection_w1 = self.init_collection_general(prefix, is_index=False)[0] + data[ct.default_float_vec_field_name] = normalize_vec + collection_w1.insert(data) + + # 2. create index L2 + default_index = {"index_type": "IVF_SQ8", "params": {"nlist": 64}, "metric_type": "L2"} + collection_w1.create_index("float_vector", default_index) + collection_w1.load() + + # 3. search L2 + search_params = {"params": {"nprobe": 10}, "metric_type": "L2"} + res_l2, _ = collection_w1.search(vectors[:default_nq], default_search_field, + search_params, default_limit, default_search_exp) + + # 4. create cosine collection and insert original data + collection_w2 = self.init_collection_general(prefix, is_index=False)[0] + data[ct.default_float_vec_field_name] = original_vec + collection_w2.insert(data) + + # 5. create index cosine + default_index = {"index_type": "IVF_SQ8", "params": {"nlist": 64}, "metric_type": "COSINE"} + collection_w2.create_index("float_vector", default_index) + collection_w2.load() + + # 6. search cosine + search_params = {"params": {"nprobe": 10}, "metric_type": "COSINE"} + res_cosine, _ = collection_w2.search(vectors[:default_nq], default_search_field, + search_params, default_limit, default_search_exp) + + # 7. check the search results + for i in range(default_nq): + assert res_l2[i].ids == res_cosine[i].ids + + @pytest.mark.tags(CaseLabel.L2) + def test_search_cosine_results_same_as_ip(self): + """ + target: test search results of ip and cosine keep the same + method: 1. search IP + 2. search cosine + 3. compare the results + expected: raise no exception + """ + # 1. create collection and insert data + collection_w = self.init_collection_general(prefix, True, is_index=False)[0] + + # 2. search IP + default_index = {"index_type": "IVF_SQ8", "params": {"nlist": 64}, "metric_type": "IP"} + collection_w.create_index("float_vector", default_index) + collection_w.load() + search_params = {"params": {"nprobe": 10}, "metric_type": "IP"} + res_ip, _ = collection_w.search(vectors[:default_nq], default_search_field, + search_params, default_limit, default_search_exp) + + # 3. search cosine + collection_w.release() + collection_w.drop_index() + default_index = {"index_type": "IVF_SQ8", "params": {"nlist": 64}, "metric_type": "COSINE"} + collection_w.create_index("float_vector", default_index) + collection_w.load() + search_params = {"params": {"nprobe": 10}, "metric_type": "COSINE"} + res_cosine, _ = collection_w.search(vectors[:default_nq], default_search_field, + search_params, default_limit, default_search_exp) + + # 4. check the search results + for i in range(default_nq): + assert res_ip[i].ids == res_cosine[i].ids @pytest.mark.tags(CaseLabel.L2) def test_search_without_connect(self): @@ -4825,23 +4941,22 @@ def test_search_string_with_invalid_expr(self, auto_id): """ # 1. initialize with data collection_w, _, _, insert_ids = \ - self.init_collection_general( - prefix, True, auto_id=auto_id, dim=default_dim)[0:4] + self.init_collection_general(prefix, True, auto_id=auto_id, dim=default_dim)[0:4] # 2. search log.info("test_search_string_with_invalid_expr: searching collection %s" % collection_w.name) - vectors = [[random.random() for _ in range(default_dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, default_limit, default_invaild_string_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "failed to create query plan: type mismatch"} - ) + check_items={"err_code": 65535, + "err_msg": "failed to create query plan: cannot parse expression: " + "varchar >= 0, error: comparisons between VarChar, " + "element_type: None and Int64 elementType: None are not supported"}) @pytest.mark.tags(CaseLabel.L2) - @pytest.mark.parametrize("expression", cf.gen_normal_string_expressions(ct.default_string_field_name)) + @pytest.mark.parametrize("expression", cf.gen_normal_string_expressions([ct.default_string_field_name])) def test_search_with_different_string_expr(self, dim, expression, _async, enable_dynamic_field): """ target: test search with different string expressions @@ -4870,16 +4985,13 @@ def test_search_with_different_string_expr(self, dim, expression, _async, enable filter_ids.append(_id) # 2. create index - index_param = {"index_type": "IVF_FLAT", - "metric_type": "COSINE", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() # 3. search with expression - log.info( - "test_search_with_expression: searching with expression: %s" % expression) - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] + log.info("test_search_with_expression: searching with expression: %s" % expression) + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] search_res, _ = collection_w.search(vectors[:default_nq], default_search_field, default_search_params, nb, expression, _async=_async, @@ -5584,8 +5696,8 @@ def test_search_pagination_with_offset_over_num_entities(self, offset): @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[:6], - ct.default_index_params[:6])) + zip(ct.all_index_types[:7], + ct.default_index_params[:7])) def test_search_pagination_after_different_index(self, index, params, auto_id, offset, _async): """ target: test search pagination after different index @@ -5605,14 +5717,12 @@ def test_search_pagination_after_different_index(self, index, params, auto_id, o if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - default_index = {"index_type": index, - "params": params, "metric_type": "L2"} + default_index = {"index_type": index, "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. search search_params = cf.gen_search_param(index) - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] for search_param in search_params: res = collection_w.search(vectors[:default_nq], default_search_field, search_param, default_limit + offset, default_search_exp, _async=_async)[0] @@ -5694,7 +5804,7 @@ def test_search_pagination_with_invalid_offset_type(self, offset): "err_msg": "offset [%s] is invalid" % offset}) @pytest.mark.tags(CaseLabel.L1) - @pytest.mark.parametrize("offset", [-1, 16386]) + @pytest.mark.parametrize("offset", [-1, 16385]) def test_search_pagination_with_invalid_offset_value(self, offset): """ target: test search pagination with invalid offset value @@ -5702,20 +5812,17 @@ def test_search_pagination_with_invalid_offset_value(self, offset): expected: raise exception """ # 1. initialize - collection_w = self.init_collection_general( - prefix, True, dim=default_dim)[0] + collection_w = self.init_collection_general(prefix, True, dim=default_dim)[0] # 2. search - search_param = {"metric_type": "COSINE", - "params": {"nprobe": 10}, "offset": offset} - vectors = [[random.random() for _ in range(default_dim)] - for _ in range(default_nq)] + search_param = {"metric_type": "COSINE", "params": {"nprobe": 10}, "offset": offset} + vectors = [[random.random() for _ in range(default_dim)] for _ in range(default_nq)] collection_w.search(vectors[:default_nq], default_search_field, search_param, default_limit, default_search_exp, check_task=CheckTasks.err_res, - check_items={"err_code": 1, + check_items={"err_code": 65535, "err_msg": "offset [%d] is invalid, should be in range " - "[1, 16385], but got %d" % (offset, offset)}) + "[1, 16384], but got %d" % (offset, offset)}) class TestSearchDiskann(TestcaseBase): @@ -5826,18 +5933,13 @@ def test_search_invalid_params_with_diskann_A(self, dim, auto_id, search_list, l """ # 1. initialize with data collection_w, _, _, insert_ids = \ - self.init_collection_general( - prefix, True, auto_id=auto_id, dim=dim, is_index=False)[0:4] + self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim, is_index=False)[0:4] # 2. create index - default_index = {"index_type": "DISKANN", - "metric_type": "L2", "params": {}} - collection_w.create_index( - ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}} + collection_w.create_index(ct.default_float_vec_field_name, default_index) collection_w.load() - default_search_params = {"metric_type": "L2", - "params": {"search_list": search_list}} - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] + default_search_params = {"metric_type": "L2", "params": {"search_list": search_list}} + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, @@ -5845,9 +5947,8 @@ def test_search_invalid_params_with_diskann_A(self, dim, auto_id, search_list, l default_search_exp, output_fields=output_fields, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "fail to search on all shard leaders"} - ) + check_items={"err_code": 65538, + "err_msg": "fail to search on all shard leaders"}) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("limit", [20]) @@ -5862,27 +5963,21 @@ def test_search_invalid_params_with_diskann_B(self, dim, auto_id, search_list, l """ # 1. initialize with data collection_w, _, _, insert_ids = \ - self.init_collection_general( - prefix, True, auto_id=auto_id, dim=dim, is_index=False)[0:4] + self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim, is_index=False)[0:4] # 2. create index - default_index = {"index_type": "DISKANN", - "metric_type": "L2", "params": {}} - collection_w.create_index( - ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}} + collection_w.create_index(ct.default_float_vec_field_name, default_index) collection_w.load() - default_search_params = {"metric_type": "L2", - "params": {"search_list": search_list}} - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] - output_fields = [default_int64_field_name, - default_float_field_name, default_string_field_name] + default_search_params = {"metric_type": "L2", "params": {"search_list": search_list}} + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, limit, default_search_exp, output_fields=output_fields, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "fail to search on all shard leaders"}) + check_items={"err_code": 65538, + "err_msg": "UnknownError"}) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("limit", [6553]) @@ -5897,27 +5992,21 @@ def test_search_invalid_params_with_diskann_C(self, dim, auto_id, search_list, l """ # 1. initialize with data collection_w, _, _, insert_ids = \ - self.init_collection_general( - prefix, True, auto_id=auto_id, dim=dim, is_index=False)[0:4] + self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim, is_index=False)[0:4] # 2. create index - default_index = {"index_type": "DISKANN", - "metric_type": "L2", "params": {}} - collection_w.create_index( - ct.default_float_vec_field_name, default_index) + default_index = {"index_type": "DISKANN", "metric_type": "L2", "params": {}} + collection_w.create_index(ct.default_float_vec_field_name, default_index) collection_w.load() - default_search_params = {"metric_type": "L2", - "params": {"search_list": search_list}} - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] - output_fields = [default_int64_field_name, - default_float_field_name, default_string_field_name] + default_search_params = {"metric_type": "L2", "params": {"search_list": search_list}} + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] + output_fields = [default_int64_field_name, default_float_field_name, default_string_field_name] collection_w.search(vectors[:default_nq], default_search_field, default_search_params, limit, default_search_exp, output_fields=output_fields, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "fail to search on all shard leaders"}) + check_items={"err_code": 65538, + "err_msg": "failed to search"}) @pytest.mark.tags(CaseLabel.L2) def test_search_with_diskann_with_string_pk(self, dim, enable_dynamic_field): @@ -6262,6 +6351,30 @@ def test_range_search_normal(self, nq, dim, auto_id, is_flush, radius, range_fil # distances_tmp = list(hits.distances) # assert distances_tmp.count(1.0) == 1 + @pytest.mark.tags(CaseLabel.L1) + def test_range_search_cosine(self): + """ + target: test range search normal case + method: create connection, collection, insert and search + expected: search successfully with limit(topK) + """ + # 1. initialize with data + collection_w = self.init_collection_general(prefix, True)[0] + range_filter = random.uniform(0, 1) + radius = random.uniform(-1, range_filter) + + # 2. range search + range_search_params = {"metric_type": "COSINE", + "params": {"radius": radius, "range_filter": range_filter}} + search_res = collection_w.search(vectors[:nq], default_search_field, + range_search_params, default_limit, + default_search_exp)[0] + + # 3. check search results + for hits in search_res: + for distance in hits.distances: + assert range_filter >= distance > radius + @pytest.mark.tags(CaseLabel.L2) def test_range_search_only_range_filter(self): """ @@ -6751,19 +6864,20 @@ def test_range_search_after_different_index_with_params(self, dim, index, params if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - default_index = {"index_type": index, - "params": params, "metric_type": "L2"} + default_index = {"index_type": index, "params": params, "metric_type": "L2"} collection_w.create_index("float_vector", default_index) collection_w.load() # 3. range search search_params = cf.gen_search_param(index) - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] for search_param in search_params: search_param["params"]["radius"] = 1000 search_param["params"]["range_filter"] = 0 if index.startswith("IVF_"): search_param["params"].pop("nprobe") + if index == "SCANN": + search_param["params"].pop("nprobe") + search_param["params"].pop("reorder_k") log.info("Searching with search params: {}".format(search_param)) collection_w.search(vectors[:default_nq], default_search_field, search_param, default_limit, @@ -6794,18 +6908,14 @@ def test_range_search_after_index_different_metric_type(self, dim, index, params if params.get("PQM"): if (dim % params["PQM"]) != 0: params["PQM"] = dim // 4 - log.info( - "test_range_search_after_index_different_metric_type: Creating index-%s" % index) - default_index = {"index_type": index, - "params": params, "metric_type": "IP"} + log.info("test_range_search_after_index_different_metric_type: Creating index-%s" % index) + default_index = {"index_type": index, "params": params, "metric_type": "IP"} collection_w.create_index("float_vector", default_index) - log.info( - "test_range_search_after_index_different_metric_type: Created index-%s" % index) + log.info("test_range_search_after_index_different_metric_type: Created index-%s" % index) collection_w.load() # 3. search search_params = cf.gen_search_param(index, "IP") - vectors = [[random.random() for _ in range(dim)] - for _ in range(default_nq)] + vectors = [[random.random() for _ in range(dim)] for _ in range(default_nq)] for search_param in search_params: search_param["params"]["radius"] = 0 search_param["params"]["range_filter"] = 1000 @@ -7312,28 +7422,33 @@ def test_range_search_with_expression_large(self, dim): is_index=False)[0:4] # 2. create index - index_param = {"index_type": "IVF_FLAT", - "metric_type": "L2", "params": {"nlist": 100}} + index_param = {"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 100}} collection_w.create_index("float_vector", index_param) collection_w.load() # 3. search with expression expression = f"0 < {default_int64_field_name} < 5001" - log.info( - "test_search_with_expression: searching with expression: %s" % expression) + log.info("test_search_with_expression: searching with expression: %s" % expression) nums = 5000 vectors = [[random.random() for _ in range(dim)] for _ in range(nums)] - range_search_params = {"metric_type": "L2", "params": {"radius": 1000, - "range_filter": 0}} + # calculate the distance to make sure in range(0, 1000) + search_params = {"metric_type": "L2"} search_res, _ = collection_w.search(vectors, default_search_field, - range_search_params, default_limit, expression, - check_task=CheckTasks.check_search_results, - check_items={ - "nq": nums, - "ids": insert_ids, - "limit": default_limit, - }) + search_params, 500, expression) + for i in range(nums): + if len(search_res[i]) < 10: + assert False + for j in range(len(search_res[i])): + if search_res[i][j].distance < 0 or search_res[i][j].distance >= 1000: + assert False + # range search + range_search_params = {"metric_type": "L2", "params": {"radius": 1000, "range_filter": 0}} + search_res, _ = collection_w.search(vectors, default_search_field, + range_search_params, default_limit, expression) + for i in range(nums): + log.info(i) + assert len(search_res[i]) == default_limit @pytest.mark.tags(CaseLabel.L2) def test_range_search_with_consistency_bounded(self, nq, dim, auto_id, _async): @@ -7904,8 +8019,7 @@ def test_load_partition_drop_partition_delete(self): expected: No exception """ # insert data - collection_w = self.init_collection_general( - prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load && release @@ -7917,18 +8031,17 @@ def test_load_partition_drop_partition_delete(self): collection_w.delete(f"int64 in {delete_ids}") # search on collection, partition1, partition2 collection_w.search(vectors[:1], field_name, default_search_params, 200, - partition_names=[ - partition_w1.name, partition_w2.name], + partition_names=[partition_w1.name, partition_w2.name], check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: 'not loaded'}) + check_items={ct.err_code: 65535, ct.err_msg: 'not loaded'}) collection_w.search(vectors[:1], field_name, default_search_params, 200, partition_names=[partition_w1.name], check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: 'not loaded'}) + check_items={ct.err_code: 65535, ct.err_msg: 'not loaded'}) collection_w.search(vectors[:1], field_name, default_search_params, 200, partition_names=[partition_w2.name], check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: 'not found'}) + check_items={ct.err_code: 65535, ct.err_msg: 'not found'}) @pytest.mark.tags(CaseLabel.L2) def test_compact_load_collection_release_partition(self): @@ -8649,8 +8762,7 @@ def test_load_collection_release_all_partitions(self): expected: No exception """ # init the collection - collection_w = self.init_collection_general( - prefix, True, 200, partition_num=1, is_index=False)[0] + collection_w = self.init_collection_general(prefix, True, 200, partition_num=1, is_index=False)[0] partition_w1, partition_w2 = collection_w.partitions collection_w.create_index(default_search_field, default_index_params) # load and release @@ -8660,8 +8772,8 @@ def test_load_collection_release_all_partitions(self): # search on collection collection_w.search(vectors[:1], field_name, default_search_params, 200, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, - ct.err_msg: "fail to get shard leaders from QueryCoord: collection not loaded"}) + check_items={ct.err_code: 65535, + ct.err_msg: "collection not loaded"}) @pytest.mark.tags(CaseLabel.L2) @pytest.mark.xfail(reason="issue #24446") @@ -9051,8 +9163,8 @@ def test_range_search_iterator_only_radius(self): @pytest.mark.tags(CaseLabel.L2) @pytest.mark.skip("issue #25145") @pytest.mark.parametrize("index, params", - zip(ct.all_index_types[:6], - ct.default_index_params[:6])) + zip(ct.all_index_types[:7], + ct.default_index_params[:7])) @pytest.mark.parametrize("metrics", ct.float_metrics) def test_search_iterator_after_different_index_metrics(self, index, params, metrics): """ @@ -9063,10 +9175,8 @@ def test_search_iterator_after_different_index_metrics(self, index, params, metr """ # 1. initialize with data batch_size = 100 - collection_w = self.init_collection_general( - prefix, True, is_index=False)[0] - default_index = {"index_type": index, - "params": params, "metric_type": metrics} + collection_w = self.init_collection_general(prefix, True, is_index=False)[0] + default_index = {"index_type": index, "params": params, "metric_type": metrics} collection_w.create_index(field_name, default_index) collection_w.load() # 2. search iterator diff --git a/tests/python_client/testcases/test_utility.py b/tests/python_client/testcases/test_utility.py index 889b1ef20be90..3b66fd0e961df 100644 --- a/tests/python_client/testcases/test_utility.py +++ b/tests/python_client/testcases/test_utility.py @@ -94,7 +94,8 @@ def test_has_collection_name_invalid(self, get_invalid_collection_name): self.utility_wrap.has_collection( c_name, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "Invalid collection name"}) + check_items={ct.err_code: 1100, + ct.err_msg: "collection name should not be empty: invalid parameter"}) # elif not isinstance(c_name, str): self.utility_wrap.has_collection(c_name, check_task=CheckTasks.err_res, # check_items={ct.err_code: 1, ct.err_msg: "illegal"}) @@ -112,7 +113,8 @@ def test_has_partition_collection_name_invalid(self, get_invalid_collection_name self.utility_wrap.has_partition( c_name, p_name, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "Invalid"}) + check_items={ct.err_code: 1100, + ct.err_msg: "collection name should not be empty: invalid parameter"}) @pytest.mark.tags(CaseLabel.L2) def test_has_partition_name_invalid(self, get_invalid_partition_name): @@ -134,9 +136,11 @@ def test_has_partition_name_invalid(self, get_invalid_partition_name): @pytest.mark.tags(CaseLabel.L2) def test_drop_collection_name_invalid(self, get_invalid_collection_name): self._connect() - error = f'`collection_name` value {get_invalid_collection_name} is illegal' + error1 = {ct.err_code: 1, ct.err_msg: f"`collection_name` value {get_invalid_collection_name} is illegal"} + error2 = {ct.err_code: 1100, ct.err_msg: f"Invalid collection name: {get_invalid_collection_name}."} + error = error1 if get_invalid_collection_name in [[], 1, [1, '2', 3], (1,), {1: 1}, None, ""] else error2 self.utility_wrap.drop_collection(get_invalid_collection_name, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: error}) + check_items=error) # TODO: enable @pytest.mark.tags(CaseLabel.L2) @@ -241,7 +245,7 @@ def test_loading_progress_not_existed_collection_name(self): self.collection_wrap.construct_from_dataframe(c_name, df, primary_field=ct.default_int64_field_name) self.collection_wrap.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) self.collection_wrap.load() - error = {ct.err_code: 1, ct.err_msg: "describe collection failed: can't find collection"} + error = {ct.err_code: 4, ct.err_msg: "collection default:not_existed_name: collection not found"} self.utility_wrap.loading_progress("not_existed_name", check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L2) @@ -285,7 +289,7 @@ def test_wait_for_loading_collection_not_existed(self): self.utility_wrap.wait_for_loading_complete( c_name, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "can't find collection"}) + check_items={ct.err_code: 4, ct.err_msg: f"collection default:{c_name}: collection not found"}) @pytest.mark.tags(CaseLabel.L2) def test_wait_for_loading_partition_not_existed(self): @@ -299,7 +303,7 @@ def test_wait_for_loading_partition_not_existed(self): self.utility_wrap.wait_for_loading_complete( collection_w.name, partition_names=[ct.default_tag], check_task=CheckTasks.err_res, - check_items={ct.err_code: 15, ct.err_msg: f'partitionID of partitionName:{ct.default_tag} can not be find'}) + check_items={ct.err_code: 200, ct.err_msg: f'partition={ct.default_tag}: partition not found'}) @pytest.mark.tags(CaseLabel.L2) def test_drop_collection_not_existed(self): @@ -584,9 +588,11 @@ def test_rename_collection_new_invalid_value(self, get_invalid_value_collection_ new_collection_name = get_invalid_value_collection_name self.utility_wrap.rename_collection(old_collection_name, new_collection_name, check_task=CheckTasks.err_res, - check_items={"err_code": 9, - "err_msg": "collection {} was not " - "loaded into memory)".format(collection_w.name)}) + check_items={"err_code": 1100, + "err_msg": "Invalid collection name: %s. the first " + "character of a collection name must be an " + "underscore or letter: invalid parameter" + % new_collection_name}) @pytest.mark.tags(CaseLabel.L2) def test_rename_collection_not_existed_collection(self): @@ -601,9 +607,9 @@ def test_rename_collection_not_existed_collection(self): new_collection_name = cf.gen_unique_str(prefix) self.utility_wrap.rename_collection(old_collection_name, new_collection_name, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "can't find collection: {}".format( - collection_w.name)}) + check_items={"err_code": 4, + "err_msg": "collection 1:test_collection_non_exist: " + "collection not found"}) @pytest.mark.tags(CaseLabel.L1) def test_rename_collection_existed_collection_name(self): @@ -617,10 +623,10 @@ def test_rename_collection_existed_collection_name(self): old_collection_name = collection_w.name self.utility_wrap.rename_collection(old_collection_name, old_collection_name, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "duplicated new collection name :{} with other " - "collection name or alias".format( - collection_w.name)}) + check_items={"err_code": 65535, + "err_msg": "duplicated new collection name default:{}" + " with other collection name or" + " alias".format(collection_w.name)}) @pytest.mark.tags(CaseLabel.L1) def test_rename_collection_existed_collection_alias(self): @@ -636,8 +642,8 @@ def test_rename_collection_existed_collection_alias(self): self.utility_wrap.create_alias(old_collection_name, alias) self.utility_wrap.rename_collection(old_collection_name, alias, check_task=CheckTasks.err_res, - check_items={"err_code": 1, - "err_msg": "duplicated new collection name :{} with " + check_items={"err_code": 65535, + "err_msg": "duplicated new collection name default:{} with " "other collection name or alias".format(alias)}) @pytest.mark.tags(CaseLabel.L1) @@ -809,7 +815,7 @@ def test_index_process_collection_not_existed(self): self.utility_wrap.index_building_progress( c_name, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "can't find collection"}) + check_items={ct.err_code: 4, ct.err_msg: "collection not found"}) @pytest.mark.tags(CaseLabel.L1) def test_index_process_collection_empty(self): @@ -837,7 +843,7 @@ def test_index_process_collection_insert_no_index(self): cw = self.init_collection_wrap(name=c_name) data = cf.gen_default_list_data(nb) cw.insert(data=data) - error = {ct.err_code: 25, ct.err_msg: "there is no index on collection"} + error = {ct.err_code: 700, ct.err_msg: f"{c_name}: index not found"} self.utility_wrap.index_building_progress(c_name, check_task=CheckTasks.err_res, check_items=error) @pytest.mark.tags(CaseLabel.L1) @@ -898,7 +904,7 @@ def test_wait_index_collection_not_existed(self): self.utility_wrap.wait_for_index_building_complete( c_name, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "can't find collection"}) + check_items={ct.err_code: 4, ct.err_msg: "collection not found"}) @pytest.mark.tags(CaseLabel.L1) def test_wait_index_collection_empty(self): @@ -948,9 +954,8 @@ def test_loading_progress_without_loading(self): assert collection_w.num_entities == ct.default_nb self.utility_wrap.loading_progress(collection_w.name, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, - ct.err_msg: 'fail to show collections from ' - 'the querycoord, no data'}) + check_items={ct.err_code: 101, + ct.err_msg: 'collection not loaded'}) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.parametrize("nb", [ct.default_nb, 5000]) @@ -1924,7 +1929,7 @@ def test_load_balance_with_all_dst_node_not_exist(self): # load balance self.utility_wrap.load_balance(collection_w.name, src_node_id, dst_node_ids, sealed_segment_ids, check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "no available queryNode to allocate"}) + check_items={ct.err_code: 1, ct.err_msg: "destination node not found in the same replica"}) @pytest.mark.tags(CaseLabel.L1) @pytest.mark.xfail(reason="issue: https://github.com/milvus-io/milvus/issues/19441") @@ -2760,7 +2765,7 @@ def test_revoke_public_role_privilege(self, host, port): self.utility_wrap.init_role("public") self.utility_wrap.role_grant("Collection", c_name, "Insert") - @pytest.mark.tags(CaseLabel.L2) + @pytest.mark.tags(CaseLabel.RBAC) def test_revoke_user_after_delete_user(self, host, port): """ target: test revoke user with deleted user @@ -4721,7 +4726,7 @@ def test_grant_privilege_with_collection_not_exist(self, host, port): # collection flush with db_b permission self.database_wrap.using_database(db_b) collection_w.flush(check_task=CheckTasks.err_res, - check_items={ct.err_code: 1, ct.err_msg: "can't find collection"}) + check_items={ct.err_code: 4, ct.err_msg: "collection not found"}) self.database_wrap.using_database(db_a) collection_w.flush(check_task=CheckTasks.check_permission_deny) @@ -4893,6 +4898,29 @@ def test_grant_not_existed_collection_privilege(self, host, port): self.database_wrap.using_database(ct.default_db) collection_w.flush(check_task=CheckTasks.check_permission_deny) + @pytest.mark.tags(CaseLabel.RBAC) + def test_create_over_max_roles(self, host, port): + """ + target: test create roles over max num + method: test create role with random name + expected: raise exception + """ + self.connection_wrap.connect(host=host, port=port, user=ct.default_user, + password=ct.default_password, check_task=ct.CheckTasks.ccr) + # 2 original roles: admin, public + for i in range(ct.max_role_num - 2): + role_name = "role_" + str(i) + self.utility_wrap.init_role(role_name, check_task=CheckTasks.check_role_property, + check_items={exp_name: role_name}) + self.utility_wrap.create_role() + assert self.utility_wrap.role_is_exist()[0] + + # now total 10 roles, create a new one will report error + self.utility_wrap.init_role("role_11") + error = {ct.err_code: 35, + ct.err_msg: "unable to create role because the number of roles has reached the limit"} + self.utility_wrap.create_role(check_task=CheckTasks.err_res, check_items=error) + @pytest.mark.tags(CaseLabel.L3) class TestUtilityFlushAll(TestcaseBase): diff --git a/tests/python_client/utils/util_common.py b/tests/python_client/utils/util_common.py index de65f1ee70179..064a758e38801 100644 --- a/tests/python_client/utils/util_common.py +++ b/tests/python_client/utils/util_common.py @@ -98,6 +98,7 @@ def wait_signal_to_apply_chaos(): while True and (time.time() - t0 < timeout): try: df = pd.read_parquet(f) + log.debug(f"read {f}:result\n {df}") result = df[(df['event_name'] == 'init_chaos') & (df['event_status'] == 'ready')] if len(result) > 0: log.info(f"{f}: {result}") @@ -108,6 +109,8 @@ def wait_signal_to_apply_chaos(): except Exception as e: log.error(f"read_parquet error: {e}") ready_apply_chaos = False + time.sleep(10) + return ready_apply_chaos diff --git a/tests/python_client/utils/util_k8s.py b/tests/python_client/utils/util_k8s.py index 286aee45c6406..ffaba8bcc1ff0 100644 --- a/tests/python_client/utils/util_k8s.py +++ b/tests/python_client/utils/util_k8s.py @@ -56,8 +56,8 @@ def wait_pods_ready(namespace, label_selector, expected_num=None, timeout=360): api_instance = client.CoreV1Api() try: all_pos_ready_flag = False - time_cnt = 0 - while not all_pos_ready_flag and time_cnt < timeout: + t0 = time.time() + while not all_pos_ready_flag and time.time() - t0 < timeout: api_response = api_instance.list_namespaced_pod(namespace=namespace, label_selector=label_selector) all_pos_ready_flag = True if expected_num is not None and len(api_response.items) < expected_num: @@ -74,8 +74,7 @@ def wait_pods_ready(namespace, label_selector, expected_num=None, timeout=360): break if not all_pos_ready_flag: log.debug("all pods are not ready, please wait") - time.sleep(30) - time_cnt += 30 + time.sleep(5) if all_pos_ready_flag: log.info(f"all pods in namespace {namespace} with label {label_selector} are ready") else: @@ -423,12 +422,14 @@ def find_activate_standby_coord_pod(namespace, release_name, coord_type): return activate_pod_list, standby_pod_list -def reset_healthy_checker_after_standby_activated(namespace, release_name, coord_type, health_checkers, timeout=360): +def record_time_when_standby_activated(namespace, release_name, coord_type, timeout=360): activate_pod_list_before, standby_pod_list_before = find_activate_standby_coord_pod(namespace, release_name, coord_type) log.info(f"check standby switch: activate_pod_list_before {activate_pod_list_before}, " f"standby_pod_list_before {standby_pod_list_before}") standby_activated = False + activate_pod_list_after, standby_pod_list_after = find_activate_standby_coord_pod(namespace, release_name, + coord_type) start_time = time.time() end_time = time.time() while not standby_activated and end_time - start_time < timeout: @@ -443,14 +444,10 @@ def reset_healthy_checker_after_standby_activated(namespace, release_name, coord break except Exception as e: log.error(f"Exception when check standby switch: {e}") - time.sleep(10) + time.sleep(1) end_time = time.time() if standby_activated: - time.sleep(30) - cc.reset_counting(health_checkers) - for k, v in health_checkers.items(): - log.info("reset health checkers") - v.check_result() + log.info(f"Standby {coord_type} pod {activate_pod_list_after[0]} activated") else: log.info(f"Standby {coord_type} pod does not switch standby mode") diff --git a/tests/restful_client/requirements.txt b/tests/restful_client/requirements.txt index d653cc19838d7..91ad942cc62b6 100644 --- a/tests/restful_client/requirements.txt +++ b/tests/restful_client/requirements.txt @@ -1,5 +1,5 @@ requests==2.31.0 -urllib3==1.26.16 +urllib3==1.26.18 loguru~=0.5.3 pytest~=7.2.0 pyyaml~=6.0 diff --git a/tests/scripts/ci_e2e_4am.sh b/tests/scripts/ci_e2e_4am.sh index 0640270ab6d51..c692e7c9b8c36 100755 --- a/tests/scripts/ci_e2e_4am.sh +++ b/tests/scripts/ci_e2e_4am.sh @@ -61,13 +61,13 @@ if [ ! -d "${CI_LOG_PATH}" ]; then mkdir -p ${CI_LOG_PATH} fi -echo "prepare e2e test" -install_pytest_requirements +echo "prepare e2e test" +install_pytest_requirements # Pytest is not able to have both --timeout & --workers, so do not add --timeout or --workers in the shell script if [[ -n "${TEST_TIMEOUT:-}" ]]; then - + timeout "${TEST_TIMEOUT}" pytest --host ${MILVUS_SERVICE_NAME} --port ${MILVUS_SERVICE_PORT} \ --html=${CI_LOG_PATH}/report.html --self-contained-html ${@:-} else @@ -77,10 +77,20 @@ fi # Run bulk insert test if [[ -n "${TEST_TIMEOUT:-}" ]]; then - + timeout "${TEST_TIMEOUT}" pytest testcases/test_bulk_insert.py --host ${MILVUS_SERVICE_NAME} --port ${MILVUS_SERVICE_PORT} --minio_host ${MINIO_SERVICE_NAME} \ --html=${CI_LOG_PATH}/report_bulk_insert.html --self-contained-html else pytest testcases/test_bulk_insert.py --host ${MILVUS_SERVICE_NAME} --port ${MILVUS_SERVICE_PORT} --minio_host ${MINIO_SERVICE_NAME} \ --html=${CI_LOG_PATH}/report_bulk_insert.html --self-contained-html -fi \ No newline at end of file +fi + +# Run concurrent test with 10 processes +if [[ -n "${TEST_TIMEOUT:-}" ]]; then + + timeout "${TEST_TIMEOUT}" pytest testcases/test_concurrent.py --host ${MILVUS_SERVICE_NAME} --port ${MILVUS_SERVICE_PORT} --count 10 -n 10 \ + --html=${CI_LOG_PATH}/report_concurrent.html --self-contained-html +else + pytest testcases/test_concurrent.py --host ${MILVUS_SERVICE_NAME} --port ${MILVUS_SERVICE_PORT} --count 10 -n 10 \ + --html=${CI_LOG_PATH}/report_concurrent.html --self-contained-html +fi diff --git a/tests/scripts/ci_logs.sh b/tests/scripts/ci_logs.sh index af03bcc223a99..b02e7b97b7fc4 100755 --- a/tests/scripts/ci_logs.sh +++ b/tests/scripts/ci_logs.sh @@ -86,28 +86,35 @@ if [[ ! -d ${RELEASE_LOG_DIR} ]] ;then mkdir -p ${RELEASE_LOG_DIR} fi # Try to found logs file from mount disk /volume1/ci-logs -log_files=$(find ${LOG_DIR} -type f -name "*${RELEASE_NAME}*" ) +# log_files=$(find ${LOG_DIR} -type f -name "*${RELEASE_NAME}*" ) +log_files=$(ls ${LOG_DIR} | grep ${RELEASE_NAME} || echo nonexistent) -if [ -z "${log_files:-}" ]; then +if [ "${log_files}" == "nonexistent" ]; then echo "No log files find" else for log_file in ${log_files} do file_name=$(basename ${log_file}) - mv ${log_file} ${RELEASE_LOG_DIR}/`echo ${file_name} | sed 's/ci.var.log.containers.//g' ` + mv ${LOG_DIR}/${log_file} ${RELEASE_LOG_DIR}/`echo ${file_name} | sed 's/ci.var.log.containers.//g' ` done - tar -zcvf ${ARTIFACTS_NAME:-artifacts}.tar.gz ${RELEASE_LOG_DIR}/* rm -rf ${RELEASE_LOG_DIR} fi -remain_log_files=$(find ${LOG_DIR} -type f -name "*${RELEASE_NAME}*") +# remain_log_files=$(find ${LOG_DIR} -type f -name "*${RELEASE_NAME}*") + +# ls ${LOG_DIR} | grep ${RELEASE_NAME} + +remain_log_files=$(ls ${LOG_DIR} | grep ${RELEASE_NAME} || echo nonexistent) -if [ -z "${remain_log_files:-}" ]; then +if [ "${remain_log_files}" == "nonexistent" ]; then echo "No remain log files" else echo "Still have log files & Remove again" - find ${LOG_DIR} -type f -name "*${RELEASE_NAME}*" -exec rm -rf {} + + for log_file in ${remain_log_files} + do + rm -rf ${LOG_DIR}/${log_file} + done echo "Check if any remain log files after using rm to delete again " - find ${LOG_DIR} -type f -name "*${RELEASE_NAME}*" + ls ${LOG_DIR} | grep ${RELEASE_NAME} || echo nonexistent fi \ No newline at end of file diff --git a/tests/scripts/get_image_tag_by_short_name.py b/tests/scripts/get_image_tag_by_short_name.py new file mode 100644 index 0000000000000..c3da4e52ef867 --- /dev/null +++ b/tests/scripts/get_image_tag_by_short_name.py @@ -0,0 +1,68 @@ +import requests +import argparse +from tenacity import retry, stop_after_attempt + +@retry(stop=stop_after_attempt(7)) +def get_image_tag_by_short_name(repository, tag, arch): + + # Send API request to get all tags start with prefix + # ${branch}-latest means the tag is a dev build + # master-latest -> master-$date-$commit + # 2.3.0-latest -> 2.3.0-$date-$commit + # latest means the tag is a release build + # latest -> v$version + splits = tag.split("-") + prefix = splits[0] if len(splits) > 1 else "v" + url = f"https://hub.docker.com/v2/repositories/{repository}/tags?name={prefix}&ordering=last_updated" + response = requests.get(url) + data = response.json() + + # Get the latest tag with the same arch and prefix + sorted_images = sorted(data["results"], key=lambda x: x["last_updated"], reverse=True) + candidate_tag = None + for tag_info in sorted_images: + # print(tag_info) + if arch in [x["architecture"] for x in tag_info["images"]]: + candidate_tag = tag_info["name"] + if candidate_tag == tag: + continue + else: + # print(f"candidate_tag: {candidate_tag}") + break + # Get the DIGEST of the short tag + url = f"https://hub.docker.com/v2/repositories/{repository}/tags/{tag}" + response = requests.get(url) + cur_tag_info = response.json() + digest = cur_tag_info["images"][0]["digest"] + res = [] + # Iterate through all tags and find the ones with the same DIGEST + for tag_info in data["results"]: + if "digest" in tag_info["images"][0] and tag_info["images"][0]["digest"] == digest: + # Extract the image name + image_name = tag_info["name"].split(":")[0] + if image_name != tag and arch in [x["architecture"] for x in tag_info["images"]]: + res.append(image_name) + # In case of no match, try to find the latest tag with the same arch + # there is a case: push master-xxx-arm64 and master-latest, but master-latest-amd64 is not pushed, + # then there will be no tag matched, so we need to find the latest tag with the same arch even it is not the latest tag + for tag_info in data["results"]: + image_name = tag_info["name"].split(":")[0] + if image_name != tag and arch in image_name: + res.append(image_name) + # print(res) + if len(res) == 0 or candidate_tag > res[0]: + if candidate_tag is None: + return tag + return candidate_tag + else: + return res[0] + + +if __name__ == "__main__": + argparse = argparse.ArgumentParser() + argparse.add_argument("--repository", type=str, default="milvusdb/milvus") + argparse.add_argument("--tag", type=str, default="master-latest") + argparse.add_argument("--arch", type=str, default="amd64") + args = argparse.parse_args() + res = get_image_tag_by_short_name(args.repository, args.tag, args.arch) + print(res) diff --git a/tests/scripts/values/ci/pr-gpu.yaml b/tests/scripts/values/ci/pr-gpu.yaml index 1cd1490a9c95f..2bbff8e0fe520 100644 --- a/tests/scripts/values/ci/pr-gpu.yaml +++ b/tests/scripts/values/ci/pr-gpu.yaml @@ -28,10 +28,16 @@ queryCoordinator: queryNode: nodeSelector: nvidia.com/gpu.present: 'true' + extraEnv: + - name: CUDA_VISIBLE_DEVICES + value: "0,1" resources: requests: + nvidia.com/gpu: 2 cpu: "0.5" memory: "500Mi" + limits: + nvidia.com/gpu: 2 indexCoordinator: nodeSelector: @@ -44,10 +50,16 @@ indexCoordinator: indexNode: nodeSelector: nvidia.com/gpu.present: 'true' + extraEnv: + - name: CUDA_VISIBLE_DEVICES + value: "0,1" resources: requests: + nvidia.com/gpu: 2 cpu: "0.5" memory: "500Mi" + limits: + nvidia.com/gpu: 2 dataCoordinator: nodeSelector: